In [1]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

In [2]:
from google.colab import drive
drive.mount('/content/drive')
import shutil

# Copy train_images folder from Drive to local fast disk
!cp -r /content/drive/MyDrive/nndl_colab/data/train_images /content/train_images
data_dir = '/content/drive/MyDrive/nndl_colab/data'

Mounted at /content/drive


In [54]:
def generate_blended_novel_images(ann_df, img_dir, output_dir, num_samples_each=250):
    os.makedirs(output_dir, exist_ok=True)
    novel_ann = []

    to_tensor = ToTensor()
    to_pil = ToPILImage()

    ann_df['superclass_index'] = ann_df['superclass_index'].astype(int)
    ann_df['subclass_index'] = ann_df['subclass_index'].astype(int)

    novel_super_idx = 3
    novel_sub_idx = 87

    for novel_type in ['sub_novel', 'super_novel']:
        for i in range(num_samples_each):
            if novel_type == 'sub_novel':
                super_cls = random.choice(ann_df['superclass_index'].unique())
                subset = ann_df[ann_df['superclass_index'] == super_cls]
                unique_subs = subset['subclass_index'].unique()
                if len(unique_subs) < 2:
                    continue

                rows = []
                seen_subclasses = set()
                while len(rows) < 2:
                    idx = random.randint(0, len(subset) - 1)
                    row = subset.iloc[idx]
                    if row['subclass_index'] not in seen_subclasses:
                        rows.append(row)
                        seen_subclasses.add(row['subclass_index'])

                tag = 'subnovel'

            else:  # super_novel
                rows = []
                seen_superclasses = set()
                while len(rows) < 2:
                    idx = random.randint(0, len(ann_df) - 1)
                    row = ann_df.iloc[idx]
                    if row['superclass_index'] not in seen_superclasses:
                        rows.append(row)
                        seen_superclasses.add(row['superclass_index'])

                tag = 'supernovel'

            imgs = [to_tensor(Image.open(os.path.join(img_dir, row['image'])).convert('RGB')) for row in rows]

            blended = (imgs[0] + imgs[1]) / 2
            blended = torch.clamp(blended, 0, 1)

            filename = f"blend_{tag}_{i}.jpg"
            to_pil(blended).save(os.path.join(output_dir, filename))

            novel_ann.append({
                'image': filename,
                'superclass_index': novel_super_idx if tag == 'supernovel' else super_cls,
                'subclass_index': novel_sub_idx,
                'is_super_novel': 1 if tag == 'supernovel' else 0,
                'is_sub_novel': 1
            })

    return pd.DataFrame(novel_ann)






In [57]:
from torchvision import transforms
from torchvision.transforms import ToTensor, ToPILImage

# Load base data
train_ann_df = pd.read_csv(f'{data_dir}/train_data.csv')
test_ann_df = pd.read_csv(f'{data_dir}/train_data.csv')  # If no test split
super_map_df = pd.read_csv(f'{data_dir}/superclass_mapping.csv')
sub_map_df = pd.read_csv(f'{data_dir}/subclass_mapping.csv')

# Image paths
train_img_dir = '/content/train_images'
test_img_dir = '/content/train_images'

# Image transform
image_preprocessing = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3)
])


novel_df = generate_blended_novel_images(
    ann_df=train_ann_df,
    img_dir=train_img_dir,
    output_dir='/content/novel_blended_images',
    num_samples_each=250
)
#

In [58]:
# Annotate original data as "known"
train_ann_df['is_super_novel'] = 0
train_ann_df['is_sub_novel'] = 0

# Merge known and novel data
combined_ann_df = pd.concat([train_ann_df, novel_df], ignore_index=True)



In [62]:
# Create Dataset class for multilabel classification
class MultiClassImageDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]

        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]

        if self.transform:
            image = self.transform(image)

        return image, super_idx, super_label, sub_idx, sub_label

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
      try:
          img_name = self.ann_df['image'].iloc[int(idx)]
          img_path = os.path.join(self.img_dir, img_name)

          if not os.path.exists(img_path):
              raise FileNotFoundError(f"❌ Image not found: {img_path}")

          image = Image.open(img_path).convert('RGB')

          super_idx = int(self.ann_df['superclass_index'].iloc[int(idx)])
          super_label = self.super_map_df['class'].iloc[super_idx]

          sub_idx = int(self.ann_df['subclass_index'].iloc[int(idx)])
          sub_label = self.sub_map_df['class'].iloc[sub_idx]

          if self.transform:
              image = self.transform(image)

          return image, super_idx, super_label, sub_idx, sub_label

      except Exception as e:
          print(f"❗ Error at idx {idx}: {e}")
          raise e  # re-raise so DataLoader stops with full traceback


In [63]:
import os
from torch.utils.data import Dataset
from PIL import Image

class MultiClassImageDatasetWithNovelty(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, transform=None):
        self.ann_df = ann_df
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        row = self.ann_df.iloc[idx]
        filename = row['image']

        # Determine image location
        if filename.startswith("blend_"):
            img_path = os.path.join('/content/novel_blended_images', filename)
        else:
            img_path = os.path.join('/content/train_images', filename)

        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        super_idx = int(row['superclass_index'])
        sub_idx = int(row['subclass_index'])
        is_super_novel = int(row.get('is_super_novel', 0))
        is_sub_novel = int(row.get('is_sub_novel', 0))

        return image, super_idx, sub_idx, is_super_novel, is_sub_novel



In [64]:
from torch.utils.data import random_split, DataLoader

# Combined annotation already includes known + super-novel + sub-novel
full_dataset = MultiClassImageDatasetWithNovelty(
    ann_df=combined_ann_df,
    super_map_df=super_map_df,
    sub_map_df=sub_map_df,
    transform=image_preprocessing
)

# Fix randomness for reproducibility
torch.manual_seed(42)

# Define split sizes
train_size = int(0.8 * len(full_dataset))
val_size = int(0.1 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

# Perform the split
train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size]
)

# Create DataLoaders
batch_size = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False
)


In [66]:
import timm
import torch
import torch.nn as nn

