In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
! pip install timm

Collecting timm
  Downloading timm-0.9.8-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub (from timm)
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors (from timm)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m34.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: safetensors, huggingface-hub, timm
Successfully installed huggingface-hub-0.18.0 safetensors-0.4.0 timm-0.9.8


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier

import json
import pathlib

import timm
import tqdm
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
import torchvision
import random


## SUPERVISED LEARNING ON 20% OF THE DATASET

In [4]:
path_dataset_train = pathlib.Path("/content/gdrive/MyDrive/imagenette2-320/train")
path_dataset_val = pathlib.Path("/content/gdrive/MyDrive/imagenette2-320/val")
path_labels = pathlib.Path("/content/gdrive/MyDrive/imagenette2-320/imagenette_labels.json")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_workers = 2

# Data related
with path_labels.open("r") as f:
    label_mapping = json.load(f)

transform_plain = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.Resize((224, 224),antialias=True),
    ]
)
dataset_train_plain = ImageFolder(path_dataset_train, transform=transform_plain)
dataset_val_plain = ImageFolder(path_dataset_val, transform=transform_plain)

if dataset_train_plain.classes != dataset_val_plain.classes:
    raise ValueError("Inconsistent classes")

batch_size = 256
data_loader_train_plain = DataLoader(dataset_train_plain,batch_size=batch_size, shuffle = True,
    drop_last=False,num_workers=n_workers,)
data_loader_val_plain = DataLoader(dataset_val_plain,batch_size=batch_size,
    drop_last=True,num_workers=n_workers,shuffle = True)
data_loader_val_plain_subset = DataLoader(dataset_val_plain,batch_size=batch_size,drop_last=False,
    sampler=SubsetRandomSampler(list(range(0, len(dataset_val_plain), 50))),num_workers=n_workers,)


In [5]:
class resnet(nn.Module):

    def __init__(self,num_classes,pretrained = False):

        super(resnet, self).__init__()
        self.num_classes = num_classes
        if pretrained :
          self.encoder = torchvision.models.resnet18(weights = "DEFAULT")
        else:
          self.encoder = torchvision.models.resnet18()
        self.activation = nn.ReLU()
        self.classifier = nn.Linear(1000,num_classes)

    def forward(self, x):

        x = self.encoder(x)
        x =  self.activation(x)
        x = self.classifier(x)

        return x

In [None]:
# Splitting dataset
n_train = len(dataset_train_plain)
indices = list(range(n_train))
random.shuffle(indices)

split = int(np.floor(0.2 * n_train))
train_indices, _ = indices[:split], indices[split:]

n_test = len(dataset_val_plain)
indices = list(range(n_test))
random.shuffle(indices)

split = int(np.floor(0.2 * n_test))
test_indices, _ = indices[:split], indices[split:]

train_supervised_sampler = SubsetRandomSampler(train_indices)
test_supervised_sampler = SubsetRandomSampler(test_indices)

data_loader_train_supervised = DataLoader(dataset_train_plain, batch_size=128,sampler=train_supervised_sampler, num_workers=n_workers)
data_loader_test_supervised = DataLoader(dataset_val_plain, batch_size=128,sampler=test_supervised_sampler, num_workers=n_workers)



In [6]:
# Create a model for supervised learning
supervised_model = resnet(num_classes = 10,pretrained = False)
# supervised_model = LeNet5(num_classes = 10)
supervised_model = supervised_model.to(device)

optimizer_supervised = torch.optim.Adam(supervised_model.parameters(), lr= 0.001)
criterion = torch.nn.CrossEntropyLoss()

In [8]:
def test_model_accuracy(test_loader, model, criterion):
    correct_predictions = 0
    total_samples = 0
    average_loss = 0
    model.eval()


    with torch.no_grad():
        for idx, (img, label) in enumerate(test_loader):
            img = img.to(device)
            label = label.to(device)


            output = model(img)
            loss = criterion(output, label)
            average_loss += loss.item()

            # Calculate the accuracy
            _, predicted_labels = output.max(1)
            correct_predictions += (predicted_labels == label).sum().item()

            total_samples += label.size(0)

    test_accuracy = correct_predictions / total_samples
    average_loss /= len(test_loader)
    print(f'Test Accuracy: {test_accuracy:.4f}')
    print(f'Average Test Loss : {average_loss:.4f}')
    return test_accuracy,average_loss

