<a href="https://colab.research.google.com/github/GersteinJo/CV-frameworks/blob/main/finetuning_dino.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
from torchvision import datasets, transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.io import read_image

In [3]:
class DataAugmentationDINO(object):
    def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
        flip_and_color_jitter = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
        ])
        normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        # first global crop
        self.global_transfo1 = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
            flip_and_color_jitter,
            transforms.GaussianBlur(3, 1.0),
            normalize,
        ])
        # second global crop
        self.global_transfo2 = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
            flip_and_color_jitter,
            transforms.GaussianBlur(3, 0.1),
            transforms.RandomSolarize(128, 0.2),
            normalize,
        ])
        # transformation for the local small crops
        self.local_crops_number = local_crops_number
        self.local_transfo = transforms.Compose([
            transforms.RandomResizedCrop(98, scale=local_crops_scale, interpolation=Image.BICUBIC),
            flip_and_color_jitter,
            transforms.GaussianBlur(3, 0.5),
            normalize,
        ])

    def __call__(self, image):
        crops = []
        crops.append(self.global_transfo1(image))
        crops.append(self.global_transfo2(image))
        for _ in range(self.local_crops_number):
            crops.append(self.local_transfo(image))
        return crops

In [4]:
from torch import nn
import numpy as np
import torch.nn.functional as F

In [5]:

