In [1]:
import os
import sys
import time
import math
import csv
import random
import warnings
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms, models
from PIL import Image
from functools import partial
from typing import Dict, Callable, List, Tuple, Optional
from einops import rearrange
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, cohen_kappa_score, balanced_accuracy_score,
    confusion_matrix, log_loss
)
from ptflops import get_model_complexity_info
import matplotlib.pyplot as plt

In [2]:
import torchvision
import sklearn
import PIL
import matplotlib
import einops

print("Python:", sys.version)
print("Torch:", torch.__version__)
print("Torchvision:", torchvision.__version__)
print("NumPy:", np.__version__)
print("Pandas:", pd.__version__)
print("scikit-learn:", sklearn.__version__)
print("Pillow:", PIL.__version__)
print("Matplotlib:", matplotlib.__version__)
print("einops:", einops.__version__)

Python: 3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0]
Torch: 2.1.2+cu121
Torchvision: 0.16.2+cu121
NumPy: 1.26.3
Pandas: 2.2.3
scikit-learn: 1.6.1
Pillow: 10.2.0
Matplotlib: 3.8.2
einops: 0.8.1


In [3]:
data = np.load('/root/autodl-fs/retinamnist_224.npz')
train_images = data['train_images']
train_labels = data['train_labels']
val_images = data['val_images']
val_labels = data['val_labels']
test_images = data['test_images']
test_labels = data['test_labels']
data.close()

def to_tensor(images, labels):
    images = torch.from_numpy(images).float()
    print(images.shape)
    if images.dim() == 3:
        images = images.unsqueeze(-1).repeat(1, 1, 1, 3)
    images = images.permute(0, 3, 1, 2)
    labels = torch.from_numpy(labels).long().squeeze()
    return images, labels

train_images, train_labels = to_tensor(train_images, train_labels)
val_images, val_labels = to_tensor(val_images, val_labels)
test_images, test_labels = to_tensor(test_images, test_labels)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class MNISTDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

