# Synthetic vs Real Data for Autonomous Vehicle Training

In [None]:
# Import all dependencies

import os
import numpy as np 
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt 
import torchvision.utils as vutils
from PIL import Image
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from torchvision.utils import save_image
from torch.autograd import Variable


from CGAN.CDCGAN import Generator, Discriminator
from Classifier.ResNet import ResidualNetwork


### Training our Conditional Deep Convolutional Generational Adversarial Network

In [None]:
# Create image folder
os.makedirs("sample_images", exist_ok=True); 

# Hyperparameters
latent_dim = 100; 
img_size = 32; 
channels = 1; 
n_classes = 10; 
embedding_dim = 50; 
batch_size = 64; 
n_epochs = 200; 
sample_interval = 400; 
lr = 0.0002; 
b1 = 0.5; 
b2 = 0.999; 

img_shape = (channels, img_size, img_size); 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); 

# Loss
adversarial_loss = nn.BCELoss(); 

# Initialize models
generator = Generator().to(device); 
discriminator = Discriminator().to(device); 

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2)); 
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2));  

# Transform images
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# Load full dataset
dataset = datasets.ImageFolder(
    root="Dataset/Train",
    transform=transform
)

# DataLoaders
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True); 

# Sample generator output
def sample_image(n_row, batches_done):
    z = torch.randn(n_row ** 2, latent_dim).to(device); 
    labels = torch.tensor([i for i in range(n_row) for _ in range(n_row)], dtype=torch.long).to(device); 
    gen_imgs = generator(z, labels); 
    save_image(gen_imgs.data, f"sample_images/{batches_done}.png", nrow=n_row, normalize=True); 


# Training
g_losses = []; 
d_losses = []; 
for epoch in range(n_epochs):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{n_epochs}"); 
    epoch_g_loss = 0; 
    epoch_d_loss = 0; 

    for i, (imgs, labels) in enumerate(pbar):
        real_imgs = imgs.to(device); 
        labels = labels.to(device); 
        batch_size = real_imgs.size(0); 

        valid = torch.ones(batch_size, 1, device=device); 
        fake = torch.zeros(batch_size, 1, device=device); 

        # Train Generator
        optimizer_G.zero_grad();  
        z = torch.randn(batch_size, latent_dim, device=device); 
        gen_labels = torch.randint(0, n_classes, (batch_size,), device=device); 
        gen_imgs = generator(z, gen_labels);  
        g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid); 
        g_loss.backward(); 
        optimizer_G.step(); 

        # Train Discriminator
        optimizer_D.zero_grad(); 
        real_loss = adversarial_loss(discriminator(real_imgs, labels), valid);  
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake); 
        d_loss = (real_loss + fake_loss) / 2; 
        d_loss.backward(); 
        optimizer_D.step(); 

        epoch_g_loss += g_loss.item(); 
        epoch_d_loss += d_loss.item(); 

        pbar.set_postfix(D_loss=f"{d_loss.item():.4f}", G_loss=f"{g_loss.item():.4f}"); 

        if (epoch * len(dataloader) + i) % sample_interval == 0:
            sample_image(n_row=10, batches_done=epoch * len(dataloader) + i);  

    g_losses.append(epoch_g_loss / len(dataloader)); 
    d_losses.append(epoch_d_loss / len(dataloader)); 


plt.figure(); 
plt.plot(g_losses, linestye = '-'); 
plt.title("Generator Loss"); 
plt.xlabel("Epochs"); 
plt.ylabel("Loss"); 
plt.savefig("Generator_Training_Loss.png"); 
plt.close(); 

plt.figure(); 
plt.plot(d_losses, linestye = '-'); 
plt.title("Discriminator Loss"); 
plt.xlabel("Epochs"); 
plt.ylabel("Loss"); 
plt.savefig("Discriminator_Training_Loss.png"); 
plt.close(); 


### Generating our Synthetic Training Dataset

In [None]:
def generate_images_per_class(generator, latent_dim, n_classes, total_per_class = 500, save_dir = "Generated_Dataset/Train"):
    generator.eval(); # Evaluation mode

    os.makedirs(save_dir, exist_ok = True); 

    with torch.no_grad():
        for class_label in range(n_classes):
            class_dir = os.path.join(save_dir, f"{class_label}"); 
            os.makedirs(class_dir, exist_ok = True); 

            num_generated = 0; 
            while num_generated < total_per_class:
                batch_size = min(64, total_per_class - num_generated); 
                z = torch.randn(batch_size, latent_dim).to(device); 
                labels = torch.full((batch_size,), class_label, dtype=torch.long).to(device); 

                gen_imgs = generator(z, labels); 

                for i in range(batch_size):
                    img_tensor = gen_imgs[i].cpu(); 
                    img_path = os.path.join(class_dir, f"{i}.png"); 
                    vutils.save_image(img_tensor, img_path, normalize = True); 

                num_generated += batch_size; 

