In [1]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tabulate
parent_dir = os.path.join(os.getcwd(), '..')
if parent_dir not in sys.path: sys.path.append(parent_dir)
from utility.data import get_loader, EmbeddingDataset
from model.gan import Generator, Discriminator
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset
from datasets import load_from_disk
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 128          
condition_dim = 10        
gan_epochs = 20            
generation_size = 1000    
batch_size = 32                

full_dataset = load_from_disk('../data/full_dataset_new', keep_in_memory=True)
split_datasets = full_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split_datasets['train']
test_dataset = split_datasets['test']

test_x = np.array(test_dataset['embedding'])
test_y = np.array(test_dataset['labels'])

train_x_full = np.array(train_dataset['embedding'])
train_y_full = np.array(train_dataset['labels'])

In [3]:


train_ds = EmbeddingDataset(train_dataset)
test_ds = EmbeddingDataset(test_dataset)
test_loader = get_loader(test_ds, batch_size=batch_size, shuffle=False)  

In [None]:

sample_sizes = [20, 50, 100, 200, 1000, len(train_ds)]
knn_accuracy_before = {}
knn_accuracy_after = {}

for size in sample_sizes:
    print(f"\n[Real Data Only] Training size: {size}")
    X_train = train_x_full[:size]
    y_train = train_y_full[:size]

    knn_real = KNeighborsClassifier(n_neighbors=1, n_jobs=-1)
    knn_real.fit(X_train, y_train)
    pred_before = knn_real.predict(test_x)
    acc_before = accuracy_score(test_y, pred_before)
    print(f"KNN Accuracy (Real Data) for size {size}: {acc_before:.4f}")
    knn_accuracy_before[size] = acc_before

    print("Training Conditional GAN on the same subset...")
    train_subset = Subset(train_ds, range(size))
    gan_loader = get_loader(train_subset, batch_size=batch_size, shuffle=True)

    input_dim = X_train.shape[1]  
    unique_labels = np.unique(train_y_full)
    num_classes_gan = len(unique_labels)

    generator = Generator(
        latent_dim=latent_dim,
        condition_dim=condition_dim,
        num_classes=num_classes_gan,
        start_dim=latent_dim * 2,
        n_layer=3,
        output_dim=input_dim
    ).to(device)
    discriminator = Discriminator(
        condition_dim=condition_dim,
        num_classes=num_classes_gan,
        start_dim=256,
        n_layer=3,
        input_dim=input_dim
    ).to(device)

    adversarial_loss = nn.BCELoss().to(device)
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    for epoch in range(gan_epochs):
        d_loss_epoch, g_loss_epoch = 0.0, 0.0
        for embeddings, labels in gan_loader:
            embeddings = embeddings.to(device)
            labels = torch.tensor(labels, dtype=torch.long).to(device)
            b_size = embeddings.size(0)

            valid = torch.ones(b_size, 1, device=device)
            fake = torch.zeros(b_size, 1, device=device)

            # Train Generator
            optimizer_G.zero_grad()
            z = torch.randn(b_size, latent_dim, device=device)
            gen_data = generator(z, labels)
            g_loss = adversarial_loss(discriminator(gen_data, labels), valid)
            g_loss.backward()
            optimizer_G.step()

            # Train Discriminator
            optimizer_D.zero_grad()
            real_loss = adversarial_loss(discriminator(embeddings, labels), valid)
            fake_loss = adversarial_loss(discriminator(gen_data.detach(), labels), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            d_loss_epoch += d_loss.item()
            g_loss_epoch += g_loss.item()

        print(f"[GAN Epoch {epoch+1}/{gan_epochs}] D loss: {d_loss_epoch/len(gan_loader):.4f}, G loss: {g_loss_epoch/len(gan_loader):.4f}")


    synthetic_data_list = []
    synthetic_labels_list = []
    for lab in unique_labels:
        lab_tensor = torch.full((generation_size,), lab, dtype=torch.long, device=device)
        z = torch.randn(generation_size, latent_dim, device=device)
        synth = generator(z, lab_tensor).cpu().detach().numpy()
        synthetic_data_list.append(synth)
        synthetic_labels_list.append(np.full((generation_size,), lab))

    synthetic_x = np.concatenate(synthetic_data_list, axis=0)
    synthetic_y = np.concatenate(synthetic_labels_list, axis=0)

    # (C1) 將真實資料與合成資料合併
    train_combined_x = np.concatenate([X_train, synthetic_x], axis=0)
    train_combined_y = np.concatenate([y_train, synthetic_y], axis=0)

    # (C2) 用 KNN 訓練 (After Concatenation)
    knn_aug = KNeighborsClassifier(n_neighbors=5, n_jobs=-1)
    knn_aug.fit(train_combined_x, train_combined_y)
    pred_after = knn_aug.predict(test_x)
    acc_after = accuracy_score(test_y, pred_after)
    print(f"KNN Accuracy (After Concatenation) for size {size}: {acc_after:.4f}")
    knn_accuracy_after[size] = acc_after


summary_df = pd.DataFrame(
    [[s, knn_accuracy_before[s], knn_accuracy_after[s]] for s in sample_sizes],
    columns=["Train Samples", "Real Only Accuracy", "After Concatenation Accuracy"]
)
print("\nAccuracy Summary:")
print(summary_df)
print(tabulate.tabulate(summary_df.values, headers=summary_df.columns, tablefmt="fancy_grid"))

plt.figure(figsize=(8, 5))
plt.plot(sample_sizes, [knn_accuracy_before[s] for s in sample_sizes], marker='o', linestyle='-', color='b', markersize=8, label="Real Only")
plt.plot(sample_sizes, [knn_accuracy_after[s] for s in sample_sizes], marker='s', linestyle='--', color='r', markersize=8, label="After Concatenation")
plt.xlabel("Training Size", fontsize=14, fontfamily="Times New Roman")
plt.ylabel("Accuracy", fontsize=14, fontfamily="Times New Roman")
plt.title("KNN Accuracy vs. Training Size (Real vs. Augmented)", fontsize=16, fontfamily="Times New Roman")
plt.xscale("log")
plt.grid(True, linestyle="--", alpha=0.6)
plt.legend()
plt.show()


[Real Data Only] Training size: 20
KNN Accuracy (Real Data) for size 20: 0.8460
Training Conditional GAN on the same subset...


  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)
  labels = torch.tensor(labels, dtype=torch.long).to(device)


