In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import os
from sklearn.metrics import accuracy_score, classification_report
from torch.utils.data import DataLoader, Subset, ConcatDataset, TensorDataset
from torch.utils.tensorboard import SummaryWriter
import datetime
import json
import numpy as np
from typing import Tuple
import torchvision.utils as vutils
from torchvision import datasets, transforms

In [2]:
# --- utils.py content ---
def save_gan_images(images: torch.Tensor, epoch: int, n_classes: int = 10, output_dir: str = "outputs/gan_images") -> None:
    os.makedirs(output_dir, exist_ok=True)
    grid = vutils.make_grid(images, nrow=n_classes, normalize=True)
    vutils.save_image(grid, os.path.join(output_dir, f"epoch_{epoch:03d}.png"))

In [3]:
# --- models.py content ---
def weights_init(m: nn.Module) -> None:
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Classifier(nn.Module):
    def __init__(self) -> None:
        super(Classifier, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1, 2),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 1024),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(1024, 10)
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.main(input)

class Generator(nn.Module):
    def __init__(self, nz: int = 100, ngf: int = 64, nc: int = 1, n_classes: int = 10) -> None:
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(n_classes, nz)
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz * 2, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, noise: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        c = self.label_emb(labels)
        x = torch.cat([noise, c], 1)
        x = x.view(x.size(0), -1, 1, 1)
        return self.main(x)

