In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, Dinov2Model
from PIL import Image
from pathlib import Path
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
import argparse
import tarfile
import urllib.request
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset, Subset
from torchvision.transforms import v2

In [5]:
def collate_fn(batch):
    """Custom collate function to handle PIL images"""
    if len(batch[0]) == 3:  # train/val (image, label, filename)
        images = [item[0] for item in batch]
        labels = [item[1] for item in batch]
        filenames = [item[2] for item in batch]
        return images, labels, filenames
    else:  # test (image, filename)
        images = [item[0] for item in batch]
        filenames = [item[1] for item in batch]
        return images, filenames

In [6]:
class ImageDataset(Dataset):
    def __init__(self, image_dir, image_list, labels=None,
                 resolution=224, split="train", apply_transforms=True):
        self.image_dir = Path(image_dir)
        self.image_list = image_list
        self.labels = labels
        self.split = split
        self.resolution = resolution
        self.apply_transforms = apply_transforms

        imagenet_mean = [0.485, 0.456, 0.406]
        imagenet_std = [0.229, 0.224, 0.225]

        if apply_transforms:
            if split == "train":
                self.transform = v2.Compose([
                    v2.RandomResizedCrop(resolution, scale=(0.8, 1.0)),
                    v2.RandomHorizontalFlip(p=0.5),
                    v2.ColorJitter(
                        brightness=0.4,
                        contrast=0.4,
                        saturation=0.4,
                        hue=0.1
                    ),
                    v2.ToImage(),
                    v2.ToDtype(torch.float32, scale=True),
                    v2.Normalize(mean=imagenet_mean, std=imagenet_std),
                ])
            else:
                self.transform = v2.Compose([
                    v2.Resize(256),
                    v2.CenterCrop(resolution),
                    v2.ToImage(),
                    v2.ToDtype(torch.float32, scale=True),
                    v2.Normalize(mean=imagenet_mean, std=imagenet_std),
                ])
        else:
            self.transform = None   # <-- important

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        img_path = self.image_dir / img_name

        img = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)   # -> Tensor (for supervised)
        # else: keep img as PIL (for SSL)

        if self.labels is not None:
            return img, self.labels[idx], img_name
        return img, img_name



In [10]:
from pathlib import Path
import pandas as pd
from torch.utils.data import DataLoader

# ============================================================
# Hyperparameters (replace args.*)
# ============================================================
batch_size = 64
num_workers = 4
resolution = 224   # or whatever you want for training
# ============================================================

# Load CSV files
data_dir = Path("/home/long/code/amogh/data/testset_1")

print("\nLoading dataset metadata...")
train_df = pd.read_csv(data_dir / 'train_labels.csv')
val_df = pd.read_csv(data_dir / 'val_labels.csv')
test_df = pd.read_csv(data_dir / 'test_labels_INTERNAL.csv')

print(f"  Train: {len(train_df)} images")
print(f"  Val:   {len(val_df)} images")
print(f"  Test:  {len(test_df)} images")
print(f"  Classes: {train_df['class_id'].nunique()}")

train_dataset1 = ImageDataset(
    data_dir / 'train',
    train_df['filename'].tolist(),
    train_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=True,
)

val_dataset1 = ImageDataset(
    data_dir / 'val',
    val_df['filename'].tolist(),
    val_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=True,
)

test_dataset1 = ImageDataset(
    data_dir / 'test',
    test_df['filename'].tolist(),
    labels=test_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=True,
)

