In [1]:
import pandas as pd

import torch
import torch.nn as nn

import torchvision.transforms as T

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from PIL import Image

from tqdm.notebook import tqdm

In [2]:
DATA_PATH = '../data/'

In [3]:
reference = pd.read_csv(DATA_PATH + 'reference_images_metadata.csv')
query = pd.read_csv(DATA_PATH + 'query_images_metadata.csv')
training = pd.read_csv(DATA_PATH + 'training_images_metadata.csv')
gt = pd.read_csv(DATA_PATH + 'public_ground_truth.csv')

print(reference.shape)
print(query.shape)
print(training.shape)
print(gt.shape)

query.head()

(1000000, 2)
(50000, 2)
(1000000, 2)
(25000, 2)


Unnamed: 0,image_id,md5_checksum
0,Q00000,de21a560619005c56dcbd3a7e6c00fd9
1,Q00001,7a68c7f40674a463d14d74c8f8033cc7
2,Q00002,2005093a0ca9b1a33194561b219a0c49
3,Q00003,9b4f2a7cf20d4256b6d46dbba49dd86d
4,Q00004,9038b363055284ba8882943d707a4d06


In [4]:
import random
import augly.image as imaugs

NN_TRANSFORMS = T.Compose([
    T.Resize(256, interpolation=3),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])



In [36]:
from torch.utils.data import Dataset
import torchvision

# Batch size X (Anchor - Positive - Negatives - semi-label)
# SOLAR -1, anchor 1 positive 0, negative
class SOLARDataset(Dataset):
    def __init__(self, images, NN_TRANSFORMS):
        self.img_list = images

        self.transform = NN_TRANSFORMS

        self.num_augmentations = 1 # 1 for SOLAR
        self.num_negatives = 4 # look default in SOLAR
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        id_ = self.img_list[idx]

        # Get anchors
        anchor = Image.open('../data/images/'+id_+'.jpg')
        query = self.transform(anchor.convert('RGB'))

        # Get Positive
        COLOR_JITTER_PARAMS = {
                "brightness_factor": random.uniform(0.3, 1.7),
                "contrast_factor": random.uniform(0.3, 1.7),
                "saturation_factor": random.uniform(0.3, 1.7),
        }

        AUGMENTATIONS = [
            imaugs.OneOf(
                [imaugs.Blur(),imaugs.Pixelization()]
            ),
            imaugs.ColorJitter(**COLOR_JITTER_PARAMS),
            imaugs.OneOf(
                [imaugs.OverlayOntoScreenshot(), imaugs.OverlayEmoji(), imaugs.OverlayText()]
            ),
        ]

        AUG = imaugs.Compose(AUGMENTATIONS)
        positive = self.transform(AUG(anchor).convert('RGB'))

        # Get Negatives
        negatives = []
        for _ in range(self.num_negatives):
            neg_idx = idx
            while neg_idx == idx:
                neg_idx = random.randint(0,self.__len__()-1)
            neg_id = self.img_list[idx]
            neg = Image.open('../data/images/'+neg_id+'.jpg')
            negatives.append(self.transform(neg.convert('RGB')).unsqueeze(0))
        
        negatives = torch.cat(negatives,dim=0)
 
        # Compose
        x = torch.cat([query.unsqueeze(0),positive.unsqueeze(0),negatives],dim=0)
        labels = [-1, 1] + ([0] * self.num_negatives)
        return {
            'x': x,
            'labels': torch.tensor(labels)
        }
        
class IterDataset(Dataset):
    def __init__(self, images):
        self.id_list = images
        self.images =[]
        for id_ in tqdm(self.id_list,total=len(self.id_list)):
            anchor = torchvision.io.read_image('../data/images/'+id_+'.jpg',mode=torchvision.io.ImageReadMode.RGB).unsqueeze(0)
            self.images.append(anchor)

    
    def __len__(self):
        return len(self.id_list)
    
    def __getitem__(self, idx):
        # Compose
        return {
            'id': self.id_list[idx],
            'x': self.images[idx]
        }

In [37]:
train_images = training.image_id.tolist()
val_q_images = query.image_id.tolist()
val_r_images = reference.image_id.tolist()
"""
train_images = []

val_q_images = []
val_r_images = []
for i in range(0,1000):
    train_images.append('T'+f"{i:06d}")


for i in range(0,1000):
    val_q_images.append('Q'+f"{i:05d}")
    val_r_images.append('R'+f"{i:06d}")
"""

train_dataset = SOLARDataset(train_images, NN_TRANSFORMS)

val_q_dataset = IterDataset(val_q_images)
val_r_dataset = IterDataset(val_r_images)

  0%|          | 0/50000 [00:00<?, ?it/s]

In [27]:
item = train_dataset.__getitem__(0)
print(item['x'].shape)
print(item['labels'].shape)

torch.Size([6, 3, 224, 224])
torch.Size([6])


In [28]:
item = val_q_dataset.__getitem__(0)
print(item['x'].shape)

torch.Size([1, 3, 768, 1024])


In [34]:
from torch.utils.data import DataLoader

NW = 8
T_BS = 32

V_BS = 32

train_dataloader = DataLoader(
    train_dataset,
    batch_size=T_BS,
    shuffle=True,
    num_workers=NW,
    pin_memory=True,
    #collate_fn=collate_fn,
)