def train(train_loader,val_loader,model,optimizer,criterion,num_epochs):

    training_loss = []
    training_accuracy = []
    testing_accuracy =[]
    testing_loss = []
    for i in range(num_epochs):
        running_loss = 0
        correct_predictions = 0
        model.train()
        for idx, (img,label) in enumerate(train_loader):

            # label = torch.tensor(label)
            bsz = label.shape[0]
            optimizer.zero_grad()
            img = img.to(device)
            label = label.to(device)
            prediction = model(img)

            # Calculate the loss
            loss = criterion(prediction, label)
            running_loss += loss.item()

            # Calculate the accuracy
            _, predicted_labels = prediction.max(1)
            curr_predicted_labels = (predicted_labels == label).sum().item()
            correct_predictions += curr_predicted_labels

            loss.backward()
            optimizer.step()

            if idx%5 == 0:
                print(f"idx : {idx} , Current Loss(batch) : {loss.item()}, Correct Predictions(batch) : {curr_predicted_labels}/{bsz}")

        running_loss /= len(train_loader)
        epoch_accuracy = correct_predictions / (1893)

        print(f'Epoch: {i}, Loss (per batch) : {running_loss:.4f}, Accuracy: {epoch_accuracy:.4f}')

        training_loss.append(running_loss)
        training_accuracy.append(epoch_accuracy)

        # test_accuracy(data_loader_val_subset,model)

        t,l = test_model_accuracy(val_loader,model,criterion)
        testing_accuracy.append(t)
        testing_loss.append(l)

    return training_loss, training_accuracy,testing_accuracy,testing_loss

# **TRAINING RESNET FROM SCRATCH**
TRAINING LOSS : 0.03 \\
TRAINING ACCURACY : 99% \\
BEST TEST ACCURACY : 70%




In [None]:
a1,b1,c1,d1 = train(data_loader_train_supervised,data_loader_test_supervised,supervised_model,optimizer_supervised,criterion,50)


idx : 0 , Current Loss(batch) : 2.0837905406951904, Correct Predictions(batch) : 39/128
idx : 5 , Current Loss(batch) : 1.90829598903656, Correct Predictions(batch) : 41/128
idx : 10 , Current Loss(batch) : 1.9454395771026611, Correct Predictions(batch) : 40/128
Epoch: 0, Loss (per batch) : 1.9312, Accuracy: 0.3170
Test Accuracy: 0.1287
Average Test Loss : 4.8453
idx : 0 , Current Loss(batch) : 1.6776411533355713, Correct Predictions(batch) : 52/128
idx : 5 , Current Loss(batch) : 1.8509056568145752, Correct Predictions(batch) : 42/128
idx : 10 , Current Loss(batch) : 1.7008922100067139, Correct Predictions(batch) : 53/128
Epoch: 1, Loss (per batch) : 1.7340, Accuracy: 0.3941
Test Accuracy: 0.2599
Average Test Loss : 2.2175
idx : 0 , Current Loss(batch) : 1.5756869316101074, Correct Predictions(batch) : 61/128
idx : 5 , Current Loss(batch) : 1.5255651473999023, Correct Predictions(batch) : 54/128
idx : 10 , Current Loss(batch) : 1.4662233591079712, Correct Predictions(batch) : 66/128
E

# **FINE-TUNING USING IMAGENET PRETRAINED WEIGHTS**
TRAINING LOSS : 0.03 \\
TRAINING ACCURACY : 99% \\
BEST TEST ACCURACY : 85.5%

In [None]:
a,b,c,d = train(data_loader_train_supervised,data_loader_test_supervised,supervised_model,optimizer_supervised,criterion,50)


