In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
! pip install timm

Collecting timm
  Downloading timm-0.9.11-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m31.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.11


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier

import json
import pathlib

import timm
import tqdm
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
import torchvision
import random


## SUPERVISED LEARNING ON 20% OF THE DATASET

In [None]:
torch.manual_seed(0)

<torch._C.Generator at 0x798016a23290>

In [None]:
path_dataset_train = pathlib.Path("/content/gdrive/MyDrive/imagenette2-320/train")
path_dataset_val = pathlib.Path("/content/gdrive/MyDrive/imagenette2-320/val")
path_labels = pathlib.Path("/content/gdrive/MyDrive/imagenette2-320/imagenette_labels.json")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_workers = 2

# Data related
with path_labels.open("r") as f:
    label_mapping = json.load(f)

transform_plain = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.Resize((224, 224),antialias=True),
    ]
)
dataset_train_plain = ImageFolder(path_dataset_train, transform=transform_plain)
dataset_val_plain = ImageFolder(path_dataset_val, transform=transform_plain)

if dataset_train_plain.classes != dataset_val_plain.classes:
    raise ValueError("Inconsistent classes")

batch_size = 256
data_loader_train_plain = DataLoader(dataset_train_plain,batch_size=256, shuffle = True,drop_last=False,num_workers=n_workers,)
data_loader_val_plain = DataLoader(dataset_val_plain,batch_size=batch_size,drop_last=False,num_workers=n_workers,shuffle = True)
data_loader_val_plain_subset = DataLoader(dataset_val_plain,batch_size=batch_size,drop_last=False,sampler=SubsetRandomSampler(list(range(0, len(dataset_val_plain), 50))),num_workers=n_workers,)


In [None]:
dataset_val_plain

Dataset ImageFolder
    Number of datapoints: 3925
    Root location: /content/gdrive/MyDrive/imagenette2-320/val
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
           )

In [None]:
class resnet(nn.Module):

    def __init__(self,num_classes,pretrained = False):

        super(resnet, self).__init__()
        self.num_classes = num_classes
        if pretrained :
          self.encoder = torchvision.models.resnet18(weights = "DEFAULT")
        else:
          self.encoder = torchvision.models.resnet18()
        self.activation = nn.ReLU()
        self.classifier = nn.Linear(1000,num_classes)

    def forward(self, x):

        x = self.encoder(x)
        x =  self.activation(x)
        x = self.classifier(x)

        return x
class vit_base(nn.Module):

  def __init__(self,num_classes,pretrained = False):

      super(resnet, self).__init__()
      self.num_classes = num_classes

      self.encoder = timm.create_model('vit_base_patch16_224', pretrained=pretrained)

      self.activation = nn.ReLU()
      self.classifier = nn.Linear(1000,num_classes)

  def forward(self, x):

      x = self.encoder(x)
      x =  self.activation(x)
      x = self.classifier(x)

      return x

In [None]:
# Splitting dataset
n_train = len(dataset_train_plain)
indices = list(range(n_train))
random.shuffle(indices)

split = int(np.floor(0.2 * n_train))
train_indices, _ = indices[:split], indices[split:]

n_test = len(dataset_val_plain)
indices = list(range(n_test))
random.shuffle(indices)

split = int(np.floor(0.2 * n_test))
test_indices, _ = indices[:split], indices[split:]

train_supervised_sampler = SubsetRandomSampler(train_indices)
test_supervised_sampler = SubsetRandomSampler(test_indices)

data_loader_train_supervised = DataLoader(dataset_train_plain, batch_size=128,sampler=train_supervised_sampler, num_workers=n_workers)
data_loader_test_supervised = DataLoader(dataset_val_plain, batch_size=128,sampler=test_supervised_sampler, num_workers=n_workers)



In [None]:
print(f'Size of Training Dataset :{len(data_loader_train_supervised)*128}')
print(f'Size of Validation Dataset :{len(data_loader_test_supervised)*128}')

Size of Training Dataset :1920
Size of Validation Dataset :896


In [None]:
# Create a model for supervised learning
supervised_model = resnet(num_classes = 10,pretrained = False)
supervised_model = supervised_model.to(device)

optimizer_supervised = torch.optim.Adam(supervised_model.parameters(), lr= 0.001)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def test_model_accuracy(test_loader, model, criterion):
    correct_predictions = 0
    total_samples = 0
    average_loss = 0
    model.eval()


    with torch.no_grad():
        for idx, (img, label) in enumerate(test_loader):
            img = img.to(device)
            label = label.to(device)


            output = model(img)
            loss = criterion(output, label)
            average_loss += loss.item()

            # Calculate the accuracy
            _, predicted_labels = output.max(1)
            correct_predictions += (predicted_labels == label).sum().item()

            total_samples += label.size(0)

    test_accuracy = correct_predictions / total_samples
    average_loss /= len(test_loader)
    print(f'Test Accuracy: {test_accuracy:.4f}')
    print(f'Average Test Loss : {average_loss:.4f}')
    return test_accuracy,average_loss

def train(train_loader,val_loader,model,optimizer,criterion,num_epochs):

    training_loss = []
    training_accuracy = []
    testing_accuracy =[]
    testing_loss = []
    for i in range(num_epochs):
        running_loss = 0
        correct_predictions = 0
        model.train()
        for idx, (img,label) in enumerate(train_loader):

            # label = torch.tensor(label)
            bsz = label.shape[0]
            optimizer.zero_grad()
            img = img.to(device)
            label = label.to(device)
            prediction = model(img)

            # Calculate the loss
            loss = criterion(prediction, label)
            running_loss += loss.item()

            # Calculate the accuracy
            _, predicted_labels = prediction.max(1)
            curr_predicted_labels = (predicted_labels == label).sum().item()
            correct_predictions += curr_predicted_labels

            loss.backward()
            optimizer.step()

            if idx%5 == 0:
                print(f"idx : {idx} , Current Loss(batch) : {loss.item()}, Correct Predictions(batch) : {curr_predicted_labels}/{bsz}")

        running_loss /= len(train_loader)
        epoch_accuracy = correct_predictions / (1893)

        print(f'Epoch: {i}, Loss (per batch) : {running_loss:.4f}, Accuracy: {epoch_accuracy:.4f}')

        training_loss.append(running_loss)
        training_accuracy.append(epoch_accuracy)

        # test_accuracy(data_loader_val_subset,model)

        t,l = test_model_accuracy(val_loader,model,criterion)
        testing_accuracy.append(t)
        testing_loss.append(l)

    return training_loss, training_accuracy,testing_accuracy,testing_loss

# **TRAINING RESNET FROM SCRATCH**
TRAINING LOSS : 0.03 \\
TRAINING ACCURACY : 99% \\
BEST TEST ACCURACY : 70%




In [None]:
supervised_model = resnet(num_classes = 10,pretrained = False)
supervised_model = supervised_model.to(device)
optimizer_supervised = torch.optim.Adam(supervised_model.parameters(), lr= 0.00005,weight_decay=0.0001)
criterion = torch.nn.CrossEntropyLoss()
a1,b1,c1,d1 = train(data_loader_train_supervised,data_loader_test_supervised,supervised_model,optimizer_supervised,criterion,50)


idx : 0 , Current Loss(batch) : 2.0837905406951904, Correct Predictions(batch) : 39/128
idx : 5 , Current Loss(batch) : 1.90829598903656, Correct Predictions(batch) : 41/128
idx : 10 , Current Loss(batch) : 1.9454395771026611, Correct Predictions(batch) : 40/128
Epoch: 0, Loss (per batch) : 1.9312, Accuracy: 0.3170
Test Accuracy: 0.1287
Average Test Loss : 4.8453
idx : 0 , Current Loss(batch) : 1.6776411533355713, Correct Predictions(batch) : 52/128
idx : 5 , Current Loss(batch) : 1.8509056568145752, Correct Predictions(batch) : 42/128
idx : 10 , Current Loss(batch) : 1.7008922100067139, Correct Predictions(batch) : 53/128
Epoch: 1, Loss (per batch) : 1.7340, Accuracy: 0.3941
Test Accuracy: 0.2599
Average Test Loss : 2.2175
idx : 0 , Current Loss(batch) : 1.5756869316101074, Correct Predictions(batch) : 61/128
idx : 5 , Current Loss(batch) : 1.5255651473999023, Correct Predictions(batch) : 54/128
idx : 10 , Current Loss(batch) : 1.4662233591079712, Correct Predictions(batch) : 66/128
E

