# Notebook Description

Below is a complete training code for **AlterNet-LC**, including data preprocessing, model definition, training, validation, and testing steps.

The notebook also includes comparative experiments on the **PneumoniaMNIST** dataset with three models: **DenseNet**, **GLCM+SVM**, and **LBP+RF**.

Finally, the trained **AlterNet-LC** model is tested for generalization capability on a **Kaggle dataset**.

#### **Warning:**
The code in this notebook is for reference only. It has **not undergone strict data leakage prevention or logic optimization** and should **not be directly used in production environments**.

Results may vary due to device differences. Default configurations are provided for reference; adjust and debug according to your actual setup.

#### Authors：Li Jiawei；Chen Mingfang；Yao Zehan

MedMNIST provides a corresponding Python library for accessing its datasets.

First, install `medmnist`:
```bash
pip install medmnist
```

Then download the dataset locally via Python (may require a VPN):
```python
import medmnist
from medmnist import INFO
import numpy as np
import torch
import torchvision.transforms as transforms

data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

data_flag = 'pneumoniamnist'
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])

train_dataset = DataClass(split='train', transform=data_transform, download=True, size=224, mmap_mode='r')
```

For generalization testing, the external test dataset can be downloaded via the Kaggle API. First, install `kagglehub`:
```bash
pip install kagglehub
```

The first time using `kagglehub` may require logging into your Kaggle account:
```python
import kagglehub
kagglehub.login()
```

After logging in, download the dataset:
```python
import kagglehub

# Download latest version
path = kagglehub.dataset_download("paultimothymooney/chest-xray-pneumonia")

print("Path to dataset files:", path)
```

# AlterNet-LC

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms, models
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from tqdm import tqdm
import os
import datetime
import time
# Chinese Font Support
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# random number seed
torch.manual_seed(101010)
np.random.seed(101010)
# Checking GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

In [None]:
# data path
data_path = r"pneumoniamnist_224.npz"
def ensure_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
        print(f"make dir: {dir_path}")

# make results_dir
results_dir = "evaluation_results"
ensure_dir(results_dir)
img_dir = os.path.join(results_dir, "images")
ensure_dir(img_dir)

## 1.WindowAttention

In [None]:
# WindowAttention
class WindowAttention:
    @staticmethod
    def partition(x, window_size):
        """
        Partition the input tensor into non-overlapping windows.
        Args:
            x: input tensor (B, C, H, W)
            window_size: window size
        Returns:
            windows: windowed tensor (B*num_windows, C, window_size, window_size)
            padded_shape: padded shape (Hp, Wp)
        """
        B, C, H, W = x.shape

        pad_h = (window_size - H % window_size) % window_size
        pad_w = (window_size - W % window_size) % window_size

        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h))

        Hp, Wp = H + pad_h, W + pad_w
        num_win_h, num_win_w = Hp // window_size, Wp // window_size

        x = x.view(B, C, num_win_h, window_size, num_win_w, window_size)

        windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size)

        return windows, (Hp, Wp)

    @staticmethod
    def reverse(windows, window_size, H, W, Hp, Wp):
        """
        Merge windows back into the complete feature map
        Args:
            windows: windowed tensor (B*num_windows, C, window_size, window_size)
            window_size: window size
            H, W: original height and width
            Hp, Wp: padded height and width
        Returns:
            merged tensor (B, C, H, W)
        """
        B_win, C, _, _ = windows.shape
        num_win_h, num_win_w = Hp // window_size, Wp // window_size
        B = B_win // (num_win_h * num_win_w)

        x = windows.view(B, num_win_h, num_win_w, C, window_size, window_size)

        x = x.permute(0, 3, 1, 4, 2, 5).contiguous()

        x = x.view(B, C, Hp, Wp)

        if Hp > H or Wp > W:
            x = x[:, :, :H, :W]

        return x

## 2.Multiple Self-attention Blocks

