In [88]:
!pip install monai
!pip install torch torchvision



In [89]:
## Train Test Val Split

# 80 - 10 - 10 split

import random
import os
## train test split
np_dir = '/Users/costanzasiniscalchi/Documents/MS/ADL/Project/ADL_Project/data/numpy_conversions_3_scans'
ids = os.listdir(np_dir)
print(ids)
print(len(ids))
def split_data(ids, train_size=0.7, test_size=0.15, validation_size=0.15):
    # Shuffle the list of IDs
    random.shuffle(ids)
    
    # Calculate split indices
    total_size = len(ids)
    train_end = int(train_size * total_size)
    test_end = train_end + int(test_size * total_size)
    
    # Split the data
    train_ids = ids[:train_end]
    test_ids = ids[train_end:test_end]
    validation_ids = ids[test_end:]
    
    return train_ids, test_ids, validation_ids

# Example usage:
train, test, validation = split_data(ids)

print("Train IDs:", train)
print("Test IDs:", test)
print("Validation IDs:", validation)


['9539210', '6008569', '4276824', '9909448', '2887681', '5030375', '8478383', '1984879', '8052813', '4773593', '3343577', '7125565', '5158901', '4210363', '3301724', '2963960', '5187625', '7755697', '4317780', '1004359', '2823276', '7237992', '3475739', '1369125', '6967785', '1016072', '3048898', '3191214', '3730353', '3162878', '3692881', '9827494', '4052945', '5160587', '4040157', '5692079', '3165520', '8120729', '7760229', '2448082', '4943065', '2924615', '7863867', '9380004', '1635604', '4136011', '7550757', '8686311', '2599481', '3705605', '4210489', '7672530', '5730499', '4532706', '6433158']
55
Train IDs: ['1004359', '3692881', '7125565', '5158901', '4210363', '4136011', '4276824', '3301724', '3343577', '4052945', '2599481', '3191214', '7550757', '5730499', '1016072', '3705605', '1369125', '2963960', '3048898', '7672530', '4040157', '2887681', '9909448', '5187625', '1984879', '1635604', '3162878', '8120729', '3475739', '4943065', '9539210', '3730353', '6433158', '8686311', '3165

In [90]:
import os
import numpy as np
from monai.data import Dataset, DataLoader
from monai.transforms import Compose, ScaleIntensity, ToTensor
import torch

class MRIDataLoader:
    def __init__(self, numpy_dir, id_list, random_order = True, transform=None):
        self.numpy_dir = numpy_dir
        self.transform = transform
        self.patient_ids = id_list
        self.random_order = random_order
        self.data, self.labels = self._prepare_data()

    def _prepare_data(self):
        all_data = []
        labels = []
        for patient_id in self.patient_ids:
            # List the scan dates for each patient
            scan_dates = sorted(os.listdir(os.path.join(self.numpy_dir, patient_id)))
            patient_data = []
            patient_label = []
            for scan_date in scan_dates:
                # Get the corresponding numpy file path
                numpy_file = os.path.join(self.numpy_dir, patient_id, scan_date, f"preventad_{patient_id}_{scan_date}_t1w_001_t1w-defaced_001.npy")
                
                # Check if the numpy file exists
                if os.path.exists(numpy_file):
                    patient_data.append(numpy_file)  # Append MRI file path
                    # Assign the label based on the scan date
                    if scan_date == 'PREBL00':
                        patient_label.append(0)
                    elif scan_date == 'PREFU12':
                        patient_label.append(1)
                    else:
                        patient_label.append(2)
            all_data.append(patient_data)
            labels.append(patient_label)

        return all_data, labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        numpy_files = self.data[idx]
        
        label = self.labels[idx]  # Get the corresponding label
        
        if self.random_order:
            # Shuffle the files and labels together
            combined = list(zip(numpy_files, label))
            random.shuffle(combined)
            numpy_files, label = zip(*combined)  # Unzip back to separate lists

        numpy_data = []
        # Load corresponding numpy file
        for file in numpy_files:
            # Apply transformations if provided
            data = np.load(file)  # Load the scan as a numpy array
            if self.transform:
                data = self.transform(data)  # Apply transformations
            numpy_data.append(data)
        
        # Stack the data for the sequence of scans
        numpy_data = np.stack(numpy_data, axis=0)  # Shape will be (num_scans, channels, D, H, W)

        # Return the data as a dictionary
        return {"labels": torch.tensor(label, dtype=torch.long), "numpy": torch.tensor(numpy_data, dtype=torch.float32)}


In [91]:
# Define a transformation pipeline for MRI data
transform = Compose([ScaleIntensity(), ToTensor()])

# Initialize the dataset loader
train_mri_loader = MRIDataLoader(numpy_dir=np_dir, id_list = train, transform=transform)
test_mri_loader = MRIDataLoader(numpy_dir=np_dir, id_list = test, transform=transform)
val_mri_loader = MRIDataLoader(numpy_dir=np_dir, id_list = validation, transform=transform)

# Create DataLoader
train_loader = DataLoader(train_mri_loader, batch_size=4, shuffle=True)
test_loader = DataLoader(test_mri_loader, batch_size=4, shuffle=True)
val_loader = DataLoader(val_mri_loader, batch_size=4, shuffle=True)


print(len(train_mri_loader))
sample = next(iter(train_loader)) # check if iterating works
print("Sample data shape:", sample['numpy'].shape)  # Check the shape of the data

# Sample label
print("Sample labels:", sample['labels'])

38
Sample data shape: torch.Size([4, 3, 176, 256, 240])
Sample labels: tensor([[2, 1, 0],
        [0, 1, 2],
        [1, 2, 0],
        [1, 2, 0]])


In [92]:
import torch
from monai.networks.nets import AutoencoderKL

# Define the autoencoder model (make sure it matches the saved model architecture)
autoencoder_model = AutoencoderKL(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    latent_channels=8,
    channels=[64, 128, 256],
    num_res_blocks=2,
    norm_num_groups=32,
    norm_eps=1e-06,
    attention_levels=[False, False, False],
    with_encoder_nonlocal_attn=False,
    with_decoder_nonlocal_attn=False,
    include_fc=False
)

# Load the pretrained weights
autoencoder_path = '/Users/costanzasiniscalchi/Documents/MS/ADL/Project/ADL_Project/monai_brats_mri_generative_diffusion_1.1.2/models/model_autoencoder.pt'
state_dict = torch.load(autoencoder_path, map_location=torch.device('cpu'))

# Load the state dict with strict=False to ignore mismatched keys
autoencoder_model.load_state_dict(state_dict, strict=False)

# Set the model to evaluation mode
autoencoder_model.eval()


AutoencoderKL(
  (encoder): Encoder(
    (blocks): ModuleList(
      (0): Convolution(
        (conv): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      )
      (1-2): 2 x AEKLResBlock(
        (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)
        (conv1): Convolution(
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)
        (conv2): Convolution(
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (nin_shortcut): Identity()
      )
      (3): AEKLDownsample(
        (pad): AsymmetricPad()
        (conv): Convolution(
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2))
        )
      )
      (4): AEKLResBlock(
        (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)
        (conv1): Convolution(
          (conv): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=

In [93]:
# Train function
def train_model(train_loader, model, encoder, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for batch in train_loader:
        inputs = batch["numpy"].to(device)  # Shape: (batch_size, num_scans, channels, D, H, W)
        labels = batch["labels"].to(device)  # Shape: (batch_size, num_scans)

        batch_loss = 0
        for i in range(inputs.size(1)):  # Loop over each scan in the batch
            # Get the latent representation for the current scan
            scan_input = inputs[:, i, :, :, :]  # Shape: (batch_size, channels, D, H, W)
            print(f"Shape of scan_input {i}: {scan_input.shape}")
            # Pass the scan through the encoder
            latent_rep = encoder(scan_input)  # Get the encoded features
            print(f"Latent representation shape: {latent_rep.shape}")
            
            # Flatten the latent representation
            latent_rep = latent_rep.view(latent_rep.size(0), -1)  # Flatten to (batch_size, latent_dim)

            # Forward pass through the TemporalOrderModel
            optimizer.zero_grad()
            outputs = model(latent_rep)

            # Compute the loss for the current scan
            loss = criterion(outputs, labels[:, i])
            loss.backward()

            # Add this loss to the batch loss
            batch_loss += loss.item()

        # Average loss over the batch
        optimizer.step()  # Apply gradients after processing all scans in the batch
        total_loss += batch_loss / inputs.size(1)  # Average loss per scan

    avg_loss = total_loss / len(train_loader)
    return avg_loss

# Validation model
def validate_model(val_loader, model, encoder, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch["numpy"].to(device)
            labels = batch["labels"].to(device)

            batch_loss = 0
            for i in range(inputs.size(1)):  # Loop over each scan in the batch
                scan_input = inputs[:, i, :, :, :]
                latent_rep = encoder(scan_input)  # Get the latent features

                latent_rep = latent_rep.view(latent_rep.size(0), -1)  # Flatten the latent representation

                # Forward pass through TemporalOrderModel
                outputs = model(latent_rep)

                # Compute the loss
                loss = criterion(outputs, labels[:, i])
                batch_loss += loss.item()

                # Get predictions
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels[:, i]).sum().item()

            total_loss += batch_loss / inputs.size(1)  # Average loss per scan

    avg_loss = total_loss / len(val_loader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy


In [94]:
import torch
import torch.optim as optim
import torch.nn as nn
from monai.networks.nets import AutoencoderKL

# Define the TemporalOrderModel class
class TemporalOrderModel(nn.Module):
    def __init__(self, input_size):
        super(TemporalOrderModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 3)  # Output 3 classes for 3 possible orders

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()

# Initialize the encoder and model
autoencoder_model = AutoencoderKL(
    spatial_dims=3,
    in_channels=4,  # Update to 4 channels if your data has 4 channels
    out_channels=1,
    latent_channels=8,
    channels=[64, 128, 256],
    num_res_blocks=2,
    norm_num_groups=32,
    norm_eps=1e-06,
    attention_levels=[False, False, False],
    with_encoder_nonlocal_attn=False,
    with_decoder_nonlocal_attn=False,
    include_fc=False
)


input_size = 8 * 64 * 64 * 64  # Update based on actual latent dimensions (this is an example)
temporal_order_model = TemporalOrderModel(input_size=input_size)

# Move models to the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
temporal_order_model.to(device)
autoencoder_model.to(device)

# Define optimizer, using both temporal_order_model and autoencoder_model (if fine-tuning the encoder)
optimizer = optim.Adam(
    list(temporal_order_model.parameters()) + list(autoencoder_model.parameters()), 
    lr=0.001
)

# Check model architecture for debugging
print("TemporalOrderModel architecture:", temporal_order_model)
print("Autoencoder architecture:", autoencoder_model)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train_model(train_loader, temporal_order_model, autoencoder_model.encoder, criterion, optimizer, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}")
    
    # Validation
    val_loss, val_accuracy = validate_model(val_loader, temporal_order_model, autoencoder_model.encoder, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")


TemporalOrderModel architecture: TemporalOrderModel(
  (fc1): Linear(in_features=2097152, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=3, bias=True)
)
Autoencoder architecture: AutoencoderKL(
  (encoder): Encoder(
    (blocks): ModuleList(
      (0): Convolution(
        (conv): Conv3d(4, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      )
      (1-2): 2 x AEKLResBlock(
        (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)
        (conv1): Convolution(
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)
        (conv2): Convolution(
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (nin_shortcut): Identity()
      )
      (3): AEKLDownsample(
        (pad): AsymmetricPad()
        (conv): Convolution(
   

RuntimeError: Expected number of channels in input to be divisible by num_groups, but got input of shape [64, 176, 256, 240] and num_groups=32