<a href="https://colab.research.google.com/github/abyssmu/CS598-Hadi-Womack/blob/main/CS598_nhadi3_sarobin2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
class AgeEncoder(nn.Module):
    """
    Age encoder module as described in the paper.
    Inspired by positional encoding in the Transformer model.
    """
    def __init__(self, d_model=512):
        super(AgeEncoder, self).__init__()
        self.d_model = d_model
        self.fc1 = nn.Linear(d_model, 512)
        self.layer_norm = nn.LayerNorm(512)
        self.fc2 = nn.Linear(512, 1024)

    def forward(self, age):
        """
        Args:
            age: scalar age value (normalized to 0-1 range)
        """
        # Create positional encoding for age
        encoding = torch.zeros(self.d_model, device=age.device)
        div_term = torch.exp(torch.arange(0, self.d_model, 2, device=age.device) *
                            -(math.log(10000.0) / self.d_model))

        # Apply sinusoidal encoding
        encoding[0::2] = torch.sin(age * div_term)
        encoding[1::2] = torch.cos(age * div_term)

        # Transform through linear layers
        x = self.fc1(encoding)
        x = self.layer_norm(x)
        x = self.fc2(x)

        return x

In [3]:
class AlzheimerNet(nn.Module):
    """
    CNN architecture for Alzheimer's disease detection as described in the paper.
    """
    def __init__(self, num_classes=3, widening_factor=8, use_age=True):
        super(AlzheimerNet, self).__init__()
        self.use_age = use_age
        self.widening_factor = widening_factor

        # Block 1
        self.conv1 = nn.Conv3d(1, 4 * widening_factor, kernel_size=1, stride=1, padding=0, dilation=1)
        self.in1 = nn.InstanceNorm3d(4 * widening_factor)
        self.relu1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool3d(kernel_size=3, stride=2)

        # Block 2
        self.conv2 = nn.Conv3d(4 * widening_factor, 32 * widening_factor, kernel_size=3, stride=1, padding=0, dilation=2)
        self.in2 = nn.InstanceNorm3d(32 * widening_factor)
        self.relu2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool3d(kernel_size=3, stride=2)

        # Block 3
        self.conv3 = nn.Conv3d(32 * widening_factor, 64 * widening_factor, kernel_size=5, stride=1, padding=2, dilation=2)
        self.in3 = nn.InstanceNorm3d(64 * widening_factor)
        self.relu3 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool3d(kernel_size=3, stride=2)

        # Block 4
        self.conv4 = nn.Conv3d(64 * widening_factor, 64 * widening_factor, kernel_size=3, stride=1, padding=1, dilation=2)
        self.in4 = nn.InstanceNorm3d(64 * widening_factor)
        self.relu4 = nn.ReLU(inplace=True)
        self.pool4 = nn.MaxPool3d(kernel_size=5, stride=2)

        # Fully connected layers
        self.fc1 = nn.Linear(64 * widening_factor * 5 * 5 * 5, 1024)

        # Age encoder
        if self.use_age:
            self.age_encoder = AgeEncoder(d_model=512)
            self.age_fc1 = nn.Linear(1024, 1024)

        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x, age=None):
        """
        Args:
            x: Input MRI scan [batch_size, 1, 96, 96, 96]
            age: Age of the patient (optional)
        """
        # Block 1
        x = self.conv1(x)
        x = self.in1(x)
        x = self.relu1(x)
        x = self.pool1(x)

        # Block 2
        x = self.conv2(x)
        x = self.in2(x)
        x = self.relu2(x)
        x = self.pool2(x)

        # Block 3
        x = self.conv3(x)
        x = self.in3(x)
        x = self.relu3(x)
        x = self.pool3(x)

        # Block 4
        x = self.conv4(x)
        x = self.in4(x)
        x = self.relu4(x)
        x = self.pool4(x)

        # Flatten
        x = x.view(x.size(0), -1)

        # Fully connected layer
        x = self.fc1(x)

        # Incorporate age if provided
        if self.use_age and age is not None:
            # Encode age
            age_feat = self.age_encoder(age)

            # Combine features
            x = x + age_feat
            x = self.age_fc1(x)

        # Final classification layer
        x = self.fc2(x)

        return x

In [5]:
def preprocess_mri_scan(scan, target_size=(96, 96, 96)):
    """
    Preprocess MRI scan for the model.

    Args:
        scan: Input MRI scan (numpy array)
        target_size: Target size for the model

    Returns:
        Preprocessed scan as torch tensor
    """
    # Add channel dimension if not present
    if len(scan.shape) == 3:
        scan = scan[None, ...]

    # Convert to torch tensor
    if not isinstance(scan, torch.Tensor):
        scan = torch.from_numpy(scan).float()

    # Ensure scan has the right dimensions
    if scan.shape[1:] != target_size:
        # You might need to implement proper 3D resizing
        scan = F.interpolate(scan.unsqueeze(0), size=target_size, mode='trilinear', align_corners=False).squeeze(0)

    return scan

In [6]:
def example_usage():
    # Create the model
    model = AlzheimerNet(num_classes=3, widening_factor=8, use_age=True)

    # Simulate an MRI scan (batch_size, channels, depth, height, width)
    dummy_scan = torch.randn(1, 1, 96, 96, 96)

    # Simulate patient age (normalized)
    patient_age = torch.tensor([75.0 / 120.0], dtype=torch.float32)

    # Forward pass
    output = model(dummy_scan, patient_age)

    # Get prediction
    _, predicted_class = torch.max(output, 1)

    return predicted_class.item()

In [7]:
def train_model(model, train_loader, criterion, optimizer, device, epochs=10):
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for i, data in enumerate(train_loader):
            inputs, ages, labels = data
            inputs, ages, labels = inputs.to(device), ages.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs, ages)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}, Acc: {100*correct/total:.2f}%')

    return model

In [8]:
def get_best_model_config():
    config = {
        'num_classes': 3,  # CN, MCI, AD
        'widening_factor': 8,  # Best performing according to paper
        'learning_rate': 0.01,
        'momentum': 0.9,
        'batch_size': 4,  # For memory considerations
        'epochs': 100,
        'use_age': True  # Whether to incorporate age information
    }
    return config