In [3]:
#%% -------- 1. Configuration & Imports --------
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import os
import splitfolders
from torch.optim.lr_scheduler import OneCycleLR
import torchmetrics
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.cuda.amp import autocast

# Configuration
NUM_CLASSES = 4
IMG_SIZE = 384
BATCH_SIZE = 32
NUM_WORKERS = 4
PRECISION = '16-mixed'
EPOCHS = 150
LR = 2e-4
WARMUP_PCT = 0.1
DATA_ROOT = "C:\\Users\\DELL 5540\\Desktop\\Brachial\\2nd Classification"
SPLIT_ROOT = "C:\\Users\\DELL 5540\\Desktop\\Brachial\\Split_Dataset"
SEED = 42
CLASS_NAMES = ['Type 0', 'Type 1 Neurapraxia', 'Type 2 Axonotemsis', 'Type 3 Neurotmesis']

# Reproducibility
pl.seed_everything(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Seed set to 42


In [4]:
#%% -------- 2. Medical Data Pipeline --------
class MedicalAugmentation:
    """MRI-optimized augmentation pipeline"""
    def __init__(self):
        self.affine = transforms.RandomAffine(
            degrees=(-7, 7), 
            translate=(0.05, 0.05),
            scale=(0.9, 1.1)
        )
        self.elastic = transforms.ElasticTransform(
            alpha=50.0,
            sigma=5.0
        )
        self.color_jitter = transforms.ColorJitter(
            brightness=0.1,
            contrast=0.2,
            saturation=0.1
        )
        
    def __call__(self, img):
        img = self.affine(img)
        if np.random.rand() > 0.7:
            img = self.elastic(img)
        img = self.color_jitter(img)
        return img

train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    MedicalAugmentation(),
    transforms.RandomErasing(p=0.4, scale=(0.02, 0.15)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset preparation
if not os.path.exists(os.path.join(SPLIT_ROOT, 'train')):
    splitfolders.ratio(
        DATA_ROOT,
        output=SPLIT_ROOT,
        seed=SEED,
        ratio=(0.8, 0.1, 0.1),
        group_prefix=None,
        move=False,
        oversample='auto'
    )

train_dataset = ImageFolder(os.path.join(SPLIT_ROOT, 'train'), train_transform)
val_dataset = ImageFolder(os.path.join(SPLIT_ROOT, 'val'), val_transform)
test_dataset = ImageFolder(os.path.join(SPLIT_ROOT, 'test'), val_transform)

# Class balancing
class_counts = np.bincount(train_dataset.targets)
class_weights = torch.tensor([1.0 / (count + 1e-5) for count in class_counts], dtype=torch.float32)
sampler = WeightedRandomSampler(weights=class_weights[train_dataset.targets], num_samples=len(train_dataset), replacement=True)

In [8]:
#%% -------- 3. Enhanced ResMT Architecture --------
class MedicalResMT(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        
        # 1. Hybrid Backbone
        self.cnn = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-3])  # Until layer3
        
        # 2. Feature Adaptation
        self.adapt = nn.Sequential(
            nn.Conv2d(1024, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.GELU(),
            nn.AdaptiveAvgPool2d((14, 14)))
        
        # 3. Positional Encoding
        self.pos_embed = nn.Parameter(torch.randn(1, 196, 512) * 0.02)
        
        # 4. Transformer Encoder
        self.transformer = nn.Sequential(
            TransformerEncoderLayer(
                d_model=512,
                nhead=8,
                dim_feedforward=2048,
                dropout=0.1,
                batch_first=True
            ),
            TransformerEncoderLayer(
                d_model=512,
                nhead=8,
                dim_feedforward=2048,
                dropout=0.1,
                batch_first=True
            )
        )
        
        # 5. Attention Gate
        self.attention = nn.Sequential(
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
        # 6. Classifier Head
        self.head = nn.Sequential(
            nn.LayerNorm(512),
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes))
        
        self._init_weights()

    def _init_weights(self):
        # Transformer initialization
        for m in self.transformer.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
                    
        # Attention initialization
        nn.init.kaiming_normal_(self.attention[0].weight)
        nn.init.kaiming_normal_(self.attention[2].weight)

    def forward(self, x):
        # CNN Features
        x = self.cnn(x)  # [B, 1024, 28, 28]
        x = self.adapt(x)  # [B, 512, 14, 14]
        
        # Prepare for Transformer
        B, C, H, W = x.shape
        x = x.view(B, C, -1).permute(0, 2, 1)  # [B, 196, 512]
        x = x + self.pos_embed
        
        # Transformer Processing
        x = self.transformer(x)
        
        # Attention-based Pooling
        attn_weights = self.attention(x)  # [B, 196, 1]
        x = torch.sum(x * attn_weights, dim=1)  # Weighted sum
        
        return self.head(x)

In [9]:
#%% -------- 4. Optimized Training Setup --------
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.register_buffer('alpha', alpha)  # Store as buffer
        self.gamma = gamma

    def forward(self, inputs, targets):
        if self.alpha is not None:
            # Ensure alpha is on same device as inputs
            alpha = self.alpha.to(inputs.device)
        else:
            alpha = None
            
        ce_loss = nn.functional.cross_entropy(
            inputs, 
            targets, 
            reduction='none', 
            weight=alpha
        )
        pt = torch.exp(-ce_loss)
        loss = (1 - pt) ** self.gamma * ce_loss
        return loss.mean()
    

class LitMedicalModel(pl.LightningModule):
    def __init__(self, total_steps):
        super().__init__()
        self.save_hyperparameters()
        self.model = MedicalResMT(NUM_CLASSES)
        
        # Move alpha to device using register_buffer
        self.register_buffer(
            'class_weights',
            torch.tensor([1.0, 2.5, 3.0, 4.0], dtype=torch.float32)
        )
        
        self.criterion = FocalLoss(
            alpha=self.class_weights,  # Use registered buffer
            gamma=2.5
        )
        
        # Metrics
        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=NUM_CLASSES)
        self.val_auc = torchmetrics.AUROC(task='multiclass', num_classes=NUM_CLASSES)
        self.val_f1 = torchmetrics.F1Score(task='multiclass', num_classes=NUM_CLASSES)
        
        # Training state
        self.total_steps = total_steps

    def configure_optimizers(self):
        # Layer-wise optimization
        optimizer = optim.AdamW([
            {'params': self.model.cnn.parameters(), 'lr': LR/10},
            {'params': self.model.adapt.parameters(), 'lr': LR/5},
            {'params': self.model.transformer.parameters(), 'lr': LR},
            {'params': self.model.head.parameters(), 'lr': LR}
        ], weight_decay=0.05)
        
        # OneCycle schedule
        scheduler = OneCycleLR(
            optimizer,
            max_lr=[LR/10, LR/5, LR, LR],
            total_steps=self.total_steps,
            pct_start=WARMUP_PCT,
            anneal_strategy='cos'
        )
        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]

    @autocast()  
    def training_step(self, batch, batch_idx):
        # Medical Mixup
        x, y = batch
        lam = np.random.beta(0.2, 0.2)
        idx = torch.randperm(x.size(0))
        
        mixed_x = lam * x + (1 - lam) * x[idx]
        logits = self.model(mixed_x)
        
        loss = lam * self.criterion(logits, y) + (1 - lam) * self.criterion(logits, y[idx])
        
        # Metrics
        self.train_acc(logits.softmax(dim=1), y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', self.train_acc, prog_bar=True)
        return loss

    @autocast()
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.criterion(logits, y)
        
        self.val_auc(logits.softmax(dim=1), y)
        self.val_f1(logits.softmax(dim=1), y)
        self.log_dict({
            'val_loss': loss,
            'val_auc': self.val_auc,
            'val_f1': self.val_f1
        }, prog_bar=True)
        return loss

  @autocast()
  @autocast()


In [10]:
#%% -------- 5. Training Execution --------
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True
)