class Discriminator(nn.Module):
    def __init__(self, nc: int = 1, ndf: int = 64, n_classes: int = 10) -> None:
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(n_classes, 32 * 32)
        self.main = nn.Sequential(
            nn.Conv2d(nc + 1, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        label_embedding_reshaped = self.label_embedding(labels).view(labels.size(0), 1, 32, 32)
        d_in = torch.cat((img, label_embedding_reshaped), 1)
        return self.main(d_in)

In [4]:
# --- data_loader.py content ---
class ToTensorLong:
    def __call__(self, y):
        return torch.tensor(y, dtype=torch.long)

def get_dataloaders(
    low_resource_size: int, batch_size: int = 64
) -> Tuple[DataLoader, DataLoader, DataLoader, DataLoader]:
    target_transform = ToTensorLong()
    transform = transforms.Compose(
        [
            transforms.Resize(32),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )

    mnist_train = datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=transform,
        target_transform=target_transform,
    )
    mnist_train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=0)

    svhn_train = datasets.SVHN(
        root="./data",
        split="train",
        download=True,
        transform=transform,
        target_transform=target_transform,
    )
    svhn_test = datasets.SVHN(
        root="./data",
        split="test",
        download=True,
        transform=transform,
        target_transform=target_transform,
    )

    svhn_test_loader = DataLoader(svhn_test, batch_size=batch_size, shuffle=False, num_workers=0)

    targets = np.array(svhn_train.labels)
    indices = []
    samples_per_class = max(1, int(low_resource_size / 10)) if low_resource_size > 0 else 0

    if low_resource_size > 0:
        for i in range(10):
            class_indices = np.where(targets == i)[0]
            num_to_sample = min(samples_per_class, len(class_indices))
            indices.extend(np.random.choice(class_indices, num_to_sample, replace=False))

    svhn_low_resource_subset = Subset(svhn_train, indices)
    svhn_low_resource_loader = DataLoader(
        svhn_low_resource_subset, batch_size=batch_size, shuffle=True, num_workers=0
    )

    transform_augmented = transforms.Compose(
        [
            transforms.Resize(32),
            transforms.Grayscale(num_output_channels=1),
            transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )
    svhn_train_augmented = datasets.SVHN(
        root="./data",
        split="train",
        download=True,
        transform=transform_augmented,
        target_transform=target_transform,
    )
    svhn_augmented_subset = Subset(svhn_train_augmented, indices)
    svhn_trad_aug_loader = DataLoader(
        svhn_augmented_subset, batch_size=batch_size, shuffle=True, num_workers=0
    )

    return (
        mnist_train_loader,
        svhn_low_resource_loader,
        svhn_trad_aug_loader,
        svhn_test_loader,
    )

def get_gan_augmented_loader(
    svhn_low_resource_loader: DataLoader,
    generated_images: torch.Tensor,
    generated_labels: torch.Tensor,
    batch_size: int = 64,
) -> DataLoader:
    generated_labels = generated_labels.to(dtype=torch.long)
    generated_dataset = TensorDataset(generated_images, generated_labels)
    combined_dataset = ConcatDataset(
        [svhn_low_resource_loader.dataset, generated_dataset]
    )
    combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return combined_loader

In [5]:
# --- train_gan.py content ---
def train_gan(
    n_samples: int,
    num_epochs: int = 100,
    nz: int = 100,
    lr: float = 0.0002,
    beta1: float = 0.5,
    batch_size: int = 64,
    ngf: int = 64,
    ndf: int = 64,
    nc: int = 1,
    n_classes: int = 10,
    output_dir: str = "outputs",
) -> None:
    print(f"train_gan: output_dir={output_dir}")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Starting GAN training on device: {device}")

    log_dir = os.path.join(
        output_dir, "tensorboard_logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + f"_GAN_N{n_samples}"
    )
    writer = SummaryWriter(log_dir)
    print(f"TensorBoard logs for GAN training will be saved to: {log_dir}")

    _, svhn_low_resource_loader, _, _ = get_dataloaders(
        low_resource_size=n_samples, batch_size=batch_size
    )
    if len(svhn_low_resource_loader.dataset) == 0:
        print(f"Warning: No low-resource SVHN samples available for n_samples={n_samples}. GAN training skipped.")
        return

    netG = Generator(nz=nz, ngf=ngf, nc=nc, n_classes=n_classes).to(device)
    netD = Discriminator(nc=nc, ndf=ndf, n_classes=n_classes).to(device)

    netG.apply(weights_init)
    netD.apply(weights_init)

    criterion = nn.BCELoss()

    fixed_noise = torch.randn(100, nz, device=device)
    fixed_labels = torch.arange(0, n_classes, device=device).repeat(10)

    real_batch = next(iter(svhn_low_resource_loader))
    writer.add_image("GAN/Real Images", vutils.make_grid(real_batch[0][:100], nrow=10, normalize=True), 0)

    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    print(f"Starting GAN training with {num_epochs} epochs on {n_samples} SVHN samples...")
    for epoch in range(num_epochs):
        for i, (real_images, real_labels) in enumerate(svhn_low_resource_loader):
            real_images = real_images.to(device)
            real_labels = real_labels.to(device)
            b_size = real_images.size(0)

            netD.zero_grad()
            label_real = torch.full((b_size,), 1.0, dtype=torch.float, device=device)
            output = netD(real_images, real_labels).view(-1)
            errD_real = criterion(output, label_real)
            errD_real.backward()

            noise = torch.randn(b_size, nz, device=device)
            fake_labels = torch.randint(0, n_classes, (b_size,), device=device)
            fake_images = netG(noise, fake_labels)
            label_fake = torch.full((b_size,), 0.0, dtype=torch.float, device=device)
            output = netD(fake_images.detach(), fake_labels).view(-1)
            errD_fake = criterion(output, label_fake)
            errD_fake.backward()
            errD = errD_real + errD_fake
            optimizerD.step()

            netG.zero_grad()
            output = netD(fake_images, fake_labels).view(-1)
            errG = criterion(
                output, label_real
            )
            errG.backward()
            optimizerG.step()

        print(
            f"Epoch [{epoch+1}/{num_epochs}] "
            f"Loss_D: {errD.item():.4f} "
            f"Loss_G: {errG.item():.4f}"
        )
        writer.add_scalar("GAN/Loss_D", errD.item(), epoch)
        writer.add_scalar("GAN/Loss_G", errG.item(), epoch)

        if (epoch + 1) % 10 == 0 or (epoch + 1) == num_epochs:
            with torch.no_grad():
                fake = netG(fixed_noise, fixed_labels).detach().cpu()
            gan_images_dir = os.path.join(output_dir, "gan_images")
            os.makedirs(gan_images_dir, exist_ok=True)
            save_gan_images(fake, epoch + 1, n_classes=n_classes, output_dir=gan_images_dir)
            print(f"Saved generated images for epoch {epoch+1}.")
            img_grid = vutils.make_grid(fake, nrow=n_classes, normalize=True)
            writer.add_image("GAN/Generated Images", img_grid, epoch)

    gan_checkpoints_dir = os.path.join(output_dir, "checkpoints")
    os.makedirs(gan_checkpoints_dir, exist_ok=True)
    gan_model_path = os.path.join(gan_checkpoints_dir, f"gan_generator_n{n_samples}.pth")
    torch.save(netG.state_dict(), gan_model_path)
    print(f"Finished GAN Training. Generator model saved to {gan_model_path}")

    writer.close()
    print(f"TensorBoard writer closed. View logs with: tensorboard --logdir {os.path.dirname(log_dir)}")

In [6]:
# --- main_experiment.py content ---
def train_classifier(
    model: nn.Module,
    train_loader: 'DataLoader',
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    current_epoch: int,
    total_epochs: int,
) -> None:
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f"  Epoch {current_epoch}/{total_epochs} Batch {batch_idx}/{len(train_loader)} Loss: {loss.item():.4f}")

def evaluate_classifier(
    model: nn.Module, test_loader: 'DataLoader', device: torch.device
) -> tuple[float, str]:
    model.eval()
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            preds = torch.argmax(output, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    accuracy = accuracy_score(all_targets, all_preds)
    report = classification_report(all_targets, all_preds, digits=4)
    return accuracy, report

def _pretrain_classifier(
    classifier: nn.Module,
    mnist_loader: 'DataLoader',
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    epochs: int,
    output_dir: str,
) -> nn.Module:
    mnist_model_path = os.path.join(output_dir, "checkpoints", "classifier_mnist_pretrained.pth")

    if not os.path.exists(mnist_model_path):
        print(f"Pre-training classifier on MNIST for {epochs} epochs...")
        for epoch in range(epochs):
            train_classifier(classifier, mnist_loader, optimizer, criterion, device, epoch + 1, epochs)
            print(f"MNIST Pre-train Epoch {epoch+1}/{epochs} complete.")
        os.makedirs(os.path.dirname(mnist_model_path), exist_ok=True)
        torch.save(classifier.state_dict(), mnist_model_path)
        print(f"Pre-trained MNIST classifier saved to {mnist_model_path}")
    else:
        print(f"Loading pre-trained MNIST classifier from {mnist_model_path}...")
        classifier.load_state_dict(torch.load(mnist_model_path, map_location=device))
        print("Pre-trained MNIST classifier loaded.")
    return classifier

def _run_source_only_scenario(
    classifier: nn.Module, device: torch.device, output_dir: str
) -> nn.Module:
    print("Scenario: Source Only - Evaluating MNIST pre-trained model directly on SVHN.")
    model_save_path = os.path.join(output_dir, "checkpoints", "classifier_source_only.pth")
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
    torch.save(classifier.state_dict(), model_save_path)
    print(f"Source-only classifier model saved to {model_save_path}")
    return classifier

def _run_fine_tune_scenario(
    classifier: nn.Module,
    svhn_low_loader: 'DataLoader',
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    epochs: int,
    output_dir: str,
) -> nn.Module:
    print(f"Scenario: Fine-tuning on low-resource SVHN data for {epochs} epochs...")
    for epoch in range(epochs):
        train_classifier(classifier, svhn_low_loader, optimizer, criterion, device, epoch + 1, epochs)
        print(f"SVHN Fine-tune Epoch {epoch+1}/{epochs} complete.")
    model_save_path = os.path.join(output_dir, "checkpoints", "classifier_fine_tune.pth")
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
    torch.save(classifier.state_dict(), model_save_path)
    print(f"Fine-tuned classifier model saved to {model_save_path}")
    return classifier

def _run_traditional_aug_scenario(
    classifier: nn.Module,
    svhn_trad_aug_loader: 'DataLoader',
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    epochs: int,
    output_dir: str,
) -> nn.Module:
    print(f"Scenario: Fine-tuning on traditionally augmented low-resource SVHN data for {epochs} epochs...")
    for epoch in range(epochs):
        train_classifier(classifier, svhn_trad_aug_loader, optimizer, criterion, device, epoch + 1, epochs)
        print(f"SVHN Traditional Aug Epoch {epoch+1}/{epochs} complete.")
    model_save_path = os.path.join(output_dir, "checkpoints", "classifier_traditional_aug.pth")
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
    torch.save(classifier.state_dict(), model_save_path)
    print(f"Traditional augmented classifier model saved to {model_save_path}")
    return classifier

def _run_gan_aug_scenario(
    classifier: nn.Module,
    svhn_low_loader: 'DataLoader',
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    args: argparse.Namespace,
    output_dir: str,
    gan_model_base_dir: str = None,
) -> nn.Module:
    print(f"  _run_gan_aug_scenario: output_dir={output_dir}, gan_model_base_dir={gan_model_base_dir}")
    print(f"Scenario: Fine-tuning on GAN-augmented SVHN data for {args.classifier_epochs_finetune} epochs...")
    generator = Generator(nz=args.gan_nz, ngf=args.gan_ngf, nc=args.gan_nc, n_classes=args.gan_n_classes).to(device)
    
    gan_load_dir = gan_model_base_dir if gan_model_base_dir else output_dir
    gan_model_path = os.path.join(gan_load_dir, "checkpoints", f"gan_generator_n{args.n_samples}.pth")
    print(f"  _run_gan_aug_scenario: gan_load_dir={gan_load_dir}, gan_model_path={gan_model_path}")
    
    if not os.path.exists(gan_model_path):
        raise FileNotFoundError(
            f"GAN generator model for n_samples={args.n_samples} not found at {gan_model_path}. "
            f"Please ensure the GAN was trained and saved to the correct output directory."
        )
    generator.load_state_dict(torch.load(gan_model_path, map_location=device))
    generator.eval()

    print(f"Generating {args.n_synthetic} synthetic images using GAN...")
    noise = torch.randn(args.n_synthetic, args.gan_nz, device=device)
    labels = torch.randint(0, args.gan_n_classes, (args.n_synthetic,), device=device)
    with torch.no_grad():
        synthetic_images = generator(noise, labels).cpu()
    print("Synthetic images generated.")

    gan_augmented_loader = get_gan_augmented_loader(
        svhn_low_loader, synthetic_images, labels.cpu(), batch_size=args.batch_size
    )
    print("GAN-augmented data loader created.")

    for epoch in range(args.classifier_epochs_finetune):
        train_classifier(
            classifier, gan_augmented_loader, optimizer, criterion, device, epoch + 1, args.classifier_epochs_finetune
        )
        print(f"SVHN GAN Aug Epoch {epoch+1}/{args.classifier_epochs_finetune} complete.")
    model_save_path = os.path.join(output_dir, "checkpoints", "classifier_gan_aug.pth")
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
    torch.save(classifier.state_dict(), model_save_path)
    print(f"GAN-augmented classifier model saved to {model_save_path}")
    return classifier

def run_experiment(
    args: argparse.Namespace,
    gan_model_base_dir: str = None,
) -> None:
    print(f"run_experiment: args.output_dir={args.output_dir}, gan_model_base_dir={gan_model_base_dir}")
    log_dir = os.path.join(
        args.output_dir, "tensorboard_logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + f"_{args.scenario}_N{args.n_samples}"
    )
    writer = SummaryWriter(log_dir)
    print(f"TensorBoard logs will be saved to: {log_dir}")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"\n===== RUNNING SCENARIO: {args.scenario.upper()} with N={args.n_samples} on {device} =====")

    mnist_loader, svhn_low_loader, svhn_trad_aug_loader, svhn_test_loader = (
        get_dataloaders(low_resource_size=args.n_samples, batch_size=args.batch_size)
    )

    classifier = Classifier().to(device)
    optimizer = optim.Adam(classifier.parameters(), lr=args.classifier_lr)
    criterion = nn.CrossEntropyLoss()

    classifier = _pretrain_classifier(
        classifier, mnist_loader, optimizer, criterion, device, args.classifier_epochs_mnist, args.output_dir
    )

    if args.scenario == "source_only":
        classifier = _run_source_only_scenario(classifier, device, args.output_dir)
    elif args.scenario == "fine_tune":
        classifier = _run_fine_tune_scenario(
            classifier, svhn_low_loader, optimizer, criterion, device, args.classifier_epochs_finetune, args.output_dir
        )
    elif args.scenario == "traditional_aug":
        classifier = _run_traditional_aug_scenario(
            classifier, svhn_trad_aug_loader, optimizer, criterion, device, args.classifier_epochs_finetune, args.output_dir
        )
    elif args.scenario == "gan_aug":
        print(f"run_experiment: Calling _run_gan_aug_scenario with output_dir={args.output_dir}, gan_model_base_dir={gan_model_base_dir}")
        classifier = _run_gan_aug_scenario(
            classifier, svhn_low_loader, optimizer, criterion, device, args, args.output_dir, gan_model_base_dir
        )

    print("\n--- Performing Final Evaluation on SVHN Test Set ---")
    accuracy, report = evaluate_classifier(classifier, svhn_test_loader, device)
    print(f"\n--- Results for Scenario: {args.scenario.upper()} (N={args.n_samples}) ---")
    print(f"Final Accuracy on SVHN Test Set: {accuracy * 100:.2f}%")
    print("Classification Report:")
    print(report)
    print("=" * 60)

    writer.add_scalar("Final Accuracy/SVHN", accuracy, 0)

    results = {
        "scenario": args.scenario,
        "n_samples": args.n_samples,
        "batch_size": args.batch_size,
        "classifier_lr": args.classifier_lr,
        "classifier_epochs_mnist": args.classifier_epochs_mnist,
        "classifier_epochs_finetune": args.classifier_epochs_finetune,
        "gan_nz": args.gan_nz,
        "gan_ngf": args.gan_ngf,
        "gan_ndf": args.gan_ndf,
        "gan_nc": args.gan_nc,
        "gan_n_classes": args.gan_n_classes,
        "n_synthetic": args.n_synthetic,
        "final_accuracy": accuracy,
        "classification_report": report,
        "timestamp": datetime.datetime.now().isoformat(),
        "log_dir": log_dir,
    }

    results_dir = os.path.join(args.output_dir, "results")
    os.makedirs(results_dir, exist_ok=True)
    results_filename = f"{args.scenario}_N{args.n_samples}_results.json"
    results_path = os.path.join(results_dir, results_filename)

    with open(results_path, 'w') as f:
        json.dump(results, f, indent=4)
    print(f"Experiment results saved to: {results_path}")

    writer.close()
    print(f"TensorBoard writer closed. View logs with: tensorboard --logdir {os.path.dirname(log_dir)}")

In [7]:
# --- Experiment Configuration and Execution ---
class Args:
    def __init__(self):
        self.scenario = ""
        self.n_samples = 100
        self.batch_size = 64
        self.classifier_lr = 1e-3
        self.classifier_epochs_mnist = 5
        self.classifier_epochs_finetune = 20
        self.gan_nz = 100
        self.gan_ngf = 64
        self.gan_ndf = 64
        self.gan_nc = 1
        self.gan_n_classes = 10
        self.n_synthetic = 5000
        self.output_dir = "./outputs"

# Define scenarios and n_samples to iterate over
scenarios = ["source_only", "fine_tune", "traditional_aug", "gan_aug"]
n_samples_values = [10, 100, 1000] # Example values, can be adjusted

for n_samples_val in n_samples_values:
    print(f"\n--- Preparing for experiments with N_samples = {n_samples_val} ---")
    # Create a base directory for the current n_samples_val run
    base_experiment_run_dir = os.path.join("./experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + f"_N{n_samples_val}")
    os.makedirs(base_experiment_run_dir, exist_ok=True)
    print(f"Experiment Loop: base_experiment_run_dir={base_experiment_run_dir}")

    # Train GAN once for the current n_samples_val if gan_aug is in scenarios
    if "gan_aug" in scenarios:
        print(f"Experiment Loop: Calling train_gan with output_dir={base_experiment_run_dir}")
        train_gan(
            n_samples=n_samples_val,
            num_epochs=100, # Using a default for GAN training epochs in notebook
            nz=100,
            lr=0.0002, # Using default GAN LR
            beta1=0.5,
            batch_size=64,
            ngf=64,
            ndf=64,
            nc=1,
            n_classes=10,
            output_dir=base_experiment_run_dir, # Save GAN model to the base run directory
        )
        print(f"GAN training for n_samples={n_samples_val} completed.")

    for scenario_name in scenarios:
        print(f"\nRunning experiment for scenario: {scenario_name} with N={n_samples_val}")
        args = Args()
        args.scenario = scenario_name
        args.n_samples = n_samples_val
        # Create a unique output directory for each experiment run, nested under the base run directory
        args.output_dir = os.path.join(base_experiment_run_dir, f"exp_{scenario_name}_N{n_samples_val}")

        # Input Validation (can be moved to a helper function if desired)
        if args.n_samples < 0:
            raise ValueError("n_samples must be a non-negative integer.")
        if args.batch_size <= 0:
            raise ValueError("batch_size must be a positive integer.")
        if args.classifier_lr <= 0:
            raise ValueError("classifier_lr (learning rate) must be a positive float.")
        if args.classifier_epochs_mnist <= 0:
            raise ValueError("classifier_epochs_mnist must be a positive integer.")
        if args.classifier_epochs_finetune <= 0:
            raise ValueError("classifier_epochs_finetune must be a positive integer.")
        if args.gan_nz <= 0:
            raise ValueError("gan_nz (latent vector size) must be a positive integer.")
        if args.gan_ngf <= 0:
            raise ValueError("gan_ngf (generator feature map size) must be a positive integer.")
        if args.gan_ndf <= 0:
            raise ValueError("gan_ndf (discriminator feature map size) must be a positive integer.")
        if args.gan_nc <= 0:
            raise ValueError("gan_nc (number of channels) must be a positive integer.")
        if args.gan_n_classes <= 0:
            raise ValueError("gan_n_classes (number of classes) must be a positive integer.")
        if args.n_synthetic < 0:
            raise ValueError("n_synthetic must be a non-negative integer.")

        print(f"Experiment Loop: Calling run_experiment for scenario={scenario_name}, n_samples={n_samples_val} with args.output_dir={args.output_dir}, gan_model_base_dir={base_experiment_run_dir}")
        run_experiment(args, gan_model_base_dir=base_experiment_run_dir)

print("\nAll experiments completed.")


--- Preparing for experiments with N_samples = 10 ---
Experiment Loop: base_experiment_run_dir=./experiments\20251022-135817_N10
Experiment Loop: Calling train_gan with output_dir=./experiments\20251022-135817_N10
train_gan: output_dir=./experiments\20251022-135817_N10
Starting GAN training on device: cuda:0
TensorBoard logs for GAN training will be saved to: ./experiments\20251022-135817_N10\tensorboard_logs\20251022-135817_GAN_N10
Starting GAN training with 100 epochs on 10 SVHN samples...
Epoch [1/100] Loss_D: 1.4078 Loss_G: 2.1752
Epoch [2/100] Loss_D: 2.3618 Loss_G: 1.5000
Epoch [3/100] Loss_D: 1.5394 Loss_G: 1.6555
Epoch [4/100] Loss_D: 1.0215 Loss_G: 1.8654
Epoch [5/100] Loss_D: 1.2114 Loss_G: 1.3626
Epoch [6/100] Loss_D: 1.2608 Loss_G: 1.7763
Epoch [7/100] Loss_D: 1.0519 Loss_G: 2.3866
Epoch [8/100] Loss_D: 1.1307 Loss_G: 2.1906
Epoch [9/100] Loss_D: 1.1952 Loss_G: 1.8247
Epoch [10/100] Loss_D: 0.7428 Loss_G: 1.9096
Saved generated images for epoch 10.
Epoch [11/100] Loss_D: 0