class ViTMultiTaskWithNovelty(nn.Module):
    def __init__(self):
        super().__init__()
        self.base = timm.create_model('vit_base_patch32_224', pretrained=True)
        self.base.head = nn.Identity()

        # Main classification heads
        self.super_head = nn.Linear(self.base.num_features, 4)
        self.sub_head = nn.Linear(self.base.num_features, 88)

        # Novelty detection heads
        self.super_novel_head = nn.Linear(self.base.num_features, 1)  # detects novel superclass
        self.sub_novel_head = nn.Linear(self.base.num_features, 1)    # detects novel subclass

        # Freeze base model
        for param in self.base.parameters():
            param.requires_grad = False

        # Unfreeze last 2 blocks + norm
        for name, param in self.base.named_parameters():
            if "blocks.10" in name or "blocks.11" in name or "norm" in name:
                param.requires_grad = True

    def forward(self, x):
        feats = self.base(x)
        return (
            self.super_head(feats),         # logits for superclass
            self.sub_head(feats),           # logits for subclass
            self.super_novel_head(feats),   # binary novelty score
            self.sub_novel_head(feats)      # binary novelty score
        )

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTMultiTaskWithNovelty().to(device)



In [67]:
# Separate head and unfreezed backbone params
head_params = (
    list(model.super_head.parameters()) +
    list(model.sub_head.parameters()) +
    list(model.super_novel_head.parameters()) +
    list(model.sub_novel_head.parameters())
)

backbone_params = [
    p for n, p in model.named_parameters()
    if "blocks.10" in n or "blocks.11" in n or "norm" in n
]

optimizer = torch.optim.AdamW([
    {'params': head_params, 'lr': 1e-3},
    {'params': backbone_params, 'lr': 1e-5}
])

# Losses
criterion = nn.CrossEntropyLoss()
super_novel_criterion = nn.BCEWithLogitsLoss()
sub_novel_criterion = nn.BCEWithLogitsLoss()



In [68]:
import time
import torch.nn.functional as F

