In [2]:
import os
import sys
import numpy as np
import torch

from datetime import datetime
from typing import Tuple
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.nn import KLDivLoss, CrossEntropyLoss, CosineEmbeddingLoss, MSELoss
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau

###  Combined Image and TextDistillation with Bigger Dataset with Loss2

Loss2 Definition : (MSE Loss + Cosine Loss + Constrastive Loss )


In [3]:
loss_name = "loss2"

### Loading Teacher models

In [4]:
import clip

model_name = "ViT-B/32"

model, preprocess = clip.load(model_name)

visual_teacher_model = model.visual
text_teacher_model = model.transformer


input_resolution = model.visual.input_resolution

### Instantiating Student models

[VisionTransformer](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L206)
[Transformer](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L195)

In [5]:
from clip.model import Transformer, VisionTransformer
from clip.model import convert_weights  # Make them float16

# Set Student Configuration

patch_size = 32
width = 384
layers = 6
heads = 12
output_dim = 512

visual_student_model = VisionTransformer(
    input_resolution=input_resolution,
    patch_size=patch_size,
    width=width,
    layers=layers,
    heads=heads,
    output_dim=output_dim,
)

width = 512
layers = 6
heads = 8  # More Number of Heads


def build_attention_mask():
    context_length = 77
    mask = torch.empty(context_length, context_length)
    mask.fill_(float("-inf"))
    mask.triu_(1)  # zero out the lower diagonal
    return mask


text_student_model = Transformer(
    width=width, layers=layers, heads=heads, attn_mask=build_attention_mask()
)


convert_weights(visual_student_model)
convert_weights(text_student_model)

