In [1]:
import numpy as np
import sys
sys.path.append('C:/Users/mosta/OneDrive - UNCG\Academics/CSC 699 - Thesis/repos/brain_connectome/graphIO')
from graphIO import read_ad_adj_data, read_ad_curv_data, analyze_matrices
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import time

In [2]:
AD_ADJ_DIR = "C:/Users/mosta/OneDrive - UNCG/Academics/CSC 699 - Thesis/data/ad_adjacencies/"
AD_CURV_DIR = "C:/Users/mosta/OneDrive - UNCG/Academics/CSC 699 - Thesis/data/curvatures/"
ATLAS = 160
HIDDEN_DIM = 96
LATENT_DIM = 80
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 100
LR = 0.001
BATCH_SIZE = 32

In [3]:
control_adj_matrices, patient_adj_matrices = read_ad_adj_data(AD_ADJ_DIR)
control_curv_matrices, patient_curv_matrices = read_ad_curv_data(AD_CURV_DIR)

analyze_matrices(control_adj_matrices)
analyze_matrices(control_curv_matrices)
analyze_matrices(patient_adj_matrices)
analyze_matrices(patient_curv_matrices)

Reading adjacency matrices: 100%|██████████| 50/50 [00:00<00:00, 322.52it/s]
Reading adjacency matrices: 100%|██████████| 50/50 [00:00<00:00, 357.90it/s]
Reading adjacency matrices: 100%|██████████| 50/50 [00:00<00:00, 311.54it/s]
Reading adjacency matrices: 100%|██████████| 50/50 [00:00<00:00, 308.31it/s]


Statistics for the entire set of matrices:
Mean: 2.6867397195928787e-17
Standard Deviation: 1.0000000000000002
Maximum Value: 45.895410203419495
Minimum Value: -0.1663877963580864
----------------------------------------
Statistics for the entire set of matrices:
Mean: 1.2079226507921702e-17
Standard Deviation: 1.0
Maximum Value: 5.524465098710106
Minimum Value: -0.4229439417607101
----------------------------------------
Statistics for the entire set of matrices:
Mean: 1.1793899190593038e-16
Standard Deviation: 0.9999999999999999
Maximum Value: 43.91243325526468
Minimum Value: -0.1657311204532757
----------------------------------------
Statistics for the entire set of matrices:
Mean: 9.952039192739903e-17
Standard Deviation: 1.0000000000000004
Maximum Value: 5.716892641685452
Minimum Value: -0.4064465658939247
----------------------------------------


In [4]:
tensors = {}
tensors['control_adj'] = torch.tensor(control_adj_matrices.reshape((-1, ATLAS, ATLAS)), dtype=torch.float32).unsqueeze(1)
tensors['patient_adj'] = torch.tensor(patient_adj_matrices.reshape((-1, ATLAS, ATLAS)), dtype=torch.float32).unsqueeze(1)
tensors['control_curv'] = torch.tensor(control_curv_matrices.reshape((-1, ATLAS, ATLAS)), dtype=torch.float32).unsqueeze(1)
tensors['patient_curv'] = torch.tensor(patient_curv_matrices.reshape((-1, ATLAS, ATLAS)), dtype=torch.float32).unsqueeze(1)
print(tensors['control_adj'].shape, tensors['patient_adj'].shape, tensors['control_curv'].shape, tensors['patient_curv'].shape)

torch.Size([50, 1, 160, 160]) torch.Size([50, 1, 160, 160]) torch.Size([50, 1, 160, 160]) torch.Size([50, 1, 160, 160])


In [5]:
# Define the VAE model
class VAE(nn.Module):
    def __init__(self, input_shape=(1, ATLAS, ATLAS), hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),  # output shape: (32, ATLAS, ATLAS)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),  # output shape: (64, ATLAS, ATLAS)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),  # output shape: (128, ATLAS, ATLAS)
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Flatten()  # output shape: (128 * ATLAS * ATLAS)
        )
        
        self.flattened_size = 128 * ATLAS * ATLAS
        
        self.fc1 = nn.Linear(self.flattened_size, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(hidden_dim, latent_dim)
        
        self.fc4 = nn.Linear(latent_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, self.flattened_size)
        
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (128, ATLAS, ATLAS)),
            nn.ConvTranspose2d(128, 64, 3, padding=1),  # output shape: (64, ATLAS, ATLAS)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, padding=1),  # output shape: (32, ATLAS, ATLAS)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, padding=1),  # output shape: (1, ATLAS, ATLAS)
            nn.Sigmoid()
        )
        
    def encode(self, x):
        h = self.encoder(x)
        h = torch.relu(self.bn1(self.fc1(h)))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = torch.relu(self.bn2(self.fc4(z)))
        h = torch.relu(self.fc5(h))
        h = self.decoder(h)
        return h
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def loss_function(recon_x, x, mu, logvar):
    MSE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD

def calculate_mse(dataloader, model):
    model.eval()
    mse_loss = 0
    with torch.no_grad():
        for data, _ in dataloader:
            data = data.to(DEVICE)
            recon, _, _ = model(data)
            mse_loss += nn.functional.mse_loss(recon, data, reduction='sum').item()
    return mse_loss / len(dataloader.dataset)