def train_and_validate(model, train_loader, val_loader, optimizer,
                       criterion, super_novel_criterion, sub_novel_criterion,
                       device, num_epochs=10):

    for epoch in range(num_epochs):
        print(f"\n Epoch {epoch+1}/{num_epochs}")
        start_time = time.time()

        # ---- TRAIN ----
        model.train()
        train_loss = 0.0
        correct_super_train, correct_sub_train = 0, 0
        correct_supernovel_train, correct_subnovel_train = 0, 0
        total_train = 0

        for i, (images, super_idx, sub_idx, is_super_novel, is_sub_novel) in enumerate(train_loader):
            images = images.to(device)
            super_idx = super_idx.to(device)
            sub_idx = sub_idx.to(device)
            is_super_novel = is_super_novel.float().to(device)
            is_sub_novel = is_sub_novel.float().to(device)

            out_super, out_sub, out_super_novel, out_sub_novel = model(images)

            loss = (
                criterion(out_super, super_idx) +
                criterion(out_sub, sub_idx) +
                super_novel_criterion(out_super_novel.squeeze(), is_super_novel) +
                sub_novel_criterion(out_sub_novel.squeeze(), is_sub_novel)
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accuracy
            pred_super = out_super.argmax(1)
            pred_sub = out_sub.argmax(1)
            pred_supernovel = (torch.sigmoid(out_super_novel.squeeze()) > 0.5).long()
            pred_subnovel = (torch.sigmoid(out_sub_novel.squeeze()) > 0.5).long()

            correct_super_train += (pred_super == super_idx).sum().item()
            correct_sub_train += (pred_sub == sub_idx).sum().item()
            correct_supernovel_train += (pred_supernovel == is_super_novel.long()).sum().item()
            correct_subnovel_train += (pred_subnovel == is_sub_novel.long()).sum().item()

            total_train += images.size(0)
            train_loss += loss.item()

            #print(f"[Train Batch {i+1}] Loss: {loss.item():.4f} | Super Acc: {correct_super_train/total_train:.4f} | Sub Acc: {correct_sub_train/total_train:.4f} | SuperNovel Acc: {correct_supernovel_train/total_train:.4f} | SubNovel Acc: {correct_subnovel_train/total_train:.4f}")

        print(f" Epoch {epoch+1} Train Summary: Loss = {train_loss/len(train_loader):.4f} | Super Acc = {correct_super_train/total_train:.4f} | Sub Acc = {correct_sub_train/total_train:.4f} | SuperNovel Acc = {correct_supernovel_train/total_train:.4f} | SubNovel Acc = {correct_subnovel_train/total_train:.4f}")

        # ---- VALIDATION ----
        model.eval()
        val_loss = 0.0
        correct_super_val, correct_sub_val = 0, 0
        correct_supernovel_val, correct_subnovel_val = 0, 0
        total_val = 0

        with torch.no_grad():
            for i, (images, super_idx, sub_idx, is_super_novel, is_sub_novel) in enumerate(val_loader):
                images = images.to(device)
                super_idx = super_idx.to(device)
                sub_idx = sub_idx.to(device)
                is_super_novel = is_super_novel.float().to(device)
                is_sub_novel = is_sub_novel.float().to(device)

                out_super, out_sub, out_super_novel, out_sub_novel = model(images)

                loss = (
                    criterion(out_super, super_idx) +
                    criterion(out_sub, sub_idx) +
                    super_novel_criterion(out_super_novel.squeeze(), is_super_novel) +
                    sub_novel_criterion(out_sub_novel.squeeze(), is_sub_novel)
                )
                val_loss += loss.item()

                pred_super = out_super.argmax(1)
                pred_sub = out_sub.argmax(1)
                pred_supernovel = (torch.sigmoid(out_super_novel.squeeze()) > 0.5).long()
                pred_subnovel = (torch.sigmoid(out_sub_novel.squeeze()) > 0.5).long()

                correct_super_val += (pred_super == super_idx).sum().item()
                correct_sub_val += (pred_sub == sub_idx).sum().item()
                correct_supernovel_val += (pred_supernovel == is_super_novel.long()).sum().item()
                correct_subnovel_val += (pred_subnovel == is_sub_novel.long()).sum().item()

                total_val += images.size(0)

                #print(f"[Val Batch {i+1}] Loss: {loss.item():.4f} | Super Acc: {correct_super_val/total_val:.4f} | Sub Acc: {correct_sub_val/total_val:.4f} | SuperNovel Acc: {correct_supernovel_val/total_val:.4f} | SubNovel Acc: {correct_subnovel_val/total_val:.4f}")

        print(f" Epoch {epoch+1} Val Summary: Loss = {val_loss/len(val_loader):.4f} | Super Acc = {correct_super_val/total_val:.4f} | Sub Acc = {correct_sub_val/total_val:.4f} | SuperNovel Acc = {correct_supernovel_val/total_val:.4f} | SubNovel Acc = {correct_subnovel_val/total_val:.4f}")
        print(f"Time: {time.time() - start_time:.2f}s")



In [70]:
#15 is best for now
train_and_validate(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,               # CrossEntropyLoss (shared)
    super_novel_criterion,   # BCEWithLogitsLoss for super_novel
    sub_novel_criterion,     # BCEWithLogitsLoss for sub_novel
    device,
    num_epochs=15
)




 Epoch 1/15
 Epoch 1 Train Summary: Loss = 0.1929 | Super Acc = 0.9912 | Sub Acc = 0.9773 | SuperNovel Acc = 0.9816 | SubNovel Acc = 0.9976
 Epoch 1 Val Summary: Loss = 0.1874 | Super Acc = 0.9882 | Sub Acc = 0.9705 | SuperNovel Acc = 0.9838 | SubNovel Acc = 0.9971
Time: 15.82s

 Epoch 2/15
 Epoch 2 Train Summary: Loss = 0.1176 | Super Acc = 0.9956 | Sub Acc = 0.9886 | SuperNovel Acc = 0.9843 | SubNovel Acc = 0.9993
 Epoch 2 Val Summary: Loss = 0.1799 | Super Acc = 0.9912 | Sub Acc = 0.9646 | SuperNovel Acc = 0.9823 | SubNovel Acc = 0.9985
Time: 15.68s

 Epoch 3/15
 Epoch 3 Train Summary: Loss = 0.0796 | Super Acc = 0.9993 | Sub Acc = 0.9932 | SuperNovel Acc = 0.9897 | SubNovel Acc = 0.9998
 Epoch 3 Val Summary: Loss = 0.1568 | Super Acc = 0.9882 | Sub Acc = 0.9735 | SuperNovel Acc = 0.9838 | SubNovel Acc = 1.0000
Time: 15.60s

 Epoch 4/15
 Epoch 4 Train Summary: Loss = 0.0527 | Super Acc = 0.9998 | Sub Acc = 0.9972 | SuperNovel Acc = 0.9932 | SubNovel Acc = 1.0000
 Epoch 4 Val Summar

In [71]:
def evaluate(model, dataloader, device):
    model.eval()
    total = 0
    correct_super = 0
    correct_sub = 0
    correct_super_novel = 0
    correct_sub_novel = 0

    with torch.no_grad():
        for images, super_idx, sub_idx, is_super_novel, is_sub_novel in dataloader:
            images = images.to(device)
            super_idx = super_idx.to(device)
            sub_idx = sub_idx.to(device)
            is_super_novel = is_super_novel.to(device)
            is_sub_novel = is_sub_novel.to(device)

            out_super, out_sub, out_super_novel, out_sub_novel = model(images)

            pred_super = out_super.argmax(1)
            pred_sub = out_sub.argmax(1)
            pred_super_novel = (torch.sigmoid(out_super_novel.squeeze()) > 0.5).long()
            pred_sub_novel = (torch.sigmoid(out_sub_novel.squeeze()) > 0.5).long()

            correct_super += (pred_super == super_idx).sum().item()
            correct_sub += (pred_sub == sub_idx).sum().item()
            correct_super_novel += (pred_super_novel == is_super_novel.long()).sum().item()
            correct_sub_novel += (pred_sub_novel == is_sub_novel.long()).sum().item()

            total += images.size(0)

    print(f"   Evaluation Summary:")
    print(f"   Superclass Accuracy      = {correct_super / total:.4f}")
    print(f"   Subclass Accuracy        = {correct_sub / total:.4f}")
    print(f"   Super Novelty Accuracy   = {correct_super_novel / total:.4f}")
    print(f"   Sub Novelty Accuracy     = {correct_sub_novel / total:.4f}")



In [72]:
evaluate(model, test_loader, device)

   Evaluation Summary:
   Superclass Accuracy      = 0.9897
   Subclass Accuracy        = 0.9632
   Super Novelty Accuracy   = 0.9824
   Sub Novelty Accuracy     = 1.0000


In [73]:
class SuperNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.base = timm.create_model('vit_base_patch32_224', pretrained=True)
        self.base.head = nn.Identity()

        self.super_head = nn.Linear(self.base.num_features, 4)
        self.super_novel_head = nn.Linear(self.base.num_features, 1)

        for param in self.base.parameters():
            param.requires_grad = False
        for name, param in self.base.named_parameters():
            if any(f"blocks.{i}" in name for i in range(8, 12)) or "norm" in name:
                param.requires_grad = True

    def forward(self, x):
        feats = self.base(x)
        return self.super_head(feats), self.super_novel_head(feats)


In [77]:
class SubNet(nn.Module):
    def __init__(self, super_dim=4):
        super().__init__()
        self.base = timm.create_model('vit_base_patch32_224', pretrained=True)
        self.base.head = nn.Identity()

        self.input_proj = nn.Linear(self.base.num_features + super_dim, 512)
        self.sub_head = nn.Linear(512, 88)
        self.sub_novel_head = nn.Linear(512, 1)

        for param in self.base.parameters():
            param.requires_grad = False
        for name, param in self.base.named_parameters():
            if any(f"blocks.{i}" in name for i in range(8, 12)) or "norm" in name:
                param.requires_grad = True

    def forward(self, x, super_onehot):
        feats = self.base(x)
        combined = torch.cat([feats, super_onehot], dim=1)
        hidden = self.input_proj(combined)
        return self.sub_head(hidden), self.sub_novel_head(hidden)


In [74]:
criterion = nn.CrossEntropyLoss()
bce = nn.BCEWithLogitsLoss()

def train_two_stage(supernet, subnet, train_loader, optimizer_super, optimizer_sub, device, num_epochs=10):
    supernet.train()
    subnet.train()

    for epoch in range(num_epochs):
        print(f"\n🧪 Epoch {epoch+1}/{num_epochs}")
        total, correct_super, correct_sub = 0, 0, 0
        correct_supernovel, correct_subnovel = 0, 0

        for images, super_idx, sub_idx, is_super_novel, is_sub_novel in train_loader:
            images = images.to(device)
            super_idx = super_idx.to(device)
            sub_idx = sub_idx.to(device)
            is_super_novel = is_super_novel.float().to(device)
            is_sub_novel = is_sub_novel.float().to(device)

            # --- SuperNet forward ---
            super_logits, super_novel = supernet(images)
            super_loss = (
                criterion(super_logits, super_idx) +
                bce(super_novel.squeeze(), is_super_novel)
            )
            optimizer_super.zero_grad()
            super_loss.backward()
            optimizer_super.step()

            # --- SubNet forward (teacher forcing with GT super_idx) ---
            super_onehot = F.one_hot(super_idx, num_classes=4).float().to(device)
            sub_logits, sub_novel = subnet(images, super_onehot)
            sub_loss = (
                criterion(sub_logits, sub_idx) +
                bce(sub_novel.squeeze(), is_sub_novel)
            )
            optimizer_sub.zero_grad()
            sub_loss.backward()
            optimizer_sub.step()

            # --- Metrics ---
            pred_super = super_logits.argmax(1)
            pred_sub = sub_logits.argmax(1)
            pred_supernovel = (torch.sigmoid(super_novel.squeeze()) > 0.5).long()
            pred_subnovel = (torch.sigmoid(sub_novel.squeeze()) > 0.5).long()

            correct_super += (pred_super == super_idx).sum().item()
            correct_sub += (pred_sub == sub_idx).sum().item()
            correct_supernovel += (pred_supernovel == is_super_novel.long()).sum().item()
            correct_subnovel += (pred_subnovel == is_sub_novel.long()).sum().item()
            total += images.size(0)

        print(f" Super Acc = {correct_super/total:.4f} | Sub Acc = {correct_sub/total:.4f}")
        print(f" SuperNovel Acc = {correct_supernovel/total:.4f} | SubNovel Acc = {correct_subnovel/total:.4f}")


In [75]:
def evaluate_two_stage(supernet, subnet, dataloader, device):
    supernet.eval()
    subnet.eval()

    total, correct_super, correct_sub = 0, 0, 0
    correct_supernovel, correct_subnovel = 0, 0

    with torch.no_grad():
        for images, super_idx, sub_idx, is_super_novel, is_sub_novel in dataloader:
            images = images.to(device)
            super_idx = super_idx.to(device)
            sub_idx = sub_idx.to(device)
            is_super_novel = is_super_novel.to(device)
            is_sub_novel = is_sub_novel.to(device)

            super_logits, super_novel = supernet(images)
            pred_super = torch.argmax(super_logits, dim=1)

            # Use predicted superclass for inference
            super_onehot = F.one_hot(pred_super, num_classes=4).float()
            sub_logits, sub_novel = subnet(images, super_onehot.to(device))
            pred_sub = torch.argmax(sub_logits, dim=1)

            pred_supernovel = (torch.sigmoid(super_novel.squeeze()) > 0.5).long()
            pred_subnovel = (torch.sigmoid(sub_novel.squeeze()) > 0.5).long()

            correct_super += (pred_super == super_idx).sum().item()
            correct_sub += (pred_sub == sub_idx).sum().item()
            correct_supernovel += (pred_supernovel == is_super_novel.long()).sum().item()
            correct_subnovel += (pred_subnovel == is_sub_novel.long()).sum().item()
            total += images.size(0)

    print(f" EVAL: Super Acc = {correct_super/total:.4f} | Sub Acc = {correct_sub/total:.4f}")
    print(f"          SuperNovel Acc = {correct_supernovel/total:.4f} | SubNovel Acc = {correct_subnovel/total:.4f}")


In [None]:
supernet = SuperNet().to(device)
subnet = SubNet().to(device)

optimizer_super = torch.optim.AdamW(supernet.parameters(), lr=1e-4)
optimizer_sub = torch.optim.AdamW(subnet.parameters(), lr=1e-4)

train_two_stage(supernet, subnet, train_loader, optimizer_super, optimizer_sub, device, num_epochs=5)



In [None]:
evaluate_two_stage(supernet, subnet, test_loader, device)

In [80]:
#!cp -r /content/drive/MyDrive/nndl_colab/data/test_images /content/test_images

In [81]:
def test_model_1stage(model, test_loader, device, save_to_csv=True, save_path='test_predictions.csv'):
    model.eval()
    test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}

    novel_super_idx = 3
    novel_sub_idx = 87

    with torch.no_grad():
        for i, (images, img_names) in enumerate(test_loader):
            images = images.to(device)

            super_logits, sub_logits, super_novel_logits, sub_novel_logits = model(images)

            super_pred = torch.argmax(super_logits, dim=1).item()
            sub_pred = torch.argmax(sub_logits, dim=1).item()

            super_novel_score = torch.sigmoid(super_novel_logits.squeeze()).item()
            sub_novel_score = torch.sigmoid(sub_novel_logits.squeeze()).item()

            if super_novel_score > 0.5:
                final_super = novel_super_idx
                final_sub = novel_sub_idx
            elif sub_novel_score > 0.5:
                final_super = super_pred
                final_sub = novel_sub_idx
            else:
                final_super = super_pred
                final_sub = sub_pred

            test_predictions['image'].append(img_names[0])
            test_predictions['superclass_index'].append(final_super)
            test_predictions['subclass_index'].append(final_sub)

    test_predictions = pd.DataFrame(test_predictions)

    if save_to_csv:
        test_predictions.to_csv(save_path, index=False)
        print(f" Test predictions saved to {save_path}")

    return test_predictions