# **FINE-TUNING USING IMAGENET PRETRAINED WEIGHTS**
TRAINING LOSS : 0.03 \\
TRAINING ACCURACY : 100% \\
BEST TEST ACCURACY : 96.31%

In [None]:
supervised_model = resnet(num_classes = 10,pretrained = True)
supervised_model = supervised_model.to(device)
optimizer_supervised = torch.optim.Adam(supervised_model.parameters(), lr= 0.00005,weight_decay=0.0001)
criterion = torch.nn.CrossEntropyLoss()
a,b,c,d = train(data_loader_train_supervised,data_loader_test_supervised,supervised_model,optimizer_supervised,criterion,50)

idx : 0 , Current Loss(batch) : 2.6880009174346924, Correct Predictions(batch) : 18/128
idx : 5 , Current Loss(batch) : 1.6161912679672241, Correct Predictions(batch) : 59/128
idx : 10 , Current Loss(batch) : 1.0312111377716064, Correct Predictions(batch) : 94/128
Epoch: 0, Loss (per batch) : 1.4793, Accuracy: 0.5388
Test Accuracy: 0.8955
Average Test Loss : 0.5593
idx : 0 , Current Loss(batch) : 0.4932309091091156, Correct Predictions(batch) : 116/128
idx : 5 , Current Loss(batch) : 0.30241039395332336, Correct Predictions(batch) : 123/128
idx : 10 , Current Loss(batch) : 0.3053094744682312, Correct Predictions(batch) : 119/128
Epoch: 1, Loss (per batch) : 0.3401, Accuracy: 0.9377
Test Accuracy: 0.9414
Average Test Loss : 0.2514
idx : 0 , Current Loss(batch) : 0.1299310177564621, Correct Predictions(batch) : 127/128
idx : 5 , Current Loss(batch) : 0.14610449969768524, Correct Predictions(batch) : 124/128
idx : 10 , Current Loss(batch) : 0.14258304238319397, Correct Predictions(batch) 

## **THE BELOW CELLS WERE LOST DUE TO SYNCHRONISATION ERRORS IN GOOGLE COLAB NOTEBOOKS. THE CODE HAS BEEN REPRODUCED. IMPORTANT RESULTS FROM THE OUTPUT HAVE BEEN SUMMARISED**

## **TRAINING VIT FROM SCRATCH**


In [None]:
supervised_model = vit_base(num_classes = 10,pretrained = False)
supervised_model = supervised_model.to(device)
optimizer_supervised = torch.optim.Adam(supervised_model.parameters(), lr= 0.00005,weight_decay=0.0001)
criterion = torch.nn.CrossEntropyLoss()
a1,b1,c1,d1 = train(data_loader_train_supervised,data_loader_test_supervised,supervised_model,optimizer_supervised,criterion,50)

## VALIDATION ACCURACY : 52.35 %

## **TRAINING VIT USING PRETRAINED WEIGHTS (TRANSFER LEARNING)**

In [None]:
supervised_model = vit_base(num_classes = 10,pretrained = True)
supervised_model = supervised_model.to(device)
optimizer_supervised = torch.optim.Adam(supervised_model.parameters(), lr= 0.00005,weight_decay=0.0001)
criterion = torch.nn.CrossEntropyLoss()
a1,b1,c1,d1 = train(data_loader_train_supervised,data_loader_test_supervised,supervised_model,optimizer_supervised,criterion,50)

## VALIDATION ACCURACY : 98.34%

# DINO CODE taken from  
<ol>
  <li>https://github.com/jankrepl/mildlyoverfitted/tree/master/github_adventures/dino</li>
  <li>https://github.com/facebookresearch/dino</li>

</ol>

In [None]:
class DataAugmentation:
    """Create crops of an input image together with additional augmentation.

    It generates 2 global crops and `n_local_crops` local crops.

    Parameters
    ----------
    global_crops_scale : tuple
        Range of sizes for the global crops.

    local_crops_scale : tuple
        Range of sizes for the local crops.

    n_local_crops : int
        Number of local crops to create.

    size : int
        The size of the final image.

    Attributes
    ----------
    global_1, global_2 : transforms.Compose
        Two global transforms.

    local : transforms.Compose
        Local transform. Note that the augmentation is stochastic so one
        instance is enough and will lead to different crops.
    """
    def __init__(
        self,
        global_crops_scale=(0.4, 1),
        local_crops_scale=(0.05, 0.4),
        n_local_crops=8,
        size=224,
    ):
        self.n_local_crops = n_local_crops
        RandomGaussianBlur = lambda p: transforms.RandomApply(  # noqa
            [transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2))],
            p=p,
        )

        flip_and_jitter = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [
                        transforms.ColorJitter(
                            brightness=0.4,
                            contrast=0.4,
                            saturation=0.2,
                            hue=0.1,
                        ),
                    ]
                ),
                transforms.RandomGrayscale(p=0.2),
            ]
        )

        normalize = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )

        self.global_1 = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    size,
                    scale=global_crops_scale,
                    interpolation=Image.BICUBIC,
                ),
                flip_and_jitter,
                RandomGaussianBlur(1.0),  # always apply
                normalize,
            ],
        )

        self.global_2 = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    size,
                    scale=global_crops_scale,
                    interpolation=Image.BICUBIC,
                ),
                flip_and_jitter,
                RandomGaussianBlur(0.1),
                transforms.RandomSolarize(170, p=0.2),
                normalize,
            ],
        )

        self.local = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    size,
                    scale=local_crops_scale,
                    interpolation=Image.BICUBIC,
                ),
                flip_and_jitter,
                RandomGaussianBlur(0.5),
                normalize,
            ],
        )

    def __call__(self, img):
        """Apply transformation.

        Parameters
        ----------
        img : PIL.Image
            Input image.

        Returns
        -------
        all_crops : list
            List of `torch.Tensor` representing different views of
            the input `img`.
        """
        all_crops = []
        all_crops.append(self.global_1(img))
        all_crops.append(self.global_2(img))

        all_crops.extend([self.local(img) for _ in range(self.n_local_crops)])

        return all_crops