class DINOLoss(nn.Module):
    def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        student_output [b, n, k]
        teacher_output [b, n, k]

        """
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        batch_center = batch_center / len(teacher_output)

        # ema update
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)


In [7]:
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json  # Restrict permissions
!kaggle datasets download -d bhavikjikadara/dog-and-cat-classification-dataset
!unzip -q "dog-and-cat-classification-dataset.zip" -d /content/

Dataset URL: https://www.kaggle.com/datasets/bhavikjikadara/dog-and-cat-classification-dataset
License(s): apache-2.0
Downloading dog-and-cat-classification-dataset.zip to /content
 97% 752M/775M [00:02<00:00, 197MB/s]
100% 775M/775M [00:03<00:00, 266MB/s]


In [8]:
from google.colab import drive
drive.mount('/content/drive')

import torch
from torchvision import datasets, transforms

# Load DINOv2 onto GPU
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
# model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)




Mounted at /content/drive


In [9]:
# Modify the SelfSupervisedDataset class:

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from PIL import Image

class SelfSupervisedDataset(Dataset):
    def __init__(self, root_dir, augmentation=None):
        self.root_dir = root_dir
        self.augmentation = augmentation
        self.image_paths = []

        # Walk through the directory structure and collect all image paths
        for class_dir in ['Cat', 'Dog']:
            class_path = os.path.join(root_dir, class_dir)
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(img_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.augmentation:
            crops = self.augmentation(image)  # Returns list of augmented crops
            return crops  # List of tensors [global1, global2, local1, local2, ...]
        return image

# DINO Augmentation (unchanged)
global_crops_scale = (0.4, 1.0)
local_crops_scale = (0.05, 0.4)
local_crops_number = 8  # 8 small local crops (adjust if needed)

augmentation = DataAugmentationDINO(
    global_crops_scale=global_crops_scale,
    local_crops_scale=local_crops_scale,
    local_crops_number=local_crops_number,
)

# Create dataset and dataloader - point to your PetImages folder
train_dataset = SelfSupervisedDataset(root_dir='/content/PetImages', augmentation=augmentation)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)

In [12]:
import torch
import torch.nn as nn

# Load DINOv2
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')

# Projection head (for DINO loss)
class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 2048),
            nn.GELU(),
            nn.Linear(2048, 2048),
            nn.GELU(),
            nn.Linear(2048, out_dim),
        )

    def forward(self, x):
        return self.mlp(x)

# Student & Teacher models
student = model.to(device)
teacher = model.to(device)

# Add projection heads
out_dim = 256  # Dimension for DINO loss
student.head = DINOHead(model.embed_dim, out_dim).to(device)
teacher.head = DINOHead(model.embed_dim, out_dim).to(device)

# Teacher does not get updated by gradients
for param in teacher.parameters():
    param.requires_grad = False

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


In [11]:
# Hyperparameters (adjust as needed)
from tqdm import tqdm
# Hyperparameters (adjust as needed)
lr = 1e-4
epochs = 2
warmup_epochs = 1
momentum_teacher = 0.996  # EMA decay rate

optimizer = torch.optim.AdamW(student.parameters(), lr=lr)

# Your DINOLoss
loss_fn = DINOLoss(
    out_dim=out_dim,
    ncrops=2 + local_crops_number,  # 2 global + N local crops
    warmup_teacher_temp=0.04,
    teacher_temp=0.07,
    warmup_teacher_temp_epochs=warmup_epochs,
    nepochs=epochs,
).to(device)

student.train()
# Training loop
for epoch in range(epochs):
    student.train()
    teacher.train()  # Teacher in train mode (for BatchNorm), but no gradients

    for crops in tqdm(train_loader):
        # crops: List of tensors [global1, global2, local1, local2, ...]
        crops = [crop.to(device).requires_grad_(True) for crop in crops]

        # Forward passes
        teacher_outputs = [teacher(crop) for crop in crops[:2]]  # Only global crops for teacher
        student_outputs = [student(crop) for crop in crops]  # All crops for student

        # Concatenate outputs
        teacher_output = torch.cat(teacher_outputs)
        student_output = torch.cat(student_outputs)

        # Compute DINO loss
        loss = loss_fn(student_output, teacher_output, epoch)

        # Backpropagation
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        # EMA update for teacher
        with torch.no_grad():
            for s_param, t_param in zip(student.parameters(), teacher.parameters()):
                t_param.data = momentum_teacher * t_param.data + (1 - momentum_teacher) * s_param.data

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

  0%|          | 0/1563 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [9]:
from google.colab import drive
drive.mount('/content/drive')

import torch
from torchvision import datasets, transforms

# Load DINOv2 onto GPU
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [13]:
import os

os.makedirs("/content/drive/MyDrive/pretrainedModelsCV", exist_ok=True)

torch.save(teacher.state_dict(), "/content/drive/MyDrive/pretrainedModelsCV/dino_finetuned_weights_cats_dogs.pth")

# Alternatively, save the entire model (less flexible)
torch.save(teacher, "/content/drive/MyDrive/pretrainedModelsCV/dino_finetuned_model_cats_dogs.pth")

In [12]:
!ls '/content/drunk_sober_data/drunk_sober_data/60mins/36_krod_4_e_M_54_90.jpg'

/content/drunk_sober_data/60mins/36_krod_4_e_M_54_90.jpg


In [13]:
from torch.hub import load

# 1. Load base model
teacher = load('facebookresearch/dinov2', 'dinov2_vits14')

# 2. Load state dict with strict=False
state_dict = torch.load("/content/drive/MyDrive/pretrainedModelsCV/dino_finetuned_weights_cats_dogs.pth", map_location=device)
teacher.load_state_dict(state_dict, strict=False)  # Ignore unexpected keys

# 3. Set to eval mode
teacher.eval()

teacher.to(device)

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (n

In [14]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import torch

# Define standard transform for non-augmented images
default_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

class CatDogDataset(Dataset):
    def __init__(self, root_dir, augmentation=None):
        self.root_dir = root_dir
        self.augmentation = augmentation
        self.transform = default_transform if augmentation is None else None
        self.class_map = {'Cat': 0, 'Dog': 1}
        self.image_paths = []
        self.labels = []

        # Load all image paths and labels
        for class_name in self.class_map:
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(class_dir, img_name))
                    self.labels.append(self.class_map[class_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            label = self.labels[idx]

            if self.augmentation:
                crops = self.augmentation(image)
                return crops[:2], label  # Return just global crops + label
            else:
                return self.transform(image), label  # Apply default transform
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a dummy image and label if there's an error
            dummy_img = torch.zeros(3, 224, 224)
            return dummy_img, 0

def simple_collate_fn(batch):
    """Simplified collate function that handles both augmented and non-augmented cases"""
    images = []
    labels = []

    for item in batch:
        if isinstance(item[0], list):  # Augmented case
            images.append(item[0][0])  # Take first global crop
            labels.append(item[1])
        else:  # Non-augmented case
            images.append(item[0])
            labels.append(item[1])

    return torch.stack(images), torch.tensor(labels)

# Create datasets
train_dataset = CatDogDataset(
    root_dir='/content/PetImages',
    # augmentation=augmentation  # Your DINO augmentation
)

train_dataset_u = CatDogDataset(
    root_dir='/content/PetImages'
)

# Create loaders with reduced workers and simplified collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,  # Reduced to 0 to avoid shared memory issues
    collate_fn=simple_collate_fn,
    pin_memory=True
)

train_loader_u = DataLoader(
    train_dataset_u,
    batch_size=16,
    shuffle=False,
    num_workers=0,  # Reduced to 0 to avoid shared memory issues
    collate_fn=simple_collate_fn,
    pin_memory=True
)

In [15]:
import metric
from metric import correlation_dissimilarity, train_linear_classifier, encode_set, embedding_plotter, get_tsne



from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from sklearn import preprocessing


from tqdm import tqdm
def train_linear_classifier1(X, y, test_size=0.2, random_state=42, **kwargs):
    """
    Trains a linear classifier (Logistic Regression) and returns the model and accuracy.

    Parameters:
    X (array-like): Feature matrix
    y (array-like): Target vector
    test_size (float): Proportion of data to use for testing (default: 0.2)
    random_state (int): Random seed for reproducibility (default: 42)
    **kwargs: Additional arguments to pass to LogisticRegression

    Returns:
    tuple: (trained_model, accuracy_score)
    """
    # Split data into training and test sets
    print(X.shape, y.shape)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state
    )
    scaler = preprocessing.StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Initialize and train the linear classifier
    model = LogisticRegression(**kwargs)
    model.fit(X_train, y_train)

    # Make predictions and calculate accuracy
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)

    return model, accuracy

def encode_set1(encoder_function: callable, loader, original_loader, device="cpu"):
    all_embeddings = []
    all_labels = []
    all_original_images = []

    with torch.no_grad():
        for images_dino, targets in tqdm(loader):
            images_dino = images_dino.to(device, non_blocking=True)
            embeddings = encoder_function(images_dino).cpu()
            all_embeddings.append(embeddings)
            all_labels.append(targets)
        for images_orig, _ in tqdm(original_loader):
            all_original_images.append(images_orig)

    embeddings = torch.cat(all_embeddings)  # [N, D]
    labels = torch.cat(all_labels)
    original_images = torch.cat(all_original_images)
    original_images = original_images.reshape(original_images.shape[0], -1)

    return (embeddings.numpy(),
            labels.numpy(),
            original_images.numpy())

def run(encoder_function: callable, loader: torch.utils.data.DataLoader, original_loader: torch.utils.data.DataLoader,
        logger = None, device = "cpu", embeddings_np = None, labels_np = None, original_images_np = None, **kwargs):
    if embeddings_np is None or labels_np is None or original_images_np is None:
        embeddings_np, labels_np, original_images_np = encode_set(encoder_function, loader, original_loader, device)
    print(embeddings_np.shape, labels_np.shape)
    cl_acc = train_linear_classifier(embeddings_np, labels_np, **kwargs)[1]
    cor_diss = correlation_dissimilarity(embeddings_np[:10], original_images_np[:10])

    print(embeddings_np.shape)
    if not logger is None:
        logger.log({
            "classification_accuracy" : cl_acc,
        })
        logger.log({
                "second_order_similarity" : cor_diss
        })
    print(f"classification_accuracy : {cl_acc}, \nsecond_order_similarity : {cor_diss}")
    data_df = get_tsne(embeddings_np, labels_np)
    return embeddings_np, labels_np, original_images_np, data_df



In [None]:
embeddings, labels, original_images = encode_set1(teacher, train_loader, train_loader_u, device)
import os

file_path = "/content/drive/MyDrive/Data/dino_counter_cat_dog.txt"

if not os.path.exists(file_path):
    with open(file_path, 'w') as file:
        file.write(f"0")
with open(file_path, 'r') as file:
    counter = int(file.read())+1


df = pd.DataFrame(np.hstack([embeddings, original_images]), columns = [f"emb_{i}" for i in range(embeddings.shape[1])]+[f"or_{i}" for i in range(original_images.shape[1])])
df['label'] = labels

df.to_csv(f'/content/drive/MyDrive/Data/dino_embedding{counter}.csv')

 13%|█▎        | 199/1563 [16:14<1:58:33,  5.22s/it]

In [26]:
import wandb
run_name = f'finetuned_dino_metric_cat_dog'
config = {
    "encoder" : "dino_finetuned_cat_dog",
    "type_log" : "metric",
}

import os

file_path = "/content/drive/MyDrive/Data/dino_counter_cat_dog.txt"

if not os.path.exists(file_path):
    with open(file_path, 'w') as file:
        file.write(f"0")
with open(file_path, 'r') as file:
    counter = int(file.read())+1

if counter > 1:
    df = pd.read_csv('/content/drive/MyDrive/Data/dino_embedding.csv')
    embeddings, labels, original_images = df[[c for c in df.columns if 'emb' in c]].to_numpy(), df['label'].to_numpy(), df[[c for c in df.columns if 'or' in c]].to_numpy()
else:
    embeddings, labels, original_images = None, None, None

logger = wandb.init(project = 'CV_frameworks', config = config, name = run_name)

# for batch, _ in all_loader:
#     print(batch.size(), teacher(batch).size(), _.size())
embeddings, labels, original_images, tsne_data = run(teacher, train_loader, train_loader_u, logger = logger, device= device,
                                                     embeddings_np = embeddings, labels_np = labels, original_images_np = original_images,
                                                     max_iter = 1000)
logger.finish()

import pandas as pd

df = pd.DataFrame(np.hstack([embeddings, original_images]), columns = [f"emb_{i}" for i in range(embeddings.shape[1])]+[f"or_{i}" for i in range(original_images.shape[1])])
df['label'] = labels




df.to_csv(f'/content/drive/MyDrive/Data/dino_embedding{counter}.csv')

tsne_data.to_csv(f'/content/drive/MyDrive/Data/dino_tsne_embedding{counter}.csv')

KeyboardInterrupt: 