In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, Dinov2Model
from PIL import Image
from pathlib import Path
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
import argparse
import tarfile
import urllib.request
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset, Subset
from torchvision.transforms import v2

In [2]:
# def collate_fn(batch):
#     """Custom collate function to handle PIL images"""
#     if len(batch[0]) == 3:  # train/val (image, label, filename)
#         images = [item[0] for item in batch]
#         labels = [item[1] for item in batch]
#         filenames = [item[2] for item in batch]
#         return images, labels, filenames
#     else:  # test (image, filename)
#         images = [item[0] for item in batch]
#         filenames = [item[1] for item in batch]
#         return images, filenames

In [3]:
# class ImageDataset(Dataset):
#     def __init__(self, image_dir, image_list, labels=None,
#                  resolution=224, split="train", apply_transforms=True):
#         self.image_dir = Path(image_dir)
#         self.image_list = image_list
#         self.labels = labels
#         self.split = split
#         self.resolution = resolution
#         self.apply_transforms = apply_transforms

#         imagenet_mean = [0.485, 0.456, 0.406]
#         imagenet_std = [0.229, 0.224, 0.225]

#         if apply_transforms:
#             if split == "train":
#                 self.transform = v2.Compose([
#                     v2.RandomResizedCrop(resolution, scale=(0.8, 1.0)),
#                     v2.RandomHorizontalFlip(p=0.5),
#                     v2.ColorJitter(
#                         brightness=0.4,
#                         contrast=0.4,
#                         saturation=0.4,
#                         hue=0.1
#                     ),
#                     v2.ToImage(),
#                     v2.ToDtype(torch.float32, scale=True),
#                     v2.Normalize(mean=imagenet_mean, std=imagenet_std),
#                 ])
#             else:
#                 self.transform = v2.Compose([
#                     v2.Resize(256),
#                     v2.CenterCrop(resolution),
#                     v2.ToImage(),
#                     v2.ToDtype(torch.float32, scale=True),
#                     v2.Normalize(mean=imagenet_mean, std=imagenet_std),
#                 ])
#         else:
#             self.transform = None   # <-- important

#     def __len__(self):
#         return len(self.image_list)

#     def __getitem__(self, idx):
#         img_name = self.image_list[idx]
#         img_path = self.image_dir / img_name

#         img = Image.open(img_path).convert('RGB')

#         if self.transform is not None:
#             img = self.transform(img)   # -> Tensor (for supervised)
#         # else: keep img as PIL (for SSL)

#         if self.labels is not None:
#             return img, self.labels[idx], img_name
#         return img, img_name



class ImageDataset(Dataset):
    def __init__(self, image_dir, image_list, labels=None,
                 resolution=224, split="train", apply_transforms=True):
        self.image_dir = Path(image_dir)
        self.image_list = image_list
        self.labels = labels
        self.split = split
        self.resolution = resolution
        self.apply_transforms = apply_transforms

        imagenet_mean = [0.485, 0.456, 0.406]
        imagenet_std = [0.229, 0.224, 0.225]

        if apply_transforms:
            if split == "train":
                self.transform = v2.Compose([
                    v2.RandomResizedCrop(resolution, scale=(0.8, 1.0)),
                    v2.RandomHorizontalFlip(p=0.5),
                    v2.ColorJitter(
                        brightness=0.4,
                        contrast=0.4,
                        saturation=0.4,
                        hue=0.1
                    ),
                    v2.ToImage(),
                    v2.ToDtype(torch.float32, scale=True),
                    v2.Normalize(mean=imagenet_mean, std=imagenet_std),
                ])
            else:
                self.transform = v2.Compose([
                    v2.Resize(256),
                    v2.CenterCrop(resolution),
                    v2.ToImage(),
                    v2.ToDtype(torch.float32, scale=True),
                    v2.Normalize(mean=imagenet_mean, std=imagenet_std),
                ])
        else:
            self.transform = None

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        img_path = self.image_dir / img_name

        img = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        if self.labels is not None:
            if self.split == "test":
                # For test: return (img, label, idx) so we can align with test_df
                return img, self.labels[idx], idx
            else:
                # For train/val: keep old behavior (img, label, img_name)
                return img, self.labels[idx], img_name

        # Unlabeled case (e.g., SSL)
        if self.split == "test":
            return img, idx
        else:
            return img, img_name


In [13]:
from pathlib import Path
import pandas as pd
from torch.utils.data import DataLoader

# ============================================================
# Hyperparameters (replace args.*)
# ============================================================
batch_size = 64
num_workers = 4
resolution = 224   # or whatever you want for training
# ============================================================

# Load CSV files
data_dir = Path("/home/long/code/amogh/data/testset_3")

print("\nLoading dataset metadata...")
train_df = pd.read_csv(data_dir / 'train_labels.csv')
val_df = pd.read_csv(data_dir / 'val_labels.csv')
test_df = pd.read_csv(data_dir / 'test_labels_INTERNAL.csv')

print(f"  Train: {len(train_df)} images")
print(f"  Val:   {len(val_df)} images")
print(f"  Test:  {len(test_df)} images")
print(f"  Classes: {train_df['class_id'].nunique()}")

train_dataset1 = ImageDataset(
    data_dir / 'train',
    train_df['filename'].tolist(),
    train_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=True,
)

val_dataset1 = ImageDataset(
    data_dir / 'val',
    val_df['filename'].tolist(),
    val_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=True,
)

test_dataset1 = ImageDataset(
    data_dir / 'test',
    test_df['filename'].tolist(),
    labels=test_df['class_id'].tolist(),
    resolution=resolution,
    apply_transforms=True,
    split="test"
)

train_loader1 = DataLoader(
    train_dataset1,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    # collate_fn=collate_fn,
)

val_loader1 = DataLoader(
    val_dataset1,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    # collate_fn=collate_fn,
)

test_loader1 = DataLoader(
    test_dataset1,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    # collate_fn=collate_fn,
)



Loading dataset metadata...
  Train: 13895 images
  Val:   2977 images
  Test:  2978 images
  Classes: 397


In [14]:
#load the model

