In [1]:
!pip install open-clip-torch

import open_clip
import tqdm
import torch
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model, _ , preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32",
    pretrained="laion2b_s34b_b79k"
)

model = model.to(device)

tokenizer = open_clip.get_tokenizer("ViT-B-32")

Collecting open-clip-torch
  Downloading open_clip_torch-2.24.0-py3-none-any.whl (1.5 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.5 MB[0m [31m6.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.5/1.5 MB[0m [31m28.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
Collecting ftfy (from open-clip-torch)
  Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
Collecting timm (from open-clip-torch)
  Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m56.5 MB/s[0m eta [36m0:00:00

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


open_clip_pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

In [2]:
import torch.nn.functional as F

def embeddings_to_class_probs(vision_embeddings, text_embeddings):
    vision_embeddings = vision_embeddings / vision_embeddings.norm(dim=-1, keepdim=True)
    text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
    logits = vision_embeddings @ text_embeddings.T
    class_probs = F.softmax(100. * logits, dim=-1)
    return class_probs

### SLT10 Dataset

In [None]:
import torch.nn as nn

labels = [
    "an airplane",
    "a bird",
    "a car",
    "a cat",
    "a deer",
    "a dog",
    "a horse",
    "a monkey",
    "a ship",
    "a truck"
]

text = tokenizer(labels).to(device)
text_embeddings = model.encode_text(text).to(device)

linear_probe = nn.Linear(512, len(labels))

In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import STL10

optimizer = torch.optim.Adam(linear_probe.parameters(), lr=3e-4)

dataset_path = '.'
# Define batch size for the DataLoader
batch_size = 64

train_dataset = STL10(
    root=dataset_path,
    download=True,
    split="train",
    transform=preprocess
)

# Create a DataLoader for training dataset
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

test_dataset = STL10(
    root=dataset_path,
    download=True,
    split="test"
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

Files already downloaded and verified
Files already downloaded and verified


### Zero Shot

In [None]:
num_correct = 0

for image, label in tqdm.tqdm(test_dataset):
    input_tensor = preprocess(image).unsqueeze(0)
    vision_embeddings = model.encode_image(input_tensor)
    output_class_probs = embeddings_to_class_probs(vision_embeddings, text_embeddings)
    output_label = torch.argmax(output_class_probs, dim=-1)
    num_correct += int(torch.count_nonzero(output_label == label))

accuracy = 100. * num_correct / len(test_dataset)
print(accuracy)

### Linear head for *classification*

In [None]:
# Eval the model

num_correct = 0

for image, label in tqdm.tqdm(test_dataset):
    input_tensor = preprocess(image).unsqueeze(0)
    vision_embeddings = model.encode_image(input_tensor)
    output_logits = linear_probe(vision_embeddings)
    output_logprob = F.log_softmax(output_logits, dim=-1)
    output_label = torch.argmax(output_logprob, dim=-1)
    num_correct += int(torch.count_nonzero(output_label == label))

accuracy = 100. * num_correct / len(dataset)
print(accuracy) ## should be 98 (bound to be verified)

### CIFAR10 Dataset

In [14]:
labels = [
    "an airplane",
    "an automobile",
    "a bird",
    "a cat",
    "a deer",
    "a dog",
    "a frog",
    "a horse",
    "a ship",
    "a truck"
]

text = tokenizer(labels).to(device)
text_embeddings = model.encode_text(text).to(device)

linear_probe = nn.Linear(512, len(labels)).to(device)

In [11]:
import tqdm
import torch
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

dataset_path = '.'

optimizer = torch.optim.Adam(linear_probe.parameters(), lr=3e-4)

train_dataset = CIFAR10(
    root=dataset_path,
    download=True,
    train=True,
    transform=preprocess
)

test_dataset = CIFAR10(
    root=dataset_path,
    download=True,
    train=False,
    transform=None
)

batch_size = 64

# Create a DataLoader for training dataset
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

Files already downloaded and verified
Files already downloaded and verified


#### Zero Shot

In [None]:
num_correct = 0

for image, label in tqdm.tqdm(test_dataset):
    input_tensor = preprocess(image).unsqueeze(0).to(device)
    vision_embeddings = model.encode_image(input_tensor).to(device)
    output_class_probs = embeddings_to_class_probs(vision_embeddings, text_embeddings)
    output_label = torch.argmax(output_class_probs, dim=-1)
    num_correct += int(torch.count_nonzero(output_label == label))

accuracy = 100. * num_correct / len(test_dataset)
print(accuracy)

#### Linear head for classification


In [18]:
num_epochs = 10

model.eval() # freeze the clip model, we are only training the linear layer

losses = []

for epoch in range(num_epochs):
    epoch_losses = []
    for input_tensor, label in iter(tqdm.tqdm(train_loader)):
        input_tensor , label = input_tensor.to(device), label.to(device)
        vision_embeddings = model.encode_image(input_tensor)
        optimizer.zero_grad()
        output_logits = linear_probe(vision_embeddings)
        output_logprob = F.log_softmax(output_logits, dim=-1)
        loss = F.nll_loss(output_logprob, label)

        epoch_losses.append(loss.item())

        loss.backward()
        optimizer.step()

    epoch_average_loss = sum(epoch_losses) / len(epoch_losses)
    losses.append(epoch_average_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

100%|██████████| 782/782 [06:51<00:00,  1.90it/s]


Epoch [1/10], Loss: 2.4414


100%|██████████| 782/782 [06:50<00:00,  1.90it/s]


Epoch [2/10], Loss: 2.3688


100%|██████████| 782/782 [06:50<00:00,  1.90it/s]


Epoch [3/10], Loss: 2.2351


100%|██████████| 782/782 [06:50<00:00,  1.90it/s]


Epoch [4/10], Loss: 2.3377


100%|██████████| 782/782 [06:50<00:00,  1.90it/s]


Epoch [5/10], Loss: 2.3204


100%|██████████| 782/782 [06:51<00:00,  1.90it/s]


Epoch [6/10], Loss: 2.2168


100%|██████████| 782/782 [06:51<00:00,  1.90it/s]


Epoch [7/10], Loss: 2.3108


100%|██████████| 782/782 [06:50<00:00,  1.90it/s]


Epoch [8/10], Loss: 2.2826


100%|██████████| 782/782 [06:48<00:00,  1.91it/s]


Epoch [9/10], Loss: 2.3488


100%|██████████| 782/782 [06:50<00:00,  1.90it/s]

Epoch [10/10], Loss: 2.2735





In [1]:
num_correct = 0

for image, label in tqdm.tqdm(test_dataset):
    input_tensor = preprocess(image).unsqueeze(0)
    vision_embeddings = model.encode_image(input_tensor)
    output_class_probs = embeddings_to_class_probs(vision_embeddings, text_embeddings)
    output_label = torch.argmax(output_class_probs, dim=-1)
    num_correct += int(torch.count_nonzero(output_label == label))

accuracy = 100. * num_correct / len(test_dataset)
print(accuracy)

NameError: name 'tqdm' is not defined

### Person ReId

In [None]:
from PIL import Image, ImageFile

from torch.utils.data import Dataset
import os.path as osp
import random
import torch
ImageFile.LOAD_TRUNCATED_IMAGES = True


def read_image(img_path):
    """Keep reading image until succeed.
    This can avoid IOError incurred by heavy IO process."""
    got_img = False
    if not osp.exists(img_path):
        raise IOError("{} does not exist".format(img_path))
    while not got_img:
        try:
            img = Image.open(img_path).convert('RGB')
            got_img = True
        except IOError:
            print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
            pass
    return img


class BaseDataset(object):
    """
    Base class of reid dataset
    """

    def get_imagedata_info(self, data):
        pids, cams, tracks = [], [], []
        for _, pid, camid, trackid in data:
            pids += [pid]
            cams += [camid]
            tracks += [trackid]
        pids = set(pids)
        cams = set(cams)
        tracks = set(tracks)
        num_pids = len(pids)
        num_cams = len(cams)
        num_imgs = len(data)
        num_views = len(tracks)
        return num_pids, num_imgs, num_cams, num_views

    def print_dataset_statistics(self):
        raise NotImplementedError


class BaseImageDataset(BaseDataset):
    """
    Base class of image reid dataset
    """

    def print_dataset_statistics(self, train, query, gallery):
        num_train_pids, num_train_imgs, num_train_cams, num_train_views = self.get_imagedata_info(train)
        num_query_pids, num_query_imgs, num_query_cams, num_train_views = self.get_imagedata_info(query)
        num_gallery_pids, num_gallery_imgs, num_gallery_cams, num_train_views = self.get_imagedata_info(gallery)

        print("Dataset statistics:")
        print("  ----------------------------------------")
        print("  subset   | # ids | # images | # cameras")
        print("  ----------------------------------------")
        print("  train    | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
        print("  query    | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
        print("  gallery  | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
        print("  ----------------------------------------")


class ImageDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img_path, pid, camid, trackid = self.dataset[index]
        img = read_image(img_path)

        if self.transform is not None:
            img = self.transform(img)

        return img, pid, camid, trackid, img_path.split('/')[-1]

In [None]:
# encoding: utf-8
"""
@author:  sherlock
@contact: sherlockliao01@gmail.com
"""

import glob
import re

import os.path as osp

from collections import defaultdict
import pickle
class Market1501(BaseImageDataset):
    """
    Market1501
    Reference:
    Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
    URL: http://www.liangzheng.org/Project/project_reid.html

    Dataset statistics:
    # identities: 1501 (+1 for background)
    # images: 12936 (train) + 3368 (query) + 15913 (gallery)
    """
    dataset_dir = 'Market-1501-v15.09.15'

    def __init__(self, root='', verbose=True, pid_begin = 0, **kwargs):
        super(Market1501, self).__init__()
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')

        self._check_before_run()
        self.pid_begin = pid_begin
        train = self._process_dir(self.train_dir, relabel=True)
        query = self._process_dir(self.query_dir, relabel=False)
        gallery = self._process_dir(self.gallery_dir, relabel=False)

        if verbose:
            print("=> Market1501 loaded")
            self.print_dataset_statistics(train, query, gallery)

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train)
        self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query)
        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery)

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))

    def _process_dir(self, dir_path, relabel=False):
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        pattern = re.compile(r'([-\d]+)_c(\d)')

        pid_container = set()
        for img_path in sorted(img_paths):
            pid, _ = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # junk images are just ignored
            pid_container.add(pid)
        pid2label = {pid: label for label, pid in enumerate(pid_container)}
        dataset = []
        for img_path in sorted(img_paths):
            pid, camid = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # junk images are just ignored
            assert 0 <= pid <= 1501  # pid == 0 means background
            assert 1 <= camid <= 6
            camid -= 1  # index starts from 0
            if relabel: pid = pid2label[pid]

            dataset.append((img_path, self.pid_begin + pid, camid, 0))
        return dataset

In [None]:
!pip install yacs

import torchvision.transforms as T
from yacs.config import CfgNode as CN

# cfg = CN()

# cfg.merge_from_file('./vit_clipreid.yml')
# cfg.freeze()

val_transforms = T.Compose([
    # T.Resize(cfg.INPUT.SIZE_TEST),
    T.Resize([256, 128]),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = Market1501(root=(''))

market1501_val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)

market1501_val_loader = DataLoader(
    market1501_val_set, batch_size=32, shuffle=False,
    # collate_fn=val_collate_fn
)

In [None]:
train_transforms = T.Compose([
            T.Resize([256, 128], interpolation=3),
            T.RandomHorizontalFlip(p=0.5),
            T.Pad(10),
            T.RandomCrop([256, 128]),
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            # RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'),
            # RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN)
        ])

train_set = ImageDataset(dataset.train, train_transforms)
train_set_normal = ImageDataset(dataset.train, val_transforms)

market1501_train_loader = DataLoader(
    train_set, batch_size=32,
    sampler=RandomIdentitySampler(dataset.train, 64, 4),
    num_workers = 8,
    # collate_fn=train_collate_fn
)


## Linear head for classification



In [None]:
import torch.nn as nn

linear_probe = nn.Linear(512, 1501)

In [None]:
import tqdm
import torch
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from torchvision.datasets import STL10

optimizer = torch.optim.Adam(linear_probe.parameters(), lr=3e-4)

num_epochs = 3

for epoch in range(num_epochs):
    for input_tensor, label in iter(tqdm.tqdm(market1501_train_loader)):
        vision_embeddings = model.encode_image(input_tensor)
        optimizer.zero_grad()
        output_logits = linear_probe(vision_embeddings)
        output_logprob = F.log_softmax(output_logits, dim=-1)
        loss = F.nll_loss(output_logprob, label)
        loss.backward()
        optimizer.step()

In [None]:
import tqdm
import torch
from torchvision.datasets import CIFAR10
from torch.utils.data import Subset

num_correct = 0

for image, label in tqdm.tqdm(market1501_val_loader):
    input_tensor = preprocess(image).unsqueeze(0)
    vision_embeddings = model.encode_image(input_tensor)
    output_class_probs = embeddings_to_class_probs(vision_embeddings, text_embeddings)
    output_label = torch.argmax(output_class_probs, dim=-1)
    num_correct += int(torch.count_nonzero(output_label == label))

accuracy = 100. * num_correct / len(dataset)
print(accuracy)

## Knowledge Distillation


In [None]:
import torch.nn as nn

# Define the ResNet18 model as the student model
class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        # Modify the classifier layer for your specific task

    def forward(self, x):
        return self.resnet18(x)

# Knowledge Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self):
        super(DistillationLoss, self).__init__()

    def forward(self, outputs_student, outputs_teacher):
        return nn.MSELoss()(outputs_student, outputs_teacher)

class ProjectionHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProjectionHead, self).__init__()
        self.projection_head = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.projection_head(x)

### *Contrastive* Relational Distillation

In [None]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms
import matplotlib.pyplot as plt

## Knowledge Distillation
student_model = ResNet18()
projection_head = ProjectionHead(1000, 512)  # Projection head outside ViT
distillation_loss = DistillationLoss()

# Define optimizer and learning rate
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

losses = []

# Training loop
num_epochs = 3  # Adjust as needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.eval()
student_model.to(device)
projection_head.to(device)

for epoch in range(num_epochs):
    epoch_losses = []  # Store losses for each epoch
    for inputs, labels in tqdm.tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass on the teacher model
        with torch.no_grad():
            teacher_outputs = model.encode_image(inputs)

        # Forward pass on the student model
        student_outputs = student_model(inputs)
        student_proj = projection_head(student_outputs)

        # Compute the distillation loss
        loss = distillation_loss(student_proj, teacher_outputs)

        epoch_losses.append(loss.item())

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_average_loss = sum(epoch_losses) / len(epoch_losses)
    losses.append(epoch_average_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

In [None]:
## Zero shot evaluation of ResNet18

num_correct = 0

for image, label in tqdm.tqdm(dataset):
    input_tensor = preprocess(image).unsqueeze(0)
    student_outputs = student_model(input_tensor)
    vistion_embeddings = projection_head(student_outputs)
    output_class_probs = embeddings_to_class_probs(vision_embeddings, text_embeddings)
    output_label = torch.argmax(output_class_probs, dim=-1)
    num_correct += int(torch.count_nonzero(output_label == label))

accuracy = 100. * num_correct / len(dataset)
print(accuracy)

### Feature Distillation

In [None]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms
import matplotlib.pyplot as plt

## Knowledge Distillation
student_model = ResNet18()
student_model.load_state_dict(torch.load('resnet18_student_model.pth'))
projection_head = ProjectionHead(1000, 512)  # Projection head outside ViT
distillation_loss = DistillationLoss()

# Define optimizer and learning rate
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

losses = []

# Training loop
num_epochs = 20  # Adjust as needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.eval()
student_model.to(device)
projection_head.to(device)

for epoch in range(num_epochs):
    epoch_losses = []  # Store losses for each epoch
    for inputs, labels in tqdm.tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass on the teacher model
        with torch.no_grad():
            teacher_outputs = model.encode_image(inputs)

        # Forward pass on the student model
        student_outputs = student_model(inputs)
        student_proj = projection_head(student_outputs)

        # Compute the distillation loss
        loss = distillation_loss(student_proj, teacher_outputs)

        epoch_losses.append(loss.item())

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_average_loss = sum(epoch_losses) / len(epoch_losses)
    losses.append(epoch_average_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

100%|██████████| 79/79 [00:31<00:00,  2.47it/s]


Epoch [1/20], Loss: 0.1050


100%|██████████| 79/79 [00:30<00:00,  2.55it/s]


Epoch [2/20], Loss: 0.0880


100%|██████████| 79/79 [00:30<00:00,  2.61it/s]


Epoch [3/20], Loss: 0.1086


100%|██████████| 79/79 [00:30<00:00,  2.55it/s]


Epoch [4/20], Loss: 0.0880


100%|██████████| 79/79 [00:30<00:00,  2.58it/s]


Epoch [5/20], Loss: 0.0799


100%|██████████| 79/79 [00:30<00:00,  2.58it/s]


Epoch [6/20], Loss: 0.0681


100%|██████████| 79/79 [00:30<00:00,  2.57it/s]


Epoch [7/20], Loss: 0.0726


100%|██████████| 79/79 [00:30<00:00,  2.59it/s]


Epoch [8/20], Loss: 0.0698


 44%|████▍     | 35/79 [00:13<00:17,  2.59it/s]

In [None]:
# Plot the loss curve
plt.plot(losses, label='Distillation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Knowledge Distillation Loss Over Epochs')
plt.legend()
plt.show()

# Save the trained student model
torch.save(student_model.state_dict(), 'resnet18_student_model.pth')

In [None]:
## Zero shot evaluation of ResNet18
num_correct = 0

for image, label in tqdm.tqdm(test_dataset):
    input_tensor = preprocess(image).unsqueeze(0).to(device)
    student_outputs = student_model(input_tensor)
    vision_embeddings = projection_head(student_outputs).to(device)
    output_class_probs = embeddings_to_class_probs(vision_embeddings, text_embeddings)
    output_label = torch.argmax(output_class_probs, dim=-1)
    num_correct += int(torch.count_nonzero(output_label == label))

accuracy = 100. * num_correct / len(test_dataset)
print(accuracy)