train_loader1 = DataLoader(
    train_dataset1,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

val_loader1 = DataLoader(
    val_dataset1,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

test_loader1 = DataLoader(
    test_dataset1,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=collate_fn,
)



Loading dataset metadata...
  Train: 8232 images
  Val:   1727 images
  Test:  1829 images
  Classes: 200


In [11]:
#load the model

In [12]:
import tarfile
import urllib.request
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset, Subset
from torchvision.transforms import v2

In [15]:
import torch
import torchvision
from timm.models.vision_transformer import vit_base_patch32_224
from torch import nn
from lightly.models import utils
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms import MAETransform
import copy
from lightly.models.modules import DINOProjectionHead
from lightly.loss import DINOLoss  # only needed if you re-train SSL
from lightly.models.utils import deactivate_requires_grad

In [16]:
class DINO(nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z


In [17]:
import torchvision

# --- Build same backbone as used for DINO pretraining ---
resnet = torchvision.models.resnet18()
# resnet = torchvision.models.resnet34()
backbone = nn.Sequential(*list(resnet.children())[:-1])  # (B, 512, 1, 1)
input_dim = 512

dino_model = DINO(backbone, input_dim)

# --- Load your pre-trained DINO checkpoint ---
ckpt = torch.load(
    "/home/long/code/dl_project1/experiments/outputs/dino-v1/dino-v1_small_100.pt",
    map_location="cpu",
)

dino_model.load_state_dict(ckpt["model_state"], strict=True)



  WeightNorm.apply(module, name, dim)


<All keys matched successfully>

In [18]:
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule

In [21]:
# Freeze everything (linear probe only)
for p in dino_model.parameters():
    p.requires_grad = False

dino_model.eval()
print("done")

done


In [24]:
class DINOEncoderWrapper(nn.Module):
    """Wraps the DINO student backbone and returns a flat feature vector."""

    def __init__(self, dino_model):
        super().__init__()
        self.backbone = dino_model.student_backbone

    def forward(self, x):
        feats = self.backbone(x)          # (B, 512, 1, 1) for ResNet18 backbone
        if isinstance(feats, (list, tuple)):
            feats = feats[0]
        feats = feats.flatten(1)          # (B, 512)
        return feats

class LinearProbeModel(nn.Module):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier

    def forward(self, x):
        with torch.no_grad():              # encoder frozen
            feats = self.encoder(x)
            # feats = feats.flatten(1)
        logits = self.classifier(feats)
        return logits

NUM_CLASSES = train_df['class_id'].nunique()

print(NUM_CLASSES," total classes")

feat_dim   = 512  # ResNet18 backbone
classifier = nn.Linear(feat_dim, NUM_CLASSES)
# model = dino_model.to(device)
encoder = DINOEncoderWrapper(dino_model)
model   = LinearProbeModel(encoder, classifier).to("cuda")



200  total classes


In [25]:
# Only train the classifier
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

params = list(model.classifier.parameters())
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    params,
    lr=3e-3,          # tune between 1e-3 and 3e-3
    weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=50,         # num_epochs
)

scaler = torch.cuda.amp.GradScaler()

Using device: cuda


  scaler = torch.cuda.amp.GradScaler()


In [27]:
# --- Combine train + val ---
from torch.utils.data import ConcatDataset, DataLoader

trainval_ds = ConcatDataset([train_dataset1, val_dataset1])

trainval_loader = DataLoader(
    trainval_ds,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
)


In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

encoder = model.encoder.to(device)
encoder.eval()
for p in encoder.parameters():
    p.requires_grad = False


@torch.no_grad()
def extract_features(encoder, loader, device, desc="Extracting"):
    encoder.eval()
    all_feats = []
    all_labels = []

    # tqdm wrapper around the dataloader
    for batch in tqdm(loader, desc=desc):
        if len(batch) == 3:
            images, labels, _ = batch
        else:
            images, labels = batch

        images = images.to(device, non_blocking=True)

        feats = encoder(images)   # shape: (B, C, H, W) or (B, D)
        if feats.dim() > 2:
            feats = feats.flatten(1)  # (B, D)

        all_feats.append(feats.cpu())
        all_labels.append(labels.cpu())

    all_feats = torch.cat(all_feats, dim=0)   # (N, D)
    all_labels = torch.cat(all_labels, dim=0) # (N,)
    return all_feats, all_labels


print("Extracting train+val features...")
trainval_feats, trainval_labels = extract_features(
    encoder, trainval_loader, device, desc="Train/Val"
)

print("Extracting test features...")
test_feats, test_labels = extract_features(
    encoder, test_loader1, device, desc="Test"
)

feat_dim = trainval_feats.shape[1]
num_classes = int(trainval_labels.max().item() + 1)

print(f"Feature dim = {feat_dim}, num_classes = {num_classes}")


Extracting train+val features...


Train/Val:  96%|█████████████████████████████████████████████████████████████████████████████████▏   | 149/156 [00:16<00:00,  8.42it/s]

In [None]:
trainval_feat_ds = TensorDataset(trainval_feats, trainval_labels)
test_feat_ds     = TensorDataset(test_feats,     test_labels)

feat_batch_size = 512  # can be big, it's cheap now

trainval_feat_loader = DataLoader(
    trainval_feat_ds, batch_size=feat_batch_size,
    shuffle=True, num_workers=0
)

test_feat_loader = DataLoader(
    test_feat_ds, batch_size=feat_batch_size,
    shuffle=False, num_workers=0
)