[GAN Epoch 1/20] D loss: 0.7049, G loss: 0.6626
[GAN Epoch 2/20] D loss: 0.6479, G loss: 0.6508
[GAN Epoch 3/20] D loss: 0.6038, G loss: 0.6409
[GAN Epoch 4/20] D loss: 0.5628, G loss: 0.6337
[GAN Epoch 5/20] D loss: 0.5284, G loss: 0.6207
[GAN Epoch 6/20] D loss: 0.5010, G loss: 0.6083
[GAN Epoch 7/20] D loss: 0.4827, G loss: 0.6010
[GAN Epoch 8/20] D loss: 0.4607, G loss: 0.6209
[GAN Epoch 9/20] D loss: 0.4341, G loss: 0.6757
[GAN Epoch 10/20] D loss: 0.4095, G loss: 0.7573
[GAN Epoch 11/20] D loss: 0.4015, G loss: 0.8203
[GAN Epoch 12/20] D loss: 0.4180, G loss: 0.8080
[GAN Epoch 13/20] D loss: 0.4574, G loss: 0.7507
[GAN Epoch 14/20] D loss: 0.4949, G loss: 0.7162
[GAN Epoch 15/20] D loss: 0.4814, G loss: 0.7399
[GAN Epoch 16/20] D loss: 0.4234, G loss: 0.8254
[GAN Epoch 17/20] D loss: 0.3430, G loss: 0.9918
[GAN Epoch 18/20] D loss: 0.2763, G loss: 1.1296
[GAN Epoch 19/20] D loss: 0.2425, G loss: 1.2242
[GAN Epoch 20/20] D loss: 0.2301, G loss: 1.2716
KNN Accuracy (After Concatena

  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)
  labels = torch.tensor(labels, dtype=torch.long).to(device)