In [15]:
import tarfile
import urllib.request
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset, Subset
from torchvision.transforms import v2

In [16]:
import torch
import torchvision
from timm.models.vision_transformer import vit_base_patch32_224
from torch import nn
from lightly.models import utils
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms import MAETransform
import copy
from lightly.models.modules import DINOProjectionHead
from lightly.loss import DINOLoss  # only needed if you re-train SSL
from lightly.models.utils import deactivate_requires_grad

In [17]:
class DINO(nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z


In [18]:
import torchvision

# --- Build same backbone as used for DINO pretraining ---
resnet = torchvision.models.resnet18()
# resnet = torchvision.models.resnet34()
backbone = nn.Sequential(*list(resnet.children())[:-1])  # (B, 512, 1, 1)
input_dim = 512

dino_model = DINO(backbone, input_dim)

# --- Load your pre-trained DINO checkpoint ---
# ckpt = torch.load(
#     "/home/long/code/dl_project1/experiments/outputs/dino-v1/dino-v1_small_100.pt",
#     map_location="cpu",
# )

# ckpt = torch.load(
#     "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18.pt",
#     map_location="cpu",
# )

# ckpt18_path = 
ckpt = torch.load(
    "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18_test2fixed_data3.pt",
    map_location="cpu",
)

dino_model.load_state_dict(ckpt["model_state"], strict=True)



  WeightNorm.apply(module, name, dim)


<All keys matched successfully>

In [19]:
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule

In [20]:
# Freeze everything (linear probe only)
for p in dino_model.parameters():
    p.requires_grad = False

dino_model.eval()
print("done")

done


In [21]:
# Only train the classifier
device = "cuda:2" if torch.cuda.is_available() else "cpu"

In [22]:
class DINOEncoderWrapper(nn.Module):
    """Wraps the DINO student backbone and returns a flat feature vector."""

    def __init__(self, dino_model):
        super().__init__()
        self.backbone = dino_model.student_backbone

    def forward(self, x):
        feats = self.backbone(x)          # (B, 512, 1, 1) for ResNet18 backbone
        if isinstance(feats, (list, tuple)):
            feats = feats[0]
        feats = feats.flatten(1)          # (B, 512)
        return feats

class LinearProbeModel(nn.Module):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier

    def forward(self, x):
        with torch.no_grad():              # encoder frozen
            feats = self.encoder(x)
            # feats = feats.flatten(1)
        logits = self.classifier(feats)
        return logits



import math
import torch
import torch.nn as nn

class DINOEncoderWrapper(nn.Module):
    """Wraps the DINO student backbone and returns a 1024-dim flat feature vector."""
    def __init__(self, dino_model, in_dim=512, proj_dim=1024):
        super().__init__()
        self.backbone = dino_model.student_backbone

        # ----- Fixed, non-trainable projection: (in_dim -> proj_dim) -----
        proj = torch.randn(in_dim, proj_dim) / math.sqrt(in_dim)
        # register as a buffer so it's moved with .to(device) but NOT a parameter
        self.register_buffer("proj", proj)

        # (optional) also freeze backbone params explicitly
        for p in self.backbone.parameters():
            p.requires_grad = False

    def forward(self, x):
        feats = self.backbone(x)          # (B, 512, 1, 1) for ResNet18 backbone
        if isinstance(feats, (list, tuple)):
            feats = feats[0]
        feats = feats.flatten(1)          # (B, 512)
        feats = feats @ self.proj         # (B, 1024), fixed linear map
        return feats


NUM_CLASSES = train_df['class_id'].nunique()

print(NUM_CLASSES," total classes")

# feat_dim   = 512  # ResNet18 backbone
feat_dim   = 512  # ResNet18 backbone
classifier = nn.Linear(feat_dim, NUM_CLASSES)
# model = dino_model.to(device)
encoder = DINOEncoderWrapper(dino_model)
model   = LinearProbeModel(encoder, classifier).to(device)


397  total classes


In [23]:
import math
import torch
import torch.nn as nn
import torchvision

device = "cuda:3" if torch.cuda.is_available() else "cpu"

# -------------------------------
# 1. DINO encoder wrapper (same idea as yours)
# -------------------------------
class DINOEncoderWrapper(nn.Module):
    """Wraps the DINO student backbone and returns a flat feature vector."""
    def __init__(self, dino_model):
        super().__init__()
        self.backbone = dino_model.student_backbone

        # (optional) extra safety: freeze backbone params
        for p in self.backbone.parameters():
            p.requires_grad = False

    def forward(self, x):
        feats = self.backbone(x)          # (B, 512, 1, 1) for ResNet18/34
        if isinstance(feats, (list, tuple)):
            feats = feats[0]
        feats = feats.flatten(1)          # (B, 512)
        return feats

# -------------------------------
# 2. Build DINO18 and DINO34 models from ckpts
# -------------------------------
# paths to your checkpoints
# ckpt18_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18.pt"
# ckpt18_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18_test2fixed_data2.pt"
ckpt18_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18_test2fixed_data3.pt"

ckpt34_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_big_boi_test_fixed.pt"

# ---- ResNet18 backbone ----
resnet18 = torchvision.models.resnet18()
backbone18 = nn.Sequential(*list(resnet18.children())[:-1])  # (B, 512, 1, 1)
in_dim18 = 512

dino18_model = DINO(backbone18, in_dim18)
ckpt18 = torch.load(ckpt18_path, map_location="cpu")
dino18_model.load_state_dict(ckpt18["model_state"], strict=True)
dino18_model.to(device)

# ---- ResNet34 backbone ----
resnet34 = torchvision.models.resnet34()
backbone34 = nn.Sequential(*list(resnet34.children())[:-1])  # (B, 512, 1, 1)
in_dim34 = 512

dino34_model = DINO(backbone34, in_dim34)
ckpt34 = torch.load(ckpt34_path, map_location="cpu")
dino34_model.load_state_dict(ckpt34["model_state"], strict=True)
dino34_model.to(device)

# -------------------------------
# 3. Wrap each into an encoder
# -------------------------------
encoder18 = DINOEncoderWrapper(dino18_model).to(device)
encoder34 = DINOEncoderWrapper(dino34_model).to(device)

# -------------------------------
# 4. Concat encoder
# -------------------------------
class ConcatEncoder(nn.Module):
    """Runs both encoders and concatenates their features."""
    def __init__(self, enc18, enc34):
        super().__init__()
        self.enc18 = enc18
        self.enc34 = enc34

        # freeze everything inside concat encoder (just in case)
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, x):
        f18 = self.enc18(x)   # (B, 512)
        f34 = self.enc34(x)   # (B, 512)
        return torch.cat([f18, f34], dim=1)  # (B, 1024)

