In [1]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image
import copy

In [2]:
# Create Dataset class for multilabel classification
class MultiClassImageDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]

        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]

        if self.transform:
            image = self.transform(image)

        return image, super_idx, super_label, sub_idx, sub_label

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
train_ann_df = pd.read_csv('/content/drive/My Drive/train_data.csv')
super_map_df = pd.read_csv('/content/drive/My Drive/superclass_mapping.csv')
sub_map_df = pd.read_csv('/content/drive/My Drive/subclass_mapping.csv')

In [5]:
#Test Dataset
#test_ann_df = pd.read_csv('/content/drive/My Drive/test_data.csv')

train_img_dir = '/content/drive/My Drive/train_images/train_images/'
test_img_dir = '/content/drive/My Drive/test_images/test_images/'

import os
from tqdm import tqdm
import shutil

# Create a local cache directory
local_cache_dir = "/content/local_train_image_cache"
os.makedirs(local_cache_dir, exist_ok=True)

# Copy your dataset from Google Drive to local storage once
if len(os.listdir(local_cache_dir)) == 0:  # Only copy if cache is empty
    print("Caching images locally from Google Drive...")
    source_dir = train_img_dir

    # Get list of image files
    image_files = [f for f in os.listdir(source_dir)
                  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    # Copy files with progress bar
    for img in tqdm(image_files):
        shutil.copy(os.path.join(source_dir, img),
                   os.path.join(local_cache_dir, img))

    print(f"Cached {len(image_files)} images to local storage")


image_preprocessing = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0), std=(1)),
])

# Update this in your original code where you define image_preprocessing
# image_preprocessing = transforms.Compose([
#     transforms.Resize((224, 224)),  # Resize to 224x224 for ViT
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
# ])

# Create train and val split
full_dataset = MultiClassImageDataset(train_ann_df, super_map_df, sub_map_df, local_cache_dir, transform=image_preprocessing)
train_dataset, val_dataset = random_split(full_dataset, [0.9, 0.1])

#Create test dataset
test_dataset = MultiClassImageTestDataset(super_map_df, sub_map_df, test_img_dir, transform=image_preprocessing)

# Create dataloaders
batch_size = 64

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)


val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=True)

test_loader = DataLoader(test_dataset,
                         batch_size=1,
                         shuffle=False)




Caching images locally from Google Drive...


100%|██████████| 6288/6288 [02:44<00:00, 38.32it/s] 

Cached 6288 images to local storage





In [13]:
# @title Default title text
#BEST MODEL

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import random
from torch.utils.data import DataLoader, Subset, random_split

class CNN(nn.Module):
    def __init__(self, input_size=64, num_superclasses=4, num_subclasses=88):
        super().__init__()

        self.feature_size = input_size // (2**3)


        self.block1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2)
        )


        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2)
        )


        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2)
        )


        self.fc1 = nn.Linear(self.feature_size * self.feature_size * 128, 256)
        self.dropout1 = nn.Dropout(0.2)
        self.fc2 = nn.Linear(256, 128)
        self.dropout2 = nn.Dropout(0.2)


        self.fc3a = nn.Linear(128, num_superclasses)
        self.fc3b = nn.Linear(128, num_subclasses)

    def forward(self, x):
        """Forward pass through the network"""

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = torch.flatten(x, 1)


        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)


        super_out = self.fc3a(x)
        sub_out = self.fc3b(x)

        return super_out, sub_out

    def get_features(self, x):
        """Extract features before the final classification layer"""

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)

        return x


class NoveltyDetectionTrainer:
    def __init__(self, full_dataset, image_preprocessing, device='cuda', batch_size=64):
        self.full_dataset = full_dataset
        self.image_preprocessing = image_preprocessing
        self.device = device
        self.batch_size = batch_size


        self.energy_mean = 0
        self.energy_std = 1


        self.superclass_indices = set()
        for i in range(len(full_dataset)):
            _, super_idx, _, _, _ = full_dataset[i]
            if hasattr(super_idx, 'item'):
                super_idx = super_idx.item()
            self.superclass_indices.add(super_idx)

        self.superclass_indices = sorted(list(self.superclass_indices))
        print(f"Found superclasses with indices: {self.superclass_indices}")

    def cross_validate_novelty_detection(self, epochs=5, confidence_threshold=0.0):
        results = []

        for fold, novel_idx in enumerate(self.superclass_indices):
            print(f"\n=== Fold {fold+1}/{len(self.superclass_indices)}: Treating superclass {novel_idx} as novel ===")

            known_indices, novel_indices = self._split_by_superclass(novel_idx)

            np.random.shuffle(known_indices)
            train_size = int(0.9 * len(known_indices))
            train_indices = known_indices[:train_size]
            val_known_indices = known_indices[train_size:]

            train_dataset = Subset(self.full_dataset, train_indices)
            val_known_dataset = Subset(self.full_dataset, val_known_indices)
            val_novel_dataset = Subset(self.full_dataset, novel_indices)


            train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
            val_known_loader = DataLoader(val_known_dataset, batch_size=self.batch_size, shuffle=False)
            val_novel_loader = DataLoader(val_novel_dataset, batch_size=self.batch_size, shuffle=False)


            model = CNN(input_size=64, num_superclasses=len(self.superclass_indices)+1).to(self.device)
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(model.parameters(), lr=1e-3)


            self._train_model(model, criterion, optimizer, train_loader, epochs)


            self._calibrate_energy_stats(model, train_loader)


            metrics = self._evaluate_novelty_detection(model, val_known_loader, val_novel_loader, confidence_threshold)
            results.append(metrics)

            print(f"Fold {fold+1} results:")
            for key, value in metrics.items():
                print(f"  {key}: {value:.4f}")


        avg_results = {}
        for key in results[0].keys():
            avg_results[key] = sum(r[key] for r in results) / len(results)


        for key, value in avg_results.items():
            print(f"{key}: {value:.4f}")

        return avg_results, results

    def find_optimal_threshold(self, fold_index=0, threshold_range=np.arange(-3.0, 3.0, 0.1)):
        novel_idx = self.superclass_indices[fold_index]
        print(f"\n=== Finding optimal threshold for fold {fold_index+1}: Superclass {novel_idx} as novel ===")


        known_indices, novel_indices = self._split_by_superclass(novel_idx)


        np.random.shuffle(known_indices)
        train_size = int(0.9 * len(known_indices))
        train_indices = known_indices[:train_size]
        val_known_indices = known_indices[train_size:]


        train_dataset = Subset(self.full_dataset, train_indices)
        val_known_dataset = Subset(self.full_dataset, val_known_indices)
        val_novel_dataset = Subset(self.full_dataset, novel_indices)

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_known_loader = DataLoader(val_known_dataset, batch_size=self.batch_size, shuffle=False)
        val_novel_loader = DataLoader(val_novel_dataset, batch_size=self.batch_size, shuffle=False)


        model = CNN(input_size=64, num_superclasses=len(self.superclass_indices)+1).to(self.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)


        self._train_model(model, criterion, optimizer, train_loader, epochs=5)


        self._calibrate_energy_stats(model, train_loader)


        known_energies, novel_energies = self._collect_energies(model, val_known_loader, val_novel_loader)


        results = []
        for threshold in threshold_range:

            known_correct = sum(1 for e in known_energies if e <= threshold)
            known_accuracy = known_correct / len(known_energies) if known_energies else 0


            novel_correct = sum(1 for e in novel_energies if e > threshold)
            novel_accuracy = novel_correct / len(novel_energies) if novel_energies else 0

            balanced_accuracy = (known_accuracy + novel_accuracy) / 2

            results.append({
                'threshold': threshold,
                'known_accuracy': known_accuracy,
                'novel_accuracy': novel_accuracy,
                'balanced_accuracy': balanced_accuracy
            })

            print(f"Threshold {threshold:.2f}: Known Acc={known_accuracy:.4f}, Novel Acc={novel_accuracy:.4f}, Balanced Acc={balanced_accuracy:.4f}")


        best_result = max(results, key=lambda x: x['balanced_accuracy'])

        print(f"\nBest threshold: {best_result['threshold']:.2f}")
        print(f"Known accuracy: {best_result['known_accuracy']:.4f}")
        print(f"Novel accuracy: {best_result['novel_accuracy']:.4f}")
        print(f"Balanced accuracy: {best_result['balanced_accuracy']:.4f}")

        return best_result['threshold'], results

    def _calibrate_energy_stats(self, model, loader):
      """Calculate energy statistics on a dataset for normalization"""
      model.eval()
      all_energies = []

      with torch.no_grad():
          for data in loader:
              inputs = data[0].to(self.device)


              super_outputs, _ = model(inputs)


              energies = -torch.logsumexp(super_outputs, dim=1)
              all_energies.extend(energies.cpu().numpy())

      all_energies = np.array(all_energies)
      self.energy_mean = float(np.mean(all_energies))
      self.energy_std = float(np.std(all_energies) + 1e-6)  # Add epsilon to avoid division by zero

      print(f"Calibrated energy statistics: mean={self.energy_mean:.4f}, std={self.energy_std:.4f}")

    def _compute_normalized_energy(self, logits):


      raw_energy = -torch.logsumexp(logits, dim=1)


      normalized_energy = (raw_energy - self.energy_mean) / self.energy_std

      return normalized_energy

    def _split_by_superclass(self, novel_superclass_idx):
        """Split dataset indices into known and novel based on superclass"""
        known_indices = []
        novel_indices = []

        for i in range(len(self.full_dataset)):
            _, super_idx, _, _, _ = self.full_dataset[i]
            if hasattr(super_idx, 'item'):
                super_idx = super_idx.item()

            if super_idx == novel_superclass_idx:
                novel_indices.append(i)
            else:
                known_indices.append(i)

        return known_indices, novel_indices

    def _train_model(self, model, criterion, optimizer, train_loader, epochs):
        """Train the model on known classes"""
        model.train()
        for epoch in range(epochs):
            running_loss = 0.0
            for i, data in enumerate(train_loader):
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(self.device)
                super_labels = super_labels.to(self.device)
                sub_labels = sub_labels.to(self.device)

                optimizer.zero_grad()
                super_outputs, sub_outputs = model(inputs)
                loss = criterion(super_outputs, super_labels) + criterion(sub_outputs, sub_labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}')

    def _evaluate_novelty_detection(self, model, known_loader, novel_loader, threshold):
      """Evaluate novelty detection performance using balanced ensemble approach."""
      model.eval()


      self._calibrate_energy_stats(model, known_loader)

      def eval_loader(loader, is_novel):
          super_correct, sub_correct = 0, 0
          super_total, sub_total = 0, 0

          with torch.no_grad():
              for data in loader:
                  inputs, _, _, _, _ = data
                  inputs = inputs.to(self.device)

                  super_outputs, sub_outputs = model(inputs)


                  super_energies = self._compute_normalized_energy(super_outputs)
                  energy_novel = super_energies > threshold


                  super_probs = F.softmax(super_outputs, dim=1)
                  super_confidences, _ = torch.max(super_probs, dim=1)
                  confidence_novel = super_confidences < 0.7


                  energy_weight = 0.6
                  confidence_weight = 0.4

                  novelty_score = energy_weight * energy_novel.float() + confidence_weight * confidence_novel.float()

                  is_novel_super = novelty_score > 0.5

                  sub_probs = F.softmax(sub_outputs, dim=1)
                  sub_confidences, _ = torch.max(sub_probs, dim=1)
                  is_novel_sub = sub_confidences < 0.5

                  if is_novel:
                      super_correct += is_novel_super.sum().item()
                      sub_correct += is_novel_sub.sum().item()
                  else:
                      super_correct += (~is_novel_super).sum().item()
                      sub_correct += (~is_novel_sub).sum().item()

                  super_total += inputs.size(0)
                  sub_total += inputs.size(0)

          return (
              super_correct / super_total if super_total else 0,
              sub_correct / sub_total if sub_total else 0
          )


      known_super_acc, known_sub_acc = eval_loader(known_loader, is_novel=False)
      novel_super_acc, novel_sub_acc = eval_loader(novel_loader, is_novel=True)

      balanced_super_acc = (known_super_acc + novel_super_acc) / 2
      balanced_sub_acc = (known_sub_acc + novel_sub_acc) / 2

      return {
          'known_superclass_accuracy': known_super_acc,
          'novel_superclass_accuracy': novel_super_acc,
          'balanced_superclass_accuracy': balanced_super_acc,
          'known_subclass_accuracy': known_sub_acc,
          'novel_subclass_accuracy': novel_sub_acc,
          'balanced_subclass_accuracy': balanced_sub_acc
      }


    def _collect_energies(self, model, known_loader, novel_loader):
        """Collect normalized energy scores for known and novel classes"""
        model.eval()

        known_energies = []
        novel_energies = []

        with torch.no_grad():

            for data in known_loader:
                inputs, _, _, _, _ = data
                inputs = inputs.to(self.device)

                super_outputs, _ = model(inputs)
                energies = self._compute_normalized_energy(super_outputs)
                known_energies.extend(energies.cpu().numpy())


            for data in novel_loader:
                inputs, _, _, _, _ = data
                inputs = inputs.to(self.device)

                super_outputs, _ = model(inputs)
                energies = self._compute_normalized_energy(super_outputs)
                novel_energies.extend(energies.cpu().numpy())

        return known_energies, novel_energies