idx : 0 , Current Loss(batch) : 2.985105037689209, Correct Predictions(batch) : 5/128
idx : 5 , Current Loss(batch) : 0.2605835497379303, Correct Predictions(batch) : 118/128
idx : 10 , Current Loss(batch) : 0.4118060767650604, Correct Predictions(batch) : 110/128
Epoch: 0, Loss (per batch) : 0.5977, Accuracy: 0.8267
Test Accuracy: 0.5924
Average Test Loss : 3.3110
idx : 0 , Current Loss(batch) : 0.21699590981006622, Correct Predictions(batch) : 119/128
idx : 5 , Current Loss(batch) : 0.18178032338619232, Correct Predictions(batch) : 121/128
idx : 10 , Current Loss(batch) : 0.30990925431251526, Correct Predictions(batch) : 118/128
Epoch: 1, Loss (per batch) : 0.2742, Accuracy: 0.9149
Test Accuracy: 0.6471
Average Test Loss : 2.4216
idx : 0 , Current Loss(batch) : 0.18378369510173798, Correct Predictions(batch) : 123/128
idx : 5 , Current Loss(batch) : 0.20953992009162903, Correct Predictions(batch) : 121/128
idx : 10 , Current Loss(batch) : 0.1788659244775772, Correct Predictions(batch

## SUPERVISED LEARNING ON 100% TRAINING DATASET

In [10]:
batch_size = 128
data_loader_train_full = DataLoader(dataset_train_plain,batch_size=batch_size, shuffle = True,
    drop_last=False,num_workers=n_workers,)
data_loader_val_full = DataLoader(dataset_val_plain,batch_size=batch_size,
    drop_last=True,num_workers=n_workers,shuffle = True)

In [None]:
a2,b2,c2,d2 = train(data_loader_train_full,data_loader_val_full,supervised_model,optimizer_supervised,criterion,50)


idx : 0 , Current Loss(batch) : 1.6387784481048584, Correct Predictions(batch) : 56/128
idx : 5 , Current Loss(batch) : 1.6421983242034912, Correct Predictions(batch) : 59/128
idx : 10 , Current Loss(batch) : 1.520270824432373, Correct Predictions(batch) : 64/128
idx : 15 , Current Loss(batch) : 1.5230779647827148, Correct Predictions(batch) : 62/128
idx : 20 , Current Loss(batch) : 1.5206944942474365, Correct Predictions(batch) : 65/128
idx : 25 , Current Loss(batch) : 1.6448556184768677, Correct Predictions(batch) : 58/128


In [None]:
class DataAugmentation:
    """Create crops of an input image together with additional augmentation.

    It generates 2 global crops and `n_local_crops` local crops.

    Parameters
    ----------
    global_crops_scale : tuple
        Range of sizes for the global crops.

    local_crops_scale : tuple
        Range of sizes for the local crops.

    n_local_crops : int
        Number of local crops to create.

    size : int
        The size of the final image.

    Attributes
    ----------
    global_1, global_2 : transforms.Compose
        Two global transforms.

    local : transforms.Compose
        Local transform. Note that the augmentation is stochastic so one
        instance is enough and will lead to different crops.
    """
    def __init__(
        self,
        global_crops_scale=(0.4, 1),
        local_crops_scale=(0.05, 0.4),
        n_local_crops=8,
        size=224,
    ):
        self.n_local_crops = n_local_crops
        RandomGaussianBlur = lambda p: transforms.RandomApply(  # noqa
            [transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2))],
            p=p,
        )

        flip_and_jitter = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [
                        transforms.ColorJitter(
                            brightness=0.4,
                            contrast=0.4,
                            saturation=0.2,
                            hue=0.1,
                        ),
                    ]
                ),
                transforms.RandomGrayscale(p=0.2),
            ]
        )

        normalize = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )

        self.global_1 = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    size,
                    scale=global_crops_scale,
                    interpolation=Image.BICUBIC,
                ),
                flip_and_jitter,
                RandomGaussianBlur(1.0),  # always apply
                normalize,
            ],
        )

        self.global_2 = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    size,
                    scale=global_crops_scale,
                    interpolation=Image.BICUBIC,
                ),
                flip_and_jitter,
                RandomGaussianBlur(0.1),
                transforms.RandomSolarize(170, p=0.2),
                normalize,
            ],
        )

        self.local = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    size,
                    scale=local_crops_scale,
                    interpolation=Image.BICUBIC,
                ),
                flip_and_jitter,
                RandomGaussianBlur(0.5),
                normalize,
            ],
        )

    def __call__(self, img):
        """Apply transformation.

        Parameters
        ----------
        img : PIL.Image
            Input image.

        Returns
        -------
        all_crops : list
            List of `torch.Tensor` representing different views of
            the input `img`.
        """
        all_crops = []
        all_crops.append(self.global_1(img))
        all_crops.append(self.global_2(img))

        all_crops.extend([self.local(img) for _ in range(self.n_local_crops)])

        return all_crops

