In [None]:
import pandas as pd

import torch
import torch.nn as nn

from tqdm.notebook import tqdm

In [None]:
DATA_PATH = '../data/raw/'

In [None]:
reference = pd.read_csv(DATA_PATH + 'reference_images_metadata.csv')
query = pd.read_csv(DATA_PATH + 'query_images_metadata.csv')
training = pd.read_csv(DATA_PATH + 'training_images_metadata.csv')
gt = pd.read_csv(DATA_PATH + 'public_ground_truth.csv')

print(reference.shape)
print(query.shape)
print(training.shape)
print(gt.shape)

query.head()

In [None]:
from torch.utils.data import Dataset
import random

# Batch size X (Anchor - Positive - Negatives - semi-label)
# SOLAR -1, anchor 1 positive 0, negative

DTYPE = torch.float32

class SOLARDataset(Dataset):
    def __init__(self, images):
        self.img_ids = images

        self.num_augmentations = 1 # 1 for SOLAR
        self.num_negatives = 4 # look default in SOLAR
    
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        id_ = self.img_ids[idx]

        # Get anchor
        anchor = torch.load('../data/processed/images/'+id_+'.pt')

        # Get Positive
        positive = torch.load('../data/processed/augmented/'+id_+'.pt')

        # Get Negatives
        negatives = []
        for _ in range(self.num_negatives):

            neg_idx = idx
            while neg_idx == idx:
                neg_idx = random.randint(0,self.__len__()-1)

            neg_id = self.img_ids[idx]
            neg = torch.load('../data/processed/images/'+neg_id+'.pt')
            negatives.append(neg.unsqueeze(0))
            
        negatives = torch.cat(negatives,dim=0)
 
        # Compose
        x = torch.cat([anchor.unsqueeze(0),positive.unsqueeze(0),negatives],dim=0).to(DTYPE)
        labels = [-1, 1] + ([0] * self.num_negatives)
        return {
            'x': x,
            'labels': torch.tensor(labels)
        }
        
class IterDataset(Dataset):
    def __init__(self, images):
        self.img_ids = images
    
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        id_ = self.img_ids[idx]
        # Compose
        return {
            'id': id_,
            'x': torch.load('../data/processed/images/'+id_+'.pt').unsqueeze(0).to(DTYPE)
        }

In [None]:
"""
train_images = training.image_id.tolist()
val_q_images = query.image_id.tolist()
val_r_images = reference.image_id.tolist()
"""
train_images = []

val_q_images = []
val_r_images = []
for i in range(0,1000):
    train_images.append('T'+f"{i:06d}")


for i in range(0,1000):
    val_q_images.append('Q'+f"{i:05d}")
    val_r_images.append('R'+f"{i:06d}")


train_dataset = SOLARDataset(train_images)

val_q_dataset = IterDataset(val_q_images)
val_r_dataset = IterDataset(val_r_images)

In [None]:
item = train_dataset.__getitem__(0)
print(item['x'].shape)
print(item['labels'].shape)

In [None]:
item = val_q_dataset.__getitem__(0)
print(item['x'].shape)

In [None]:
from torch.utils.data import DataLoader

NW = 8
BS = 8

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BS,
    shuffle=True,
    num_workers=NW,
    pin_memory=True,
    #collate_fn=collate_fn,
)

valid_q_dataloader = DataLoader(
    val_q_dataset,
    batch_size=BS,
    shuffle=False,
    num_workers=NW,
    pin_memory=True,
    #collate_fn=collate_fn,
)

valid_r_dataloader = DataLoader(
    val_r_dataset,
    batch_size=BS,
    shuffle=False,
    num_workers=NW,
    pin_memory=True,
    #collate_fn=collate_fn,
)

In [None]:
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
model.eval()
model.head = torch.nn.Identity()
#model

In [None]:
import neptune.new as neptune
run = neptune.init(project='victorcallejas/FBSim',
                   api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJlNDRlNTJiNC00OTQwLTQxYjgtYWZiNS02OWQ0MDcwZmU5N2YifQ=='
    )

In [None]:
device = torch.device("cuda")
device

