In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    precision_recall_fscore_support,
    roc_curve, 
    auc
)
import pylidc as pl
from skimage.transform import resize
import matplotlib.pyplot as plt
import seaborn as sns
import copy

# Optional: Use Albumentations for advanced augmentation
USE_ALBUMENTATIONS = True
if USE_ALBUMENTATIONS:
    import albumentations as A
    from albumentations.pytorch import ToTensorV2

# -----------------------------
# Data Loading and Preprocessing
# -----------------------------
def load_nodule_data(
    max_scans=400,
    print_counts=True, 
    min_nodule_size=10, 
    max_nodule_size=50
):
    scans = pl.query(pl.Scan).all()
    nodule_images, labels = [], []
    processed_scan_ids = set()
    stats = {
        'total_scans': 0,
        'total_nodules': 0,
        'benign_nodules': 0,
        'malignant_nodules': 0,
        'discarded_nodules': 0
    }

    for scan_idx, scan in enumerate(scans[:max_scans]):
        if scan.id in processed_scan_ids:
            continue
        stats['total_scans'] += 1
        volume = scan.to_volume()
        for nodule in scan.cluster_annotations():
            stats['total_nodules'] += 1
            try:
                malignancy_scores = [anno.malignancy for anno in nodule if anno.malignancy is not None]
                if not malignancy_scores:
                    stats['discarded_nodules'] += 1
                    continue
                malignancy = np.mean(malignancy_scores)
                label = 1 if malignancy >= 3 else 0
                coords = np.mean([anno.centroid for anno in nodule], axis=0)
                if np.any(np.isnan(coords)):
                    stats['discarded_nodules'] += 1
                    continue
                x, y, z = np.round(coords).astype(int)
                size = np.random.randint(min_nodule_size, max_nodule_size)
                x_start = max(0, x - size // 2)
                x_end = min(volume.shape[0], x + size // 2)
                y_start = max(0, y - size // 2)
                y_end = min(volume.shape[1], y + size // 2)
                z_start = max(0, z - size // 2)
                z_end = min(volume.shape[2], z + size // 2)
                nodule_patch = volume[x_start:x_end, y_start:y_end, z_start:z_end]
                if len(nodule_patch.shape) < 3:
                    stats['discarded_nodules'] += 1
                    continue
                central_slice = nodule_patch[:, :, nodule_patch.shape[2] // 2]
                resized_slice = resize(central_slice, (32, 32), mode='constant', anti_aliasing=True)
                # Normalize robustly (results in floats, not 0-255)
                resized_slice = (resized_slice - np.mean(resized_slice)) / (np.std(resized_slice) + 1e-8)
                nodule_images.append(resized_slice)
                labels.append(label)
                if label == 1:
                    stats['malignant_nodules'] += 1
                else:
                    stats['benign_nodules'] += 1
            except Exception as e:
                stats['discarded_nodules'] += 1

    nodule_images = np.array(nodule_images)
    labels = np.array(labels)
    if print_counts:
        print("\nNodule Processing Statistics:")
        for key, value in stats.items():
            print(f"{key.replace('_', ' ').title()}: {value}")
    return nodule_images, labels

# -----------------------------
# Dataset with Enhanced Data Augmentation
# -----------------------------
class LungNoduleDataset(Dataset):
    def __init__(self, images, labels, mode='train'):
        self.images = images
        self.labels = labels
        self.mode = mode
        if USE_ALBUMENTATIONS:
            if self.mode == 'train':
                self.transform = A.Compose([
                    A.RandomRotate90(p=0.5),
                    A.HorizontalFlip(p=0.5),
                    A.Transpose(p=0.5),
                    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
                    A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
                    A.RandomBrightnessContrast(p=0.3),
                    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.2),
                    A.Normalize(mean=(0.5,), std=(0.5,)),
                    ToTensorV2()
                ])
            else:
                self.transform = A.Compose([
                    A.Resize(32, 32),
                    A.Normalize(mean=(0.5,), std=(0.5,)),
                    ToTensorV2()
                ])
        else:
            if self.mode == 'train':
                self.transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.RandomRotation(degrees=15),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomVerticalFlip(p=0.5),
                    transforms.RandomAffine(degrees=10, shear=5, scale=(0.9, 1.1)),
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.2),
                    transforms.Grayscale(num_output_channels=1),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.5], std=[0.5]),
                ])
            else:
                self.transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.Grayscale(num_output_channels=1),
                    transforms.Resize((32, 32)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.5], std=[0.5]),
                ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        if USE_ALBUMENTATIONS:
            # Convert the normalized float image (from load_nodule_data) to standard 0-255 uint8.
            img_rescaled = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
            img_uint8 = np.clip(img_rescaled * 255, 0, 255).astype(np.uint8)
            # Albumentations expects image shape (H, W, C) even for grayscale.
            img_uint8 = np.expand_dims(img_uint8, axis=-1)
            transformed = self.transform(image=img_uint8)
            img = transformed['image']
        else:
            img = np.expand_dims(img, 0).astype(np.uint8)
            img = np.transpose(img, (1, 2, 0))
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.long)