class Trainer():
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device


        self.energy_mean = 0
        self.energy_std = 1
        self.energy_calibrated = False


        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3, verbose=True
        )


        self.temperature = 1.5

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        for i, data in enumerate(self.train_loader):
            inputs, super_labels, sub_labels = data[0].to(self.device), data[1].to(self.device), data[3].to(self.device)

            self.optimizer.zero_grad()
            super_outputs, sub_outputs = self.model(inputs)
            loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/(i+1):.3f}')
        avg_loss = running_loss/(i+1)
        self.scheduler.step(avg_loss)
        return avg_loss




        self._calibrate_energy_stats()

    def _calibrate_energy_stats(self):
        """Calculate energy statistics on training data for normalization"""
        self.model.eval()
        all_energies = []

        with torch.no_grad():
            for data in self.train_loader:
                inputs = data[0].to(self.device)


                super_outputs, _ = self.model(inputs)


                energies = -torch.logsumexp(super_outputs, dim=1)
                all_energies.extend(energies.cpu().numpy())


        all_energies = np.array(all_energies)
        self.energy_mean = float(np.mean(all_energies))
        self.energy_std = float(np.std(all_energies) + 1e-6)
        self.energy_calibrated = True

        print(f"Calibrated energy statistics: mean={self.energy_mean:.4f}, std={self.energy_std:.4f}")

    def compute_normalized_energy(self, logits):


        raw_energy = -torch.logsumexp(logits, dim=1)


        if not self.energy_calibrated:

            print("Warning: Energy statistics not calibrated, using raw energy")
            return raw_energy

        normalized_energy = (raw_energy - self.energy_mean) / self.energy_std

        return normalized_energy

    def validate_epoch(self, novel_superclass_idx=3, novel_subclass_idx=87, confidence_threshold=0.0, temperature=1.0):

      # Make sure energy statistics are calibrated
      if not self.energy_calibrated:
          self._calibrate_energy_stats()

      self.model.eval()

      # Metrics to track
      correct_with_novelty = 0
      super_correct_standard = 0
      sub_correct = 0

      novel_total = 0
      known_total = 0
      novel_correct = 0
      known_correct = 0

      total = 0

      novel_super_predictions = 0
      novel_sub_predictions = 0

      all_super_energies = []
      all_sub_confidences = []

      running_loss = 0.0

      with torch.no_grad():
          for i, data in enumerate(self.val_loader):
              inputs, super_labels, _, sub_labels, _ = data
              inputs = inputs.to(self.device)
              super_labels = super_labels.to(self.device)
              sub_labels = sub_labels.to(self.device)

              super_outputs, sub_outputs = self.model(inputs)


              super_energies = self.compute_normalized_energy(super_outputs)
              energy_novel = super_energies > confidence_threshold


              super_probs = F.softmax(super_outputs, dim=1)
              super_confidences, super_predicted = torch.max(super_probs, dim=1)

              conf_threshold = 0.7
              confidence_novel = super_confidences < conf_threshold

              energy_weight = 0.6
              confidence_weight = 0.4


              novelty_score = energy_weight * energy_novel.float() + confidence_weight * confidence_novel.float()


              decision_threshold = 0.5
              novel_super_mask = novelty_score > decision_threshold


              final_super_preds = torch.where(
                  novel_super_mask,
                  torch.full_like(super_predicted, novel_superclass_idx),
                  super_predicted
              )


              sub_probs = F.softmax(sub_outputs, dim=1)
              sub_confidences, sub_predicted = torch.max(sub_probs, dim=1)
              sub_threshold = 0.5
              novel_sub_mask = sub_confidences < sub_threshold

              final_sub_preds = torch.where(
                  novel_sub_mask,
                  torch.full_like(sub_predicted, novel_subclass_idx),
                  sub_predicted
              )


              total += super_labels.size(0)


              correct_with_novelty += (final_super_preds == super_labels).sum().item()
              super_correct_standard += (super_predicted == super_labels).sum().item()
              sub_correct += (final_sub_preds == sub_labels).sum().item()


              is_novel_label = super_labels == novel_superclass_idx
              novel_total += is_novel_label.sum().item()
              known_total += (~is_novel_label).sum().item()

              novel_correct += ((final_super_preds == super_labels) & is_novel_label).sum().item()
              known_correct += ((final_super_preds == super_labels) & ~is_novel_label).sum().item()


              novel_super_predictions += novel_super_mask.sum().item()
              novel_sub_predictions += novel_sub_mask.sum().item()


              all_super_energies.extend(super_energies.cpu().numpy())
              all_sub_confidences.extend(sub_confidences.cpu().numpy())


              loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
              running_loss += loss.item()


      super_acc = 100 * correct_with_novelty / total if total > 0 else 0
      sub_acc = 100 * sub_correct / total if total > 0 else 0

      novel_acc = 100 * novel_correct / novel_total if novel_total > 0 else 0
      known_acc = 100 * known_correct / known_total if known_total > 0 else 0
      balanced_acc = (novel_acc + known_acc) / 2 if novel_total > 0 and known_total > 0 else 0

      avg_super_energy = sum(all_super_energies) / len(all_super_energies) if all_super_energies else 0
      avg_sub_conf = sum(all_sub_confidences) / len(all_sub_confidences) if all_sub_confidences else 0

      novel_super_perc = 100 * novel_super_predictions / total if total > 0 else 0
      novel_sub_perc = 100 * novel_sub_predictions / total if total > 0 else 0

      # Display metrics
      print(f'Validation loss: {running_loss/(i+1):.3f}')
      print(f'Validation superclass acc: {super_acc:.2f}%')
      print(f'Validation subclass acc: {sub_acc:.2f}%')
      print(f'Novel superclass acc: {novel_acc:.2f}%, Known superclass acc: {known_acc:.2f}%')
      print(f'Balanced superclass acc: {balanced_acc:.2f}%')
      print(f'Average normalized superclass energy: {avg_super_energy:.4f}')
      print(f'Average subclass confidence: {avg_sub_conf:.4f}')
      print(f'Samples predicted as novel superclass: {novel_super_predictions} ({novel_super_perc:.2f}%)')
      print(f'Samples predicted as novel subclass: {novel_sub_predictions} ({novel_sub_perc:.2f}%)')

      return {
          'loss': running_loss/(i+1),
          'accuracy': super_acc,
          'novel_acc': novel_acc,
          'known_acc': known_acc,
          'balanced_acc': balanced_acc
      }

    def test(self, save_to_csv=False, return_predictions=False, confidence_threshold=0.0):
      if not self.test_loader:
          raise NotImplementedError('test_loader not specified')

      # Make sure energy statistics are calibrated
      if not self.energy_calibrated:
          self._calibrate_energy_stats()

      self.model.eval()
      novel_superclass_idx = 3  # Index for novel superclass
      novel_subclass_idx = 87   # Index for novel subclass

      # Create full data structure for internal use
      full_test_predictions = {
          'image': [],
          'superclass_index': [],
          'subclass_index': [],
          'superclass_energy': [],
          'subclass_confidence': [],
          'novelty_score': []
      }

      with torch.no_grad():
          for i, data in enumerate(self.test_loader):
              inputs, img_name = data[0].to(self.device), data[1]

              super_outputs, sub_outputs = self.model(inputs)


              super_energies = self.compute_normalized_energy(super_outputs)
              energy_novel = super_energies > confidence_threshold


              super_probs = F.softmax(super_outputs, dim=1)
              super_confidences, super_predicted = torch.max(super_probs, dim=1)


              conf_threshold = 0.7
              confidence_novel = super_confidences < conf_threshold


              energy_weight = 0.6
              confidence_weight = 0.4

              novelty_score = energy_weight * energy_novel.float() + confidence_weight * confidence_novel.float()
              decision_threshold = 0.5
              novel_super_mask = novelty_score > decision_threshold


              sub_probs = F.softmax(sub_outputs, dim=1)
              sub_confidences, sub_predicted = torch.max(sub_probs, dim=1)
              sub_threshold = 0.5
              novel_sub_mask = sub_confidences < sub_threshold

              for j in range(inputs.size(0)):
                  img = img_name[j] if isinstance(img_name, list) else img_name[0]


                  super_pred = novel_superclass_idx if novel_super_mask[j] else super_predicted[j].item()
                  sub_pred = novel_subclass_idx if novel_sub_mask[j] else sub_predicted[j].item()

                  full_test_predictions['image'].append(img)
                  full_test_predictions['superclass_index'].append(super_pred)
                  full_test_predictions['subclass_index'].append(sub_pred)
                  full_test_predictions['superclass_energy'].append(super_energies[j].item())
                  full_test_predictions['subclass_confidence'].append(sub_confidences[j].item())
                  full_test_predictions['novelty_score'].append(novelty_score[j].item())

      full_predictions_df = pd.DataFrame(data=full_test_predictions)


      simplified_test_predictions = {
          'image': full_test_predictions['image'],
          'superclass_index': full_test_predictions['superclass_index'],
          'subclass_index': full_test_predictions['subclass_index']
      }
      simplified_predictions_df = pd.DataFrame(data=simplified_test_predictions)


      novel_super_count = sum(1 for idx in full_test_predictions['superclass_index'] if idx == novel_superclass_idx)
      novel_sub_count = sum(1 for idx in full_test_predictions['subclass_index'] if idx == novel_subclass_idx)

      total_count = len(full_test_predictions['image'])
      novel_super_perc = 100 * novel_super_count / total_count if total_count > 0 else 0
      novel_sub_perc = 100 * novel_sub_count / total_count if total_count > 0 else 0

      print(f'Test set predictions:')
      print(f'Images predicted as novel superclass: {novel_super_count} ({novel_super_perc:.2f}%)')
      print(f'Images predicted as novel subclass: {novel_sub_count} ({novel_sub_perc:.2f}%)')


      print(f'Novelty score distribution:')
      bins = [0, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0]
      for i in range(len(bins)-1):
          count = sum(1 for score in full_test_predictions['novelty_score']
                    if bins[i] <= score < bins[i+1])
          print(f'  {bins[i]:.1f}-{bins[i+1]:.1f}: {count} ({100*count/total_count:.2f}%)')

      if save_to_csv:

          simplified_predictions_df.to_csv('example_test_predictions.csv', index=False)
          print("Predictions saved to 'example_test_predictions.csv'")

      if return_predictions:

          return full_predictions_df


def train_with_novelty_detection(full_dataset, image_preprocessing, device='cuda', batch_size=64, epochs=5):

    novelty_trainer = NoveltyDetectionTrainer(
        full_dataset=full_dataset,
        image_preprocessing=image_preprocessing,
        device=device,
        batch_size=batch_size
    )


    print("Running cross-validation for novelty detection...")
    avg_results, fold_results = novelty_trainer.cross_validate_novelty_detection(epochs=epochs)


    print("\nFinding optimal energy threshold...")
    best_threshold, threshold_results = novelty_trainer.find_optimal_threshold()

    return avg_results, best_threshold


In [14]:
# 1. First run cross-validation to find the optimal threshold
results, threshold = train_with_novelty_detection(full_dataset, image_preprocessing)
device = 'cuda'

# 2. Then train your final model on all data and use the threshold for inference
model = CNN(input_size=64, num_superclasses=3, num_subclasses=87).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, test_loader, device=device)


trainer.energy_mean = 0
trainer.energy_std = 1

for epoch in range(20):
    print(f'Epoch {epoch+1}')
    trainer.train_epoch()
    trainer.validate_epoch(confidence_threshold=threshold)  # Use optimized threshold
    print('')

print('Finished Training')


predictions = trainer.test(save_to_csv=True, confidence_threshold=threshold)

Found superclasses with indices: [0, 1, 2]
Running cross-validation for novelty detection...

=== Fold 1/3: Treating superclass 0 as novel ===
Epoch 1/5, Loss: 3.5685
Epoch 2/5, Loss: 2.1151
Epoch 3/5, Loss: 1.4693
Epoch 4/5, Loss: 1.1098
Epoch 5/5, Loss: 0.8552
Calibrated energy statistics: mean=-7.2986, std=2.6351
Calibrated energy statistics: mean=-7.1025, std=2.7703
Fold 1 results:
  known_superclass_accuracy: 0.4595
  novel_superclass_accuracy: 0.9330
  balanced_superclass_accuracy: 0.6962
  known_subclass_accuracy: 0.7432
  novel_subclass_accuracy: 0.6173
  balanced_subclass_accuracy: 0.6803

=== Fold 2/3: Treating superclass 1 as novel ===
Epoch 1/5, Loss: 3.7484
Epoch 2/5, Loss: 2.2629
Epoch 3/5, Loss: 1.5926
Epoch 4/5, Loss: 1.2059
Epoch 5/5, Loss: 1.0068
Calibrated energy statistics: mean=-7.6700, std=2.9918
Calibrated energy statistics: mean=-7.4149, std=2.9566
Fold 2 results:
  known_superclass_accuracy: 0.4537
  novel_superclass_accuracy: 0.9621
  balanced_superclass_accur

In [8]:
# @title Default title text
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from scipy.stats import weibull_min
from sklearn.preprocessing import normalize
from collections import defaultdict

class CNN(nn.Module):
    def __init__(self, input_size=64, num_superclasses=4, num_subclasses=88):
        super().__init__()

        self.feature_size = input_size // (2**3)

        self.block1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2)
        )

        self.fc1 = nn.Linear(self.feature_size * self.feature_size * 128, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3a = nn.Linear(128, num_superclasses)  # 4 superclasses: bird, dog, reptile, novel
        self.fc3b = nn.Linear(128, num_subclasses)    # All subclasses + novel

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = torch.flatten(x, 1)  # flatten all dimensions except batch

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        super_out = self.fc3a(x)
        sub_out = self.fc3b(x)
        return super_out, sub_out

    def get_features(self, x):
        """Extract features before the final classification layer"""
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

    def get_logits(self, features):
        """Get class logits from features"""
        super_out = self.fc3a(features)
        sub_out = self.fc3b(features)
        return super_out, sub_out


class OpenMaxModel:
    def __init__(self, model, num_superclasses=3, num_subclasses=87, tailsize=20, alpha=10):
        """
        Initialize OpenMax model with a pre-trained CNN

        Args:
            model: Pre-trained CNN model
            num_superclasses: Number of known superclasses
            num_subclasses: Number of known subclasses
            tailsize: Number of extremal samples to use for Weibull fitting
            alpha: Number of top activations to consider for recalibration
        """
        self.model = model
        self.num_superclasses = num_superclasses
        self.num_subclasses = num_subclasses
        self.tailsize = tailsize
        self.alpha = alpha

        # Storage for class means and Weibull models
        self.super_means = None
        self.sub_means = None
        self.super_weibull_models = None
        self.sub_weibull_models = None

    def fit(self, train_loader, device):

        self.model.eval()

        # Collect activations for each class
        super_activations = defaultdict(list)
        sub_activations = defaultdict(list)

        # Collect activations for each class
        with torch.no_grad():
            for data in train_loader:
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                # Get features
                features = self.model.get_features(inputs)
                super_logits, sub_logits = self.model.get_logits(features)

                # Store activations for each class
                for i in range(inputs.size(0)):
                    super_class = super_labels[i].item()
                    sub_class = sub_labels[i].item()

                    super_activations[super_class].append(super_logits[i].cpu().numpy())
                    sub_activations[sub_class].append(sub_logits[i].cpu().numpy())

        # Compute means for each class
        self.super_means = {}
        self.sub_means = {}

        for c in range(self.num_superclasses):
            if c in super_activations and len(super_activations[c]) > 0:
                self.super_means[c] = np.mean(super_activations[c], axis=0)

        for c in range(self.num_subclasses):
            if c in sub_activations and len(sub_activations[c]) > 0:
                self.sub_means[c] = np.mean(sub_activations[c], axis=0)

        # Compute distances to mean for Weibull fitting
        super_dists = defaultdict(list)
        sub_dists = defaultdict(list)

        with torch.no_grad():
            for data in train_loader:
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                # Get features
                features = self.model.get_features(inputs)
                super_logits, sub_logits = self.model.get_logits(features)

                # Compute distances
                for i in range(inputs.size(0)):
                    super_class = super_labels[i].item()
                    sub_class = sub_labels[i].item()

                    if super_class in self.super_means:
                        super_mean = self.super_means[super_class]
                        super_logit = super_logits[i].cpu().numpy()
                        super_dist = np.linalg.norm(super_logit - super_mean)
                        super_dists[super_class].append(super_dist)

                    if sub_class in self.sub_means:
                        sub_mean = self.sub_means[sub_class]
                        sub_logit = sub_logits[i].cpu().numpy()
                        sub_dist = np.linalg.norm(sub_logit - sub_mean)
                        sub_dists[sub_class].append(sub_dist)

        # Fit Weibull models
        self.super_weibull_models = {}
        self.sub_weibull_models = {}

        for c in range(self.num_superclasses):
            if c in super_dists and len(super_dists[c]) > self.tailsize:
                # Sort distances and take tailsize largest
                sorted_dists = sorted(super_dists[c])
                tail_dists = sorted_dists[-self.tailsize:]

                # Fit Weibull distribution
                try:
                    shape, loc, scale = weibull_min.fit(tail_dists, floc=0)
                    self.super_weibull_models[c] = (shape, loc, scale)
                except:
                    print(f"Warning: Failed to fit Weibull for superclass {c}")

        for c in range(self.num_subclasses):
            if c in sub_dists and len(sub_dists[c]) > self.tailsize:
                # Sort distances and take tailsize largest
                sorted_dists = sorted(sub_dists[c])
                tail_dists = sorted_dists[-self.tailsize:]

                # Fit Weibull distribution
                try:
                    shape, loc, scale = weibull_min.fit(tail_dists, floc=0)
                    self.sub_weibull_models[c] = (shape, loc, scale)
                except:
                    print(f"Warning: Failed to fit Weibull for subclass {c}")

    def predict(self, inputs, device):

        self.model.eval()

        with torch.no_grad():
            # Get features and logits
            features = self.model.get_features(inputs)
            super_logits, sub_logits = self.model.get_logits(features)

            # Convert to numpy for processing
            super_logits_np = super_logits.cpu().numpy()
            sub_logits_np = sub_logits.cpu().numpy()

            # Process each sample
            super_preds = []
            sub_preds = []

            for i in range(inputs.size(0)):
                # Recalibrate superclass logits
                super_logit = super_logits_np[i]
                super_pred = self._recalibrate_sample(super_logit, self.super_means, self.super_weibull_models, self.num_superclasses)
                super_preds.append(super_pred)

                # Recalibrate subclass logits
                sub_logit = sub_logits_np[i]
                sub_pred = self._recalibrate_sample(sub_logit, self.sub_means, self.sub_weibull_models, self.num_subclasses)
                sub_preds.append(sub_pred)

            # Convert back to tensors
            super_preds = torch.tensor(super_preds, device=device)
            sub_preds = torch.tensor(sub_preds, device=device)

        return super_preds, sub_preds

    def _recalibrate_sample(self, logits, means, weibull_models, num_classes):

        # Get top alpha class indices
        top_alpha_idx = np.argsort(logits)[-self.alpha:]

        # Compute distances to class means
        distances = {}
        for c in range(num_classes):
            if c in means:
                distances[c] = np.linalg.norm(logits - means[c])

        # Recalibrate activations
        recalibrated = np.copy(logits)
        for c in top_alpha_idx:
            if c < num_classes and c in weibull_models:
                # Get Weibull parameters
                shape, loc, scale = weibull_models[c]

                # Compute probability of being an outlier
                dist = distances.get(c, 0)
                weibull_score = 1 - weibull_min.cdf(dist, shape, loc, scale)

                # Adjust activation
                recalibrated[c] = logits[c] * (1 - weibull_score)

        # Apply softmax to get probabilities
        recalibrated_probs = np.exp(recalibrated) / np.sum(np.exp(recalibrated))

        # Compute probability of being unknown
        unknown_prob = 1.0 - np.sum(recalibrated_probs)

        # Make final prediction (including novel class)
        if unknown_prob > 0.5:  # Threshold for novel class detection
            return num_classes  # Return novel class index
        else:
            return np.argmax(recalibrated_probs)


class OpenMaxTrainer:
    def __init__(self, model, openmax_model, criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'):
        self.model = model
        self.openmax_model = openmax_model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0

        for i, data in enumerate(self.train_loader):
            inputs, super_labels, _, sub_labels, _ = data
            inputs = inputs.to(self.device)
            super_labels = super_labels.to(self.device)
            sub_labels = sub_labels.to(self.device)

            self.optimizer.zero_grad()
            super_outputs, sub_outputs = self.model(inputs)
            loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/(i+1):.3f}')

    def validate_epoch(self):
        self.model.eval()

        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0

        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(self.device)
                super_labels = super_labels.to(self.device)
                sub_labels = sub_labels.to(self.device)

                # Standard forward pass for loss calculation
                super_outputs, sub_outputs = self.model(inputs)
                loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)

                # OpenMax prediction
                super_preds, sub_preds = self.openmax_model.predict(inputs, self.device)

                total += super_labels.size(0)
                super_correct += (super_preds == super_labels).sum().item()
                sub_correct += (sub_preds == sub_labels).sum().item()

                running_loss += loss.item()

        print(f'Validation loss: {running_loss/(i+1):.3f}')
        print(f'Validation superclass acc: {100 * super_correct / total:.2f}%')
        print(f'Validation subclass acc: {100 * sub_correct / total:.2f}%')

    def test(self, save_to_csv=False, return_predictions=False):
        if not self.test_loader:
            raise NotImplementedError('test_loader not specified')

        self.model.eval()

        # Evaluate on test set with OpenMax
        test_predictions = {
            'image': [],
            'superclass_index': [],
            'subclass_index': []
        }

        with torch.no_grad():
            for i, data in enumerate(self.test_loader):
                inputs, img_name = data[0].to(self.device), data[1]

                # OpenMax prediction
                super_preds, sub_preds = self.openmax_model.predict(inputs, self.device)

                for j in range(inputs.size(0)):
                    img = img_name[j] if isinstance(img_name, list) else img_name[0]

                    test_predictions['image'].append(img)
                    test_predictions['superclass_index'].append(super_preds[j].item())
                    test_predictions['subclass_index'].append(sub_preds[j].item())

        test_predictions = pd.DataFrame(data=test_predictions)

        # Print summary of novel predictions
        novel_super_count = sum(1 for idx in test_predictions['superclass_index'] if idx == self.openmax_model.num_superclasses)
        novel_sub_count = sum(1 for idx in test_predictions['subclass_index'] if idx == self.openmax_model.num_subclasses)

        print(f'Test set predictions:')
        print(f'Images predicted as novel superclass: {novel_super_count} ({100*novel_super_count/len(test_predictions):.2f}%)')
        print(f'Images predicted as novel subclass: {novel_sub_count} ({100*novel_sub_count/len(test_predictions):.2f}%)')

        if save_to_csv:
            test_predictions.to_csv('openmax_test_predictions.csv', index=False)

        if return_predictions:
            return test_predictions