# Function to get the latent space representation
def get_latent_space(model, data_vector, use_mean=True):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # No need to compute gradients
        # Ensure the data is in the correct shape and tensor format
        data_tensor = torch.tensor(data_vector.reshape(1, 1, ATLAS, ATLAS), dtype=torch.float32).to(DEVICE)
        # Pass through the encoder to get mu and logvar
        mu, logvar = model.encode(data_tensor)
        if use_mean:
            return mu.cpu().numpy()  # Return the mean as the latent representation
        else:
            # Sample from the distribution using reparameterization trick
            z = model.reparameterize(mu, logvar)
            return z.cpu().numpy()  # Return the sampled latent representation

In [6]:
def train_vae(model, dataloader, epochs=100, learning_rate=1e-3, device='cuda'):
    """
    Train the Variational Autoencoder (VAE) model.

    Parameters:
    model (nn.Module): The VAE model to train.
    dataloader (DataLoader): DataLoader for the training data.
    epochs (int): Number of training epochs.
    learning_rate (float): Learning rate for the optimizer.
    device (str): Device to run the training on ('cuda' or 'cpu').

    Returns:
    None
    """
    # Move model to the specified device
    model = model.to(device)
    
    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    model.train()  # Set the model to training mode

    for epoch in range(epochs):
        start_time = time.time()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(dataloader):
            data = data.to(device)
            optimizer.zero_grad()
            recon, mu, logvar = model(data)
            loss = loss_function(recon, data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            # Print the loss for each batch
            print(f"Epoch {epoch + 1} [{batch_idx + 1}/{len(dataloader)}], Batch Loss: {loss.item():.4f}")

        # Calculate time taken for the epoch
        end_time = time.time()
        epoch_time = end_time - start_time
        
        # Print the average loss for this epoch and time taken
        avg_loss = train_loss / len(dataloader.dataset)
        print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}, Time: {epoch_time:.2f}s")

        # Estimate and print the remaining time
        remaining_time = epoch_time * (epochs - epoch - 1)
        print(f"Estimated remaining time: {remaining_time / 60:.2f} minutes")

In [7]:
control_patient_curv_tensor = torch.concatenate((tensors['control_curv'], tensors['patient_curv']))
control_patient_curv_dataset = TensorDataset(control_patient_curv_tensor, control_patient_curv_tensor)
control_patient_curv_dataloader = DataLoader(control_patient_curv_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [8]:
# Define the VAE model
control_patient_curv_model = VAE(input_shape=(1, ATLAS, ATLAS), hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)

# Train the VAE model
train_vae(control_patient_curv_model, control_patient_curv_dataloader, epochs=EPOCHS, learning_rate=LR, device=DEVICE)
# Optionally, save the entire model (including structure and state)
torch.save(control_patient_curv_model, 'control_patient_curv_model.pth')

Epoch 1 [1/4], Batch Loss: 1278663.2500
Epoch 1 [2/4], Batch Loss: 1060325.2500
Epoch 1 [3/4], Batch Loss: 928902.3750
Epoch 1 [4/4], Batch Loss: 117643.0625
Epoch 1, Average Loss: 33855.3394, Time: 41.35s
Estimated remaining time: 68.24 minutes
Epoch 2 [1/4], Batch Loss: 786527.1250
Epoch 2 [2/4], Batch Loss: 760852.0625
Epoch 2 [3/4], Batch Loss: 698770.6250
Epoch 2 [4/4], Batch Loss: 88073.3203
Epoch 2, Average Loss: 23342.2313, Time: 38.42s
Estimated remaining time: 62.76 minutes
Epoch 3 [1/4], Batch Loss: 671608.8750
Epoch 3 [2/4], Batch Loss: 592227.6250
Epoch 3 [3/4], Batch Loss: 616674.1250
Epoch 3 [4/4], Batch Loss: 78631.7188
Epoch 3, Average Loss: 19591.4234, Time: 37.99s
Estimated remaining time: 61.42 minutes
Epoch 4 [1/4], Batch Loss: 555930.2500
Epoch 4 [2/4], Batch Loss: 581783.4375
Epoch 4 [3/4], Batch Loss: 589822.4375
Epoch 4 [4/4], Batch Loss: 72479.2969
Epoch 4, Average Loss: 18000.1542, Time: 39.78s
Estimated remaining time: 63.65 minutes
Epoch 5 [1/4], Batch Loss

In [9]:
# control_patient_adj_tensor = torch.concatenate((tensors['control_adj'], tensors['patient_adj']))
# control_patient_adj_dataset = TensorDataset(control_patient_adj_tensor, control_patient_adj_tensor)
# control_patient_adj_dataloader = DataLoader(control_patient_adj_dataset, batch_size=8, shuffle=True)

In [10]:
# # Define the VAE model
# control_patient_adj_model = VAE(input_shape=(1, ATLAS, ATLAS), hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)

# # Train the VAE model
# train_vae(control_patient_adj_model, control_patient_adj_dataloader, epochs=EPOCHS, learning_rate=LR, device=DEVICE)