concat_encoder = ConcatEncoder(encoder18, encoder34).to(device)

# -------------------------------
# 5. Linear probe on top of concatenated features
# -------------------------------
NUM_CLASSES = train_df["class_id"].nunique()
feat_dim_concat = 512 * 2  # 1024

classifier_concat = nn.Linear(feat_dim_concat, NUM_CLASSES).to(device)

class LinearProbeModel(nn.Module):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier

    def forward(self, x):
        with torch.no_grad():          # encoder frozen
            feats = self.encoder(x)    # (B, 1024)
        logits = self.classifier(feats)
        return logits

model = LinearProbeModel(concat_encoder, classifier_concat).to(device)


In [24]:
import math
import torch
import torch.nn as nn
import torchvision

device = "cuda:3" if torch.cuda.is_available() else "cpu"

# -------------------------------
# 1. DINO encoder wrapper
# -------------------------------
class DINOEncoderWrapper(nn.Module):
    """Wraps the DINO student backbone and returns a flat feature vector."""
    def __init__(self, dino_model):
        super().__init__()
        self.backbone = dino_model.student_backbone

        # freeze backbone params
        for p in self.backbone.parameters():
            p.requires_grad = False

    def forward(self, x):
        feats = self.backbone(x)          # (B, 512, 1, 1)
        if isinstance(feats, (list, tuple)):
            feats = feats[0]
        feats = feats.flatten(1)          # (B, 512)
        return feats

# -------------------------------
# 2. Build DINO18 (A), DINO18 (B) and DINO34 from ckpts
# -------------------------------
# ckpt18_a_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18_test2fixed_data2.pt"
ckpt18_a_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18_test2fixed_data3.pt"

ckpt18_b_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18_test2fixed.pt"
ckpt34_path   = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_big_boi_test_fixed.pt"

# ---- ResNet18 backbone A ----
resnet18_a = torchvision.models.resnet18()
backbone18_a = nn.Sequential(*list(resnet18_a.children())[:-1])  # (B, 512, 1, 1)
in_dim18 = 512

dino18_a = DINO(backbone18_a, in_dim18)
ckpt18_a = torch.load(ckpt18_a_path, map_location="cpu")
dino18_a.load_state_dict(ckpt18_a["model_state"], strict=True)
dino18_a.to(device)

# ---- ResNet18 backbone B (second checkpoint) ----
resnet18_b = torchvision.models.resnet18()
backbone18_b = nn.Sequential(*list(resnet18_b.children())[:-1])
dino18_b = DINO(backbone18_b, in_dim18)
ckpt18_b = torch.load(ckpt18_b_path, map_location="cpu")
dino18_b.load_state_dict(ckpt18_b["model_state"], strict=True)
dino18_b.to(device)

# ---- ResNet34 backbone ----
resnet34 = torchvision.models.resnet34()
backbone34 = nn.Sequential(*list(resnet34.children())[:-1])  # (B, 512, 1, 1)
in_dim34 = 512

dino34 = DINO(backbone34, in_dim34)
ckpt34 = torch.load(ckpt34_path, map_location="cpu")
dino34.load_state_dict(ckpt34["model_state"], strict=True)
dino34.to(device)

# -------------------------------
# 3. Wrap each into an encoder
# -------------------------------
encoder18_a = DINOEncoderWrapper(dino18_a).to(device)
encoder18_b = DINOEncoderWrapper(dino18_b).to(device)
encoder34   = DINOEncoderWrapper(dino34).to(device)

# -------------------------------
# 4. Concat encoder (3-way)
# -------------------------------
class ConcatEncoder(nn.Module):
    """Runs three encoders and concatenates their features."""
    def __init__(self, enc18_a, enc18_b, enc34):
        super().__init__()
        self.enc18_a = enc18_a
        self.enc18_b = enc18_b
        self.enc34   = enc34

        # freeze everything inside concat encoder
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, x):
        f18_a = self.enc18_a(x)   # (B, 512)
        f18_b = self.enc18_b(x)   # (B, 512)
        f34   = self.enc34(x)     # (B, 512)
        return torch.cat([f18_a, f18_b, f34], dim=1)  # (B, 1536)

concat_encoder = ConcatEncoder(encoder18_a, encoder18_b, encoder34).to(device)

# -------------------------------
# 5. Linear probe on top of concatenated features
# -------------------------------
NUM_CLASSES = train_df["class_id"].nunique()
feat_dim_concat = 512 * 3  # 1536

classifier_concat = nn.Linear(feat_dim_concat, NUM_CLASSES).to(device)

class LinearProbeModel(nn.Module):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier

    def forward(self, x):
        with torch.no_grad():          # encoder frozen
            feats = self.encoder(x)    # (B, 1536)
        logits = self.classifier(feats)
        return logits

model = LinearProbeModel(concat_encoder, classifier_concat).to(device)


In [25]:
#all encoders
import math
import torch
import torch.nn as nn
import torchvision

device = "cuda:3" if torch.cuda.is_available() else "cpu"

# -------------------------------
# 1. DINO encoder wrapper
# -------------------------------
class DINOEncoderWrapper(nn.Module):
    """Wraps the DINO student backbone and returns a flat feature vector."""
    def __init__(self, dino_model):
        super().__init__()
        self.backbone = dino_model.student_backbone

        # freeze backbone params
        for p in self.backbone.parameters():
            p.requires_grad = False

    def forward(self, x):
        feats = self.backbone(x)          # (B, 512, 1, 1)
        if isinstance(feats, (list, tuple)):
            feats = feats[0]
        feats = feats.flatten(1)          # (B, 512)
        return feats