generate_images_per_class(generator, latent_dim=latent_dim, n_classes = 15, total_per_class = 800, save_dir = "Generated_Dataset/Train"); 


### Training our Residual Network on Fake Data

In [None]:
# Training Loop with Fake Dataset

# Dataset path
data_dir = "Generated_Dataset/Train"; 

# Transforms
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# Load full dataset
dataset = datasets.ImageFolder(
    root=data_dir,
    transform=transform
)

# DataLoaders
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model
num_classes = len(dataset.classes); 
resnet_fake = ResidualNetwork(
    in_channel=3,
    channel_1=64,
    channel_2=64,
    channel_3=128,
    channel_4=256,
    channel_5=512,
    number_classes=num_classes
).to(device); 

# Loss and optimizer
criterion = nn.CrossEntropyLoss(); 
optimizer = optim.Adam(resnet_fake.parameters(), lr=1e-4); 

# Training loop
epochs = 10; 
training_losses = []; 
for epoch in range(epochs):
    resnet_fake.train();  
    running_loss = 0.0; 
    correct = 0; 
    total = 0; 

    for inputs, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
        inputs, labels = inputs.to(device), labels.to(device); 

        # Forward
        outputs = resnet_fake(inputs); 
        loss = criterion(outputs, labels); 

        # Backward
        optimizer.zero_grad(); 
        loss.backward(); 
        optimizer.step(); 

        # Metrics
        running_loss += loss.item() * inputs.size(0); 
        _, predicted = torch.max(outputs, 1); 
        total += labels.size(0); 
        correct += (predicted == labels).sum().item(); 

    epoch_loss = running_loss / len(dataset); 
    training_losses.append(); 
    epoch_acc = correct / total; 

    print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}"); 


plt.figure(); 
plt.plot(training_losses, linestye = '-'); 
plt.title("Training Loss"); 
plt.xlabel("Epochs"); 
plt.ylabel("Loss"); 
plt.savefig("ResNet_Loss_Real.png"); 
plt.close(); 


### Training our Residual Network on Real Data

In [None]:
# Training Loop with Real Dataset

# Dataset path
data_dir = "Dataset/Train"; 

# Transforms
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
]); 

# Load full dataset
dataset = datasets.ImageFolder(
    root=data_dir,
    transform=transform
); 

# Split into train/test using fixed seed
n_total = len(dataset); 
indices = np.random.RandomState(seed=42).permutation(n_total); 
split = int(n_total * 0.8); 
train_indices = indices[:split]; 
test_indices = indices[split:];  

train_dataset = Subset(dataset, train_indices); 
test_dataset = Subset(dataset, test_indices); 

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True); 
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False); 

# Model
num_classes = len(dataset.classes); 
resnet_real = ResidualNetwork(
    in_channel=3,
    channel_1=64,
    channel_2=64,
    channel_3=128,
    channel_4=256,
    channel_5=512,
    number_classes=num_classes
).to(device); 

# Loss and optimizer
criterion = nn.CrossEntropyLoss(); 
optimizer = optim.Adam(resnet_real.parameters(), lr=1e-4); 

# Training loop
epochs = 10; 
training_losses = []; 
for epoch in range(epochs):
    resnet_real.train();  
    running_loss = 0.0; 
    correct = 0; 
    total = 0; 

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        inputs, labels = inputs.to(device), labels.to(device); 

        # Forward
        outputs = resnet_real(inputs); 
        loss = criterion(outputs, labels); 

        # Backward
        optimizer.zero_grad(); 
        loss.backward(); 
        optimizer.step(); 

        # Metrics
        running_loss += loss.item() * inputs.size(0); 
        _, predicted = torch.max(outputs, 1); 
        total += labels.size(0); 
        correct += (predicted == labels).sum().item(); 

    epoch_loss = running_loss / len(dataset); 
    training_losses.append(); 
    epoch_acc = correct / total; 

    print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}"); 


plt.figure(); 
plt.plot(training_losses, linestye = '-'); 
plt.title("Training Loss"); 
plt.xlabel("Epochs"); 
plt.ylabel("Loss"); 
plt.savefig("ResNet_Loss_Real.png"); 
plt.close(); 


### Testing our Trained Models

In [None]:
def test_accuracy(model, dataloader, device):
    model.eval(); 
    correct = 0; 
    total = 0; 

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device); 

            outputs = model(images); 
            _, predicted = torch.max(outputs, 1); 

            correct += (predicted == labels).sum().item(); 
            total += labels.size(0); 

    return correct / total; 

acc_fake = test_accuracy(resnet_fake, test_loader, device); 
print("ResNet Test Accuracy (Trained on Synthetic Data): ", acc_fake); 
acc_real = test_accuracy(resnet_real, test_loader, device); 
print("ResNet Test Accuracy (Trained on Real Data): ", acc_real); 
