# Using a VAE to expand the training data

In [153]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import sklearn
from sklearn. model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.feature_selection import SelectKBest, f_classif

import matplotlib.pyplot as plt # For data viz
import pandas as pd
import numpy as np
import sys

print('System Version:', sys.version)
print('PyTorch version', torch.__version__)
print('CUDA version', torch.version.cuda)
print('Numpy version', np.__version__)
print('Pandas version', pd.__version__)
print('Sklearn version', sklearn.__version__)

System Version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
PyTorch version 2.5.1
CUDA version 12.4
Numpy version 1.26.4
Pandas version 2.2.3
Sklearn version 1.6.0


In [154]:
# Confirm device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")
try: 
    name = torch.cuda.get_device_name(0)
    count = torch.cuda.device_count()
    print(f"Device count: {count}")
    print(f"Device name: {name}")
except RuntimeError:
    print('No GPUs detected')

Using cuda device
Device count: 1
Device name: NVIDIA L40S


In [155]:
# Load the data
X = np.load('data/X_TRAIN_RAW.npy')
y = np.load('data/y_TRAIN_RAW.npy')

In [156]:
# Define the dataset
class COPEDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        cope_data = self.data[index]
        cope_data = (cope_data - np.min(cope_data)) / (np.max(cope_data) - np.min(cope_data))  # normalize

        label = self.target[index]
        volume = torch.tensor(cope_data, dtype=torch.float32).unsqueeze(0)  # (1, 91, 109, 91)
        label = torch.tensor([1.0, 0.0] if label == 0 else [0.0, 1.0], dtype=torch.float32)
        return volume, label

In [157]:
# Initiate the dataset and data loader
dataset = COPEDataset(X, y)
dataloader = DataLoader(dataset, batch_size=4)

Creating an autoencoder to generate more data

In [None]:
# Define the model
class CVAE3D(nn.Module):
    def __init__(self, latent_dim=128, label_dim=2, label_embed_dim=16):
        super(CVAE3D, self).__init__()
        self.latent_dim = latent_dim
        self.label_dim = label_dim

        # Embed the label into a dense vector
        self.label_embedding = nn.Linear(label_dim, label_embed_dim)

        # Encoder
        self.encoder_conv = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),

        )

        # Calculate the size of the flattened layer dynamically
        # We run a dummy tensor through the conv/pool layers to find the shape
        self.flatten_dim = None
        self.min_dim = None
        self._get_conv_output((1, 23, 19, 15)) # Dummy run to calculate self._to_linear

        print(self.flatten_dim)
        print(self.min_dim)

        self.encoder_fc_mu = nn.Linear(self.flatten_dim + label_embed_dim, latent_dim)
        self.encoder_fc_logvar = nn.Linear(self.flatten_dim + label_embed_dim, latent_dim)

        # Decoder
        self.decoder_fc = nn.Linear(latent_dim + label_embed_dim, self.flatten_dim)

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose3d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def _get_conv_output(self, shape):
        """Helper function to calculate the input size for the fully connected layers."""
        # Create a dummy input tensor with batch size 1
        dummy_input = torch.rand(1, *shape) 
        output_features = self.encoder_conv(dummy_input)
        # Store the calculated size
        self.flatten_dim = output_features.view(output_features.size(0), -1).size(1)
        self.min_dim = output_features.shape

    
    def encode(self, x, labels):
        x = self.encoder_conv(x)
        x = x.view(x.size(0), -1)
        labels = self.label_embedding(labels)
        x = torch.cat([x, labels], dim=1)
        mu = self.encoder_fc_mu(x)
        logvar = self.encoder_fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, labels):
        labels = self.label_embedding(labels)
        z = torch.cat([z, labels], dim=1)
        x = self.decoder_fc(z)
        x = x.view(-1, 16, 5, 4, 4)
        x = self.decoder_conv(x)
        # Crop to original shape
        x = x[:, :, :23, :19, :15]
        return x

    def forward(self, x, labels):
        mu, logvar = self.encode(x, labels)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, labels)
        return recon, mu, logvar


In [159]:
# Define the loss
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kld

In [160]:
# Train the model
model = CVAE3D(latent_dim=128).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(30):
    model.train()
    total_loss = 0
    for batch in dataloader:
        volumes, labels = batch
        volumes, labels = volumes.cuda(), labels.cuda()

        optimizer.zero_grad()
        recon, mu, logvar = model(volumes, labels)
        loss = vae_loss(recon, volumes, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}")


