In [16]:
import pandas as pd

import torch
import torch.nn as nn

import timm
import torchvision
import torchvision.transforms as T

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from PIL import Image

from tqdm.notebook import tqdm

In [17]:
DATA_PATH = '../data/'

In [18]:
reference = pd.read_csv(DATA_PATH + 'reference_images_metadata.csv')
query = pd.read_csv(DATA_PATH + 'query_images_metadata.csv')
training = pd.read_csv(DATA_PATH + 'training_images_metadata.csv')
gt = pd.read_csv(DATA_PATH + 'public_ground_truth.csv')

print(reference.shape)
print(query.shape)
print(training.shape)
print(gt.shape)

query.head()

(1000000, 2)
(50000, 2)
(1000000, 2)
(25000, 2)


Unnamed: 0,image_id,md5_checksum
0,Q00000,de21a560619005c56dcbd3a7e6c00fd9
1,Q00001,7a68c7f40674a463d14d74c8f8033cc7
2,Q00002,2005093a0ca9b1a33194561b219a0c49
3,Q00003,9b4f2a7cf20d4256b6d46dbba49dd86d
4,Q00004,9038b363055284ba8882943d707a4d06


In [19]:
import random
import augly.image as imaugs

COLOR_JITTER_PARAMS = {
    "brightness_factor": random.uniform(0.5, 1.5),
    "contrast_factor": random.uniform(0.5, 1.5),
    "saturation_factor": random.uniform(0.5, 1.5),
}

AUGMENTATIONS = [
    imaugs.OneOf(
        [imaugs.Blur(),imaugs.Pixelization()]
    ),
    imaugs.ColorJitter(**COLOR_JITTER_PARAMS),
    imaugs.OneOf(
        [imaugs.OverlayOntoScreenshot(), imaugs.OverlayEmoji(), imaugs.OverlayText()]
    ),
]

NN_TRANSFORMS = T.Compose([
    T.Resize(256, interpolation=3),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])



In [20]:
from torch.utils.data import Dataset

# Batch size X (Anchor - Positive - Negative - semi-label)
# 1 similar or -1 dissimilar
class SimDataset(Dataset):
    def __init__(self, images, AUGMENTATIONS, NN_TRANSFORMS):
        self.img_list = images

        self.transform = NN_TRANSFORMS
        self.augmentations = AUGMENTATIONS

        self.num_augmentations = 5
        #self.num_negatives = 1
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        id_ = self.img_list[idx]

        # Get anchors
        anchor = Image.open('../data/images/'+id_+'.jpg').convert('RGB')
        anchors = [self.transform(anchor)] * self.num_augmentations

        # Get Positives
        positives = []
        for _ in range(self.num_augmentations):

            COLOR_JITTER_PARAMS = {
                "brightness_factor": random.uniform(0.3, 1.7),
                "contrast_factor": random.uniform(0.3, 1.7),
                "saturation_factor": random.uniform(0.3, 1.7),
            }

            AUG = imaugs.Compose(self.augmentations)
            new_positive = AUG(anchor).convert('RGB')

            positives.append(self.transform(new_positive))


        # Get Negatives
        negatives = []
        neg_idx = idx
        while neg_idx == idx:
            neg_idx = random.randint(0,self.__len__()-1)
        neg_id = self.img_list[idx]
        neg_image = Image.open('../data/images/'+neg_id+'.jpg')
        for _ in range(self.num_augmentations):

            COLOR_JITTER_PARAMS = {
                "brightness_factor": random.uniform(0.3, 1.7),
                "contrast_factor": random.uniform(0.3, 1.7),
                "saturation_factor": random.uniform(0.3, 1.7),
            }

            AUG = imaugs.Compose(self.augmentations)
            new_negative = AUG(anchor).convert('RGB')

            negatives.append(self.transform(new_negative))


        # Compose
        return {
            'anchors': torch.stack(anchors),
            'positives': torch.stack(positives),
            'negatives': torch.stack(negatives),
            'targets': torch.Tensor([-1, 1]).repeat(self.num_augmentations)
        }

