In [None]:
import math
import copy
import os
import glob
import argparse
import tarfile
import urllib.request
from pathlib import Path

import torch
import torchvision
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision.transforms import v2
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
import torch.distributed as dist

from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
from transformers import AutoImageProcessor, Dinov2Model
from huggingface_hub import snapshot_download
from timm.models.vision_transformer import vit_base_patch32_224


from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule

In [10]:
# globals
dataset_1_dir = './data/testset_1'
dataset_2_dir = './data/testset_2'
dataset_3_dir = './data/testset_3'
output_dir = './outputs'
pretrain_weights = 'dino-v1/dino-v1_small_100.pt'
resnet = torchvision.models.resnet18()
# resnet = torchvision.models.resnet34()
per_gpu_batch_size = 128

In [3]:
class ImageDataset(Dataset):
    def __init__(self, image_dir, image_list, labels=None,
                 resolution=224, split="train", apply_transforms=True):
        self.image_dir = Path(image_dir)
        self.image_list = image_list
        self.labels = labels
        self.split = split
        self.resolution = resolution
        self.apply_transforms = apply_transforms

        imagenet_mean = [0.485, 0.456, 0.406]
        imagenet_std = [0.229, 0.224, 0.225]

        if apply_transforms:
            if split == "train":
                self.transform = v2.Compose([
                    v2.RandomResizedCrop(resolution, scale=(0.8, 1.0)),
                    v2.RandomHorizontalFlip(p=0.5),
                    v2.ColorJitter(
                        brightness=0.4,
                        contrast=0.4,
                        saturation=0.4,
                        hue=0.1
                    ),
                    v2.ToImage(),
                    v2.ToDtype(torch.float32, scale=True),
                    v2.Normalize(mean=imagenet_mean, std=imagenet_std),
                ])
            else:
                self.transform = v2.Compose([
                    v2.Resize(256),
                    v2.CenterCrop(resolution),
                    v2.ToImage(),
                    v2.ToDtype(torch.float32, scale=True),
                    v2.Normalize(mean=imagenet_mean, std=imagenet_std),
                ])
        else:
            self.transform = None   # <-- important

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        img_path = self.image_dir / img_name

        img = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)   # -> Tensor (for supervised)
        # else: keep img as PIL (for SSL)

        if self.labels is not None:
            return img, self.labels[idx], img_name
        return img, img_name



In [4]:
def collate_fn(batch):
    """Custom collate function to handle PIL images"""
    if len(batch[0]) == 3:  # train/val (image, label, filename)
        images = [item[0] for item in batch]
        labels = [item[1] for item in batch]
        filenames = [item[2] for item in batch]
        return images, labels, filenames
    else:  # test (image, filename)
        images = [item[0] for item in batch]
        filenames = [item[1] for item in batch]
        return images, filenames

In [5]:
# ============================================================
# Hyperparameters (replace args.*)
# ============================================================
batch_size = 64
num_workers = 4
resolution = 224   # or whatever you want for training
# ============================================================

# Load CSV files
data_dir = Path(dataset_1_dir)

print("\nLoading dataset metadata...")
train_df = pd.read_csv(data_dir / 'train_labels.csv')
val_df = pd.read_csv(data_dir / 'val_labels.csv')
test_df = pd.read_csv(data_dir / 'test_labels_INTERNAL.csv')

print(f"  Train: {len(train_df)} images")
print(f"  Val:   {len(val_df)} images")
print(f"  Test:  {len(test_df)} images")
print(f"  Classes: {train_df['class_id'].nunique()}")

train_dataset1 = ImageDataset(
    data_dir / 'train',
    train_df['filename'].tolist(),
    train_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=False,
)

val_dataset1 = ImageDataset(
    data_dir / 'val',
    val_df['filename'].tolist(),
    val_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=False,
)

test_dataset1 = ImageDataset(
    data_dir / 'test',
    test_df['filename'].tolist(),
    labels=test_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=False,
)