In [None]:
class MSABlock(nn.Module):
    """Multiple self-attention blocks, include Window Attention, LayerNorm and Residual Connection"""

    def __init__(self, dim, num_heads, head_dim, window_size=7, mlp_ratio=4., dropout_rate=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        self.dropout_rate = dropout_rate
        self.total_head_dim = num_heads * head_dim

        # attention layer
        self.qkv = nn.Linear(dim, self.total_head_dim * 3, bias=False)
        self.proj = nn.Linear(self.total_head_dim, dim)
        self.proj_drop = nn.Dropout(dropout_rate)

        # MLP layer
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout_rate)
        )

        # normalization layer
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)

        # Weight Initialization
        nn.init.xavier_uniform_(self.qkv.weight)
        nn.init.xavier_uniform_(self.proj.weight)

    def forward(self, x):
        B, C, H, W = x.shape
        shortcut1 = x

        # normalization
        x = x.permute(0, 2, 3, 1)  # B, H, W, C
        x = self.norm1(x)

        # Window segmentation
        x_windows, (Hp, Wp) = WindowAttention.partition(x.permute(0, 3, 1, 2), self.window_size)
        x_windows = x_windows.permute(0, 2, 3, 1)
        win_seq_len = self.window_size * self.window_size
        x_windows_seq = x_windows.reshape(-1, win_seq_len, C)

        # generate QKV
        qkv = self.qkv(x_windows_seq)
        qkv = qkv.reshape(-1, win_seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)

        # attention calculation
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        scale = self.head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = attn.softmax(dim=-1)

        # weighted sum
        attn_output = (attn @ v)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(-1, win_seq_len, self.total_head_dim)

        # projection
        attn_output = self.proj(attn_output)
        attn_output = self.proj_drop(attn_output)

        # Restore window format
        attn_output = attn_output.reshape(-1, self.window_size, self.window_size, C)
        attn_output = attn_output.permute(0, 3, 1, 2)

        # merge window
        attn_merged = WindowAttention.reverse(attn_output, self.window_size, H, W, Hp, Wp)

        # first residual link
        x = shortcut1 + attn_merged

        # MLP
        shortcut2 = x
        x = x.permute(0, 2, 3, 1)
        x = self.norm2(x)
        x = self.mlp(x)
        x = x.permute(0, 3, 1, 2)

        # second residual link
        x = shortcut2 + x

        return x

## 3.Contrast Self-attention Blocks