In [82]:
class TestImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.image_files = sorted(os.listdir(img_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name


In [83]:
open_set_test_dataset = TestImageDataset(img_dir='/content/test_images', transform=image_preprocessing)

open_set_test_loader = DataLoader(
    open_set_test_dataset,
    batch_size=1,
    shuffle=False
)

In [None]:
test_predictions = test_model_1stage(model, open_set_test_loader, device, save_to_csv=True, save_path='test_predictions_one_stage_ViT.csv')

In [85]:
def test_model_2stage(supernet, subnet, test_loader, device, save_to_csv=True, save_path='test_predictions.csv'):
    supernet.eval()
    subnet.eval()
    test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}

    novel_super_idx = 3
    novel_sub_idx = 87

    with torch.no_grad():
        for i, (images, img_names) in enumerate(test_loader):
            images = images.to(device)

            super_logits, super_novel_logits = supernet(images)
            super_pred = torch.argmax(super_logits, dim=1)

            super_novel_score = torch.sigmoid(super_novel_logits.squeeze()).item()

            if super_novel_score > 0.5:
                final_super = novel_super_idx
                final_sub = novel_sub_idx
            else:
                super_onehot = torch.nn.functional.one_hot(super_pred, num_classes=4).float()
                sub_logits, sub_novel_logits = subnet(images, super_onehot.to(device))
                sub_pred = torch.argmax(sub_logits, dim=1)

                sub_novel_score = torch.sigmoid(sub_novel_logits.squeeze()).item()

                if sub_novel_score > 0.5:
                    final_super = super_pred.item()
                    final_sub = novel_sub_idx
                else:
                    final_super = super_pred.item()
                    final_sub = sub_pred.item()

            test_predictions['image'].append(img_names[0])
            test_predictions['superclass_index'].append(final_super)
            test_predictions['subclass_index'].append(final_sub)

    test_predictions = pd.DataFrame(test_predictions)

    if save_to_csv:
        test_predictions.to_csv(save_path, index=False)
        print(f" Test predictions saved to {save_path}")

    return test_predictions


In [None]:
test_predictions = test_model_2stage(
    supernet,
    subnet,
    open_set_test_loader,
    device,
    save_to_csv=True,
    save_path='test_predictions_2stage.csv'
)


In [87]:
import torch.nn.functional as F

def test_model_1stage_with_softmax_threshold(model, test_loader, device, threshold=0.7, save_to_csv=True, save_path='test_predictions.csv'):
    model.eval()
    test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}

    novel_super_idx = 3
    novel_sub_idx = 87

    with torch.no_grad():
        for i, (images, img_names) in enumerate(test_loader):
            images = images.to(device)

            super_logits, sub_logits, super_novel_logits, sub_novel_logits = model(images)

            super_probs = F.softmax(super_logits, dim=1)
            sub_probs = F.softmax(sub_logits, dim=1)

            super_max_prob, super_pred = torch.max(super_probs, dim=1)
            sub_max_prob, sub_pred = torch.max(sub_probs, dim=1)

            super_novel_score = torch.sigmoid(super_novel_logits.squeeze()).item()
            sub_novel_score = torch.sigmoid(sub_novel_logits.squeeze()).item()

            if super_max_prob.item() < threshold or super_novel_score > 0.5:
                final_super = novel_super_idx
                final_sub = novel_sub_idx
            elif sub_max_prob.item() < threshold or sub_novel_score > 0.5:
                final_super = super_pred.item()
                final_sub = novel_sub_idx
            else:
                final_super = super_pred.item()
                final_sub = sub_pred.item()

            test_predictions['image'].append(img_names[0])
            test_predictions['superclass_index'].append(final_super)
            test_predictions['subclass_index'].append(final_sub)

    test_predictions = pd.DataFrame(test_predictions)

    if save_to_csv:
        test_predictions.to_csv(save_path, index=False)
        print(f" Test predictions saved to {save_path}")

    return test_predictions


