In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tabulate
parent_dir = os.path.join(os.getcwd(), '..', '..')
if parent_dir not in sys.path: sys.path.append(parent_dir)
from utility.data import get_loader, EmbeddingDataset
from model.gan import Generator, Discriminator
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from datasets import load_from_disk
from sklearn.metrics import accuracy_score
from model.lstm import LSTMClassifier

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 10
learning_rate = 0.001
batch_size = 128
latent_dim = 128
condition_dim = 10
gan_epochs = 50
generation_size = 1000

full_dataset = load_from_disk('../../data/full_dataset_new', keep_in_memory=True)
split_datasets = full_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split_datasets['train']
test_dataset = split_datasets['test']

train_x_full = np.array(train_dataset['embedding'])
train_y_full = np.array(train_dataset['labels'])

test_x = np.array(test_dataset['embedding'])
test_y = np.array(test_dataset['labels'])

In [3]:


def train_and_evaluate_lstm(model, train_loader, test_loader, num_epochs, lr, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(1, num_epochs+1):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * batch_y.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == batch_y).sum().item()
            total += batch_y.size(0)
        train_loss = running_loss / total
        train_acc = correct / total
        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}")

    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = model(batch_x)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch_y.cpu().numpy())
    test_acc = accuracy_score(all_labels, all_preds)
    return test_acc, all_preds, all_labels



train_ds = EmbeddingDataset(train_dataset)
test_ds = EmbeddingDataset(test_dataset)

In [None]:
sample_sizes = [20, 50, 70, 100, 300, 1000, len(train_ds)]
lstm_accuracy_before = {}
lstm_accuracy_after = {}

for size in sample_sizes:
    print(f"\n[Real Data Only] Training size: {size}")
    train_subset = Subset(train_ds, range(size))
    train_loader = get_loader(train_subset, batch_size=batch_size, shuffle=True)
    test_loader = get_loader(test_ds, batch_size=batch_size, shuffle=False)

    input_dim = train_ds[0][0].shape[0]  
    hidden_dim = 128
    num_classes = len(np.unique(train_y_full)) 

    lstm_model = LSTMClassifier(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes, dropout=0.5)
    lstm_model.to(device)

    print("Training LSTM on real data...")
    acc_before, preds_before, labels_before = train_and_evaluate_lstm(
        lstm_model, train_loader, test_loader, num_epochs, learning_rate, device
    )
    print(f"LSTM Test Accuracy (Real Data) for sample size {size}: {acc_before:.4f}")
    lstm_accuracy_before[size] = acc_before

    generator = Generator(
        latent_dim=latent_dim,
        condition_dim=condition_dim,
        num_classes=num_classes,
        start_dim=latent_dim * 2,
        n_layer=3,
        output_dim=input_dim
    ).to(device)
    discriminator = Discriminator(
        condition_dim=condition_dim,
        num_classes=num_classes,
        start_dim=256,
        n_layer=3,
        input_dim=input_dim
    ).to(device)

    adversarial_loss = nn.BCELoss().to(device)
    optimizer_G = optim.Adam(generator.parameters(), lr=5e-5, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=3e-4, betas=(0.5, 0.999))

    gan_loader = get_loader(train_subset, batch_size=batch_size, shuffle=True)
    for epoch in range(gan_epochs):
        d_loss_epoch, g_loss_epoch = 0.0, 0.0
        for embeddings, labels in gan_loader:
            embeddings = embeddings.to(device)
            labels = torch.tensor(labels, dtype=torch.long).to(device)
            b_size = embeddings.size(0)

            valid = torch.ones(b_size, 1, device=device)
            fake = torch.zeros(b_size, 1, device=device)

            # Train Generator
            optimizer_G.zero_grad()
            z = torch.randn(b_size, latent_dim, device=device)
            gen_data = generator(z, labels)
            g_loss = adversarial_loss(discriminator(gen_data, labels), valid)
            g_loss.backward()
            optimizer_G.step()

            # Train Discriminator
            optimizer_D.zero_grad()
            real_loss = adversarial_loss(discriminator(embeddings, labels), valid)
            fake_loss = adversarial_loss(discriminator(gen_data.detach(), labels), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            d_loss_epoch += d_loss.item()
            g_loss_epoch += g_loss.item()

        print(f"[GAN Epoch {epoch+1}/{gan_epochs}] D loss: {d_loss_epoch/len(gan_loader):.4f}, G loss: {g_loss_epoch/len(gan_loader):.4f}")

    synthetic_data_list = []
    synthetic_labels_list = []
    unique_labels = np.unique(train_y_full)
    for lab in unique_labels:
        lab_tensor = torch.full((generation_size,), lab, dtype=torch.long, device=device)
        z = torch.randn(generation_size, latent_dim, device=device)
        synth = generator(z, lab_tensor).cpu().detach().numpy()
        synthetic_data_list.append(synth)
        synthetic_labels_list.append(np.full((generation_size,), lab))

    synthetic_x = np.concatenate(synthetic_data_list, axis=0)
    synthetic_y = np.concatenate(synthetic_labels_list, axis=0)

    X_train = train_x_full[:size]
    y_train = train_y_full[:size]
    train_combined_x = np.concatenate([X_train, synthetic_x], axis=0)
    train_combined_y = np.concatenate([y_train, synthetic_y], axis=0)

    train_combined_dataset = TensorDataset(
        torch.tensor(train_combined_x, dtype=torch.float),
        torch.tensor(train_combined_y, dtype=torch.long)
    )
    train_combined_loader = DataLoader(train_combined_dataset, batch_size=batch_size, shuffle=True)

    lstm_model_aug = LSTMClassifier(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes, dropout=0.5)
    lstm_model_aug.to(device)

    print("Training LSTM on real + synthetic (concatenated) data...")
    acc_after, preds_after, labels_after = train_and_evaluate_lstm(
        lstm_model_aug, train_combined_loader, test_loader, num_epochs, learning_rate, device
    )
    print(f"LSTM Test Accuracy (After Concatenation) for sample size {size}: {acc_after:.4f}")
    lstm_accuracy_after[size] = acc_after

summary_df = pd.DataFrame(
    [[s, lstm_accuracy_before[s], lstm_accuracy_after[s]] for s in sample_sizes],
    columns=["Train Samples", "Real Only Accuracy", "After Concatenation Accuracy"]
)

print("Accuracy Summary:")
print(tabulate.tabulate(summary_df.values, headers=summary_df.columns, tablefmt="fancy_grid"))

plt.figure(figsize=(8, 5))
plt.plot(sample_sizes, list(lstm_accuracy_before.values()), marker='o', linestyle='-', color='b', markersize=8, label="Real Only")
plt.plot(sample_sizes, list(lstm_accuracy_after.values()), marker='s', linestyle='--', color='r', markersize=8, label="After Concatenation")
plt.xlabel("Training Size", fontsize=14, fontfamily="Times New Roman")
plt.ylabel("Accuracy", fontsize=14, fontfamily="Times New Roman")
plt.title("LSTM Accuracy vs. Training Size (Real vs. Augmented)", fontsize=16, fontfamily="Times New Roman")
plt.xscale("log")
plt.grid(True, linestyle="--", alpha=0.6)
plt.legend()
plt.show()


[Real Data Only] Training size: 20


  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)