In [None]:
class ContrastiveMSABlock(MSABlock):
    """Multiple Self-Attention Blocks with Contrastive Learning Headers"""

    def __init__(self, dim, num_heads, head_dim, window_size=7, mlp_ratio=4., dropout_rate=0.1,
                 proj_dim=128, temperature=0.07):
        super().__init__(dim, num_heads, head_dim, window_size, mlp_ratio, dropout_rate)

        # Compare and contrast learning parameters
        self.temperature = temperature

        # Comparative Learning Projection Heads
        self.contrast_projection = nn.Sequential(
            nn.Linear(dim, dim // 2),
            nn.LayerNorm(dim // 2),
            nn.ReLU(inplace=True),
            nn.Linear(dim // 2, proj_dim)
        )

        # Initialize projection header weights
        for m in self.contrast_projection.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def contrastive_loss(self, features, labels):

        # Normalize the features
        features = F.normalize(features, dim=1)

        # Calculate the cosine similarity between samples
        similarity_matrix = torch.matmul(features, features.T) / self.temperature

        # Creating a label similarity matrix
        labels = labels.view(-1, 1)
        mask_same_class = torch.eq(labels, labels.T).float()

        # Creating a self-mask
        mask_self = torch.eye(mask_same_class.shape[0], device=mask_same_class.device)
        mask_other = 1 - mask_self

        # positive sample-pair mask (computing)
        mask_positives = mask_same_class * mask_other

        # Check if there are enough positive sample pairs before calculating the loss
        positive_per_sample = mask_positives.sum(1)
        valid_samples = positive_per_sample > 0

        if valid_samples.sum() == 0:
            return torch.tensor(0.0, device=features.device)

        # Calculation of comparative losses
        exp_similarity = torch.exp(similarity_matrix) * mask_other

        # Calculate log-sum-exp for each sample
        log_prob_denominator = torch.log(exp_similarity.sum(1, keepdim=True) + 1e-12)

        # Calculate the log probability of a positive sample pair
        log_prob_positives = similarity_matrix - log_prob_denominator

        # Calculate the average log probability of a positive sample pair for each sample
        mean_log_prob_positives = (mask_positives * log_prob_positives).sum(1) / (positive_per_sample + 1e-12)

        # Calculation of ultimate loss
        mean_log_prob_positives = mean_log_prob_positives[valid_samples]

        # Final loss
        contrastive_loss = -mean_log_prob_positives.mean()

        return contrastive_loss

    def forward(self, x, labels=None):
        """Forward Propagation

        Args:
            x: Input features [B, C, H, W]
            labels: Optional, class labels [B]

        Returns:
            x: Processed features
            contrast_loss: Contrastive loss value
        """
        # Call forward propagation of the parent class to get the basic output
        x = super().forward(x)

        # If no labels are provided, return directly to the feature
        if labels is None:
            return x, 0.0

        # Global average pooling extracts a feature representation for each sample
        batch_size = x.shape[0]
        pooled_features = F.adaptive_avg_pool2d(x, (1, 1)).view(batch_size, -1)

        # Mapping to Contrast Learning Space using a Projection Head
        projected_features = self.contrast_projection(pooled_features)

        # Calculation of comparative losses
        contrast_loss = self.contrastive_loss(projected_features, labels)

        return x, contrast_loss

## 4.Pre-activated Residual Block

In [None]:
class PreActResidualBlock(nn.Module):
    """Pre-activated residual block"""

    def __init__(self, in_channels, out_channels, stride=1, use_projection=False):
        super().__init__()
        self.use_projection = use_projection

        # Projection Shortcut
        if self.use_projection:
            self.proj_conv = nn.Conv2d(
                in_channels, out_channels, kernel_size=1,
                stride=stride, padding=0, bias=False
            )

        # primary path
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3,
            stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False
        )

        # Weight Initialisation
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
        if self.use_projection:
            nn.init.kaiming_normal_(self.proj_conv.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        shortcut = x

        # pre-activation
        x = self.bn1(x)
        x = F.relu(x)

        # Projection Shortcut
        if self.use_projection:
            shortcut = self.proj_conv(x)

        # The first convolution
        x = self.conv1(x)
        x = self.bn2(x)
        x = F.relu(x)

        # The second convolution
        x = self.conv2(x)

        # Shortcut Connection
        return shortcut + x

## 5.AlterNet-LC Model Architecture

In [None]:
class AlterNet_LC(nn.Module):
    """AlterNet-LC model"""

    def __init__(
        self,
        in_channels=3,
        num_classes=2,
        blocks_per_stage=[3, 4, 6, 3],
        initial_filters=128,
        head_counts=[6, 12, 24],
        head_dim=64,
        window_size=7,
        dropout_rate=0.3
    ):
        super().__init__()

        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, initial_filters, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(initial_filters),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # networking phase
        self.stages = nn.ModuleList()
        current_filters = initial_filters

        # Stage 1 (pure CNN)
        stage1 = nn.Sequential()
        stage1.add_module('block0', PreActResidualBlock(current_filters, current_filters, use_projection=True))
        for i in range(1, blocks_per_stage[0]):
            stage1.add_module(f'block{i}', PreActResidualBlock(current_filters, current_filters))
        self.stages.append(stage1)

        # Stages 2-4 (CNN + MSA with ContrastiveHead)
        self.contrastive_blocks = nn.ModuleList()

        for stage_idx, num_blocks in enumerate(blocks_per_stage[1:]):
            stage_num = stage_idx + 2
            next_filters = current_filters * 2
            num_heads = head_counts[stage_idx]

            stage = nn.Sequential()

            # First block of the phase (downsampling)
            stage.add_module(
                'block0',
                PreActResidualBlock(current_filters, next_filters, stride=2, use_projection=True)
            )

            # Intermediate CNN blocks
            for i in range(1, num_blocks - 1):
                stage.add_module(f'block{i}', PreActResidualBlock(next_filters, next_filters))

            # The last block is MSA with comparative learning
            contrastive_block = ContrastiveMSABlock(
                dim=next_filters,
                num_heads=num_heads,
                head_dim=head_dim,
                window_size=window_size,
                temperature=0.07
            )
            stage.add_module(f'block{num_blocks-1}', contrastive_block)
            self.contrastive_blocks.append(contrastive_block)

            self.stages.append(stage)
            current_filters = next_filters

        # Classification Head
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(current_filters, num_classes)

        # Weight Initialisation
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, labels=None):
        x = self.stem(x)

        # Initialising the sum of comparison losses
        contrastive_loss_sum = 0.0

        # Stage 1
        x = self.stages[0](x)

        # Stage 2-4
        for i, stage in enumerate(self.stages[1:], 1):
            for name, block in stage.named_children():
                if name != f'block{len(stage)-1}':
                    x = block(x)
                else:
                    x, stage_contrast_loss = block(x, labels)
                    contrastive_loss_sum += stage_contrast_loss

        # Classification Head
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)

        # If it is a train pattern and labels are provided, return the classification output and comparison loss
        if self.training and labels is not None:
            return x, contrastive_loss_sum
        else:
            return x

## 6.Transformation of Pneumonia Imaging Dataset

In [None]:
class MedicalDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]  # original shape: (C, H, W)
        label = self.labels[idx]

        if self.transform:
            # Convert to PIL Image
            img = self.transform(img.numpy().transpose(1,2,0))

        return img.float(), label