In [None]:
test_predictions = test_model_1stage_with_softmax_threshold(
    model,
    open_set_test_loader,
    device,
    threshold=0.8,
    save_to_csv=True,
    save_path='1_stage_thresholded.csv'
)


In [91]:
import torch.nn.functional as F

def test_model_1stage_multi_threshold(model, test_loader, device, super_threshold=0.7, sub_threshold=0.6, save_to_csv=True, save_path='test_predictions.csv'):
    model.eval()
    test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}

    novel_super_idx = 3
    novel_sub_idx = 87

    with torch.no_grad():
        for i, (images, img_names) in enumerate(test_loader):
            images = images.to(device)

            super_logits, sub_logits, super_novel_logits, sub_novel_logits = model(images)

            super_probs = F.softmax(super_logits, dim=1)
            sub_probs = F.softmax(sub_logits, dim=1)

            super_max_prob, super_pred = torch.max(super_probs, dim=1)
            sub_max_prob, sub_pred = torch.max(sub_probs, dim=1)

            super_novel_score = torch.sigmoid(super_novel_logits.squeeze()).item()
            sub_novel_score = torch.sigmoid(sub_novel_logits.squeeze()).item()

            if super_max_prob.item() < super_threshold or super_novel_score > 0.5:
                final_super = novel_super_idx
                final_sub = novel_sub_idx
            elif sub_max_prob.item() < sub_threshold or sub_novel_score > 0.5:
                final_super = super_pred.item()
                final_sub = novel_sub_idx
            else:
                final_super = super_pred.item()
                final_sub = sub_pred.item()

            test_predictions['image'].append(img_names[0])
            test_predictions['superclass_index'].append(final_super)
            test_predictions['subclass_index'].append(final_sub)

    test_predictions = pd.DataFrame(test_predictions)

    if save_to_csv:
        test_predictions.to_csv(save_path, index=False)
        print(f" Test predictions saved to {save_path}")

    return test_predictions


In [None]:
test_predictions = test_model_1stage_multi_threshold(
    model,
    open_set_test_loader,
    device,
    super_threshold=0.9,
    sub_threshold=0.75,
    save_to_csv=True,
    save_path='test_predictions_multi_threshold.csv'
)


In [95]:
from torchvision import transforms
from torchvision.transforms import ToTensor, ToPILImage
from torch.utils.data import random_split, DataLoader

# Load base data
train_ann_df = pd.read_csv(f'{data_dir}/train_data.csv')
test_ann_df = pd.read_csv(f'{data_dir}/train_data.csv')  # (if no separate test set)
super_map_df = pd.read_csv(f'{data_dir}/superclass_mapping.csv')
sub_map_df = pd.read_csv(f'{data_dir}/subclass_mapping.csv')

train_img_dir = '/content/train_images'
test_img_dir = '/content/train_images'

# Image transform
image_preprocessing = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3)
])

# Annotate original known data
train_ann_df['is_super_novel'] = 0
train_ann_df['is_sub_novel'] = 0

# Only known data
full_dataset = MultiClassImageDatasetWithNovelty(
    ann_df=train_ann_df,
    super_map_df=super_map_df,
    sub_map_df=sub_map_df,
    transform=image_preprocessing
)

torch.manual_seed(42)

train_size = int(0.8 * len(full_dataset))
val_size = int(0.1 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size]
)

batch_size = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False
)


