In [1]:
import os
import sys
import numpy as np
import torch

from datetime import datetime
from typing import Tuple
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.nn import KLDivLoss, CrossEntropyLoss, CosineEmbeddingLoss, MSELoss
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau

### Loading Teacher model ---> CLIP Image Extractor 

In [2]:
import clip

model_name = "ViT-B/32"

# model is the torch model.
# preprocess function is for image preprocessing.

model, preprocess = clip.load(model_name)

# Get only the visual model
teacher_model = model.visual
input_resolution = model.visual.input_resolution

print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in model.visual.parameters()]):,}",
)
print("Input resolution:", input_resolution)

Model parameters: 87,849,216
Input resolution: 224


In [3]:
model.transformer

Transformer(
  (resblocks): Sequential(
    (0): ResidualAttentionBlock(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (c_fc): Linear(in_features=512, out_features=2048, bias=True)
        (gelu): QuickGELU()
        (c_proj): Linear(in_features=2048, out_features=512, bias=True)
      )
      (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (1): ResidualAttentionBlock(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (c_fc): Linear(in_features=512, out_features=2048, bias=True)
        (gelu): QuickGELU()
        (c_proj): Linear(in_features=2048, out_features=512, bias=True)
      

### Instantiating Student model 

[VisionTransformer](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L206)

In [4]:
from clip.model import VisionTransformer
from clip.model import convert_weights  # Make them float16

# Set Student Configuration

patch_size = 32
width = 384
layers = 6
heads = 12
output_dim = 512

student_model = VisionTransformer(
    input_resolution=input_resolution,
    patch_size=patch_size,
    width=width,
    layers=layers,
    heads=heads,
    output_dim=output_dim,
)


convert_weights(student_model)


print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in student_model.parameters()]):,}",
)

Model parameters: 12,044,160


### Load the WIT Dataset

In [None]:
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import urllib

import PIL.Image

from datasets import load_dataset
from datasets.utils.file_utils import get_datasets_user_agent


def fetch_single_image(image_url, timeout=None, retries=0):
    for _ in range(retries + 1):
        try:
            request = urllib.request.Request(
                image_url,
                data=None,
                headers={"user-agent": get_datasets_user_agent()},
            )
            with urllib.request.urlopen(request, timeout=timeout) as req:
                image = PIL.Image.open(io.BytesIO(req.read()))
            break
        except Exception:
            image = None
    return image


def fetch_images(batch, num_threads, timeout=None, retries=0):
    fetch_single_image_with_args = partial(
        fetch_single_image, timeout=timeout, retries=retries
    )
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        batch["image"] = list(
            executor.map(fetch_single_image_with_args, batch["image_url"])
        )
    return batch


num_threads = 1
# dset = load_dataset("cifar10")
dset = load_dataset("conceptual_captions")
dset = dset.map(
    fetch_images, batched=True, batch_size=4, fn_kwargs={"num_threads": num_threads}
)

No config specified, defaulting to: conceptual_captions/unlabeled
Reusing dataset conceptual_captions (/home/ecbm4040/.cache/huggingface/datasets/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8)


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/829584 [00:00<?, ?ba/s]

In [334]:
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.RandomCrop(224),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0, 0, 0), (1, 1, 1)),
    ]
)

cifar100 = torchvision.datasets.CIFAR100(
    "data/", download=True, train=True, transform=transform
)
train_dataloader = torch.utils.data.DataLoader(
    cifar100, batch_size=32, shuffle=True, num_workers=8
)

Files already downloaded and verified