## 7.Load and Preprocess Dataset Function

In [None]:
def load_and_preprocess_data(data_path, batch_size=64):

    # Extract data from npz file
    data = np.load(data_path)
    train_images = data['train_images']
    train_labels = data['train_labels'].squeeze()
    test_images = data['test_images']
    test_labels = data['test_labels'].squeeze()
    val_images = data['val_images']
    val_labels = data['val_labels'].squeeze()

    # Check data distribution
    print("Original label values:", np.unique(data['train_labels']))
    print("Training labels distribution:", np.unique(train_labels, return_counts=True))
    print("Validation labels distribution:", np.unique(val_labels, return_counts=True))
    print("Test labels distribution:", np.unique(test_labels, return_counts=True))

    # Ensure images have 3 channels (RGB)
    def ensure_3channel(images):
        if images.ndim == 3:
            images = np.expand_dims(images, axis=-1)
        if images.shape[-1] == 1:
            images = np.repeat(images, 3, axis=-1)
        return images

    train_images = ensure_3channel(train_images)
    val_images = ensure_3channel(val_images)
    test_images = ensure_3channel(test_images)

    # Normalize and convert to PyTorch format (N, C, H, W)
    def preprocess_images(images):
        images = np.transpose(images, (0, 3, 1, 2))
        return images

    train_images = preprocess_images(train_images)
    val_images = preprocess_images(val_images)
    test_images = preprocess_images(test_images)

    # Data augmentation transforms
    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    basic_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    # Create PyTorch datasets
    train_dataset = MedicalDataset(
        torch.tensor(train_images, dtype=torch.uint8),
        torch.tensor(train_labels, dtype=torch.long),
        transform=train_transform
    )

    val_dataset = MedicalDataset(
        torch.tensor(val_images, dtype=torch.uint8),
        torch.tensor(val_labels, dtype=torch.long),
        transform=basic_transform
    )

    test_dataset = MedicalDataset(
        torch.tensor(test_images, dtype=torch.uint8),
        torch.tensor(test_labels, dtype=torch.long),
        transform=basic_transform
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
    )

    return train_loader, val_loader, test_loader

## 8.LMF Loss

In [None]:
class LMFLoss(nn.Module):
    def __init__(self, class_weights, gamma=2.6, margin=0.7, alpha=0.27):
        super().__init__()
        self.gamma = gamma
        self.margin = margin
        self.alpha = alpha
        self.class_weights = class_weights

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.class_weights)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt)**self.gamma * ce_loss

        # Margin loss
        targets_onehot = F.one_hot(targets, num_classes=2).float()
        probs = inputs.softmax(dim=1)

        # Applying a boundary penalty for negative samples
        neg_margin = self.margin * 1.15
        pos_margin = self.margin

        # Choose different margins depending on the target category
        margins = torch.zeros_like(targets_onehot)
        margins[:, 0] = neg_margin
        margins[:, 1] = pos_margin

        # Select the actual margin value to be used
        selected_margins = (margins * targets_onehot).sum(1, keepdim=True)
        margins = selected_margins * (1 - torch.abs(targets_onehot - probs))

        margin_loss = F.relu(1 - (probs * (2*targets_onehot-1) - margins)).mean()

        # Adding extra weight for misclassified negative samples
        neg_samples = (targets == 0)
        misclassified_neg = neg_samples & (probs[:, 1] > 0.25)

        sample_weights = torch.ones_like(targets, dtype=torch.float32)
        sample_weights[misclassified_neg] = 1.1

        weighted_loss = (focal_loss + self.alpha*margin_loss) * sample_weights

        return weighted_loss.mean()