In [97]:
def train_and_validate_pure(model, train_loader, val_loader, optimizer,
                       criterion, super_novel_criterion, sub_novel_criterion,
                       device, num_epochs=10, warmup_epochs=5):

    for epoch in range(num_epochs):
        print(f"\n Epoch {epoch+1}/{num_epochs}")
        start_time = time.time()

        model.train()
        train_loss = 0.0
        correct_super_train, correct_sub_train = 0, 0
        correct_supernovel_train, correct_subnovel_train = 0, 0
        total_train = 0

        for i, (images, super_idx, sub_idx, is_super_novel, is_sub_novel) in enumerate(train_loader):
            images = images.to(device)
            super_idx = super_idx.to(device)
            sub_idx = sub_idx.to(device)
            is_super_novel = is_super_novel.float().to(device)
            is_sub_novel = is_sub_novel.float().to(device)

            out_super, out_sub, out_super_novel, out_sub_novel = model(images)

            if epoch < warmup_epochs:
                # Warmup phase: only classification losses
                loss = criterion(out_super, super_idx) + criterion(out_sub, sub_idx)
            else:
                # After warmup: add novelty losses
                loss = (
                    criterion(out_super, super_idx) +
                    criterion(out_sub, sub_idx) +
                    super_novel_criterion(out_super_novel.squeeze(), is_super_novel) +
                    sub_novel_criterion(out_sub_novel.squeeze(), is_sub_novel)
                )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pred_super = out_super.argmax(1)
            pred_sub = out_sub.argmax(1)
            pred_supernovel = (torch.sigmoid(out_super_novel.squeeze()) > 0.5).long()
            pred_subnovel = (torch.sigmoid(out_sub_novel.squeeze()) > 0.5).long()

            correct_super_train += (pred_super == super_idx).sum().item()
            correct_sub_train += (pred_sub == sub_idx).sum().item()
            correct_supernovel_train += (pred_supernovel == is_super_novel.long()).sum().item()
            correct_subnovel_train += (pred_subnovel == is_sub_novel.long()).sum().item()

            total_train += images.size(0)
            train_loss += loss.item()

        print(f" Epoch {epoch+1} Train Summary: Loss = {train_loss/len(train_loader):.4f} | "
              f"Super Acc = {correct_super_train/total_train:.4f} | "
              f"Sub Acc = {correct_sub_train/total_train:.4f} | "
              f"SuperNovel Acc = {correct_supernovel_train/total_train:.4f} | "
              f"SubNovel Acc = {correct_subnovel_train/total_train:.4f}")

        model.eval()
        val_loss = 0.0
        correct_super_val, correct_sub_val = 0, 0
        correct_supernovel_val, correct_subnovel_val = 0, 0
        total_val = 0

        with torch.no_grad():
            for i, (images, super_idx, sub_idx, is_super_novel, is_sub_novel) in enumerate(val_loader):
                images = images.to(device)
                super_idx = super_idx.to(device)
                sub_idx = sub_idx.to(device)
                is_super_novel = is_super_novel.float().to(device)
                is_sub_novel = is_sub_novel.float().to(device)

                out_super, out_sub, out_super_novel, out_sub_novel = model(images)

                if epoch < warmup_epochs:
                    loss = criterion(out_super, super_idx) + criterion(out_sub, sub_idx)
                else:
                    loss = (
                        criterion(out_super, super_idx) +
                        criterion(out_sub, sub_idx) +
                        super_novel_criterion(out_super_novel.squeeze(), is_super_novel) +
                        sub_novel_criterion(out_sub_novel.squeeze(), is_sub_novel)
                    )
                val_loss += loss.item()

                pred_super = out_super.argmax(1)
                pred_sub = out_sub.argmax(1)
                pred_supernovel = (torch.sigmoid(out_super_novel.squeeze()) > 0.5).long()
                pred_subnovel = (torch.sigmoid(out_sub_novel.squeeze()) > 0.5).long()

                correct_super_val += (pred_super == super_idx).sum().item()
                correct_sub_val += (pred_sub == sub_idx).sum().item()
                correct_supernovel_val += (pred_supernovel == is_super_novel.long()).sum().item()
                correct_subnovel_val += (pred_subnovel == is_sub_novel.long()).sum().item()

                total_val += images.size(0)

        print(f" Epoch {epoch+1} Val Summary: Loss = {val_loss/len(val_loader):.4f} | "
              f"Super Acc = {correct_super_val/total_val:.4f} | "
              f"Sub Acc = {correct_sub_val/total_val:.4f} | "
              f"SuperNovel Acc = {correct_supernovel_val/total_val:.4f} | "
              f"SubNovel Acc = {correct_subnovel_val/total_val:.4f}")
        print(f"Time: {time.time() - start_time:.2f}s")


In [99]:
# 1. Rebuild Model
model = ViTMultiTaskWithNovelty().to(device)

# 2. Rebuild Optimizer
head_params = (
    list(model.super_head.parameters()) +
    list(model.sub_head.parameters()) +
    list(model.super_novel_head.parameters()) +
    list(model.sub_novel_head.parameters())
)
backbone_params = [
    p for n, p in model.named_parameters()
    if ("blocks.8" in n or "blocks.9" in n or "blocks.10" in n or "blocks.11" in n or "norm" in n)
]
optimizer = torch.optim.AdamW([
    {'params': head_params, 'lr': 1e-3},
    {'params': backbone_params, 'lr': 1e-5}
])
train_and_validate_pure(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    super_novel_criterion,
    sub_novel_criterion,
    device,
    num_epochs=3,
    warmup_epochs=5
)


 Epoch 1/3
 Epoch 1 Train Summary: Loss = 0.9985 | Super Acc = 0.9738 | Sub Acc = 0.7845 | SuperNovel Acc = 0.5851 | SubNovel Acc = 0.7636
 Epoch 1 Val Summary: Loss = 0.2322 | Super Acc = 1.0000 | Sub Acc = 0.9252 | SuperNovel Acc = 0.5685 | SubNovel Acc = 0.7532
Time: 14.57s

 Epoch 2/3
 Epoch 2 Train Summary: Loss = 0.1099 | Super Acc = 1.0000 | Sub Acc = 0.9744 | SuperNovel Acc = 0.5831 | SubNovel Acc = 0.7312
 Epoch 2 Val Summary: Loss = 0.1636 | Super Acc = 1.0000 | Sub Acc = 0.9427 | SuperNovel Acc = 0.5653 | SubNovel Acc = 0.7373
Time: 14.96s

 Epoch 3/3
 Epoch 3 Train Summary: Loss = 0.0594 | Super Acc = 1.0000 | Sub Acc = 0.9881 | SuperNovel Acc = 0.5819 | SubNovel Acc = 0.7149
 Epoch 3 Val Summary: Loss = 0.1382 | Super Acc = 1.0000 | Sub Acc = 0.9506 | SuperNovel Acc = 0.5669 | SubNovel Acc = 0.7150