class Head(nn.Module):
    """Network hooked up to the CLS token embedding.

    Just a MLP with the last layer being normalized in a particular way.

    Parameters
    ----------
    in_dim : int
        The dimensionality of the token embedding.

    out_dim : int
        The dimensionality of the final layer (we compute the softmax over).

    hidden_dim : int
        Dimensionality of the hidden layers.

    bottleneck_dim : int
        Dimensionality of the second last layer.

    n_layers : int
        The number of layers.

    norm_last_layer : bool
        If True, then we freeze the norm of the weight of the last linear layer
        to 1.

    Attributes
    ----------
    mlp : nn.Sequential
        Vanilla multi-layer perceptron.

    last_layer : nn.Linear
        Reparametrized linear layer with weight normalization. That means
        that that it will have `weight_g` and `weight_v` as learnable
        parameters instead of a single `weight`.
    """

    def __init__(
        self,
        in_dim,
        out_dim,
        hidden_dim=2048,
        bottleneck_dim=256,
        n_layers=3,
        norm_last_layer=False,
    ):
        super().__init__()
        if n_layers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            layers.append(nn.GELU())
            for _ in range(n_layers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = nn.Sequential(*layers)

        self.apply(self._init_weights)

        self.last_layer = nn.utils.weight_norm(
            nn.Linear(bottleneck_dim, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        """Initialize learnable parameters."""
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        """Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Of shape `(n_samples, in_dim)`.

        Returns
        -------
        torch.Tensor
            Of shape `(n_samples, out_dim)`.
        """
        x = self.mlp(x)  # (n_samples, bottleneck_dim)
        x = nn.functional.normalize(x, dim=-1, p=2)  # (n_samples, bottleneck_dim)
        x = self.last_layer(x)  # (n_samples, out_dim)

        return x

class MultiCropWrapper(nn.Module):
  """Convenience class for forward pass of multiple crops.

  Parameters
  ----------
  backbone : timm.models.vision_transformer.VisionTransformer
      Instantiated Vision Transformer. Note that we will take the `head`
      attribute and replace it with `nn.Identity`.

  new_head : Head
      New head that is going to be put on top of the `backbone`.
  """
  def __init__(self, backbone, new_head):
      super().__init__()
      backbone.head = nn.Identity()  # deactivate original head
      self.backbone = backbone
      self.new_head = new_head

  def forward(self, x):
      """Run the forward pass.

      The different crops are concatenated along the batch dimension
      and then a single forward pass is fun. The resulting tensor
      is then chunked back to per crop tensors.

      Parameters
      ----------
      x : list
          List of `torch.Tensor` each of shape `(n_samples, 3, size, size)`.

      Returns
      -------
      tuple
          Tuple of `torch.Tensor` each of shape `(n_samples, out_dim)` where
          `output_dim` is determined by `Head`.
      """
      n_crops = len(x)
      concatenated = torch.cat(x, dim=0)  # (n_samples * n_crops, 3, size, size)
      cls_embedding = self.backbone(concatenated)  # (n_samples * n_crops, in_dim)
      logits = self.new_head(cls_embedding)  # (n_samples * n_crops, out_dim)
      chunks = logits.chunk(n_crops)  # n_crops * (n_samples, out_dim)

      return chunks

class Loss(nn.Module):
    """The loss function.

    We subclass the `nn.Module` becuase we want to create a buffer for the
    logits center of the teacher.

    Parameters
    ----------
    out_dim : int
        The dimensionality of the final layer (we computed the softmax over).

    teacher_temp, student_temp : float
        Softmax temperature of the teacher resp. student.

    center_momentum : float
        Hyperparameter for the exponential moving average that determines
        the center logits. The higher the more the running average matters.
    """
    def __init__(
        self, out_dim, teacher_temp=0.04, student_temp=0.1, center_momentum=0.9
    ):
        super().__init__()
        self.student_temp = student_temp
        self.teacher_temp = teacher_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))

    def forward(self, student_output, teacher_output):
        """Evaluate loss.

        Parameters
        ----------
        student_output, teacher_output : tuple
            Tuple of tensors of shape `(n_samples, out_dim)` representing
            logits. The length is equal to number of crops.
            Note that student processed all crops and that the two initial crops
            are the global ones.

        Returns
        -------
        loss : torch.Tensor
            Scalar representing the average loss.
        """
        student_temp = [s / self.student_temp for s in student_output]
        teacher_temp = [(t - self.center) / self.teacher_temp for t in teacher_output]

        student_sm = [F.log_softmax(s, dim=-1) for s in student_temp]
        teacher_sm = [F.softmax(t, dim=-1).detach() for t in teacher_temp]

        total_loss = 0
        n_loss_terms = 0

        for t_ix, t in enumerate(teacher_sm):
            for s_ix, s in enumerate(student_sm):
                if t_ix == s_ix:
                    continue

                loss = torch.sum(-t * s, dim=-1)  # (n_samples,)
                total_loss += loss.mean()  # scalar
                n_loss_terms += 1

        total_loss /= n_loss_terms
        self.update_center(teacher_output)

        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """Update center used for teacher output.

        Compute the exponential moving average.

        Parameters
        ----------
        teacher_output : tuple
            Tuple of tensors of shape `(n_samples, out_dim)` where each
            tensor represents a different crop.
        """
        batch_center = torch.cat(teacher_output).mean(
            dim=0, keepdim=True
        )  # (1, out_dim)
        self.center = self.center * self.center_momentum + batch_center * (
            1 - self.center_momentum
        )


def clip_gradients(model, clip=2.0):
    """Rescale norm of computed gradients.

    Parameters
    ----------
    model : nn.Module
        Module.

    clip : float
        Maximum norm.
    """
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            clip_coef = clip / (param_norm + 1e-6)
            if clip_coef < 1:
                p.grad.data.mul_(clip_coef)

def compute_knn(backbone, data_loader_train, data_loader_val):
    """Get CLS embeddings and use KNN classifier on them.

    We load all embeddings in memory and use sklearn. Should
    be doable.

    Parameters
    ----------
    backbone : timm.models.vision_transformer.VisionTransformer
        Vision transformer whose head is just an identity
        mapping.

    data_loader_train, data_loader_val : torch.utils.data.DataLoader
        Training and validation dataloader that does not apply any
        augmentations. Just casting to tensor and then normalizing.

    Returns
    -------
    val_accuracy : float
        Validation accuracy.
    """
    device = next(backbone.parameters()).device

    data_loaders = {
        "train": data_loader_train,
        "val": data_loader_val,
    }
    lists = {
        "X_train": [],
        "y_train": [],
        "X_val": [],
        "y_val": [],
    }

    for name, data_loader in data_loaders.items():
        for imgs, y in data_loader:
            imgs = imgs.to(device)
            lists[f"X_{name}"].append(backbone(imgs).detach().cpu().numpy())
            lists[f"y_{name}"].append(y.detach().cpu().numpy())

    arrays = {k: np.concatenate(l) for k, l in lists.items()}

    estimator = KNeighborsClassifier()
    estimator.fit(arrays["X_train"], arrays["y_train"])
    y_val_pred = estimator.predict(arrays["X_val"])

    acc = accuracy_score(arrays["y_val"], y_val_pred)

    return acc

def compute_embedding(backbone, data_loader):
    """Compute CLS embedding and prepare for TensorBoard.

    Parameters
    ----------
    backbone : timm.models.vision_transformer.VisionTransformer
        Vision transformer. The head should be an identity mapping.

    data_loader : torch.utils.data.DataLoader
        Validation dataloader that does not apply any augmentations. Just
        casting to tensor and then normalizing.

    Returns
    -------
    embs : torch.Tensor
        Embeddings of shape `(n_samples, out_dim)`.

    imgs : torch.Tensor
        Images of shape `(n_samples, 3, height, width)`.

    labels : list
        List of strings representing the classes.
    """
    device = next(backbone.parameters()).device

    embs_l = []
    imgs_l = []
    labels = []

    for img, y in data_loader:
        img = img.to(device)
        embs_l.append(backbone(img).detach().cpu())
        imgs_l.append(((img * 0.224) + 0.45).cpu())  # undo norm
        labels.extend([data_loader.dataset.classes[i] for i in y.tolist()])

    embs = torch.cat(embs_l, dim=0)
    imgs = torch.cat(imgs_l, dim=0)

    return embs, imgs, labels

In [None]:
batch_size = 64
device = "cuda"
logging_freq = 200
n_crops = 4
momentum_teacher = 0.995
n_epochs  = 40
dim = 1000
out_dim = 1024
clip_grad = 2.0
norm_last_layer = True
teacher_temp = 0.04
student_temp = 0.1
weight_decay = 0.4

In [None]:
transform_aug = DataAugmentation(size=224, n_local_crops=n_crops - 2)
dataset_train_aug = ImageFolder(path_dataset_train, transform=transform_aug)
data_loader_train_aug = DataLoader(dataset_train_aug,batch_size=batch_size,shuffle=True,
    drop_last=True,num_workers=n_workers, pin_memory=True,)

In [None]:

student_vit = torchvision.models.resnet18()
teacher_vit = torchvision.models.resnet18()

student = MultiCropWrapper(student_vit,Head(dim, out_dim, norm_last_layer=True, hidden_dim=512,bottleneck_dim=256,))

teacher = MultiCropWrapper(teacher_vit, Head(dim, out_dim, norm_last_layer=True, hidden_dim=512,bottleneck_dim=256,))
student, teacher = student.to(device), teacher.to(device)

teacher.load_state_dict(student.state_dict())

for p in teacher.parameters():
        p.requires_grad = False

In [None]:
# Loss related
loss_inst = Loss(out_dim, teacher_temp=teacher_temp,student_temp=student_temp,).to(device)

lr = 0.001 * batch_size / 256
optimizer = torch.optim.AdamW(student.parameters(),lr=lr,weight_decay=weight_decay)

# Training loop
n_batches = len(dataset_train_aug) //batch_size
best_acc = 0
n_steps = 0

In [None]:
save_best_path = '/content/gdrive/MyDrive/imagenette2-320/resnet_weights/best_model.pth'
best_loss = 1e3

In [None]:
for e in range(n_epochs):
    for i, (images, _) in tqdm.tqdm(enumerate(data_loader_train_aug), total=n_batches):

        images = [img.to(device) for img in images]

        teacher_output = teacher(images[:2])
        student_output = student(images)

        loss = loss_inst(student_output, teacher_output)

        optimizer.zero_grad()
        loss.backward()
        clip_gradients(student, clip_grad)
        optimizer.step()
        if loss.item() < best_loss:
          torch.save(student.state_dict(),save_best_path+ f'_str{e+100}.pth')
          best_loss = loss.item()
        if i%100 == 0:
          print(f"epoch : {e}  i : {i}  Loss : {loss.item()}")
        with torch.no_grad():
            for student_ps, teacher_ps in zip(student.parameters(), teacher.parameters()):

                teacher_ps.data.mul_(momentum_teacher)
                teacher_ps.data.add_(
                    (1 - momentum_teacher) * student_ps.detach().data
                )
        n_steps += 1

    if e % 5 == 0:
        save_path = f'/content/gdrive/MyDrive/imagenette2-320/resnet_weights/model_{e}.pth'
        torch.save(student.state_dict(),save_path)
        student.eval()
        current_acc = compute_knn(student.backbone,data_loader_train_plain,data_loader_val_plain,)
        print(f"epoch :{e}  current_acc : {current_acc}")
        student.train()


  1%|          | 1/147 [00:06<15:35,  6.41s/it]

epoch : 0  i : 0  Loss : 8.49441146850586


 69%|██████▊   | 101/147 [04:46<02:34,  3.36s/it]

epoch : 0  i : 100  Loss : 8.991518020629883


100%|██████████| 147/147 [06:56<00:00,  2.83s/it]


epoch :100  current_acc : 0.33203125


  1%|          | 1/147 [00:02<06:27,  2.66s/it]

epoch : 1  i : 0  Loss : 8.986802101135254


 69%|██████▊   | 101/147 [01:48<00:55,  1.20s/it]

epoch : 1  i : 100  Loss : 8.98403549194336


100%|██████████| 147/147 [02:37<00:00,  1.07s/it]
  1%|          | 1/147 [00:02<06:20,  2.61s/it]

epoch : 2  i : 0  Loss : 8.968889236450195


 69%|██████▊   | 101/147 [01:51<00:56,  1.24s/it]

epoch : 2  i : 100  Loss : 8.92247200012207


100%|██████████| 147/147 [02:42<00:00,  1.10s/it]
  1%|          | 1/147 [00:02<06:48,  2.80s/it]

epoch : 3  i : 0  Loss : 8.899709701538086


  5%|▌         | 8/147 [00:11<03:14,  1.40s/it]


KeyboardInterrupt: ignored

In [None]:
best_path = '/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/best_resnet_model.pth'
student.load_state_dict(torch.load(best_path))

<All keys matched successfully>

In [None]:
student.eval()
current_acc = compute_knn(student.backbone,data_loader_train_plain,data_loader_val_plain)
student.train()
print(f"current_acc : {current_acc}")

KeyboardInterrupt: ignored

In [None]:
class dino_resnet(nn.Module):

    def __init__(self,num_classes,student,freeze):

        super(dino_resnet, self).__init__()
        self.num_classes = num_classes
        self.encoder = student.backbone

        if freeze:
          for p in self.encoder.parameters():
            p.requires_grad = False
        else:
          for p in self.encoder.parameters():
            p.requires_grad = True

        self.activation = nn.ReLU()
        self.classifier = nn.Linear(1000,num_classes)

    def forward(self, x):

        x = self.encoder(x)
        x =  self.activation(x)
        x = self.classifier(x)

        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
dino_model = dino_resnet(num_classes=10,student = student,freeze = False)
# model = resnet(num_classes = 10)
dino_model = dino_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(dino_model.parameters(),lr = 0.001)

In [None]:
a_,b_,c_,d_ = train(data_loader_train_supervised,data_loader_test_supervised,dino_model,optimizer,criterion,50)

idx : 0 , Current Loss(batch) : 2.309063673019409, Correct Predictions(batch) : 15/128
idx : 5 , Current Loss(batch) : 1.6374412775039673, Correct Predictions(batch) : 79/128
idx : 10 , Current Loss(batch) : 1.0436230897903442, Correct Predictions(batch) : 94/128
Epoch: 0, Loss (per batch) : 1.5277, Accuracy: 0.5747
Test Accuracy: 0.6102
Average Test Loss : 1.2284
idx : 0 , Current Loss(batch) : 0.7566370964050293, Correct Predictions(batch) : 97/128
idx : 5 , Current Loss(batch) : 0.6162517666816711, Correct Predictions(batch) : 102/128
idx : 10 , Current Loss(batch) : 0.4694118797779083, Correct Predictions(batch) : 107/128
Epoch: 1, Loss (per batch) : 0.5575, Accuracy: 0.8146
Test Accuracy: 0.5439
Average Test Loss : 1.7321
idx : 0 , Current Loss(batch) : 0.40690839290618896, Correct Predictions(batch) : 107/128
idx : 5 , Current Loss(batch) : 0.29295873641967773, Correct Predictions(batch) : 118/128
idx : 10 , Current Loss(batch) : 0.2980726361274719, Correct Predictions(batch) : 1

In [None]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
dino_model = dino_resnet(num_classes=10,student = student,freeze = True)
# model = resnet(num_classes = 10)
dino_model = dino_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(dino_model.parameters(),lr = 0.001)

In [None]:
a1_,b1_,c1_,d1_ = train(data_loader_train_supervised,data_loader_test_supervised,dino_model,optimizer,criterion,50)

idx : 0 , Current Loss(batch) : 2.5271718502044678, Correct Predictions(batch) : 2/128
idx : 5 , Current Loss(batch) : 0.855148196220398, Correct Predictions(batch) : 127/128
idx : 10 , Current Loss(batch) : 0.2452448010444641, Correct Predictions(batch) : 128/128
Epoch: 0, Loss (per batch) : 0.8205, Accuracy: 0.8685
Test Accuracy: 0.8115
Average Test Loss : 0.6565
idx : 0 , Current Loss(batch) : 0.0667416974902153, Correct Predictions(batch) : 128/128
idx : 5 , Current Loss(batch) : 0.040391165763139725, Correct Predictions(batch) : 128/128
idx : 10 , Current Loss(batch) : 0.022342432290315628, Correct Predictions(batch) : 128/128
Epoch: 1, Loss (per batch) : 0.0376, Accuracy: 1.0000
Test Accuracy: 0.8115
Average Test Loss : 0.6522
idx : 0 , Current Loss(batch) : 0.024831531569361687, Correct Predictions(batch) : 128/128
idx : 5 , Current Loss(batch) : 0.013432176783680916, Correct Predictions(batch) : 128/128
idx : 10 , Current Loss(batch) : 0.010921841487288475, Correct Predictions(

In [None]:
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/a_.npy',a_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/b_.npy',b_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/c_.npy',c_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/d_.npy',d_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/a1_.npy',a1_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/b1_.npy',b1_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/c1_.npy',c1_)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/d1_.npy',d1_)




In [None]:
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/a.npy',a)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/b.npy',b)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/c.npy',c)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/d.npy',d)


In [None]:
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/a1.npy',a1)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/b1.npy',b1)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/c1.npy',c1)
np.save(f'/content/gdrive/MyDrive/imagenette2-320/Resnet_weights/d1.npy',d1)