Training LSTM on real data...


  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)


Epoch 1: Train Loss=1.1997, Train Acc=0.1000
Epoch 2: Train Loss=0.9333, Train Acc=0.9000
Epoch 3: Train Loss=0.7476, Train Acc=0.9000
Epoch 4: Train Loss=0.6572, Train Acc=0.9000
Epoch 5: Train Loss=0.5251, Train Acc=0.9500
Epoch 6: Train Loss=0.4566, Train Acc=0.9500
Epoch 7: Train Loss=0.3736, Train Acc=0.9500
Epoch 8: Train Loss=0.3260, Train Acc=0.9500
Epoch 9: Train Loss=0.2996, Train Acc=0.9500
Epoch 10: Train Loss=0.2525, Train Acc=0.9500
LSTM Test Accuracy (Real Data) for sample size 20: 0.8796


  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)
  labels = torch.tensor(labels, dtype=torch.long).to(device)


[GAN Epoch 1/50] D loss: 0.6929, G loss: 0.6842
[GAN Epoch 2/50] D loss: 0.6023, G loss: 0.6936
[GAN Epoch 3/50] D loss: 0.5223, G loss: 0.7163
[GAN Epoch 4/50] D loss: 0.4384, G loss: 0.7559
[GAN Epoch 5/50] D loss: 0.3595, G loss: 0.8179
[GAN Epoch 6/50] D loss: 0.2914, G loss: 0.9151
[GAN Epoch 7/50] D loss: 0.2306, G loss: 1.0643
[GAN Epoch 8/50] D loss: 0.1796, G loss: 1.2566
[GAN Epoch 9/50] D loss: 0.1350, G loss: 1.5109
[GAN Epoch 10/50] D loss: 0.1019, G loss: 1.7788
[GAN Epoch 11/50] D loss: 0.0855, G loss: 1.9788
[GAN Epoch 12/50] D loss: 0.0756, G loss: 2.1147
[GAN Epoch 13/50] D loss: 0.0777, G loss: 2.1365
[GAN Epoch 14/50] D loss: 0.0837, G loss: 2.0640
[GAN Epoch 15/50] D loss: 0.0959, G loss: 2.0297
[GAN Epoch 16/50] D loss: 0.0933, G loss: 1.9569
[GAN Epoch 17/50] D loss: 0.0895, G loss: 2.1630
[GAN Epoch 18/50] D loss: 0.1352, G loss: 1.7062
[GAN Epoch 19/50] D loss: 0.0786, G loss: 2.2989
[GAN Epoch 20/50] D loss: 0.0761, G loss: 2.8283
[GAN Epoch 21/50] D loss: 0.1

  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)