# -------------------------------
# 2. Build DINO18 (A,B,C,D) and DINO34 from ckpts
# -------------------------------
ckpt18_a_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18_test2fixed_data3.pt"
ckpt18_b_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18_test2fixed.pt"
ckpt18_c_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18_test2fixed_data2.pt"
ckpt18_d_path = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_18_test2fixed_data1_data1againidk.pt"
ckpt34_path   = "/home/long/code/amogh/data/models/dino-v1_full_finetuned_big_boi_test_fixed.pt"

in_dim18 = 512
in_dim34 = 512

# ---- ResNet18 backbone A ----
resnet18_a = torchvision.models.resnet18()
backbone18_a = nn.Sequential(*list(resnet18_a.children())[:-1])  # (B, 512, 1, 1)
dino18_a = DINO(backbone18_a, in_dim18)
ckpt18_a = torch.load(ckpt18_a_path, map_location="cpu")
dino18_a.load_state_dict(ckpt18_a["model_state"], strict=True)
dino18_a.to(device)

# ---- ResNet18 backbone B ----
resnet18_b = torchvision.models.resnet18()
backbone18_b = nn.Sequential(*list(resnet18_b.children())[:-1])
dino18_b = DINO(backbone18_b, in_dim18)
ckpt18_b = torch.load(ckpt18_b_path, map_location="cpu")
dino18_b.load_state_dict(ckpt18_b["model_state"], strict=True)
dino18_b.to(device)

# ---- ResNet18 backbone C ----
resnet18_c = torchvision.models.resnet18()
backbone18_c = nn.Sequential(*list(resnet18_c.children())[:-1])
dino18_c = DINO(backbone18_c, in_dim18)
ckpt18_c = torch.load(ckpt18_c_path, map_location="cpu")
dino18_c.load_state_dict(ckpt18_c["model_state"], strict=True)
dino18_c.to(device)

# ---- ResNet18 backbone D ----
resnet18_d = torchvision.models.resnet18()
backbone18_d = nn.Sequential(*list(resnet18_d.children())[:-1])
dino18_d = DINO(backbone18_d, in_dim18)
ckpt18_d = torch.load(ckpt18_d_path, map_location="cpu")
dino18_d.load_state_dict(ckpt18_d["model_state"], strict=True)
dino18_d.to(device)

# ---- ResNet34 backbone ----
resnet34 = torchvision.models.resnet34()
backbone34 = nn.Sequential(*list(resnet34.children())[:-1])  # (B, 512, 1, 1)
dino34 = DINO(backbone34, in_dim34)
ckpt34 = torch.load(ckpt34_path, map_location="cpu")
dino34.load_state_dict(ckpt34["model_state"], strict=True)
dino34.to(device)

# -------------------------------
# 3. Wrap each into an encoder
# -------------------------------
encoder18_a = DINOEncoderWrapper(dino18_a).to(device)
encoder18_b = DINOEncoderWrapper(dino18_b).to(device)
encoder18_c = DINOEncoderWrapper(dino18_c).to(device)
encoder18_d = DINOEncoderWrapper(dino18_d).to(device)
encoder34   = DINOEncoderWrapper(dino34).to(device)

# -------------------------------
# 4. Concat encoder (5-way)
# -------------------------------
class ConcatEncoder(nn.Module):
    """Runs five encoders and concatenates their features."""
    def __init__(self, enc18_a, enc18_b, enc18_c, enc18_d, enc34):
        super().__init__()
        self.enc18_a = enc18_a
        self.enc18_b = enc18_b
        self.enc18_c = enc18_c
        self.enc18_d = enc18_d
        self.enc34   = enc34

        # freeze everything inside concat encoder
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, x):
        f18_a = self.enc18_a(x)   # (B, 512)
        f18_b = self.enc18_b(x)   # (B, 512)
        f18_c = self.enc18_c(x)   # (B, 512)
        f18_d = self.enc18_d(x)   # (B, 512)
        f34   = self.enc34(x)     # (B, 512)
        return torch.cat([f18_a, f18_b, f18_c, f18_d, f34], dim=1)  # (B, 2560)

concat_encoder = ConcatEncoder(
    encoder18_a, encoder18_b, encoder18_c, encoder18_d, encoder34
).to(device)

# -------------------------------
# 5. Linear probe on top of concatenated features
# -------------------------------
NUM_CLASSES = train_df["class_id"].nunique()
feat_dim_concat = 512 * 5  # 2560

classifier_concat = nn.Linear(feat_dim_concat, NUM_CLASSES).to(device)

class LinearProbeModel(nn.Module):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier

    def forward(self, x):
        with torch.no_grad():          # encoder frozen
            feats = self.encoder(x)    # (B, 2560)
        logits = self.classifier(feats)
        return logits

model = LinearProbeModel(concat_encoder, classifier_concat).to(device)


In [26]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())


In [27]:
total_params = count_params(concat_encoder)
print("Total parameters in concat encoder:", total_params)


Total parameters in concat encoder: 65990720


In [28]:
device

'cuda:3'

In [29]:

print(f"Using device: {device}")

params = list(model.classifier.parameters())
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    params,
    lr=3e-3,          # tune between 1e-3 and 3e-3
    weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=50,         # num_epochs
)

scaler = torch.cuda.amp.GradScaler()

Using device: cuda:3


  scaler = torch.cuda.amp.GradScaler()


In [30]:
# --- Combine train + val ---
from torch.utils.data import ConcatDataset, DataLoader

trainval_ds = ConcatDataset([train_dataset1, val_dataset1])

trainval_loader = DataLoader(
    trainval_ds,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
)


In [31]:
# Only train the classifier
# device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

params = list(model.classifier.parameters())
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    params,
    lr=3e-3,          # tune between 1e-3 and 3e-3
    weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=50,         # num_epochs
)

scaler = torch.cuda.amp.GradScaler()

Using device: cuda:3


  scaler = torch.cuda.amp.GradScaler()


In [32]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

encoder = model.encoder.to(device)
encoder.eval()
for p in encoder.parameters():
    p.requires_grad = False


