# Using a VAE to expand the FLAT training data

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import sklearn
from sklearn. model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.feature_selection import SelectKBest, f_classif

import matplotlib.pyplot as plt # For data viz
import pandas as pd
import numpy as np
import sys

print('System Version:', sys.version)
print('PyTorch version', torch.__version__)
print('Numpy version', np.__version__)
print('Pandas version', pd.__version__)
print('Sklearn version', sklearn.__version__)

System Version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
PyTorch version 2.5.1
Numpy version 1.26.4
Pandas version 2.2.3
Sklearn version 1.6.0


In [2]:
# Confirm device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")
count = torch.cuda.device_count()
print(f"Device count: {count}")

Using cuda device
Device count: 2


In [3]:
# Load the data
X = np.load('data/X_TRAIN_RAW.npy')
y = np.load('data/y_TRAIN_RAW.npy')

In [4]:
# Define the dataset
class COPEDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        cope_data = self.data[index]
        cope_data = (cope_data - np.min(cope_data)) / (np.max(cope_data) - np.min(cope_data))  # normalize

        label = self.target[index]
        volume = torch.tensor(cope_data, dtype=torch.float32).unsqueeze(0)  # (1, 91, 109, 91)
        label = torch.tensor([1.0, 0.0] if label == 0 else [0.0, 1.0], dtype=torch.float32)
        return volume, label

In [5]:
# Initiate the dataset and data loader
dataset = COPEDataset(X, y)
dataloader = DataLoader(dataset, batch_size=4)

Creating an autoencoder to generate more data

In [6]:
# -------------------------
# VAE model
# -------------------------
class VAE(nn.Module):
    def __init__(self, input_dim=520, hidden_dim=200, latent_dim=20):
        super().__init__()
        self.input_dim = input_dim
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.fc_dec1 = nn.Linear(latent_dim, hidden_dim)
        self.fc_dec2 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        # x: [B, input_dim]
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        # reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc_dec1(z))
        x_recon = torch.sigmoid(self.fc_dec2(h))  # pixel probabilities 0..1
        return x_recon

    def forward(self, x):
        # x: [B, input_dim]
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar



In [7]:
# -------------------------
# Loss
# -------------------------
def loss_function(recon_x, x, mu, logvar, reduction="sum"):
    # recon_x, x: [B, D]
    # Reconstruction loss: BCE (assuming input in [0,1])
    # Sum over pixels, then average/scale outside as needed
    BCE = F.binary_cross_entropy(recon_x, x, reduction=reduction)  # sum or mean
    # KL divergence between q(z|x) = N(mu, sigma^2) and p(z)=N(0,1)
    # KL = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD, BCE, KLD

In [8]:
# Hyperparameters & setup
# -------------------------
seed = 42
torch.manual_seed(seed)

batch_size = 128
epochs = 20
lr = 1e-3
latent_dim = 20
hidden_dim = 400
image_size = 28
nz = latent_dim  # shorthand

In [9]:
# Train the model
model = VAE(input_dim=520, hidden_dim=hidden_dim, latent_dim=latent_dim).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(30):
    model.train()
    total_loss = 0
    for batch in dataloader:
        volumes, labels = batch
        volumes, labels = volumes.cuda(), labels.cuda()

        optimizer.zero_grad()
        recon, mu, logvar = model(volumes, labels)
        loss = vae_loss(recon, volumes, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}")


TypeError: VAE.forward() takes 2 positional arguments but 3 were given

In [None]:
# Generate synthetic data
n = 200
NEW_X = torch.empty((n*2,91,109,91), device=device)
NEW_y = y
model.eval()

# Generating 40 * 3 new samples (120 samples: 60 of each group)
# Group 0
print('Group 0:')
for i in range(n):
    with torch.no_grad():
        z = torch.randn(1, 128).cuda()

        group_label = torch.tensor([[1.0, 0.0]]).cuda()  # one-hot
        synthetic = model.decode(z, group_label)  # (1, 1, 91, 109, 91)
        synthetic = torch.squeeze(synthetic)
    
        NEW_X[i] = synthetic #, synthetic, axis=0)
        NEW_y = np.append(NEW_y, [0], axis=0)
        print(f'Generated sample {i+1}/200')

# Group 1
print('Group 1:')
for i in range(n):
    with torch.no_grad():
        z = torch.randn(1, 128).cuda()

        group_label = torch.tensor([[0.0, 1.0]]).cuda()  # one-hot
        synthetic = model.decode(z, group_label)  # (1, 1, 91, 109, 91)
        synthetic = torch.squeeze(synthetic)
    
    NEW_X[i+n] = synthetic #NEW_X = np.append(NEW_X, synthetic, axis=0)
    NEW_y = np.append(NEW_y, [1], axis=0)
    print(f'Generated sample {i+1}/200')

Syn_X = NEW_X.cpu().numpy()
NEW_X = np.append(X, Syn_X, axis=0)
        

In [None]:
# Confirm additions
print(NEW_X.shape)
print(NEW_y.shape)

In [None]:
# Save the new data
np.save('data/X_SYN.npy', NEW_X)
np.save('data/y_SYN.npy', NEW_y)