# Example usage:
def train_with_openmax(full_dataset, device='cuda', batch_size=64, num_epochs=20):
    # Create cross-validation split
    from torch.utils.data import random_split

    # Split into train and validation
    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    # Create dataloaders
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    model = CNN(input_size=64, num_superclasses=4, num_subclasses=88).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Train standard model first
    trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, device=device)

    print("Training standard model...")
    for epoch in range(10):  # Pre-train for 10 epochs
        print(f'Epoch {epoch+1}/10')
        trainer.train_epoch()

    # Initialize OpenMax model
    openmax_model = OpenMaxModel(model, num_superclasses=3, num_subclasses=87)

    # Fit OpenMax parameters
    print("\nFitting OpenMax parameters...")
    openmax_model.fit(train_loader, device)

    # Continue training with OpenMax
    openmax_trainer = OpenMaxTrainer(model, openmax_model, criterion, optimizer, train_loader, val_loader, device=device)

    print("\nTraining with OpenMax...")
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        openmax_trainer.train_epoch()
        openmax_trainer.validate_epoch()
        print()

    print("Finished Training")

    return model, openmax_model


# Cross-validation for OpenMax
def openmax_cross_validation(full_dataset, image_preprocessing, device='cuda', batch_size=64):
    """Run cross-validation for OpenMax novel detection"""
    # Get all unique superclass indices
    superclass_indices = set()
    for i in range(len(full_dataset)):
        _, super_idx, _, _, _ = full_dataset[i]
        if hasattr(super_idx, 'item'):
            super_idx = super_idx.item()
        superclass_indices.add(super_idx)

    superclass_indices = sorted(list(superclass_indices))
    print(f"Found superclasses with indices: {superclass_indices}")

    results = []

    # For each superclass, treat it as novel and others as known
    for fold, novel_idx in enumerate(superclass_indices):
        print(f"\n=== Fold {fold+1}/{len(superclass_indices)}: Treating superclass {novel_idx} as novel ===")

        # Create data splits
        known_indices = []
        novel_indices = []

        for i in range(len(full_dataset)):
            _, super_idx, _, _, _ = full_dataset[i]
            if hasattr(super_idx, 'item'):
                super_idx = super_idx.item()

            if super_idx == novel_idx:
                novel_indices.append(i)
            else:
                known_indices.append(i)

        # Further split known indices into train/validation
        np.random.shuffle(known_indices)
        train_size = int(0.9 * len(known_indices))
        train_indices = known_indices[:train_size]
        val_known_indices = known_indices[train_size:]

        # Create datasets
        train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
        val_known_dataset = torch.utils.data.Subset(full_dataset, val_known_indices)
        val_novel_dataset = torch.utils.data.Subset(full_dataset, novel_indices)

        # Create dataloaders
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_known_loader = torch.utils.data.DataLoader(val_known_dataset, batch_size=batch_size, shuffle=False)
        val_novel_loader = torch.utils.data.DataLoader(val_novel_dataset, batch_size=batch_size, shuffle=False)

        # Initialize model
        model = CNN(input_size=64, num_superclasses=len(superclass_indices)+1).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)

        # Train standard model first
        print("Training standard model...")
        for epoch in range(5):
            running_loss = 0.0
            for i, data in enumerate(train_loader):
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                optimizer.zero_grad()
                super_outputs, sub_outputs = model(inputs)
                loss = criterion(super_outputs, super_labels) + criterion(sub_outputs, sub_labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            print(f'Epoch {epoch+1}/5, Loss: {running_loss/(i+1):.4f}')

        # Initialize OpenMax model
        num_known_classes = len(superclass_indices) - 1  # Excluding the novel class
        openmax_model = OpenMaxModel(model, num_superclasses=num_known_classes)

        # Fit OpenMax parameters
        print("\nFitting OpenMax parameters...")
        openmax_model.fit(train_loader, device)

        # Evaluate OpenMax on known and novel samples
        model.eval()

        # Test on known classes
        known_correct = 0
        known_total = 0

        with torch.no_grad():
            for data in val_known_loader:
                inputs, super_labels, _, _, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)

                # OpenMax prediction
                super_preds, _ = openmax_model.predict(inputs, device)

                known_total += super_labels.size(0)
                known_correct += (super_preds == super_labels).sum().item()

        # Test on novel classes
        novel_correct = 0
        novel_total = 0

        with torch.no_grad():
            for data in val_novel_loader:
                inputs, _, _, _, _ = data
                inputs = inputs.to(device)

                # OpenMax prediction
                super_preds, _ = openmax_model.predict(inputs, device)

                # For novel classes, prediction should be the novel class index
                novel_total += inputs.size(0)
                novel_correct += (super_preds == num_known_classes).sum().item()

        known_acc = known_correct / known_total if known_total > 0 else 0
        novel_acc = novel_correct / novel_total if novel_total > 0 else 0
        balanced_acc = (known_acc + novel_acc) / 2

        results.append({
            'fold': fold,
            'novel_class': novel_idx,
            'known_accuracy': known_acc,
            'novel_accuracy': novel_acc,
            'balanced_accuracy': balanced_acc
        })

        print(f"Fold {fold+1} results:")
        print(f"  Known class accuracy: {known_acc:.4f}")
        print(f"  Novel class accuracy: {novel_acc:.4f}")
        print(f"  Balanced accuracy: {balanced_acc:.4f}")

    # Calculate average results
    avg_known_acc = sum(r['known_accuracy'] for r in results) / len(results)
    avg_novel_acc = sum(r['novel_accuracy'] for r in results) / len(results)
    avg_balanced_acc = sum(r['balanced_accuracy'] for r in results) / len(results)

    print("\n=== OpenMax Cross-Validation Summary ===")
    print(f"Average known accuracy: {avg_known_acc:.4f}")
    print(f"Average novel accuracy: {avg_novel_acc:.4f}")
    print(f"Average balanced accuracy: {avg_balanced_acc:.4f}")

    return results

In [9]:
# 1. First, cross-validate to evaluate OpenMax performance
results = openmax_cross_validation(full_dataset, image_preprocessing)

# 2. Then train your final model with OpenMax
model, openmax_model = train_with_openmax(full_dataset)

# 3. Test with OpenMax on the test set
openmax_trainer = OpenMaxTrainer(model, openmax_model, criterion, optimizer, train_loader, val_loader, test_loader, device)
predictions = openmax_trainer.test(save_to_csv=True)

Found superclasses with indices: [0, 1, 2]

=== Fold 1/3: Treating superclass 0 as novel ===
Training standard model...
Epoch 1/5, Loss: 2.9957
Epoch 2/5, Loss: 1.3963
Epoch 3/5, Loss: 0.8281
Epoch 4/5, Loss: 0.5496
Epoch 5/5, Loss: 0.4145

Fitting OpenMax parameters...
Fold 1 results:
  Known class accuracy: 0.7568
  Novel class accuracy: 0.9946
  Balanced accuracy: 0.8757

=== Fold 2/3: Treating superclass 1 as novel ===
Training standard model...
Epoch 1/5, Loss: 3.0070
Epoch 2/5, Loss: 1.3814
Epoch 3/5, Loss: 0.9175
Epoch 4/5, Loss: 0.5548
Epoch 5/5, Loss: 0.3574

Fitting OpenMax parameters...
Fold 2 results:
  Known class accuracy: 0.6556
  Novel class accuracy: 0.9827
  Balanced accuracy: 0.8192

=== Fold 3/3: Treating superclass 2 as novel ===
Training standard model...
Epoch 1/5, Loss: 3.2028
Epoch 2/5, Loss: 1.4771
Epoch 3/5, Loss: 0.8684
Epoch 4/5, Loss: 0.5534
Epoch 5/5, Loss: 0.3015

Fitting OpenMax parameters...
Fold 3 results:
  Known class accuracy: 0.5888
  Novel class 



Training loss: 3.162
Epoch 2/10
Training loss: 1.460
Epoch 3/10
Training loss: 0.947
Epoch 4/10
Training loss: 0.635
Epoch 5/10
Training loss: 0.428
Epoch 6/10
Training loss: 0.327
Epoch 7/10
Training loss: 0.288
Epoch 8/10
Training loss: 0.231
Epoch 9/10
Training loss: 0.189
Epoch 10/10
Training loss: 0.171

Fitting OpenMax parameters...

Training with OpenMax...
Epoch 1/20
Training loss: 1.105
Validation loss: 1.608
Validation superclass acc: 74.09%
Validation subclass acc: 48.49%

Epoch 2/20
Training loss: 0.357
Validation loss: 1.076
Validation superclass acc: 78.06%
Validation subclass acc: 58.51%

Epoch 3/20
Training loss: 0.119
Validation loss: 0.981
Validation superclass acc: 78.38%
Validation subclass acc: 60.41%

Epoch 4/20
Training loss: 0.092
Validation loss: 1.003
Validation superclass acc: 81.88%
Validation subclass acc: 65.98%

Epoch 5/20
Training loss: 0.082
Validation loss: 1.092
Validation superclass acc: 83.31%
Validation subclass acc: 66.30%

Epoch 6/20
Training los

In [None]:
test_predictions = trainer.test(save_to_csv=True, return_predictions=True)

In [None]:
# @title Default title text
# Function for OpenMax cross-validation
def openmax_cross_validation(full_dataset, device='cuda', batch_size=64, epochs=10):
    """
    Run enhanced cross-validation for OpenMax novel detection

    Args:
        full_dataset: The complete dataset
        device: Device to use (cuda or cpu)
        batch_size: Batch size for training
        epochs: Number of training epochs

    Returns:
        Results dictionary with cross-validation metrics
    """
    # Get all unique superclass indices
    superclass_indices = set()
    for i in range(len(full_dataset)):
        _, super_idx, _, _, _ = full_dataset[i]
        if hasattr(super_idx, 'item'):
            super_idx = super_idx.item()
        superclass_indices.add(super_idx)

    superclass_indices = sorted(list(superclass_indices))
    print(f"Found superclasses with indices: {superclass_indices}")

    results = []

    # For each superclass, treat it as novel and others as known
    for fold, novel_idx in enumerate(superclass_indices):
        print(f"\n=== Fold {fold+1}/{len(superclass_indices)}: Treating superclass {novel_idx} as novel ===")

        # Create data splits
        known_indices = []
        novel_indices = []

        for i in range(len(full_dataset)):
            _, super_idx, _, _, _ = full_dataset[i]
            if hasattr(super_idx, 'item'):
                super_idx = super_idx.item()

            if super_idx == novel_idx:
                novel_indices.append(i)
            else:
                known_indices.append(i)

        # Further split known indices into train/validation with stratification
        np.random.shuffle(known_indices)
        train_size = int(0.9 * len(known_indices))
        train_indices = known_indices[:train_size]
        val_known_indices = known_indices[train_size:]

        # Create datasets
        from torch.utils.data import Subset
        train_dataset = Subset(full_dataset, train_indices)
        val_known_dataset = Subset(full_dataset, val_known_indices)
        val_novel_dataset = Subset(full_dataset, novel_indices)

        # Combine known and novel validation sets
        val_indices = val_known_indices + novel_indices
        val_dataset = Subset(full_dataset, val_indices)

        # Create dataloaders
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )

        # Separate loaders for evaluation
        val_known_loader = torch.utils.data.DataLoader(
            val_known_dataset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True
        )

        val_novel_loader = torch.utils.data.DataLoader(
            val_novel_dataset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True
        )

        # Initialize enhanced model
        model = CNN(
            input_size=64,
            num_superclasses=len(superclass_indices)+1
        ).to(device)

        # Use label smoothing for regularization
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        # Use AdamW optimizer with weight decay
        optimizer = optim.AdamW(
            model.parameters(),
            lr=0.001,
            weight_decay=0.0001
        )

        # Train the model
        print("Training model...")
        for epoch in range(epochs):
            running_loss = 0.0
            model.train()

            for i, data in enumerate(train_loader):
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                optimizer.zero_grad()
                super_outputs, sub_outputs = model(inputs)
                loss = criterion(super_outputs, super_labels) + criterion(sub_outputs, sub_labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            # Print progress every epoch
            print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/(i+1):.4f}')

        # Initialize OpenMax model with adjusted parameters for better novel detection
        num_known_classes = len(superclass_indices) - 1  # Excluding the novel class
        openmax_model = EnhancedOpenMaxModel(
            model,
            num_superclasses=num_known_classes,
            tailsize=30,      # Increased tailsize for better statistics
            alpha=5,          # Reduced alpha to consider fewer top classes
            threshold=0.7     # Adjusted threshold for better balance
        )

        # Fit OpenMax parameters
        print("\nFitting OpenMax parameters...")
        openmax_model.fit(train_loader, device)

        # Evaluate OpenMax on known and novel samples
        model.eval()

        # Test on known classes
        known_correct = 0
        known_total = 0

        with torch.no_grad():
            for data in val_known_loader:
                inputs, super_labels, _, _, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)

                # OpenMax prediction
                super_preds, _, _, _ = openmax_model.predict(inputs, device)

                known_total += super_labels.size(0)
                known_correct += (super_preds == super_labels).sum().item()

        # Test on novel classes
        novel_correct = 0
        novel_total = 0

        with torch.no_grad():
            for data in val_novel_loader:
                inputs, _, _, _, _ = data
                inputs = inputs.to(device)

                # OpenMax prediction
                super_preds, _, unknown_probs, _ = openmax_model.predict(inputs, device)

                # For novel classes, prediction should be the novel class index
                novel_total += inputs.size(0)
                novel_correct += (super_preds == num_known_classes).sum().item()

                # Print statistics of unknown probabilities for novel samples
                if novel_total > 0:
                    print(f"Novel samples unknown probability: mean={np.mean(unknown_probs):.4f}, "
                          f"min={np.min(unknown_probs):.4f}, max={np.max(unknown_probs):.4f}")

        known_acc = known_correct / known_total if known_total > 0 else 0
        novel_acc = novel_correct / novel_total if novel_total > 0 else 0
        balanced_acc = (known_acc + novel_acc) / 2

        results.append({
            'fold': fold,
            'novel_class': novel_idx,
            'known_accuracy': known_acc,
            'novel_accuracy': novel_acc,
            'balanced_accuracy': balanced_acc
        })

        print(f"Fold {fold+1} results:")
        print(f"  Known class accuracy: {known_acc:.4f}")
        print(f"  Novel class accuracy: {novel_acc:.4f}")
        print(f"  Balanced accuracy: {balanced_acc:.4f}")

    # Calculate average results
    avg_known_acc = sum(r['known_accuracy'] for r in results) / len(results)
    avg_novel_acc = sum(r['novel_accuracy'] for r in results) / len(results)
    avg_balanced_acc = sum(r['balanced_accuracy'] for r in results) / len(results)

    print("\n=== Enhanced OpenMax Cross-Validation Summary ===")
    print(f"Average known accuracy: {avg_known_acc:.4f}")
    print(f"Average novel accuracy: {avg_novel_acc:.4f}")
    print(f"Average balanced accuracy: {avg_balanced_acc:.4f}")

    return results


# Function to set up and train a complete OpenMax model
def setup_and_train_openmax(train_loader, val_loader, test_loader=None, device='cuda', epochs=15):
    """
    Set up and train a complete OpenMax model

    Args:
        train_loader: Training data loader
        val_loader: Validation data loader
        test_loader: Test data loader (optional)
        device: Device to use (cuda or cpu)
        epochs: Number of training epochs

    Returns:
        Tuple of (model, openmax_model, trainer)
    """
    # Initialize enhanced model
    model = CNN(
        input_size=64,
        num_superclasses=4,
        num_subclasses=88
    ).to(device)

    # Use label smoothing for regularization
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    # Use AdamW optimizer with weight decay
    optimizer = optim.AdamW(
        model.parameters(),
        lr=0.001,
        weight_decay=0.0001
    )

    # Train the base model first
    print("Training base model...")
    for epoch in range(5):  # Pre-train for 5 epochs
        model.train()
        running_loss = 0.0

        for i, data in enumerate(train_loader):
            inputs, super_labels, _, sub_labels, _ = data
            inputs = inputs.to(device)
            super_labels = super_labels.to(device)
            sub_labels = sub_labels.to(device)

            optimizer.zero_grad()
            super_outputs, sub_outputs = model(inputs)
            loss = criterion(super_outputs, super_labels) + criterion(sub_outputs, sub_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch {epoch+1}/5, Loss: {running_loss/(i+1):.4f}')

    # Initialize OpenMax model (assuming 3 known superclasses, excluding novel)
    openmax_model = EnhancedOpenMaxModel(
        model,
        num_superclasses=3,
        num_subclasses=87,
        tailsize=30,
        alpha=5,
        threshold=0.7
    )

    # Fit OpenMax parameters
    print("\nFitting OpenMax parameters...")
    openmax_model.fit(train_loader, device)

    # Create OpenMax trainer
    trainer = EnhancedOpenMaxTrainer(
        model=model,
        openmax_model=openmax_model,
        criterion=criterion,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=device
    )

    # Continue training with OpenMax
    print("\nTraining with OpenMax...")
    best_balanced_acc = 0
    best_state_dict = None

    for epoch in range(epochs):
        print(f'Epoch {epoch+1}/{epochs}')
        trainer.train_epoch()
        metrics = trainer.validate_epoch()

        # Save best model
        if metrics['balanced_acc'] > best_balanced_acc:
            best_balanced_acc = metrics['balanced_acc']
            best_state_dict = model.state_dict().copy()
            print(f"New best model! Balanced accuracy: {best_balanced_acc:.4f}")

    # Load best model if found
    if best_state_dict is not None:
        model.load_state_dict(best_state_dict)

        # Re-fit OpenMax with best model
        openmax_model.fit(train_loader, device)

    return model, openmax_model, trainer


# Example usage:
# model, openmax_model, trainer = setup_and_train_openmax(train_loader, val_loader, test_loader)
# test_predictions = trainer.test(save_to_csv=True)import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from scipy.stats import weibull_min
from sklearn.preprocessing import normalize
from collections import defaultdict

class CNN(nn.Module):
    def __init__(self, input_size=64, num_superclasses=4, num_subclasses=88):
        super().__init__()

        self.feature_size = input_size // (2**3)

        # First convolutional block with increased capacity
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding='same'),  # Increased from 32 to 64
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2)
        )

        # Second convolutional block with increased capacity
        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding='same'),  # Increased from 64 to 128
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2)
        )

        # Third convolutional block with increased capacity
        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding='same'),  # Increased from 128 to 256
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2, 2)
        )

        # Fully connected layers with dropout
        self.fc1 = nn.Linear(self.feature_size * self.feature_size * 256, 512)  # Increased capacity
        self.dropout1 = nn.Dropout(0.25)
        self.fc2 = nn.Linear(512, 256)
        self.dropout2 = nn.Dropout(0.25)

        # Classification heads
        self.fc3a = nn.Linear(256, num_superclasses)  # Superclass prediction
        self.fc3b = nn.Linear(256, num_subclasses)    # Subclass prediction

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = torch.flatten(x, 1)  # flatten all dimensions except batch

        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)

        super_out = self.fc3a(x)
        sub_out = self.fc3b(x)

        return super_out, sub_out

    def get_features(self, x):
        """Extract features before the final classification layer"""
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)

        return x

    def get_logits(self, features):
        """Get class logits from features"""
        super_out = self.fc3a(features)
        sub_out = self.fc3b(features)
        return super_out, sub_out