batch_size = 32
train_dataset = MNISTDataset(train_images, train_labels, transform=train_transform)
test_dataset = MNISTDataset(test_images, test_labels, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print("Shape of the training set images:", train_images.shape) 
print("Shape of training set labels:", train_labels.shape)

torch.Size([1080, 224, 224, 3])
torch.Size([120, 224, 224, 3])
torch.Size([400, 224, 224, 3])
Shape of the training set images: torch.Size([1080, 3, 224, 224])
Shape of training set labels: torch.Size([1080])


In [4]:
NUM_CLASSES = 5
BASE_PATH = Path("/root/models")
HEAD_LAYER_NAMES = ("head", "classifier", "fc", "logits")

def dynamic_import(model_path: Path, module_path: str):
    sys.path.insert(0, str(model_path))
    try:
        module = __import__(module_path)
        return module
    finally:
        sys.path.pop(0)

_ = dynamic_import(f"{BASE_PATH}/EMO", "emo_models")
from emo_models import EMO_2M

def _filter_head_keys(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    return {k: v for k, v in state_dict.items() if not any(layer_name in k for layer_name in HEAD_LAYER_NAMES)}

def load_pretrained_weights(model: nn.Module, ckpt_path: Path) -> None:
    checkpoint = torch.load(str(ckpt_path), map_location="cpu")
    state_dict = checkpoint.get("model", checkpoint)
    filtered = _filter_head_keys(state_dict)
    result = model.load_state_dict(filtered, strict=False)
    if result.missing_keys:
        print(f"[Missing] {', '.join(result.missing_keys[:3])}...")
    if result.unexpected_keys:
        print(f"[Extra] {', '.join(result.unexpected_keys[:3])}...")

def init_classification_head(model: nn.Module) -> None:
    for name, module in model.named_modules():
        if any(name.endswith(layer) for layer in HEAD_LAYER_NAMES):
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
                print(f"Reinitialized: {name}")

def load_simple_checkpoint(model_constructor: Callable[..., nn.Module], ckpt_path: Path) -> nn.Module:
    model = model_constructor(num_classes=NUM_CLASSES)
    load_pretrained_weights(model, ckpt_path)
    init_classification_head(model)
    return model

def load_complex_checkpoint(model_constructor: Callable[..., nn.Module], ckpt_path: Path, key: str = "model", shape_transform: bool = False) -> nn.Module:
    model = model_constructor(num_classes=NUM_CLASSES)
    checkpoint = torch.load(str(ckpt_path), map_location="cpu")
    state_dict = checkpoint.get(key, checkpoint)
    if shape_transform:
        target = model.state_dict()
        for k in list(state_dict.keys()):
            if k in target and state_dict[k].ndim != target[k].ndim:
                state_dict[k] = state_dict[k].view(*target[k].shape)
    filtered = _filter_head_keys(state_dict)
    model.load_state_dict(filtered, strict=False)
    init_classification_head(model)
    return model

model_configs = [
    ("EMO-2M",
     lambda: load_simple_checkpoint(
         EMO_2M,
         f"{BASE_PATH}/EMO/weights/net.pth"
     )),
]

def build_models(configs: list[tuple[str, Callable[[], nn.Module]]]) -> Dict[str, Optional[nn.Module]]:
    built = {}
    for name, builder in configs:
        try:
            built[name] = builder()
            print(f"✅ Success: {name}")
        except Exception as e:
            warnings.warn(f"❌ Failed to load {name}: {e}")
            built[name] = None
    return built

models = build_models(model_configs)

if __name__ == "__main__":
    print("\nLoaded Models:")
    for name, model in models.items():
        status = "Loaded" if model else "Failed"
        print(f"{name.ljust(16)}: {status}")
    sample = models.get("EMO_2M")
    if sample is not None:
        print("\nSample Model Architecture:")
        print(sample)

[Missing] head.weight, head.bias...
Reinitialized: head
✅ Success: EMO-2M

Loaded Models:
EMO-2M          : Loaded


In [5]:
def _softplus_floor(x, eps=1e-6):
    return F.softplus(x) + eps

def _collect_logits_labels(model: nn.Module, data_loader, device):
    model.eval()
    logits_list, labels_list = [], []
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            logits_list.append(outputs)
            labels_list.append(labels)
    return torch.cat(logits_list), torch.cat(labels_list)

class MiTLoss_WithTrainCalibration(nn.Module):
    def __init__(self, num_classes: int, train_loader, model: nn.Module, device):
        super().__init__()
        self.num_classes = int(num_classes)
        self.device = device
        self.model = model
        self.train_loader = train_loader

        # ---- Initialize temperature using training set ----
        init_tau = self._initialize_temperature()
        self.tau = nn.Parameter(init_tau)

        # ---- Running class histogram (for empirical label entropy H*) ----
        self.register_buffer("class_counts", torch.ones(self.num_classes))
        self.register_buffer("total_seen", torch.tensor(self.num_classes, dtype=torch.long))

        # ---- Dual-averaged λ ----
        self.register_buffer("lambda_entropy", torch.tensor(0.1))
        self.register_buffer("dual_updates", torch.tensor(0.1, dtype=torch.long))

        self.ce = nn.CrossEntropyLoss(reduction="mean")

    def _initialize_temperature(self):
        warnings.warn("Initializing temperature using the training set.")
        logits, labels = _collect_logits_labels(self.model, self.train_loader, self.device)
        logits, labels = logits.to(self.device), labels.to(self.device)

        logT = torch.tensor(0.0, device=self.device, requires_grad=True)
        opt = torch.optim.LBFGS([logT], lr=0.1, max_iter=50, line_search_fn="strong_wolfe")

        def closure():
            opt.zero_grad()
            T = torch.exp(logT)
            log_probs = F.log_softmax(logits / T, dim=1)
            nll = F.nll_loss(log_probs, labels, reduction='mean')
            nll.backward()
            return nll

        opt.step(closure)
        T_star = torch.exp(logT).detach()               
        tau0 = torch.log(torch.expm1(T_star).clamp_min(1e-12))
        return tau0

    @torch.no_grad()
    def _update_label_entropy(self, targets: torch.Tensor):
        dev = self.class_counts.device
        targets = targets.to(dev, dtype=torch.long)
        binc = torch.bincount(targets, minlength=self.num_classes).to(dev, dtype=self.class_counts.dtype)
        self.class_counts += binc
        self.total_seen += targets.numel()

    @torch.no_grad()
    def _empirical_label_entropy(self) -> torch.Tensor:
        probs = self.class_counts / self.class_counts.sum()
        logp = torch.log(probs.clamp_min(1e-12))
        H_star = -(probs * logp).sum()
        return H_star

    def forward(self, logits: torch.Tensor, targets: torch.Tensor):
        dev = logits.device
        targets = targets.to(dev, dtype=torch.long)

        # 1) Update H* from data
        self._update_label_entropy(targets)
        H_star = self._empirical_label_entropy()
        H_max = math.log(self.num_classes + 1e-12)

        # 2) Temperature scaling
        T = _softplus_floor(self.tau).clamp(1e-3, 500.0)
        scaled = logits / T

        # 3) Cross-entropy loss
        ce_loss = self.ce(scaled, targets)

        # 4) Entropy of predictive distribution
        log_probs = F.log_softmax(scaled, dim=1)
        probs = log_probs.exp()
        H = -(probs * log_probs).sum(dim=1).mean()

        # 5) Update λ via dual-averaging
        with torch.no_grad():
            d = (H_star - H) / max(H_max, 1e-12)
            d = torch.clamp(d, min=0.0)
            self.dual_updates += 1
            new_lambda = (self.lambda_entropy * (self.dual_updates - 1) + d) / self.dual_updates
            new_lambda = torch.clamp(new_lambda, 0.0, 0.5)
            self.lambda_entropy.copy_(new_lambda)

        # 6) Combined loss
        loss = ce_loss - self.lambda_entropy * H

        stats = {
            "loss": loss.detach(),
            "ce": ce_loss.detach(),
            "H": H.detach(),
            "H_star": H_star.detach(),
            "lambda": self.lambda_entropy.detach(),
            "T": T.detach()
        }
        return loss, stats

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def calculate_specificity(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    specificity = 0.0
    for i in range(cm.shape[0]):
        tn = np.sum(np.delete(np.delete(cm, i, axis=0), i, axis=1))
        fp = np.sum(cm[:, i]) - cm[i, i]
        denominator = tn + fp
        specificity += tn / denominator if denominator != 0 else 0.0
    return specificity / cm.shape[0]

def train(epoch, net, optimizer, criterion, train_metrics):
    print('\nEpoch:', epoch)
    net.train()
    train_loss, correct, total = 0, 0, 0
    targets_all, predicted_all = [], []
    train_T_list, train_lambda_list = [], []
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss, stats = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        T_batch = float(stats["T"].item())
        lambda_batch = float(stats["lambda"].item())
        logits_scaled = outputs / T_batch
        probs = torch.softmax(logits_scaled, dim=1)
        _, predicted = probs.max(1)
        train_loss += loss.item()
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        targets_all.extend(targets.detach().cpu().numpy())
        predicted_all.extend(predicted.detach().cpu().numpy())
        train_T_list.append(T_batch)
        train_lambda_list.append(lambda_batch)
    OA = 100 * accuracy_score(targets_all, predicted_all)
    P = 100 * precision_score(targets_all, predicted_all, average='macro')
    Se = 100 * recall_score(targets_all, predicted_all, average='macro')
    Sp = 100 * calculate_specificity(targets_all, predicted_all)
    F1 = 100 * f1_score(targets_all, predicted_all, average='macro')
    Kappa = 100 * cohen_kappa_score(targets_all, predicted_all)
    T_epoch = float(np.mean(train_T_list)) if train_T_list else float(stats["T"].item())
    lambda_epoch = float(np.mean(train_lambda_list)) if train_lambda_list else float(stats["lambda"].item())
    train_metrics['loss'].append(train_loss / len(train_loader))
    train_metrics['P'].append(P)
    train_metrics['Se'].append(Se)
    train_metrics['Sp'].append(Sp)
    train_metrics['F1'].append(F1)
    train_metrics['OA'].append(OA)
    train_metrics['Kappa'].append(Kappa)
    train_metrics['T'].append(T_epoch)
    train_metrics['lambda'].append(lambda_epoch)
    print(f'Train Loss: {train_loss/len(train_loader):.3f} | OA: {OA:.1f}% | P: {P:.1f} | Se: {Se:.1f} | Sp: {Sp:.1f} | F1: {F1:.1f} | Kappa: {Kappa:.1f} | T(avg): {T_epoch:.3f} | lambda(avg): {lambda_epoch:.3f}')

class ModelSaver:
    def __init__(self, model_name):
        self.best_metrics = {'OA': 0.0, 'AUC': 0.0, 'F1': 0.0, 'P': 0.0, 'Se': 0.0, 'Sp': 0.0, 'Kappa': 0.0}
        self.save_dir = os.path.join('best_models', model_name)
        os.makedirs(self.save_dir, exist_ok=True)
    def check_and_save(self, net, criterion, current_metrics):
        for metric in ['OA', 'AUC', 'F1', 'Kappa']:
            if current_metrics[metric] > self.best_metrics[metric]:
                self.best_metrics[metric] = current_metrics[metric]
                torch.save({"model": net.state_dict(), "criterion": criterion.state_dict()}, os.path.join(self.save_dir, f'best_{metric}.pth'))
        torch.save({"model": net.state_dict(), "criterion": criterion.state_dict()}, os.path.join(self.save_dir, 'final_model.pth'))

def val(epoch, net, criterion, val_metrics, model_saver):
    net.eval()
    val_loss, correct, total = 0, 0, 0
    targets_all, predicted_all, probabilities_all = [], [], []
    val_T_list = []
    val_lambda_list = []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss, stats = criterion(outputs, targets)
            val_loss += loss.item()
            T_batch = float(stats["T"].item())
            lambda_batch = float(stats["lambda"].item())
            logits_scaled = outputs / T_batch
            probs = torch.softmax(logits_scaled, dim=1)
            _, predicted = probs.max(1)
            probabilities_all.extend(probs.detach().cpu().numpy())
            targets_all.extend(targets.detach().cpu().numpy())
            predicted_all.extend(predicted.detach().cpu().numpy())
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            val_T_list.append(T_batch)
            val_lambda_list.append(lambda_batch)
    OA = 100 * accuracy_score(targets_all, predicted_all)
    P = 100 * precision_score(targets_all, predicted_all, average='macro')
    Se = 100 * recall_score(targets_all, predicted_all, average='macro')
    Sp = 100 * calculate_specificity(targets_all, predicted_all)
    F1 = 100 * f1_score(targets_all, predicted_all, average='macro')
    Kappa = 100 * cohen_kappa_score(targets_all, predicted_all)
    try:
        n_classes = len(np.unique(targets_all))
        if n_classes == 2:
            AUC = 100 * roc_auc_score(targets_all, np.array(probabilities_all)[:, 1])
        else:
            AUC = 100 * roc_auc_score(targets_all, probabilities_all, multi_class='ovr', average='macro')
    except Exception as e:
        print(f"AUC calculation failed: {str(e)}")
        AUC = 0.0
    T_avg = float(np.mean(val_T_list)) if val_T_list else float(stats["T"].item())
    lambda_avg = float(np.mean(val_lambda_list)) if val_lambda_list else float(stats["lambda"].item())
    val_metrics['loss'].append(val_loss / len(test_loader))
    val_metrics['P'].append(P)
    val_metrics['Se'].append(Se)
    val_metrics['Sp'].append(Sp)
    val_metrics['F1'].append(F1)
    val_metrics['OA'].append(OA)
    val_metrics['AUC'].append(AUC)
    val_metrics['Kappa'].append(Kappa)
    val_metrics['T'].append(T_avg)
    val_metrics['lambda'].append(lambda_avg)
    current_metrics = {'OA': OA, 'AUC': AUC, 'F1': F1, 'P': P, 'Se': Se, 'Sp': Sp, 'Kappa': Kappa, 'T': T_avg, 'lambda': lambda_avg}
    model_saver.check_and_save(net, criterion, current_metrics)
    print(f'Val  Loss: {val_loss/len(test_loader):.3f} | OA: {OA:.1f}% | P: {P:.1f} | Se: {Se:.1f} | Sp: {Sp:.1f} | F1: {F1:.1f} | AUC: {AUC:.1f} | Kappa: {Kappa:.1f} | T(avg): {T_avg:.3f} | lambda(avg): {lambda_avg:.3f}')

def train_and_save_model(model, model_name, train_loader, test_loader, num_classes, num_epochs=10):
    train_metrics = {key: [] for key in ['loss', 'P', 'Se', 'Sp', 'F1', 'OA', 'Kappa', 'T', 'lambda']}
    val_metrics = {key: [] for key in ['loss', 'P', 'Se', 'Sp', 'F1', 'OA', 'AUC', 'Kappa', 'T', 'lambda']}
    net = model.to(device)
    criterion = MiTLoss_WithTrainCalibration(num_classes, train_loader, model, device)
    optimizer = torch.optim.Adam(list(net.parameters()) + list(criterion.parameters()), lr=0.0001)
    model_saver = ModelSaver(model_name)
    for epoch in range(num_epochs):
        start_time = time.time()
        train(epoch, net, optimizer, criterion, train_metrics)
        val(epoch, net, criterion, val_metrics, model_saver)
        epoch_time = time.time() - start_time
        print(f'Epoch {epoch+1}/{num_epochs} Time: {epoch_time:.1f}s')
    os.makedirs(model_name, exist_ok=True)
    save_metrics_to_csv(train_metrics, os.path.join(model_name, 'train_metrics.csv'))
    save_metrics_to_csv(val_metrics, os.path.join(model_name, 'val_metrics.csv'))

def save_metrics_to_csv(metrics, file_path):
    fieldnames = ['epoch'] + list(metrics.keys())
    epochs = range(1, len(next(iter(metrics.values()))) + 1)
    with open(file_path, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for epoch in epochs:
            row = {'epoch': epoch}
            for key in metrics.keys():
                row[key] = metrics[key][epoch-1]
            writer.writerow(row)

In [None]:
if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    for config in model_configs:
        model_name, model_builder = config
        torch.cuda.empty_cache()
        print(f"\n{'='*20} Training {model_name} {'='*20}")
        model = model_builder().to(device)
        train_and_save_model(
                model=model,
                model_name=model_name,
                train_loader=train_loader,
                test_loader=test_loader,
                num_classes = NUM_CLASSES,
                num_epochs=10,
            )


[Missing] head.weight, head.bias...
Reinitialized: head

Epoch: 0
Train Loss: 1.152 | OA: 49.8% | P: 34.5 | Se: 32.8 | Sp: 85.2 | F1: 32.5 | Kappa: 24.8 | T(avg): 2.472 | lambda(avg): 0.071
Val  Loss: 0.914 | OA: 54.8% | P: 30.0 | Se: 33.8 | Sp: 86.0 | F1: 31.2 | AUC: 81.5 | Kappa: 29.9 | T(avg): 2.473 | lambda(avg): 0.125
Epoch 1/10 Time: 2.1s

Epoch: 1
Train Loss: 0.841 | OA: 60.6% | P: 51.9 | Se: 44.7 | Sp: 88.7 | F1: 44.7 | Kappa: 42.3 | T(avg): 2.472 | lambda(avg): 0.142
Val  Loss: 0.809 | OA: 61.0% | P: 49.9 | Se: 47.3 | Sp: 88.8 | F1: 46.7 | AUC: 85.3 | Kappa: 43.2 | T(avg): 2.472 | lambda(avg): 0.155
Epoch 2/10 Time: 2.1s

Epoch: 2
Train Loss: 0.726 | OA: 64.0% | P: 57.2 | Se: 51.0 | Sp: 89.6 | F1: 52.4 | Kappa: 47.6 | T(avg): 2.472 | lambda(avg): 0.169
Val  Loss: 0.698 | OA: 64.2% | P: 45.8 | Se: 46.4 | Sp: 89.7 | F1: 43.4 | AUC: 87.6 | Kappa: 47.7 | T(avg): 2.472 | lambda(avg): 0.185
Epoch 3/10 Time: 2.1s

Epoch: 3
Train Loss: 0.672 | OA: 65.6% | P: 57.5 | Se: 53.6 | Sp: 90.