<a href="https://colab.research.google.com/github/GersteinJo/CV-frameworks/blob/main/finetuning_dino.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torchvision import datasets, transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.io import read_image

In [2]:
class DataAugmentationDINO(object):
    def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
        flip_and_color_jitter = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
        ])
        normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        # first global crop
        self.global_transfo1 = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
            flip_and_color_jitter,
            transforms.GaussianBlur(3, 1.0),
            normalize,
        ])
        # second global crop
        self.global_transfo2 = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
            flip_and_color_jitter,
            transforms.GaussianBlur(3, 0.1),
            transforms.RandomSolarize(128, 0.2),
            normalize,
        ])
        # transformation for the local small crops
        self.local_crops_number = local_crops_number
        self.local_transfo = transforms.Compose([
            transforms.RandomResizedCrop(98, scale=local_crops_scale, interpolation=Image.BICUBIC),
            flip_and_color_jitter,
            transforms.GaussianBlur(3, 0.5),
            normalize,
        ])

    def __call__(self, image):
        crops = []
        crops.append(self.global_transfo1(image))
        crops.append(self.global_transfo2(image))
        for _ in range(self.local_crops_number):
            crops.append(self.local_transfo(image))
        return crops

In [3]:
from torch import nn
import numpy as np
import torch.nn.functional as F