1280
torch.Size([1, 16, 5, 4, 4])
Epoch 1, Loss: 145.67093658447266
Epoch 2, Loss: 123.98710038926866
Epoch 3, Loss: 121.16567484537761
Epoch 4, Loss: 118.48729960123698
Epoch 5, Loss: 114.86731211344402
Epoch 6, Loss: 113.19648996988933
Epoch 7, Loss: 112.18529256184895
Epoch 8, Loss: 110.75551859537761
Epoch 9, Loss: 110.0108405219184
Epoch 10, Loss: 108.85653898451064
Epoch 11, Loss: 108.55058966742621
Epoch 12, Loss: 108.54929818047418
Epoch 13, Loss: 107.89935514662001
Epoch 14, Loss: 107.4641973707411
Epoch 15, Loss: 106.14699215359158
Epoch 16, Loss: 105.61040581597223
Epoch 17, Loss: 105.64522043863933
Epoch 18, Loss: 104.81863191392686
Epoch 19, Loss: 103.63469314575195
Epoch 20, Loss: 103.8680771721734
Epoch 21, Loss: 103.40898513793945
Epoch 22, Loss: 102.73618189493816
Epoch 23, Loss: 101.91496870252821
Epoch 24, Loss: 100.90865707397461
Epoch 25, Loss: 101.32738749186198
Epoch 26, Loss: 101.50817447238498
Epoch 27, Loss: 101.07666820949979
Epoch 28, Loss: 100.3480389912923

In [161]:
# Generate synthetic data
n = 500
NEW_X = torch.empty((n*2,10,7,8), device=device)
NEW_y = y
model.eval()

# Generating 40 * 3 new samples (120 samples: 60 of each group)
# Group 0
print('Group 0:')
for i in range(n):
    with torch.no_grad():
        z = torch.randn(1, 128).cuda()

        group_label = torch.tensor([[1.0, 0.0]]).cuda()  # one-hot
        synthetic = model.decode(z, group_label)  # (1, 1, 91, 109, 91)
        synthetic = torch.squeeze(synthetic)
    
        NEW_X[i] = synthetic #, synthetic, axis=0)
        NEW_y = np.append(NEW_y, [0], axis=0)
        print(f'Generated sample {i+1}/{n}')

# Group 1
print('Group 1:')
for i in range(n):
    with torch.no_grad():
        z = torch.randn(1, 128).cuda()

        group_label = torch.tensor([[0.0, 1.0]]).cuda()  # one-hot
        synthetic = model.decode(z, group_label)  # (1, 1, 91, 109, 91)
        synthetic = torch.squeeze(synthetic)
    
    NEW_X[i+n] = synthetic #NEW_X = np.append(NEW_X, synthetic, axis=0)
    NEW_y = np.append(NEW_y, [1], axis=0)
    print(f'Generated sample {i+1}/{n}')

Syn_X = NEW_X.cpu().numpy()
NEW_X = np.append(X, Syn_X, axis=0)
        

Group 0:
Generated sample 1/500
Generated sample 2/500
Generated sample 3/500
Generated sample 4/500
Generated sample 5/500
Generated sample 6/500
Generated sample 7/500
Generated sample 8/500
Generated sample 9/500
Generated sample 10/500
Generated sample 11/500
Generated sample 12/500
Generated sample 13/500
Generated sample 14/500
Generated sample 15/500
Generated sample 16/500
Generated sample 17/500
Generated sample 18/500
Generated sample 19/500
Generated sample 20/500
Generated sample 21/500
Generated sample 22/500
Generated sample 23/500
Generated sample 24/500
Generated sample 25/500
Generated sample 26/500
Generated sample 27/500
Generated sample 28/500
Generated sample 29/500
Generated sample 30/500
Generated sample 31/500
Generated sample 32/500
Generated sample 33/500
Generated sample 34/500
Generated sample 35/500
Generated sample 36/500
Generated sample 37/500
Generated sample 38/500
Generated sample 39/500
Generated sample 40/500
Generated sample 41/500
Generated sample

Generated sample 134/500
Generated sample 135/500
Generated sample 136/500
Generated sample 137/500
Generated sample 138/500
Generated sample 139/500
Generated sample 140/500
Generated sample 141/500
Generated sample 142/500
Generated sample 143/500
Generated sample 144/500
Generated sample 145/500
Generated sample 146/500
Generated sample 147/500
Generated sample 148/500
Generated sample 149/500
Generated sample 150/500
Generated sample 151/500
Generated sample 152/500
Generated sample 153/500
Generated sample 154/500
Generated sample 155/500
Generated sample 156/500
Generated sample 157/500
Generated sample 158/500
Generated sample 159/500
Generated sample 160/500
Generated sample 161/500
Generated sample 162/500
Generated sample 163/500
Generated sample 164/500
Generated sample 165/500
Generated sample 166/500
Generated sample 167/500
Generated sample 168/500
Generated sample 169/500
Generated sample 170/500
Generated sample 171/500
Generated sample 172/500
Generated sample 173/500


In [162]:
# Confirm additions
print(NEW_X.shape)
print(NEW_y.shape)

(1035, 10, 7, 8)
(1035,)


In [163]:
# Save the new data
np.save('data/X_SYN.npy', NEW_X)
np.save('data/y_SYN.npy', NEW_y)