In [1]:
import os
import sys
import numpy as np
import torch

from datetime import datetime
from typing import Tuple
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.nn import KLDivLoss, CrossEntropyLoss, CosineEmbeddingLoss,MSELoss
from torch.optim import Adam,SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau

### Loading Teacher model ---> CLIP Transformer

In [32]:
import clip

model_name = "ViT-B/32"

# model is the torch model.
# preprocess function is for image preprocessing.

model, preprocess = clip.load(model_name)

# Get the transformer model
teacher_model = model.transformer

print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in model.transformer.parameters()]):,}",
)



Model parameters: 37,828,608


### Instantiating Student model 

[Transformer](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L195)

In [79]:
from clip.model import Transformer
from clip.model import convert_weights # Make them float16
# Set Student Configuration

width = 512
layers = 6
heads = 8 # More Number of Heads 

def build_attention_mask():
    context_length = 77 
    mask = torch.empty(context_length,context_length)
    mask.fill_(float("-inf"))
    mask.triu_(1)  # zero out the lower diagonal
    return mask
    
student_model = Transformer(
    width=width,
    layers=layers,
    heads=heads,
     attn_mask=build_attention_mask()
    
)



convert_weights(student_model)


print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in student_model.parameters()]):,}",
)


Model parameters: 18,914,304


In [80]:
def encode_text(transformer, text):
    
    x = model.token_embedding(text).type(model.dtype)  # [batch_size, n_ctx, d_model]

    x = x + model.positional_embedding.type(model.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND
    
    x = transformer(x)
    
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = model.ln_final(x).type(model.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ model.text_projection

    return x

### Load the WIT Dataset

In [81]:
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import urllib

import PIL.Image

from datasets import load_dataset
from datasets.utils.file_utils import get_datasets_user_agent

num_threads = 20
dset = load_dataset("conceptual_captions",split='train[:10000]')

dset = dset.remove_columns("image_url")

dset = dset.filter(lambda example: len(example["caption"]) < 75)

No config specified, defaulting to: conceptual_captions/unlabeled
Reusing dataset conceptual_captions (/home/ecbm4040/.cache/huggingface/datasets/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8)
Loading cached processed dataset at /home/ecbm4040/.cache/huggingface/datasets/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8/cache-d49e7748d48fd605.arrow


In [82]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dset, batch_size=16,shuffle=True,num_workers=8)

In [83]:
class TextDistillationTrainer:
    def __init__(self, *args, **kwargs):
        
        self.teacher = teacher_model
        self.student = student_model
        self.train_dataloader = train_dataloader

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.teacher = self.teacher.to(self.device)
        self.student = self.student.to(self.device)
        self.teacher.eval()

        self.epochs = 30
        self.start_epoch = 1

        # set up optimizer
        self.optimizer = SGD(self.student.parameters(),lr = 0.001)

        
    def compute_loss(self, texts, return_outputs=False):
        texts = clip.tokenize(texts)
        
        texts = texts.to(self.device)

        outputs_student = encode_text(self.student,texts)
        
        # compute teacher output
        
        with torch.no_grad():
            outputs_teacher = model.encode_text(texts)
#             outputs_teacher = torch.tensor(outputs_teacher.detach().cpu().numpy()).to(self.device)
        

        # assert size
        assert outputs_student.size() == outputs_teacher.size()

        # KL Divergence Loss        
        kl_loss = KLDivLoss(reduction="batchmean",log_target =True)
        loss = kl_loss(F.log_softmax(outputs_student),F.log_softmax(outputs_teacher))

        # Cosine loss
        loss += CosineEmbeddingLoss()(
            outputs_teacher, outputs_student, torch.ones(outputs_teacher.size()[0]).to(self.device)
        )
        
#         #MSE Loss 
#         mse_loss = MSELoss()
#         loss += mse_loss(outputs_teacher, outputs_student)


        return  loss

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            loss_value = self._train_epoch(epoch)
            print(f"KLD-CosineLoss after {epoch} Epoch is {loss_value}")

    def _train_epoch(self, epoch):
        loss_value = 0
        for batch_idx, data in enumerate(self.train_dataloader):
        
            self.optimizer.zero_grad()
            
            texts = data["caption"]

            loss = self.compute_loss(texts)
            
            loss_value += loss
            
            loss.backward()
            
            self.optimizer.step()
            
            if batch_idx % 100 == 0:
                print(f"Loss after {batch_idx} Batch is {loss_value/(batch_idx+1)} ")

        return loss_value.detach().cpu().numpy() / len(self.train_dataloader)

In [84]:
Trainer = TextDistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    train_dataloader=train_dataloader,
)

In [85]:
Trainer.train()



Loss after 0 Batch is 2.30859375 
Loss after 100 Batch is 0.6328125 
Loss after 200 Batch is 0.595703125 
Loss after 300 Batch is 0.5751953125 
Loss after 400 Batch is 0.5634765625 
Loss after 500 Batch is 0.5537109375 
KLD-CosineLoss after 1 Epoch is 0.5530451866404715
Loss after 0 Batch is 0.453369140625 
Loss after 100 Batch is 0.4990234375 
Loss after 200 Batch is 0.4951171875 
Loss after 300 Batch is 0.490966796875 
Loss after 400 Batch is 0.489501953125 
Loss after 500 Batch is 0.486083984375 
KLD-CosineLoss after 2 Epoch is 0.4862475442043222
Loss after 0 Batch is 0.47802734375 
Loss after 100 Batch is 0.46533203125 
Loss after 200 Batch is 0.46044921875 
Loss after 300 Batch is 0.460205078125 
Loss after 400 Batch is 0.456298828125 
Loss after 500 Batch is 0.453369140625 
KLD-CosineLoss after 3 Epoch is 0.4535854616895874
Loss after 0 Batch is 0.45654296875 
Loss after 100 Batch is 0.431396484375 
Loss after 200 Batch is 0.429443359375 
Loss after 300 Batch is 0.432373046875 
L

Loss after 0 Batch is 0.342529296875 
Loss after 100 Batch is 0.318115234375 
Loss after 200 Batch is 0.31396484375 
Loss after 300 Batch is 0.3134765625 
Loss after 400 Batch is 0.31396484375 
Loss after 500 Batch is 0.315185546875 
KLD-CosineLoss after 30 Epoch is 0.3153241650294695


In [87]:
torch.save(Trainer.student.state_dict(),f"Text_DistilledModel.pt")