class EnhancedOpenMaxModel:
    def __init__(self, model, num_superclasses=3, num_subclasses=87, tailsize=20, alpha=6, threshold=0.8):
        """
        Initialize Enhanced OpenMax model with a pre-trained CNN

        Args:
            model: Pre-trained CNN model
            num_superclasses: Number of known superclasses
            num_subclasses: Number of known subclasses
            tailsize: Number of extremal samples to use for Weibull fitting
            alpha: Number of top activations to consider for recalibration
            threshold: Threshold for unknown class probability (higher is more conservative)
        """
        self.model = model
        self.num_superclasses = num_superclasses  # Number of known superclasses
        self.num_subclasses = num_subclasses  # Number of known subclasses
        self.tailsize = tailsize
        self.alpha = alpha
        self.threshold = threshold  # Increased threshold for better novel detection

        # Storage for class means and Weibull models
        self.super_means = None
        self.sub_means = None
        self.super_weibull_models = None
        self.sub_weibull_models = None

        # Statistics to better normalize the data
        self.super_logits_mean = None
        self.super_logits_std = None
        self.sub_logits_mean = None
        self.sub_logits_std = None

        # For MAV (Mean Activation Vector) calibration
        self.super_mavs = None
        self.sub_mavs = None

        # Class distance statistics
        self.super_dist_mean = None
        self.super_dist_std = None
        self.sub_dist_mean = None
        self.sub_dist_std = None

    def fit(self, train_loader, device):
        """
        Fit Weibull distributions to activation vectors for known classes
        """
        self.model.eval()

        # Collect activations for each class
        super_activations = defaultdict(list)
        sub_activations = defaultdict(list)
        all_super_logits = []
        all_sub_logits = []

        # Collect activations for each class
        with torch.no_grad():
            for data in train_loader:
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                # Get features
                features = self.model.get_features(inputs)
                super_logits, sub_logits = self.model.get_logits(features)

                # Store activations for each class
                for i in range(inputs.size(0)):
                    super_class = super_labels[i].item()
                    sub_class = sub_labels[i].item()

                    # Only store for known classes
                    if super_class < self.num_superclasses:
                        super_activations[super_class].append(super_logits[i].cpu().numpy())
                        all_super_logits.append(super_logits[i].cpu().numpy())

                    if sub_class < self.num_subclasses:
                        sub_activations[sub_class].append(sub_logits[i].cpu().numpy())
                        all_sub_logits.append(sub_logits[i].cpu().numpy())

        # Compute logits statistics for better normalization
        all_super_logits = np.array(all_super_logits)
        all_sub_logits = np.array(all_sub_logits)

        self.super_logits_mean = np.mean(all_super_logits, axis=0)
        self.super_logits_std = np.std(all_super_logits, axis=0) + 1e-6

        self.sub_logits_mean = np.mean(all_sub_logits, axis=0)
        self.sub_logits_std = np.std(all_sub_logits, axis=0) + 1e-6

        # Compute means for each class (MAVs)
        self.super_mavs = {}
        self.sub_mavs = {}

        for c in range(self.num_superclasses):
            if c in super_activations and len(super_activations[c]) > 0:
                self.super_mavs[c] = np.mean(super_activations[c], axis=0)

        for c in range(self.num_subclasses):
            if c in sub_activations and len(sub_activations[c]) > 0:
                self.sub_mavs[c] = np.mean(sub_activations[c], axis=0)

        # Compute distances to mean for Weibull fitting
        super_dists = defaultdict(list)
        sub_dists = defaultdict(list)
        all_super_dists = []
        all_sub_dists = []

        with torch.no_grad():
            for data in train_loader:
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                # Get features
                features = self.model.get_features(inputs)
                super_logits, sub_logits = self.model.get_logits(features)

                # Compute distances
                for i in range(inputs.size(0)):
                    super_class = super_labels[i].item()
                    sub_class = sub_labels[i].item()

                    # Only process known classes
                    if super_class < self.num_superclasses and super_class in self.super_mavs:
                        super_mav = self.super_mavs[super_class]
                        super_logit = super_logits[i].cpu().numpy()

                        # Normalize logits for better distance calculation
                        norm_super_logit = (super_logit - self.super_logits_mean) / self.super_logits_std
                        norm_super_mav = (super_mav - self.super_logits_mean) / self.super_logits_std

                        # Use cosine distance instead of euclidean for better results
                        super_dist = 1 - np.dot(norm_super_logit, norm_super_mav) / (
                            np.linalg.norm(norm_super_logit) * np.linalg.norm(norm_super_mav) + 1e-10)

                        super_dists[super_class].append(super_dist)
                        all_super_dists.append(super_dist)

                    if sub_class < self.num_subclasses and sub_class in self.sub_mavs:
                        sub_mav = self.sub_mavs[sub_class]
                        sub_logit = sub_logits[i].cpu().numpy()

                        # Normalize logits
                        norm_sub_logit = (sub_logit - self.sub_logits_mean) / self.sub_logits_std
                        norm_sub_mav = (sub_mav - self.sub_logits_mean) / self.sub_logits_std

                        # Use cosine distance
                        sub_dist = 1 - np.dot(norm_sub_logit, norm_sub_mav) / (
                            np.linalg.norm(norm_sub_logit) * np.linalg.norm(norm_sub_mav) + 1e-10)

                        sub_dists[sub_class].append(sub_dist)
                        all_sub_dists.append(sub_dist)

        # Compute global distance statistics
        all_super_dists = np.array(all_super_dists)
        all_sub_dists = np.array(all_sub_dists)

        self.super_dist_mean = np.mean(all_super_dists)
        self.super_dist_std = np.std(all_super_dists) + 1e-6

        self.sub_dist_mean = np.mean(all_sub_dists)
        self.sub_dist_std = np.std(all_sub_dists) + 1e-6

        print(f"Super distance stats: mean={self.super_dist_mean:.4f}, std={self.super_dist_std:.4f}")
        print(f"Sub distance stats: mean={self.sub_dist_mean:.4f}, std={self.sub_dist_std:.4f}")

        # Fit Weibull models
        self.super_weibull_models = {}
        self.sub_weibull_models = {}

        for c in range(self.num_superclasses):
            if c in super_dists and len(super_dists[c]) >= self.tailsize:
                # Sort distances and take tailsize largest
                sorted_dists = sorted(super_dists[c])
                tail_dists = sorted_dists[-self.tailsize:]

                # Fit Weibull distribution
                try:
                    shape, loc, scale = weibull_min.fit(tail_dists, floc=0)
                    self.super_weibull_models[c] = (shape, loc, scale)
                    print(f"Fitted Weibull for superclass {c}: shape={shape:.4f}, scale={scale:.4f}")
                except Exception as e:
                    print(f"Warning: Failed to fit Weibull for superclass {c}: {e}")

        for c in range(self.num_subclasses):
            if c in sub_dists and len(sub_dists[c]) >= self.tailsize:
                # Sort distances and take tailsize largest
                sorted_dists = sorted(sub_dists[c])
                tail_dists = sorted_dists[-self.tailsize:]

                # Fit Weibull distribution
                try:
                    shape, loc, scale = weibull_min.fit(tail_dists, floc=0)
                    self.sub_weibull_models[c] = (shape, loc, scale)
                except Exception as e:
                    print(f"Warning: Failed to fit Weibull for subclass {c}: {e}")

    def predict(self, inputs, device):
        """
        Predict with Enhanced OpenMax recalibration
        """
        self.model.eval()

        with torch.no_grad():
            # Get features and logits
            features = self.model.get_features(inputs)
            super_logits, sub_logits = self.model.get_logits(features)

            # Convert to numpy for processing
            super_logits_np = super_logits.cpu().numpy()
            sub_logits_np = sub_logits.cpu().numpy()

            # Process each sample
            super_preds = []
            sub_preds = []
            super_unknown_probs = []
            sub_unknown_probs = []

            for i in range(inputs.size(0)):
                # Recalibrate superclass logits
                super_logit = super_logits_np[i]
                super_pred, super_unknown_prob = self._recalibrate_sample(
                    super_logit,
                    self.super_mavs,
                    self.super_weibull_models,
                    self.num_superclasses,
                    self.super_logits_mean,
                    self.super_logits_std,
                    self.super_dist_mean,
                    self.super_dist_std
                )
                super_preds.append(super_pred)
                super_unknown_probs.append(super_unknown_prob)

                # Recalibrate subclass logits
                sub_logit = sub_logits_np[i]
                sub_pred, sub_unknown_prob = self._recalibrate_sample(
                    sub_logit,
                    self.sub_mavs,
                    self.sub_weibull_models,
                    self.num_subclasses,
                    self.sub_logits_mean,
                    self.sub_logits_std,
                    self.sub_dist_mean,
                    self.sub_dist_std
                )
                sub_preds.append(sub_pred)
                sub_unknown_probs.append(sub_unknown_prob)

            # Convert back to tensors
            super_preds = torch.tensor(super_preds, device=device)
            sub_preds = torch.tensor(sub_preds, device=device)

        return super_preds, sub_preds, super_unknown_probs, sub_unknown_probs

    def _recalibrate_sample(self, logits, mavs, weibull_models, num_classes,
                           logits_mean, logits_std, dist_mean, dist_std):
        """
        Recalibrate logits for a single sample with enhanced techniques

        Args:
            logits: Logits from the model for a single sample
            mavs: Mean Activation Vectors for each class
            weibull_models: Fitted Weibull models
            num_classes: Number of known classes
            logits_mean: Mean of all logits for normalization
            logits_std: Standard deviation of all logits for normalization
            dist_mean: Mean of all distances for normalization
            dist_std: Standard deviation of all distances for normalization

        Returns:
            Tuple of (predicted class including novel class, unknown probability)
        """
        # Get top alpha class indices
        top_alpha_idx = np.argsort(logits)[-self.alpha:]

        # Normalize logits
        norm_logits = (logits - logits_mean) / logits_std

        # Compute distances to class means and weibull scores
        distances = {}
        weibull_scores = {}

        # Calculate distance to each known class
        for c in range(num_classes):
            if c in mavs:
                mav = mavs[c]
                norm_mav = (mav - logits_mean) / logits_std

                # Use cosine distance
                dist = 1 - np.dot(norm_logits, norm_mav) / (
                    np.linalg.norm(norm_logits) * np.linalg.norm(norm_mav) + 1e-10)

                # Normalize distance
                norm_dist = (dist - dist_mean) / dist_std
                distances[c] = norm_dist

                # Calculate weibull score (probability of being an outlier)
                if c in weibull_models:
                    shape, loc, scale = weibull_models[c]
                    try:
                        # Higher score means more likely to be an outlier
                        weibull_scores[c] = 1 - weibull_min.cdf(dist, shape, loc, scale)
                    except:
                        weibull_scores[c] = 0.5  # Default if calculation fails

        # Recalibrate activations using Weibull scores
        recalibrated = np.copy(logits)

        # Weight for outlier evidence
        evidence_weights = []

        for c in top_alpha_idx:
            if c < num_classes and c in weibull_models:
                # Get weibull score
                w_score = weibull_scores.get(c, 0.5)

                # Adjust activation - more reduction for higher weibull scores
                recalibrated[c] = logits[c] * (1 - w_score)

                # Collect evidence for being an outlier
                evidence_weights.append(w_score)

        # Apply softmax to get probabilities (excluding unknown)
        # Add small epsilon to prevent underflow
        recalibrated_exp = np.exp(recalibrated - np.max(recalibrated))
        recalibrated_probs = recalibrated_exp / (np.sum(recalibrated_exp) + 1e-10)

        # Compute probability of being unknown using average of evidence weights
        # Higher values of evidence_weights indicate stronger evidence of being novel
        if evidence_weights:
            avg_evidence = np.mean(evidence_weights)

            # Apply sigmoid scaling to convert evidence to probability
            unknown_prob = 1.0 / (1.0 + np.exp(-10 * (avg_evidence - 0.5)))
        else:
            unknown_prob = 0.5  # Default if no evidence is available

        # Make final prediction (including novel class)
        if unknown_prob > self.threshold:
            return num_classes, unknown_prob  # Return novel class index
        else:
            return np.argmax(recalibrated_probs), unknown_prob


class EnhancedOpenMaxTrainer:
    def __init__(self, model, openmax_model, criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'):
        self.model = model
        self.openmax_model = openmax_model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device

        # Add scheduler for better training
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=10, eta_min=1e-6
        )

        # For center loss
        self.centers = None
        self.center_loss_weight = 0.05

    def _init_centers(self, num_classes, feature_dim):
        """Initialize class centers for center loss"""
        self.centers = torch.zeros(num_classes, feature_dim, device=self.device)

    def _update_centers(self, features, labels, alpha=0.1):
        """Update class centers based on current batch"""
        if self.centers is None:
            # Initialize centers if not done already
            self._init_centers(4, features.size(1))  # Assuming 4 superclasses

        # Create a mask for each class in the batch
        batch_size = features.size(0)
        unique_labels = torch.unique(labels)

        for label in unique_labels:
            if label >= self.centers.size(0):
                continue  # Skip if label is out of range

            mask = (labels == label).float()
            mask_sum = mask.sum()

            if mask_sum > 0:
                # Compute mean features for this class in the batch
                class_features = torch.sum(features * mask.unsqueeze(1), dim=0) / mask_sum

                # Update center with moving average
                self.centers[label] = alpha * class_features + (1 - alpha) * self.centers[label]

    def _center_loss(self, features, labels):
        """Calculate center loss"""
        if self.centers is None:
            return 0.0

        batch_size = features.size(0)

        # Gather features by labels
        centers_batch = self.centers[labels]

        # Calculate center loss
        return 0.5 * torch.sum(torch.pow(features - centers_batch, 2)) / batch_size

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0

        for i, data in enumerate(self.train_loader):
            inputs, super_labels, _, sub_labels, _ = data
            inputs = inputs.to(self.device)
            super_labels = super_labels.to(self.device)
            sub_labels = sub_labels.to(self.device)

            # Forward pass
            features = self.model.get_features(inputs)
            super_outputs, sub_outputs = self.model.get_logits(features)

            # Standard cross-entropy loss
            ce_loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)

            # Center loss for better feature separation
            c_loss = self._center_loss(features, super_labels)

            # Combined loss
            loss = ce_loss + self.center_loss_weight * c_loss

            # Backward and optimize
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            # Update center representations
            with torch.no_grad():
                self._update_centers(features.detach(), super_labels)

            running_loss += loss.item()

        avg_loss = running_loss / (i+1)
        print(f'Training loss: {avg_loss:.4f}')

        # Step the scheduler
        self.scheduler.step()

        return avg_loss

    def validate_epoch(self):
        self.model.eval()

        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0

        # For separate known/novel evaluation
        novel_total = 0
        known_total = 0
        novel_correct = 0
        known_correct = 0

        all_super_unknown_probs = []
        all_is_novel = []

        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(self.device)
                super_labels = super_labels.to(self.device)
                sub_labels = sub_labels.to(self.device)

                # Get features for loss calculation
                features = self.model.get_features(inputs)
                super_outputs, sub_outputs = self.model.get_logits(features)

                # Standard cross-entropy loss
                ce_loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)

                # Center loss
                c_loss = self._center_loss(features, super_labels)

                # Combined loss
                loss = ce_loss + self.center_loss_weight * c_loss

                # OpenMax prediction
                super_preds, sub_preds, super_unknown_probs, _ = self.openmax_model.predict(inputs, self.device)

                # Track overall accuracy
                total += super_labels.size(0)
                super_correct += (super_preds == super_labels).sum().item()
                sub_correct += (sub_preds == sub_labels).sum().item()

                # Track novel vs known separately
                is_novel = super_labels >= self.openmax_model.num_superclasses
                novel_total += is_novel.sum().item()
                known_total += (~is_novel).sum().item()

                # A novel sample is correctly classified if detected as novel
                novel_correct += (is_novel & (super_preds == self.openmax_model.num_superclasses)).sum().item()

                # A known sample is correctly classified if prediction matches label
                known_correct += ((~is_novel) & (super_preds == super_labels)).sum().item()

                # Store probabilities and labels for threshold analysis
                all_super_unknown_probs.extend(super_unknown_probs)
                all_is_novel.extend(is_novel.cpu().numpy())

                running_loss += loss.item()

        # Calculate metrics
        avg_loss = running_loss / (i+1)
        super_acc = 100 * super_correct / total if total > 0 else 0
        sub_acc = 100 * sub_correct / total if total > 0 else 0

        novel_acc = 100 * novel_correct / novel_total if novel_total > 0 else 0
        known_acc = 100 * known_correct / known_total if known_total > 0 else 0
        balanced_acc = (novel_acc + known_acc) / 2 if novel_total > 0 and known_total > 0 else 0

        print(f'Validation loss: {avg_loss:.4f}')
        print(f'Validation superclass acc: {super_acc:.2f}%, subclass acc: {sub_acc:.2f}%')

        if novel_total > 0:
            print(f'Novel acc: {novel_acc:.2f}%, Known acc: {known_acc:.2f}%, Balanced acc: {balanced_acc:.2f}%')

        # Return metrics dictionary
        return {
            'loss': avg_loss,
            'super_acc': super_acc,
            'sub_acc': sub_acc,
            'novel_acc': novel_acc,
            'known_acc': known_acc,
            'balanced_acc': balanced_acc
        }

    def test(self, save_to_csv=False, return_predictions=False, output_file='openmax_test_predictions.csv'):
        if not self.test_loader:
            raise ValueError('test_loader not specified')

        self.model.eval()

        # Evaluate on test set with OpenMax
        test_predictions = {
            'image': [],
            'superclass_index': [],
            'subclass_index': [],
            'unknown_probability': []
        }

        with torch.no_grad():
            for i, data in enumerate(self.test_loader):
                inputs, img_name = data[0].to(self.device), data[1]

                # OpenMax prediction
                super_preds, sub_preds, super_unknown_probs, sub_unknown_probs = self.openmax_model.predict(inputs, self.device)

                for j in range(inputs.size(0)):
                    img = img_name[j] if isinstance(img_name, list) else img_name[0]

                    test_predictions['image'].append(img)
                    test_predictions['superclass_index'].append(super_preds[j].item())
                    test_predictions['subclass_index'].append(sub_preds[j].item())
                    test_predictions['unknown_probability'].append(super_unknown_probs[j])

        # Create full DataFrame with all information
        full_predictions_df = pd.DataFrame(data=test_predictions)

        # Create simplified DataFrame for output (compatible with original code)
        simplified_predictions = {
            'image': test_predictions['image'],
            'superclass_index': test_predictions['superclass_index'],
            'subclass_index': test_predictions['subclass_index']
        }
        simplified_predictions_df = pd.DataFrame(data=simplified_predictions)

        # Print summary of novel predictions
        novel_super_count = sum(1 for idx in test_predictions['superclass_index'] if idx == self.openmax_model.num_superclasses)
        novel_sub_count = sum(1 for idx in test_predictions['subclass_index'] if idx == self.openmax_model.num_subclasses)

        total_count = len(test_predictions['image'])
        novel_super_perc = 100 * novel_super_count / total_count if total_count > 0 else 0
        novel_sub_perc = 100 * novel_sub_count / total_count if total_count > 0 else 0

        print(f'Test set predictions:')
        print(f'Images predicted as novel superclass: {novel_super_count} ({novel_super_perc:.2f}%)')
        print(f'Images predicted as novel subclass: {novel_sub_count} ({novel_sub_perc:.2f}%)')

        # Also print unknown probability statistics
        unknown_probs = np.array(test_predictions['unknown_probability'])
        print(f'Unknown probability statistics:')
        print(f'  Mean: {np.mean(unknown_probs):.4f}')
        print(f'  Std: {np.std(unknown_probs):.4f}')
        print(f'  Min: {np.min(unknown_probs):.4f}')
        print(f'  Max: {np.max(unknown_probs):.4f}')

        if save_to_csv:
            # Save in the same format as the original code
            simplified_predictions_df.to_csv(output_file, index=False)
            print(f"Predictions saved to '{output_file}'")

            # # Save detailed version with probabilities
            # full_predictions_df.to_csv(f'detailed_{output_file}', index=False)
            # print(f"Detailed predictions saved to 'detailed_{output_file}'")

        if return_predictions:
            # Return the predictions for further analysis
            return full_predictions_df