class IterDataset(Dataset):
    def __init__(self, images, NN_TRANSFORMS):
        self.img_list = images
        self.transform = NN_TRANSFORMS
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        id_ = self.img_list[idx]

        # Get anchors
        anchor = Image.open('../data/images/'+id_+'.jpg').convert('RGB')
        anchor = self.transform(anchor).unsqueeze(0)

        # Compose
        return {
            'id': id_,
            'anchor': anchor
        }

In [21]:
train_images = []

val_q_images = []
val_r_images = []
for i in range(0,1000):
    train_images.append('T'+f"{i:06d}")


for i in range(0,1000):
    val_q_images.append('Q'+f"{i:05d}")
    val_r_images.append('R'+f"{i:06d}")


train_dataset = SimDataset(train_images, AUGMENTATIONS, NN_TRANSFORMS)

val_q_dataset = IterDataset(val_q_images, NN_TRANSFORMS)
val_r_dataset = IterDataset(val_r_images, NN_TRANSFORMS)

In [22]:
item = train_dataset.__getitem__(0)
item['anchors'].shape

torch.Size([5, 3, 224, 224])

In [23]:
item = val_q_dataset.__getitem__(0)
item['anchor'].shape

torch.Size([1, 3, 224, 224])

In [24]:
from torch.utils.data import DataLoader

NW = 3
BS = 8

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BS,
    shuffle=True,
    num_workers=NW,
    #collate_fn=collate_fn,
)

valid_q_dataloader = DataLoader(
    val_q_dataset,
    batch_size=BS,
    shuffle=False,
    num_workers=NW,
    #collate_fn=collate_fn,
)

valid_r_dataloader = DataLoader(
    val_r_dataset,
    batch_size=BS,
    shuffle=False,
    num_workers=NW,
    #collate_fn=collate_fn,
)

In [25]:
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
model.eval()
model.head = torch.nn.Identity()
#model

Using cache found in /home/victor/.cache/torch/hub/facebookresearch_deit_main


In [26]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
device

device(type='cuda')

In [33]:
import sys
sys.path.append("..")
from src.utils.losses import TripletLoss

optimizer = torch.optim.SGD(
                model.parameters(),
                lr = 2e-4
            )
criterion = TripletLoss()

val_steps = 50

In [34]:
import neptune.new as neptune
run = neptune.init(project='victorcallejas/FBSim')

https://app.neptune.ai/victorcallejas/FBSim/e/FBSIM-31
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [35]:
 # Validation
import h5py
import numpy as np

from src.eval_metrics_script.eval_metrics import get_matching_from_descs, evaluate_metrics

def valid(model,valid_q_dataloader,valid_r_dataloader):

    qry_ids, ref_ids = [], []

    M_query, M_ref = [], []

    model.eval()

    for _, batch in tqdm(enumerate(valid_q_dataloader),total=len(valid_q_dataloader)):

        anchors = batch['anchor'].to(device).flatten(start_dim=0, end_dim=1)
        ids = batch['id']
        
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_anchor_emb = model(anchors)
                qry_ids.extend(ids)
                M_query.extend(b_anchor_emb.cpu().detach().numpy())


    for _, batch in tqdm(enumerate(valid_r_dataloader),total=len(valid_r_dataloader)):

        anchors = batch['anchor'].to(device).flatten(start_dim=0, end_dim=1)
        ids = batch['id']
        
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_anchor_emb = model(anchors)
                ref_ids.extend(ids)
                M_ref.extend(b_anchor_emb.cpu().detach().numpy())

    M_query, M_ref = np.asarray(M_query,dtype=np.float32), np.asarray(M_ref,dtype=np.float32)

    submission_df = get_matching_from_descs(M_query, M_ref, qry_ids, ref_ids, gt)
    ap, rp90 = evaluate_metrics(submission_df, gt)

    run["dev/ap"].log(ap)
    run["dev/rp90"].log(rp90)


