In [1]:
#%% -------- 1. Configuration & Imports --------
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import os
import splitfolders
from torch.optim.lr_scheduler import OneCycleLR
import torchmetrics
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.cuda.amp import autocast

# Configuration
NUM_CLASSES = 4
IMG_SIZE = 384
BATCH_SIZE = 32
NUM_WORKERS = 4
PRECISION = '16-mixed'
EPOCHS = 150
LR = 1e-4
BATCH_SIZE = 32
DATA_ROOT = "C:\\Users\\DELL 5540\\Desktop\\Brachial\\2nd Classification"
SPLIT_ROOT = "C:\\Users\\DELL 5540\\Desktop\\Brachial\\Split_Dataset"
SEED = 42
CLASS_NAMES = ['Type 0', 'Type 1 Neurapraxia', 'Type 2 Axonotemsis', 'Type 3 Neurotmesis']

# Reproducibility
pl.seed_everything(SEED)
torch.use_deterministic_algorithms(False)  # Add this line
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Seed set to 42


In [2]:
#%% -------- 2. Enhanced Data Preparation Cell --------
def prepare_dataset():
    splitfolders.ratio(
        DATA_ROOT,
        output=SPLIT_ROOT,
        seed=SEED,
        ratio=(0.7, 0.15, 0.15),
        group_prefix=None,
        move=False
    )

if not os.path.exists(os.path.join(SPLIT_ROOT, 'train')):
    prepare_dataset()

# Enhanced MRI-specific transforms
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.GaussianBlur(kernel_size=3),
    transforms.RandomAdjustSharpness(sharpness_factor=2),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)  # Fixed normalization for 3 channels
])

val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Load datasets
train_dataset = ImageFolder(os.path.join(SPLIT_ROOT, 'train'), train_transform)
val_dataset = ImageFolder(os.path.join(SPLIT_ROOT, 'val'), val_transform)
test_dataset = ImageFolder(os.path.join(SPLIT_ROOT, 'test'), val_transform)

# Class balancing with pos_weight
class_counts = np.bincount(train_dataset.targets)
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float32)  # Add this line
pos_weight = torch.tensor([class_counts[0]/class_counts[1]], dtype=torch.float32)
samples_weights = class_weights[train_dataset.targets]
sampler = WeightedRandomSampler(samples_weights, len(samples_weights), replacement=True)

In [3]:
#%% -------- 2. Medical Data Pipeline --------
class MedicalAugmentation:
    """MRI-optimized augmentation pipeline"""
    def __init__(self):
        self.affine = transforms.RandomAffine(
            degrees=(-7, 7), 
            translate=(0.05, 0.05),
            scale=(0.9, 1.1)
        )
        self.elastic = transforms.ElasticTransform(
            alpha=50.0,
            sigma=5.0
        )
        self.color_jitter = transforms.ColorJitter(
            brightness=0.1,
            contrast=0.2,
            saturation=0.1
        )
        
    def __call__(self, img):
        img = self.affine(img)
        if np.random.rand() > 0.7:
            img = self.elastic(img)
        img = self.color_jitter(img)
        return img