In [None]:
results = openmax_cross_validation(full_dataset, device='cuda', batch_size=64, epochs=10)
model, openmax_model, trainer = setup_and_train_openmax(train_loader, val_loader, test_loader)
test_predictions = trainer.test(save_to_csv=True)

Found superclasses with indices: [0, 1, 2]

=== Fold 1/3: Treating superclass 0 as novel ===
Training model...
Epoch 1/10, Loss: 4.4512
Epoch 2/10, Loss: 3.2227
Epoch 3/10, Loss: 2.7819
Epoch 4/10, Loss: 2.4416
Epoch 5/10, Loss: 2.2537
Epoch 6/10, Loss: 2.0857
Epoch 7/10, Loss: 1.9910
Epoch 8/10, Loss: 1.8896
Epoch 9/10, Loss: 1.8026
Epoch 10/10, Loss: 1.7324

Fitting OpenMax parameters...
Super distance stats: mean=1.0000, std=0.0000
Sub distance stats: mean=0.1782, std=0.1414
Fitted Weibull for superclass 1: shape=2886523556.1939, scale=1.0000
Novel samples unknown probability: mean=0.2574, min=0.2574, max=0.2574
Novel samples unknown probability: mean=0.2574, min=0.2574, max=0.2574
Novel samples unknown probability: mean=0.2574, min=0.2574, max=0.2574
Novel samples unknown probability: mean=0.2574, min=0.2574, max=0.2574
Novel samples unknown probability: mean=0.2574, min=0.2574, max=0.2574
Novel samples unknown probability: mean=0.2574, min=0.2574, max=0.2574
Novel samples unknown 

In [None]:
# @title Default title text
#BEST OPEN MAX ENHANCED MODEL
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from scipy.stats import weibull_min
from sklearn.preprocessing import normalize
from collections import defaultdict

class OptimizedCNN(nn.Module):
    def __init__(self, input_size=64, num_superclasses=4, num_subclasses=88):
        super().__init__()

        self.feature_size = input_size // (2**3)

        # First convolutional block with increased capacity
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2)
        )

        # Second convolutional block with increased capacity
        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2)
        )

        # Third convolutional block with increased capacity
        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2, 2)
        )

        # Fully connected layers with dropout
        self.fc1 = nn.Linear(self.feature_size * self.feature_size * 256, 512)
        self.bn1 = nn.BatchNorm1d(512)  # Added BatchNorm after FC
        self.dropout1 = nn.Dropout(0.3)  # Increased dropout rate
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.3)

        # Classification heads
        self.fc3a = nn.Linear(256, num_superclasses)
        self.fc3b = nn.Linear(256, num_subclasses)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = torch.flatten(x, 1)

        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)

        super_out = self.fc3a(x)
        sub_out = self.fc3b(x)

        return super_out, sub_out

    def get_features(self, x):
        """Extract features before the final classification layer"""
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = torch.flatten(x, 1)
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)

        return x

    def get_logits(self, features):
        """Get class logits from features"""
        super_out = self.fc3a(features)
        sub_out = self.fc3b(features)
        return super_out, sub_out


class OptimizedOpenMaxModel:
    def __init__(self, model, num_superclasses=3, num_subclasses=87, tailsize=40, alpha=4,
                 threshold=0.6, distance_multiplier=1.5):

        self.model = model
        self.num_superclasses = num_superclasses
        self.num_subclasses = num_subclasses
        self.tailsize = tailsize
        self.alpha = alpha
        self.threshold = threshold
        self.distance_multiplier = distance_multiplier

        # Storage for class means and Weibull models
        self.super_mavs = None  # Mean Activation Vectors
        self.sub_mavs = None
        self.super_weibull_models = None
        self.sub_weibull_models = None

        # Statistics for normalization
        self.super_logits_mean = None
        self.super_logits_std = None
        self.sub_logits_mean = None
        self.sub_logits_std = None
        self.super_dist_mean = None
        self.super_dist_std = None
        self.sub_dist_mean = None
        self.sub_dist_std = None

    def fit(self, train_loader, device):
        """
        Fit Weibull distributions to activation vectors for known classes
        with improved normalization and calibration
        """
        self.model.eval()
        print("\nFitting OpenMax parameters with distance_multiplier =", self.distance_multiplier)

        # Collect activations for each class
        super_activations = defaultdict(list)
        sub_activations = defaultdict(list)
        all_super_logits = []
        all_sub_logits = []

        # Collect activations for each class
        print("Collecting class activations...")
        with torch.no_grad():
            for data in train_loader:
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                # Get features
                features = self.model.get_features(inputs)
                super_logits, sub_logits = self.model.get_logits(features)

                # Store activations for each class
                for i in range(inputs.size(0)):
                    super_class = super_labels[i].item()
                    sub_class = sub_labels[i].item()

                    # Only store for known classes
                    if super_class < self.num_superclasses:
                        super_activations[super_class].append(super_logits[i].cpu().numpy())
                        all_super_logits.append(super_logits[i].cpu().numpy())

                    if sub_class < self.num_subclasses:
                        sub_activations[sub_class].append(sub_logits[i].cpu().numpy())
                        all_sub_logits.append(sub_logits[i].cpu().numpy())

        # Compute logits statistics for normalization
        all_super_logits = np.array(all_super_logits)
        all_sub_logits = np.array(all_sub_logits)

        self.super_logits_mean = np.mean(all_super_logits, axis=0)
        self.super_logits_std = np.std(all_super_logits, axis=0) + 1e-6

        self.sub_logits_mean = np.mean(all_sub_logits, axis=0)
        self.sub_logits_std = np.std(all_sub_logits, axis=0) + 1e-6

        # Compute Mean Activation Vectors (MAVs)
        self.super_mavs = {}
        self.sub_mavs = {}

        for c in range(self.num_superclasses):
            if c in super_activations and len(super_activations[c]) > 0:
                self.super_mavs[c] = np.mean(super_activations[c], axis=0)
                print(f"Class {c} MAV stats: mean={np.mean(self.super_mavs[c]):.4f}, std={np.std(self.super_mavs[c]):.4f}")

        for c in range(self.num_subclasses):
            if c in sub_activations and len(sub_activations[c]) > 0:
                self.sub_mavs[c] = np.mean(sub_activations[c], axis=0)

        # Compute distances to mean for Weibull fitting
        print("Computing distances for Weibull fitting...")
        super_dists = defaultdict(list)
        sub_dists = defaultdict(list)
        all_super_dists = []
        all_sub_dists = []

        with torch.no_grad():
            for data in train_loader:
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                # Get features
                features = self.model.get_features(inputs)
                super_logits, sub_logits = self.model.get_logits(features)

                # Compute distances
                for i in range(inputs.size(0)):
                    super_class = super_labels[i].item()
                    sub_class = sub_labels[i].item()

                    # Only process known classes
                    if super_class < self.num_superclasses and super_class in self.super_mavs:
                        super_mav = self.super_mavs[super_class]
                        super_logit = super_logits[i].cpu().numpy()

                        # Normalize logits for better distance calculation
                        norm_super_logit = (super_logit - self.super_logits_mean) / self.super_logits_std
                        norm_super_mav = (super_mav - self.super_logits_mean) / self.super_logits_std

                        # Use cosine distance for better performance
                        super_dist = 1 - np.dot(norm_super_logit, norm_super_mav) / (
                            np.linalg.norm(norm_super_logit) * np.linalg.norm(norm_super_mav) + 1e-10)

                        # Apply distance multiplier
                        super_dist = super_dist * self.distance_multiplier

                        super_dists[super_class].append(super_dist)
                        all_super_dists.append(super_dist)

                    if sub_class < self.num_subclasses and sub_class in self.sub_mavs:
                        sub_mav = self.sub_mavs[sub_class]
                        sub_logit = sub_logits[i].cpu().numpy()

                        # Normalize logits
                        norm_sub_logit = (sub_logit - self.sub_logits_mean) / self.sub_logits_std
                        norm_sub_mav = (sub_mav - self.sub_logits_mean) / self.sub_logits_std

                        # Use cosine distance
                        sub_dist = 1 - np.dot(norm_sub_logit, norm_sub_mav) / (
                            np.linalg.norm(norm_sub_logit) * np.linalg.norm(norm_sub_mav) + 1e-10)

                        sub_dists[sub_class].append(sub_dist)
                        all_sub_dists.append(sub_dist)

        # Compute global distance statistics
        all_super_dists = np.array(all_super_dists)
        all_sub_dists = np.array(all_sub_dists)

        self.super_dist_mean = np.mean(all_super_dists)
        self.super_dist_std = np.std(all_super_dists) + 1e-6

        self.sub_dist_mean = np.mean(all_sub_dists)
        self.sub_dist_std = np.std(all_sub_dists) + 1e-6

        print(f"Super distance stats: mean={self.super_dist_mean:.4f}, std={self.super_dist_std:.4f}")
        print(f"Sub distance stats: mean={self.sub_dist_mean:.4f}, std={self.sub_dist_std:.4f}")

        # Fit Weibull models with improved robustness
        self.super_weibull_models = {}
        self.sub_weibull_models = {}

        print("Fitting Weibull distributions...")
        for c in range(self.num_superclasses):
            if c in super_dists and len(super_dists[c]) >= self.tailsize:
                # Sort distances and take tailsize largest
                sorted_dists = sorted(super_dists[c])
                tail_dists = sorted_dists[-self.tailsize:]

                # Fit Weibull with multiple methods for robustness
                self.super_weibull_models[c] = self._robust_weibull_fit(tail_dists)
                shape, loc, scale = self.super_weibull_models[c]
                print(f"Fitted Weibull for superclass {c}: shape={shape:.4f}, scale={scale:.4f}")

        for c in range(self.num_subclasses):
            if c in sub_dists and len(sub_dists[c]) >= self.tailsize:
                # Sort distances and take tailsize largest
                sorted_dists = sorted(sub_dists[c])
                tail_dists = sorted_dists[-self.tailsize:]

                # Fit Weibull distribution
                self.sub_weibull_models[c] = self._robust_weibull_fit(tail_dists)

    def _robust_weibull_fit(self, distances):
        """
        Fit Weibull distribution with multiple methods and fallbacks for robustness
        """
        # Try multiple fitting methods
        try:
            # First try MLE with fixed location parameter
            shape, loc, scale = weibull_min.fit(distances, floc=0)

            # Check if shape is reasonable (not too extreme)
            if 0.1 <= shape <= 20:
                return shape, loc, scale

            # If shape is extreme, try moment matching
            shape, loc, scale = weibull_min.fit(distances, floc=0, method='mm')
            return shape, loc, scale

        except Exception as e:
            # Fallback to reasonable defaults
            print(f"Warning: Weibull fitting failed, using defaults: {e}")
            # Use empirical mean and standard deviation for defaults
            mean_dist = np.mean(distances)
            std_dist = np.std(distances)
            # Approximate Weibull parameters
            shape = 2.0  # Reasonable default shape
            scale = mean_dist # Scale based on mean
            return shape, 0, scale

    def calibrate_threshold(self, val_loader, device):
        """
        Calibrate the threshold for better novelty detection using validation data
        """
        print("\nCalibrating novelty detection threshold...")
        self.model.eval()

        # Collect validation data
        all_probs = []
        is_novel = []

        with torch.no_grad():
            for data in val_loader:
                inputs, super_labels, _, _, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)

                # Get predictions
                _, _, unknown_probs, _ = self.predict(inputs, device)

                # Store results
                novel_labels = super_labels >= self.num_superclasses
                all_probs.extend(unknown_probs)
                is_novel.extend(novel_labels.cpu().numpy())

        # Find optimal threshold
        thresholds = np.linspace(0.1, 0.9, 17)  # 0.1, 0.15, 0.2, ..., 0.9
        best_threshold = 0.5
        best_balanced_acc = 0

        for threshold in thresholds:
            # Make predictions at this threshold
            pred_novel = np.array(all_probs) > threshold
            true_novel = np.array(is_novel)

            # Calculate metrics
            novel_correct = np.sum(pred_novel & true_novel)
            novel_total = np.sum(true_novel)
            novel_acc = novel_correct / novel_total if novel_total > 0 else 0

            known_correct = np.sum((~pred_novel) & (~true_novel))
            known_total = np.sum(~true_novel)
            known_acc = known_correct / known_total if known_total > 0 else 0

            balanced_acc = (novel_acc + known_acc) / 2

            print(f"Threshold {threshold:.2f}: Known={known_acc:.4f}, Novel={novel_acc:.4f}, Balanced={balanced_acc:.4f}")

            if balanced_acc > best_balanced_acc:
                best_balanced_acc = balanced_acc
                best_threshold = threshold

        # Set the calibrated threshold
        old_threshold = self.threshold
        self.threshold = best_threshold
        print(f"Updated threshold from {old_threshold:.2f} to {best_threshold:.2f}, Balanced accuracy: {best_balanced_acc:.4f}")

        return best_threshold, best_balanced_acc

    def predict(self, inputs, device):
        """
        Predict with Enhanced OpenMax recalibration
        """
        self.model.eval()

        with torch.no_grad():
            # Get features and logits
            features = self.model.get_features(inputs)
            super_logits, sub_logits = self.model.get_logits(features)

            # Convert to numpy for processing
            super_logits_np = super_logits.cpu().numpy()
            sub_logits_np = sub_logits.cpu().numpy()

            # Process each sample
            super_preds = []
            sub_preds = []
            super_unknown_probs = []
            sub_unknown_probs = []

            for i in range(inputs.size(0)):
                # Recalibrate superclass logits
                super_logit = super_logits_np[i]
                super_pred, super_unknown_prob = self._recalibrate_sample(
                    super_logit,
                    self.super_mavs,
                    self.super_weibull_models,
                    self.num_superclasses,
                    self.super_logits_mean,
                    self.super_logits_std,
                    self.super_dist_mean,
                    self.super_dist_std
                )
                super_preds.append(super_pred)
                super_unknown_probs.append(super_unknown_prob)

                # Recalibrate subclass logits
                sub_logit = sub_logits_np[i]
                sub_pred, sub_unknown_prob = self._recalibrate_sample(
                    sub_logit,
                    self.sub_mavs,
                    self.sub_weibull_models,
                    self.num_subclasses,
                    self.sub_logits_mean,
                    self.sub_logits_std,
                    self.sub_dist_mean,
                    self.sub_dist_std
                )
                sub_preds.append(sub_pred)
                sub_unknown_probs.append(sub_unknown_prob)

            # Convert back to tensors
            super_preds = torch.tensor(super_preds, device=device)
            sub_preds = torch.tensor(sub_preds, device=device)

        return super_preds, sub_preds, super_unknown_probs, sub_unknown_probs

    def _recalibrate_sample(self, logits, mavs, weibull_models, num_classes,
                           logits_mean, logits_std, dist_mean, dist_std):
        """
        Recalibrate logits for a single sample with improved techniques
        """
        # Get top alpha class indices
        top_alpha_idx = np.argsort(logits)[-self.alpha:]

        # Normalize logits for better distance calculation
        norm_logits = (logits - logits_mean) / logits_std
        norm_logits_unit = norm_logits / (np.linalg.norm(norm_logits) + 1e-10)  # Unit normalize

        # Compute distances to class means and weibull scores
        distances = {}
        weibull_scores = {}

        # Calculate distance to each known class
        for c in range(num_classes):
            if c in mavs:
                mav = mavs[c]
                # Normalize MAV the same way as logits
                norm_mav = (mav - logits_mean) / logits_std
                norm_mav_unit = norm_mav / (np.linalg.norm(norm_mav) + 1e-10)

                # Use cosine distance
                dist = 1 - np.dot(norm_logits_unit, norm_mav_unit)

                # Apply distance multiplier
                dist = dist * self.distance_multiplier

                # Store distance
                distances[c] = dist

                # Calculate weibull score if model exists
                if c in weibull_models:
                    shape, loc, scale = weibull_models[c]
                    try:
                        # Higher score means more likely to be an outlier
                        weibull_scores[c] = 1 - weibull_min.cdf(dist, shape, loc, scale)
                    except:
                        weibull_scores[c] = 0.5  # Default if calculation fails

        # Recalibrate activations using Weibull scores
        recalibrated = np.copy(logits)

        # Weight for outlier evidence
        evidence_weights = []

        for c in top_alpha_idx:
            if c < num_classes and c in weibull_models:
                # Get weibull score
                w_score = weibull_scores.get(c, 0.5)

                # Adjust activation - more reduction for higher weibull scores
                recalibrated[c] = logits[c] * (1 - w_score)

                # Collect evidence for being an outlier
                evidence_weights.append(w_score)

        # Apply softmax to get probabilities (excluding unknown)
        recalibrated_exp = np.exp(recalibrated - np.max(recalibrated))
        recalibrated_probs = recalibrated_exp / (np.sum(recalibrated_exp) + 1e-10)

        # Compute probability of being unknown using average of evidence weights
        if evidence_weights:
            avg_evidence = np.mean(evidence_weights)

            # Apply sigmoid scaling to convert evidence to probability
            unknown_prob = 1.0 / (1.0 + np.exp(-10 * (avg_evidence - 0.5)))
        else:
            unknown_prob = 0.5  # Default if no evidence is available

        # Make final prediction (including novel class)
        if unknown_prob > self.threshold:
            return num_classes, unknown_prob  # Return novel class index
        else:
            return np.argmax(recalibrated_probs), unknown_prob