# -----------------------------
# Advanced Capsule Network with Decoder (Reconstruction) Branch
# -----------------------------
class CapsNetWithDecoder(nn.Module):
    def __init__(self, num_capsules, capsule_dim, num_classes, reconstruction_weight=0.0005, routing_iters=3):
        super(CapsNetWithDecoder, self).__init__()
        self.reconstruction_weight = reconstruction_weight
        self.routing_iters = routing_iters

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 256, kernel_size=9, stride=1, padding=0),
            nn.BatchNorm2d(256),
            nn.PReLU(),
            nn.Dropout2d(0.3)
        )
        self.primary_capsules = nn.Sequential(
            nn.Conv2d(256, num_capsules * capsule_dim, kernel_size=9, stride=2, padding=0),
            nn.BatchNorm2d(num_capsules * capsule_dim),
            nn.PReLU(),
            nn.Dropout2d(0.3)
        )
        self.num_capsules = num_capsules
        self.capsule_dim = capsule_dim
        self.num_classes = num_classes

        self.classification_head = nn.Sequential(
            nn.Linear(num_capsules * capsule_dim, 256),
            nn.BatchNorm1d(256),
            nn.PReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.PReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, num_classes)
        )

        self.decoder = nn.Sequential(
            nn.Linear(num_capsules * capsule_dim, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 32 * 32),
            nn.Sigmoid()
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def squash(self, x):
        squared_norm = (x ** 2).sum(dim=-1, keepdim=True)
        scale = squared_norm / (1 + squared_norm + 1e-8)
        return scale * x / (torch.sqrt(squared_norm) + 1e-8)

    def dynamic_routing(self, u_hat):
        batch_size, input_capsules, output_capsules, capsule_dim = u_hat.shape
        b_ij = torch.zeros(batch_size, input_capsules, output_capsules, device=u_hat.device)
        for iteration in range(self.routing_iters):
            c_ij = F.softmax(b_ij, dim=2)
            s_j = torch.sum(c_ij.unsqueeze(-1) * u_hat, dim=1)
            v_j = self.squash(s_j)
            if iteration < self.routing_iters - 1:
                b_ij = b_ij + torch.sum(u_hat * v_j.unsqueeze(1), dim=-1)
        return v_j

    def forward(self, x):
        x = self.conv1(x)
        u_hat = self.primary_capsules(x)
        batch_size = u_hat.size(0)
        u_hat = u_hat.view(batch_size, self.num_capsules, self.capsule_dim, -1)
        u_hat = u_hat.permute(0, 3, 1, 2)
        v_j = self.dynamic_routing(u_hat)
        v_j_flat = v_j.view(batch_size, -1)
        class_output = self.classification_head(v_j_flat)
        reconstruction = self.decoder(v_j_flat)
        reconstruction = reconstruction.view(-1, 1, 32, 32)
        return class_output, reconstruction

# -----------------------------
# Weighted Margin Loss Function
# -----------------------------
def margin_loss(y_true, y_pred, m_plus=0.9, m_minus=0.1, lambda_=0.5, class_weights=None):
    """
    Advanced margin loss for capsule networks with optional class weights.
    If class_weights is provided (a list or tensor of shape [num_classes]),
    each sample's loss is multiplied by the weight corresponding to its true class.
    """
    y_true_one_hot = F.one_hot(y_true, num_classes=y_pred.size(1)).float()
    y_pred = torch.clamp(y_pred, min=0.0, max=1.0)
    positive_loss = F.relu(m_plus - y_pred).pow(2) * y_true_one_hot
    negative_loss = F.relu(y_pred - m_minus).pow(2) * (1 - y_true_one_hot)
    loss = positive_loss + lambda_ * negative_loss
    if class_weights is not None:
        class_weights = torch.tensor(class_weights, device=y_true.device, dtype=torch.float32)
        sample_weights = (y_true_one_hot * class_weights).sum(dim=1)
        loss = loss.mean(dim=1) * sample_weights
        return loss.mean()
    return loss.mean()

# -----------------------------
# Evaluation Functions
# -----------------------------
def evaluate_model(model, val_loader, device):
    model.eval()
    all_preds, all_targets, all_probs = [], [], []
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output, _ = model(data)
            probs = F.softmax(output, dim=1)
            preds = output.argmax(1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_targets, all_preds, average='binary'
    )
    cm = confusion_matrix(all_targets, all_preds)
    fpr, tpr, _ = roc_curve(all_targets, [prob[1] for prob in all_probs])
    roc_auc = auc(fpr, tpr)
    return {
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm,
        'roc_auc': roc_auc,
        'fpr': fpr,
        'tpr': tpr,
        'predictions': all_preds,
        'targets': all_targets,
        'probabilities': all_probs
    }

def visualize_training_metrics(history):
    plt.figure(figsize=(15, 5))
    plt.subplot(131)
    plt.title('Loss')
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(132)
    plt.title('Accuracy')
    plt.plot(history['train_accuracy'], label='Train Accuracy')
    plt.plot(history['val_accuracy'], label='Val Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.subplot(133)
    plt.title('Learning Rate')
    plt.plot(history['learning_rate'], label='Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.legend()
    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.show()
    plt.close()

def visualize_nodules(images, labels, num_samples=3):
    plt.figure(figsize=(15, 5))
    for i in range(num_samples):
        plt.subplot(1, num_samples, i+1)
        sample_idx = np.random.randint(len(images))
        sample_img = images[sample_idx].squeeze()
        label = labels[sample_idx]
        plt.imshow(sample_img, cmap='gray')
        plt.title(f'Nodule {"Malignant" if label == 1 else "Benign"}')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('nodule_samples.png')
    plt.show()
    plt.close()

def plot_confusion_matrix(cm):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Benign', 'Malignant'], 
                yticklabels=['Benign', 'Malignant'])
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.show()
    plt.close()

def plot_roc_curve(fpr, tpr, roc_auc):
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC)')
    plt.legend(loc='lower right')
    plt.savefig('roc_curve.png')
    plt.show()
    plt.close()

# -----------------------------
# Main Training and Evaluation Loop with Early Stopping
# -----------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    images, labels = load_nodule_data(max_scans=400)
    X_train, X_val, y_train, y_val = train_test_split(
        images, labels, test_size=0.2, random_state=42, stratify=labels
    )
    visualize_nodules(images, labels)
    train_dataset = LungNoduleDataset(X_train, y_train, mode='train')
    val_dataset = LungNoduleDataset(X_val, y_val, mode='val')

    # Create a weighted sampler to balance classes if needed.
    class_sample_count = np.array([np.sum(y_train == t) for t in np.unique(y_train)])
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in y_train])
    samples_weight = torch.from_numpy(samples_weight).float()
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)

    train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    model = CapsNetWithDecoder(
        num_capsules=12, 
        capsule_dim=16, 
        num_classes=2,
        reconstruction_weight=0.0005,
        routing_iters=3
    ).to(device)
    
    # Use a slightly lower learning rate and add weight decay for regularization.
    optimizer = optim.Adam(model.parameters(), lr=0.0008, weight_decay=1e-4)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1)

    # Compute class weights from training labels
    classes, counts = np.unique(y_train, return_counts=True)
    class_weights = {int(cls): float(1.0 / count) for cls, count in zip(classes, counts)}
    print("Class weights:", class_weights)

    history = {
        'train_loss': [],
        'val_loss': [],
        'train_accuracy': [],
        'val_accuracy': [],
        'learning_rate': []
    }
    num_epochs = 100
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')
    patience = 15
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss, total_correct = 0.0, 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            class_output, reconstruction = model(data)
            # Pass class weights as a list in order of increasing class index.
            loss_margin = margin_loss(target, class_output, class_weights=list(class_weights.values()))
            loss_reconstruction = F.mse_loss(reconstruction, data)
            loss = loss_margin + model.reconstruction_weight * loss_reconstruction
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * data.size(0)
            total_correct += (class_output.argmax(1) == target).sum().item()
        train_loss = total_loss / len(train_dataset)
        train_acc = total_correct / len(train_dataset)

        model.eval()
        val_loss, val_correct = 0.0, 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                class_output, reconstruction = model(data)
                loss_margin = margin_loss(target, class_output, class_weights=list(class_weights.values()))
                loss_reconstruction = F.mse_loss(reconstruction, data)
                loss = loss_margin + model.reconstruction_weight * loss_reconstruction
                val_loss += loss.item() * data.size(0)
                val_correct += (class_output.argmax(1) == target).sum().item()
        val_loss /= len(val_dataset)
        val_acc = val_correct / len(val_dataset)
        current_lr = optimizer.param_groups[0]['lr']
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_accuracy'].append(train_acc)
        history['val_accuracy'].append(val_acc)
        history['learning_rate'].append(current_lr)
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} | LR: {current_lr:.6f}")
        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping triggered!")
            break
        scheduler.step(epoch + 1)

    model.load_state_dict(best_model_wts)
    visualize_training_metrics(history)
    eval_results = evaluate_model(model, val_loader, device)
    plot_confusion_matrix(eval_results['confusion_matrix'])
    plot_roc_curve(eval_results['fpr'], eval_results['tpr'], eval_results['roc_auc'])
    print("\nClassification Report:")
    print(classification_report(
        eval_results['targets'], 
        eval_results['predictions'], 
        target_names=['Benign', 'Malignant']
    ))
    torch.save(model.state_dict(), "lung_capsnet_model_desktop_v1_improved.pth")
    print("Model saved as lung_capsnet_model_desktop_v1_improved.pth")

if __name__ == "__main__":
    main()
