In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset

class MRIDataset(Dataset):
    def __init__(self, data, transform=None):
        """
        Args:
            mri_path (str): Path to NumPy array of MRI scans.
            transform (callable, optional): Optional transform to apply to MRI scans.
        """
        self.mri_data = data 
        assert self.mri_data.shape[0] % 3 == 0, "Total scans must be a multiple of 3 (each patient has 3 scans)."
        
        self.num_patients = self.mri_data.shape[0] // 3  # Each patient has 3 scans
        self.transform = transform

    def __len__(self):
        return self.num_patients  # Number of patients

    def __getitem__(self, idx):
        """
        Returns a triplet of MRI scans for a patient.
        Output shape: [3, C, H, W, D]
        """
        start_idx = idx * 3
        scans = self.mri_data[start_idx : start_idx + 3]  # Grab three scans (0, 12, 24 months)
        scans = torch.tensor(scans, dtype=torch.float32)  # Convert to PyTorch tensor


        return scans  # Shape: [3, C, H, W, D]


In [None]:
from torch.utils.data import DataLoader

mri_path = "scan_data.npy"
mri_data = np.load(mri_path) 
mri_data = mri_data[0:21]


# Create dataset
dataset = MRIDataset(mri_data)




In [4]:
print(len(dataset))

7


In [5]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

In [6]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x10d656fc0>

In [7]:
# Test one batch
for batch in dataloader:
    print("Batch Shape:", batch.shape)  # Expected: [4, 3, C, H, W, D] for batch size 4
    break

Batch Shape: torch.Size([4, 3, 1, 176, 256, 240])


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset

# Variational Autoencoder (VAE) for Encoding MRI Scans
import torch
import torch.nn as nn
import torch.optim as optim

import torch
import torch.nn as nn