class OptimizedOpenMaxTrainer:
    def __init__(self, model, openmax_model, criterion, optimizer, train_loader, val_loader,
                 test_loader=None, device='cuda', center_loss_weight=0.0005):
        self.model = model
        self.openmax_model = openmax_model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device

        # Critical fix: Reduce center loss weight dramatically (from 0.05 to 0.0005)
        self.center_loss_weight = center_loss_weight

        # Add scheduler for better convergence
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=10, eta_min=1e-6
        )

        # For center loss
        self.centers = None

        # For tracking history
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'super_acc': [],
            'sub_acc': [],
            'novel_acc': [],
            'known_acc': [],
            'balanced_acc': []
        }

    def _init_centers(self, num_classes, feature_dim):
        """Initialize class centers for center loss"""
        self.centers = torch.zeros(num_classes, feature_dim, device=self.device)

    def _update_centers(self, features, labels, alpha=0.1):
        """Update class centers based on current batch"""
        if self.centers is None:
            # Initialize centers if not done already
            self._init_centers(4, features.size(1))  # Assuming 4 superclasses

        # Create a mask for each class in the batch
        batch_size = features.size(0)
        unique_labels = torch.unique(labels)

        for label in unique_labels:
            if label >= self.centers.size(0):
                continue  # Skip if label is out of range

            mask = (labels == label).float()
            mask_sum = mask.sum()

            if mask_sum > 0:
                # Compute mean features for this class in the batch
                class_features = torch.sum(features * mask.unsqueeze(1), dim=0) / mask_sum

                # Update center with moving average
                self.centers[label] = alpha * class_features + (1 - alpha) * self.centers[label]

    def _center_loss(self, features, labels):
        """Calculate center loss"""
        if self.centers is None:
            # Return a tensor with 0 value, not a Python float
            return torch.tensor(0.0, device=self.device)

        batch_size = features.size(0)

        # Gather features by labels
        centers_batch = self.centers[labels]

        # Calculate center loss
        return 0.5 * torch.sum(torch.pow(features - centers_batch, 2)) / batch_size

    def train_epoch(self):
        """Train the model for one epoch with improved techniques"""
        self.model.train()
        running_loss = 0.0
        running_ce_loss = 0.0
        running_center_loss = 0.0

        for i, data in enumerate(self.train_loader):
            inputs, super_labels, _, sub_labels, _ = data
            inputs = inputs.to(self.device)
            super_labels = super_labels.to(self.device)
            sub_labels = sub_labels.to(self.device)

            # Forward pass with feature extraction
            features = self.model.get_features(inputs)
            super_outputs, sub_outputs = self.model.get_logits(features)

            # Calculate standard cross-entropy loss
            ce_loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)

            # Calculate center loss
            c_loss = self._center_loss(features, super_labels)

            # Combined loss
            loss = ce_loss + self.center_loss_weight * c_loss

            # Backward and optimize
            self.optimizer.zero_grad()
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            # Update center representations
            with torch.no_grad():
                self._update_centers(features.detach(), super_labels)

            running_loss += loss.item()
            running_ce_loss += ce_loss.item()
            running_center_loss += c_loss.item()

        avg_loss = running_loss / (i+1)
        avg_ce_loss = running_ce_loss / (i+1)
        avg_center_loss = running_center_loss / (i+1)

        print(f'Training loss: {avg_loss:.4f} (CE: {avg_ce_loss:.4f}, Center: {avg_center_loss*self.center_loss_weight:.4f})')

        # Step the scheduler if available
        if self.scheduler is not None:
            self.scheduler.step()

        # Update history
        self.history['train_loss'].append(avg_loss)

        return avg_loss

    def validate_epoch(self):
        """Validate with improved metrics tracking"""
        if self.openmax_model is None:
            return self._validate_without_openmax()

        self.model.eval()

        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0

        # For separate known/novel evaluation
        novel_total = 0
        known_total = 0
        novel_correct = 0
        known_correct = 0

        all_super_unknown_probs = []
        all_is_novel = []

        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(self.device)
                super_labels = super_labels.to(self.device)
                sub_labels = sub_labels.to(self.device)

                # Get features for loss calculation
                features = self.model.get_features(inputs)
                super_outputs, sub_outputs = self.model.get_logits(features)

                # Standard cross-entropy loss
                ce_loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)

                # Center loss
                c_loss = self._center_loss(features, super_labels)

                # Combined loss
                loss = ce_loss + self.center_loss_weight * c_loss

                # OpenMax prediction
                super_preds, sub_preds, super_unknown_probs, _ = self.openmax_model.predict(inputs, self.device)

                # Track overall accuracy
                total += super_labels.size(0)
                super_correct += (super_preds == super_labels).sum().item()
                sub_correct += (sub_preds == sub_labels).sum().item()

                # Track novel vs known separately
                is_novel = super_labels >= self.openmax_model.num_superclasses
                novel_total += is_novel.sum().item()
                known_total += (~is_novel).sum().item()

                # A novel sample is correctly classified if detected as novel
                novel_correct += (is_novel & (super_preds == self.openmax_model.num_superclasses)).sum().item()

                # A known sample is correctly classified if prediction matches label
                known_correct += ((~is_novel) & (super_preds == super_labels)).sum().item()

                # Store probabilities and labels for threshold analysis
                all_super_unknown_probs.extend(super_unknown_probs)
                all_is_novel.extend(is_novel.cpu().numpy())

                running_loss += loss.item()

        # Calculate metrics
        avg_loss = running_loss / (i+1)
        super_acc = 100 * super_correct / total if total > 0 else 0
        sub_acc = 100 * sub_correct / total if total > 0 else 0

        novel_acc = 100 * novel_correct / novel_total if novel_total > 0 else 0
        known_acc = 100 * known_correct / known_total if known_total > 0 else 0
        balanced_acc = (novel_acc + known_acc) / 2 if novel_total > 0 and known_total > 0 else 0

        print(f'Validation loss: {avg_loss:.4f}')
        print(f'Overall superclass acc: {super_acc:.2f}%, subclass acc: {sub_acc:.2f}%')

        if novel_total > 0:
            print(f'Novel acc: {novel_acc:.2f}% ({novel_correct}/{novel_total})')
        if known_total > 0:
            print(f'Known acc: {known_acc:.2f}% ({known_correct}/{known_total})')
        if novel_total > 0 and known_total > 0:
            print(f'Balanced acc: {balanced_acc:.2f}%')

        # Update history
        self.history['val_loss'].append(avg_loss)
        self.history['super_acc'].append(super_acc)
        self.history['sub_acc'].append(sub_acc)
        self.history['novel_acc'].append(novel_acc)
        self.history['known_acc'].append(known_acc)
        self.history['balanced_acc'].append(balanced_acc)

        return {
            'loss': avg_loss,
            'super_acc': super_acc,
            'sub_acc': sub_acc,
            'novel_acc': novel_acc,
            'known_acc': known_acc,
            'balanced_acc': balanced_acc
        }

    def _validate_without_openmax(self):
        """Validate during pre-training (before OpenMax is initialized)"""
        self.model.eval()

        correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0

        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(self.device)
                super_labels = super_labels.to(self.device)
                sub_labels = sub_labels.to(self.device)

                # Forward pass
                super_outputs, sub_outputs = self.model(inputs)

                # Compute loss
                loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)

                # Get predictions
                _, super_predicted = torch.max(super_outputs, 1)
                _, sub_predicted = torch.max(sub_outputs, 1)

                # Count correct
                total += super_labels.size(0)
                correct += (super_predicted == super_labels).sum().item()
                sub_correct += (sub_predicted == sub_labels).sum().item()

                running_loss += loss.item()

        # Calculate accuracy
        accuracy = 100 * correct / total if total > 0 else 0
        sub_accuracy = 100 * sub_correct / total if total > 0 else 0
        avg_loss = running_loss / (i+1)

        print(f'Pre-training validation - Loss: {avg_loss:.4f}, Super acc: {accuracy:.2f}%, Sub acc: {sub_accuracy:.2f}%')

        # Update history
        self.history['val_loss'].append(avg_loss)
        self.history['super_acc'].append(accuracy)
        self.history['sub_acc'].append(sub_accuracy)
        self.history['novel_acc'].append(0)
        self.history['known_acc'].append(accuracy)
        self.history['balanced_acc'].append(0)

        return {
            'loss': avg_loss,
            'super_acc': accuracy,
            'sub_acc': sub_accuracy,
            'novel_acc': 0,
            'known_acc': accuracy,
            'balanced_acc': 0
        }

    def test(self, save_to_csv=False, return_predictions=False, output_file='optimized_openmax_predictions.csv'):
        """Test the model and generate predictions with detailed analysis"""
        if not self.test_loader:
            raise ValueError('test_loader not specified')

        if self.openmax_model is None:
            raise ValueError('OpenMax model not initialized')

        self.model.eval()

        # Create dictionaries for predictions
        test_predictions = {
            'image': [],
            'superclass_index': [],
            'subclass_index': [],
            'unknown_probability': []
        }

        with torch.no_grad():
            for i, data in enumerate(self.test_loader):
                inputs, img_name = data[0].to(self.device), data[1]

                # OpenMax prediction
                super_preds, sub_preds, super_unknown_probs, sub_unknown_probs = self.openmax_model.predict(inputs, self.device)

                for j in range(inputs.size(0)):
                    img = img_name[j] if isinstance(img_name, list) else img_name[0]

                    test_predictions['image'].append(img)
                    test_predictions['superclass_index'].append(super_preds[j].item())
                    test_predictions['subclass_index'].append(sub_preds[j].item())
                    test_predictions['unknown_probability'].append(super_unknown_probs[j])

                # Print progress for large datasets
                if (i+1) % 100 == 0:
                    print(f"Processed {i+1} batches...")

        # Create full DataFrame with all information
        full_predictions_df = pd.DataFrame(data=test_predictions)

        # Create simplified DataFrame for output (compatible with original code)
        simplified_predictions = {
            'image': test_predictions['image'],
            'superclass_index': test_predictions['superclass_index'],
            'subclass_index': test_predictions['subclass_index']
        }
        simplified_predictions_df = pd.DataFrame(data=simplified_predictions)

        # Print summary of novel predictions
        novel_super_count = sum(1 for idx in test_predictions['superclass_index'] if idx == self.openmax_model.num_superclasses)
        novel_sub_count = sum(1 for idx in test_predictions['subclass_index'] if idx == self.openmax_model.num_subclasses)

        total_count = len(test_predictions['image'])
        novel_super_perc = 100 * novel_super_count / total_count if total_count > 0 else 0
        novel_sub_perc = 100 * novel_sub_count / total_count if total_count > 0 else 0

        print(f'Test set predictions:')
        print(f'Images predicted as novel superclass: {novel_super_count} ({novel_super_perc:.2f}%)')
        print(f'Images predicted as novel subclass: {novel_sub_count} ({novel_sub_perc:.2f}%)')

        # Print unknown probability statistics
        unknown_probs = np.array(test_predictions['unknown_probability'])
        print(f'Unknown probability statistics:')
        print(f'  Mean: {np.mean(unknown_probs):.4f}')
        print(f'  Std: {np.std(unknown_probs):.4f}')
        print(f'  Min: {np.min(unknown_probs):.4f}')
        print(f'  Max: {np.max(unknown_probs):.4f}')

        # Print distribution of unknown probabilities
        print(f'Unknown probability distribution:')
        bins = [0, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0]
        for i in range(len(bins)-1):
            count = sum(1 for p in unknown_probs if bins[i] <= p < bins[i+1])
            print(f'  {bins[i]:.1f}-{bins[i+1]:.1f}: {count} ({100 * count / len(unknown_probs):.2f}%)')

        if save_to_csv:
            # Save in the same format as the original code
            simplified_predictions_df.to_csv(output_file, index=False)
            print(f"Predictions saved to '{output_file}'")

            # Save detailed version with probabilities
            full_predictions_df.to_csv(f'detailed_{output_file}', index=False)
            print(f"Detailed predictions saved to 'detailed_{output_file}'")

        if return_predictions:
            # Return the predictions for further analysis
            return full_predictions_df, unknown_probs


# Complete integrated solution for training and optimizing OpenMax
def optimize_openmax_model(train_loader, val_loader, test_loader=None, device='cuda', epochs=15):
    """
    End-to-end workflow to train and optimize OpenMax model for both seen and unseen class accuracy

    Args:
        train_loader: Training data loader
        val_loader: Validation data loader
        test_loader: Test data loader (optional)
        device: Device to use (cuda or cpu)
        epochs: Number of training epochs

    Returns:
        Tuple of (model, openmax_model, trainer)
    """
    print("=== Starting Optimized OpenMax Training ===")

    # Phase 1: Initialize model and pre-train
    print("\nPhase 1: Initializing and pre-training classification model")
    model = OptimizedCNN(
        input_size=64,
        num_superclasses=4,
        num_subclasses=88
    ).to(device)

    # Check if we need to handle class imbalance
    class_counts = [0, 0, 0, 0]  # Count samples per superclass
    for _, super_label, _, _, _ in train_loader.dataset:
        if hasattr(super_label, 'item'):
            super_label = super_label.item()
        if super_label < len(class_counts):
            class_counts[super_label] += 1

    print(f"Class distribution: {class_counts}")

    # Create weighted loss if there's imbalance
    non_zero_counts = [c for c in class_counts if c > 0]
    if len(non_zero_counts) > 1 and max(non_zero_counts) / min(non_zero_counts) > 1.5:
        total = sum(class_counts)
        weights = [total / (c * len(non_zero_counts)) if c > 0 else 0.0 for c in class_counts]
        class_weights = torch.tensor(weights, device=device)
        print(f"Using class weights: {weights}")
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    else:
        # Use label smoothing for regularization
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)


    # Use AdamW optimizer with weight decay
    optimizer = optim.AdamW(
        model.parameters(),
        lr=0.001,
        weight_decay=0.0001
    )

    # Initialize trainer with very small center loss weight
    trainer = OptimizedOpenMaxTrainer(
        model=model,
        openmax_model=None,  # Will be created after pre-training
        criterion=criterion,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=device,
        center_loss_weight=0.0005  # Critical fix: Reduce from 0.05 to 0.0005
    )

    # Pre-train for a few epochs
    print("\nPre-training classification model...")
    for epoch in range(5):
        trainer.train_epoch()
        metrics = trainer._validate_without_openmax()
        print(f"Pre-training Epoch {epoch+1}/5: Accuracy: {metrics['super_acc']:.2f}%")

    # Phase 2: Initialize and fit OpenMax
    print("\nPhase 2: Initializing OpenMax model")
    openmax_model = OptimizedOpenMaxModel(
        model,
        num_superclasses=3,
        num_subclasses=87,
        tailsize=40,                  # Larger tailsize for better statistics
        alpha=4,                      # Fewer top activations to consider
        threshold=0.6,                # Initial threshold (will be calibrated)
        distance_multiplier=1.5       # Scale distances for better sensitivity
    )

    # Fit OpenMax parameters
    openmax_model.fit(train_loader, device)

    # Update trainer with OpenMax model
    trainer.openmax_model = openmax_model

    # Phase 3: Calibrate threshold for better novelty detection
    print("\nPhase 3: Calibrating novelty detection threshold")
    best_threshold, _ = openmax_model.calibrate_threshold(val_loader, device)

    # Phase 4: Continue training with OpenMax
    print("\nPhase 4: Continuing training with OpenMax and optimized parameters")
    best_metrics = {'balanced_acc': 0}
    best_state_dict = None

    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        trainer.train_epoch()
        metrics = trainer.validate_epoch()

        # Save best model based on balanced accuracy
        if metrics['balanced_acc'] > best_metrics['balanced_acc']:
            best_metrics = metrics.copy()
            best_state_dict = {
                'model': model.state_dict().copy(),
                'threshold': openmax_model.threshold
            }
            print(f"New best model! Balanced accuracy: {best_metrics['balanced_acc']:.4f}")

        # Re-calibrate every 5 epochs
        if (epoch + 1) % 5 == 0:
            openmax_model.calibrate_threshold(val_loader, device)

    # Load best model
    if best_state_dict is not None:
        model.load_state_dict(best_state_dict['model'])
        openmax_model.threshold = best_state_dict['threshold']
        print(f"Loaded best model with balanced accuracy: {best_metrics['balanced_acc']:.4f}")
        print(f"Using threshold: {openmax_model.threshold:.4f}")

        # Re-fit OpenMax with best model
        openmax_model.fit(train_loader, device)

    return model, openmax_model, trainer


