In [3]:
import pandas as pd

import torch
import torch.nn as nn

import torchvision.transforms as T

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from PIL import Image

from tqdm.notebook import tqdm

In [4]:
DATA_PATH = '../data/'

In [5]:
reference = pd.read_csv(DATA_PATH + 'reference_images_metadata.csv')
query = pd.read_csv(DATA_PATH + 'query_images_metadata.csv')
training = pd.read_csv(DATA_PATH + 'training_images_metadata.csv')
gt = pd.read_csv(DATA_PATH + 'public_ground_truth.csv')

print(reference.shape)
print(query.shape)
print(training.shape)
print(gt.shape)

query.head()

(1000000, 2)
(50000, 2)
(1000000, 2)
(25000, 2)


Unnamed: 0,image_id,md5_checksum
0,Q00000,de21a560619005c56dcbd3a7e6c00fd9
1,Q00001,7a68c7f40674a463d14d74c8f8033cc7
2,Q00002,2005093a0ca9b1a33194561b219a0c49
3,Q00003,9b4f2a7cf20d4256b6d46dbba49dd86d
4,Q00004,9038b363055284ba8882943d707a4d06


In [6]:
import random
import augly.image as imaugs

COLOR_JITTER_PARAMS = {
    "brightness_factor": random.uniform(0.5, 1.5),
    "contrast_factor": random.uniform(0.5, 1.5),
    "saturation_factor": random.uniform(0.5, 1.5),
}

AUGMENTATIONS = [
    imaugs.OneOf(
        [imaugs.Blur(),imaugs.Pixelization()]
    ),
    imaugs.ColorJitter(**COLOR_JITTER_PARAMS),
    imaugs.OneOf(
        [imaugs.OverlayOntoScreenshot(), imaugs.OverlayEmoji(), imaugs.OverlayText()]
    ),
]

NN_TRANSFORMS = T.Compose([
    T.Resize(256, interpolation=3),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])



In [7]:
from torch.utils.data import Dataset

# Batch size X (Anchor - Positive - Negatives - semi-label)
# SOLAR -1, anchor 1 positive 0, negative
class SOLARDataset(Dataset):
    def __init__(self, images, AUGMENTATIONS, NN_TRANSFORMS):
        self.img_list = images

        self.transform = NN_TRANSFORMS
        self.augmentations = AUGMENTATIONS

        self.num_augmentations = 1 # 1 for SOLAR
        self.num_negatives = 4 # look default in SOLAR
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        id_ = self.img_list[idx]

        # Get anchors
        anchor = Image.open('../data/images/'+id_+'.jpg')
        query = self.transform(anchor.convert('RGB'))

        # Get Positive
        COLOR_JITTER_PARAMS = {
                "brightness_factor": random.uniform(0.3, 1.7),
                "contrast_factor": random.uniform(0.3, 1.7),
                "saturation_factor": random.uniform(0.3, 1.7),
        }

        AUG = imaugs.Compose(self.augmentations)
        positive = self.transform(AUG(anchor).convert('RGB'))

        # Get Negatives
        negatives = []
        for _ in range(self.num_negatives):
            neg_idx = idx
            while neg_idx == idx:
                neg_idx = random.randint(0,self.__len__()-1)
            neg_id = self.img_list[idx]
            neg = Image.open('../data/images/'+neg_id+'.jpg')
            negatives.append(self.transform(neg.convert('RGB')).unsqueeze(0))
        
        negatives = torch.cat(negatives,dim=0)
 
        # Compose
        x = torch.cat([query.unsqueeze(0),positive.unsqueeze(0),negatives],dim=0)
        labels = [-1, 1] + ([0] * self.num_negatives)
        return {
            'x': x,
            'labels': torch.tensor(labels)
        }

class IterDataset(Dataset):
    def __init__(self, images, NN_TRANSFORMS):
        self.img_list = images
        self.transform = NN_TRANSFORMS
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        id_ = self.img_list[idx]

        # Get anchors
        anchor = Image.open('../data/images/'+id_+'.jpg').convert('RGB')
        anchor = self.transform(anchor).unsqueeze(0)

        # Compose
        return {
            'id': id_,
            'x': anchor
        }

In [8]:
train_images = []

val_q_images = []
val_r_images = []
for i in range(0,1000):
    train_images.append('T'+f"{i:06d}")


for i in range(0,1000):
    val_q_images.append('Q'+f"{i:05d}")
    val_r_images.append('R'+f"{i:06d}")


train_dataset = SOLARDataset(train_images, AUGMENTATIONS, NN_TRANSFORMS)

val_q_dataset = IterDataset(val_q_images, NN_TRANSFORMS)
val_r_dataset = IterDataset(val_r_images, NN_TRANSFORMS)

In [9]:
item = train_dataset.__getitem__(0)
print(item['x'].shape)
print(item['labels'].shape)

torch.Size([6, 3, 224, 224])
torch.Size([6])


In [10]:
item = val_q_dataset.__getitem__(0)
print(item['x'].shape)

torch.Size([1, 3, 224, 224])


In [11]:
from torch.utils.data import DataLoader

NW = 0
BS = 6

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BS,
    shuffle=True,
    num_workers=NW,
    #collate_fn=collate_fn,
)

valid_q_dataloader = DataLoader(
    val_q_dataset,
    batch_size=BS,
    shuffle=False,
    num_workers=NW,
    #collate_fn=collate_fn,
)

valid_r_dataloader = DataLoader(
    val_r_dataset,
    batch_size=BS,
    shuffle=False,
    num_workers=NW,
    #collate_fn=collate_fn,
)

In [12]:
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
model.eval()
model.head = torch.nn.Identity()
#model