In [None]:
import sys
sys.path.append("..")
from src.utils.losses.contrastive import TripletLoss
from src.utils.losses.SOLAR import SOLARLoss

optimizer = torch.optim.AdamW(
                model.parameters(),
                lr = 2e-5
            )

criterion = SOLARLoss().to(device)

val_steps = 150

In [None]:
 # Validation
import h5py
import numpy as np

from src.eval_metrics_script.eval_metrics import get_matching_from_descs, evaluate_metrics

def valid(model,valid_q_dataloader,valid_r_dataloader):

    qry_ids, ref_ids = [], []

    M_query, M_ref = [], []

    model.eval()

    for _, batch in tqdm(enumerate(valid_q_dataloader),total=len(valid_q_dataloader)):

        x = batch['x'].to(device).flatten(0, 1)
        ids = batch['id']
        
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_emb = model(x)
                qry_ids.extend(ids)
                M_query.extend(b_emb.detach())


    for _, batch in tqdm(enumerate(valid_r_dataloader),total=len(valid_r_dataloader)):

        x = batch['x'].to(device).flatten(0,1)
        ids = batch['id']
        
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_emb = model(x)
                ref_ids.extend(ids)
                M_ref.extend(b_emb.detach())

    M_query, M_ref = np.asarray(M_query,dtype=np.float32), np.asarray(M_ref,dtype=np.float32)

    submission_df = get_matching_from_descs(M_query, M_ref, qry_ids, ref_ids, gt)
    ap, rp90 = evaluate_metrics(submission_df, gt)

    run["dev/ap"].log(ap)
    run["dev/rp90"].log(rp90)
    print('VALID - AP: ',ap, 'rp90: ',rp90)


In [None]:
model = model.to(device)


for epoch in range(1,1000):
    print('EPOCH - ',epoch)

    # Training
    total_train_loss = 0

    model.train()
    optimizer.zero_grad()

    with tqdm(total=len(train_dataloader)) as t_bar:
        for step, batch in enumerate(train_dataloader):

            if (step+1) % val_steps == 0:
                valid(model,valid_q_dataloader,valid_r_dataloader)
                
            x = batch['x'].flatten(0,1).to(device)      
            targets = batch['labels'].to(device).view(-1)

            with torch.cuda.amp.autocast(enabled=False):
                b_emb = model(x)

            loss = criterion(b_emb.permute(1,0), targets)
            run["train/batch_loss"].log(loss)
            
            loss.backward()
            total_train_loss += loss.item()
            
            t_bar.set_description("Batch Loss: "+str(loss.item()), refresh=True)

            optimizer.step()
            optimizer.zero_grad()

            t_bar.update()

    print('TRAIN - Loss: ',total_train_loss/len(train_dataloader))
    run["train/epoch_loss"].log(total_train_loss/len(train_dataloader))

    """
    if total_dev_loss < best_dev_loss:
        best_dev_loss = total_dev_loss
        path = '../artifacts/tmp/'+str(epoch)+'.ckpt'
        torch.save({
            'model':model.state_dict(),
            'opt':optimizer.state_dict(),
            'dev_loss':dev_loss
        },path)
    """


## GENERATE SUBMISSION

In [None]:
 def getDescriptors(model,dataloader,device):

    model.to(device)

    model.eval()

    features = []

    for step, batch in tqdm(enumerate(dataloader),total=len(dataloader)):
        batch = batch.to(device)
        with torch.cuda.amp.autocast(enabled=False):
            with torch.no_grad(): 
                b_logits = model(batch)

        features.extend(b_logits)

    return features.to_numpy()

In [None]:
qry_ids = ['Q' + str(x).zfill(5) for x in range(50_000)]
ref_ids = ['R' + str(x).zfill(6) for x in range(1_000_000)]

query_feats = getDescriptors(model,valid_q_dataloader,device)
reference_feats = getDescriptors(model,valid_r_dataloader,device)

In [None]:
out = "../submissions/fb-isc-submission.h5"
with h5py.File(out, "w") as f:
    f.create_dataset("query", data=query_feats)
    f.create_dataset("reference", data=reference_feats)
    f.create_dataset('query_ids', data=qry_ids)
    f.create_dataset('reference_ids', data=ref_ids)