# Using a VAE to expand the training data

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt # For data viz
import pandas as pd
import numpy as np
import sys

print('System Version:', sys.version)
print('PyTorch version', torch.__version__)
print('CUDA version', torch.version.cuda)
print('Numpy version', np.__version__)
print('Pandas version', pd.__version__)

System Version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
PyTorch version 2.5.1
CUDA version 12.4
Numpy version 1.26.4
Pandas version 2.2.3


In [17]:
# Confirm device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")
try: 
    name = torch.cuda.get_device_name(0)
    count = torch.cuda.device_count()
    print(f"Device count: {count}")
    print(f"Device name: {name}")
except RuntimeError:
    print('No GPUs detected')

Using cuda device
Device count: 1
Device name: NVIDIA L40S


In [18]:
# Load the data
X = np.load('data/X_TRAIN_RAW.npy')
y = np.load('data/y_TRAIN_RAW.npy')

In [19]:
# Define the dataset
class COPEDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # Get the data and label
        cope_data = self.data[index]
        label = self.target[index]

        # Normalize the data thorugh per-subject z-score normalization
        cope_data = (cope_data - np.mean(cope_data)) / np.std(cope_data)

        # Convert to tensors
        volume = torch.tensor(cope_data, dtype=torch.float32).unsqueeze(0)  # (1, 91, 109, 91) expects # of chanennels first
        label = torch.tensor([1.0, 0.0] if label == 0 else [0.0, 1.0], dtype=torch.float32)

        return volume, label

In [20]:
# Initiate the dataset and data loader
dataset = COPEDataset(X, y)
dataloader = DataLoader(dataset, batch_size=1)

Creating an autoencoder to generate more data

In [None]:
# Define the model
class CVAE3D(nn.Module):
    def __init__(self, latent_dim=128, label_dim=2, label_embed_dim=16):
        super(CVAE3D, self).__init__()
        self.latent_dim = latent_dim
        self.label_dim = label_dim
        print()
        print("[INFO] instantiating pytorch model: CVAE 3D")

        # Embed the label into a dense vector
        self.label_embedding = nn.Linear(label_dim, label_embed_dim)

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),

            nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),

            nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),

            nn.Conv3d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(),
        )

        # Calculate the size of the flattened layer dynamically
        # Run a dummy tensor through the conv/pool layers to find the shape
        self.flatten_dim = None
        dummy_tesnor_shape = (1, 91, 109, 91) # Input shape (C, D, H, W))
        self._get_encoder_output(dummy_tesnor_shape)

        self.encoder_fc_mu = nn.Linear(self.flatten_dim + label_embed_dim, latent_dim)
        self.encoder_fc_logvar = nn.Linear(self.flatten_dim + label_embed_dim, latent_dim)

        # Decoder
        self.decoder_fc = nn.Linear(latent_dim + label_embed_dim, self.flatten_dim)

        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),

            nn.ConvTranspose3d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),

            nn.ConvTranspose3d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),

            nn.ConvTranspose3d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def _get_encoder_output(self, shape):
        """
        Helper function to calculate the input size for the fully connected layers.
        
        Args:
            shape (tuple): The shape of the input tensor (C, D, H, W).
        Returns:
            None: Sets the self._to_linear attribute.
        """
        # Create a dummy input tensor with batch size 1
        dummy_input = torch.rand(1, *shape) 
        output_features = self.encoder(dummy_input)
        # Store the calculated size
        self.flatten_dim = output_features.view(output_features.size(0), -1).size(1)
        print(f"Calculated encoder feature size: {self.flatten_dim}")

    def encode(self, x, labels):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        labels = self.label_embedding(labels)
        x = torch.cat([x, labels], dim=1)
        mu = self.encoder_fc_mu(x)
        logvar = self.encoder_fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, labels):
        labels = self.label_embedding(labels)
        z = torch.cat([z, labels], dim=1)
        x = self.decoder_fc(z)
        x = x.view(-1, 128, 6, 7, 6)
        x = self.decoder(x)
        print(x.shape)
        # Crop to original shape
        x = x[:, :, :91, :109, :91]
        return x

    def forward(self, x, labels):
        mu, logvar = self.encode(x, labels)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, labels)
        return recon, mu, logvar