Time: 14.81s


In [100]:
evaluate(model, test_loader, device)

   Evaluation Summary:
   Superclass Accuracy      = 1.0000
   Subclass Accuracy        = 0.9603
   Super Novelty Accuracy   = 0.5714
   Sub Novelty Accuracy     = 0.6762


In [None]:
test_predictions = test_model_1stage_multi_threshold(
    model,
    open_set_test_loader,
    device,
    super_threshold=0.9,
    sub_threshold=0.85,
    save_to_csv=True,
    save_path='test_predictions_multi_threshold.csv'
)


In [103]:
def train_two_stage_pure(supernet, subnet, train_loader, optimizer_super, optimizer_sub, device,
                    num_epochs=10, warmup_epochs=5):
    for epoch in range(num_epochs):
        print(f"\n Epoch {epoch+1}/{num_epochs}")
        supernet.train()
        subnet.train()

        total, correct_super, correct_sub = 0, 0, 0
        correct_supernovel, correct_subnovel = 0, 0

        for images, super_idx, sub_idx, is_super_novel, is_sub_novel in train_loader:
            images = images.to(device)
            super_idx = super_idx.to(device)
            sub_idx = sub_idx.to(device)
            is_super_novel = is_super_novel.float().to(device)
            is_sub_novel = is_sub_novel.float().to(device)

            # --- SuperNet forward ---
            super_logits, super_novel = supernet(images)

            if epoch < warmup_epochs:
                super_loss = criterion(super_logits, super_idx)
            else:
                super_loss = (
                    criterion(super_logits, super_idx) +
                    bce(super_novel.squeeze(), is_super_novel)
                )

            optimizer_super.zero_grad()
            super_loss.backward()
            optimizer_super.step()

            # --- SubNet forward (teacher forcing with GT super_idx) ---
            super_onehot = F.one_hot(super_idx, num_classes=4).float().to(device)
            sub_logits, sub_novel = subnet(images, super_onehot)

            if epoch < warmup_epochs:
                sub_loss = criterion(sub_logits, sub_idx)
            else:
                sub_loss = (
                    criterion(sub_logits, sub_idx) +
                    bce(sub_novel.squeeze(), is_sub_novel)
                )

            optimizer_sub.zero_grad()
            sub_loss.backward()
            optimizer_sub.step()

            # --- Metrics ---
            pred_super = super_logits.argmax(1)
            pred_sub = sub_logits.argmax(1)
            pred_supernovel = (torch.sigmoid(super_novel.squeeze()) > 0.5).long()
            pred_subnovel = (torch.sigmoid(sub_novel.squeeze()) > 0.5).long()

            correct_super += (pred_super == super_idx).sum().item()
            correct_sub += (pred_sub == sub_idx).sum().item()
            correct_supernovel += (pred_supernovel == is_super_novel.long()).sum().item()
            correct_subnovel += (pred_subnovel == is_sub_novel.long()).sum().item()
            total += images.size(0)

        print(f" Super Acc = {correct_super/total:.4f} | Sub Acc = {correct_sub/total:.4f}")
        print(f" SuperNovel Acc = {correct_supernovel/total:.4f} | SubNovel Acc = {correct_subnovel/total:.4f}")


In [None]:
supernet = SuperNet().to(device)
subnet = SubNet().to(device)

optimizer_super = torch.optim.AdamW(supernet.parameters(), lr=1e-4)
optimizer_sub = torch.optim.AdamW(subnet.parameters(), lr=1e-4)

train_two_stage_pure(
    supernet,
    subnet,
    train_loader,
    optimizer_super,
    optimizer_sub,
    device,
    num_epochs=5,
    warmup_epochs=5
)



In [None]:
evaluate_two_stage(supernet, subnet, test_loader, device)

In [109]:
import torch.nn.functional as F

def test_model_2stage_multi_threshold(supernet, subnet, test_loader, device,
                                       super_threshold=0.75, sub_threshold=0.65,
                                       save_to_csv=True, save_path='test_predictions.csv'):
    supernet.eval()
    subnet.eval()

    test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}

    novel_super_idx = 3
    novel_sub_idx = 87
    num_super_classes = 4

    with torch.no_grad():
        for i, (images, img_names) in enumerate(test_loader):
            images = images.to(device)

            # ----- Stage 1: Superclass prediction -----
            super_logits, super_novel_logits = supernet(images)

            super_probs = F.softmax(super_logits, dim=1)
            super_max_prob, super_pred = torch.max(super_probs, dim=1)

            super_novel_score = torch.sigmoid(super_novel_logits.squeeze()).item()

            if super_max_prob.item() < super_threshold or super_novel_score > 0.5:
                final_super = novel_super_idx
                final_sub = novel_sub_idx
            else:
                final_super = super_pred.item()

                # ----- Stage 2: Subclass prediction -----
                super_onehot = F.one_hot(torch.tensor([final_super], device=device), num_classes=num_super_classes).float()
                sub_logits, sub_novel_logits = subnet(images, super_onehot)


                sub_probs = F.softmax(sub_logits, dim=1)
                sub_max_prob, sub_pred = torch.max(sub_probs, dim=1)

                sub_novel_score = torch.sigmoid(sub_novel_logits.squeeze()).item()

                if sub_max_prob.item() < sub_threshold or sub_novel_score > 0.5:
                    final_sub = novel_sub_idx
                else:
                    final_sub = sub_pred.item()

            test_predictions['image'].append(img_names[0])
            test_predictions['superclass_index'].append(final_super)
            test_predictions['subclass_index'].append(final_sub)

    test_predictions = pd.DataFrame(test_predictions)

    if save_to_csv:
        test_predictions.to_csv(save_path, index=False)
        print(f" 2-stage multi-threshold test predictions saved to {save_path}")

    return test_predictions


In [None]:
test_predictions = test_model_2stage_multi_threshold(
    supernet,
    subnet,
    open_set_test_loader,
    device,
    super_threshold=0.75,
    sub_threshold=0.65,
    save_to_csv=True,
    save_path='test_predictions_2stage_multi_threshold.csv'
)


In [139]:
import timm
import torch.nn as nn

import timm
import torch.nn as nn