[GAN Epoch 4/20] D loss: 0.4745, G loss: 0.6224
[GAN Epoch 5/20] D loss: 0.4230, G loss: 0.7250
[GAN Epoch 6/20] D loss: 0.4106, G loss: 0.8234
[GAN Epoch 7/20] D loss: 0.4945, G loss: 0.7386
[GAN Epoch 8/20] D loss: 0.5396, G loss: 0.7087
[GAN Epoch 9/20] D loss: 0.3819, G loss: 0.9494
[GAN Epoch 10/20] D loss: 0.2468, G loss: 1.2681
[GAN Epoch 11/20] D loss: 0.2037, G loss: 1.3760
[GAN Epoch 12/20] D loss: 0.2105, G loss: 1.3703
[GAN Epoch 13/20] D loss: 0.1841, G loss: 1.5216
[GAN Epoch 14/20] D loss: 0.1437, G loss: 1.8420
[GAN Epoch 15/20] D loss: 0.1817, G loss: 2.2223
[GAN Epoch 16/20] D loss: 0.3318, G loss: 2.2624
[GAN Epoch 17/20] D loss: 0.3581, G loss: 1.8496
[GAN Epoch 18/20] D loss: 0.1813, G loss: 2.4451
[GAN Epoch 19/20] D loss: 0.3146, G loss: 2.0084
[GAN Epoch 20/20] D loss: 0.1902, G loss: 1.4204
KNN Accuracy (After Concatenation) for size 50: 0.8943

[Real Data Only] Training size: 100
KNN Accuracy (Real Data) for size 100: 0.8425
Training Conditional GAN on the sam

  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)
  labels = torch.tensor(labels, dtype=torch.long).to(device)


[GAN Epoch 1/20] D loss: 0.6115, G loss: 0.6617
[GAN Epoch 2/20] D loss: 0.4904, G loss: 0.6076
[GAN Epoch 3/20] D loss: 0.4134, G loss: 0.7509
[GAN Epoch 4/20] D loss: 0.4996, G loss: 0.7213
[GAN Epoch 5/20] D loss: 0.3145, G loss: 1.0206
[GAN Epoch 6/20] D loss: 0.2361, G loss: 1.2285
[GAN Epoch 7/20] D loss: 0.2385, G loss: 1.3603
[GAN Epoch 8/20] D loss: 0.1866, G loss: 1.9101
[GAN Epoch 9/20] D loss: 0.3001, G loss: 1.7873
[GAN Epoch 10/20] D loss: 0.2250, G loss: 2.5563
[GAN Epoch 11/20] D loss: 0.1442, G loss: 2.1605
[GAN Epoch 12/20] D loss: 0.1090, G loss: 2.2973
[GAN Epoch 13/20] D loss: 0.0801, G loss: 2.6134
[GAN Epoch 14/20] D loss: 0.0830, G loss: 2.9463
[GAN Epoch 15/20] D loss: 0.0621, G loss: 2.9888
[GAN Epoch 16/20] D loss: 0.0436, G loss: 3.3102
[GAN Epoch 17/20] D loss: 0.0833, G loss: 2.9875
[GAN Epoch 18/20] D loss: 0.6591, G loss: 3.9739
[GAN Epoch 19/20] D loss: 0.1033, G loss: 4.6429
[GAN Epoch 20/20] D loss: 0.2830, G loss: 3.0302
KNN Accuracy (After Concatena

  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)
  labels = torch.tensor(labels, dtype=torch.long).to(device)


