In [1]:
import os
import h5py
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import nibabel as nib

KeyboardInterrupt: 

In [None]:
base_path = "/home/fgomezacebo/scratch"   # Directory with the HDF5 files
all_data = []

# Load data from each HDF5 file
for file_name in os.listdir(base_path):
    if file_name.endswith('_masked_data.h5'):
        file_path = os.path.join(base_path, file_name)
        with h5py.File(file_path, 'r') as h5f:
            data = h5f['masked_data'][:]
            all_data.append(data)

# Concatenate all subject data
all_data = np.concatenate(all_data, axis=0)

In [None]:
# Define the VAE model with Dropout layers
class Conv3dVAE(nn.Module):
    def __init__(self):
        super(Conv3dVAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=6, stride=2, padding=2),
            nn.MaxPool3d(2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.MaxPool3d(2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Conv3d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.MaxPool3d(2),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        # Compute the size of the latent vector
        self.flatten_size = 128 * 12 * 14 * 12
        
        # Latent vectors mu and logvar
        self.fc_mu = nn.Linear(self.flatten_size, 256)
        self.fc_logvar = nn.Linear(self.flatten_size, 256)
        
        # Decoder
        self.decoder_fc = nn.Linear(256, self.flatten_size)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(32, 1, kernel_size=6, stride=2, padding=2, output_padding=1),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        x = self.decoder_fc(z)
        x = x.view(x.size(0), 128, 12, 14, 12)
        x = self.decoder(x)
        return x
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

In [None]:
# Loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Training settings
num_epochs = 50
learning_rate = 0.001
weight_decay = 1e-4

# Convert data to DataLoader
dataset = TensorDataset(all_data)

# Leave-3-out cross-validation
num_samples = len(dataset)
indices = np.arange(num_samples)


In [None]:
for i in range(num_samples):
    for j in range(i+1, num_samples):
        for k in range(j+1, num_samples):
            # Create training and validation sets
            train_indices = list(set(indices) - {i, j, k})
            val_indices = [i, j, k]

            train_subset = Subset(dataset, train_indices)
            val_subset = Subset(dataset, val_indices)

            train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
            val_loader = DataLoader(val_subset, batch_size=16, shuffle=False)

            # Initialize model, optimizer, and loss function
            model = Conv3dVAE()
            optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

            train_losses = []
            train_bces = []
            train_klds = []

            # Training loop
            for epoch in range(num_epochs):
                model.train()
                train_loss = 0
                train_bce = 0
                train_kld = 0

                for batch_idx, (data, _) in enumerate(train_loader):
                    optimizer.zero_grad()
                    recon_batch, mu, logvar = model(data)
                    loss, bce, kld = loss_function(recon_batch, data, mu, logvar)
                    loss.backward()
                    train_loss += loss.item()
                    train_bce += bce.item()
                    train_kld += kld.item()
                    optimizer.step()

                avg_train_loss = train_loss / len(train_loader.dataset)
                avg_train_bce = train_bce / len(train_loader.dataset)
                avg_train_kld = train_kld / len(train_loader.dataset)

                train_losses.append(avg_train_loss)
                train_bces.append(avg_train_bce)
                train_klds.append(avg_train_kld)

                print(f"Fold {(i, j, k)}, Epoch {epoch}, Total Loss: {avg_train_loss}, Reconstruction Loss: {avg_train_bce}, KLD: {avg_train_kld}")

            # Validation loop
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for batch_idx, (data, _) in enumerate(val_loader):
                    recon_batch, mu, logvar = model(data)
                    loss, bce, kld = loss_function(recon_batch, data, mu, logvar)
                    val_loss += loss.item()

            print(f"Fold {(i, j, k)}, Validation Loss: {val_loss / len(val_loader.dataset)}")

            # Plotting the loss graphs
            plt.figure(figsize=(10, 5))
            plt.plot(train_losses, label='Total Loss')
            plt.plot(train_bces, label='Reconstruction Loss')
            plt.plot(train_klds, label='KLD')
            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.title(f'Losses for fold {(i, j, k)}')
            plt.legend()
            plt.show()

In [None]:
model.eval()
latent_vectors = []

with torch.no_grad():
    for i in range(num_samples):
        sample = dataset[i][0].unsqueeze(0)  # Add batch dimension
        mu, logvar = model.encode(sample)
        z = model.reparameterize(mu, logvar)
        latent_vectors.append(z.cpu().numpy())

latent_vectors = np.array(latent_vectors).squeeze()

In [None]:
k = 5  # Choose the number of clusters
kmeans = KMeans(n_clusters=k, random_state=0).fit(latent_vectors)

# Get the cluster centers
cluster_centers = kmeans.cluster_centers_

# Decode the cluster centers back to the original space
cluster_centers = torch.tensor(cluster_centers).float()
decoded_images = model.decode(cluster_centers)

# Plot the decoded images
decoded_images = decoded_images.cpu().numpy()

for i in range(k):
    plt.figure(figsize=(5, 5))
    plt.imshow(decoded_images[i, 0, :, :, 45], cmap='gray')  # Plot a central slice of the 3D image
    plt.title(f'Cluster {i+1}')
    plt.show()

# Majority voting to create the final cluster map
cluster_map = np.zeros(data.shape[2:])  # Initialize the cluster map

# Assign clusters to the original images
assigned_clusters = kmeans.predict(latent_vectors)

# Create the cluster map using majority voting
for idx in range(num_samples):
    sample = data[idx].squeeze().numpy()
    cluster = assigned_clusters[idx]
    cluster_map += (sample == cluster).astype(int)

# Save the final cluster map
cluster_map_img = nib.Nifti1Image(cluster_map, np.eye(4))
nib.save(cluster_map_img, 'final_cluster_map_VAE.nii')

print("Final cluster map saved as 'final_cluster_map_VAE.nii'")