In [22]:
# Define the loss
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kld

In [23]:
# Train the model
model = CVAE3D(latent_dim=128).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(30):
    model.train()
    total_loss = 0
    for batch in dataloader:
        volumes, labels = batch
        volumes, labels = volumes.cuda(), labels.cuda()

        optimizer.zero_grad()
        recon, mu, logvar = model(volumes, labels)
        loss = vae_loss(recon, volumes, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}")



[INFO] instantiating pytorch model: CVAE 3D
Calculated encoder feature size: 32256
torch.Size([1, 1, 96, 112, 96])


  recon_loss = F.mse_loss(recon_x, x, reduction='sum')


RuntimeError: The size of tensor a (96) must match the size of tensor b (91) at non-singleton dimension 4

In [None]:
# Generate synthetic data
n = 200
NEW_X = torch.empty((n*2,91,109,91), device=device)
NEW_y = y
model.eval()

# Generating 40 * 3 new samples (120 samples: 60 of each group)
# Group 0
print('Group 0:')
for i in range(n):
    with torch.no_grad():
        z = torch.randn(1, 128).cuda()

        group_label = torch.tensor([[1.0, 0.0]]).cuda()  # one-hot
        synthetic = model.decode(z, group_label)  # (1, 1, 91, 109, 91)
        synthetic = torch.squeeze(synthetic)
    
        NEW_X[i] = synthetic #, synthetic, axis=0)
        NEW_y = np.append(NEW_y, [0], axis=0)
        print(f'Generated sample {i+1}/200')

# Group 1
print('Group 1:')
for i in range(n):
    with torch.no_grad():
        z = torch.randn(1, 128).cuda()

        group_label = torch.tensor([[0.0, 1.0]]).cuda()  # one-hot
        synthetic = model.decode(z, group_label)  # (1, 1, 91, 109, 91)
        synthetic = torch.squeeze(synthetic)
    
    NEW_X[i+n] = synthetic #NEW_X = np.append(NEW_X, synthetic, axis=0)
    NEW_y = np.append(NEW_y, [1], axis=0)
    print(f'Generated sample {i+1}/200')

Syn_X = NEW_X.cpu().numpy()
NEW_X = np.append(X, Syn_X, axis=0)
        

Group 0:
Generated sample 1/200
Generated sample 2/200
Generated sample 3/200
Generated sample 4/200
Generated sample 5/200
Generated sample 6/200
Generated sample 7/200
Generated sample 8/200
Generated sample 9/200
Generated sample 10/200
Generated sample 11/200
Generated sample 12/200
Generated sample 13/200
Generated sample 14/200
Generated sample 15/200
Generated sample 16/200
Generated sample 17/200
Generated sample 18/200
Generated sample 19/200
Generated sample 20/200
Generated sample 21/200
Generated sample 22/200
Generated sample 23/200
Generated sample 24/200
Generated sample 25/200
Generated sample 26/200
Generated sample 27/200
Generated sample 28/200
Generated sample 29/200
Generated sample 30/200
Generated sample 31/200
Generated sample 32/200
Generated sample 33/200
Generated sample 34/200
Generated sample 35/200
Generated sample 36/200
Generated sample 37/200
Generated sample 38/200
Generated sample 39/200
Generated sample 40/200
Generated sample 41/200
Generated sample

In [None]:
# Confirm additions
print(NEW_X.shape)
print(NEW_y.shape)

(409, 91, 109, 91)
(409,)


In [None]:
# Save the new data
np.save('data/X_SYN.npy', NEW_X)
np.save('data/y_SYN.npy', NEW_y)