In [None]:
import os
import sys
from tempfile import NamedTemporaryFile
from urllib.request import urlopen
from urllib.parse import unquote, urlparse
from urllib.error import HTTPError
from zipfile import ZipFile
import tarfile
import shutil

CHUNK_SIZE = 40960
DATA_SOURCE_MAPPING = 'random-image-for-testing-classification:https%3A%2F%2Fstorage.googleapis.com%2Fkaggle-data-sets%2F823708%2F1413358%2Fbundle%2Farchive.zip%3FX-Goog-Algorithm%3DGOOG4-RSA-SHA256%26X-Goog-Credential%3Dgcp-kaggle-com%2540kaggle-161607.iam.gserviceaccount.com%252F20240811%252Fauto%252Fstorage%252Fgoog4_request%26X-Goog-Date%3D20240811T232501Z%26X-Goog-Expires%3D259200%26X-Goog-SignedHeaders%3Dhost%26X-Goog-Signature%3D778aca31c9d2d45b1d7a035fa239af647f68021731df4636ce729bd34fb487fde940363f7e22f3a432acdbdaf26db1ad1b9e76e1b11bb078cdeb13c3366d0e92349644553c3fcfc22056613faf6fabe31578e3e9715784989a5c267098ea12de9894ae62cff6017365fe503c50e15531c8d1a9e45b7de9f06e2a1e5c1d97b058f814bbf888a77354a80c17e679664ce4c314e2c00399fa004f00a73b9d2d3b99add322599284a265afc3dbb1aab016cb6c82a0434ab6c98ed12f0f8d4add417de2a1fbb38fc43daa1bc2b5fc5d357bc958d5ba585d1b5d90f40c99e01ef1690dc64203d949cfa1c3b610b1eca3cb1f4e002acb521b57897c14de037b7600eb5c,flickr30k-dataset:https%3A%2F%2Fstorage.googleapis.com%2Fkaggle-data-sets%2F4401550%2F7558580%2Fbundle%2Farchive.zip%3FX-Goog-Algorithm%3DGOOG4-RSA-SHA256%26X-Goog-Credential%3Dgcp-kaggle-com%2540kaggle-161607.iam.gserviceaccount.com%252F20240811%252Fauto%252Fstorage%252Fgoog4_request%26X-Goog-Date%3D20240811T232501Z%26X-Goog-Expires%3D259200%26X-Goog-SignedHeaders%3Dhost%26X-Goog-Signature%3D532be343a3760455e93cab4c13a61a3971c78c39e940a4db282e30b3748f6251a5cea848afb593810a55895dd9725ed8d014b99dc48cf27c85551d4522fb2400b8f0e5a8e1993146d0120cd40048e903021707f30e86d829bac4b7a68dc0ca3e282151d42d2b394729530ce65c80df977318a08498f44374b46aa9145b3b8d6ee62fa4a5e5788857758b1a0978347fedb96b9ed1c9b3999c2a3fd0c1dc3c60448b925de35f60772b43e3c3aa6666dff49f184cdba8a1ef56713e11c01256e61d557fb8468e280138b3fed2b8cc056322b22da2ccae6cac4e4a80506db71bab4f5aec38c9332e6d3df1da4cc63f51179480b9eb5ad634b9fa5eab2e1c384cfd27'

KAGGLE_INPUT_PATH='/kaggle/input'
KAGGLE_WORKING_PATH='/kaggle/working'
KAGGLE_SYMLINK='kaggle'

!umount /kaggle/input/ 2> /dev/null
shutil.rmtree('/kaggle/input', ignore_errors=True)
os.makedirs(KAGGLE_INPUT_PATH, 0o777, exist_ok=True)
os.makedirs(KAGGLE_WORKING_PATH, 0o777, exist_ok=True)

try:
  os.symlink(KAGGLE_INPUT_PATH, os.path.join("..", 'input'), target_is_directory=True)
except FileExistsError:
  pass
try:
  os.symlink(KAGGLE_WORKING_PATH, os.path.join("..", 'working'), target_is_directory=True)
except FileExistsError:
  pass

