# Using a VAE to expand the training data

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import sklearn
from sklearn. model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.feature_selection import SelectKBest, f_classif

import matplotlib.pyplot as plt # For data viz
import pandas as pd
import numpy as np
import sys

print('System Version:', sys.version)
print('PyTorch version', torch.__version__)
print('Numpy version', np.__version__)
print('Pandas version', pd.__version__)
print('Sklearn version', sklearn.__version__)

System Version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
PyTorch version 2.5.1
Numpy version 1.26.4
Pandas version 2.2.3
Sklearn version 1.6.0


In [2]:
# Confirm device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


In [3]:
# Load the data
X = np.load('data/X_TRAIN_RAW.npy')
y = np.load('data/y_TRAIN_RAW.npy')

In [4]:
# Define the dataset
class COPEDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        cope_data = self.data[index]
        cope_data = (cope_data - np.min(cope_data)) / (np.max(cope_data) - np.min(cope_data))  # normalize

        label = self.target[index]
        volume = torch.tensor(cope_data, dtype=torch.float32).unsqueeze(0)  # (1, 91, 109, 91)
        label = torch.tensor([1.0, 0.0] if label == 0 else [0.0, 1.0], dtype=torch.float32)
        return volume, label

In [5]:
# Initiate the dataset and data loader
dataset = COPEDataset(X, y)
dataloader = DataLoader(dataset, batch_size=4)

Creating an autoencoder to generate more data

In [6]:
# Define the model
class CVAE3D(nn.Module):
    def __init__(self, latent_dim=128, label_dim=2, label_embed_dim=16):
        super(CVAE3D, self).__init__()
        self.latent_dim = latent_dim
        self.label_dim = label_dim

        # Embed the label into a dense vector
        self.label_embedding = nn.Linear(label_dim, label_embed_dim)

        # Encoder
        self.encoder_conv = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1),  # -> (16, 46, 55, 46)
            nn.BatchNorm3d(16),
            nn.ReLU(),

            nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1),  # -> (32, 23, 28, 23)
            nn.BatchNorm3d(32),
            nn.ReLU(),

            nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1),  # -> (64, 12, 14, 12)
            nn.BatchNorm3d(64),
            nn.ReLU(),

            nn.Conv3d(64, 128, kernel_size=3, stride=2, padding=1),  # -> (128, 6, 7, 6)
            nn.BatchNorm3d(128),
            nn.ReLU()
        )

        self.flatten_dim = 128 * 6 * 7 * 6
        self.encoder_fc_mu = nn.Linear(self.flatten_dim + label_embed_dim, latent_dim)
        self.encoder_fc_logvar = nn.Linear(self.flatten_dim + label_embed_dim, latent_dim)

        # Decoder
        self.decoder_fc = nn.Linear(latent_dim + label_embed_dim, self.flatten_dim)

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),

            nn.ConvTranspose3d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),

            nn.ConvTranspose3d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),

            nn.ConvTranspose3d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def encode(self, x, labels):
        x = self.encoder_conv(x)
        x = x.view(x.size(0), -1)
        labels = self.label_embedding(labels)
        x = torch.cat([x, labels], dim=1)
        mu = self.encoder_fc_mu(x)
        logvar = self.encoder_fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, labels):
        labels = self.label_embedding(labels)
        z = torch.cat([z, labels], dim=1)
        x = self.decoder_fc(z)
        x = x.view(-1, 128, 6, 7, 6)
        x = self.decoder_conv(x)
        # Crop to original shape
        x = x[:, :, :91, :109, :91]
        return x

    def forward(self, x, labels):
        mu, logvar = self.encode(x, labels)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, labels)
        return recon, mu, logvar


In [7]:
# Define the loss
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kld

In [8]:
# Train the model
model = CVAE3D(latent_dim=128).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(30):
    model.train()
    total_loss = 0
    for batch in dataloader:
        volumes, labels = batch
        volumes, labels = volumes.cuda(), labels.cuda()

        optimizer.zero_grad()
        recon, mu, logvar = model(volumes, labels)
        loss = vae_loss(recon, volumes, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}")


Epoch 1, Loss: 133936.89765625
Epoch 2, Loss: 120481.63359375
Epoch 3, Loss: 109592.3484375
Epoch 4, Loss: 102097.8734375
Epoch 5, Loss: 95371.459375
Epoch 6, Loss: 89257.2359375
Epoch 7, Loss: 84095.89453125
Epoch 8, Loss: 78951.27265625
Epoch 9, Loss: 74174.776953125
Epoch 10, Loss: 69750.843359375
Epoch 11, Loss: 65801.34453125
Epoch 12, Loss: 61995.145703125
Epoch 13, Loss: 58861.404296875
Epoch 14, Loss: 56309.698046875
Epoch 15, Loss: 55212.195703125
Epoch 16, Loss: 52680.5125
Epoch 17, Loss: 49727.79765625
Epoch 18, Loss: 47429.1859375
Epoch 19, Loss: 45393.740625
Epoch 20, Loss: 43628.162890625
Epoch 21, Loss: 42137.19453125
Epoch 22, Loss: 40661.204296875
Epoch 23, Loss: 39679.85703125
Epoch 24, Loss: 38884.97578125
Epoch 25, Loss: 37911.194921875
Epoch 26, Loss: 36609.539453125
Epoch 27, Loss: 35695.41796875
Epoch 28, Loss: 34560.112890625
Epoch 29, Loss: 33563.124609375
Epoch 30, Loss: 33029.05078125


In [9]:
# Generate synthetic data
NEW_X = X
NEW_y = y
model.eval()

# Generating 40 * 3 new samples (120 samples: 60 of each group)
# Group 0
for i in range(200):
    with torch.no_grad():
        z = torch.randn(1, 128).cuda()

        group_label = torch.tensor([[1.0, 0.0]]).cuda()  # one-hot
        synthetic = model.decode(z, group_label)  # (1, 1, 91, 109, 91)
        synthetic = np.squeeze(synthetic.cpu(), axis=0)
    
    NEW_X = np.append(NEW_X, synthetic, axis=0)
    NEW_y = np.append(NEW_y, [0], axis=0)

# Group 1
for i in range(200):
    with torch.no_grad():
        z = torch.randn(1, 128).cuda()

        group_label = torch.tensor([[0.0, 1.0]]).cuda()  # one-hot
        synthetic = model.decode(z, group_label)  # (1, 1, 91, 109, 91)
        synthetic = np.squeeze(synthetic.cpu(), axis=0)
    
    NEW_X = np.append(NEW_X, synthetic, axis=0)
    NEW_y = np.append(NEW_y, [1], axis=0)


        

In [10]:
# Confirm additions
print(NEW_X.shape)
print(NEW_y.shape)

(419, 91, 109, 91)
(419,)


In [11]:
# Save the new data
np.save('data/X_SYN.npy', NEW_X)
np.save('data/y_SYN.npy', NEW_y)