In [6]:
def encode_text(transformer, text):

    x = model.token_embedding(text).type(model.dtype)  # [batch_size, n_ctx, d_model]

    x = x + model.positional_embedding.type(model.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND

    x = transformer(x)

    x = x.permute(1, 0, 2)  # LND -> NLD
    x = model.ln_final(x).type(model.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ model.text_projection

    return x

In [7]:
class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim=512, projection_dim=512, dropout=0.2):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

### Load the Conceptual Captions Dataset

In [8]:
# from concurrent.futures import ThreadPoolExecutor
# from functools import partial
# import io
# import urllib

# import PIL.Image

# from datasets import load_dataset
# from datasets.utils.file_utils import get_datasets_user_agent


# def fetch_single_image(image_url, timeout=None, retries=0):
#     for _ in range(retries + 1):
#         try:
#             request = urllib.request.Request(
#                 image_url,
#                 data=None,
#                 headers={"user-agent": get_datasets_user_agent()},
#             )
#             with urllib.request.urlopen(request, timeout=timeout) as req:
#                 image = PIL.Image.open(io.BytesIO(req.read()))
#             break
#         except Exception:
#             image = None
#     return image


# def fetch_images(batch, num_threads, timeout=None, retries=0):
#     fetch_single_image_with_args = partial(
#         fetch_single_image, timeout=timeout, retries=retries
#     )
#     with ThreadPoolExecutor(max_workers=num_threads) as executor:
#         batch["image"] = list(
#             executor.map(fetch_single_image_with_args, batch["image_url"])
#         )
#     return batch


# num_threads = 8
# dset = load_dataset("conceptual_captions",split='train[:50000]')#,cache_dir = "./data/ConceptualCaptions")


# dset = dset.filter(lambda example: len(example["caption"]) < 75)

# dset = dset.map(
#     fetch_images, batched=True, batch_size=100, fn_kwargs={"num_threads": num_threads}
# )


# dset = dset.remove_columns("image_url")
# # dset = dset.filter(lambda example : example["image"] is not None)

# # dset.save_to_disk("./data/processed")

#### Convert Dataset into Torch Dataloader

In [10]:
from torch.utils.data import DataLoader
from datasets import load_from_disk

dset = load_from_disk("./data/processed")


def transform_func(examples):
    examples["image"] = [preprocess(img) for img in examples["image"]]
    return examples


dset = dset.with_transform(transform_func)

train_dataloader = DataLoader(dset, batch_size=16, shuffle=True, num_workers=8)

In [11]:
dset

Dataset({
    features: ['caption', 'image'],
    num_rows: 7012
})

In [12]:
class DistillationTrainer:
    def __init__(self, *args, **kwargs):

        self.visual_teacher = visual_teacher_model
        self.text_teacher = text_teacher_model

        self.visual_student = visual_student_model
        self.text_student = text_student_model

        self.image_projection = ProjectionHead().half()
        self.text_projection = ProjectionHead().half()
        self.temperature = 1

        self.train_dataloader = train_dataloader

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.visual_teacher = self.visual_teacher.to(self.device)
        self.text_teacher = self.text_teacher.to(self.device)
        self.visual_student = self.visual_student.to(self.device)
        self.text_student = self.text_student.to(self.device)
        self.image_projection = self.image_projection.to(self.device)
        self.text_projection = self.text_projection.to(self.device)

        self.visual_teacher.eval()
        self.text_teacher.eval()

        self.epochs = 30
        self.start_epoch = 1

        # set up optimizer
        self.optimizer = SGD(
            list(self.visual_student.parameters())
            + list(self.text_student.parameters()),
            lr=0.001,
        )

    def compute_loss(self, images, texts, return_outputs=False):
        texts = clip.tokenize(texts)

        texts = texts.to(self.device)

        images = images.to(self.device).half()

        visual_outputs_student = self.visual_student(images)
        text_outputs_student = encode_text(self.text_student, texts)

        # compute teacher output
        with torch.no_grad():
            visual_outputs_teacher = self.visual_teacher(images)
            text_outputs_teacher = model.encode_text(texts)

        # assert size
        assert visual_outputs_student.size() == visual_outputs_teacher.size()
        assert text_outputs_student.size() == text_outputs_teacher.size()

        # MSE Loss between the embeddings

        mse_loss = MSELoss()
        image_mse_loss = mse_loss(visual_outputs_student, visual_outputs_teacher)
        text_mse_loss = mse_loss(text_outputs_student, text_outputs_teacher)

        # Cosine Loss
        image_cosine_loss = CosineEmbeddingLoss()(
            visual_outputs_teacher,
            visual_outputs_student,
            torch.ones(visual_outputs_teacher.size()[0]).to(self.device),
        )

        text_cosine_loss = CosineEmbeddingLoss()(
            text_outputs_teacher,
            text_outputs_student,
            torch.ones(text_outputs_teacher.size()[0]).to(self.device),
        )

        # Push visual_outputs_student and text_outputs_student closer
        image_embeddings = self.image_projection(visual_outputs_student)
        text_embeddings = self.text_projection(text_outputs_student)

        logits = (text_embeddings @ image_embeddings.T) / self.temperature

        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T

        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )

        texts_loss = self.cross_entropy(logits, targets, reduction="none")
        images_loss = self.cross_entropy(logits.T, targets.T, reduction="none")
        con_loss = (images_loss + texts_loss) / 2.0
        con_loss = con_loss.mean()

        loss = (
            image_mse_loss
            + text_mse_loss
            + image_cosine_loss
            + text_cosine_loss
            + con_loss
        )

        loss = loss / 5

        return loss

    def cross_entropy(self, preds, targets, reduction="none"):
        log_softmax = nn.LogSoftmax(dim=-1)
        loss = (-targets * log_softmax(preds)).sum(1)
        if reduction == "none":
            return loss
        elif reduction == "mean":
            return loss.mean()

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            print(f"Starting Epoch {epoch} ------------------------------------------")
            loss_value = self._train_epoch(epoch)
            print(f"Combined Loss Value after {epoch} Epoch is {loss_value}")

    def _train_epoch(self, epoch):
        loss_value = 0
        for batch_idx, data in enumerate(self.train_dataloader):

            self.optimizer.zero_grad()

            texts = data["caption"]
            images = data["image"]

            loss = self.compute_loss(images, texts)

            loss_value += loss

            loss.backward()

            self.optimizer.step()

            if batch_idx % 100 == 0:
                print(
                    f"Loss after {batch_idx}/{len(self.train_dataloader)} Batch is {loss_value/(batch_idx+1)} "
                )

        return loss_value.detach().cpu().numpy() / len(self.train_dataloader)

