In [35]:
import os
import sys
import numpy as np
import torch

from datetime import datetime
from typing import Tuple

from torch.nn import Module
from torch.nn import KLDivLoss, CrossEntropyLoss, CosineEmbeddingLoss
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

### Loading Teacher model ---> CLIP Image Extractor 

In [36]:
import clip

model_name = "ViT-B/32"

# model is the torch model.
# preprocess function is for image preprocessing.

model, preprocess = clip.load(model_name)

# Get only the visual model
teacher_model = model.visual
input_resolution = model.visual.input_resolution

print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in model.visual.parameters()]):,}",
)
print("Input resolution:", input_resolution)

Model parameters: 87,849,216
Input resolution: 224


### Instantiating Student model 

[VisionTransformer](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L206)

In [37]:
from clip.model import VisionTransformer
from clip.model import convert_weights # Make them float16
# Set Student Configuration

patch_size = 32
width = 384
layers = 6
heads = 12
output_dim = 512

student_model = VisionTransformer(
    input_resolution=input_resolution,
    patch_size=patch_size,
    width=width,
    layers=layers,
    heads=heads,
    output_dim=output_dim,
)



convert_weights(student_model)


print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in student_model.parameters()]):,}",
)


Model parameters: 12,044,160


### Load the WIT Dataset

In [38]:
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import urllib

import PIL.Image

from datasets import load_dataset
from datasets.utils.file_utils import get_datasets_user_agent


def fetch_single_image(image_url, timeout=None, retries=0):
    for _ in range(retries + 1):
        try:
            request = urllib.request.Request(
                image_url,
                data=None,
                headers={"user-agent": get_datasets_user_agent()},
            )
            with urllib.request.urlopen(request, timeout=timeout) as req:
                image = PIL.Image.open(io.BytesIO(req.read()))
            break
        except Exception:
            image = None
    return image


def fetch_images(batch, num_threads, timeout=None, retries=0):
    fetch_single_image_with_args = partial(
        fetch_single_image, timeout=timeout, retries=retries
    )
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        batch["image"] = list(
            executor.map(fetch_single_image_with_args, batch["image_url"])
        )
    return batch


num_threads = 20
dset = load_dataset("cifar10")
# dset = dset.map(
#     fetch_images, batched=True, batch_size=100, fn_kwargs={"num_threads": num_threads}
# )

Reusing dataset cifar10 (/home/ecbm4040/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

In [39]:
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
                transforms.Resize(224),
                transforms.RandomCrop(224),
                #transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

cifar100 = torchvision.datasets.CIFAR100('data/',download=True,train=True,transform=transform)
train_dataloader = torch.utils.data.DataLoader(cifar100,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=8)

Files already downloaded and verified


In [40]:
class DistillationTrainer:
    def __init__(self, *args, **kwargs):
        self.teacher = teacher_model
        self.student = student_model
        self.preprocess = preprocess
        self.train_dataloader = train_dataloader

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.teacher = self.teacher.to(self.device)
        self.student = self.student.to(self.device)
        self.teacher.eval()

        self.epochs = 30
        self.start_epoch = 1

        # set up optimizer
        self.optimizer = Adam(self.student.parameters())

        # Set up LR Scheduler
        self.lr_scheduler = ReduceLROnPlateau(self.optimizer, "min")

    def compute_loss(self, images, return_outputs=False):
        images = images.to(self.device).half()
        import pdb; pdb.set_trace()
        outputs_student = self.student(images)

        # compute teacher output
        with torch.no_grad():
            outputs_teacher = self.teacher(images)

        # assert size
        assert outputs_student.size() == outputs_teacher.size()

        # Soften probabilities and compute distillation loss

        # KL Divergence Loss
        kl_loss = KLDivLoss(reduction="batchmean")
        loss = kl_loss(outputs_student, outputs_teacher)
        # Cosine loss
        loss = loss + CosineEmbeddingLoss()(
            outputs_teacher, outputs_student, torch.ones(outputs_teacher.size()[0]).to(self.device)
        )
        print(f"Loss is {loss}")
        return (loss, outputs_student) if return_outputs else loss

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            loss_value = self._train_epoch(epoch)
            print(f"KLD-CosineLoss after {epoch} Epoch is {loss_value}")

    def _train_epoch(self, epoch):
        loss_value = 0
        for batch_idx, (images, _) in enumerate(self.train_dataloader):
            
            loss = self.compute_loss( images)
            print(loss)
            loss_value += loss

            self.optimizer.zero_grad()
            
            loss.backward()
            
            self.optimizer.step()
            
            if batch_idx % 10 == 0:
                print(f"Loss after {batch_idx} Batch is {loss_value/(batch_idx+1)} ")

        return loss_value.detach().cpu().numpy() / len(self.train_dataloader)

In [41]:
Trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    train_dataloader=train_dataloader,
    preprocess = preprocess,
)

In [42]:
Trainer.train()

Loss is -59.8125
tensor(-59.8125, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 0 Batch is -59.8125 
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1

Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 100 Batch is nan 
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss is nan
tensor(nan, d

KeyboardInterrupt: 

In [None]:
for x in train_dataloader:
    print(x[0].shape,x[1].shape)
    break

In [None]:
# torch.save(
#     {
#         "distilbert": student_distil_bert.state_dict(),
#     },
#     f"distiled_distilbert_sst2_{datetime.now():%Y-%m-%d_%H-%M-%S%z}.pt",
# )