model = LitMedicalModel(total_steps=EPOCHS * len(train_loader))

# Callbacks
checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_auc',
    mode='max',
    save_top_k=3,
    filename='best-{epoch}-{val_auc:.4f}'
)

early_stop = pl.callbacks.EarlyStopping(
    monitor='val_auc',
    patience=25,
    mode='max',
    min_delta=0.005
)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=EPOCHS,
    precision=PRECISION,
    callbacks=[checkpoint, early_stop],
    gradient_clip_val=0.5,
    accumulate_grad_batches=2,
    deterministic=True
)

trainer.fit(model, train_loader, val_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | MedicalResMT       | 19.9 M
1 | criterion | FocalLoss          | 0     
2 | train_acc | MulticlassAccuracy | 0     
3 | val_auc   | MulticlassAUROC    | 0     
4 | val_f1    | MulticlassF1Score  | 0     
-------------------------------------------------
19.9 M    Trainable params
0         Non-trainable params
19.9 M    Total params
79.734    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

RuntimeError: cumsum_cuda_kernel does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.

In [None]:
#%% -------- 6. Advanced Evaluation --------
from sklearn.metrics import auc, roc_curve


class MedicalEvaluator:
    def __init__(self, model):
        self.model = model
        self.model.eval()
        self.tta_transforms = [
            lambda x: x,
            lambda x: torch.flip(x, [-1]),
            lambda x: torch.rot90(x, 1, [-2, -1]),
            lambda x: torch.rot90(x, -1, [-2, -1])
        ]
        
    def predict(self, x):
        with torch.no_grad(), autocast():
            logits = torch.zeros((x.size(0), NUM_CLASSES), device=x.device)
            for transform in self.tta_transforms:
                augmented = transform(x)
                logits += self.model(augmented)
            return logits / len(self.tta_transforms)
    
    def evaluate(self, loader):
        all_preds = []
        all_targets = []
        all_probs = []
        
        for x, y in loader:
            x = x.to(self.model.device)
            logits = self.predict(x)
            
            probs = logits.softmax(1)
            preds = probs.argmax(1)
            
            all_probs.append(probs.cpu())
            all_preds.append(preds.cpu())
            all_targets.append(y)
            
        return (
            torch.cat(all_preds),
            torch.cat(all_targets),
            torch.cat(all_probs)
        )

# Load best model
best_model = LitMedicalModel.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path
)
evaluator = MedicalEvaluator(best_model)