class ConvVAE(nn.Module):
    def __init__(self, channels, height, width, depth, latent_dim):
        super(ConvVAE, self).__init__()

        # Encoder with Conv layers
        self.encoder = nn.Sequential(
            nn.Conv3d(channels, 32, kernel_size=4, stride=2, padding=1),  # (32, H/2, W/2, D/2)
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=4, stride=2, padding=1),         # (64, H/4, W/4, D/4)
            nn.ReLU(),
            nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1),        # (128, H/8, W/8, D/8)
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * (height // 8) * (width // 8) * (depth // 8), latent_dim * 2)  # Mean & log-variance
        )

        # Decoder with ConvTranspose layers
        self.decoder_input = nn.Linear(latent_dim, 128 * (height // 8) * (width // 8) * (depth // 8))
        
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (128, height // 8, width // 8, depth // 8)),
            nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1),  # Upsample
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(32, channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # Output range [0, 1]
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # Encode
        h = self.encoder(x)
        mu, logvar = h[:, :h.shape[1]//2], h[:, h.shape[1]//2:]

        # Reparameterize
        z = self.reparameterize(mu, logvar)

        # Decode
        x_recon = self.decoder_input(z)
        x_recon = self.decoder(x_recon)

        return x_recon, mu, logvar, z



# Latent Diffusion Model (LDM) for Temporal Ordering and Prediction
class LDM(nn.Module):
    def __init__(self, latent_dim):
        super(LDM, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(1, 512), nn.ReLU(), 
            nn.Linear(512, 512), nn.ReLU(),
            nn.Linear(512, latent_dim)
        )

    def forward(self, z_noisy, conditioning):
        print(f"z_noisy shape: {z_noisy.shape}, conditioning shape: {conditioning.shape}")

        # Dynamically update input layer size on first forward pass
        if self.model[0].in_features != z_noisy.shape[1] + conditioning.shape[1]:
            input_dim = z_noisy.shape[1] + conditioning.shape[1]
            
            # Update the first linear layer with correct dimensions
            self.model[0] = nn.Linear(input_dim, 512).to(z_noisy.device)

        return self.model(torch.cat([z_noisy, conditioning], dim=1))

# Next Scan Prediction Model
class NextScanPredictor(nn.Module):
    def __init__(self, latent_dim):
        super(NextScanPredictor, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim * 2, 512), nn.ReLU(),
            nn.Linear(512, 512), nn.ReLU(),
            nn.Linear(512, latent_dim)  # Ensure output is `latent_dim`
        )

    def forward(self, z_t):
        return self.model(z_t)  # Output shape: (batch_size, latent_dim)



# Training Loops

def pretrain_ldm(vae, ldm, dataloader, epochs=10, lr=1e-3, device='cuda'):
    optimizer = torch.optim.Adam(ldm.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    for epoch in range(epochs):
        vae.train()
        ldm.train()

        for mri_seq in dataloader:
            
            batch_size, seq_len, channels, height, width, depth = mri_seq.shape
            
            # Initialize lists for latent representations
            latents = []

            # Encode MRI scans into latent space
            print("encode")
            for i in range(seq_len):
                with torch.no_grad():  # No need to track gradients for encoding
                    _, _, _, z = vae(mri_seq[:, i])
                latents.append(z)

            # Stack latent representations: [batch_size, seq_len, latent_dim]

            print("noise")
            z_latents = torch.stack(latents, dim=1)

            # Shuffle and add noise in latent space
            perm = torch.randperm(seq_len)
            z_shuffled = z_latents[:, perm]
            
            noise = torch.randn_like(z_shuffled)
            z_noisy = z_shuffled + noise

            # Pass the latent representations to LDM
            print("denoise")
            z_denoised = ldm(z_noisy.view(batch_size * seq_len, -1), conditioning=noise.view(batch_size * seq_len, -1))

            # Reshape back to original latent structure
            z_denoised = z_denoised.view(batch_size, seq_len, -1)

            # Compute loss
            diffusion_loss = loss_fn(z_denoised, z_shuffled)

            optimizer.zero_grad()
            diffusion_loss.backward()
            optimizer.step()

        print(f"Pretraining Epoch {epoch + 1}, Diffusion Loss: {diffusion_loss.item():.4f}")

    return vae, ldm




# Set hyperparameters


latent_dim = 32



In [9]:
def fine_tune_ldm(vae, ldm, predictor, dataloader, epochs=10, lr=1e-3, device='cuda'):
    optimizer_ldm = torch.optim.Adam(ldm.parameters(), lr=lr)
    optimizer_pred = torch.optim.Adam(predictor.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    vae.eval()  # Freeze VAE
    ldm.train()  # Train LDM
    predictor.train()  # Train Next Scan Predictor

    for epoch in range(epochs):
        total_loss = 0

        for mri_seq in dataloader:
            batch_size, seq_len, channels, height, width, depth = mri_seq.shape

            latents = []

            # Encode MRI scans into latent space
            for i in range(seq_len):
                with torch.no_grad():
                    _, _, _, z = vae(mri_seq[:, i])
                latents.append(z)

            z_latents = torch.stack(latents, dim=1)  # [batch_size, seq_len, latent_dim]

            # Add noise for diffusion process
            noise = torch.randn_like(z_latents) * 0.1
            z_noisy = z_latents + noise

            # LDM learns to denoise
            z_denoised = ldm(z_noisy.view(batch_size * seq_len, -1), conditioning=noise.view(batch_size * seq_len, -1))
            z_denoised = z_denoised.view(batch_size, seq_len, -1)

            # **Fix the shape before passing to predictor**
            # Select last two time steps and concatenate them
            z_pred_input = torch.cat([z_denoised[:, -2], z_denoised[:, -1]], dim=-1)  # [batch_size, latent_dim * 2]

            # **Predict next latent representation**
            # Ensure correct latent shape before decoding
            z_next_pred = predictor(z_pred_input)  # Output: [batch_size, latent_dim]

            # Pass correctly through VAE decoder
            z_next_pred = z_next_pred.view(batch_size, -1)  # Ensure proper shape
            x_next_pred = vae.decoder_input(z_next_pred)  # Linear layer to expand
            x_next_pred = vae.decoder(x_next_pred)  # Decode to MRI scan


            # Compute loss (predictor should match actual next scan's latent)
            pred_loss = loss_fn(z_next_pred, z_latents[:, -1])

            optimizer_ldm.zero_grad()
            optimizer_pred.zero_grad()
            pred_loss.backward()
            optimizer_ldm.step()
            optimizer_pred.step()

            total_loss += pred_loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Fine-tuning Epoch {epoch + 1}/{epochs}, Prediction Loss: {avg_loss:.4f}")

    return vae, ldm, predictor


In [10]:
vae = ConvVAE(1, 176, 256, 240, latent_dim)


In [11]:
ldm = LDM(latent_dim)

In [12]:
vae, ldm = pretrain_ldm(vae, ldm, dataloader, epochs=10, lr=1e-3)

encode
noise
denoise
z_noisy shape: torch.Size([12, 32]), conditioning shape: torch.Size([12, 32])
encode
noise
denoise
z_noisy shape: torch.Size([9, 32]), conditioning shape: torch.Size([9, 32])
Pretraining Epoch 1, Diffusion Loss: 0.9971
encode
noise
denoise
z_noisy shape: torch.Size([12, 32]), conditioning shape: torch.Size([12, 32])
encode
noise
denoise
z_noisy shape: torch.Size([9, 32]), conditioning shape: torch.Size([9, 32])
Pretraining Epoch 2, Diffusion Loss: 0.8842
encode
noise
denoise
z_noisy shape: torch.Size([12, 32]), conditioning shape: torch.Size([12, 32])
encode
noise
denoise
z_noisy shape: torch.Size([9, 32]), conditioning shape: torch.Size([9, 32])
Pretraining Epoch 3, Diffusion Loss: 0.9172
encode
noise
denoise
z_noisy shape: torch.Size([12, 32]), conditioning shape: torch.Size([12, 32])
encode
noise
denoise
z_noisy shape: torch.Size([9, 32]), conditioning shape: torch.Size([9, 32])
Pretraining Epoch 4, Diffusion Loss: 0.9923
encode
noise
denoise
z_noisy shape: torc

In [13]:
predictor = NextScanPredictor(latent_dim=latent_dim)



In [14]:
vae, ldm, predictor = fine_tune_ldm(vae, ldm, predictor, dataloader, epochs=5, lr=1e-3)


z_noisy shape: torch.Size([12, 32]), conditioning shape: torch.Size([12, 32])
z_noisy shape: torch.Size([9, 32]), conditioning shape: torch.Size([9, 32])
Fine-tuning Epoch 1/5, Prediction Loss: 1.0506
z_noisy shape: torch.Size([12, 32]), conditioning shape: torch.Size([12, 32])
z_noisy shape: torch.Size([9, 32]), conditioning shape: torch.Size([9, 32])
Fine-tuning Epoch 2/5, Prediction Loss: 0.8584
z_noisy shape: torch.Size([12, 32]), conditioning shape: torch.Size([12, 32])
z_noisy shape: torch.Size([9, 32]), conditioning shape: torch.Size([9, 32])
Fine-tuning Epoch 3/5, Prediction Loss: 0.9501
z_noisy shape: torch.Size([12, 32]), conditioning shape: torch.Size([12, 32])
z_noisy shape: torch.Size([9, 32]), conditioning shape: torch.Size([9, 32])
Fine-tuning Epoch 4/5, Prediction Loss: 0.8819
z_noisy shape: torch.Size([12, 32]), conditioning shape: torch.Size([12, 32])
z_noisy shape: torch.Size([9, 32]), conditioning shape: torch.Size([9, 32])
Fine-tuning Epoch 5/5, Prediction Loss: 0.

Remaining TODO: 
* Add in Patient Info 
* Evaluation Metrics
* Add in remaining data (currently model accounts for 10%)
* Hyperparameter tune VAE, LDM, and predictor