@torch.no_grad()
def extract_features(encoder, loader, device, desc="Extracting"):
    encoder.eval()
    all_feats = []
    all_labels = []

    # tqdm wrapper around the dataloader
    for batch in tqdm(loader, desc=desc):
        if len(batch) == 3:
            images, labels, _ = batch
        else:
            images, labels = batch

        images = images.to(device, non_blocking=True)

        feats = encoder(images)   # shape: (B, C, H, W) or (B, D)
        if feats.dim() > 2:
            feats = feats.flatten(1)  # (B, D)

        all_feats.append(feats.cpu())
        all_labels.append(labels.cpu())

    all_feats = torch.cat(all_feats, dim=0)   # (N, D)
    all_labels = torch.cat(all_labels, dim=0) # (N,)
    return all_feats, all_labels


print("Extracting train+val features...")
trainval_feats, trainval_labels = extract_features(
    encoder, trainval_loader, device, desc="Train/Val"
)

print("Extracting test features...")
test_feats, test_labels = extract_features(
    encoder, test_loader1, device, desc="Test"
)

feat_dim = trainval_feats.shape[1]
num_classes = int(trainval_labels.max().item() + 1)

print(f"Feature dim = {feat_dim}, num_classes = {num_classes}")


Extracting train+val features...


Train/Val: 100%|████████████████████████████████████████████████████████████████████████████████████| 264/264 [00:48<00:00,  5.47it/s]


Extracting test features...


Test: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 47/47 [00:10<00:00,  4.60it/s]


Feature dim = 2560, num_classes = 397


In [33]:
trainval_feat_ds = TensorDataset(trainval_feats, trainval_labels)
test_feat_ds     = TensorDataset(test_feats,     test_labels)

feat_batch_size = 512  # can be big, it's cheap now
# feat_batch_size = 512  # can be big, it's cheap now

trainval_feat_loader = DataLoader(
    trainval_feat_ds, batch_size=feat_batch_size,
    shuffle=True, num_workers=0
)

test_feat_loader = DataLoader(
    test_feat_ds, batch_size=feat_batch_size,
    shuffle=False, num_workers=0
)


In [55]:
import torch.nn as nn

# linear_head = nn.Linear(feat_dim, num_classes).to(device)
import torch.nn as nn

dropout_p = 0.25  # try 0.2 / 0.3 / 0.5