for data_source_mapping in DATA_SOURCE_MAPPING.split(','):
    directory, download_url_encoded = data_source_mapping.split(':')
    download_url = unquote(download_url_encoded)
    filename = urlparse(download_url).path
    destination_path = os.path.join(KAGGLE_INPUT_PATH, directory)
    try:
        with urlopen(download_url) as fileres, NamedTemporaryFile() as tfile:
            total_length = fileres.headers['content-length']
            print(f'Downloading {directory}, {total_length} bytes compressed')
            dl = 0
            data = fileres.read(CHUNK_SIZE)
            while len(data) > 0:
                dl += len(data)
                tfile.write(data)
                done = int(50 * dl / int(total_length))
                sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {dl} bytes downloaded")
                sys.stdout.flush()
                data = fileres.read(CHUNK_SIZE)
            if filename.endswith('.zip'):
              with ZipFile(tfile) as zfile:
                zfile.extractall(destination_path)
            else:
              with tarfile.open(tfile.name) as tarfile:
                tarfile.extractall(destination_path)
            print(f'\nDownloaded and uncompressed: {directory}')
    except HTTPError as e:
        print(f'Failed to load (likely expired) {download_url} to path {destination_path}')
        continue
    except OSError as e:
        print(f'Failed to load {download_url} to path {destination_path}')
        continue

print('Data source import complete.')


# 🛠 | Install Libraries

In [None]:
!pip install lightning timm torchinfo wandb

In [None]:
!wandb login 55d0dbf56bdd3224ddc3a254f8cdea33f62cdb72

# 📚 | Import Libraries

In [None]:
import os
import timm
from lightning.pytorch.loggers import WandbLogger
import torch, torch.nn as nn,torch.nn.functional as F
from lightning.pytorch import seed_everything
from lightning import LightningDataModule
from lightning import Trainer
from lightning import LightningModule
import lightning.pytorch as pl
import cv2
import pandas as pd
import numpy as np
from glob import glob
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AutoTokenizer, AutoModel, DistilBertForSequenceClassification, DistilBertModel, DistilBertTokenizer
from PIL import Image
from sklearn.model_selection import GroupKFold
from torchinfo import summary
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import AdamW
from sklearn.model_selection import StratifiedKFold
from lightning.pytorch.callbacks import ModelCheckpoint

# ⚙️ | Configuration

In [None]:
class CFG:
    debug = False
    seed = 42
    image_preset = "tiny_vit_21m_224.dist_in22k_ft_in1k"
    image_size = [224, 224]
    image_path = "/kaggle/input/flickr30k-dataset/Images"
    caption_path = "/kaggle/input/flickr30k-dataset"
    text_preset = "SmartComponents/bge-micro-v2"
    sequence_length = 200
    batch_size = 64
    device = 'cuda:0'
    epochs = 8
    embedding_dim = 256

    dropout = 0.1
    lr = 3e-4
    T_max = 3
    eta_min = 1e-6



device = torch.device(CFG.device)
seed_everything(CFG.seed)

# ♻️ | Reproducibility

# 📖 | Meta Data

In [None]:
df = pd.read_csv(f"{CFG.caption_path}/captions.txt")
df["image_path"] = CFG.image_path + "/" + df.image
df

In [None]:
if CFG.debug:
    df = df.iloc[:5000]

In [None]:
print(df.shape)
df = df.dropna(subset=['caption'])
df['caption'] = df['caption'].str.strip()
df = df.reset_index(drop=True)
df.shape

In [None]:
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=CFG.seed)

df['fold'] = -1

for fold, (_, valid_index) in enumerate(skf.split(X=df, y=df['image_path'])):
    df.loc[valid_index, 'fold'] = fold


train_df = df[df['fold'] != 0].reset_index(drop=True)
valid_df = df[df['fold'] == 0].reset_index(drop=True)
print(f"# Num Train: {len(train_df)} | Num Valid: {len(valid_df)}")

train_paths = train_df.image_path.values
train_texts = train_df.caption.values
valid_paths = valid_df.image_path.values

valid_texts = valid_df.caption.values