In [335]:
class DistillationTrainer:
    def __init__(self, *args, **kwargs):
        self.teacher = teacher_model
        self.student = student_model
        self.train_dataloader = train_dataloader

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.teacher = self.teacher.to(self.device)
        self.student = self.student.to(self.device)
        self.teacher.eval()

        self.epochs = 30
        self.start_epoch = 1

        # set up optimizer
        self.optimizer = SGD(self.student.parameters(), lr=0.001)

        # Set up LR Scheduler

    #         self.lr_scheduler = ReduceLROnPlateau(self.optimizer, "min")

    def compute_loss(self, images, return_outputs=False):
        images = images.to(self.device).half()

        outputs_student = self.student(images)

        # compute teacher output
        with torch.no_grad():
            outputs_teacher = self.teacher(images)
        #             outputs_teacher = torch.tensor(outputs_teacher.detach().cpu().numpy()).to(self.device)

        # assert size
        assert outputs_student.size() == outputs_teacher.size()

        # Soften probabilities and compute distillation loss

        #         KL Divergence Loss
        kl_loss = KLDivLoss(reduction="batchmean", log_target=True)
        loss = kl_loss(F.log_softmax(outputs_student), F.log_softmax(outputs_teacher))

        # Cosine loss
        loss += CosineEmbeddingLoss()(
            outputs_teacher,
            outputs_student,
            torch.ones(outputs_teacher.size()[0]).to(self.device),
        )

        #         #MSE Loss
        #         mse_loss = MSELoss()
        #         loss += mse_loss(outputs_teacher, outputs_student)

        return loss

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            loss_value = self._train_epoch(epoch)
            print(f"KLD-CosineLoss after {epoch} Epoch is {loss_value}")

    def _train_epoch(self, epoch):
        loss_value = 0
        for batch_idx, (images, _) in enumerate(self.train_dataloader):

            self.optimizer.zero_grad()
            #             labels = torch.nn.functional.one_hot(labels,num_classes=100)

            loss = self.compute_loss(images)

            loss_value += loss

            loss.backward()

            #             torch.autograd.set_detect_anomaly(True)

            self.optimizer.step()

            #             print("After",self.student.conv1.weight)

            if batch_idx % 10 == 0:
                print(f"Loss after {batch_idx} Batch is {loss_value/(batch_idx+1)} ")

        return loss_value.detach().cpu().numpy() / len(self.train_dataloader)

In [336]:
Trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    train_dataloader=train_dataloader,
    preprocess=preprocess,
)

In [None]:
Trainer.train()



Loss after 0 Batch is 1.5751953125 
Loss after 10 Batch is 1.4755859375 
Loss after 20 Batch is 1.38671875 
Loss after 30 Batch is 1.3125 
Loss after 40 Batch is 1.2490234375 
Loss after 50 Batch is 1.193359375 
Loss after 60 Batch is 1.1455078125 
Loss after 70 Batch is 1.1044921875 
Loss after 80 Batch is 1.068359375 
Loss after 90 Batch is 1.033203125 
Loss after 100 Batch is 1.001953125 
Loss after 110 Batch is 0.9736328125 
Loss after 120 Batch is 0.9462890625 
Loss after 130 Batch is 0.921875 
Loss after 140 Batch is 0.89990234375 
Loss after 150 Batch is 0.880859375 
Loss after 160 Batch is 0.86083984375 
Loss after 170 Batch is 0.83984375 
Loss after 180 Batch is 0.8212890625 
Loss after 190 Batch is 0.80419921875 
Loss after 200 Batch is 0.7890625 
Loss after 210 Batch is 0.775390625 
Loss after 220 Batch is 0.76318359375 
Loss after 230 Batch is 0.75146484375 
Loss after 240 Batch is 0.7412109375 
Loss after 250 Batch is 0.72900390625 
Loss after 260 Batch is 0.71630859375 
L

Loss after 520 Batch is 0.1578369140625 
Loss after 530 Batch is 0.1578369140625 
Loss after 540 Batch is 0.157470703125 
Loss after 550 Batch is 0.1573486328125 
Loss after 560 Batch is 0.156982421875 
Loss after 570 Batch is 0.15673828125 
Loss after 580 Batch is 0.1563720703125 
Loss after 590 Batch is 0.156005859375 
Loss after 600 Batch is 0.155517578125 
Loss after 610 Batch is 0.155029296875 
Loss after 620 Batch is 0.1546630859375 
Loss after 630 Batch is 0.1544189453125 
Loss after 640 Batch is 0.1541748046875 
Loss after 650 Batch is 0.154052734375 
Loss after 660 Batch is 0.1536865234375 
Loss after 670 Batch is 0.1531982421875 
Loss after 680 Batch is 0.1529541015625 
Loss after 690 Batch is 0.1527099609375 
Loss after 700 Batch is 0.152587890625 
Loss after 710 Batch is 0.1522216796875 
Loss after 720 Batch is 0.15185546875 
Loss after 730 Batch is 0.1514892578125 
Loss after 740 Batch is 0.1512451171875 
Loss after 750 Batch is 0.15087890625 
Loss after 760 Batch is 0.150

In [None]:
# torch.save(
#     {
#         "distilbert": student_distil_bert.state_dict(),
#     },
#     f"distiled_distilbert_sst2_{datetime.now():%Y-%m-%d_%H-%M-%S%z}.pt",
# )