In [13]:
Trainer = DistillationTrainer(
    visual_teacher_model=visual_teacher_model,
    text_teacher_model=text_teacher_model,
    text_student_model=text_student_model,
    visual_student_model=visual_student_model,
    train_dataloader=train_dataloader,
)

In [14]:
Trainer.train()

Starting Epoch 1 ------------------------------------------
Loss after 0/439 Batch is 3.802734375 
Loss after 100/439 Batch is 2.73046875 
Loss after 200/439 Batch is 2.310546875 
Loss after 300/439 Batch is 2.09375 
Loss after 400/439 Batch is 1.9541015625 
Combined Loss Value after 1 Epoch is 1.9145785876993167
Starting Epoch 2 ------------------------------------------
Loss after 0/439 Batch is 1.3642578125 
Loss after 100/439 Batch is 1.45703125 
Loss after 200/439 Batch is 1.421875 
Loss after 300/439 Batch is 1.39453125 
Loss after 400/439 Batch is 1.373046875 
Combined Loss Value after 2 Epoch is 1.356492027334852
Starting Epoch 3 ------------------------------------------
Loss after 0/439 Batch is 1.2216796875 
Loss after 100/439 Batch is 1.2431640625 
Loss after 200/439 Batch is 1.220703125 
Loss after 300/439 Batch is 1.2177734375 
Loss after 400/439 Batch is 1.2080078125 
Combined Loss Value after 3 Epoch is 1.1993166287015946
Starting Epoch 4 -------------------------------

Loss after 200/439 Batch is 0.79833984375 
Loss after 300/439 Batch is 0.79638671875 
Loss after 400/439 Batch is 0.7900390625 
Combined Loss Value after 26 Epoch is 0.7864464692482915
Starting Epoch 27 ------------------------------------------
Loss after 0/439 Batch is 0.79345703125 
Loss after 100/439 Batch is 0.783203125 
Loss after 200/439 Batch is 0.79052734375 
Loss after 300/439 Batch is 0.79296875 
Loss after 400/439 Batch is 0.78662109375 
Combined Loss Value after 27 Epoch is 0.7835990888382688
Starting Epoch 28 ------------------------------------------
Loss after 0/439 Batch is 0.86572265625 
Loss after 100/439 Batch is 0.796875 
Loss after 200/439 Batch is 0.791015625 
Loss after 300/439 Batch is 0.78759765625 
Loss after 400/439 Batch is 0.78125 
Combined Loss Value after 28 Epoch is 0.7790432801822323
Starting Epoch 29 ------------------------------------------
Loss after 0/439 Batch is 0.8515625 
Loss after 100/439 Batch is 0.79150390625 
Loss after 200/439 Batch is 0.

In [15]:
from datetime import datetime

format_data = "date_%d_%m_%y_time_%H_%M_%S"
timestamp = datetime.strftime(datetime.now(), format_data)

torch.save(
    Trainer.visual_student.state_dict(),
    f"results/{timestamp}_{loss_name}_CombinedVisual_DistilledModel.pt",
)
torch.save(
    Trainer.text_student.state_dict(),
    f"results/{timestamp}_{loss_name}_CombinedText_DistilledModel.pt",
)

In [16]:
print(f"results/{timestamp}_{loss_name}_CombinedVisual_DistilledModel.pt")

results/date_08_12_22_time_04_47_47_loss2_CombinedVisual_DistilledModel.pt