'''
gkf = GroupKFold(n_splits=5)

df['fold'] = -1
for fold, (train_index, valid_index) in enumerate(gkf.split(df, groups=df["image"])):
    df.loc[valid_index, 'fold'] = fold
sample_df = df.groupby("image").head(1).reset_index(drop=True)
train_df = sample_df[sample_df.fold != 0]
valid_df = sample_df[sample_df.fold == 0]
print(f"# Num Train: {len(train_df)} | Num Valid: {len(valid_df)}")

train_paths = train_df.image_path.values
train_texts = train_df.caption.values
valid_paths = valid_df.image_path.values

valid_texts = valid_df.caption.values'''

# 🔪 | Data Split

# 🍚 | DataLoader

In [None]:
class ImageTextDataset(Dataset):
    def __init__(self, image_filenames, captions, transforms, image_size=CFG.image_size[0]):
        self.image_filenames = image_filenames
        self.captions = captions
        self.transforms = transforms

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, index):
        image = Image.open(self.image_filenames[index])
        image = self.transforms(image)
        text = self.captions[index]

        return image, text


In [None]:
class ImageTextDataModule(LightningDataModule):
    def __init__(self, train_paths, train_texts, valid_paths, valid_texts, batch_size=CFG.batch_size, image_size=CFG.image_size):
        super().__init__()
        self.train_paths = train_paths
        self.train_texts = train_texts
        self.valid_paths = valid_paths
        self.valid_texts = valid_texts
        self.batch_size = batch_size
        self.image_size = image_size

        self.transform = transforms.Compose([
            transforms.Resize((self.image_size[0], self.image_size[0])),
            transforms.ToTensor()
        ])
        self.save_hyperparameters()

    def setup(self, stage=None):
        self.train_dataset = ImageTextDataset(self.train_paths, self.train_texts, self.transform)
        self.valid_dataset = ImageTextDataset(self.valid_paths, self.valid_texts, self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.batch_size, shuffle=False, drop_last=True, num_workers=4, pin_memory=True)


# 🔍 | Loss

# 🤖 | Modeling