Using cache found in /home/victor/.cache/torch/hub/facebookresearch_deit_main


In [13]:
import neptune.new as neptune
#run = neptune.init(project='victorcallejas/FBSim')

In [14]:
device = torch.device("cuda")
device

device(type='cpu')

In [15]:
import sys
sys.path.append("..")
from src.utils.losses.contrastive import TripletLoss
from src.utils.losses.SOLAR import SOLARLoss

optimizer = torch.optim.AdamW(
                model.parameters(),
                lr = 2e-5
            )

criterion = SOLARLoss().to(device)

val_steps = 50

In [16]:
 # Validation
import h5py
import numpy as np

from src.eval_metrics_script.eval_metrics import get_matching_from_descs, evaluate_metrics

def valid(model,valid_q_dataloader,valid_r_dataloader):

    qry_ids, ref_ids = [], []

    M_query, M_ref = [], []

    model.eval()

    for _, batch in tqdm(enumerate(valid_q_dataloader),total=len(valid_q_dataloader)):

        x = batch['x'].to(device).flatten(start_dim=0, end_dim=1)
        ids = batch['id']
        
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_emb = model(x)
                qry_ids.extend(ids)
                M_query.extend(b_emb.cpu().detach().numpy())


    for _, batch in tqdm(enumerate(valid_r_dataloader),total=len(valid_r_dataloader)):

        x = batch['x'].to(device).flatten(start_dim=0, end_dim=1)
        ids = batch['id']
        
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_emb = model(x)
                ref_ids.extend(ids)
                M_ref.extend(b_emb.cpu().detach().numpy())

    M_query, M_ref = np.asarray(M_query,dtype=np.float32), np.asarray(M_ref,dtype=np.float32)

    submission_df = get_matching_from_descs(M_query, M_ref, qry_ids, ref_ids, gt)
    ap, rp90 = evaluate_metrics(submission_df, gt)

    #run["dev/ap"].log(ap)
    #run["dev/rp90"].log(rp90)
    print('VALID - AP: ',ap, 'rp90: ',rp90)


In [17]:
model = model.to(device)


for epoch in range(1,1000):
    print('EPOCH - ',epoch)

    # Training
    total_train_loss = 0

    model.train()
    optimizer.zero_grad()

    with tqdm(total=len(train_dataloader)) as t_bar:
        for step, batch in enumerate(train_dataloader):

            if (step+1) % val_steps == 0:
                valid(model,valid_q_dataloader,valid_r_dataloader)
                
            x = batch['x'].flatten(0,1).to(device)
            targets = batch['labels'].to(device).view(-1)

            with torch.cuda.amp.autocast(enabled=False):
                b_emb = model(x)

            loss = criterion(b_emb.permute(1,0), targets)
            #run["train/batch_loss"].log(loss)
            
            loss.backward()
            total_train_loss += loss.item()
            
            t_bar.set_description("Batch Loss: "+str(loss.item()), refresh=True)

            optimizer.step()
            optimizer.zero_grad()

            t_bar.update()

    print('TRAIN - Loss: ',total_train_loss/len(train_dataloader))
    #run["train/epoch_loss"].log(total_train_loss/len(train_dataloader))

    """
    if total_dev_loss < best_dev_loss:
        best_dev_loss = total_dev_loss
        path = '../artifacts/tmp/'+str(epoch)+'.ckpt'
        torch.save({
            'model':model.state_dict(),
            'opt':optimizer.state_dict(),
            'dev_loss':dev_loss
        },path)
    """


EPOCH -  1


  0%|          | 0/167 [00:00<?, ?it/s]

## GENERATE SUBMISSION

In [None]:
import h5py
import numpy as np

#M_ref = np.random.rand(1_000_000, 256).astype('float32')
#M_query = np.random.rand(50_000, 256).astype('float32')

qry_ids = ['Q' + str(x).zfill(5) for x in range(50_000)]
ref_ids = ['R' + str(x).zfill(6) for x in range(1_000_000)]

In [None]:
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 16

class SubmDataset(Dataset):
    def __init__(self, image_names, transform):
        self.image_names = image_names
        self.transform = transform
        self.image_path = '../data/images/'
        self.image_ext = '.jpg'

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_path =self.image_path + self.image_names[idx] + self.image_ext
        im = Image.open(img_path)
        img = self.transform(im)
        return img

In [None]:
query_dataset = SubmDataset(qry_ids,transform)
reference_dataset = SubmDataset(ref_ids,transform)

query_dataloader = DataLoader(query_dataset , batch_size=BATCH_SIZE, sampler=torch.utils.data.SequentialSampler(query_dataset), num_workers=0)
reference_dataloader = DataLoader(reference_dataset , batch_size=BATCH_SIZE, sampler=torch.utils.data.SequentialSampler(query_dataset), num_workers=0)

In [None]:
 def getDescriptors(model,dataloader,device):

    model.to(device)

    model.eval()

    features = []

    for step, batch in tqdm(enumerate(dataloader),total=len(dataloader)):
        batch = batch.to(device)
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_logits = model(batch)

        features.extend(b_logits)

    return features.to_numpy()

In [None]:
device = torch.device('cuda')

In [None]:
query_feats = getDescriptors(model,query_dataloader,device)
reference_feats = getDescriptors(model,reference_dataloader,device)

In [None]:
out = "../submissions/fb-isc-submission.h5"
with h5py.File(out, "w") as f:
    f.create_dataset("query", data=query_feats)
    f.create_dataset("reference", data=reference_feats)
    f.create_dataset('query_ids', data=qry_ids)
    f.create_dataset('reference_ids', data=ref_ids)