linear_head = nn.Sequential(
    nn.Dropout(p=dropout_p),
    nn.Linear(feat_dim, num_classes),
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(linear_head.parameters(), lr=2e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
scaler = torch.cuda.amp.GradScaler()


  scaler = torch.cuda.amp.GradScaler()


In [58]:
def run_epoch_head(head, loader, train=True):
    if train:
        head.train()
    else:
        head.eval()

    running_loss = 0.0
    correct = 0
    total = 0

    for feats, labels in loader:
        feats = feats.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        if train:
            optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast():
            logits = head(feats)
            loss = criterion(logits, labels)

        if train:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        running_loss += loss.item() * labels.size(0)
        _, preds = torch.max(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / total
    acc = 100.0 * correct / total
    return avg_loss, acc


In [59]:
num_epochs = 500
best_test_acc = 0.0
best_epoch = 0

In [60]:
# num_epochs = 500
# best_test_acc = 0.0
# best_epoch = 0

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = run_epoch_head(linear_head, trainval_feat_loader, train=True)
    test_loss, test_acc   = run_epoch_head(linear_head, test_feat_loader,   train=False)
    scheduler.step()

    if test_acc > best_test_acc:
        best_test_acc = test_acc
        best_epoch = epoch
        # you can save the head weights here if you want
        torch.save(linear_head.state_dict(), "best_linear_head_2.pt")

    print(
        f"Epoch {epoch:02d} | "
        f"train loss: {train_loss:.4f}, acc: {train_acc:.2f}% | "
        f"test loss: {test_loss:.4f}, acc: {test_acc:.2f}%"
    )

print(f"BEST TEST ACC = {best_test_acc:.2f}% at epoch {best_epoch}")


  with torch.cuda.amp.autocast():


Epoch 01 | train loss: 5.2999, acc: 7.87% | test loss: 4.2723, acc: 16.12%
Epoch 02 | train loss: 3.7108, acc: 23.33% | test loss: 3.6290, acc: 24.38%
Epoch 03 | train loss: 3.1834, acc: 31.32% | test loss: 3.4668, acc: 26.76%
Epoch 04 | train loss: 2.8600, acc: 36.89% | test loss: 3.2599, acc: 28.91%
Epoch 05 | train loss: 2.5803, acc: 41.23% | test loss: 3.1906, acc: 30.96%
Epoch 06 | train loss: 2.4021, acc: 45.56% | test loss: 3.1740, acc: 30.62%
Epoch 07 | train loss: 2.2752, acc: 48.06% | test loss: 3.1362, acc: 31.40%
Epoch 08 | train loss: 2.1708, acc: 50.63% | test loss: 3.1033, acc: 31.93%
Epoch 09 | train loss: 1.9955, acc: 53.59% | test loss: 2.9992, acc: 33.21%
Epoch 10 | train loss: 1.8763, acc: 56.47% | test loss: 3.0498, acc: 33.51%
Epoch 11 | train loss: 1.8032, acc: 58.45% | test loss: 3.0593, acc: 32.74%
Epoch 12 | train loss: 1.6828, acc: 60.78% | test loss: 3.0030, acc: 34.39%
Epoch 13 | train loss: 1.6269, acc: 61.58% | test loss: 2.9663, acc: 34.12%
Epoch 14 | tr

In [61]:
from pathlib import Path
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

@torch.no_grad()
def create_submission_from_features(head, test_feats, test_df, device, output_dir,
                                    batch_size=256):
    """
    Uses precomputed test_feats (N, D) and a trained head to create submission.csv.
    This exactly matches the sanity-checked evaluation.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    head = head.to(device).eval()

    N = test_feats.shape[0]
    assert N == len(test_df), f"Mismatch: test_feats={N}, test_df={len(test_df)}"

    # Dataset: (feats, idx) so we preserve order
    idx_tensor = torch.arange(N, dtype=torch.long)
    ds = TensorDataset(test_feats, idx_tensor)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False)

    pred_array = torch.zeros(N, dtype=torch.long)

    for feats_batch, idx in tqdm(loader, desc="Submission (features)"):
        feats_batch = feats_batch.to(device, non_blocking=True)
        logits = head(feats_batch)
        preds  = torch.argmax(logits, dim=1).cpu()
        pred_array[idx] = preds

    # optional sanity: check class distribution
    bincount = torch.bincount(pred_array)
    print("Prediction class counts:", bincount.tolist())

    submission = pd.DataFrame({
        "id": test_df["filename"].values,
        "class_id": pred_array.numpy(),
    })

    out_path = output_dir / "submission3.csv"
    submission.to_csv(out_path, index=False)
    print(f"\nSubmission written to: {out_path}")
    return submission


In [62]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# IMPORTANT: make sure you're using the trained/best head
linear_head.load_state_dict(torch.load("best_linear_head.pt", map_location=device))

submission = create_submission_from_features(
    head=linear_head,
    test_feats=test_feats,       # from your extract_features(...)
    test_df=test_df,
    device=device,
    output_dir=".",
)


RuntimeError: Error(s) in loading state_dict for Sequential:
	size mismatch for 1.weight: copying a param with shape torch.Size([64, 2560]) from checkpoint, the shape in current model is torch.Size([397, 2560]).
	size mismatch for 1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([397]).

In [None]:
submission

In [None]:
linear_head.eval()
with torch.no_grad():
    logits = linear_head(test_feats.to(device))
    preds  = torch.argmax(logits, dim=1).cpu()

correct = (preds == test_labels).sum().item()
acc = 100.0 * correct / len(test_labels)
print(f"Sanity check on stored test_feats: {acc:.2f}%")


In [None]:
from pathlib import Path
import pandas as pd
import torch
from tqdm import tqdm

@torch.no_grad()
def create_submission(encoder, head, test_loader, test_df, device, output_dir):
    """
    Runs encoder + head on test_loader and writes submission.csv,
    guaranteeing exact 1:1 alignment with test_df.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    encoder = encoder.to(device).eval()
    head    = head.to(device).eval()

    N = len(test_df)
    pred_array = torch.zeros(N, dtype=torch.long)

    for batch in tqdm(test_loader, desc="Test inference (order-safe)"):
        # After the ImageDataset change, test batch = (images, labels, idx)
        images, _, idx = batch

        images = images.to(device, non_blocking=True)

        feats = encoder(images)
        if feats.dim() > 2:
            feats = feats.flatten(1)

        logits = head(feats)
        preds  = torch.argmax(logits, dim=1).cpu()

        # Place predictions into the correct positions
        pred_array[idx] = preds

    submission = pd.DataFrame({
        "id": test_df["filename"],      # matches the Kaggle-style column
        "class_id": pred_array.numpy(), # one prediction per row
    })

    out_path = output_dir / "submission.csv"
    submission.to_csv(out_path, index=False)
    print(f"\nSubmission written to: {out_path}")


In [193]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = model.encoder     # frozen encoder from your LinearProbeModel
head    = linear_head       # or model.classifier, whichever you trained

create_submission(
    encoder=encoder,
    head=head,
    test_loader=test_loader1,
    test_df=test_df,
    device=device,
    output_dir=".",
)


Test inference (order-safe): 100%|███████| 47/47 [00:03<00:00, 13.87it/s]


Submission written to: submission.csv





In [None]:
 Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest_finetuned_18_big_data_texture.pt

In [220]:
!ls /home/long/code/dl_project1/experiments/outputs/dino-v1/ -lrth

total 517M
-rw-rw-r-- 1 long long 182M Dec  6 04:28 dino-v1_small_100.pt
-rw-rw-r-- 1 long long 336M Dec  6 14:09 dino-v1_latest.pt


In [32]:
import torch
import pandas as pd
from tqdm import tqdm
import os


def create_submission(
    test_loader,
    head,                         # this is your linear_head
    output_path="submission.csv",
    device="cuda",
    encoder=None,                 # <-- NEW: optional encoder
):
    """
    Create submission.csv from test_loader and a trained linear head.

    Args:
        test_loader:
            - If encoder is not None:
                yields (images, filenames) or (images, labels, filenames)
                images: preprocessed tensors ready for encoder
            - If encoder is None:
                yields (features, filenames) or (features, labels, filenames)
                features: outputs of encoder (flattened)
        head: trained linear head; takes encoder features as input and outputs logits
        encoder: optional frozen encoder. If provided, we do:
                 feats = encoder(images); feats = feats.flatten(1) if needed
        output_path: where to save submission CSV
        device: 'cuda' or 'cpu'

    Returns:
        submission_df: DataFrame with 'id' and 'class_id'
        accuracy: test accuracy if labels are present, else None
    """

    device = device if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    head = head.to(device)
    head.eval()

    if encoder is not None:
        encoder = encoder.to(device)
        encoder.eval()
        for p in encoder.parameters():
            p.requires_grad = False

    # Check if dataloader is shuffled (common issue)
    if hasattr(test_loader, "sampler") and hasattr(test_loader.sampler, "shuffle"):
        if test_loader.sampler.shuffle:
            print("⚠️  WARNING: test_loader appears to be shuffled!")
            print("Make sure shuffle=False for submission.")

    all_predictions = []
    all_filenames = []
    correct = 0
    total = 0
    has_labels = False

    print("Generating predictions...")

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Inference"):
            # -------------------------
            # 1. Unpack batch
            # -------------------------
            if isinstance(batch, (list, tuple)):
                if len(batch) == 3:
                    x, labels, filenames = batch
                    has_labels = True
                elif len(batch) == 2:
                    x, second = batch
                    # Heuristic: tensor → labels, list/str → filenames
                    if isinstance(second, torch.Tensor):
                        x, labels = batch
                        filenames = None
                        has_labels = True
                    else:
                        x, filenames = batch
                        labels = None
                else:
                    raise ValueError(f"Unexpected batch length: {len(batch)}")
            else:
                # Dict style (not typical in your code, but kept for safety)
                x = batch["image"]
                filenames = batch.get("filename", batch.get("id"))
                labels = batch.get("label", batch.get("class_id"))
                if labels is not None:
                    has_labels = True

            # -------------------------
            # 2. Move to device
            # -------------------------
            if isinstance(x, torch.Tensor):
                x = x.to(device, non_blocking=True)

            if labels is not None and isinstance(labels, torch.Tensor):
                labels = labels.to(device, non_blocking=True)

            # -------------------------
            # 3. Forward: encoder (optional) + head
            # -------------------------
            with torch.cuda.amp.autocast():
                if encoder is not None:
                    feats = encoder(x)
                    if feats.dim() > 2:
                        feats = feats.flatten(1)
                    logits = head(feats)
                else:
                    # x is already features
                    logits = head(x)

            # -------------------------
            # 4. Predictions
            # -------------------------
            if logits.dim() == 2:
                _, preds = torch.max(logits, dim=1)
            else:  # already predictions
                preds = logits

            # accuracy if labels available
            if labels is not None:
                correct += (preds == labels).sum().item()
                total += labels.size(0)

            all_predictions.extend(preds.cpu().numpy())

            # -------------------------
            # 5. Filenames
            # -------------------------
            clean_names = []

            if filenames is None:
                # If you ever call this on a feature-loader without filenames,
                # you should instead pass filenames separately or avoid CSV here.
                raise ValueError(
                    "Filenames are None. For submission, test_loader must yield filenames."
                )

            if isinstance(filenames, torch.Tensor):
                for f in filenames.cpu().tolist():
                    clean_names.append(os.path.basename(str(f)))
            elif isinstance(filenames, list):
                for f in filenames:
                    clean_names.append(os.path.basename(str(f)))
            else:  # single string
                clean_names.append(os.path.basename(str(filenames)))

            all_filenames.extend(clean_names)

    # -------------------------
    # 6. Build submission
    # -------------------------

    import pdb
    pdb.set_trace()
    submission_df = pd.DataFrame({
        "id": all_filenames,
        "class_id": all_predictions,
    })

    # Deduplicate IDs if needed
    duplicates = submission_df[submission_df.duplicated(subset=["id"], keep=False)]
    if len(duplicates) > 0:
        print("\n⚠️  WARNING: Duplicate IDs found!")
        print(f"Number of duplicate entries: {len(duplicates)}")
        print("\nRemoving duplicates (keeping first occurrence)...")
        submission_df = submission_df.drop_duplicates(subset=["id"], keep="first")

    submission_df.to_csv(output_path, index=False)

    # -------------------------
    # 7. Accuracy if labels given
    # -------------------------
    accuracy = None
    if total > 0:
        accuracy = correct / total
        print(f"\n{'='*60}")
        print(f"TEST ACCURACY: {accuracy:.4f} ({accuracy*100:.2f}%)")
        print(f"Correct: {correct} / {total}")
        print(f"{'='*60}")

    # Summary
    print(f"\n{'='*60}")
    print(f"✓ Submission saved to: {output_path}")
    print(f"{'='*60}")
    print(f"Total predictions: {len(submission_df)}")
    print(f"Unique classes predicted: {submission_df['class_id'].nunique()}")
    print(f"\nClass distribution (top 10):")
    print(submission_df["class_id"].value_counts().head(10))

    # Validate format
    print("\nValidating submission format...")
    assert list(submission_df.columns) == ["id", "class_id"], "Invalid columns!"
    assert submission_df["class_id"].min() >= 0, "Invalid class_id < 0"
    assert submission_df.isnull().sum().sum() == 0, "Missing values found!"
    print("✓ Submission format is valid!")

    return submission_df, accuracy


In [33]:
linear_head

Linear(in_features=512, out_features=100, bias=True)

In [39]:
len(test_loader1)

141

In [45]:
# reload best head
linear_head.load_state_dict(torch.load("best_linear_head.pt"))

# test_loader_comp should yield (images, filenames) with shuffle=False
submission_df, _ = create_submission(
    test_loader=test_loader1,
    head=linear_head,
    encoder=encoder,                # the frozen encoder you used for features
    output_path="submission.csv",
    device="cuda:1",
)


Using device: cuda:1
Generating predictions...


  with torch.cuda.amp.autocast():
Inference: 100%|██████████████████████████████████████████████████████████████████████████████████████| 141/141 [00:12<00:00, 11.37it/s]

> [32m/tmp/ipykernel_703917/3071104241.py[39m([92m157[39m)[36mcreate_submission[39m[34m()[39m
[32m    155[39m     [38;5;28;01mimport[39;00m pdb
[32m    156[39m     pdb.set_trace()
[32m--> 157[39m     submission_df = pd.DataFrame({
[32m    158[39m         [33m"id"[39m: all_filenames,
[32m    159[39m         [33m"class_id"[39m: all_predictions,






ipdb>  exit


In [46]:
submission_df, _ = create_submission(
    test_loader=,
    head=linear_head,     # no encoder
    encoder=None,
    output_path="submission.csv",
    device="cuda",
)


SyntaxError: expected argument value expression (2858673219.py, line 2)

In [47]:
import torch
import pandas as pd
from tqdm import tqdm

def create_submission_from_feats(
    test_feats,
    head,
    filenames,
    output_path="submission.csv",
    device="cuda",
    batch_size=512,
):
    """
    Create submission.csv using precomputed test_feats and a trained linear head.

    Args:
        test_feats: Tensor of shape (N, D), features from frozen encoder
        head: trained linear head mapping D -> num_classes
        filenames: list/array of length N with image ids (e.g. test_df['filename'])
        output_path: where to save submission.csv
        device: 'cuda' or 'cpu'
        batch_size: inference batch size

    Returns:
        submission_df: DataFrame with columns ['id', 'class_id']
    """
    device = device if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    head = head.to(device)
    head.eval()

    N = test_feats.shape[0]
    assert N == len(filenames), f"Mismatch: {N} feats vs {len(filenames)} filenames"

    all_preds = []

    with torch.no_grad():
        for start in tqdm(range(0, N, batch_size), desc="Inference (features)"):
            end = min(start + batch_size, N)
            batch_feats = test_feats[start:end].to(device, non_blocking=True)

            with torch.cuda.amp.autocast():
                logits = head(batch_feats)

            preds = logits.argmax(dim=1)
            all_preds.append(preds.cpu())

    all_preds = torch.cat(all_preds, dim=0).numpy()
    assert len(all_preds) == N

    submission_df = pd.DataFrame({
        "id": filenames,
        "class_id": all_preds,
    })

    submission_df.to_csv(output_path, index=False)

    print(f"\nSaved submission to {output_path}")
    print(f"Total predictions: {len(submission_df)}")
    print(f"Unique classes predicted: {submission_df['class_id'].nunique()}")
    print("\nClass distribution (top 10):")
    print(submission_df["class_id"].value_counts().head(10))

    # basic sanity checks
    print("\nValidating submission format...")
    assert list(submission_df.columns) == ["id", "class_id"], "Invalid columns!"
    assert submission_df["class_id"].min() >= 0, "Invalid class_id < 0"
    assert submission_df.isnull().sum().sum() == 0, "Missing values found!"
    print("✓ Submission format is valid!")

    return submission_df


In [48]:
device

'cuda:1'

In [49]:
# from earlier
test_feats, test_labels = extract_features(encoder, test_loader1, device)
# and
test_df = pd.read_csv(data_dir / 'test_labels_INTERNAL.csv')  # or test_images.csv for Kaggle


Extracting: 100%|█████████████████████████████████████████████████████████████████████████████████████| 141/141 [00:06<00:00, 21.93it/s]


In [50]:
# Load best head
linear_head.load_state_dict(torch.load("best_linear_head.pt"))

# Filenames in the same order as test_feats / test_loader1
filenames = test_df["filename"].tolist()   # or 'id' for Kaggle test csv

submission_df = create_submission_from_feats(
    test_feats=test_feats,
    head=linear_head,
    filenames=filenames,
    output_path="submission.csv",
    device="cuda",
)


Using device: cuda


  with torch.cuda.amp.autocast():
Inference (features): 100%|█████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 76.50it/s]


Saved submission to submission.csv
Total predictions: 9000
Unique classes predicted: 1

Class distribution (top 10):
class_id
0    9000
Name: count, dtype: int64

Validating submission format...
✓ Submission format is valid!





In [130]:
feat_dim = trainval_feats.shape[1]
num_classes = int(trainval_labels.max().item() + 1)

linear_head = nn.Linear(feat_dim, num_classes).to(device)

criterion = torch.nn.CrossEntropyLoss()

# 🔴 IMPORTANT: optimizer must use linear_head.parameters()
optimizer = torch.optim.AdamW(
    linear_head.parameters(),
    lr=1e-3,
    weight_decay=1e-4,
)

scaler = torch.cuda.amp.GradScaler()

  scaler = torch.cuda.amp.GradScaler()


In [131]:
def run_epoch_head(head, loader, train=True):
    if train:
        head.train()
    else:
        head.eval()

    running_loss = 0.0
    correct = 0
    total = 0

    for feats, labels in loader:
        feats = feats.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        if train:
            optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast():
            logits = head(feats)
            loss = criterion(logits, labels)

        if train:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        running_loss += loss.item() * labels.size(0)
        _, preds = torch.max(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / total
    acc = 100.0 * correct / total
    return avg_loss, acc


In [132]:
num_epochs = 50
best_test_acc = 0.0
best_epoch = 0

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = run_epoch_head(linear_head, trainval_feat_loader, train=True)
    test_loss, test_acc   = run_epoch_head(linear_head, test_feat_loader,   train=False)

    # if you still want scheduler:
    scheduler.step()

    if test_acc > best_test_acc:
        best_test_acc = test_acc
        best_epoch = epoch
        torch.save(linear_head.state_dict(), "best_linear_head.pt")

    print(
        f"Epoch {epoch:02d} | "
        f"train loss: {train_loss:.4f}, acc: {train_acc:.2f}% | "
        f"test loss: {test_loss:.4f}, acc: {test_acc:.2f}%"
    )
    with torch.no_grad():
        w_norm = linear_head.weight.norm().item()
        print("Head weight norm:", w_norm)


print(f"BEST TEST ACC = {best_test_acc:.2f}% at epoch {best_epoch}")


  with torch.cuda.amp.autocast():


Epoch 01 | train loss: 5.3503, acc: 0.99% | test loss: 5.1610, acc: 1.97%
Head weight norm: 8.72060489654541
Epoch 02 | train loss: 5.0719, acc: 2.69% | test loss: 4.9822, acc: 4.48%
Head weight norm: 9.860909461975098
Epoch 03 | train loss: 4.9029, acc: 4.48% | test loss: 4.8600, acc: 5.80%
Head weight norm: 11.409844398498535
Epoch 04 | train loss: 4.7823, acc: 6.33% | test loss: 4.7703, acc: 6.23%
Head weight norm: 13.110548973083496
Epoch 05 | train loss: 4.6873, acc: 7.51% | test loss: 4.6988, acc: 7.11%
Head weight norm: 14.873912811279297
Epoch 06 | train loss: 4.6122, acc: 8.43% | test loss: 4.6392, acc: 7.71%
Head weight norm: 16.646989822387695
Epoch 07 | train loss: 4.5452, acc: 9.22% | test loss: 4.5899, acc: 8.69%
Head weight norm: 18.419111251831055
Epoch 08 | train loss: 4.4865, acc: 10.18% | test loss: 4.5480, acc: 8.86%
Head weight norm: 20.18934440612793
Epoch 09 | train loss: 4.4326, acc: 10.46% | test loss: 4.5092, acc: 8.91%
Head weight norm: 21.93984603881836
Epoc

In [73]:
print("num classes (unique labels):", trainval_labels.unique().numel())
print("min label:", trainval_labels.min().item())
print("max label:", trainval_labels.max().item())


num classes (unique labels): 200
min label: 0
max label: 199