# Test evaluation
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS
)

preds, targets, probs = evaluator.evaluate(test_loader)

# Classification report
print("Medical Diagnostic Report:")
print(classification_report(
    targets.numpy(),
    preds.numpy(),
    target_names=CLASS_NAMES,
    digits=4
))

# Confusion matrix
cm = confusion_matrix(targets.numpy(), preds.numpy())
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
plt.title('Medical Diagnostic Confusion Matrix')
plt.show()

# ROC curves
fpr, tpr, _ = roc_curve(targets.numpy(), probs.numpy()[:, 1], pos_label=1)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()

In [None]:
#%% -------- 7. Visualization Utilities --------
def visualize_attention(model, img_tensor):
    model.eval()
    with torch.no_grad():
        features = model.model.cnn(img_tensor.unsqueeze(0))
        features = model.model.adapt(features)
        B, C, H, W = features.shape
        spatial = features.view(B, C, -1).permute(0, 2, 1)
        spatial += model.model.pos_embed
        attn_weights = model.model.attention(spatial).squeeze()
        
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img_tensor.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
    plt.title("Original Image")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(attn_weights.view(H, W).cpu().numpy(), cmap='jet')
    plt.title("Attention Map")
    plt.axis('off')
    plt.show()

# Example usage
sample, _ = test_dataset[0]
visualize_attention(best_model, sample)