In [4]:
class DINOLoss(nn.Module):
    def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        student_output [b, n, k]
        teacher_output [b, n, k]

        """
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        batch_center = batch_center / len(teacher_output)

        # ema update
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)


In [5]:
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json  # Restrict permissions
!kaggle datasets download -d kipshidze/drunk-vs-sober-infrared-image-dataset
!unzip -q "drunk-vs-sober-infrared-image-dataset.zip" -d /content/

Dataset URL: https://www.kaggle.com/datasets/kipshidze/drunk-vs-sober-infrared-image-dataset
License(s): MIT
Downloading drunk-vs-sober-infrared-image-dataset.zip to /content
  0% 0.00/1.45M [00:00<?, ?B/s]
100% 1.45M/1.45M [00:00<00:00, 869MB/s]


In [6]:
from google.colab import drive
drive.mount('/content/drive')

import torch
from torchvision import datasets, transforms

# Load DINOv2 onto GPU
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)




Mounted at /content/drive


Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth
100%|██████████| 84.2M/84.2M [00:00<00:00, 322MB/s]


In [7]:
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from PIL import Image

class SelfSupervisedDataset(Dataset):
    def __init__(self, csv_file, root_dir, augmentation=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.augmentation = augmentation

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_path).convert('RGB')

        if self.augmentation:
            crops = self.augmentation(image)  # Returns list of augmented crops
            return crops  # List of tensors [global1, global2, local1, local2, ...]
        return image

# DINO Augmentation (from your code)
global_crops_scale = (0.4, 1.0)
local_crops_scale = (0.05, 0.4)
local_crops_number = 8  # 8 small local crops (adjust if needed)

augmentation = DataAugmentationDINO(
    global_crops_scale=global_crops_scale,
    local_crops_scale=local_crops_scale,
    local_crops_number=local_crops_number,
)

# Create dataset and dataloader
train_dataset = SelfSupervisedDataset(csv_file='/content/train.csv', root_dir='/content/', augmentation=augmentation)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)

In [10]:
import torch
import torch.nn as nn

# Load DINOv2
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')

# Projection head (for DINO loss)
class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 2048),
            nn.GELU(),
            nn.Linear(2048, 2048),
            nn.GELU(),
            nn.Linear(2048, out_dim),
        )

    def forward(self, x):
        return self.mlp(x)

# Student & Teacher models
student = model.to(device)
teacher = model.to(device)

# Add projection heads
out_dim = 256  # Dimension for DINO loss
student.head = DINOHead(model.embed_dim, out_dim).to(device)
teacher.head = DINOHead(model.embed_dim, out_dim).to(device)

# Teacher does not get updated by gradients
for param in teacher.parameters():
    param.requires_grad = False

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


In [11]:
# Hyperparameters (adjust as needed)
from tqdm import tqdm
# Hyperparameters (adjust as needed)
lr = 1e-4
epochs = 100
warmup_epochs = 10
momentum_teacher = 0.996  # EMA decay rate

optimizer = torch.optim.AdamW(student.parameters(), lr=lr)

# Your DINOLoss
loss_fn = DINOLoss(
    out_dim=out_dim,
    ncrops=2 + local_crops_number,  # 2 global + N local crops
    warmup_teacher_temp=0.04,
    teacher_temp=0.07,
    warmup_teacher_temp_epochs=warmup_epochs,
    nepochs=epochs,
).to(device)

student.train()
# Training loop
for epoch in range(epochs):
    student.train()
    teacher.train()  # Teacher in train mode (for BatchNorm), but no gradients

    for crops in tqdm(train_loader):
        # crops: List of tensors [global1, global2, local1, local2, ...]
        crops = [crop.to(device).requires_grad_(True) for crop in crops]

        # Forward passes
        teacher_outputs = [teacher(crop) for crop in crops[:2]]  # Only global crops for teacher
        student_outputs = [student(crop) for crop in crops]  # All crops for student

        # Concatenate outputs
        teacher_output = torch.cat(teacher_outputs)
        student_output = torch.cat(student_outputs)

        # Compute DINO loss
        loss = loss_fn(student_output, teacher_output, epoch)

        # Backpropagation
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        # EMA update for teacher
        with torch.no_grad():
            for s_param, t_param in zip(student.parameters(), teacher.parameters()):
                t_param.data = momentum_teacher * t_param.data + (1 - momentum_teacher) * s_param.data

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

100%|██████████| 25/25 [00:21<00:00,  1.15it/s]


Epoch 0, Loss: 6.1673


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 1, Loss: 6.1956


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 2, Loss: 6.4770


100%|██████████| 25/25 [00:22<00:00,  1.09it/s]


Epoch 3, Loss: 6.2633


100%|██████████| 25/25 [00:23<00:00,  1.07it/s]


Epoch 4, Loss: 6.3069


100%|██████████| 25/25 [00:22<00:00,  1.12it/s]


Epoch 5, Loss: 6.4187


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 6, Loss: 6.6434


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 7, Loss: 6.2988


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 8, Loss: 6.2691


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 9, Loss: 6.3628


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 10, Loss: 6.3501


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 11, Loss: 6.3180


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 12, Loss: 6.2548


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 13, Loss: 6.6210


100%|██████████| 25/25 [00:22<00:00,  1.13it/s]


Epoch 14, Loss: 6.4417


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 15, Loss: 6.3641


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 16, Loss: 6.5674


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 17, Loss: 6.5832


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 18, Loss: 6.3913


100%|██████████| 25/25 [00:22<00:00,  1.12it/s]


Epoch 19, Loss: 6.3283


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 20, Loss: 6.4720


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 21, Loss: 6.4100


100%|██████████| 25/25 [00:23<00:00,  1.09it/s]


Epoch 22, Loss: 6.5172


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 23, Loss: 6.4985


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 24, Loss: 6.6639


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 25, Loss: 6.4253


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 26, Loss: 6.6434


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 27, Loss: 6.4239


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 28, Loss: 6.5823


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 29, Loss: 6.2753


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 30, Loss: 6.4159


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 31, Loss: 6.2491


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 32, Loss: 6.5995


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 33, Loss: 6.2737


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 34, Loss: 6.3580


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 35, Loss: 6.5059


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 36, Loss: 6.6398


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 37, Loss: 6.5693


100%|██████████| 25/25 [00:23<00:00,  1.07it/s]


Epoch 38, Loss: 6.3149


100%|██████████| 25/25 [00:24<00:00,  1.03it/s]


Epoch 39, Loss: 6.7736


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 40, Loss: 6.7250


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 41, Loss: 6.8468


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 42, Loss: 6.5508


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 43, Loss: 6.6081


100%|██████████| 25/25 [00:23<00:00,  1.07it/s]


Epoch 44, Loss: 6.3935


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 45, Loss: 6.1805


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 46, Loss: 6.3477


100%|██████████| 25/25 [00:22<00:00,  1.09it/s]


Epoch 47, Loss: 6.7621


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 48, Loss: 6.5100


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 49, Loss: 6.6374


100%|██████████| 25/25 [00:22<00:00,  1.09it/s]


Epoch 50, Loss: 6.2127


100%|██████████| 25/25 [00:23<00:00,  1.07it/s]


Epoch 51, Loss: 6.5505


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 52, Loss: 6.4079


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 53, Loss: 6.2575


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 54, Loss: 6.5221


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 55, Loss: 6.5759


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 56, Loss: 6.6573


100%|██████████| 25/25 [00:23<00:00,  1.07it/s]


Epoch 57, Loss: 6.4814


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 58, Loss: 6.5277


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 59, Loss: 6.6413


100%|██████████| 25/25 [00:23<00:00,  1.07it/s]


Epoch 60, Loss: 6.0232


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 61, Loss: 6.4754


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 62, Loss: 6.5966


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 63, Loss: 6.4156


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 64, Loss: 6.6075


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 65, Loss: 6.5081


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 66, Loss: 6.3986


100%|██████████| 25/25 [00:22<00:00,  1.12it/s]


Epoch 67, Loss: 6.2620


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 68, Loss: 6.5838


100%|██████████| 25/25 [00:23<00:00,  1.09it/s]


Epoch 69, Loss: 6.9229


100%|██████████| 25/25 [00:22<00:00,  1.13it/s]


Epoch 70, Loss: 6.3074


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 71, Loss: 6.4410


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 72, Loss: 6.3231


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 73, Loss: 6.5585


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 74, Loss: 6.4080


100%|██████████| 25/25 [00:28<00:00,  1.16s/it]


Epoch 75, Loss: 6.4482


100%|██████████| 25/25 [00:22<00:00,  1.12it/s]


Epoch 76, Loss: 6.4301


100%|██████████| 25/25 [00:24<00:00,  1.04it/s]


Epoch 77, Loss: 6.5608


100%|██████████| 25/25 [00:26<00:00,  1.07s/it]


Epoch 78, Loss: 6.5887


100%|██████████| 25/25 [00:25<00:00,  1.00s/it]


Epoch 79, Loss: 6.5279


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 80, Loss: 6.4717


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 81, Loss: 6.4864


100%|██████████| 25/25 [00:26<00:00,  1.05s/it]


Epoch 82, Loss: 6.2755


100%|██████████| 25/25 [00:23<00:00,  1.07it/s]


Epoch 83, Loss: 6.3391


100%|██████████| 25/25 [00:22<00:00,  1.13it/s]


Epoch 84, Loss: 6.7384


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 85, Loss: 6.3151


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 86, Loss: 6.4315


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 87, Loss: 6.3800


100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch 88, Loss: 6.7095


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 89, Loss: 6.3075


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]


Epoch 90, Loss: 6.2375


100%|██████████| 25/25 [00:23<00:00,  1.05it/s]


Epoch 91, Loss: 6.6491


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 92, Loss: 6.3722


100%|██████████| 25/25 [00:22<00:00,  1.12it/s]


Epoch 93, Loss: 6.3840


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 94, Loss: 6.2622


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]


Epoch 95, Loss: 6.4482


100%|██████████| 25/25 [00:22<00:00,  1.12it/s]


Epoch 96, Loss: 6.7534


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]


Epoch 97, Loss: 6.4741


100%|██████████| 25/25 [00:22<00:00,  1.10it/s]


Epoch 98, Loss: 6.3201


100%|██████████| 25/25 [00:22<00:00,  1.11it/s]

Epoch 99, Loss: 6.4824





In [12]:
!ls '/content/drunk_sober_data/drunk_sober_data/60mins/36_krod_4_e_M_54_90.jpg'

/content/drunk_sober_data/60mins/36_krod_4_e_M_54_90.jpg


In [12]:
import os

os.makedirs("/content/drive/MyDrive/pretrainedModelsCV", exist_ok=True)

torch.save(teacher.state_dict(), "/content/drive/MyDrive/pretrainedModelsCV/dino_finetuned_weights.pth")

# Alternatively, save the entire model (less flexible)
torch.save(teacher, "/content/drive/MyDrive/pretrainedModelsCV/dino_finetuned_model.pth")