In [2]:
import os
import sys
import numpy as np
import torch

from datetime import datetime
from typing import Tuple
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.nn import KLDivLoss, CrossEntropyLoss, CosineEmbeddingLoss, MSELoss
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau

### Combined Image and TextDistillation with Bigger Dataset with Loss3


In [3]:
loss_name = "loss3"

#### Loading Teacher models

In [4]:
import clip

model_name = "ViT-B/32"

# model is the torch model.
# preprocess function is for image preprocessing.

model, preprocess = clip.load(model_name)

# Get only the visual model
visual_teacher_model = model.visual
text_teacher_model = model.transformer


input_resolution = model.visual.input_resolution

### Instantiating Student models

[VisionTransformer](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L206)
[Transformer](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L195)

In [5]:
from clip.model import Transformer, VisionTransformer
from clip.model import convert_weights  # Make them float16

# Set Student Configuration

patch_size = 32
width = 384
layers = 6
heads = 12
output_dim = 512

visual_student_model = VisionTransformer(
    input_resolution=input_resolution,
    patch_size=patch_size,
    width=width,
    layers=layers,
    heads=heads,
    output_dim=output_dim,
)

width = 512
layers = 6
heads = 8  # More Number of Heads


def build_attention_mask():
    context_length = 77
    mask = torch.empty(context_length, context_length)
    mask.fill_(float("-inf"))
    mask.triu_(1)  # zero out the lower diagonal
    return mask


text_student_model = Transformer(
    width=width, layers=layers, heads=heads, attn_mask=build_attention_mask()
)


convert_weights(visual_student_model)
convert_weights(text_student_model)

In [6]:
def encode_text(transformer, text):

    x = model.token_embedding(text).type(model.dtype)  # [batch_size, n_ctx, d_model]

    x = x + model.positional_embedding.type(model.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND

    x = transformer(x)

    x = x.permute(1, 0, 2)  # LND -> NLD
    x = model.ln_final(x).type(model.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ model.text_projection

    return x

In [7]:
class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim=512, projection_dim=512, dropout=0.2):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

### Load the WIT Dataset

### Convert Dataset into Torch Dataloader

In [8]:
from torch.utils.data import DataLoader
from datasets import load_from_disk

dset = load_from_disk("./data/processed")


def transform_func(examples):
    examples["image"] = [preprocess(img) for img in examples["image"]]
    return examples


dset = dset.with_transform(transform_func)

train_dataloader = DataLoader(dset, batch_size=16, shuffle=True, num_workers=8)

In [9]:
dset

Dataset({
    features: ['caption', 'image'],
    num_rows: 7012
})

In [10]:
class DistillationTrainer:
    def __init__(self, *args, **kwargs):

        self.visual_teacher = visual_teacher_model
        self.text_teacher = text_teacher_model

        self.visual_student = visual_student_model
        self.text_student = text_student_model

        self.image_projection = ProjectionHead().half()
        self.text_projection = ProjectionHead().half()
        self.temperature = 1

        self.train_dataloader = train_dataloader

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.visual_teacher = self.visual_teacher.to(self.device)
        self.text_teacher = self.text_teacher.to(self.device)
        self.visual_student = self.visual_student.to(self.device)
        self.text_student = self.text_student.to(self.device)
        self.image_projection = self.image_projection.to(self.device)
        self.text_projection = self.text_projection.to(self.device)

        self.visual_teacher.eval()
        self.text_teacher.eval()

        self.epochs = 30
        self.start_epoch = 1

        # set up optimizer
        self.optimizer = SGD(
            list(self.visual_student.parameters())
            + list(self.text_student.parameters()),
            lr=0.001,
        )

    def compute_loss(self, images, texts, return_outputs=False):
        texts = clip.tokenize(texts)

        texts = texts.to(self.device)

        images = images.to(self.device).half()

        visual_outputs_student = self.visual_student(images)
        text_outputs_student = encode_text(self.text_student, texts)

        # compute teacher output
        with torch.no_grad():
            visual_outputs_teacher = self.visual_teacher(images)
            text_outputs_teacher = model.encode_text(texts)

        # assert size
        assert visual_outputs_student.size() == visual_outputs_teacher.size()
        assert text_outputs_student.size() == text_outputs_teacher.size()

        # Get the image and text embeddings

        image_embeddings = self.image_projection(visual_outputs_teacher)
        text_embeddings = self.text_projection(text_outputs_teacher)
        teacher_logits = (text_embeddings @ image_embeddings.T) / self.temperature

        image_embeddings = self.image_projection(visual_outputs_student)
        text_embeddings = self.text_projection(text_outputs_student)
        logits = (text_embeddings @ image_embeddings.T) / self.temperature

        # KL Divergence Loss between logits of teacher and student

        kl_loss = KLDivLoss(reduction="batchmean", log_target=True)

        logits_kl_loss = kl_loss(F.log_softmax(teacher_logits), F.log_softmax(logits))

        # Push visual_outputs_student and text_outputs_student closer

        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T

        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )

        texts_loss = self.cross_entropy(logits, targets, reduction="none")
        images_loss = self.cross_entropy(logits.T, targets.T, reduction="none")
        con_loss = (images_loss + texts_loss) / 2.0
        con_loss = con_loss.mean()

        loss = (logits_kl_loss + con_loss) / 2

        return loss

    def cross_entropy(self, preds, targets, reduction="none"):
        log_softmax = nn.LogSoftmax(dim=-1)
        loss = (-targets * log_softmax(preds)).sum(1)
        if reduction == "none":
            return loss
        elif reduction == "mean":
            return loss.mean()

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            print(f"Starting Epoch {epoch} ------------------------------------------")
            loss_value = self._train_epoch(epoch)
            print(f"Combined Loss Value after {epoch} Epoch is {loss_value}")

    def _train_epoch(self, epoch):
        loss_value = 0
        for batch_idx, data in enumerate(self.train_dataloader):

            self.optimizer.zero_grad()

            texts = data["caption"]
            images = data["image"]

            loss = self.compute_loss(images, texts)

            loss_value += loss

            loss.backward()

            self.optimizer.step()

            if batch_idx % 100 == 0:
                print(
                    f"Loss after {batch_idx}/{len(self.train_dataloader)} Batch is {loss_value/(batch_idx+1)} "
                )

        return loss_value.detach().cpu().numpy() / len(self.train_dataloader)