# Example usage:
# model, openmax_model, trainer = optimize_openmax_model(train_loader, val_loader, test_loader)
# predictions, unknown_probs = trainer.test(save_to_csv=True)

In [None]:
# @title Default title text
model, openmax_model, trainer = optimize_openmax_model(train_loader, val_loader, test_loader)
predictions, unknown_probs = trainer.test(save_to_csv=True)

=== Starting Optimized OpenMax Training ===

Phase 1: Initializing and pre-training classification model
Class distribution: [1657, 1878, 2125, 0]

Pre-training classification model...
Training loss: 3.7936 (CE: 3.7591, Center: 0.0345)
Pre-training validation - Loss: 3.0374, Super acc: 96.82%, Sub acc: 38.54%
Pre-training Epoch 1/5: Accuracy: 96.82%
Training loss: 2.5728 (CE: 2.5312, Center: 0.0416)
Pre-training validation - Loss: 2.3244, Super acc: 96.97%, Sub acc: 64.33%
Pre-training Epoch 2/5: Accuracy: 96.97%
Training loss: 2.1408 (CE: 2.0949, Center: 0.0459)
Pre-training validation - Loss: 2.0217, Super acc: 98.73%, Sub acc: 73.25%
Pre-training Epoch 3/5: Accuracy: 98.73%
Training loss: 1.8796 (CE: 1.8317, Center: 0.0479)
Pre-training validation - Loss: 1.9338, Super acc: 97.93%, Sub acc: 77.07%
Pre-training Epoch 4/5: Accuracy: 97.93%
Training loss: 1.7215 (CE: 1.6727, Center: 0.0488)
Pre-training validation - Loss: 1.8164, Super acc: 98.57%, Sub acc: 79.46%
Pre-training Epoch 5/

TypeError: cannot unpack non-iterable NoneType object

In [None]:
# Install the CLIP package
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-cs832wq9
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-cs832wq9
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [15]:
# @title Novelty Detection with Enhanced Threshold Selection

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import random
from torch.utils.data import DataLoader, Subset, random_split

class CNN(nn.Module):
    def __init__(self, input_size=64, num_superclasses=4, num_subclasses=88):
        super().__init__()

        # Calculate feature size based on input size and pooling operations
        # 3 max pooling layers with stride 2 each reduces size by factor of 2^3
        self.feature_size = input_size // (2**3)

        # First convolutional block
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2)
        )

        # Second convolutional block
        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2)
        )

        # Third convolutional block
        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2)
        )

        # Fully connected layers with dropout
        self.fc1 = nn.Linear(self.feature_size * self.feature_size * 128, 256)
        self.dropout1 = nn.Dropout(0.2)  # Add dropout for regularization
        self.fc2 = nn.Linear(256, 128)
        self.dropout2 = nn.Dropout(0.2)  # Add dropout for regularization

        # Classification heads
        self.fc3a = nn.Linear(128, num_superclasses)  # 4 superclasses: bird, dog, reptile, novel
        self.fc3b = nn.Linear(128, num_subclasses)    # All subclasses + novel

    def forward(self, x):
        """Forward pass through the network"""
        # Pass through convolutional blocks
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        # Flatten for fully connected layers
        x = torch.flatten(x, 1)  # flatten all dimensions except batch

        # Fully connected layers with dropout
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)  # Apply dropout after activation
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)  # Apply dropout after activation

        # Classification heads
        super_out = self.fc3a(x)
        sub_out = self.fc3b(x)

        return super_out, sub_out

    def get_features(self, x):
        """Extract features before the final classification layer"""
        # Pass through convolutional blocks
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        # Flatten and pass through FC layers (without final classification)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)  # Apply dropout
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)  # Apply dropout

        return x


class NoveltyDetectionTrainer:
    def __init__(self, full_dataset, image_preprocessing, device='cuda', batch_size=64,
                 min_known_acc=95, min_novel_acc=20):
        self.full_dataset = full_dataset
        self.image_preprocessing = image_preprocessing
        self.device = device
        self.batch_size = batch_size

        # Store energy normalization parameters
        self.energy_mean = 0
        self.energy_std = 1

        # Required accuracy thresholds
        self.min_known_acc = min_known_acc
        self.min_novel_acc = min_novel_acc

        # Fixed energy threshold (based on cross-validation) for high known accuracy (≥95%)
        # and reasonable novel detection (≥20%)
        self.energy_threshold = 1.9  # Fixed threshold from cross-validation

        # Get all unique superclass indices
        self.superclass_indices = set()
        for i in range(len(full_dataset)):
            _, super_idx, _, _, _ = full_dataset[i]
            if hasattr(super_idx, 'item'):
                super_idx = super_idx.item()
            self.superclass_indices.add(super_idx)

        self.superclass_indices = sorted(list(self.superclass_indices))
        print(f"Found superclasses with indices: {self.superclass_indices}")

    def cross_validate_novelty_detection(self, epochs=5, confidence_threshold=0.0):
        """Run cross-validation for novelty detection"""
        results = []

        # For each superclass, treat it as novel and others as known
        for fold, novel_idx in enumerate(self.superclass_indices):
            print(f"\n=== Fold {fold+1}/{len(self.superclass_indices)}: Treating superclass {novel_idx} as novel ===")

            # Create data splits
            known_indices, novel_indices = self._split_by_superclass(novel_idx)

            # Further split known indices into train/validation
            np.random.shuffle(known_indices)
            train_size = int(0.9 * len(known_indices))
            train_indices = known_indices[:train_size]
            val_known_indices = known_indices[train_size:]

            # Create datasets
            train_dataset = Subset(self.full_dataset, train_indices)
            val_known_dataset = Subset(self.full_dataset, val_known_indices)
            val_novel_dataset = Subset(self.full_dataset, novel_indices)

            # Create dataloaders
            train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
            val_known_loader = DataLoader(val_known_dataset, batch_size=self.batch_size, shuffle=False)
            val_novel_loader = DataLoader(val_novel_dataset, batch_size=self.batch_size, shuffle=False)

            # Initialize model, loss, optimizer
            model = CNN(input_size=64, num_superclasses=len(self.superclass_indices)+1).to(self.device)
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(model.parameters(), lr=1e-3)

            # Train the model
            self._train_model(model, criterion, optimizer, train_loader, epochs)

            # Calibrate energy statistics on training data
            self._calibrate_energy_stats(model, train_loader)

            # Evaluate novelty detection
            metrics = self._evaluate_novelty_detection(model, val_known_loader, val_novel_loader, confidence_threshold)
            results.append(metrics)

            print(f"Fold {fold+1} results:")
            for key, value in metrics.items():
                print(f"  {key}: {value:.4f}")

        # Calculate average results across folds
        avg_results = {}
        for key in results[0].keys():
            avg_results[key] = sum(r[key] for r in results) / len(results)

        print("\n=== Cross-Validation Summary ===")
        for key, value in avg_results.items():
            print(f"{key}: {value:.4f}")

        return avg_results, results

    def find_optimal_threshold(self, fold_index=0, threshold_range=np.arange(-3.0, 3.0, 0.1)):
        """Find the optimal normalized energy threshold for a given fold"""
        novel_idx = self.superclass_indices[fold_index]
        print(f"\n=== Finding optimal threshold for fold {fold_index+1}: Superclass {novel_idx} as novel ===")

        # Create data splits
        known_indices, novel_indices = self._split_by_superclass(novel_idx)

        # Further split known indices into train/validation
        np.random.shuffle(known_indices)
        train_size = int(0.9 * len(known_indices))
        train_indices = known_indices[:train_size]
        val_known_indices = known_indices[train_size:]

        # Create datasets
        train_dataset = Subset(self.full_dataset, train_indices)
        val_known_dataset = Subset(self.full_dataset, val_known_indices)
        val_novel_dataset = Subset(self.full_dataset, novel_indices)

        # Create dataloaders
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_known_loader = DataLoader(val_known_dataset, batch_size=self.batch_size, shuffle=False)
        val_novel_loader = DataLoader(val_novel_dataset, batch_size=self.batch_size, shuffle=False)

        # Initialize model, loss, optimizer
        model = CNN(input_size=64, num_superclasses=len(self.superclass_indices)+1).to(self.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)

        # Train the model
        self._train_model(model, criterion, optimizer, train_loader, epochs=5)

        # Calibrate energy statistics on training data
        self._calibrate_energy_stats(model, train_loader)

        # Collect all normalized energy scores
        known_energies, novel_energies = self._collect_energies(model, val_known_loader, val_novel_loader)

        # Evaluate different thresholds
        results = []
        for threshold in threshold_range:
            # For known classes, prediction is "known" if energy <= threshold
            known_correct = sum(1 for e in known_energies if e <= threshold)
            known_accuracy = known_correct / len(known_energies) if known_energies else 0

            # For novel classes, prediction is "novel" if energy > threshold
            novel_correct = sum(1 for e in novel_energies if e > threshold)
            novel_accuracy = novel_correct / len(novel_energies) if novel_energies else 0

            # Balanced accuracy (average of known and novel accuracies)
            balanced_accuracy = (known_accuracy + novel_accuracy) / 2

            results.append({
                'threshold': threshold,
                'known_accuracy': known_accuracy * 100,  # Convert to percentage
                'novel_accuracy': novel_accuracy * 100,  # Convert to percentage
                'balanced_accuracy': balanced_accuracy * 100  # Convert to percentage
            })

            print(f"Threshold {threshold:.2f}: Known Acc={known_accuracy:.4f}, Novel Acc={novel_accuracy:.4f}, Balanced Acc={balanced_accuracy:.4f}")

        # Find thresholds that meet our criteria
        valid_thresholds = []
        for result in results:
            # Check if thresholds meet minimum requirements
            if result['known_accuracy'] >= self.min_known_acc and result['novel_accuracy'] >= self.min_novel_acc:
                valid_thresholds.append(result)

        if valid_thresholds:
            # Choose the threshold with best balanced accuracy from valid ones
            best_result = max(valid_thresholds, key=lambda x: x['balanced_accuracy'])
            print(f"\nFound threshold meeting criteria (known ≥{self.min_known_acc}%, novel ≥{self.min_novel_acc}%):")
        else:
            # No threshold meets criteria, use threshold 1.9 from cross-validation
            print(f"\nNo threshold meets both criteria (known ≥{self.min_known_acc}%, novel ≥{self.min_novel_acc}%)")
            print("Using pre-selected threshold of 1.9 from cross-validation...")

            # Find result closest to threshold 1.9
            best_result = min(results, key=lambda x: abs(x['threshold'] - 1.9))

        print(f"Best threshold: {best_result['threshold']:.2f}")
        print(f"Known accuracy: {best_result['known_accuracy']:.4f}")
        print(f"Novel accuracy: {best_result['novel_accuracy']:.4f}")
        print(f"Balanced accuracy: {best_result['balanced_accuracy']:.4f}")

        return best_result['threshold'], results

    def _calibrate_energy_stats(self, model, loader):
      """Calculate energy statistics on a dataset for normalization"""
      model.eval()
      all_energies = []

      with torch.no_grad():
          for data in loader:
              inputs = data[0].to(self.device)

              # Get model outputs
              super_outputs, _ = model(inputs)

              # Calculate raw energy
              energies = -torch.logsumexp(super_outputs, dim=1)
              all_energies.extend(energies.cpu().numpy())

      # Compute mean and standard deviation
      all_energies = np.array(all_energies)
      self.energy_mean = float(np.mean(all_energies))
      self.energy_std = float(np.std(all_energies) + 1e-6)  # Add epsilon to avoid division by zero

      print(f"Calibrated energy statistics: mean={self.energy_mean:.4f}, std={self.energy_std:.4f}")

    def _compute_normalized_energy(self, logits):
      """Compute normalized energy scores"""
      # Calculate raw energy
      raw_energy = -torch.logsumexp(logits, dim=1)

      # Normalize using stored statistics
      normalized_energy = (raw_energy - self.energy_mean) / self.energy_std

      return normalized_energy

    def _split_by_superclass(self, novel_superclass_idx):
        """Split dataset indices into known and novel based on superclass"""
        known_indices = []
        novel_indices = []

        for i in range(len(self.full_dataset)):
            _, super_idx, _, _, _ = self.full_dataset[i]
            if hasattr(super_idx, 'item'):
                super_idx = super_idx.item()

            if super_idx == novel_superclass_idx:
                novel_indices.append(i)
            else:
                known_indices.append(i)

        return known_indices, novel_indices

    def _train_model(self, model, criterion, optimizer, train_loader, epochs):
        """Train the model on known classes"""
        model.train()
        for epoch in range(epochs):
            running_loss = 0.0
            for i, data in enumerate(train_loader):
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(self.device)
                super_labels = super_labels.to(self.device)
                sub_labels = sub_labels.to(self.device)

                optimizer.zero_grad()
                super_outputs, sub_outputs = model(inputs)
                loss = criterion(super_outputs, super_labels) + criterion(sub_outputs, sub_labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}')

    def _evaluate_novelty_detection(self, model, known_loader, novel_loader, threshold):
      """Evaluate novelty detection performance using balanced ensemble approach."""
      model.eval()

      # First calibrate energy statistics on known data
      self._calibrate_energy_stats(model, known_loader)

      def eval_loader(loader, is_novel):
          super_correct, sub_correct = 0, 0
          super_total, sub_total = 0, 0

          with torch.no_grad():
              for data in loader:
                  inputs, _, _, _, _ = data
                  inputs = inputs.to(self.device)

                  super_outputs, sub_outputs = model(inputs)

                  # --- Energy-based detection with FIXED threshold ---
                  super_energies = self._compute_normalized_energy(super_outputs)
                  # Use the fixed threshold of 1.9 to maintain high known accuracy
                  energy_novel = super_energies > self.energy_threshold  # self.energy_threshold is 1.9

                  # --- Confidence-based detection ---
                  super_probs = F.softmax(super_outputs, dim=1)
                  super_confidences, _ = torch.max(super_probs, dim=1)
                  confidence_novel = super_confidences < 0.7  # Adjust this threshold

                  # --- Balanced approach - weight both signals ---
                  # Energy gets 60% weight, confidence gets 40% weight
                  energy_weight = 0.6
                  confidence_weight = 0.4

                  # Calculate weighted score (1 = novel, 0 = known)
                  novelty_score = energy_weight * energy_novel.float() + confidence_weight * confidence_novel.float()

                  # Consider novel if score > 0.5 (adjustable threshold)
                  is_novel_super = novelty_score > 0.5

                  # --- Subclass detection remains the same ---
                  sub_probs = F.softmax(sub_outputs, dim=1)
                  sub_confidences, _ = torch.max(sub_probs, dim=1)
                  is_novel_sub = sub_confidences < 0.5

                  # Count correct predictions
                  if is_novel:
                      super_correct += is_novel_super.sum().item()
                      sub_correct += is_novel_sub.sum().item()
                  else:
                      super_correct += (~is_novel_super).sum().item()
                      sub_correct += (~is_novel_sub).sum().item()

                  super_total += inputs.size(0)
                  sub_total += inputs.size(0)

          return (
              super_correct / super_total if super_total else 0,
              sub_correct / sub_total if sub_total else 0
          )

      # Evaluate known and novel sets
      known_super_acc, known_sub_acc = eval_loader(known_loader, is_novel=False)
      novel_super_acc, novel_sub_acc = eval_loader(novel_loader, is_novel=True)

      balanced_super_acc = (known_super_acc + novel_super_acc) / 2
      balanced_sub_acc = (known_sub_acc + novel_sub_acc) / 2

      # Check if requirements are met
      known_req_met = known_super_acc * 100 >= self.min_known_acc
      novel_req_met = novel_super_acc * 100 >= self.min_novel_acc

      if known_req_met and novel_req_met:
          print(f"✓ Requirements met: known={known_super_acc*100:.2f}%, novel={novel_super_acc*100:.2f}%")
      else:
          print(f"✗ Requirements not met:")
          if not known_req_met:
              print(f"  Known accuracy {known_super_acc*100:.2f}% < {self.min_known_acc}% requirement")
          if not novel_req_met:
              print(f"  Novel accuracy {novel_super_acc*100:.2f}% < {self.min_novel_acc}% requirement")

      return {
          'known_superclass_accuracy': known_super_acc,
          'novel_superclass_accuracy': novel_super_acc,
          'balanced_superclass_accuracy': balanced_super_acc,
          'known_subclass_accuracy': known_sub_acc,
          'novel_subclass_accuracy': novel_sub_acc,
          'balanced_subclass_accuracy': balanced_sub_acc
      }

    def _collect_energies(self, model, known_loader, novel_loader):
        """Collect normalized energy scores for known and novel classes"""
        model.eval()

        known_energies = []
        novel_energies = []

        with torch.no_grad():
            # Known classes
            for data in known_loader:
                inputs, _, _, _, _ = data
                inputs = inputs.to(self.device)

                super_outputs, _ = model(inputs)
                energies = self._compute_normalized_energy(super_outputs)
                known_energies.extend(energies.cpu().numpy())

            # Novel classes
            for data in novel_loader:
                inputs, _, _, _, _ = data
                inputs = inputs.to(self.device)

                super_outputs, _ = model(inputs)
                energies = self._compute_normalized_energy(super_outputs)
                novel_energies.extend(energies.cpu().numpy())

        return known_energies, novel_energies