class TinyViTMultiTaskNoNoveltyWithDropout(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super().__init__()
        self.base = timm.create_model('vit_tiny_patch16_224', pretrained=True)
        self.base.head = nn.Identity()

        self.dropout = nn.Dropout(dropout_rate)

        self.super_head = nn.Linear(self.base.num_features, 4)
        self.sub_head = nn.Linear(self.base.num_features, 88)

        for param in self.base.parameters():
            param.requires_grad = False

        for name, param in self.base.named_parameters():
            if "blocks.8" in name or "blocks.9" in name or "blocks.10" in name or "blocks.11" in name or "norm" in name:
                param.requires_grad = True

    def forward(self, x):
        feats = self.base(x)
        feats = self.dropout(feats)

        return self.super_head(feats), self.sub_head(feats)





In [140]:
from torch.optim.lr_scheduler import StepLR

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)


In [141]:
def train_and_validate_pure(model, train_loader, val_loader, optimizer, criterion, device, num_epochs=15, scheduler=None):
    for epoch in range(num_epochs):
        print(f"\n Epoch {epoch+1}/{num_epochs}")
        start_time = time.time()

        model.train()
        total_train = 0
        correct_super_train, correct_sub_train = 0, 0
        train_loss = 0.0

        for images, super_idx, sub_idx, _, _ in train_loader:
            images = images.to(device)
            super_idx = super_idx.to(device)
            sub_idx = sub_idx.to(device)

            out_super, out_sub = model(images)

            loss = criterion(out_super, super_idx) + criterion(out_sub, sub_idx)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pred_super = out_super.argmax(1)
            pred_sub = out_sub.argmax(1)

            correct_super_train += (pred_super == super_idx).sum().item()
            correct_sub_train += (pred_sub == sub_idx).sum().item()
            total_train += images.size(0)
            train_loss += loss.item()

        if scheduler:
            scheduler.step()

        print(f" Train Loss: {train_loss/len(train_loader):.4f} | Super Acc: {correct_super_train/total_train:.4f} | Sub Acc: {correct_sub_train/total_train:.4f}")

        # ---- VALIDATION ----
        model.eval()
        total_val = 0
        correct_super_val, correct_sub_val = 0, 0
        val_loss = 0.0

        with torch.no_grad():
            for images, super_idx, sub_idx, _, _ in val_loader:
                images = images.to(device)
                super_idx = super_idx.to(device)
                sub_idx = sub_idx.to(device)

                out_super, out_sub = model(images)

                loss = criterion(out_super, super_idx) + criterion(out_sub, sub_idx)
                val_loss += loss.item()

                pred_super = out_super.argmax(1)
                pred_sub = out_sub.argmax(1)

                correct_super_val += (pred_super == super_idx).sum().item()
                correct_sub_val += (pred_sub == sub_idx).sum().item()
                total_val += images.size(0)

        print(f"📊 Val Loss: {val_loss/len(val_loader):.4f} | Super Acc: {correct_super_val/total_val:.4f} | Sub Acc: {correct_sub_val/total_val:.4f}")
        print(f"⏱️ Time: {time.time() - start_time:.2f}s")


In [None]:
# 1. Rebuild Model
model = TinyViTMultiTaskNoNoveltyWithDropout().to(device)

# 2. Rebuild Optimizer
head_params = (
    list(model.super_head.parameters()) +
    list(model.sub_head.parameters())
)
backbone_params = [
    p for n, p in model.named_parameters()
    if ("blocks.8" in n or "blocks.9" in n or "blocks.10" in n or "blocks.11" in n or "norm" in n)
]
optimizer = torch.optim.AdamW([
    {'params': head_params, 'lr': 1e-4, 'weight_decay': 0.01},
    {'params': backbone_params, 'lr': 1e-6, 'weight_decay': 0.01}
])
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)

train_and_validate_pure(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    num_epochs=8,
    scheduler=scheduler
)


In [None]:
train_and_validate_pure(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    num_epochs=12,
    scheduler=scheduler
)

In [150]:
def evaluate_pure(model, dataloader, device):
    model.eval()
    total = 0
    correct_super = 0
    correct_sub = 0

    with torch.no_grad():
        for images, super_idx, sub_idx, _, _ in dataloader:
            images = images.to(device)
            super_idx = super_idx.to(device)
            sub_idx = sub_idx.to(device)

            out_super, out_sub = model(images)

            pred_super = out_super.argmax(1)
            pred_sub = out_sub.argmax(1)

            correct_super += (pred_super == super_idx).sum().item()
            correct_sub += (pred_sub == sub_idx).sum().item()
            total += images.size(0)

    super_acc = correct_super / total
    sub_acc = correct_sub / total

    print(f"🧪 Evaluation:")
    print(f"   Superclass Accuracy = {super_acc:.4f}")
    print(f"   Subclass Accuracy   = {sub_acc:.4f}")


In [None]:
evaluate_pure(model, test_loader, device)

In [152]:
import torch.nn.functional as F

def test_model_1stage_pure(model, test_loader, device,
                            super_threshold=0.75, sub_threshold=0.65,
                            save_to_csv=True, save_path='test_predictions_pure.csv'):
    model.eval()
    test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}

    novel_super_idx = 3
    novel_sub_idx = 87

    with torch.no_grad():
        for i, (images, img_names) in enumerate(test_loader):
            images = images.to(device)

            out_super, out_sub = model(images)

            super_probs = F.softmax(out_super, dim=1)
            super_max_prob, super_pred = torch.max(super_probs, dim=1)

            sub_probs = F.softmax(out_sub, dim=1)
            sub_max_prob, sub_pred = torch.max(sub_probs, dim=1)

            # Decision based on softmax confidence
            if super_max_prob.item() < super_threshold:
                final_super = novel_super_idx
                final_sub = novel_sub_idx
            else:
                final_super = super_pred.item()

                if sub_max_prob.item() < sub_threshold:
                    final_sub = novel_sub_idx
                else:
                    final_sub = sub_pred.item()

            test_predictions['image'].append(img_names[0])
            test_predictions['superclass_index'].append(final_super)
            test_predictions['subclass_index'].append(final_sub)

    test_predictions = pd.DataFrame(test_predictions)

    if save_to_csv:
        test_predictions.to_csv(save_path, index=False)
        print(f" Test predictions saved to {save_path}")

    return test_predictions


In [None]:
test_predictions = test_model_1stage_pure(
    model,
    open_set_test_loader,
    device,
    super_threshold=0.75,
    sub_threshold=0.65,
    save_to_csv=True,
    save_path='test_predictions_1stage_pure.csv'
)