train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    MedicalAugmentation(),
    transforms.RandomErasing(p=0.4, scale=(0.02, 0.15)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset preparation
if not os.path.exists(os.path.join(SPLIT_ROOT, 'train')):
    splitfolders.ratio(
        DATA_ROOT,
        output=SPLIT_ROOT,
        seed=SEED,
        ratio=(0.8, 0.1, 0.1),
        group_prefix=None,
        move=False,
        oversample='auto'
    )

train_dataset = ImageFolder(os.path.join(SPLIT_ROOT, 'train'), train_transform)
val_dataset = ImageFolder(os.path.join(SPLIT_ROOT, 'val'), val_transform)
test_dataset = ImageFolder(os.path.join(SPLIT_ROOT, 'test'), val_transform)

# Class balancing
class_counts = np.bincount(train_dataset.targets)
class_weights = torch.tensor([len(train_dataset) / (count * NUM_CLASSES) for count in class_counts], dtype=torch.float32)
sampler = WeightedRandomSampler(weights=class_weights[train_dataset.targets], num_samples=len(train_dataset), replacement=True)

In [7]:
#%% -------- 4. Enhanced Model Definition Cell --------
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        loss = (1-pt)**self.gamma * ce_loss
        return loss.mean()

class BrachialPlexusResMT(nn.Module):
    def __init__(self, num_classes=4, img_size=224, patch_size=32, embed_dim=512, 
                 num_heads=8, num_layers=3, dropout=0.2):
        super().__init__()
        
        # ResNet-50 Backbone
        self.cnn_backbone = models.resnet50(pretrained=True)
        self.cnn_backbone = nn.Sequential(*list(self.cnn_backbone.children())[:-2])
        
        # Freeze initial layers
        for param in self.cnn_backbone.parameters():
            param.requires_grad = False
            
        # Unfreeze last residual block
        for param in self.cnn_backbone[-1].parameters():
            param.requires_grad = True

        # Feature adaptation
        self.feature_adapt = nn.Sequential(
            nn.Conv2d(2048, embed_dim, 1),
            nn.BatchNorm2d(embed_dim),
            nn.GELU()
        )
        
        # Positional Encoding
        self.positional_encoding = nn.Parameter(torch.randn(49, 1, embed_dim))
        
        # Transformer Encoder
        encoder_layers = TransformerEncoderLayer(
            embed_dim, num_heads, dim_feedforward=2048, dropout=dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Dropout(0.5),
            nn.Linear(embed_dim, num_classes)  # Changed to 4 outputs
        )

    def forward(self, x):
        x = self.cnn_backbone(x)
        x = self.feature_adapt(x)
        x = x.flatten(2).permute(2, 0, 1)
        x = x + self.positional_encoding
        x = self.transformer_encoder(x)
        x = x.mean(dim=0)
        return self.classifier(x)

In [8]:
#%% -------- 5. Enhanced Training Setup Cell --------
class LitModel(pl.LightningModule):
    def __init__(self, total_steps):
        super().__init__()
        self.save_hyperparameters()
        self.model = BrachialPlexusResMT(num_classes=NUM_CLASSES)
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.criterion = FocalLoss(alpha=class_weights.to(device))

        # Metrics
        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=NUM_CLASSES)
        self.train_auc = torchmetrics.AUROC(task='multiclass', num_classes=NUM_CLASSES)
        self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=NUM_CLASSES)
        self.val_auc = torchmetrics.AUROC(task='multiclass', num_classes=NUM_CLASSES)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.train_acc(preds, y)
        self.train_auc(logits.softmax(dim=1), y)
        
        self.log_dict({'train_loss': loss, 'train_acc': self.train_acc, 'train_auc': self.train_auc},
                     prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        self.val_auc(logits.softmax(dim=1), y)
        
        self.log_dict({'val_loss': loss, 'val_acc': self.val_acc, 'val_auc': self.val_auc},
                     prog_bar=True)
        return loss

    def configure_optimizers(self):
        # Differential learning rates
        backbone_params = []
        new_params = []
        for name, param in self.named_parameters():
            if 'cnn_backbone.layer4' in name:
                backbone_params.append(param)
            else:
                new_params.append(param)
        
        optimizer = optim.AdamW(
            [
                {'params': backbone_params, 'lr': LR/10},
                {'params': new_params, 'lr': LR}
            ],
            weight_decay=0.01
        )
        
        scheduler = OneCycleLR(
            optimizer,
            max_lr=[LR/10, LR],
            total_steps=self.hparams.total_steps
        )
        return [optimizer], [scheduler]

# Initialize
train_loader = DataLoader(train_dataset, BATCH_SIZE, sampler=sampler, 
                         num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, BATCH_SIZE, num_workers=NUM_WORKERS,
                       pin_memory=True)
model = LitModel(total_steps=EPOCHS*len(train_loader))

# Callbacks
checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_auc',
    mode='max',
    filename='best-{epoch}-{val_auc:.3f}',
    save_top_k=2
)

