In [2]:
# packages
import os
import cv2
import gc
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import itertools
from tqdm.autonotebook import tqdm
import albumentations as A

import torch
from torch import nn
import torch.nn.functional as F
import timm
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

  from tqdm.autonotebook import tqdm


## Config

In [3]:
with open("CLIP_config.yaml", "r") as configs:
    config = yaml.load(configs,  Loader=yaml.FullLoader)

## Utils

In [4]:
class AvgMeter:
    """
    average metrics meter
    """
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    """
    get learning rate of the optimizer
    """
    for param_group in optimizer.param_groups:
        return param_group["lr"]

## Dataset

In [5]:
class CLIPDataset(torch.utils.data.Dataset):

    """
    CLIP model Datasets
    """
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        self.image_names = image_filenames
        self.captions = captions
        self.encoded_captions = tokenizer (
            list(captions), padding=True, truncation=True, max_length=config['model']['max_length']
        )
        self.transforms = transforms
        pass

    def __getitem__(self, idx):
        # ids: caption tokenizers
        item = {
            key: torch.tensor(values[idx]) for key, values in self.encoded_captions.items()
        }

        image = cv2.imread(f"{config['paths']['image_path']} / {self.image_names[idx]}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transforms:
            image = self.transforms(image=image)['image']

        item['image'] = torch.tensor(image).permute(2, 0, 1).float()

        # for display
        item['caption'] = self.captions[idx]

        return item

    def __len__(self):
        return len(self.captions)

In [6]:
# transforms
def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(config['model']['size'], config['model']['size'], always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(config['model']['size'], config['model']['size'], always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )

## Image Encoder

In [7]:
class ImageEncoder(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(self,
                 model_name=config['model']['model_name'],
                 pretrained=config['training']['pretrained'],
                 trainable=config['training']['trainable']):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        for p in self.model.parameters():
            p.requires_grad = trainable


    def forward(self, x):
        return self.model(x)

## Text Encoder

In [8]:
class TextEncoder(nn.Module):
    """
    Encode text (captions) using `DistillBertModel`
    """

    def __init__(self,
                 model_name=config['model']['text_encoder_model'],
                 pretrained=config['training']['pretrained'],
                 trainable=config['training']['trainable']):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable

        # CLS token as sentence embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

## Projection Head

In [9]:
class ProjectionHead(nn.Module):
    """
    Project embedding to target dimension
    """

    def __init__(self,
                 embedding_dim,
                 projection_dim=config['model']['projection_dim'],
                 dropout=config['model']['dropout']):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

## CLIP

![CLIP_structure](CLIP.png)

- similarity
$$
    \text{logits} = \frac{1}{T} E_{\text{text}} \cdot E_{\text{image}} \\
    \text{targets} = \frac{1}{2} (E_{\text{text}} \cdot E_{\text{text}}^T + E_{\text{image}} \cdot E_{\text{image}}^T) \\
    \text{Loss} = \frac{1}{2 |\mathcal{B}|} (\text{CrossEntropy(\text{logits}, \text{targets})} + \text{CrossEntropy(\text{logits}^T, \text{targets}^T)})

$$


In [10]:
def cross_entropy(preds, targets, reduction='none'):
    """
    customized cross entropy
    """
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == 'none':
        return loss
    elif reduction == 'mean':
        return loss.mean()



class CLIPModel(nn.Module):
    """
    CLIP model
    """
    def __init__(self,
                 temperature=config['training']['temperature'],
                 image_embedding=config['model']['image_embedding'],
                 text_embedding=config['model']['text_embedding']):
        super().__init__()
        self.text_encoder = TextEncoder()
        self.text_proj = ProjectionHead(embedding_dim=text_embedding)
        self.image_encoder = ImageEncoder()
        self.image_proj = ProjectionHead(embedding_dim=image_embedding)
        self.temperature = temperature


    def forward(self, batch):
        # get features
        image_features = self.image_encoder(batch['image'])
        text_features = self.text_encoder(batch['text'])

        # getting embeddings (with the same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # calculate the loss
        logits = image_embeddings @ text_embeddings / self.temperature
        image_sim = image_embeddings @ image_embeddings.T  # {batch_size, batch_size)
        text_sim = text_embeddings @ text_embeddings.T
        targets = F.softmax((image_sim + text_sim) / 2 * self.temperature, dim=-1)

        text_loss = cross_entropy(logits, targets)
        image_loss = cross_entropy(logits.T, targets.T)

        loss = (image_loss + text_loss) / 2  # (batch_size)
        return loss.mean()


## Training

In [11]:
def get_dataframes():
    """
    get image and caption dataframes
    """
    dataframe = pd.read_csv(f"{config['paths']['captions_path']}/captions.csv")
    dataframe['id'] = range(1, len(dataframe) + 1)
    max_id = dataframe["id"].max() + 1 if not config['training']['debug'] else 100
    # max_id = 100
    image_ids = np.arange(0, max_id)
    np.random.seed(42)

    # training set & validation set split
    valid_ids = np.random.choice(
        image_ids, size=int(0.2 * len(image_ids)), replace=False
    )
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
    valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
    return train_dataframe, valid_dataframe

def build_loaders(dataframe, tokenizer, mode):
    """
    Build dataloaders
    """
    transforms = get_transforms(mode=mode)
    dataset = CLIPDataset(
        dataframe["image"].values,
        dataframe["caption"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config['training']['batch_size'],
        num_workers=config['training']['num_workers'],
        shuffle=True if mode == "train" else False,
    )
    return dataloader

In [12]:
def train(model, train_loader, optimizer, lr_scheduler, step, device=config['training']['device']):
    """
    training per epoch
    """
    loss_meter = AvgMeter()
    tqdm_obj = tqdm(train_loader, total=len(train_loader))
    for batch in tqdm_obj:
        # clean title
        batch = {k: v.to(device) for k, v in batch.items() if k != 'caption'}
        loss = model(batch)
        loss.backward()
        optimizer.step()
        if step == 'batch':
            lr_scheduler.step()
        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_obj.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter

def valid_epoch(model, valid_loader,device=config['training']['device']):
    loss_meter = AvgMeter()

    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch in tqdm_object:
        batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(valid_loss=loss_meter.avg)
    return loss_meter

# training
def training():
    """
    CLIP model training
    """
    train_df, valid_df = get_dataframes()
    tokenizer = DistilBertTokenizer.from_pretrained(config['model']['text_tokenizer'])
    train_loader = build_loaders(train_df, tokenizer, mode="train")
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    model = CLIPModel().to(config['training']['device'])
    params = [
        {"params": model.image_encoder.parameters(), "lr": config['training']['image_encoder_lr']},
        {"params": model.text_encoder.parameters(), "lr": config['training']['text_encoder_lr']},
        {"params": itertools.chain(
            model.image_proj.parameters(), model.text_proj.parameters()
        ), "lr": config['training']['head_lr'], "weight_decay": config['training']['weight_decay']}
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=0.)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=config['training']['patience'], factor=config['training']['factor']
    )
    step = "epoch"

    best_loss = float('inf')

    for epoch in range(config['training']['epochs']):
        print(f"Epoch: {epoch + 1}")
        model.train()
        train_loss = train(model, train_loader, optimizer, lr_scheduler, step)
        model.eval()
        with torch.no_grad():
            valid_loss = valid_epoch(model, valid_loader)

        if valid_loss.avg < best_loss:
            best_loss = valid_loss.avg
            torch.save(model.state_dict(), "best.pt")
            print("Saved Best Model!")

        lr_scheduler.step(valid_loss.avg)


In [None]:
training()

Epoch: 1


  0%|          | 0/6473 [00:00<?, ?it/s]

## Inference

In [None]:
def get_image_embeddings(valid_df, model_path):
    tokenizer = DistilBertTokenizer.from_pretrained(config['model']['text_tokenizer'])
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    model = CLIPModel().to(config['training']['device'])
    model.load_state_dict(torch.load(model_path, map_location=config['training']['device']))
    model.eval()

    valid_image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            image_features = model.image_encoder(batch["image"].to(config['training']['device']))
            image_embeddings = model.image_projection(image_features)
            valid_image_embeddings.append(image_embeddings)
    return model, torch.cat(valid_image_embeddings)

In [None]:
_, valid_df = get_dataframes()
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")

## Finding Matches

In [None]:
def find_matches(model, image_embeddings, query, image_filenames, n=9):
    tokenizer = DistilBertTokenizer.from_pretrained(config['model']['text_tokenizer'])
    encoded_query = tokenizer([query])
    batch = {
        key: torch.tensor(values).to(config['training']['device'])
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)

    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T

    values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
    matches = [image_filenames[idx] for idx in indices[::5]]

    _, axes = plt.subplots(3, 3, figsize=(10, 10))
    for match, ax in zip(matches, axes.flatten()):
        image = cv2.imread(f"{config['paths']['image_path'].image_path}/{match}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")

    plt.show()

In [None]:
find_matches(model,
             image_embeddings,
             query="a group of people dancing in a party",
             image_filenames=valid_df['image'].values,
             n=9)