In [1]:
import os
import copy
os.environ["CUDA_AVAILABLE_DEVICES"] = "0"
import numpy as np
import matplotlib.pyplot as plt
import torch
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule
import torchvision
from torchvision.models import resnet18,efficientnet_b0
from torchvision import datasets, transforms
import torch.nn as nn
from copy import deepcopy
from torchvision import datasets, transforms
from astra.torch.metrics import f1_score,precision_score,recall_score
from astra.torch.utils import train_fn
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm,trange
import pandas as pd

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_data_path = "/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/tensor_data/data.pt"
test_data_path = "/home/rishabh.mondal/Brick-Kilns-project/albk_rishabh/tensor_data_final/test_data.pt" 

In [None]:
dataset = torchvision.datasets.ImageNet

In [5]:
train_data = torch.load(train_data_path)
test_data = torch.load(test_data_path)
train_images = train_data['images']
train_labels = train_data['labels']
test_images = test_data['images']
test_labels = test_data['labels']

In [6]:
train_labels.dtype, test_labels.dtype

(torch.uint8, torch.uint8)

In [7]:
train_images.shape,train_labels.shape

(torch.Size([25500, 3, 224, 224]), torch.Size([25500]))

In [8]:
test_images.shape, test_labels.shape

(torch.Size([4500, 3, 224, 224]), torch.Size([4500]))

In [9]:
data_transforms = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomRotation(360),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
       
train_images = train_images.float()
test_images = test_images.float()
# transform data and convert into pytorch dataset fromat to load into dataloader
transformed_images = data_transforms(train_images)
transformed_images.shape

torch.Size([25500, 3, 224, 224])

In [10]:
train_labels.dtype, test_labels.dtype

(torch.uint8, torch.uint8)

In [3]:
dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14", pretrained=True)

Using cache found in /home/vannsh.jani/.cache/torch/hub/facebookresearch_dinov2_main


In [11]:
dinov2_vits14

DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (n

In [4]:
class DinoVisionTransformerClassifier(nn.Module):
    def __init__(self):
        super(DinoVisionTransformerClassifier, self).__init__()
        self.transformer = dinov2_vits14
        self.classifier = nn.Sequential(nn.Linear(384, 256), nn.ReLU(), nn.Linear(256, 2))

    def forward(self, x):
        x = self.transformer(x)
        x = self.transformer.norm(x)
        x = self.classifier(x)
        return x


model = DinoVisionTransformerClassifier()

In [13]:
from sklearn.model_selection import train_test_split

In [14]:
x1,x2,y1,y2 = train_test_split(train_images,train_labels,test_size=0.2,random_state=42,stratify=train_labels)
x1.shape,x2.shape,y1.shape,y2.shape

(torch.Size([20400, 3, 224, 224]),
 torch.Size([5100, 3, 224, 224]),
 torch.Size([20400]),
 torch.Size([5100]))

In [15]:
def train(model, X_train, y_train, criterion, epochs, lr, verbose, batch_size=32):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    model.to(device)
    epoch_losses = []
    train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    for epoch in trange(epochs, desc="Epochs"):
        batch_losses = []  # Track losses per batch

        for i in range(0, len(X_train), batch_size):
            optimizer.zero_grad()
            X_batch = X_train[i:i+batch_size].to(device)
            y_batch = y_train[i:i+batch_size].to(device)
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            loss.backward()
            optimizer.step()
            batch_losses.append(loss.item())

        # Calculate average loss for the epoch
        epoch_loss = sum(batch_losses) / len(batch_losses)
        epoch_losses.append(epoch_loss)
        
        if verbose:
            print(f"Epoch: {epoch} Loss: {epoch_loss}")

    return epoch_losses


In [8]:
total = 0
for p in list(model.parameters()):
    # print(p.size())
    total += p.numel()
total


22155650

In [16]:
epoch_losses = train(model,train_images,train_labels,nn.CrossEntropyLoss(),250,0.0001,True,256)

Epochs:   0%|          | 1/250 [01:21<5:37:28, 81.32s/it]

Epoch: 0 Loss: 0.25385498493909836


Epochs:   1%|          | 2/250 [02:37<5:24:26, 78.50s/it]

Epoch: 1 Loss: 0.2356398443877697


Epochs:   1%|          | 3/250 [03:53<5:18:08, 77.28s/it]

Epoch: 2 Loss: 0.20249363876879214


Epochs:   2%|▏         | 4/250 [05:09<5:14:32, 76.72s/it]

Epoch: 3 Loss: 0.17608327999711038


Epochs:   2%|▏         | 5/250 [06:25<5:12:00, 76.41s/it]

Epoch: 4 Loss: 0.16226824067533016


Epochs:   2%|▏         | 6/250 [07:41<5:09:59, 76.23s/it]

Epoch: 5 Loss: 0.15793476402759551


Epochs:   3%|▎         | 7/250 [08:57<5:08:17, 76.12s/it]

Epoch: 6 Loss: 0.15115561556071044


Epochs:   3%|▎         | 8/250 [10:13<5:06:51, 76.08s/it]

Epoch: 7 Loss: 0.14693101592361926


Epochs:   4%|▎         | 9/250 [11:29<5:05:22, 76.03s/it]

Epoch: 8 Loss: 0.14504899736493826


Epochs:   4%|▍         | 10/250 [12:44<5:03:56, 75.99s/it]

Epoch: 9 Loss: 0.14129988949745895


Epochs:   4%|▍         | 11/250 [14:00<5:02:35, 75.97s/it]

Epoch: 10 Loss: 0.13824550811201333


Epochs:   5%|▍         | 12/250 [15:16<5:01:15, 75.95s/it]

Epoch: 11 Loss: 0.13519778810441493


Epochs:   5%|▌         | 13/250 [16:32<4:59:56, 75.94s/it]

Epoch: 12 Loss: 0.13049333224073054


Epochs:   6%|▌         | 14/250 [17:48<4:58:39, 75.93s/it]

Epoch: 13 Loss: 0.127298887334764


Epochs:   6%|▌         | 15/250 [19:04<4:57:23, 75.93s/it]

Epoch: 14 Loss: 0.12209689669311047


Epochs:   6%|▋         | 16/250 [20:20<4:56:10, 75.94s/it]

Epoch: 15 Loss: 0.1148313501663506


Epochs:   7%|▋         | 17/250 [21:36<4:54:56, 75.95s/it]

Epoch: 16 Loss: 0.11016385709866881


Epochs:   7%|▋         | 18/250 [22:52<4:53:37, 75.94s/it]

Epoch: 17 Loss: 0.10156538253650069


Epochs:   8%|▊         | 19/250 [24:08<4:52:20, 75.93s/it]

Epoch: 18 Loss: 0.09951435018330812


Epochs:   8%|▊         | 20/250 [25:24<4:51:01, 75.92s/it]

Epoch: 19 Loss: 0.0924288378842175


Epochs:   8%|▊         | 21/250 [26:40<4:49:41, 75.90s/it]

Epoch: 20 Loss: 0.08671568250283598


Epochs:   9%|▉         | 22/250 [27:55<4:48:26, 75.90s/it]

Epoch: 21 Loss: 0.08631590210366995


Epochs:   9%|▉         | 23/250 [29:11<4:47:13, 75.92s/it]

Epoch: 22 Loss: 0.08370115087367594


Epochs:  10%|▉         | 24/250 [30:27<4:45:58, 75.92s/it]

Epoch: 23 Loss: 0.08184430656023324


Epochs:  10%|█         | 25/250 [31:43<4:44:42, 75.92s/it]

Epoch: 24 Loss: 0.07334069529548287


Epochs:  10%|█         | 26/250 [32:59<4:43:27, 75.93s/it]

Epoch: 25 Loss: 0.06961551800835877


Epochs:  11%|█         | 27/250 [34:15<4:42:12, 75.93s/it]

Epoch: 26 Loss: 0.06957950656302274


Epochs:  11%|█         | 28/250 [35:31<4:40:59, 75.94s/it]

Epoch: 27 Loss: 0.06633375048171729


Epochs:  12%|█▏        | 29/250 [36:47<4:39:41, 75.94s/it]

Epoch: 28 Loss: 0.08394828253425658


Epochs:  12%|█▏        | 30/250 [38:03<4:38:22, 75.92s/it]

Epoch: 29 Loss: 0.06831895485520363


Epochs:  12%|█▏        | 31/250 [39:19<4:37:06, 75.92s/it]

Epoch: 30 Loss: 0.060278645539656284


Epochs:  13%|█▎        | 32/250 [40:35<4:35:50, 75.92s/it]

Epoch: 31 Loss: 0.0557171257562004


Epochs:  13%|█▎        | 33/250 [41:51<4:34:33, 75.92s/it]

Epoch: 32 Loss: 0.05199863213347271


Epochs:  14%|█▎        | 34/250 [43:07<4:33:18, 75.92s/it]

Epoch: 33 Loss: 0.05015220785746351


Epochs:  14%|█▍        | 35/250 [44:23<4:32:05, 75.93s/it]

Epoch: 34 Loss: 0.04533806032733992


Epochs:  14%|█▍        | 36/250 [45:39<4:30:51, 75.94s/it]

Epoch: 35 Loss: 0.049862908425275236


Epochs:  15%|█▍        | 37/250 [46:55<4:29:40, 75.97s/it]

Epoch: 36 Loss: 0.042837098406162115


Epochs:  15%|█▌        | 38/250 [48:11<4:28:25, 75.97s/it]

Epoch: 37 Loss: 0.04883116844343022


Epochs:  16%|█▌        | 39/250 [49:26<4:27:08, 75.97s/it]

Epoch: 38 Loss: 0.045396200213581324


Epochs:  16%|█▌        | 40/250 [50:42<4:25:47, 75.94s/it]

Epoch: 39 Loss: 0.05361553121358156


Epochs:  16%|█▋        | 41/250 [51:58<4:24:27, 75.92s/it]

Epoch: 40 Loss: 0.04992902347119525


Epochs:  17%|█▋        | 42/250 [53:14<4:23:10, 75.92s/it]

Epoch: 41 Loss: 0.042141640632180496


Epochs:  17%|█▋        | 43/250 [54:30<4:21:52, 75.91s/it]

Epoch: 42 Loss: 0.05199898059014231


Epochs:  18%|█▊        | 44/250 [55:46<4:20:34, 75.89s/it]

Epoch: 43 Loss: 0.040624696258455516


Epochs:  18%|█▊        | 45/250 [57:02<4:19:17, 75.89s/it]

Epoch: 44 Loss: 0.0401305298646912


Epochs:  18%|█▊        | 46/250 [58:18<4:18:00, 75.89s/it]

Epoch: 45 Loss: 0.03114296618849039


Epochs:  19%|█▉        | 47/250 [59:34<4:16:44, 75.89s/it]

Epoch: 46 Loss: 0.024619197994470595


Epochs:  19%|█▉        | 48/250 [1:00:49<4:15:28, 75.88s/it]

Epoch: 47 Loss: 0.02663301958120428


Epochs:  20%|█▉        | 49/250 [1:02:05<4:14:14, 75.89s/it]

Epoch: 48 Loss: 0.024082945478148758


Epochs:  20%|██        | 50/250 [1:03:21<4:12:59, 75.90s/it]

Epoch: 49 Loss: 0.020992480999557302


Epochs:  20%|██        | 51/250 [1:04:37<4:11:43, 75.90s/it]

Epoch: 50 Loss: 0.02430309049319476


Epochs:  21%|██        | 52/250 [1:05:53<4:10:28, 75.90s/it]

Epoch: 51 Loss: 0.0251440140302293


Epochs:  21%|██        | 53/250 [1:07:09<4:09:11, 75.90s/it]

Epoch: 52 Loss: 0.021050889422767796


Epochs:  22%|██▏       | 54/250 [1:08:25<4:07:55, 75.90s/it]

Epoch: 53 Loss: 0.019743605785188266


Epochs:  22%|██▏       | 55/250 [1:09:41<4:06:39, 75.89s/it]

Epoch: 54 Loss: 0.020592022595519666


Epochs:  22%|██▏       | 56/250 [1:10:57<4:05:25, 75.91s/it]

Epoch: 55 Loss: 0.016806966924195876


Epochs:  23%|██▎       | 57/250 [1:12:13<4:04:09, 75.90s/it]

Epoch: 56 Loss: 0.017233660049969332


Epochs:  23%|██▎       | 58/250 [1:13:28<4:02:53, 75.90s/it]

Epoch: 57 Loss: 0.01674878219666425


Epochs:  24%|██▎       | 59/250 [1:14:44<4:01:36, 75.90s/it]

Epoch: 58 Loss: 0.014363283088750905


Epochs:  24%|██▍       | 60/250 [1:16:00<4:00:21, 75.90s/it]

Epoch: 59 Loss: 0.012872870761784725


Epochs:  24%|██▍       | 61/250 [1:17:16<3:59:07, 75.91s/it]

Epoch: 60 Loss: 0.017214299075421876


Epochs:  25%|██▍       | 62/250 [1:18:32<3:57:49, 75.90s/it]

Epoch: 61 Loss: 0.01565613516489975


Epochs:  25%|██▌       | 63/250 [1:19:48<3:56:31, 75.89s/it]

Epoch: 62 Loss: 0.017150674771983176


Epochs:  26%|██▌       | 64/250 [1:21:04<3:55:16, 75.89s/it]

Epoch: 63 Loss: 0.017684966832748613


Epochs:  26%|██▌       | 65/250 [1:22:20<3:53:59, 75.89s/it]

Epoch: 64 Loss: 0.016180212182807737


Epochs:  26%|██▋       | 66/250 [1:23:36<3:52:42, 75.88s/it]

Epoch: 65 Loss: 0.017371032976079732


Epochs:  27%|██▋       | 67/250 [1:24:52<3:51:28, 75.89s/it]

Epoch: 66 Loss: 0.019668910001055338


Epochs:  27%|██▋       | 68/250 [1:26:07<3:50:12, 75.89s/it]

Epoch: 67 Loss: 0.01564921628814773


Epochs:  28%|██▊       | 69/250 [1:27:23<3:48:58, 75.90s/it]

Epoch: 68 Loss: 0.020180614302516914


Epochs:  28%|██▊       | 70/250 [1:28:39<3:47:43, 75.91s/it]

Epoch: 69 Loss: 0.01753228018351365


Epochs:  28%|██▊       | 71/250 [1:29:55<3:46:27, 75.91s/it]

Epoch: 70 Loss: 0.016634115726919845


Epochs:  29%|██▉       | 72/250 [1:31:11<3:45:11, 75.91s/it]

Epoch: 71 Loss: 0.018514700509258547


Epochs:  29%|██▉       | 73/250 [1:32:27<3:43:54, 75.90s/it]

Epoch: 72 Loss: 0.013100327029824256


Epochs:  30%|██▉       | 74/250 [1:33:43<3:42:38, 75.90s/it]

Epoch: 73 Loss: 0.013116178498894442


Epochs:  30%|███       | 75/250 [1:34:59<3:41:22, 75.90s/it]

Epoch: 74 Loss: 0.010255603932746453


Epochs:  30%|███       | 76/250 [1:36:15<3:40:05, 75.89s/it]

Epoch: 75 Loss: 0.010626369091041851


Epochs:  31%|███       | 77/250 [1:37:30<3:38:48, 75.89s/it]

Epoch: 76 Loss: 0.009811325753980782


Epochs:  31%|███       | 78/250 [1:38:46<3:37:33, 75.89s/it]

Epoch: 77 Loss: 0.008652451524394564


Epochs:  32%|███▏      | 79/250 [1:40:02<3:36:19, 75.90s/it]

Epoch: 78 Loss: 0.008127623828768265


Epochs:  32%|███▏      | 80/250 [1:41:18<3:35:06, 75.92s/it]

Epoch: 79 Loss: 0.011252497236491764


Epochs:  32%|███▏      | 81/250 [1:42:34<3:33:53, 75.94s/it]

Epoch: 80 Loss: 0.007763421248382656


Epochs:  33%|███▎      | 82/250 [1:43:50<3:32:39, 75.95s/it]

Epoch: 81 Loss: 0.006352471606296603


Epochs:  33%|███▎      | 83/250 [1:45:06<3:31:25, 75.96s/it]

Epoch: 82 Loss: 0.006771935470533208


Epochs:  34%|███▎      | 84/250 [1:46:22<3:30:10, 75.97s/it]

Epoch: 83 Loss: 0.0072323688150208905


Epochs:  34%|███▍      | 85/250 [1:47:38<3:28:49, 75.94s/it]

Epoch: 84 Loss: 0.009240988348246902


Epochs:  34%|███▍      | 86/250 [1:48:54<3:27:34, 75.94s/it]

Epoch: 85 Loss: 0.007680728758205077


Epochs:  35%|███▍      | 87/250 [1:50:10<3:26:18, 75.94s/it]

Epoch: 86 Loss: 0.013940382988439524


Epochs:  35%|███▌      | 88/250 [1:51:26<3:25:03, 75.95s/it]

Epoch: 87 Loss: 0.005527182729929336


Epochs:  36%|███▌      | 89/250 [1:52:42<3:23:47, 75.95s/it]

Epoch: 88 Loss: 0.007723114323089249


Epochs:  36%|███▌      | 90/250 [1:53:58<3:22:30, 75.94s/it]

Epoch: 89 Loss: 0.0054643683148970015


Epochs:  36%|███▋      | 91/250 [1:55:14<3:21:12, 75.93s/it]

Epoch: 90 Loss: 0.005353055291761848


Epochs:  37%|███▋      | 92/250 [1:56:30<3:19:54, 75.92s/it]

Epoch: 91 Loss: 0.00541404091174627


Epochs:  37%|███▋      | 93/250 [1:57:46<3:18:39, 75.92s/it]

Epoch: 92 Loss: 0.00819218242620991


Epochs:  38%|███▊      | 94/250 [1:59:01<3:17:22, 75.92s/it]

Epoch: 93 Loss: 0.006990799202976632


Epochs:  38%|███▊      | 95/250 [2:00:17<3:16:07, 75.92s/it]

Epoch: 94 Loss: 0.007144064451640588


Epochs:  38%|███▊      | 96/250 [2:01:33<3:14:50, 75.91s/it]

Epoch: 95 Loss: 0.00634813850863793


Epochs:  39%|███▉      | 97/250 [2:02:49<3:13:34, 75.91s/it]

Epoch: 96 Loss: 0.00699358410325658


Epochs:  39%|███▉      | 98/250 [2:04:05<3:12:18, 75.91s/it]

Epoch: 97 Loss: 0.00642144497098343


Epochs:  40%|███▉      | 99/250 [2:05:21<3:11:03, 75.91s/it]

Epoch: 98 Loss: 0.007201146289808094


Epochs:  40%|████      | 100/250 [2:06:37<3:09:47, 75.92s/it]

Epoch: 99 Loss: 0.008578206957899966


Epochs:  40%|████      | 101/250 [2:07:53<3:08:31, 75.92s/it]

Epoch: 100 Loss: 0.006643479610356735


Epochs:  41%|████      | 102/250 [2:09:09<3:07:15, 75.91s/it]

Epoch: 101 Loss: 0.005426391870432781


Epochs:  41%|████      | 103/250 [2:10:25<3:06:00, 75.92s/it]

Epoch: 102 Loss: 0.006426498476685083


Epochs:  42%|████▏     | 104/250 [2:11:41<3:04:46, 75.94s/it]

Epoch: 103 Loss: 0.006831397162895883


Epochs:  42%|████▏     | 105/250 [2:12:57<3:03:33, 75.96s/it]

Epoch: 104 Loss: 0.008257537404242612


Epochs:  42%|████▏     | 106/250 [2:14:13<3:02:20, 75.98s/it]

Epoch: 105 Loss: 0.005844958982343087


Epochs:  43%|████▎     | 107/250 [2:15:29<3:01:04, 75.98s/it]

Epoch: 106 Loss: 0.0059300198342680235


Epochs:  43%|████▎     | 108/250 [2:16:45<2:59:49, 75.98s/it]

Epoch: 107 Loss: 0.008385208913532552


Epochs:  44%|████▎     | 109/250 [2:18:01<2:58:34, 75.99s/it]

Epoch: 108 Loss: 0.00704836041724775


Epochs:  44%|████▍     | 110/250 [2:19:17<2:57:18, 75.99s/it]

Epoch: 109 Loss: 0.006461721303348895


Epochs:  44%|████▍     | 111/250 [2:20:33<2:56:05, 76.01s/it]

Epoch: 110 Loss: 0.004557399111217819


Epochs:  45%|████▍     | 112/250 [2:21:49<2:54:52, 76.03s/it]

Epoch: 111 Loss: 0.007164792363910238


Epochs:  45%|████▌     | 113/250 [2:23:05<2:53:36, 76.03s/it]

Epoch: 112 Loss: 0.008456054020134616


Epochs:  46%|████▌     | 114/250 [2:24:21<2:52:24, 76.06s/it]

Epoch: 113 Loss: 0.005663661411090288


Epochs:  46%|████▌     | 115/250 [2:25:37<2:51:08, 76.07s/it]

Epoch: 114 Loss: 0.00485388648485241


Epochs:  46%|████▋     | 116/250 [2:26:53<2:49:55, 76.08s/it]

Epoch: 115 Loss: 0.006693782485908742


Epochs:  47%|████▋     | 117/250 [2:28:09<2:48:40, 76.09s/it]

Epoch: 116 Loss: 0.0077156853980704905


Epochs:  47%|████▋     | 118/250 [2:29:25<2:47:24, 76.09s/it]

Epoch: 117 Loss: 0.007499015828070697


Epochs:  48%|████▊     | 119/250 [2:30:41<2:46:08, 76.09s/it]

Epoch: 118 Loss: 0.005439545958215604


Epochs:  48%|████▊     | 120/250 [2:31:57<2:44:50, 76.08s/it]

Epoch: 119 Loss: 0.004656282649921195


Epochs:  48%|████▊     | 121/250 [2:33:14<2:43:33, 76.08s/it]

Epoch: 120 Loss: 0.006350424746015051


Epochs:  49%|████▉     | 122/250 [2:34:30<2:42:15, 76.06s/it]

Epoch: 121 Loss: 0.0030073195903241867


Epochs:  49%|████▉     | 123/250 [2:35:46<2:41:04, 76.10s/it]

Epoch: 122 Loss: 0.008395577073752065


Epochs:  50%|████▉     | 124/250 [2:37:02<2:39:45, 76.07s/it]

Epoch: 123 Loss: 0.004730730010633124


Epochs:  50%|█████     | 125/250 [2:38:24<2:42:13, 77.87s/it]

Epoch: 124 Loss: 0.00591265849499905


Epochs:  50%|█████     | 126/250 [2:40:18<3:03:42, 88.89s/it]

Epoch: 125 Loss: 0.005446548876770976


Epochs:  50%|█████     | 126/250 [2:41:39<2:39:05, 76.98s/it]


KeyboardInterrupt: 

In [None]:
plt.plot(epoch_losses)

In [17]:
def get_metrics(y_pred,y_label):
    with torch.no_grad():
        f1 = f1_score(y_pred,y_label)
        precision = precision_score(y_pred,y_label)
        recall = recall_score(y_pred,y_label)
        return f1,precision,recall

def predict(model,X_test,y_test,batch_size = 32):
    result_dict = {}
    with torch.no_grad():
        model.eval()
        # incorporating batch size
        test_dataset = TensorDataset(X_test, y_test)
        test_loader = DataLoader(test_dataset, batch_size=batch_size)
        y_pred = []
        for x, y in tqdm(test_loader):
            x = x.to(device)
            y = y.to(device)
            output = model(x)
            y_pred.append(output)
        y_pred = torch.cat(y_pred)
        print(y_pred)
        y_pred = torch.argmax(y_pred,dim=1)
        print(y_pred)
        y_pred = y_pred.cpu()
        y_test = y_test.cpu()
        # print(y_pred.dtype,y_test.dtype)
        f1,precision,recall = get_metrics(y_pred,y_test)


        result_dict = {"f1":f1,"precision":precision,"recall":recall}
        return result_dict,y_pred,y_test


In [18]:
result,y_pred,y_test = predict(model,test_images,test_labels,batch_size=128)

100%|██████████| 36/36 [00:55<00:00,  1.54s/it]

tensor([[ 3.1660, -2.5737],
        [ 0.4591, -0.0945],
        [ 5.3940, -4.7773],
        ...,
        [ 0.2043,  0.5824],
        [ 0.3519,  0.3633],
        [ 0.8291, -0.1084]], device='cuda:0')
tensor([0, 0, 0,  ..., 1, 1, 0], device='cuda:0')





In [19]:
df = pd.DataFrame(result,index=[0])
df

Unnamed: 0,f1,precision,recall
0,tensor(0.0553),tensor(0.0532),tensor(0.0576)


In [None]:
y_test[:100]

In [None]:
for i in y_pred:
    if i.item() == 1:
        print(i)

In [None]:
train_labels

In [None]:
test_labels

In [1]:
%pip install pytorch-lightning

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.2.1-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.3.2-py3-none-any.whl.metadata (19 kB)
Collecting lightning-utilities>=0.8.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.11.2-py3-none-any.whl.metadata (4.7 kB)
Downloading pytorch_lightning-2.2.1-py3-none-any.whl (801 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m757.4 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading lightning_utilities-0.11.2-py3-none-any.whl (26 kB)
Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch-lightning
Successfully installed lightning-utilities-0.11.2 pytorch-lightning-2.2.1 torchme

In [2]:
%pip install lightly

Note: you may need to restart the kernel to use updated packages.


In [11]:
class DINO(torch.nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z

In [12]:
efficientnet_b0 = torchvision.models.efficientnet_b0()
backbone = nn.Sequential(*list(efficientnet_b0.children())[:-1])

In [13]:
backbone

Sequential(
  (0): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivation(
    

In [14]:
input_dim = 1280
model = DINO(backbone, input_dim).to(device)



In [15]:
model

DINO(
  (student_backbone): Sequential(
    (0): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (scale_activati

In [16]:
transform = DINOTransform()

In [17]:
train_dataset = TensorDataset(transformed_images, train_labels)
dataloader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

In [18]:
criterion = DINOLoss(
    output_dim=2048,
    warmup_teacher_temp_epochs=5,
)
criterion = criterion.to(device)

In [19]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [20]:
epochs = 10

In [21]:
print("Starting Training")
for epoch in range(epochs):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
    for batch in tqdm(dataloader):
        views = batch[0]
        update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
        update_momentum(model.student_head, model.teacher_head, m=momentum_val)
        views = [view.to(device) for view in views]
        global_views = views[:2]
        teacher_out = [model.forward_teacher(view.unsqueeze(0)) for view in global_views]
        student_out = [model.forward(view.unsqueeze(0)) for view in views]
        loss = criterion(teacher_out, student_out, epoch=epoch)
        total_loss += loss.detach()
        loss.backward()
        # We only cancel gradients of student head.
        model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

Starting Training


100%|██████████| 99/99 [10:51<00:00,  6.58s/it] 


epoch: 00, loss: 8.04164


100%|██████████| 99/99 [10:17<00:00,  6.24s/it]


epoch: 01, loss: 7.62096


100%|██████████| 99/99 [10:11<00:00,  6.17s/it]


epoch: 02, loss: 7.61836


100%|██████████| 99/99 [10:29<00:00,  6.35s/it]


epoch: 03, loss: 7.62281


100%|██████████| 99/99 [10:08<00:00,  6.15s/it]


epoch: 04, loss: 7.62415


 16%|█▌        | 16/99 [01:42<08:26,  6.10s/it]

In [47]:
transformed_images.shape

torch.Size([25500, 3, 224, 224])

In [2]:
weights = torch.load("/home/vannsh.jani/brick_kilns/dinov2_vits14_pretrain.pth")

In [4]:
weights.keys()

dict_keys(['cls_token', 'pos_embed', 'mask_token', 'patch_embed.proj.weight', 'patch_embed.proj.bias', 'blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.attn.qkv.weight', 'blocks.0.attn.qkv.bias', 'blocks.0.attn.proj.weight', 'blocks.0.attn.proj.bias', 'blocks.0.ls1.gamma', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.0.mlp.fc1.weight', 'blocks.0.mlp.fc1.bias', 'blocks.0.mlp.fc2.weight', 'blocks.0.mlp.fc2.bias', 'blocks.0.ls2.gamma', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.attn.qkv.weight', 'blocks.1.attn.qkv.bias', 'blocks.1.attn.proj.weight', 'blocks.1.attn.proj.bias', 'blocks.1.ls1.gamma', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.1.mlp.fc1.weight', 'blocks.1.mlp.fc1.bias', 'blocks.1.mlp.fc2.weight', 'blocks.1.mlp.fc2.bias', 'blocks.1.ls2.gamma', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.attn.qkv.weight', 'blocks.2.attn.qkv.bias', 'blocks.2.attn.proj.weight', 'blocks.2.attn.proj.bias', 'blocks.2.ls1.gamma', 'blocks.2

In [5]:
weights["norm.weight"].shape,weights["norm.bias"].shape

(torch.Size([384]), torch.Size([384]))