class Head(nn.Module):
    """Network hooked up to the CLS token embedding.

    Just a MLP with the last layer being normalized in a particular way.

    Parameters
    ----------
    in_dim : int
        The dimensionality of the token embedding.

    out_dim : int
        The dimensionality of the final layer (we compute the softmax over).

    hidden_dim : int
        Dimensionality of the hidden layers.

    bottleneck_dim : int
        Dimensionality of the second last layer.

    n_layers : int
        The number of layers.

    norm_last_layer : bool
        If True, then we freeze the norm of the weight of the last linear layer
        to 1.

    Attributes
    ----------
    mlp : nn.Sequential
        Vanilla multi-layer perceptron.

    last_layer : nn.Linear
        Reparametrized linear layer with weight normalization. That means
        that that it will have `weight_g` and `weight_v` as learnable
        parameters instead of a single `weight`.
    """

    def __init__(
        self,
        in_dim,
        out_dim,
        hidden_dim=2048,
        bottleneck_dim=256,
        n_layers=3,
        norm_last_layer=False,
    ):
        super().__init__()
        if n_layers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            layers.append(nn.GELU())
            for _ in range(n_layers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = nn.Sequential(*layers)

        self.apply(self._init_weights)

        self.last_layer = nn.utils.weight_norm(
            nn.Linear(bottleneck_dim, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        """Initialize learnable parameters."""
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Of shape `(n_samples, in_dim)`.

        Returns
        -------
        torch.Tensor
            Of shape `(n_samples, out_dim)`.
        """
        x = self.mlp(x)  # (n_samples, bottleneck_dim)
        x = nn.functional.normalize(x, dim=-1, p=2)  # (n_samples, bottleneck_dim)
        x = self.last_layer(x)  # (n_samples, out_dim)

        return x

class MultiCropWrapper(nn.Module):
  """Convenience class for forward pass of multiple crops.

  Parameters
  ----------
  backbone : timm.models.vision_transformer.VisionTransformer
      Instantiated Vision Transformer. Note that we will take the `head`
      attribute and replace it with `nn.Identity`.

  new_head : Head
      New head that is going to be put on top of the `backbone`.
  """
  def __init__(self, backbone, new_head):
      super().__init__()
      backbone.head = nn.Identity()  # deactivate original head
      self.backbone = backbone
      self.new_head = new_head

  def forward(self, x):
      """Run the forward pass.

      The different crops are concatenated along the batch dimension
      and then a single forward pass is fun. The resulting tensor
      is then chunked back to per crop tensors.

      Parameters
      ----------
      x : list
          List of `torch.Tensor` each of shape `(n_samples, 3, size, size)`.

      Returns
      -------
      tuple
          Tuple of `torch.Tensor` each of shape `(n_samples, out_dim)` where
          `output_dim` is determined by `Head`.
      """
      n_crops = len(x)
      concatenated = torch.cat(x, dim=0)  # (n_samples * n_crops, 3, size, size)
      # print(concatenated.shape)
      cls_embedding = self.backbone(concatenated)  # (n_samples * n_crops, in_dim)
      # print(cls_embedding.shape)
      logits = self.new_head(cls_embedding)  # (n_samples * n_crops, out_dim)
      chunks = logits.chunk(n_crops)  # n_crops * (n_samples, out_dim)

      return chunks

class Loss(nn.Module):
    """The loss function.

    We subclass the `nn.Module` becuase we want to create a buffer for the
    logits center of the teacher.

    Parameters
    ----------
    out_dim : int
        The dimensionality of the final layer (we computed the softmax over).

    teacher_temp, student_temp : float
        Softmax temperature of the teacher resp. student.

    center_momentum : float
        Hyperparameter for the exponential moving average that determines
        the center logits. The higher the more the running average matters.
    """
    def __init__(
        self, out_dim, teacher_temp=0.04, student_temp=0.1, center_momentum=0.9
    ):
        super().__init__()
        self.student_temp = student_temp
        self.teacher_temp = teacher_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))

    def forward(self, student_output, teacher_output):
        """Evaluate loss.

        Parameters
        ----------
        student_output, teacher_output : tuple
            Tuple of tensors of shape `(n_samples, out_dim)` representing
            logits. The length is equal to number of crops.
            Note that student processed all crops and that the two initial crops
            are the global ones.

        Returns
        -------
        loss : torch.Tensor
            Scalar representing the average loss.
        """
        student_temp = [s / self.student_temp for s in student_output]
        teacher_temp = [(t - self.center) / self.teacher_temp for t in teacher_output]

        student_sm = [F.log_softmax(s, dim=-1) for s in student_temp]
        teacher_sm = [F.softmax(t, dim=-1).detach() for t in teacher_temp]

        total_loss = 0
        n_loss_terms = 0

        for t_ix, t in enumerate(teacher_sm):
            for s_ix, s in enumerate(student_sm):
                if t_ix == s_ix:
                    continue

                loss = torch.sum(-t * s, dim=-1)  # (n_samples,)
                total_loss += loss.mean()  # scalar
                n_loss_terms += 1

        total_loss /= n_loss_terms
        self.update_center(teacher_output)

        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """Update center used for teacher output.

        Compute the exponential moving average.

        Parameters
        ----------
        teacher_output : tuple
            Tuple of tensors of shape `(n_samples, out_dim)` where each
            tensor represents a different crop.
        """
        batch_center = torch.cat(teacher_output).mean(
            dim=0, keepdim=True
        )  # (1, out_dim)
        self.center = self.center * self.center_momentum + batch_center * (
            1 - self.center_momentum
        )


def clip_gradients(model, clip=2.0):
    """Rescale norm of computed gradients.

    Parameters
    ----------
    model : nn.Module
        Module.

    clip : float
        Maximum norm.
    """
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            clip_coef = clip / (param_norm + 1e-6)
            if clip_coef < 1:
                p.grad.data.mul_(clip_coef)

def compute_knn(backbone, data_loader_train, data_loader_val):
    """Get CLS embeddings and use KNN classifier on them.

    We load all embeddings in memory and use sklearn. Should
    be doable.

    Parameters
    ----------
    backbone : timm.models.vision_transformer.VisionTransformer
        Vision transformer whose head is just an identity
        mapping.

    data_loader_train, data_loader_val : torch.utils.data.DataLoader
        Training and validation dataloader that does not apply any
        augmentations. Just casting to tensor and then normalizing.

    Returns
    -------
    val_accuracy : float
        Validation accuracy.
    """
    device = next(backbone.parameters()).device

    data_loaders = {
        "train": data_loader_train,
        "val": data_loader_val,
    }
    lists = {
        "X_train": [],
        "y_train": [],
        "X_val": [],
        "y_val": [],
    }

    for name, data_loader in data_loaders.items():
        for imgs, y in data_loader:
            imgs = imgs.to(device)
            lists[f"X_{name}"].append(backbone(imgs).detach().cpu().numpy())
            lists[f"y_{name}"].append(y.detach().cpu().numpy())

    arrays = {k: np.concatenate(l) for k, l in lists.items()}

    estimator = KNeighborsClassifier()
    estimator.fit(arrays["X_train"], arrays["y_train"])
    y_val_pred = estimator.predict(arrays["X_val"])

    acc = accuracy_score(arrays["y_val"], y_val_pred)

    return acc

def compute_embedding(backbone, data_loader):
    """Compute CLS embedding and prepare for TensorBoard.

    Parameters
    ----------
    backbone : timm.models.vision_transformer.VisionTransformer
        Vision transformer. The head should be an identity mapping.

    data_loader : torch.utils.data.DataLoader
        Validation dataloader that does not apply any augmentations. Just
        casting to tensor and then normalizing.

    Returns
    -------
    embs : torch.Tensor
        Embeddings of shape `(n_samples, out_dim)`.

    imgs : torch.Tensor
        Images of shape `(n_samples, 3, height, width)`.

    labels : list
        List of strings representing the classes.
    """
    device = next(backbone.parameters()).device

    embs_l = []
    imgs_l = []
    labels = []

    for img, y in data_loader:
        img = img.to(device)
        embs_l.append(backbone(img).detach().cpu())
        imgs_l.append(((img * 0.224) + 0.45).cpu())  # undo norm
        labels.extend([data_loader.dataset.classes[i] for i in y.tolist()])

    embs = torch.cat(embs_l, dim=0)
    imgs = torch.cat(imgs_l, dim=0)

    return embs, imgs, labels

In [None]:
batch_size = 64
device = "cuda"
logging_freq = 400
n_crops = 4
momentum_teacher = 0.995
n_epochs  = 40
dim = 1000
out_dim = 8192
clip_grad = 2.0
norm_last_layer = True
teacher_temp = 0.04
student_temp = 0.1
weight_decay = 0.0001

In [None]:
batch_size = 32
transform_aug = DataAugmentation(size=224, n_local_crops=n_crops - 2)
dataset_train_aug = ImageFolder(path_dataset_train, transform=transform_aug)
data_loader_train_aug = DataLoader(dataset_train_aug,batch_size=batch_size,shuffle=True,
    drop_last=True,num_workers=n_workers, pin_memory=True,)


## STUDENT AND TEACHER WITH RESNET18 AS BACKBONE



In [None]:
student_cnn = torchvision.models.resnet18()
teacher_cnn = torchvision.models.resnet18()
student = MultiCropWrapper(student_cnn,Head(dim, out_dim, norm_last_layer=True, hidden_dim=512,bottleneck_dim=256,))
teacher = MultiCropWrapper(teacher_cnn, Head(dim, out_dim, norm_last_layer=True, hidden_dim=512,bottleneck_dim=256,))
student, teacher = student.to(device), teacher.to(device)

teacher.load_state_dict(student.state_dict())

for p in teacher.parameters():
        p.requires_grad = False



In [None]:
# Loss related
loss_inst = Loss(out_dim, teacher_temp=teacher_temp,student_temp=student_temp,).to(device)

lr = 0.0005 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=weight_decay)

# Training loop
n_batches = len(dataset_train_aug) //batch_size
best_acc = 0
n_steps = 0

In [None]:
save_best_path = '/content/gdrive/MyDrive/imagenette2-320/resnet_weights/best_model.pth'
best_loss = 1e3
n_epochs = 50

In [None]:
# save_best_path = '/content/gdrive/MyDrive/imagenette2-320/VIT_weights/best_model.pth'
dino_loss = []
KNN_accuracy = []
def train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst,validate_every = 5):
  for e in range(n_epochs):
    student.train()
    for i, (images, _) in enumerate(data_loader_train_aug):

        images = [img.to(device) for img in images]

        teacher_output = teacher(images[:2])
        student_output = student(images)

        loss = loss_inst(student_output, teacher_output)
        dino_loss.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        clip_gradients(student, clip_grad)
        optimizer.step()
        if loss.item() < best_loss:
          torch.save(student.state_dict(),save_best_path)
          best_loss = loss.item()

        if i%100 == 0:
          print(f"  epoch : {e+start_epoch}  i : {i}  Loss : {loss.item()}")

        with torch.no_grad():
            for student_ps, teacher_ps in zip(student.parameters(), teacher.parameters()):
                teacher_ps.data.mul_(momentum_teacher)
                teacher_ps.data.add_(
                    (1 - momentum_teacher) * student_ps.detach().data
                )

    if e % validate_every == 0:
        save_path = f'/content/gdrive/MyDrive/imagenette2-320/resnet_weights/model_{e}.pth'
        torch.save(student.state_dict(),save_path)
        student.eval()
        current_acc = compute_knn(student.backbone,data_loader_train_plain,data_loader_val_plain,)
        KNN_accuracy.append(current_acc)
        print(f"epoch :{e+start_epoch} KNN Accuracy : {current_acc}")
        # student.train()

  return student

In [None]:
n_epochs = 50
start_epoch = 0
student = train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst)

  epoch : 0  i : 0  Loss : 9.02420711517334
  epoch : 0  i : 100  Loss : 8.976109504699707
epoch :0 KNN Accuracy : 0.3543949044585987
  epoch : 1  i : 0  Loss : 8.952603340148926
  epoch : 1  i : 100  Loss : 8.882328987121582
  epoch : 2  i : 0  Loss : 8.82746410369873
  epoch : 2  i : 100  Loss : 8.526862144470215
  epoch : 3  i : 0  Loss : 8.311239242553711
  epoch : 3  i : 100  Loss : 8.201010704040527
  epoch : 4  i : 0  Loss : 7.843644618988037
  epoch : 4  i : 100  Loss : 7.473136901855469
  epoch : 5  i : 0  Loss : 8.027811050415039
  epoch : 5  i : 100  Loss : 7.8927321434021
epoch :5 KNN Accuracy : 0.3819108280254777
  epoch : 6  i : 0  Loss : 7.878645896911621
  epoch : 6  i : 100  Loss : 7.714883327484131
  epoch : 7  i : 0  Loss : 7.267729759216309
  epoch : 7  i : 100  Loss : 7.498587608337402
  epoch : 8  i : 0  Loss : 7.465859889984131
  epoch : 8  i : 100  Loss : 7.473573684692383
  epoch : 9  i : 0  Loss : 7.3529510498046875
  epoch : 9  i : 100  Loss : 7.0195379257202

FileNotFoundError: ignored

In [None]:
n_epochs = 20
start_epoch = 42
student = train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst)

  epoch : 42  i : 0  Loss : 3.9818437099456787
  epoch : 42  i : 100  Loss : 4.130667686462402
epoch :42 KNN Accuracy : 0.6878980891719745
  epoch : 43  i : 0  Loss : 3.9896960258483887
  epoch : 43  i : 100  Loss : 4.026087284088135
  epoch : 44  i : 0  Loss : 4.018218994140625
  epoch : 44  i : 100  Loss : 4.055819988250732
  epoch : 45  i : 0  Loss : 4.135601043701172
  epoch : 45  i : 100  Loss : 4.053676605224609
  epoch : 46  i : 0  Loss : 3.822122097015381
  epoch : 46  i : 100  Loss : 4.155665874481201
  epoch : 47  i : 0  Loss : 3.6517083644866943
  epoch : 47  i : 100  Loss : 3.9256253242492676
epoch :47 KNN Accuracy : 0.6993630573248407
  epoch : 48  i : 0  Loss : 4.020618438720703
  epoch : 48  i : 100  Loss : 3.6503639221191406
  epoch : 49  i : 0  Loss : 3.4848246574401855
  epoch : 49  i : 100  Loss : 3.5844075679779053
  epoch : 50  i : 0  Loss : 3.855245590209961
  epoch : 50  i : 100  Loss : 3.5310707092285156
  epoch : 51  i : 0  Loss : 3.7953052520751953
  epoch : 5

In [None]:
lr = 0.0001 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=1e-4)
n_epochs = 30
start_epoch = 72
student = train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst)

  epoch : 72  i : 0  Loss : 3.1728665828704834
  epoch : 72  i : 100  Loss : 3.463533878326416
epoch :72 KNN Accuracy : 0.7230573248407643
  epoch : 73  i : 0  Loss : 2.776732921600342
  epoch : 73  i : 100  Loss : 3.240694046020508
  epoch : 74  i : 0  Loss : 3.0137243270874023
  epoch : 74  i : 100  Loss : 3.3743350505828857
  epoch : 75  i : 0  Loss : 3.2060632705688477
  epoch : 75  i : 100  Loss : 2.914559841156006
  epoch : 76  i : 0  Loss : 3.091904401779175
  epoch : 76  i : 100  Loss : 3.1865782737731934
  epoch : 77  i : 0  Loss : 3.4295878410339355
  epoch : 77  i : 100  Loss : 3.1751461029052734
epoch :77 KNN Accuracy : 0.7319745222929936
  epoch : 78  i : 0  Loss : 3.029054641723633
  epoch : 78  i : 100  Loss : 3.1192684173583984
  epoch : 79  i : 0  Loss : 3.183594226837158
  epoch : 79  i : 100  Loss : 3.3319945335388184
  epoch : 80  i : 0  Loss : 2.818366765975952
  epoch : 80  i : 100  Loss : 2.9077746868133545
  epoch : 81  i : 0  Loss : 3.204420566558838
  epoch : 

In [None]:
lr = 0.001 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=0.1)
n_epochs = 30
start_epoch = 102
student = train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst)

  epoch : 102  i : 0  Loss : 2.9144911766052246
  epoch : 102  i : 100  Loss : 2.9242312908172607
epoch :102 KNN Accuracy : 0.7388535031847133
  epoch : 103  i : 0  Loss : 2.8352274894714355
  epoch : 103  i : 100  Loss : 3.088395595550537
  epoch : 104  i : 0  Loss : 3.242990016937256
  epoch : 104  i : 100  Loss : 2.7682299613952637
  epoch : 105  i : 0  Loss : 2.7257182598114014
  epoch : 105  i : 100  Loss : 2.7173056602478027
  epoch : 106  i : 0  Loss : 2.6286368370056152
  epoch : 106  i : 100  Loss : 2.799177646636963
  epoch : 107  i : 0  Loss : 2.78991961479187
  epoch : 107  i : 100  Loss : 2.737703323364258
epoch :107 KNN Accuracy : 0.7541401273885351
  epoch : 108  i : 0  Loss : 2.8545520305633545
  epoch : 108  i : 100  Loss : 2.81748628616333
  epoch : 109  i : 0  Loss : 2.494121551513672
  epoch : 109  i : 100  Loss : 2.9271163940429688
  epoch : 110  i : 0  Loss : 2.7229373455047607
  epoch : 110  i : 100  Loss : 2.977524757385254
  epoch : 111  i : 0  Loss : 2.6828703

In [None]:
lr = 0.001 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=0.1)
n_epochs = 30
start_epoch = 132
student = train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst)

  epoch : 132  i : 0  Loss : 2.29766845703125
  epoch : 132  i : 100  Loss : 2.1402668952941895
epoch :132 KNN Accuracy : 0.7747770700636942
  epoch : 133  i : 0  Loss : 2.4085488319396973
  epoch : 133  i : 100  Loss : 2.331089496612549
  epoch : 134  i : 0  Loss : 2.3805415630340576
  epoch : 134  i : 100  Loss : 2.671372890472412
  epoch : 135  i : 0  Loss : 2.4093151092529297
  epoch : 135  i : 100  Loss : 2.1754777431488037
  epoch : 136  i : 0  Loss : 2.232269525527954
  epoch : 136  i : 100  Loss : 2.4026126861572266
  epoch : 137  i : 0  Loss : 2.3166816234588623
  epoch : 137  i : 100  Loss : 2.2506725788116455
epoch :137 KNN Accuracy : 0.781656050955414
  epoch : 138  i : 0  Loss : 2.136922597885132
  epoch : 138  i : 100  Loss : 2.4002652168273926
  epoch : 139  i : 0  Loss : 2.1796789169311523
  epoch : 139  i : 100  Loss : 2.3244190216064453
  epoch : 140  i : 0  Loss : 2.2618346214294434
  epoch : 140  i : 100  Loss : 2.2655889987945557
  epoch : 141  i : 0  Loss : 2.3687

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

epoch :142 KNN Accuracy : 0.7806369426751593
  epoch : 143  i : 0  Loss : 2.0087318420410156
  epoch : 143  i : 100  Loss : 2.3552968502044678
  epoch : 144  i : 0  Loss : 2.1430976390838623
  epoch : 144  i : 100  Loss : 2.419816017150879
  epoch : 145  i : 0  Loss : 2.413015842437744
  epoch : 145  i : 100  Loss : 2.2341041564941406
  epoch : 146  i : 0  Loss : 2.066159963607788
  epoch : 146  i : 100  Loss : 2.2251124382019043
  epoch : 147  i : 0  Loss : 2.216953992843628
  epoch : 147  i : 100  Loss : 2.4136812686920166


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

epoch :147 KNN Accuracy : 0.7921019108280255
  epoch : 148  i : 0  Loss : 2.2456014156341553
  epoch : 148  i : 100  Loss : 2.2208328247070312
  epoch : 149  i : 0  Loss : 2.263960838317871
  epoch : 149  i : 100  Loss : 2.1827526092529297
  epoch : 150  i : 0  Loss : 1.876591682434082
  epoch : 150  i : 100  Loss : 2.3720602989196777
  epoch : 151  i : 0  Loss : 2.2960333824157715
  epoch : 151  i : 100  Loss : 2.3225386142730713
  epoch : 152  i : 0  Loss : 2.177553176879883
  epoch : 152  i : 100  Loss : 2.207897901535034


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

epoch :152 KNN Accuracy : 0.793375796178344
  epoch : 153  i : 0  Loss : 1.8615624904632568
  epoch : 153  i : 100  Loss : 2.1489360332489014
  epoch : 154  i : 0  Loss : 1.9984210729599
  epoch : 154  i : 100  Loss : 1.9416263103485107
  epoch : 155  i : 0  Loss : 2.5671677589416504
  epoch : 155  i : 100  Loss : 2.025700092315674
  epoch : 156  i : 0  Loss : 1.8957937955856323
  epoch : 156  i : 100  Loss : 2.2548880577087402
  epoch : 157  i : 0  Loss : 2.029087781906128
  epoch : 157  i : 100  Loss : 2.549213171005249


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

epoch :157 KNN Accuracy : 0.7882802547770701
  epoch : 158  i : 0  Loss : 2.089461326599121


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

  epoch : 158  i : 100  Loss : 2.0968079566955566
  epoch : 159  i : 0  Loss : 1.997278094291687


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

  epoch : 159  i : 100  Loss : 2.4767284393310547
  epoch : 160  i : 0  Loss : 2.3531360626220703


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

  epoch : 160  i : 100  Loss : 2.0295796394348145
  epoch : 161  i : 0  Loss : 2.374478816986084


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>
Exception ignored in: Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x797218c46cb0>    
self._shutdown_workers()Traceback (most recent call last):

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
        self._shutdown_workers()if w.is_alive():

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 'can only test a child process'if w.is_alive():

  File "/usr/lib/python3.10/multiprocessing/process.py", line 1

  epoch : 161  i : 100  Loss : 2.0741348266601562


In [None]:
curr_path = '/content/gdrive/MyDrive/imagenette2-320/resnet_weights/best_model.pth'
student.load_state_dict(torch.load(curr_path))

<All keys matched successfully>

In [None]:
curr_path = '/content/gdrive/MyDrive/imagenette2-320/resnet_weights/best_model.pth'
student.load_state_dict(torch.load(curr_path))
student.eval()
current_acc = compute_knn(student.backbone,data_loader_train_plain,data_loader_val_plain,)
print("KNN Accuracy :",current_acc)

KNN Accuracy : 0.7903184713375796


### VIT AS BACKBONE

In [None]:
vit_name, dim = "deit_small_distilled_patch16_224", 384
student_vit = timm.create_model(vit_name, pretrained=False)
teacher_vit = timm.create_model(vit_name, pretrained=False)
student_vit.head = nn.Identity()
student_vit.head_dist = nn.Identity()
teacher_vit.head = nn.Identity()
teacher_vit.head_dist = nn.Identity()

student = MultiCropWrapper(student_vit,Head(dim, out_dim, norm_last_layer=True, hidden_dim=512,bottleneck_dim=256,))
teacher = MultiCropWrapper(teacher_vit, Head(dim, out_dim, norm_last_layer=True, hidden_dim=512,bottleneck_dim=256,))
student, teacher = student.to(device), teacher.to(device)

teacher.load_state_dict(student.state_dict())

for p in teacher.parameters():
        p.requires_grad = False



In [None]:
data_loader_train_plain = DataLoader(dataset_train_plain,batch_size=64, shuffle = True,drop_last=False,num_workers=4,)
data_loader_val_plain = DataLoader(dataset_val_plain,batch_size=64,drop_last=False,num_workers=4,shuffle = True)


In [None]:
# Loss related
loss_inst = Loss(out_dim, teacher_temp=teacher_temp,student_temp=student_temp,).to(device)
batch_size = 128
lr = 0.001 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=weight_decay)

# Training loop
n_batches = len(dataset_train_aug) //batch_size
best_acc = 0

In [None]:
compute_knn(student.backbone,data_loader_train_plain,data_loader_val_plain,)

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ececd22b1c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ececd22b1c0>
Traceback (most recent call last):
Exception ignored in:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7ececd22b1c0>    
self._shutdown_workers()Traceback (most recent call last):

  File "/usr/local/lib/pyt

0.292484076433121

In [None]:
lr = 0.001 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=0.1)
save_best_path = '/content/gdrive/MyDrive/imagenette2-320/VIT_weights/best_model.pth'
n_epochs = 50
start_epoch = 0
best_loss = 1e3
student = train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst,validate_every = 10)

  epoch : 0  i : 0  Loss : 8.747684478759766
  epoch : 0  i : 100  Loss : 8.750885009765625
  epoch : 0  i : 200  Loss : 8.641149520874023
epoch :0 KNN Accuracy : 0.2835668789808917
  epoch : 1  i : 0  Loss : 8.653779029846191
  epoch : 1  i : 100  Loss : 8.589271545410156
  epoch : 1  i : 200  Loss : 8.376199722290039
  epoch : 2  i : 0  Loss : 8.296825408935547
  epoch : 2  i : 100  Loss : 8.493345260620117
  epoch : 2  i : 200  Loss : 8.406209945678711
  epoch : 3  i : 0  Loss : 8.287588119506836
  epoch : 3  i : 100  Loss : 8.214757919311523
  epoch : 3  i : 200  Loss : 7.724208354949951
  epoch : 4  i : 0  Loss : 7.922868728637695
  epoch : 4  i : 100  Loss : 8.142062187194824
  epoch : 4  i : 200  Loss : 8.096233367919922
  epoch : 5  i : 0  Loss : 7.5859503746032715
  epoch : 5  i : 100  Loss : 7.683028697967529
  epoch : 5  i : 200  Loss : 7.212411403656006
  epoch : 6  i : 0  Loss : 7.3705058097839355
  epoch : 6  i : 100  Loss : 7.771353721618652
  epoch : 6  i : 200  Loss : 

In [None]:
batch_size = 32
transform_aug = DataAugmentation(size=224, n_local_crops=n_crops - 2)
dataset_train_aug = ImageFolder(path_dataset_train, transform=transform_aug)
data_loader_train_aug = DataLoader(dataset_train_aug,batch_size=batch_size,shuffle=True,
    drop_last=True,num_workers=4, pin_memory=True,)

In [None]:
save_best_path = '/content/gdrive/MyDrive/imagenette2-320/VIT_weights/best_model.pth'
student.load_state_dict(torch.load(save_best_path))

<All keys matched successfully>

In [None]:
lr = 0.001 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=0.1)
n_epochs = 50
start_epoch = 50
best_loss = 5
student = train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst,validate_every = 10)

  epoch : 50  i : 0  Loss : 12.271464347839355
  epoch : 50  i : 100  Loss : 8.887006759643555
epoch :50 KNN Accuracy : 0.4435668789808917
  epoch : 51  i : 0  Loss : 8.531305313110352
  epoch : 51  i : 100  Loss : 8.247798919677734
  epoch : 52  i : 0  Loss : 8.640562057495117
  epoch : 52  i : 100  Loss : 7.804385185241699
  epoch : 53  i : 0  Loss : 7.2993621826171875
  epoch : 53  i : 100  Loss : 6.740617275238037
  epoch : 54  i : 0  Loss : 6.533051490783691
  epoch : 54  i : 100  Loss : 5.355316162109375
  epoch : 55  i : 0  Loss : 5.454911708831787
  epoch : 55  i : 100  Loss : 5.337714672088623
  epoch : 56  i : 0  Loss : 4.813057899475098
  epoch : 56  i : 100  Loss : 4.965258598327637
  epoch : 57  i : 0  Loss : 4.933638572692871
  epoch : 57  i : 100  Loss : 4.813665390014648
  epoch : 58  i : 0  Loss : 5.190382480621338
  epoch : 58  i : 100  Loss : 4.824418067932129
  epoch : 59  i : 0  Loss : 4.764100551605225
  epoch : 59  i : 100  Loss : 4.9516472816467285
  epoch : 60 

In [None]:
save_best_path = '/content/gdrive/MyDrive/imagenette2-320/VIT_weights/latest_model.pth'
torch.save(student.state_dict(),save_best_path)

In [None]:
student.load_state_dict()

In [None]:
lr = 0.001 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=0.1)
n_epochs = 50
start_epoch = 100
best_loss = 5
student = train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst,validate_every = 10)

  epoch : 100  i : 0  Loss : 12.947257995605469
epoch :100 KNN Accuracy : 0.34878980891719746
  epoch : 101  i : 0  Loss : 8.884957313537598
  epoch : 102  i : 0  Loss : 8.831358909606934
  epoch : 103  i : 0  Loss : 8.723714828491211
  epoch : 104  i : 0  Loss : 8.526971817016602
  epoch : 105  i : 0  Loss : 8.32490062713623
  epoch : 106  i : 0  Loss : 8.04737377166748
  epoch : 107  i : 0  Loss : 8.157147407531738
  epoch : 108  i : 0  Loss : 7.973672389984131
  epoch : 109  i : 0  Loss : 7.221358299255371
  epoch : 110  i : 0  Loss : 7.382135391235352
epoch :110 KNN Accuracy : 0.5559235668789809
  epoch : 111  i : 0  Loss : 7.059484481811523
  epoch : 112  i : 0  Loss : 6.916423320770264
  epoch : 113  i : 0  Loss : 6.730546474456787
  epoch : 114  i : 0  Loss : 6.23309326171875
  epoch : 115  i : 0  Loss : 6.289758682250977
  epoch : 116  i : 0  Loss : 6.085557460784912
  epoch : 117  i : 0  Loss : 5.436591148376465
  epoch : 118  i : 0  Loss : 5.2024431228637695
  epoch : 119  i 

In [None]:
lr = 0.001 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=0.01)
n_epochs = 50
start_epoch = 150
best_loss = 3.4
student = train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst,validate_every = 10)

  epoch : 150  i : 0  Loss : 3.5474050045013428
epoch :150 KNN Accuracy : 0.660891719745223
  epoch : 151  i : 0  Loss : 3.4862780570983887
  epoch : 152  i : 0  Loss : 3.3241653442382812
  epoch : 153  i : 0  Loss : 3.400827169418335
  epoch : 154  i : 0  Loss : 3.2651419639587402
  epoch : 155  i : 0  Loss : 3.365139961242676
  epoch : 156  i : 0  Loss : 3.2948503494262695
  epoch : 157  i : 0  Loss : 3.3898301124572754
  epoch : 158  i : 0  Loss : 3.340912342071533
  epoch : 159  i : 0  Loss : 3.5738320350646973
  epoch : 160  i : 0  Loss : 3.201122999191284
epoch :160 KNN Accuracy : 0.6759235668789809
  epoch : 161  i : 0  Loss : 3.5490455627441406
  epoch : 162  i : 0  Loss : 3.3544507026672363
  epoch : 163  i : 0  Loss : 3.134982109069824
  epoch : 164  i : 0  Loss : 3.085092306137085
  epoch : 165  i : 0  Loss : 3.251201629638672
  epoch : 166  i : 0  Loss : 3.2641971111297607
  epoch : 167  i : 0  Loss : 3.3597631454467773
  epoch : 168  i : 0  Loss : 2.8766307830810547
  epoc

FileNotFoundError: ignored

In [None]:
lr = 0.001 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=0.01)
n_epochs = 50
start_epoch = 180
best_loss = 3
student = train_Dino(best_loss,n_epochs,start_epoch,student,teacher,optimizer,loss_inst,validate_every = 10)

  epoch : 180  i : 0  Loss : 2.813631057739258
epoch :180 KNN Accuracy : 0.6963057324840765
  epoch : 181  i : 0  Loss : 2.880256175994873
  epoch : 182  i : 0  Loss : 2.753526210784912
  epoch : 183  i : 0  Loss : 2.6740455627441406
  epoch : 184  i : 0  Loss : 2.647549867630005
  epoch : 185  i : 0  Loss : 2.806070327758789
  epoch : 186  i : 0  Loss : 2.674097776412964
  epoch : 187  i : 0  Loss : 2.588306427001953
  epoch : 188  i : 0  Loss : 2.969336986541748
  epoch : 189  i : 0  Loss : 2.6850335597991943
  epoch : 190  i : 0  Loss : 2.6343657970428467
epoch :190 KNN Accuracy : 0.6991082802547771
  epoch : 191  i : 0  Loss : 2.8266971111297607
  epoch : 192  i : 0  Loss : 2.51401948928833
  epoch : 193  i : 0  Loss : 2.8059165477752686
  epoch : 194  i : 0  Loss : 2.688098669052124
  epoch : 195  i : 0  Loss : 2.662720203399658
  epoch : 196  i : 0  Loss : 2.736356735229492
  epoch : 197  i : 0  Loss : 2.709721565246582
  epoch : 198  i : 0  Loss : 2.713698387145996
  epoch : 199

# KNN ACCURACY = 70.7 %

## TRANSFER LEARNING/FINETUNING USING DINO SSL MODEL

In [None]:
class dino_resnet(nn.Module):

    def __init__(self,num_classes,student,freeze):

        super(dino_resnet, self).__init__()
        self.num_classes = num_classes
        self.encoder = student.backbone

        if freeze:
          for p in self.encoder.parameters():
            p.requires_grad = False
        else:
          for p in self.encoder.parameters():
            p.requires_grad = True

        self.activation = nn.ReLU()
        self.classifier = nn.Linear(1000,num_classes)

    def forward(self, x):

        x = self.encoder(x)
        x =  self.activation(x)
        x = self.classifier(x)

        return x

In [None]:
class dino_vit(nn.Module):
  def __init__(self,num_classes,student,freeze):
    super(dino_vit, self).__init__()
    self.num_classes = num_classes
    self.encoder = student.backbone

    if freeze:
      for p in self.encoder.parameters():
        p.requires_grad = False
    else:
      for p in self.encoder.parameters():
        p.requires_grad = True

    self.activation = nn.ReLU()
    self.classifier = nn.Linear(384,num_classes)

  def forward(self, x):

      x = self.encoder(x)
      x =  self.activation(x)
      x = self.classifier(x)

      return x


## RESNET DINO SSL

In [None]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
dino_model = dino_resnet(num_classes=10,student = student,freeze = True)
# model = resnet(num_classes = 10)
dino_model = dino_model.to(device)
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(dino_model.parameters(),lr = 0.001)

In [None]:
optimizer_supervised = torch.optim.SGD(dino_model.parameters(), lr= 0.0001,momentum  = 0.9)
a_,b_,c_,d_ = train(data_loader_train_supervised,data_loader_test_supervised,dino_model,optimizer_supervised,criterion,50)

idx : 0 , Current Loss(batch) : 2.343250036239624, Correct Predictions(batch) : 17/128
idx : 5 , Current Loss(batch) : 2.3718113899230957, Correct Predictions(batch) : 10/128
idx : 10 , Current Loss(batch) : 2.373744010925293, Correct Predictions(batch) : 10/128
Epoch: 0, Loss (per batch) : 2.3722, Accuracy: 0.1051
Test Accuracy: 0.1083
Average Test Loss : 2.3389
idx : 0 , Current Loss(batch) : 2.3546769618988037, Correct Predictions(batch) : 17/128
idx : 5 , Current Loss(batch) : 2.2986247539520264, Correct Predictions(batch) : 20/128
idx : 10 , Current Loss(batch) : 2.353295087814331, Correct Predictions(batch) : 13/128
Epoch: 1, Loss (per batch) : 2.3187, Accuracy: 0.1183
Test Accuracy: 0.1274
Average Test Loss : 2.2828
idx : 0 , Current Loss(batch) : 2.2513833045959473, Correct Predictions(batch) : 23/128
idx : 5 , Current Loss(batch) : 2.270355463027954, Correct Predictions(batch) : 23/128
idx : 10 , Current Loss(batch) : 2.2247743606567383, Correct Predictions(batch) : 21/128
Epo

In [None]:
optimizer_supervised = torch.optim.SGD(dino_model.parameters(), lr= 0.0001,momentum  = 0.9)
a_,b_,c_,d_ = train(data_loader_train_supervised,data_loader_test_supervised,dino_model,optimizer_supervised,criterion,50)

idx : 0 , Current Loss(batch) : 1.2266178131103516, Correct Predictions(batch) : 86/128
idx : 5 , Current Loss(batch) : 1.131690263748169, Correct Predictions(batch) : 101/128
idx : 10 , Current Loss(batch) : 1.065451979637146, Correct Predictions(batch) : 98/128
Epoch: 0, Loss (per batch) : 1.1060, Accuracy: 0.7501
Test Accuracy: 0.7299
Average Test Loss : 1.1807
idx : 0 , Current Loss(batch) : 1.0703883171081543, Correct Predictions(batch) : 97/128
idx : 5 , Current Loss(batch) : 1.111220121383667, Correct Predictions(batch) : 94/128
idx : 10 , Current Loss(batch) : 1.0724796056747437, Correct Predictions(batch) : 99/128
Epoch: 1, Loss (per batch) : 1.0971, Accuracy: 0.7549
Test Accuracy: 0.7350
Average Test Loss : 1.1620
idx : 0 , Current Loss(batch) : 1.1764750480651855, Correct Predictions(batch) : 90/128
idx : 5 , Current Loss(batch) : 1.0622411966323853, Correct Predictions(batch) : 97/128
idx : 10 , Current Loss(batch) : 1.034711480140686, Correct Predictions(batch) : 101/128
E

KeyboardInterrupt: ignored

In [None]:
optimizer_supervised = torch.optim.Adam(dino_model.parameters(), lr= 0.001)
a_,b_,c_,d_ = train(data_loader_train_supervised,data_loader_test_supervised,dino_model,optimizer_supervised,criterion,50)

idx : 0 , Current Loss(batch) : 2.2249107360839844, Correct Predictions(batch) : 12/128
idx : 5 , Current Loss(batch) : 0.12002411484718323, Correct Predictions(batch) : 127/128
idx : 10 , Current Loss(batch) : 0.07858672738075256, Correct Predictions(batch) : 127/128
Epoch: 0, Loss (per batch) : 0.3747, Accuracy: 0.9033
Test Accuracy: 0.6599
Average Test Loss : 1.4702
idx : 0 , Current Loss(batch) : 0.023073842748999596, Correct Predictions(batch) : 127/128
idx : 5 , Current Loss(batch) : 0.03469126671552658, Correct Predictions(batch) : 127/128
idx : 10 , Current Loss(batch) : 0.09047216922044754, Correct Predictions(batch) : 125/128
Epoch: 1, Loss (per batch) : 0.0935, Accuracy: 0.9709
Test Accuracy: 0.5121
Average Test Loss : 3.7539
idx : 0 , Current Loss(batch) : 0.11606542766094208, Correct Predictions(batch) : 125/128
idx : 5 , Current Loss(batch) : 0.026564331725239754, Correct Predictions(batch) : 126/128
idx : 10 , Current Loss(batch) : 0.07669920474290848, Correct Prediction

In [None]:
optimizer_supervised = torch.optim.Adam(dino_model.parameters(), lr= 0.0001)
a_,b_,c_,d_ = train(data_loader_train_supervised,data_loader_test_supervised,dino_model,optimizer_supervised,criterion,50)

idx : 0 , Current Loss(batch) : 8.131758659146726e-05, Correct Predictions(batch) : 128/128
idx : 5 , Current Loss(batch) : 3.8776881410740316e-05, Correct Predictions(batch) : 128/128
idx : 10 , Current Loss(batch) : 1.2073081961716525e-05, Correct Predictions(batch) : 128/128
Epoch: 0, Loss (per batch) : 0.0001, Accuracy: 1.0000
Test Accuracy: 0.8064
Average Test Loss : 1.2446
idx : 0 , Current Loss(batch) : 0.00013744711759500206, Correct Predictions(batch) : 128/128
idx : 5 , Current Loss(batch) : 4.138654276175657e-06, Correct Predictions(batch) : 128/128
idx : 10 , Current Loss(batch) : 9.37211461859988e-06, Correct Predictions(batch) : 128/128
Epoch: 1, Loss (per batch) : 0.0000, Accuracy: 1.0000
Test Accuracy: 0.8064
Average Test Loss : 1.2684
idx : 0 , Current Loss(batch) : 9.15104101295583e-06, Correct Predictions(batch) : 128/128
idx : 5 , Current Loss(batch) : 3.3917608561750967e-06, Correct Predictions(batch) : 128/128
idx : 10 , Current Loss(batch) : 1.7406917322659865e-0

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x791f638770a0>
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x791f638770a0>  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():    
if w.is_alive():  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive

      File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a

Test Accuracy: 0.8178
Average Test Loss : 1.1106
idx : 0 , Current Loss(batch) : 2.77533558801224e-07, Correct Predictions(batch) : 128/128
idx : 5 , Current Loss(batch) : 8.009370588979436e-08, Correct Predictions(batch) : 128/128
idx : 10 , Current Loss(batch) : 4.731035687655094e-07, Correct Predictions(batch) : 128/128
Epoch: 30, Loss (per batch) : 0.0000, Accuracy: 1.0000
Test Accuracy: 0.8153
Average Test Loss : 1.1987
idx : 0 , Current Loss(batch) : 2.5425049443583703e-07, Correct Predictions(batch) : 128/128
idx : 5 , Current Loss(batch) : 1.9743994528198527e-07, Correct Predictions(batch) : 128/128
idx : 10 , Current Loss(batch) : 1.7601976765035943e-07, Correct Predictions(batch) : 128/128
Epoch: 31, Loss (per batch) : 0.0000, Accuracy: 1.0000
Test Accuracy: 0.8153
Average Test Loss : 1.2262
idx : 0 , Current Loss(batch) : 4.358577996299573e-07, Correct Predictions(batch) : 128/128
idx : 5 , Current Loss(batch) : 1.3690424793821876e-07, Correct Predictions(batch) : 128/128
id

In [None]:
a_,b_,c_,d_ = train(data_loader_train_supervised,data_loader_test_supervised,dino_model,optimizer,criterion,50)

idx : 0 , Current Loss(batch) : 2.309063673019409, Correct Predictions(batch) : 15/128
idx : 5 , Current Loss(batch) : 1.6374412775039673, Correct Predictions(batch) : 79/128
idx : 10 , Current Loss(batch) : 1.0436230897903442, Correct Predictions(batch) : 94/128
Epoch: 0, Loss (per batch) : 1.5277, Accuracy: 0.5747
Test Accuracy: 0.6102
Average Test Loss : 1.2284
idx : 0 , Current Loss(batch) : 0.7566370964050293, Correct Predictions(batch) : 97/128
idx : 5 , Current Loss(batch) : 0.6162517666816711, Correct Predictions(batch) : 102/128
idx : 10 , Current Loss(batch) : 0.4694118797779083, Correct Predictions(batch) : 107/128
Epoch: 1, Loss (per batch) : 0.5575, Accuracy: 0.8146
Test Accuracy: 0.5439
Average Test Loss : 1.7321
idx : 0 , Current Loss(batch) : 0.40690839290618896, Correct Predictions(batch) : 107/128
idx : 5 , Current Loss(batch) : 0.29295873641967773, Correct Predictions(batch) : 118/128
idx : 10 , Current Loss(batch) : 0.2980726361274719, Correct Predictions(batch) : 1

In [None]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
dino_model = dino_resnet(num_classes=10,student = student,freeze = True)
# model = resnet(num_classes = 10)
dino_model = dino_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(dino_model.parameters(),lr = 0.001)

In [None]:
a1_,b1_,c1_,d1_ = train(data_loader_train_supervised,data_loader_test_supervised,dino_model,optimizer,criterion,50)

idx : 0 , Current Loss(batch) : 2.5271718502044678, Correct Predictions(batch) : 2/128
idx : 5 , Current Loss(batch) : 0.855148196220398, Correct Predictions(batch) : 127/128
idx : 10 , Current Loss(batch) : 0.2452448010444641, Correct Predictions(batch) : 128/128
Epoch: 0, Loss (per batch) : 0.8205, Accuracy: 0.8685
Test Accuracy: 0.8115
Average Test Loss : 0.6565
idx : 0 , Current Loss(batch) : 0.0667416974902153, Correct Predictions(batch) : 128/128
idx : 5 , Current Loss(batch) : 0.040391165763139725, Correct Predictions(batch) : 128/128
idx : 10 , Current Loss(batch) : 0.022342432290315628, Correct Predictions(batch) : 128/128
Epoch: 1, Loss (per batch) : 0.0376, Accuracy: 1.0000
Test Accuracy: 0.8115
Average Test Loss : 0.6522
idx : 0 , Current Loss(batch) : 0.024831531569361687, Correct Predictions(batch) : 128/128
idx : 5 , Current Loss(batch) : 0.013432176783680916, Correct Predictions(batch) : 128/128
idx : 10 , Current Loss(batch) : 0.010921841487288475, Correct Predictions(

# VIT DINO SSL MODEL

## **FINETUNING**

In [None]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
dino_model = dino_vit(num_classes=10,student = student,freeze = False)
# model = resnet(num_classes = 10)
dino_model = dino_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(dino_model.parameters(),lr = 0.001)

In [None]:
a1_v,b1_v,c1_v,d1_v = train(data_loader_train_supervised,data_loader_test_supervised,dino_model,optimizer,criterion,50)

idx : 0 , Current Loss(batch) : 2.243157386779785, Correct Predictions(batch) : 34/128
idx : 5 , Current Loss(batch) : 2.145616292953491, Correct Predictions(batch) : 36/128
idx : 10 , Current Loss(batch) : 1.9251033067703247, Correct Predictions(batch) : 74/128
Epoch: 0, Loss (per batch) : 2.0451, Accuracy: 0.4036
Test Accuracy: 0.5503
Average Test Loss : 1.9198
idx : 0 , Current Loss(batch) : 1.7066307067871094, Correct Predictions(batch) : 104/128
idx : 5 , Current Loss(batch) : 1.5666342973709106, Correct Predictions(batch) : 111/128
idx : 10 , Current Loss(batch) : 1.5042965412139893, Correct Predictions(batch) : 109/128
Epoch: 1, Loss (per batch) : 1.5433, Accuracy: 0.8468
Test Accuracy: 0.6471
Average Test Loss : 1.6029
idx : 0 , Current Loss(batch) : 1.3037490844726562, Correct Predictions(batch) : 119/128
idx : 5 , Current Loss(batch) : 1.1758002042770386, Correct Predictions(batch) : 127/128
idx : 10 , Current Loss(batch) : 1.0561203956604004, Correct Predictions(batch) : 126

## **VALIDATION ACCURACY : 72.87 %**

## **PRETRAINING**

In [None]:
dino_model = dino_vit(num_classes=10,student = student,freeze = True)
# model = resnet(num_classes = 10)
dino_model = dino_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(dino_model.parameters(),lr = 0.0001)

In [None]:
a1_v,b1_v,c1_v,d1_v = train(data_loader_train_supervised,data_loader_test_supervised,dino_model,optimizer,criterion,50)

idx : 0 , Current Loss(batch) : 2.3060736656188965, Correct Predictions(batch) : 14/128
idx : 5 , Current Loss(batch) : 2.299527168273926, Correct Predictions(batch) : 12/128
idx : 10 , Current Loss(batch) : 2.3207600116729736, Correct Predictions(batch) : 7/128
Epoch: 0, Loss (per batch) : 2.3015, Accuracy: 0.1057
Test Accuracy: 0.1108
Average Test Loss : 2.2741
idx : 0 , Current Loss(batch) : 2.2829644680023193, Correct Predictions(batch) : 16/128
idx : 5 , Current Loss(batch) : 2.289402961730957, Correct Predictions(batch) : 12/128
idx : 10 , Current Loss(batch) : 2.264461040496826, Correct Predictions(batch) : 17/128
Epoch: 1, Loss (per batch) : 2.2654, Accuracy: 0.1337
Test Accuracy: 0.1325
Average Test Loss : 2.2643
idx : 0 , Current Loss(batch) : 2.2682061195373535, Correct Predictions(batch) : 17/128
idx : 5 , Current Loss(batch) : 2.2222723960876465, Correct Predictions(batch) : 24/128
idx : 10 , Current Loss(batch) : 2.2379138469696045, Correct Predictions(batch) : 20/128
Epo

## VALIDATION ACCURACY : 66.24

In [None]:
np.save(f'/content/gdrive/MyDrive/imagenette2-320/VIT_weights/a1_v_freeze.npy',a1_v)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/VIT_weights/b1_v_freeze.npy',b1_v)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/VIT_weights/c1_v_freeze.npy',c1_v)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/VIT_weights/d1_v.npy_freeze',d1_v)

In [None]:
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/a_.npy',a_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/b_.npy',b_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/c_.npy',c_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/d_.npy',d_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/a1_.npy',a1_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/b1_.npy',b1_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/c1_.npy',c1_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/d1_.npy',d1_)

In [None]:
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/a.npy',a_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/b.npy',b)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/c.npy',c)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/d.npy',d)

In [None]:
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/a.npy',a)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/b.npy',b)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/c.npy',c)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/d.npy',d)


In [None]:
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/a1.npy',a1)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/b1.npy',b1)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/c1.npy',c1)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/d1.npy',d1)