LSTM Test Accuracy (Real Data) for sample size 50: 0.8892
[GAN Epoch 1/50] D loss: 0.6981, G loss: 0.7225
[GAN Epoch 2/50] D loss: 0.6115, G loss: 0.7351
[GAN Epoch 3/50] D loss: 0.5324, G loss: 0.7601
[GAN Epoch 4/50] D loss: 0.4466, G loss: 0.8018
[GAN Epoch 5/50] D loss: 0.3614, G loss: 0.8651
[GAN Epoch 6/50] D loss: 0.2873, G loss: 0.9599


  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)
  labels = torch.tensor(labels, dtype=torch.long).to(device)


[GAN Epoch 7/50] D loss: 0.2252, G loss: 1.1047
[GAN Epoch 8/50] D loss: 0.1745, G loss: 1.2958
[GAN Epoch 9/50] D loss: 0.1346, G loss: 1.5210
[GAN Epoch 10/50] D loss: 0.1052, G loss: 1.7572
[GAN Epoch 11/50] D loss: 0.0895, G loss: 1.9297
[GAN Epoch 12/50] D loss: 0.0854, G loss: 1.9851
[GAN Epoch 13/50] D loss: 0.0936, G loss: 1.9177
[GAN Epoch 14/50] D loss: 0.1046, G loss: 1.8468
[GAN Epoch 15/50] D loss: 0.1161, G loss: 1.7651
[GAN Epoch 16/50] D loss: 0.1173, G loss: 1.7828
[GAN Epoch 17/50] D loss: 0.0959, G loss: 1.9864
[GAN Epoch 18/50] D loss: 0.0860, G loss: 1.9626
[GAN Epoch 19/50] D loss: 0.0568, G loss: 2.4892
[GAN Epoch 20/50] D loss: 0.0521, G loss: 2.6460
[GAN Epoch 21/50] D loss: 0.0423, G loss: 2.7720
[GAN Epoch 22/50] D loss: 0.0388, G loss: 2.8190
[GAN Epoch 23/50] D loss: 0.0383, G loss: 2.8330
[GAN Epoch 24/50] D loss: 0.0409, G loss: 2.8747
[GAN Epoch 25/50] D loss: 0.0341, G loss: 3.0358
[GAN Epoch 26/50] D loss: 0.0310, G loss: 3.1073
[GAN Epoch 27/50] D los

  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)


LSTM Test Accuracy (Real Data) for sample size 70: 0.8913
[GAN Epoch 1/50] D loss: 0.7001, G loss: 0.6586
[GAN Epoch 2/50] D loss: 0.6096, G loss: 0.6775
[GAN Epoch 3/50] D loss: 0.5275, G loss: 0.7032
[GAN Epoch 4/50] D loss: 0.4451, G loss: 0.7379
[GAN Epoch 5/50] D loss: 0.3668, G loss: 0.7969


  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)
  labels = torch.tensor(labels, dtype=torch.long).to(device)


[GAN Epoch 6/50] D loss: 0.3003, G loss: 0.8857
[GAN Epoch 7/50] D loss: 0.2430, G loss: 1.0168
[GAN Epoch 8/50] D loss: 0.1899, G loss: 1.2041
[GAN Epoch 9/50] D loss: 0.1453, G loss: 1.4399
[GAN Epoch 10/50] D loss: 0.1098, G loss: 1.7044
[GAN Epoch 11/50] D loss: 0.0907, G loss: 1.9094
[GAN Epoch 12/50] D loss: 0.0823, G loss: 2.0190
[GAN Epoch 13/50] D loss: 0.0860, G loss: 1.9915
[GAN Epoch 14/50] D loss: 0.0968, G loss: 1.8866
[GAN Epoch 15/50] D loss: 0.1117, G loss: 1.8036
[GAN Epoch 16/50] D loss: 0.1197, G loss: 1.7050
[GAN Epoch 17/50] D loss: 0.1167, G loss: 1.9449
[GAN Epoch 18/50] D loss: 0.1513, G loss: 1.4771
[GAN Epoch 19/50] D loss: 0.0783, G loss: 2.0609
[GAN Epoch 20/50] D loss: 0.0494, G loss: 2.6010
[GAN Epoch 21/50] D loss: 0.0511, G loss: 2.7281
[GAN Epoch 22/50] D loss: 0.0676, G loss: 2.2448
[GAN Epoch 23/50] D loss: 0.0715, G loss: 2.2792
[GAN Epoch 24/50] D loss: 0.0429, G loss: 2.7538
[GAN Epoch 25/50] D loss: 0.0408, G loss: 2.9442
[GAN Epoch 26/50] D loss