train_loader1 = DataLoader(
    train_dataset1,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

val_loader1 = DataLoader(
    val_dataset1,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

test_loader1 = DataLoader(
    test_dataset1,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn,
)



Loading dataset metadata...
  Train: 8232 images
  Val:   1727 images
  Test:  1829 images
  Classes: 200


In [6]:
# ============================================================
# Hyperparameters (replace args.*)
# ============================================================
batch_size = 64
num_workers = 4
resolution = 224   # or whatever you want for training
# ============================================================

# Load CSV files
data_dir = Path(dataset_2_dir)

print("\nLoading dataset metadata...")
train_df = pd.read_csv(data_dir / 'train_labels.csv')
val_df = pd.read_csv(data_dir / 'val_labels.csv')
test_df = pd.read_csv(data_dir / 'test_labels_INTERNAL.csv')

print(f"  Train: {len(train_df)} images")
print(f"  Val:   {len(val_df)} images")
print(f"  Test:  {len(test_df)} images")
print(f"  Classes: {train_df['class_id'].nunique()}")

train_dataset2 = ImageDataset(
    data_dir / 'train',
    train_df['filename'].tolist(),
    train_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=False,
)

val_dataset2 = ImageDataset(
    data_dir / 'val',
    val_df['filename'].tolist(),
    val_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=False,
)

test_dataset2 = ImageDataset(
    data_dir / 'test',
    test_df['filename'].tolist(),
    labels=test_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=False,
)

train_loader2 = DataLoader(
    train_dataset2,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

val_loader2 = DataLoader(
    val_dataset2,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

test_loader2 = DataLoader(
    test_dataset2,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn,
)



Loading dataset metadata...
  Train: 26880 images
  Val:   5760 images
  Test:  5760 images
  Classes: 64


In [None]:
# ============================================================
# Hyperparameters (replace args.*)
# ============================================================
batch_size = 64
num_workers = 4
resolution = 224   # or whatever you want for training
# ============================================================

# Load CSV files
data_dir = Path(dataset_3_dir)

print("\nLoading dataset metadata...")
train_df = pd.read_csv(data_dir / 'train_labels.csv')
val_df = pd.read_csv(data_dir / 'val_labels.csv')
test_df = pd.read_csv(data_dir / 'test_labels_INTERNAL.csv')

print(f"  Train: {len(train_df)} images")
print(f"  Val:   {len(val_df)} images")
print(f"  Test:  {len(test_df)} images")
print(f"  Classes: {train_df['class_id'].nunique()}")

train_dataset3 = ImageDataset(
    data_dir / 'train',
    train_df['filename'].tolist(),
    train_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=False,
)

val_dataset3 = ImageDataset(
    data_dir / 'val',
    val_df['filename'].tolist(),
    val_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=False,
)

test_dataset3 = ImageDataset(
    data_dir / 'test',
    test_df['filename'].tolist(),
    labels=test_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=False,
)

train_loader3 = DataLoader(
    train_dataset3,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn
)

val_loader3 = DataLoader(
    val_dataset3,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn
)

test_loader3 = DataLoader(
    test_dataset3,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn
)



Loading dataset metadata...
  Train: 13895 images
  Val:   2977 images
  Test:  2978 images
  Classes: 397


In [8]:
def init_distributed():
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
        is_distributed = True
    else:
        # Fallback: single GPU / CPU
        rank = 0
        world_size = 1
        is_distributed = False
        if torch.cuda.is_available():
            device = torch.device("cuda:0")
        else:
            device = torch.device("cpu")
    return is_distributed, rank, world_size, device

is_distributed, rank, world_size, device = init_distributed()
print(f"is_distributed={is_distributed}, rank={rank}, world_size={world_size}, device={device}")

is_distributed=False, rank=0, world_size=1, device=cuda:0


In [None]:
class DINO(nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z


In [None]:
# --- Build same backbone as used for DINO pretraining ---
backbone = nn.Sequential(*list(resnet.children())[:-1])  # (B, 512, 1, 1)
input_dim = 512

dino_model = DINO(backbone, input_dim)

# --- Load your pre-trained DINO checkpoint ---
ckpt = torch.load(
    output_dir + pretrain_weights,
    map_location="cpu",
)

dino_model.load_state_dict(ckpt["model_state"], strict=True)



In [11]:
# 2. SSL transform
ssl_transform = DINOTransform()

# 3. SSL dataset wrapper
class CUB_SSL_Dataset(torch.utils.data.Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        img, _, _ = self.subset[idx]
        # import pdb
        # pdb.set_trace()
        return self.transform(img)  # returns list[Tensor] from DINOTransform

In [None]:
# 4. Combine train + val subsets (no labels)
from torch.utils.data import ConcatDataset
ssl_trainval_raw = ConcatDataset([
    train_dataset1, val_dataset1,
    train_dataset2, val_dataset2,
    train_dataset3, val_dataset3,
])


# 5. Build SSL dataset
ssl_trainval_ds = CUB_SSL_Dataset(
    subset=ssl_trainval_raw,
    transform=ssl_transform,
)

if is_distributed:
    train_sampler = DistributedSampler(
        ssl_trainval_ds,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
        drop_last=True,
    )
    shuffle = False  # sampler handles shuffling
else:
    train_sampler = None
    shuffle = True

# 6. Build SSL dataloader
ssl_loader = DataLoader(
    ssl_trainval_ds,
    batch_size=per_gpu_batch_size,
    shuffle=shuffle,
    drop_last=True,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True,
    # collate_fn=collate_fn
)

print("SSL training images:", len(ssl_trainval_ds))

if ssl_loader.dataset is not None:
    print(f"Dataset found. Total samples: {len(ssl_loader.dataset)}")
    print(f"Batch size: {ssl_loader.batch_size}")
    print(f"Number of workers: {ssl_loader.num_workers}")
else:
    print("Dataset is missing.")

In [None]:
# ---- Inspect one batch from SSL dataloader ----

batch = next(iter(ssl_loader))

print("\n=== SSL Batch Debug ===")

if isinstance(batch, list) or isinstance(batch, tuple):
    print(f"Batch is a {type(batch)} with length {len(batch)}")

# DINOTransform returns a list of views per item, but DataLoader collates it into:
# batch = list_of_views, where each element has shape (B, C, H, W)

# Example: batch[0] = global crops   shape: (B, 3, 224, 224)
#          batch[1] = global crops   shape: (B, 3, 224, 224)
#          batch[2] = local crops    shape: (B, 3, 96, 96)
# etc.

for i, view in enumerate(batch):
    print(f"\n--- View {i} ---")
    print(f"Type: {type(view)}")
    try:
        print(f"Shape: {view.shape}")
    except Exception:
        print("View has no `.shape` attribute")
    print(f"Dtype: {getattr(view, 'dtype', None)}")

print("\n=== End Debug ===\n")


In [None]:
batch = next(iter(ssl_loader))  # or ssl_loader

print("Type of batch:", type(batch))

if isinstance(batch, (list, tuple)):
    print("Batch length:", len(batch))
    for i, x in enumerate(batch):
        print(f"  item[{i}] type: {type(x)}")
        if hasattr(x, "shape"):
            print(f"  item[{i}] shape:", x.shape)
else:
    print("Batch shape:", batch.shape)


In [None]:
# ---------- Project / save dir ----------
PROJECT_NAME = "dino-v1"  # folder name
save_dir = output_dir / "dino-v1"
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
print(f"Save directory ready: {save_dir}")

device = "cuda" if torch.cuda.is_available() else "cpu"
dino_model.to(device)

ckpt_path = save_dir / f"{PROJECT_NAME}_full_finetuned_{backbone}.pt"
start_epoch = 0

optimizer = torch.optim.Adam(dino_model.parameters(), lr=0.001)

criterion = DINOLoss(
    output_dim=2048,
    warmup_teacher_temp_epochs=5,
)
# move loss to correct device because it also contains parameters
criterion = criterion.to(device)

# ---------- Training loop ----------
import time
from tqdm import tqdm

num_epochs = 100
print("Starting SSL Training (DINO)")

global_step = 0

for epoch in range(start_epoch, num_epochs):
    dino_model.train()
    total_loss = 0.0
    n_samples = 0

    epoch_start = time.time()

    # EMA momentum for teacher this epoch
    momentum_val = cosine_schedule(epoch, num_epochs, 0.996, 1.0)

    for batch in tqdm(ssl_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):

        # ssl_loader should give a list of crops (views)
        # Handle both "views" and "(views, _)" cases
        if isinstance(batch, (list, tuple)) and isinstance(batch[0], torch.Tensor):
            views = batch
        elif isinstance(batch, (list, tuple)):
            views = batch[0]
        else:
            views = batch

        # move all crops to GPU
        views = [v.to(device, non_blocking=True) for v in views]

        # ---- EMA update for teacher ----
        update_momentum(dino_model.student_backbone, dino_model.teacher_backbone, m=momentum_val)
        update_momentum(dino_model.student_head,     dino_model.teacher_head,     m=momentum_val)

        # first two are global crops for teacher
        global_views = views[:2]

        # teacher on global crops (no grad)
        with torch.no_grad():
            teacher_out = [dino_model.forward_teacher(v) for v in global_views]

        # student on all crops (global + local)
        student_out = [dino_model(v) for v in views]

        # ---- DINO loss ----
        loss = criterion(teacher_out, student_out, epoch=epoch)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        dino_model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
        optimizer.step()

        bsz = views[0].size(0)
        total_loss += loss.item() * bsz
        n_samples += bsz
        global_step += 1

    avg_loss = total_loss / max(n_samples, 1)
    elapsed = time.time() - epoch_start

    print(f"Epoch {epoch:02d} | train loss: {avg_loss:.5f} | time: {elapsed:.1f}s")

    if (epoch + 1) % 1 == 0:
        # ---- Save checkpoint (always same filename) ----
        ckpt = {
            "epoch": epoch,
            "model_state": dino_model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "avg_loss": avg_loss,
        }
        torch.save(ckpt, ckpt_path)
        print(f"âœ“ Saved checkpoint: {ckpt_path}")