## Projection Head

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, embedding_dim=CFG.embedding_dim, dropout=CFG.dropout):
        super(ProjectionHead, self).__init__()
        self.projection = nn.Linear(in_dim, embedding_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

## Image Encoder

In [None]:
class ModifiedConvNeXt(nn.Module):
    def __init__(self, backbone_name, pretrained=True):
        super(ModifiedConvNeXt, self).__init__()
        backbone = timm.create_model(backbone_name, pretrained=pretrained)
        backbone.head.fc = nn.Identity()

        self.backbone = backbone
        self.projection_head = ProjectionHead(576)

    def forward(self, x):
        x = self.backbone(x)
        x = self.projection_head(x)
        return x


def build_image_encoder(model_name=CFG.image_preset):
    model = ModifiedConvNeXt(backbone_name=model_name, pretrained=True)
    return model

## Text Encoder

In [None]:
class ModifiedDistilBertModel(nn.Module):
    def __init__(self):
        super(ModifiedDistilBertModel, self).__init__()


        self.backbone = AutoModel.from_pretrained(CFG.text_preset)
        self.projection_head = ProjectionHead(384)

    def forward(self, x):
        x = self.backbone(**x).last_hidden_state.mean(dim=1)
        x = self.projection_head(x)
        return x


def build_text_encoder():
    model = ModifiedDistilBertModel()
    return model

## SigLIP Model

In [None]:
def get_ground_truth(batch_size=CFG.batch_size):
    labels = -torch.ones((batch_size, batch_size))
    labels += 2 * torch.eye(batch_size)
    return labels

class SigLIPLoss(nn.Module):
    def __init__(self, name="siglip_loss"):
        super(SigLIPLoss, self).__init__()
        self.name = name

    def forward(self, y_true, y_pred):
        loss = -torch.sum(F.logsigmoid(y_true * y_pred), dim=-1)
        loss = torch.mean(loss)
        return loss


class SigLIPModule(LightningModule):
    def __init__(self, image_encoder, text_encoder, tokenizer, logit_scale_init=2.30, logit_bias_init=-10.0):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.tokenizer = tokenizer
        self.logit_scale = torch.nn.Parameter(torch.tensor(logit_scale_init))
        self.logit_bias = torch.nn.Parameter(torch.tensor(logit_bias_init))
        self.loss_fn = SigLIPLoss()
        self.save_hyperparameters()

    def forward(self, images, encoded_texts):
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(encoded_texts)
        logits = image_features @ text_features.T
        logits = self.logit_scale * logits + self.logit_bias
        return logits


    def training_step(self, batch, batch_idx):
        images, texts = batch
        encoded_texts = self.tokenizer(texts, return_tensors='pt', padding="max_length", max_length=200, truncation=True).to(self.device)
        logits = self(images, encoded_texts)
        ground_truth = get_ground_truth().to(self.device)
        loss = self.loss_fn(ground_truth, logits)
        self.log('train_loss', loss, on_step=True, on_epoch=False, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, texts = batch
        encoded_texts = self.tokenizer(texts, return_tensors='pt', padding="max_length", max_length=200, truncation=True).to(self.device)
        logits = self(images, encoded_texts)
        ground_truth = get_ground_truth().to(self.device)
        loss = self.loss_fn(ground_truth, logits)
        self.log('val_loss', loss, on_step=True, on_epoch=False, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=CFG.lr)
        scheduler = {
            'scheduler': CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.eta_min),
            'interval': 'epoch',
        }
        return [optimizer], [scheduler]

In [None]:
data_module = ImageTextDataModule(train_paths, train_texts, valid_paths, valid_texts, batch_size=CFG.batch_size)
wandb_logger = WandbLogger(project="SigLIP")

tokenizer = AutoTokenizer.from_pretrained(CFG.text_preset)

image_encoder=build_image_encoder()
text_encoder=build_text_encoder()
model = SigLIPModule(image_encoder=image_encoder, text_encoder=text_encoder, tokenizer=tokenizer)
model = model.to('cuda:0')

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",
    filename="{epoch}",
    save_top_k=-1,
    every_n_epochs=1)

trainer = Trainer(
    max_epochs=CFG.epochs,
    logger=wandb_logger,
    callbacks=[checkpoint_callback]
)

trainer.fit(model, datamodule=data_module)

In [None]:
def process_image(path):
    img = Image.open(path)
    img = data_module.transform(img)
    img = img.unsqueeze(0).to('cuda:0')
    return img

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=-1, keepdims=True)

def process_text(text):
    text = [f"a photo of a {x}" for x in text]
    text = tokenizer(text, return_tensors='pt', padding="max_length", max_length=200, truncation=True).to('cuda:0')
    return text

def zero_shot_classifier(image_path, labels):
    image = process_image(image_path)
    text = process_text(labels)
    with torch.no_grad():
        logits = model(image, text).cpu().numpy()

    probabilities = softmax(logits)
    pred_probabilities = dict(zip(labels, probabilities.squeeze()))
    pred_probabilities = {k: round(v * 100, 2) for k, v in pred_probabilities.items()}
    return pred_probabilities

In [None]:
files = [i  for i in os.listdir('/kaggle/input/random-image-for-testing-classification') if i.endswith('.jpg')]
classes = [i.replace('.jpg', '') for i in files]
for epoch in range(CFG.epochs):
    print(f'\nEPOCH: {epoch}\n')
    model = SigLIPModule.load_from_checkpoint(f'checkpoints/epoch={epoch}.ckpt', image_encoder=image_encoder, text_encoder=text_encoder, tokenizer=tokenizer, logit_scale_init=2.30, logit_bias_init=-10.0).to('cuda:0').eval()
    c = 0
    for file in files:
        ans = file.replace('.jpg', '')
        out = zero_shot_classifier(os.path.join('/kaggle/input/random-image-for-testing-classification', file), classes)

        top_prediction = max(out, key=out.get)
        print("Top prediction:", top_prediction, out[top_prediction])
        if top_prediction == ans:
            c += 1
        print('GT pred:', ans, out[ans], '\n')
    print(f'Right preds {c}/{len(files)}')