In [1]:
!pip install albumentations opencv-python-headless timm

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import torch
import torch.nn as nn
import numpy as np
import os
from sklearn.metrics import f1_score
# os.system("sudo pip install --user albumentations opencv-python-headless timm")
import timm
import pandas as pd

In [3]:
NICKNAME = "Andrew"
OUTPUTS_a = 10  # Number of classes
IMAGE_SIZE = 224
CHANNELS = 3
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
def load_model(checkpoint_path, num_classes=10, model_type='vit'):
    """
    Load the trained model from checkpoint
    """
    # Recreate model architecture
    if model_type == 'vit':
        model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
    elif model_type == 'efficientnet':
        model = timm.create_model('tf_efficientnet_b4_ns', pretrained=False, num_classes=num_classes)
    elif model_type == 'convnext':
        model = timm.create_model('convnext_base', pretrained=False, num_classes=num_classes)
    elif model_type == 'swin':
        model = timm.create_model('swin_base_patch4_window7_224', pretrained=False, num_classes=num_classes)

    # Load checkpoint with weights_only=False
    if checkpoint_path.endswith('_epoch_*.pt') or 'checkpoint_' in checkpoint_path:
        # Full checkpoint with optimizer and scheduler
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
        print(f"Val F1-micro: {checkpoint['val_f1_micro']:.4f}")
        if 'thresholds' in checkpoint:
            print(f"Saved thresholds: {checkpoint['thresholds']}")
            return model, checkpoint['thresholds']
    else:
        # Simple state dict
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint)
        print(f"Loaded model from {checkpoint_path}")

    model = model.to(device)
    model.eval()

    return model


In [5]:
checkpoint_path = f"/home/ubuntu/deep_learning_exam2/Code/All_Submissions/day4-trial2/checkpoint_Andrew_epoch_7.pt"
model, checkpoint = load_model(checkpoint_path, num_classes=OUTPUTS_a, model_type='swin')

Loaded checkpoint from epoch 7
Val F1-micro: 0.7901
Saved thresholds: [np.float64(0.5000000000000001), np.float64(0.6000000000000002), np.float64(0.6000000000000002), np.float64(0.6500000000000001), np.float64(0.7000000000000002), np.float64(0.6500000000000001), np.float64(0.6000000000000002), np.float64(0.6500000000000001), np.float64(0.6500000000000001), np.float64(0.7000000000000002)]


In [11]:
def find_optimal_thresholds(model, dataloader, device, num_classes=10):
    """Find optimal threshold for each class"""
    model.eval()

    all_probs = []
    all_labels = []

    print("Collecting predictions from validation set...")
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            outputs = model(images)
            probs = torch.sigmoid(outputs).cpu().numpy()

            all_probs.append(probs)
            all_labels.append(labels.numpy())

    all_probs = np.vstack(all_probs)
    all_labels = np.vstack(all_labels)

    print(f"Total samples: {len(all_probs)}")

    best_thresholds = []
    print("\nFinding optimal thresholds for each class:")
    print("=" * 60)

    for class_idx in range(num_classes):
        best_f1 = 0
        best_threshold = 0.5

        threshold_range = np.arange(0.1, 0.91, 0.05)

        for threshold in threshold_range:
            preds = (all_probs[:, class_idx] >= threshold).astype(int)
            f1 = f1_score(all_labels[:, class_idx], preds, zero_division=0)

            if f1 > best_f1:
                best_f1 = f1
                best_threshold = threshold

        best_thresholds.append(best_threshold)
        print(f"Class {class_idx}: Threshold={best_threshold:.2f}, F1={best_f1:.4f}")

    print("=" * 60)

    # Calculate overall metrics
    pred_labels = np.zeros_like(all_probs)
    for i, threshold in enumerate(best_thresholds):
        pred_labels[:, i] = (all_probs[:, i] >= threshold).astype(int)

    overall_f1_micro = f1_score(all_labels, pred_labels, average='micro', zero_division=0)
    overall_f1_macro = f1_score(all_labels, pred_labels, average='macro', zero_division=0)

    print(f"\nOverall Performance:")
    print(f"F1-Micro: {overall_f1_micro:.4f}")
    print(f"F1-Macro: {overall_f1_macro:.4f}")

    return best_thresholds

In [6]:
checkpoint = [float(round(i,2)) for i in checkpoint]
checkpoint

[0.5, 0.6, 0.6, 0.65, 0.7, 0.65, 0.6, 0.65, 0.65, 0.7]