## 9.Train_model Function

In [None]:
def train_model(model, train_loader, val_loader, epochs=100, lr=5e-5, weight_decay=1.6e-4):
    """Train and validate a PyTorch model with contrastive learning"""

    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Learning rate scheduler
    scheduler = ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6, verbose=True
    )

    # Loss function
    criterion = LMFLoss(
        class_weights=torch.tensor([1.3, 1.25]).to(device),
        gamma=2.6,
        margin=0.7,
        alpha=0.27
    )

    # Contrastive learning weight
    contrastive_weight = 0.45

    # Track best model
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    max_patience = 15

    # Record training history
    history = {
        'train_loss': [], 'train_acc': [], 'train_auc': [],
        'val_loss': [], 'val_acc': [], 'val_auc': []
    }

    # Training loop
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        train_outputs_all, train_labels_all = [], []

        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs, contrastive_loss = model(inputs, labels)

            # Calculate LMF loss
            lmf_loss = criterion(outputs, labels)

            # Total loss = LMF loss + contrastive learning weight * contrastive loss
            total_loss = lmf_loss + contrastive_weight * contrastive_loss

            # Backward pass
            total_loss.backward()
            optimizer.step()

            # Statistics collection
            train_loss += lmf_loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            train_correct += torch.sum(preds == labels).item()
            train_total += inputs.size(0)
            train_outputs_all.append(F.softmax(outputs, dim=1).detach().cpu().numpy())
            train_labels_all.append(labels.cpu().numpy())

        # Calculate epoch-level training metrics
        train_loss = train_loss / train_total
        train_acc = train_correct / train_total

        # Merge all batch outputs and labels
        train_outputs_all = np.vstack(train_outputs_all)
        train_labels_all = np.concatenate(train_labels_all)

        # Calculate training AUC metrics
        train_fpr, train_tpr, _ = roc_curve(
            (train_labels_all == 1).astype(int),
            train_outputs_all[:, 1]
        )
        train_auc_score = auc(train_fpr, train_tpr)

        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        val_outputs_all, val_labels_all = [], []

        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)

                # Calculate LMF loss
                lmf_loss = criterion(outputs, labels)

                # Only use classification loss in validation phase
                val_loss += lmf_loss.item() * inputs.size(0)

                _, preds = torch.max(outputs, 1)
                val_correct += torch.sum(preds == labels).item()
                val_total += inputs.size(0)
                val_outputs_all.append(F.softmax(outputs, dim=1).detach().cpu().numpy())
                val_labels_all.append(labels.cpu().numpy())

        # Calculate epoch-level validation metrics
        val_loss = val_loss / val_total
        val_acc = val_correct / val_total

        # Merge all batch outputs and labels
        val_outputs_all = np.vstack(val_outputs_all)
        val_labels_all = np.concatenate(val_labels_all)

        # Calculate validation AUC
        val_fpr, val_tpr, _ = roc_curve(
            (val_labels_all == 1).astype(int),
            val_outputs_all[:, 1]
        )
        val_auc_score = auc(val_fpr, val_tpr)

        # Update learning rate
        scheduler.step(val_loss)

        # Update history records
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['train_auc'].append(train_auc_score)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_auc'].append(val_auc_score)

        # Output current epoch results
        print(f"Epoch {epoch+1}/{epochs}:")
        print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, AUC: {train_auc_score:.4f}")
        print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, AUC: {val_auc_score:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
            print(f"  New best model saved! (Val Loss: {val_loss:.4f})")

            # Save model file
            torch.save(best_model_state, f'alternet_contrastive_{datetime.datetime.now().strftime("%Y%m%d")}.pth')
        else:
            patience_counter += 1
            print(f"  No improvement for {patience_counter} epochs")

        # Early stopping
        if patience_counter >= max_patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model, history


## 10.Training Curve Plotting Function

In [None]:
def plot_training(history):
    plt.figure(figsize=(16, 10))

    # Accuracy curve
    plt.subplot(2, 2, 1)
    plt.plot(history['train_acc'], label='Training Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Accuracy vs. Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    # Loss curve
    plt.subplot(2, 2, 2)
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss vs. Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.tight_layout()
    plt.show()

## 11.Model Evaluation & Result Saving

In [None]:
def evaluate_model(model, test_loader, save_results=False):

    # Get timestamp for result identification
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    model.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0
    all_preds = []
    all_labels = []
    all_outputs = []
    all_inputs = []

    # loss function
    criterion = LMFLoss(
        class_weights=torch.tensor([1.3, 1.25]).to(device),
        gamma=2.6,
        margin=0.7,
        alpha=0.27
    )

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Evaluating Model"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # No need to calculate contrastive loss during evaluation
            outputs = model(inputs)

            # Handle tuple output (output, contrastive_loss) if present
            if isinstance(outputs, tuple):
                outputs = outputs[0]

            loss = criterion(outputs, labels)

            # Custom threshold prediction
            probs = F.softmax(outputs, dim=1)[:, 1]
            preds = (probs >= 0.985).long()

            test_loss += loss.item() * inputs.size(0)
            test_correct += torch.sum(preds == labels).item()
            test_total += inputs.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_outputs.extend(probs.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_outputs = np.array(all_outputs)

    # Calculate base metrics
    accuracy = test_correct / test_total
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)

    tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
    specificity = tn / (tn + fp)
    npv = tn / (tn + fn)

    # Calculate AUC and AP
    fpr, tpr, _ = roc_curve(all_labels, all_outputs)
    roc_auc = auc(fpr, tpr)
    precision_curve, recall_curve, _ = precision_recall_curve(all_labels, all_outputs)
    pr_auc = average_precision_score(all_labels, all_outputs)

    # Generate classification report
    class_report = classification_report(all_labels, all_preds, target_names=['Negative', 'Positive'])


    # Save results
    if save_results:
        result_file = os.path.join(results_dir, f"evaluation_results_{timestamp}.txt")
        with open(result_file, 'w', encoding='utf-8') as f:
            # Write basic information
            f.write("=" * 50 + "\n")
            f.write("Pneumonia X-ray Image Classification Model Evaluation Results\n")
            f.write(f"Evaluation Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write("=" * 50 + "\n\n")

            # Write performance metrics
            f.write("Basic Performance Metrics:\n")
            f.write("-" * 40 + "\n")
            f.write(f"Test Loss: {test_loss/test_total:.4f}\n")
            f.write(f"Accuracy (Accuracy): {accuracy*100:.2f}%\n")
            f.write(f"Precision (Precision): {precision:.4f}\n")
            f.write(f"Recall/Sensitivity (Recall/Sensitivity): {recall:.4f}\n")
            f.write(f"Specificity (Specificity): {specificity:.4f}\n")
            f.write(f"F1 Score: {f1:.4f}\n")
            f.write(f"Negative Predictive Value (NPV): {npv:.4f}\n")
            f.write(f"AUC: {roc_auc:.4f}\n")
            f.write(f"PR-AUC: {pr_auc:.4f}\n\n")

            # Write confusion matrix
            f.write("Confusion Matrix:\n")
            f.write("-" * 40 + "\n")
            f.write("Prediction\\True  Negative(0)  Positive(1)\n")
            f.write(f"Negative(0)    {cm[0, 0]}      {cm[0, 1]}\n")
            f.write(f"Positive(1)    {cm[1, 0]}      {cm[1, 1]}\n\n")

            # Write class detailed performance
            f.write("Class Detailed Performance:\n")
            f.write("-" * 40 + "\n")
            f.write(f"Accuracy of Negative Samples: {(cm[0, 0]/(cm[0, 0]+cm[1, 0]))*100:.2f}% ({cm[0, 0]}/{cm[0, 0]+cm[1, 0]})\n")
            f.write(f"Accuracy of Positive Samples: {(cm[1, 1]/(cm[1, 1]+cm[0, 1]))*100:.2f}% ({cm[1, 1]}/{cm[1, 1]+cm[0, 1]})\n\n")

            # Write classification report
            f.write("Classification Report:\n")
            f.write("-" * 40 + "\n")
            f.write(class_report + "\n\n")

            # Write evaluation conclusion
            f.write("Evaluation Conclusion:\n")
            f.write("-" * 40 + "\n")
            f.write(f"The model's performance in pneumonia detection is {'excellent' if accuracy > 0.9 else 'good' if accuracy > 0.8 else 'fair'}.\n")
            f.write(f"It is worth noting that the false negative rate is {fn/(fn+tp):.4f}, which is particularly important in medical diagnosis.\n")

            # Record generated chart files
            f.write("\nGenerated Chart Files:\n")
            f.write("-" * 40 + "\n")

    # Console output
    print(f"\nTest Results:")
    print(f"  Loss: {test_loss/test_total:.4f}")
    print(f"  Accuracy: {accuracy*100:.2f}%")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall: {recall:.4f}")
    print(f"  Specificity: {specificity:.4f}")
    print(f"  F1 Score: {f1:.4f}")
    print(f"  AUC: {roc_auc:.4f}")

    # Confusion matrix display
    print("\nConfusion Matrix:")
    print("Prediction\\True  Negative(0)  Positive(1)")
    print(f"Negative(0)    {cm[0, 0]}      {cm[0, 1]}")
    print(f"Positive(1)    {cm[1, 0]}      {cm[1, 1]}")

    # Class detailed performance
    neg_acc = cm[0, 0]/(cm[0, 0]+cm[1, 0])
    pos_acc = cm[1, 1]/(cm[1, 1]+cm[0, 1])
    print(f"\n  Accuracy of Negative Samples: {neg_acc*100:.2f}% ({cm[0, 0]}/{cm[0, 0]+cm[1, 0]})")
    print(f"  Accuracy of Positive Samples: {pos_acc*100:.2f}% ({cm[1, 1]}/{cm[1, 1]+cm[0, 1]})")


    # Return to all assessment indicators
    evaluation_results = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'f1': f1,
        'npv': npv,
        'auc': roc_auc,
        'pr_auc': pr_auc,
        'confusion_matrix': cm,
        'loss': test_loss/test_total,
        'class_report': class_report,
        'fpr': fpr,
        'tpr': tpr,
        'pr_curve': precision_curve,
        'recall_curve': recall_curve,
        'all_preds':all_preds,
        'all_labels':all_labels,
        'all_outputs':all_outputs
    }

    return evaluation_results

## 12.Run

In [None]:
# Load Data
print("Load Data...")
train_loader, val_loader, test_loader = load_and_preprocess_data(
    data_path, batch_size=64
)
model = AlterNet_LC(
        in_channels=3,
        num_classes=2,
        blocks_per_stage=[3, 4, 6, 3],
        initial_filters=128,
        head_counts=[6, 12, 24],
        head_dim=64,
        window_size=7,
        dropout_rate=0.3
    ).to(device)

# Print model information
total_params = sum(p.numel() for p in model.parameters())
print(f"Total model parameters: {total_params:,}")

In [None]:
# Train model
print("Start Training...")
model, history = train_model(
    model,
    train_loader,
    val_loader,
    epochs=100,
    lr=5e-5,
    weight_decay=1.6e-4
)

In [None]:
plot_training(history)

In [None]:
evaluate_results = evaluate_model(model, test_loader)
fpr = evaluate_results['fpr']
tpr = evaluate_results['tpr']
pr_auc = evaluate_results['pr_auc']
roc_auc = auc(fpr, tpr)
recall_curve = evaluate_results['recall_curve']
precision_curve = evaluate_results['pr_curve']
all_labels = evaluate_results['all_labels']

In [None]:
# 1. Draw ROC curve
plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='random guess')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('false positive rate')
plt.ylabel('true positive rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.grid(True, linestyle='--', alpha=0.7)

plt.show()

In [None]:
# 2. Draw PR curve
plt.figure(figsize=(10, 8))
plt.plot(recall_curve, precision_curve, color='green', lw=2,
            label=f'PR Curve (AP = {pr_auc:.4f})')
plt.axhline(y=np.sum(all_labels) / len(all_labels), color='navy',
            linestyle='--', label='random guess')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('recall rate')
plt.ylabel('precision rate')
plt.title('PR Curve')
plt.legend(loc="lower left")
plt.grid(True, linestyle='--', alpha=0.7)

plt.show()

In [None]:
# Saving the final model
print("Saving the final model...")
torch.save(model.state_dict(), f'alternet_lc_{datetime.datetime.now().strftime("%Y%m%d")}.pth')
print(f"Final model saved: alternet_lc_{datetime.datetime.now().strftime('%Y%m%d')}.pth")