In [36]:
model = model.to(device)


for epoch in range(1,1000):
    print('EPOCH - ',epoch)

    # Training
    total_train_loss = 0

    model.train()
    optimizer.zero_grad()

    for step, batch in tqdm(enumerate(train_dataloader),total=len(train_dataloader)):

        if (step+1) % val_steps == 0:
            valid(model,valid_q_dataloader,valid_r_dataloader)
            

        anchors = batch['anchors'].to(device).flatten(start_dim=0, end_dim=1)
        positives = batch['positives'].to(device).flatten(start_dim=0, end_dim=1)
        negatives = batch['negatives'].to(device).flatten(start_dim=0, end_dim=1)

        with torch.cuda.amp.autocast(enabled=False):
            b_anchor_emb = model(anchors)
            b_pos_emb = model(positives)
            b_neg_emb = model(negatives)

        loss = criterion(b_anchor_emb, b_pos_emb, b_neg_emb)
        run["train/batch_loss"].log(loss)
        
        loss.backward()

        total_train_loss += loss.item()

        optimizer.step()
        optimizer.zero_grad()

    run["train/epoch_loss"].log(total_train_loss/len(train_dataloader))

    """
    if total_dev_loss < best_dev_loss:
        best_dev_loss = total_dev_loss
        path = '../artifacts/tmp/'+str(epoch)+'.ckpt'
        torch.save({
            'model':model.state_dict(),
            'opt':optimizer.state_dict(),
            'dev_loss':dev_loss
        },path)
    """


EPOCH -  1


  0%|          | 0/125 [00:00<?, ?it/s]

## GENERATE SUBMISSION

In [None]:
import h5py
import numpy as np

#M_ref = np.random.rand(1_000_000, 256).astype('float32')
#M_query = np.random.rand(50_000, 256).astype('float32')

qry_ids = ['Q' + str(x).zfill(5) for x in range(50_000)]
ref_ids = ['R' + str(x).zfill(6) for x in range(1_000_000)]

In [None]:
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 16

class SubmDataset(Dataset):
    def __init__(self, image_names, transform):
        self.image_names = image_names
        self.transform = transform
        self.image_path = '../data/images/'
        self.image_ext = '.jpg'

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_path =self.image_path + self.image_names[idx] + self.image_ext
        im = Image.open(img_path)
        img = self.transform(im)
        return img

In [None]:
query_dataset = SubmDataset(qry_ids,transform)
reference_dataset = SubmDataset(ref_ids,transform)

query_dataloader = DataLoader(query_dataset , batch_size=BATCH_SIZE, sampler=torch.utils.data.SequentialSampler(query_dataset), num_workers=0)
reference_dataloader = DataLoader(reference_dataset , batch_size=BATCH_SIZE, sampler=torch.utils.data.SequentialSampler(query_dataset), num_workers=0)

In [None]:
 def getDescriptors(model,dataloader,device):

    model.to(device)

    model.eval()

    features = []

    for step, batch in tqdm(enumerate(dataloader),total=len(dataloader)):
        batch = batch.to(device)
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_logits = model(batch)

        features.extend(b_logits)

    return features.to_numpy()

In [None]:
device = torch.device('cuda')

In [None]:
query_feats = getDescriptors(model,query_dataloader,device)
reference_feats = getDescriptors(model,reference_dataloader,device)

In [None]:
out = "../submissions/fb-isc-submission.h5"
with h5py.File(out, "w") as f:
    f.create_dataset("query", data=query_feats)
    f.create_dataset("reference", data=reference_feats)
    f.create_dataset('query_ids', data=qry_ids)
    f.create_dataset('reference_ids', data=ref_ids)