valid_q_dataloader = DataLoader(
    val_q_dataset,
    batch_size=V_BS,
    shuffle=False,
    num_workers=NW,
    pin_memory=True,
    #collate_fn=collate_fn,
)

valid_r_dataloader = DataLoader(
    val_r_dataset,
    batch_size=T_BS,
    shuffle=False,
    num_workers=NW,
    pin_memory=True,
    #collate_fn=collate_fn,
)



In [10]:
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
model.eval()
model.head = torch.nn.Identity()
#model

Using cache found in /home/ubuntu/.cache/torch/hub/facebookresearch_deit_main


In [11]:
import neptune.new as neptune
run = neptune.init(project='victorcallejas/FBSim',
                   api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJlNDRlNTJiNC00OTQwLTQxYjgtYWZiNS02OWQ0MDcwZmU5N2YifQ=='
    )

https://app.neptune.ai/victorcallejas/FBSim/e/FBSIM-38
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [12]:
device = torch.device("cuda")
device

device(type='cuda')

In [13]:
import sys
sys.path.append("..")
from src.utils.losses.contrastive import TripletLoss
from src.utils.losses.SOLAR import SOLARLoss

optimizer = torch.optim.AdamW(
                model.parameters(),
                lr = 2e-5
            )

criterion = SOLARLoss().to(device)

val_steps = 150

In [32]:
 # Validation
import h5py
import numpy as np

from src.eval_metrics_script.eval_metrics import get_matching_from_descs, evaluate_metrics

def valid(model,valid_q_dataloader,valid_r_dataloader):

    qry_ids, ref_ids = [], []

    M_query, M_ref = [], []

    model.eval()

    for _, batch in tqdm(enumerate(valid_q_dataloader),total=len(valid_q_dataloader)):

        x = batch['x'].to(device).flatten(0, 1)
        x = NN_TRANSFORMS(x)
        ids = batch['id']
        
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_emb = model(x)
                qry_ids.extend(ids)
                M_query.extend(b_emb.cpu().detach().numpy())


    for _, batch in tqdm(enumerate(valid_r_dataloader),total=len(valid_r_dataloader)):

        x = batch['x'].to(device).flatten(0,1)
        x = NN_TRANSFORMS(x)
        ids = batch['id']
        
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_emb = model(x)
                ref_ids.extend(ids)
                M_ref.extend(b_emb.cpu().detach().numpy())

    M_query, M_ref = np.asarray(M_query,dtype=np.float32), np.asarray(M_ref,dtype=np.float32)

    submission_df = get_matching_from_descs(M_query, M_ref, qry_ids, ref_ids, gt)
    ap, rp90 = evaluate_metrics(submission_df, gt)

    run["dev/ap"].log(ap)
    run["dev/rp90"].log(rp90)
    print('VALID - AP: ',ap, 'rp90: ',rp90)


In [35]:
model = model.to(device)


for epoch in range(1,1000):
    print('EPOCH - ',epoch)

    # Training
    total_train_loss = 0

    model.train()
    optimizer.zero_grad()

    with tqdm(total=len(train_dataloader)) as t_bar:
        for step, batch in enumerate(train_dataloader):

            if (step+1) % val_steps == 0:
                valid(model,valid_q_dataloader,valid_r_dataloader)
                
            x = batch['x'].flatten(0,1).to(device)      
            targets = batch['labels'].to(device).view(-1)

            with torch.cuda.amp.autocast(enabled=False):
                b_emb = model(x)

            loss = criterion(b_emb.permute(1,0), targets)
            run["train/batch_loss"].log(loss)
            
            loss.backward()
            total_train_loss += loss.item()
            
            t_bar.set_description("Batch Loss: "+str(loss.item()), refresh=True)

            optimizer.step()
            optimizer.zero_grad()

            t_bar.update()

    print('TRAIN - Loss: ',total_train_loss/len(train_dataloader))
    run["train/epoch_loss"].log(total_train_loss/len(train_dataloader))

    """
    if total_dev_loss < best_dev_loss:
        best_dev_loss = total_dev_loss
        path = '../artifacts/tmp/'+str(epoch)+'.ckpt'
        torch.save({
            'model':model.state_dict(),
            'opt':optimizer.state_dict(),
            'dev_loss':dev_loss
        },path)
    """


EPOCH -  1


  0%|          | 0/31250 [00:00<?, ?it/s]



AttributeError: 'IterDataset' object has no attribute 'img_list'

## GENERATE SUBMISSION

In [None]:
 def getDescriptors(model,dataloader,device):

    model.to(device)

    model.eval()

    features = []

    for step, batch in tqdm(enumerate(dataloader),total=len(dataloader)):
        batch = batch.to(device)
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_logits = model(batch)

        features.extend(b_logits)

    return features.to_numpy()

In [None]:
qry_ids = ['Q' + str(x).zfill(5) for x in range(50_000)]
ref_ids = ['R' + str(x).zfill(6) for x in range(1_000_000)]

query_feats = getDescriptors(model,valid_q_dataloader,device)
reference_feats = getDescriptors(model,valid_r_dataloader,device)

In [None]:
out = "../submissions/fb-isc-submission.h5"
with h5py.File(out, "w") as f:
    f.create_dataset("query", data=query_feats)
    f.create_dataset("reference", data=reference_feats)
    f.create_dataset('query_ids', data=qry_ids)
    f.create_dataset('reference_ids', data=ref_ids)