early_stop = pl.callbacks.EarlyStopping(
    monitor='val_auc',
    patience=25,  # Increased patience
    mode='max',
    min_delta=0.001
)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=EPOCHS,
    precision=PRECISION,
    callbacks=[checkpoint, early_stop],
    accumulate_grad_batches=4,
    gradient_clip_val=1.0,
    log_every_n_steps=10,
    deterministic="warn"
)

Using 16bit Automatic Mixed Precision (AMP)
C:\Users\DELL 5540\AppData\Roaming\Python\Python311\site-packages\lightning\pytorch\plugins\precision\amp.py:54: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
#%% -------- 6. Training Execution Cell --------
trainer.fit(model, train_loader, val_loader)
best_model = LitModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                | Params
--------------------------------------------------
0 | model     | BrachialPlexusResMT | 34.0 M
1 | criterion | FocalLoss           | 0     
2 | train_acc | MulticlassAccuracy  | 0     
3 | train_auc | MulticlassAUROC     | 0     
4 | val_acc   | MulticlassAccuracy  | 0     
5 | val_auc   | MulticlassAUROC     | 0     
--------------------------------------------------
25.5 M    Trainable params
8.5 M     Non-trainable params
34.0 M    Total params
136.174   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\DELL 5540\AppData\Roaming\Python\Python311\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


RuntimeError: The size of tensor a (144) must match the size of tensor b (49) at non-singleton dimension 0

In [None]:
#%% -------- 6. Advanced Evaluation --------
from sklearn.metrics import auc, roc_curve


class MedicalEvaluator:
    def __init__(self, model):
        self.model = model
        self.model.eval()
        self.tta_transforms = [
            lambda x: x,
            lambda x: torch.flip(x, [-1]),
            lambda x: torch.rot90(x, 1, [-2, -1]),
            lambda x: torch.rot90(x, -1, [-2, -1])
        ]
        
    def predict(self, x):
        with torch.no_grad(), autocast():
            logits = torch.zeros((x.size(0), NUM_CLASSES), device=x.device)
            for transform in self.tta_transforms:
                augmented = transform(x)
                logits += self.model(augmented)
            return logits / len(self.tta_transforms)
    
    def evaluate(self, loader):
        all_preds = []
        all_targets = []
        all_probs = []
        
        for x, y in loader:
            x = x.to(self.model.device)
            logits = self.predict(x)
            
            probs = logits.softmax(1)
            preds = probs.argmax(1)
            
            all_probs.append(probs.cpu())
            all_preds.append(preds.cpu())
            all_targets.append(y)
            
        return (
            torch.cat(all_preds),
            torch.cat(all_targets),
            torch.cat(all_probs)
        )

# Load best model
best_model = LitMedicalModel.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path
)
evaluator = MedicalEvaluator(best_model)

# Test evaluation
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS
)

preds, targets, probs = evaluator.evaluate(test_loader)

# Classification report
print("Medical Diagnostic Report:")
print(classification_report(
    targets.numpy(),
    preds.numpy(),
    target_names=CLASS_NAMES,
    digits=4
))

# Confusion matrix
cm = confusion_matrix(targets.numpy(), preds.numpy())
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
plt.title('Medical Diagnostic Confusion Matrix')
plt.show()

# ROC curves
fpr, tpr, _ = roc_curve(targets.numpy(), probs.numpy()[:, 1], pos_label=1)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()

In [None]:
#%% -------- 7. Visualization Utilities --------
def visualize_attention(model, img_tensor):
    model.eval()
    with torch.no_grad():
        features = model.model.cnn(img_tensor.unsqueeze(0))
        features = model.model.adapt(features)
        B, C, H, W = features.shape
        spatial = features.view(B, C, -1).permute(0, 2, 1)
        spatial += model.model.pos_embed
        attn_weights = model.model.attention(spatial).squeeze()
        
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img_tensor.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
    plt.title("Original Image")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(attn_weights.view(H, W).cpu().numpy(), cmap='jet')
    plt.title("Attention Map")
    plt.axis('off')
    plt.show()

# Example usage
sample, _ = test_dataset[0]
visualize_attention(best_model, sample)