class Trainer():
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda',
                min_known_acc=95, min_novel_acc=20):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device

        # For energy normalization
        self.energy_mean = 0
        self.energy_std = 1
        self.energy_calibrated = False

        # Accuracy requirements
        self.min_known_acc = min_known_acc
        self.min_novel_acc = min_novel_acc

        # Fixed energy threshold for high known accuracy (≥95%) and reasonable novel detection (≥20%)
        # This comes from cross-validation
        self.energy_threshold = 1.9

        # Add scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3, verbose=True
        )

        # Store temperature parameter
        self.temperature = 1.5

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        for i, data in enumerate(self.train_loader):
            inputs, super_labels, sub_labels = data[0].to(self.device), data[1].to(self.device), data[3].to(self.device)

            self.optimizer.zero_grad()
            super_outputs, sub_outputs = self.model(inputs)
            loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/(i+1):.3f}')
        avg_loss = running_loss/(i+1)
        self.scheduler.step(avg_loss)

        # Recalibrate energy statistics after each epoch
        self._calibrate_energy_stats()

        return avg_loss

    def _calibrate_energy_stats(self):
        """Calculate energy statistics on training data for normalization"""
        self.model.eval()
        all_energies = []

        with torch.no_grad():
            for data in self.train_loader:
                inputs = data[0].to(self.device)

                # Get model outputs
                super_outputs, _ = self.model(inputs)

                # Calculate raw energy
                energies = -torch.logsumexp(super_outputs, dim=1)
                all_energies.extend(energies.cpu().numpy())

        # Compute mean and standard deviation
        all_energies = np.array(all_energies)
        self.energy_mean = float(np.mean(all_energies))
        self.energy_std = float(np.std(all_energies) + 1e-6)  # Add epsilon to avoid division by zero
        self.energy_calibrated = True

        print(f"Calibrated energy statistics: mean={self.energy_mean:.4f}, std={self.energy_std:.4f}")

    def compute_normalized_energy(self, logits):
        """Compute normalized energy scores"""
        # Calculate raw energy
        raw_energy = -torch.logsumexp(logits, dim=1)

        # Normalize using stored statistics
        if not self.energy_calibrated:
            # If not calibrated, just return raw energy
            print("Warning: Energy statistics not calibrated, using raw energy")
            return raw_energy

        normalized_energy = (raw_energy - self.energy_mean) / self.energy_std

        return normalized_energy

    def validate_epoch(self, novel_superclass_idx=3, novel_subclass_idx=87):
      """
      Validate the model with balanced ensemble novelty detection approach
      using FIXED threshold of 1.9 for better known accuracy.
      """
      # Make sure energy statistics are calibrated
      if not self.energy_calibrated:
          self._calibrate_energy_stats()

      self.model.eval()

      # Metrics to track
      correct_with_novelty = 0
      super_correct_standard = 0
      sub_correct = 0

      novel_total = 0
      known_total = 0
      novel_correct = 0
      known_correct = 0

      total = 0

      novel_super_predictions = 0
      novel_sub_predictions = 0

      all_super_energies = []
      all_sub_confidences = []

      running_loss = 0.0

      with torch.no_grad():
          for i, data in enumerate(self.val_loader):
              inputs, super_labels, _, sub_labels, _ = data
              inputs = inputs.to(self.device)
              super_labels = super_labels.to(self.device)
              sub_labels = sub_labels.to(self.device)

              super_outputs, sub_outputs = self.model(inputs)

              # Normalized energy scores with FIXED threshold (1.9)
              super_energies = self.compute_normalized_energy(super_outputs)
              energy_novel = super_energies > self.energy_threshold  # Using fixed threshold 1.9

              # Confidence scores
              super_probs = F.softmax(super_outputs, dim=1)
              super_confidences, super_predicted = torch.max(super_probs, dim=1)

              # Confidence threshold can be adjusted
              conf_threshold = 0.7
              confidence_novel = super_confidences < conf_threshold

              # BALANCED APPROACH - weighted voting
              # Adjust these weights to control balance
              energy_weight = 0.6    # More weight to energy scores
              confidence_weight = 0.4 # Less weight to confidence

              # Calculate weighted novelty score (0-1 range)
              novelty_score = energy_weight * energy_novel.float() + confidence_weight * confidence_novel.float()

              # Decision threshold - adjust to control sensitivity
              decision_threshold = 0.5  # 0.5 is balanced
              novel_super_mask = novelty_score > decision_threshold

              # Create final predictions
              final_super_preds = torch.where(
                  novel_super_mask,
                  torch.full_like(super_predicted, novel_superclass_idx),
                  super_predicted
              )

              # Subclass confidence-based detection
              sub_probs = F.softmax(sub_outputs, dim=1)
              sub_confidences, sub_predicted = torch.max(sub_probs, dim=1)
              sub_threshold = 0.5
              novel_sub_mask = sub_confidences < sub_threshold

              final_sub_preds = torch.where(
                  novel_sub_mask,
                  torch.full_like(sub_predicted, novel_subclass_idx),
                  sub_predicted
              )

              # Count total
              total += super_labels.size(0)

              # Count correct predictions
              correct_with_novelty += (final_super_preds == super_labels).sum().item()
              super_correct_standard += (super_predicted == super_labels).sum().item()
              sub_correct += (final_sub_preds == sub_labels).sum().item()

              # Count novel vs known separately
              is_novel_label = super_labels == novel_superclass_idx
              novel_total += is_novel_label.sum().item()
              known_total += (~is_novel_label).sum().item()

              novel_correct += ((final_super_preds == super_labels) & is_novel_label).sum().item()
              known_correct += ((final_super_preds == super_labels) & ~is_novel_label).sum().item()

              # Count samples predicted as novel
              novel_super_predictions += novel_super_mask.sum().item()
              novel_sub_predictions += novel_sub_mask.sum().item()

              # Store energy and confidence scores
              all_super_energies.extend(super_energies.cpu().numpy())
              all_sub_confidences.extend(sub_confidences.cpu().numpy())

              # Calculate loss
              loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
              running_loss += loss.item()

      # Calculate metrics
      super_acc = 100 * correct_with_novelty / total if total > 0 else 0
      sub_acc = 100 * sub_correct / total if total > 0 else 0

      novel_acc = 100 * novel_correct / novel_total if novel_total > 0 else 0
      known_acc = 100 * known_correct / known_total if known_total > 0 else 0
      balanced_acc = (novel_acc + known_acc) / 2 if novel_total > 0 and known_total > 0 else 0

      avg_super_energy = sum(all_super_energies) / len(all_super_energies) if all_super_energies else 0
      avg_sub_conf = sum(all_sub_confidences) / len(all_sub_confidences) if all_sub_confidences else 0

      novel_super_perc = 100 * novel_super_predictions / total if total > 0 else 0
      novel_sub_perc = 100 * novel_sub_predictions / total if total > 0 else 0

      # Display metrics
      print(f'Validation loss: {running_loss/(i+1):.3f}')
      print(f'Validation superclass acc: {super_acc:.2f}%')
      print(f'Validation subclass acc: {sub_acc:.2f}%')
      print(f'Novel superclass acc: {novel_acc:.2f}%, Known superclass acc: {known_acc:.2f}%')
      print(f'Balanced superclass acc: {balanced_acc:.2f}%')
      print(f'Average normalized superclass energy: {avg_super_energy:.4f}')
      print(f'Average subclass confidence: {avg_sub_conf:.4f}')
      print(f'Samples predicted as novel superclass: {novel_super_predictions} ({novel_super_perc:.2f}%)')
      print(f'Samples predicted as novel subclass: {novel_sub_predictions} ({novel_sub_perc:.2f}%)')

      # Check if requirements are met
      requirements_met = known_acc >= self.min_known_acc and novel_acc >= self.min_novel_acc

      if requirements_met:
          print(f"✓ REQUIREMENTS MET: known={known_acc:.2f}% ≥ {self.min_known_acc}%, novel={novel_acc:.2f}% ≥ {self.min_novel_acc}%")
      else:
          print(f"✗ REQUIREMENTS NOT MET:")
          if known_acc < self.min_known_acc:
              print(f"  Known accuracy {known_acc:.2f}% < {self.min_known_acc}% requirement")
          if novel_acc < self.min_novel_acc:
              print(f"  Novel accuracy {novel_acc:.2f}% < {self.min_novel_acc}% requirement")

      return {
          'loss': running_loss/(i+1),
          'accuracy': super_acc,
          'novel_acc': novel_acc,
          'known_acc': known_acc,
          'balanced_acc': balanced_acc
      }

    def test(self, save_to_csv=False, return_predictions=False, output_file='example_test_predictions.csv'):
      """
      Test the model with fixed threshold of 1.9 for higher known accuracy
      """
      if not self.test_loader:
          raise NotImplementedError('test_loader not specified')

      # Make sure energy statistics are calibrated
      if not self.energy_calibrated:
          self._calibrate_energy_stats()

      self.model.eval()
      novel_superclass_idx = 3  # Index for novel superclass
      novel_subclass_idx = 87   # Index for novel subclass

      # Create full data structure for internal use
      full_test_predictions = {
          'image': [],
          'superclass_index': [],
          'subclass_index': [],
          'superclass_energy': [],
          'subclass_confidence': [],
          'novelty_score': []
      }

      with torch.no_grad():
          for i, data in enumerate(self.test_loader):
              inputs, img_name = data[0].to(self.device), data[1]

              super_outputs, sub_outputs = self.model(inputs)

              # Normalized energy with FIXED threshold (1.9)
              super_energies = self.compute_normalized_energy(super_outputs)
              energy_novel = super_energies > self.energy_threshold  # Fixed at 1.9

              # Confidence scores for superclasses
              super_probs = F.softmax(super_outputs, dim=1)
              super_confidences, super_predicted = torch.max(super_probs, dim=1)

              # Confidence threshold
              conf_threshold = 0.7
              confidence_novel = super_confidences < conf_threshold

              # BALANCED APPROACH - weighted voting
              energy_weight = 0.6
              confidence_weight = 0.4

              novelty_score = energy_weight * energy_novel.float() + confidence_weight * confidence_novel.float()
              decision_threshold = 0.5
              novel_super_mask = novelty_score > decision_threshold

              # Subclass confidence-based detection
              sub_probs = F.softmax(sub_outputs, dim=1)
              sub_confidences, sub_predicted = torch.max(sub_probs, dim=1)
              sub_threshold = 0.5
              novel_sub_mask = sub_confidences < sub_threshold

              for j in range(inputs.size(0)):
                  img = img_name[j] if isinstance(img_name, list) else img_name[0]

                  # Apply novelty detection
                  super_pred = novel_superclass_idx if novel_super_mask[j] else super_predicted[j].item()
                  sub_pred = novel_subclass_idx if novel_sub_mask[j] else sub_predicted[j].item()

                  full_test_predictions['image'].append(img)
                  full_test_predictions['superclass_index'].append(super_pred)
                  full_test_predictions['subclass_index'].append(sub_pred)
                  full_test_predictions['superclass_energy'].append(super_energies[j].item())
                  full_test_predictions['subclass_confidence'].append(sub_confidences[j].item())
                  full_test_predictions['novelty_score'].append(novelty_score[j].item())

      # Create complete predictions dataframe
      full_predictions_df = pd.DataFrame(data=full_test_predictions)

      # Create simplified dataframe with only the columns that match the first method
      simplified_test_predictions = {
          'image': full_test_predictions['image'],
          'superclass_index': full_test_predictions['superclass_index'],
          'subclass_index': full_test_predictions['subclass_index']
      }
      simplified_predictions_df = pd.DataFrame(data=simplified_test_predictions)

      # Summarize
      novel_super_count = sum(1 for idx in full_test_predictions['superclass_index'] if idx == novel_superclass_idx)
      novel_sub_count = sum(1 for idx in full_test_predictions['subclass_index'] if idx == novel_subclass_idx)

      total_count = len(full_test_predictions['image'])
      novel_super_perc = 100 * novel_super_count / total_count if total_count > 0 else 0
      novel_sub_perc = 100 * novel_sub_count / total_count if total_count > 0 else 0

      print(f'Test set predictions:')
      print(f'Images predicted as novel superclass: {novel_super_count} ({novel_super_perc:.2f}%)')
      print(f'Images predicted as novel subclass: {novel_sub_count} ({novel_sub_perc:.2f}%)')

      # Print distribution of novelty scores to help with threshold tuning
      print(f'Novelty score distribution:')
      bins = [0, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0]
      for i in range(len(bins)-1):
          count = sum(1 for score in full_test_predictions['novelty_score']
                    if bins[i] <= score < bins[i+1])
          print(f'  {bins[i]:.1f}-{bins[i+1]:.1f}: {count} ({100*count/total_count:.2f}%)')

      if save_to_csv:
          # Save in the same format as the first method
          simplified_predictions_df.to_csv(output_file, index=False)
          print(f"Predictions saved to '{output_file}'")

      if return_predictions:
          # Return the full predictions for internal use
          return full_predictions_df


# Helper function to run the cross-validation and find optimal threshold
def train_with_novelty_detection(full_dataset, image_preprocessing, device='cuda', batch_size=64, epochs=5,
                               min_known_acc=95, min_novel_acc=20):
    # Initialize novelty detection trainer with requirements
    novelty_trainer = NoveltyDetectionTrainer(
        full_dataset=full_dataset,
        image_preprocessing=image_preprocessing,
        device=device,
        batch_size=batch_size,
        min_known_acc=min_known_acc,  # Minimum known accuracy requirement
        min_novel_acc=min_novel_acc   # Minimum novel accuracy requirement
    )

    # Run cross-validation to evaluate novelty detection
    print("Running cross-validation for novelty detection...")
    avg_results, fold_results = novelty_trainer.cross_validate_novelty_detection(epochs=epochs)

    # Find optimal threshold
    print("\nFinding optimal energy threshold...")
    best_threshold, threshold_results = novelty_trainer.find_optimal_threshold()

    return avg_results, best_threshold

In [17]:
# Run cross-validation to find best threshold
avg_results, best_threshold = train_with_novelty_detection(
    full_dataset=full_dataset,
    image_preprocessing=image_preprocessing,
    device=device,
    batch_size=64,
    epochs=5,
    min_known_acc=95,
    min_novel_acc=20
)

print(f"Cross-validation found best threshold: {best_threshold}")

# Initialize model and trainer for main training
model = CNN(input_size=64, num_superclasses=4, num_subclasses=88).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.00005)

# Create trainer and SET THE BEST THRESHOLD
trainer = Trainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    min_known_acc=95,
    min_novel_acc=20
)

# Important - use the threshold found during cross-validation
trainer.energy_threshold = best_threshold
print(f"Training with energy threshold: {trainer.energy_threshold}")

# Train for N epochs
num_epochs = 15
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    trainer.train_epoch()
    metrics = trainer.validate_epoch()

# Test the final model
test_results = trainer.test(save_to_csv=True, output_file='best_threshold_predictions.csv')

Found superclasses with indices: [0, 1, 2]
Running cross-validation for novelty detection...

=== Fold 1/3: Treating superclass 0 as novel ===
Epoch 1/5, Loss: 3.5607
Epoch 2/5, Loss: 2.1836
Epoch 3/5, Loss: 1.5421
Epoch 4/5, Loss: 1.1159
Epoch 5/5, Loss: 0.9707
Calibrated energy statistics: mean=-5.9108, std=2.5504
Calibrated energy statistics: mean=-5.7555, std=2.6191
✗ Requirements not met:
  Novel accuracy 0.43% < 20% requirement
Fold 1 results:
  known_superclass_accuracy: 0.9932
  novel_superclass_accuracy: 0.0043
  balanced_superclass_accuracy: 0.4988
  known_subclass_accuracy: 0.6937
  novel_subclass_accuracy: 0.6086
  balanced_subclass_accuracy: 0.6512

=== Fold 2/3: Treating superclass 1 as novel ===
Epoch 1/5, Loss: 3.6545
Epoch 2/5, Loss: 2.2190
Epoch 3/5, Loss: 1.5874
Epoch 4/5, Loss: 1.2041
Epoch 5/5, Loss: 0.9888
Calibrated energy statistics: mean=-7.3367, std=2.8228
Calibrated energy statistics: mean=-7.0493, std=2.8571
✗ Requirements not met:
  Novel accuracy 7.92% < 2



Training loss: 4.063
Calibrated energy statistics: mean=-3.4238, std=1.0056
Validation loss: 2.935
Validation superclass acc: 90.92%
Validation subclass acc: 9.71%
Novel superclass acc: 0.00%, Known superclass acc: 90.92%
Balanced superclass acc: 0.00%
Average normalized superclass energy: 0.0961
Average subclass confidence: 0.2777
Samples predicted as novel superclass: 30 (4.78%)
Samples predicted as novel subclass: 544 (86.62%)
✗ REQUIREMENTS NOT MET:
  Known accuracy 90.92% < 95% requirement
  Novel accuracy 0.00% < 20% requirement

Epoch 2/15
Training loss: 2.656
Calibrated energy statistics: mean=-3.1093, std=1.0720
Validation loss: 2.489
Validation superclass acc: 90.29%
Validation subclass acc: 25.00%
Novel superclass acc: 0.00%, Known superclass acc: 90.29%
Balanced superclass acc: 0.00%
Average normalized superclass energy: 0.0265
Average subclass confidence: 0.4171
Samples predicted as novel superclass: 14 (2.23%)
Samples predicted as novel subclass: 423 (67.36%)
✗ REQUIREMEN