[GAN Epoch 1/20] D loss: 0.5802, G loss: 0.6108
[GAN Epoch 2/20] D loss: 0.4588, G loss: 0.7364
[GAN Epoch 3/20] D loss: 0.3419, G loss: 1.0125
[GAN Epoch 4/20] D loss: 0.2353, G loss: 1.3724
[GAN Epoch 5/20] D loss: 0.2408, G loss: 2.0605
[GAN Epoch 6/20] D loss: 0.1906, G loss: 2.0881
[GAN Epoch 7/20] D loss: 0.1022, G loss: 2.6905
[GAN Epoch 8/20] D loss: 0.1020, G loss: 3.3919
[GAN Epoch 9/20] D loss: 0.0717, G loss: 4.3892
[GAN Epoch 10/20] D loss: 0.2886, G loss: 3.1978
[GAN Epoch 11/20] D loss: 0.2148, G loss: 2.0739
[GAN Epoch 12/20] D loss: 0.6493, G loss: 2.6939
[GAN Epoch 13/20] D loss: 0.2876, G loss: 1.8260
[GAN Epoch 14/20] D loss: 0.2365, G loss: 2.2668
[GAN Epoch 15/20] D loss: 0.2189, G loss: 2.9451
[GAN Epoch 16/20] D loss: 0.2060, G loss: 3.2247
[GAN Epoch 17/20] D loss: 0.1338, G loss: 2.3465
[GAN Epoch 18/20] D loss: 0.1694, G loss: 3.0697
[GAN Epoch 19/20] D loss: 0.1796, G loss: 3.5073
[GAN Epoch 20/20] D loss: 0.1035, G loss: 2.7889
KNN Accuracy (After Concatena

  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)
  labels = torch.tensor(labels, dtype=torch.long).to(device)


[GAN Epoch 1/20] D loss: 0.3720, G loss: 1.0824
[GAN Epoch 2/20] D loss: 0.1025, G loss: 2.7712
[GAN Epoch 3/20] D loss: 0.3210, G loss: 3.2070
[GAN Epoch 4/20] D loss: 0.2578, G loss: 2.8544
[GAN Epoch 5/20] D loss: 0.2750, G loss: 3.1456
[GAN Epoch 6/20] D loss: 0.4444, G loss: 2.1637
[GAN Epoch 7/20] D loss: 0.3701, G loss: 1.9414
[GAN Epoch 8/20] D loss: 0.4562, G loss: 1.8792
[GAN Epoch 9/20] D loss: 0.4307, G loss: 2.2090
[GAN Epoch 10/20] D loss: 0.4253, G loss: 1.9954
[GAN Epoch 11/20] D loss: 0.3923, G loss: 2.1227
[GAN Epoch 12/20] D loss: 0.4281, G loss: 1.9954
[GAN Epoch 13/20] D loss: 0.4310, G loss: 2.2953
[GAN Epoch 14/20] D loss: 0.4184, G loss: 1.8427
[GAN Epoch 15/20] D loss: 0.4101, G loss: 1.8701
[GAN Epoch 16/20] D loss: 0.3974, G loss: 2.1974
[GAN Epoch 17/20] D loss: 0.4261, G loss: 1.9438
[GAN Epoch 18/20] D loss: 0.4137, G loss: 1.9791
[GAN Epoch 19/20] D loss: 0.4446, G loss: 1.7040
[GAN Epoch 20/20] D loss: 0.4422, G loss: 1.8165
KNN Accuracy (After Concatena

  emb = torch.tensor(self.embeddings[idx], dtype=torch.float)
  labels = torch.tensor(labels, dtype=torch.long).to(device)


[GAN Epoch 1/20] D loss: 0.2990, G loss: 2.4653
[GAN Epoch 2/20] D loss: 0.4041, G loss: 2.3616
[GAN Epoch 3/20] D loss: 0.4308, G loss: 1.9872
[GAN Epoch 4/20] D loss: 0.4827, G loss: 1.5499
[GAN Epoch 5/20] D loss: 0.4710, G loss: 1.5354
[GAN Epoch 6/20] D loss: 0.5135, G loss: 1.3799
