In [None]:
%pip install sentence-transformers

In [None]:
from io import BytesIO
import base64
from pathlib import Path
import pandas as pd
import numpy as np
from transformers import ViTImageProcessor, ViTModel, ViTFeatureExtractor
from PIL import Image
import torch
import requests
import tqdm as tqdm
from sentence_transformers import SentenceTransformer, util
from PIL import Image, ImageFile
import requests
import torch
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
BATCH_SIZE = 64

In [None]:
merged_df = pd.read_csv('/kaggle/input/nto-cv-olympiad-final-dataset/merged/merged.csv')
text_df = pd.read_csv('/kaggle/input/nto-cv-olympiad-final-dataset/descriptionsv2.csv')
df = pd.concat([merged_df, text_df[['text']]], axis=1)

In [None]:
assert all(merged_df.XID == text_df.XID)

In [None]:
def img_from_base64(data: str) -> Image.Image:
    return Image.open(BytesIO(base64.b64decode(data, validate=False)))

In [None]:
test_df = df.sample(100).reset_index(drop=True)
test_df = df

In [None]:
images = list(test_df.image.apply(img_from_base64))

In [None]:
texts = list(test_df.text)

In [None]:
# We use the original clip-ViT-B-32 for encoding images
img_model = SentenceTransformer('clip-ViT-B-32')

# Our text embedding model is aligned to the img_model and maps 50+
# languages to the same vector space
text_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1')

# Map images to the vector space
# img_embeddings = img_model.encode(images[:10])

# # Now we encode our text:
# texts = [
#     "A dog in the snow",
#     "Eine Katze",  # German: A cat
#     "Una playa con palmeras."  # Spanish: a beach with palm trees
# ]

# text_embeddings = text_model.encode(texts)

In [None]:
img_features = img_model.encode(images, convert_to_tensor=True)

In [None]:
text_features = text_model.encode(texts, convert_to_tensor=True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [None]:
class EmbeddingsDataset(Dataset):
    def __init__(self, img_features, text_features, xids, n, p=.5):
        self.imgs = img_features
        self.texts = text_features
        self.all = torch.concatenate([img_features, text_features], dim=0)
        self.xids = list(xids) * 2
        self.p = p
        self.n = n
        self.pre = [self.gen() for _ in tqdm.trange(self.n)]
    def gen(self):
        i = np.random.randint(len(self.all))
        ixid = self.xids[i]
        if np.random.random() < self.p:
            idxs = [j for j, xid in enumerate(self.xids) if xid != ixid]
            j = np.random.choice(idxs)
            return self.all[i], self.all[j], -1
        else:
            idxs = [j for j, xid in enumerate(self.xids) if xid == ixid]
            j = np.random.choice(idxs)
            return self.all[i], self.all[j], 1
    
    def __len__(self):
        return self.n
    
    def __getitem__(self, idx):
        return self.pre[idx]
    

In [None]:
import pytorch_lightning as pl

In [None]:
train_dataset = EmbeddingsDataset(img_features, text_features, test_df.XID, 100000)
val_dataset = EmbeddingsDataset(img_features, text_features, test_df.XID, 1000)
test_dataset = EmbeddingsDataset(img_features, text_features, test_df.XID, 100)

In [None]:
num_workers = 0
batch_size = 1024
train_loader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size)
val_loader = DataLoader(val_dataset, num_workers=num_workers, batch_size=batch_size)
test_loader = DataLoader(test_dataset, num_workers=num_workers, batch_size=batch_size)

In [None]:
class Adapter(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
            nn.Linear(512, 2048),
            nn.ReLU(),
            nn.Linear(2048, 512),
        )
        self.loss = nn.CosineEmbeddingLoss(margin=.0)
    
    def forward(self, x):
        return x + self.body(x)
    
    def training_step(self, batch, batch_idx):
        x, y, target = batch
        x = self(x)
        loss = self.loss(x, y, target)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y, target = batch
        x = self(x)
        loss = self.loss(x, y, target)
        self.log('val_loss', loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        return optimizer

In [None]:
model_kwargs = {
}
model = Adapter(**model_kwargs)
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min', save_last=True)
trainer = pl.Trainer(max_epochs=20, callbacks=[checkpoint_callback])

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
checkpoint_callback.best_model_score

In [None]:
checkpoint_callback.current_score

In [None]:
# adapter = Adapter.load_from_checkpoint(checkpoint_callback.best_model_path, **model_kwargs).cuda()
adapter = model.cuda()

In [None]:
def get_first_relevant_rank(e, emb, xid):
    dots = np.dot(emb, e[..., None])
    norms = np.linalg.norm(emb)[..., None]
    sims = (dots / norms / np.linalg.norm(e))[:, 0]
    top = np.argsort(sims)[::-1][:]
    df_rearranged = df.iloc[top]
    df_rearranged.reset_index(drop=True, inplace=True)
    return df_rearranged.loc[df_rearranged.XID == xid].index[0] + 1

In [None]:
adapted_text = adapter(text_features)

In [None]:
def get_ranks(features, stride=1):
    n = len(features)
    res = []
    for i in tqdm.trange(0, n, stride):
        rank = get_first_relevant_rank(features[i].detach().cpu().numpy(), img_features.detach().cpu().numpy(), test_df.iloc[i].XID)
        res.append(rank)
    return res

In [None]:
import tqdm.notebook as tqdm
text_ranks = get_ranks(text_features,1)

In [None]:
sns.histplot(text_ranks, log_scale=True)
plt.savefig('before.png')

In [None]:
adapted_ranks = get_ranks(adapted_text, 1)

In [None]:
sns.histplot(adapted_ranks, log_scale=True)
plt.savefig('after.png')

In [None]:
(1 / np.array(text_ranks)).mean()

In [None]:
(1 / np.array(adapted_ranks)).mean()

In [None]:
torch.save(adapter.state_dict(), 'adapter.pt')

In [None]:
def get_top_k(e, emb, xid, k):
    dots = np.dot(emb, e[..., None])
    norms = np.linalg.norm(emb)[..., None]
    sims = (dots / norms / np.linalg.norm(e))[:, 0]
    top = np.argsort(sims)[::-1][:k]
    accurate = (df.iloc[top].XID == xid).sum()
    return accurate / k

In [None]:
def get_accuracy(features, k, stride=1):
    n = len(features)
    res = []
    for i in tqdm.trange(0, n, stride):
        rank = get_top_k(features[i].detach().cpu().numpy(), img_features.detach().cpu().numpy(), test_df.iloc[i].XID, k)
        res.append(rank)
    return res

In [None]:
accuracies = get_accuracy(text_features, 10)

In [155]:
np.mean(accuracies)

0.08282735558324043

In [None]:
adapted_accuracies = get_accuracy(adapted_text, 10)

In [154]:
np.mean(adapted_accuracies)

0.10300963871221559