In [11]:
Trainer = DistillationTrainer(
    visual_teacher_model=visual_teacher_model,
    text_teacher_model=text_teacher_model,
    text_student_model=text_student_model,
    visual_student_model=visual_student_model,
    train_dataloader=train_dataloader,
)

In [12]:
Trainer.train()

Starting Epoch 1 ------------------------------------------




Loss after 0/439 Batch is 21.609375 
Loss after 100/439 Batch is 24.75 
Loss after 200/439 Batch is 22.328125 
Loss after 300/439 Batch is 21.625 
Loss after 400/439 Batch is 21.25 
Combined Loss Value after 1 Epoch is 20.9749430523918
Starting Epoch 2 ------------------------------------------
Loss after 0/439 Batch is 14.9375 
Loss after 100/439 Batch is 17.875 
Loss after 200/439 Batch is 17.671875 
Loss after 300/439 Batch is 17.25 
Loss after 400/439 Batch is 17.203125 
Combined Loss Value after 2 Epoch is 17.047835990888384
Starting Epoch 3 ------------------------------------------
Loss after 0/439 Batch is 6.7421875 
Loss after 100/439 Batch is 15.421875 
Loss after 200/439 Batch is 15.515625 
Loss after 300/439 Batch is 15.75 
Loss after 400/439 Batch is 15.7890625 
Combined Loss Value after 3 Epoch is 15.854214123006834
Starting Epoch 4 ------------------------------------------
Loss after 0/439 Batch is 19.09375 
Loss after 100/439 Batch is 15.109375 
Loss after 200/439 Batc

Loss after 0/439 Batch is 8.875 
Loss after 100/439 Batch is 10.75 
Loss after 200/439 Batch is 10.796875 
Loss after 300/439 Batch is 11.015625 
Loss after 400/439 Batch is 10.8203125 
Combined Loss Value after 28 Epoch is 10.76993166287016
Starting Epoch 29 ------------------------------------------
Loss after 0/439 Batch is 9.734375 
Loss after 100/439 Batch is 11.0 
Loss after 200/439 Batch is 10.65625 
Loss after 300/439 Batch is 10.65625 
Loss after 400/439 Batch is 10.796875 
Combined Loss Value after 29 Epoch is 10.788154897494305
Starting Epoch 30 ------------------------------------------
Loss after 0/439 Batch is 9.2578125 
Loss after 100/439 Batch is 10.640625 
Loss after 200/439 Batch is 11.0859375 
Loss after 300/439 Batch is 10.984375 
Loss after 400/439 Batch is 10.8046875 
Combined Loss Value after 30 Epoch is 10.751708428246014


In [13]:
from datetime import datetime

format_data = "date_%d_%m_%y_time_%H_%M_%S"
timestamp = datetime.strftime(datetime.now(), format_data)

torch.save(
    Trainer.visual_student.state_dict(),
    f"results/{timestamp}_{loss_name}_CombinedVisual_DistilledModel.pt",
)
torch.save(
    Trainer.text_student.state_dict(),
    f"results/{timestamp}_{loss_name}_CombinedText_DistilledModel.pt",
)