In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
import torch
import numpy as np
import torch
from scipy.linalg import sqrtm
import torch
import torch.optim as optim
import os
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


In [3]:
def plot_batch_images(images, labels):
    batch_size = images.shape[0]
    nrow = int(np.sqrt(batch_size))
    grid_img = torchvision.utils.make_grid(images, nrow=nrow, normalize=True)  # Arrange images in grid
    plt.figure(figsize=(10, 10))
    plt.imshow(grid_img.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
    plt.axis("off")
    plt.title("MNIST Batch Samples")
    plt.show()

In [4]:

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),  # 28x28 → 28x28
            nn.ReLU(),
            nn.MaxPool2d(2),  # 28x28 → 14x14

            nn.Conv2d(32, 64, 3, padding=1),  # 14x14 → 14x14
            nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14 → 7x7
        )

        classifier = nn.Sequential(
                    nn.Flatten(),
                    nn.Linear(64 * 7 * 7, 128),
                    nn.ReLU(),
                    nn.Linear(128, 10)
        )
        self.classifier = classifier.to(device)  # Moved to GPU

    def forward(self, x):
        feats = self.features(x)
        logits = self.classifier(feats)
        return logits, feats.view(x.size(0), -1)  # logits, feature_vector

In [5]:


def train_mnist_classifier(n_epochs=5, save_path="mnist_classifier.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=64)

    model = MNISTClassifier().to(device)
    model = model.to(device)  # Moved to GPU
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    print("📚 Training MNIST classifier...")
    for epoch in range(n_epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            out, _ = model(x)
            loss = criterion(out, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{n_epochs} ✅ Loss: {loss.item():.4f}")

    # Evaluate
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out, _ = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    acc = correct / total
    print(f" Test Accuracy: {acc*100:.2f}%")

    #torch.save(model.state_dict(), save_path)
    #print(f" Saved classifier to {save_path}")
    return model

In [6]:
classifier = train_mnist_classifier(n_epochs=5)
classifier = classifier.to(device)  # Moved to GPU

100%|██████████| 9.91M/9.91M [00:02<00:00, 4.87MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 133kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.27MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.11MB/s]


📚 Training MNIST classifier...
Epoch 1/5 ✅ Loss: 0.2158
Epoch 2/5 ✅ Loss: 0.0046
Epoch 3/5 ✅ Loss: 0.0157
Epoch 4/5 ✅ Loss: 0.0396
Epoch 5/5 ✅ Loss: 0.0253
 Test Accuracy: 99.08%


In [7]:
def load_MNISTdata():
    data_dir = os.path.join(os.getcwd(), "..", "data", "raw")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.MNIST(root=data_dir, train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root=data_dir, train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    return train_loader, test_loader

def loader_to_data(train_loader):
    all_imgs = []

    for imgs, _ in train_loader:
        imgs = imgs.to(device)  # Move each batch to device
        all_imgs.append(imgs)

    data = torch.cat(all_imgs, dim=0)  # Already on device
    print(f"Loaded: {data.shape}")
    return data


In [8]:
train_loader, _ = load_MNISTdata()
X_mnist = loader_to_data(train_loader)

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.07MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 132kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.28MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.59MB/s]


Loaded: torch.Size([60000, 1, 28, 28])


In [9]:
class Discriminator(nn.Module):
    def __init__(self, img_channels=3):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(img_channels, 64, 4, stride=2, padding=1)),  # 28 → 14
            nn.LeakyReLU(0.2),

            nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, stride=2, padding=1)),  # 14 → 7
            nn.LeakyReLU(0.2),

            nn.utils.spectral_norm(nn.Conv2d(128, 256, 3, stride=2, padding=1)),  # 7 → 4
            nn.LeakyReLU(0.2),

            nn.Flatten(),
            nn.utils.spectral_norm(nn.Linear(256 * 4 * 4, 1))
            # No Sigmoid!
        )

    def forward(self, x):
        return self.net(x)




class Generator(nn.Module):
    def __init__(self, z_dim=100, img_channels=3):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 256 * 7 * 7),
            nn.BatchNorm1d(256 * 7 * 7),
            nn.ReLU(True),

            nn.Unflatten(1, (256, 7, 7)),

            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # 7x7 → 14x14
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 14x14 → 28x28
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.Conv2d(64, img_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # Output in [-1, 1]
        )

    def forward(self, z):
        return self.net(z)

Basic

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler, TensorDataset
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt

# Set global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def train_BasicGAN(generator, discriminator, dataloader, lr=0.0001, criterion=None,
                   latent_dim=100, n_epochs=20, plotit=False):

    generator = generator.to(device)
    discriminator = discriminator.to(device)

    if criterion is None:
        criterion = nn.BCEWithLogitsLoss()

    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

    d_losses, g_losses = [], []

    for epoch in range(n_epochs):
        for imgs, _ in dataloader:
            imgs = imgs.to(device)
            batch_size = imgs.size(0)

            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)

            # Train Discriminator
            z = torch.randn(batch_size, latent_dim, device=device)
            fake_imgs = generator(z)

            d_optimizer.zero_grad()
            real_loss = criterion(discriminator(imgs), real_labels)
            fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
            d_loss = real_loss + fake_loss
            d_loss.backward()
            d_optimizer.step()

            # Train Generator
            z = torch.randn(batch_size, latent_dim, device=device)
            fake_imgs = generator(z)
            g_optimizer.zero_grad()
            g_loss = criterion(discriminator(fake_imgs), real_labels)
            g_loss.backward()
            g_optimizer.step()

        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())
        print(f"Epoch [{epoch+1}/{n_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    if plotit:
        plt.figure(figsize=(10, 5))
        plt.plot(d_losses, label="Discriminator Loss")
        plt.plot(g_losses, label="Generator Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()

    return generator

def generate_BasicGAN(generator, latent_dim=100, num_samples=16, plotit=False):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim, device=device)
        generated_images = generator(z)

    if plotit:
        plot_batch_images(generated_images, None)

    return generated_images

def get_classifier_features(classifier, data, batch_size=128):
    classifier.eval()
    classifier = classifier.to(device)
    features = []
    with torch.no_grad():
        for i in range(0, len(data), batch_size):
            batch = torch.tensor(data[i:i+batch_size], dtype=torch.float32, device=device)
            if batch.ndim == 3:
                batch = batch.unsqueeze(1)
            _, feats = classifier(batch)
            features.append(feats.cpu())
    return torch.cat(features, dim=0).numpy()

def compute_fid(real_feats, fake_feats):
    from scipy.linalg import sqrtm
    mu1, mu2 = real_feats.mean(0), fake_feats.mean(0)
    sigma1 = np.cov(real_feats, rowvar=False)
    sigma2 = np.cov(fake_feats, rowvar=False)
    covmean = sqrtm(sigma1 @ sigma2)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = np.sum((mu1 - mu2) ** 2) + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

def compute_precision_recall(real_images, generated_images, k=3, eps=0.1):
    from sklearn.neighbors import NearestNeighbors
    from sklearn.preprocessing import normalize

    def preprocess(images):
        if isinstance(images, torch.Tensor):
            images = images.to(device)
            images = images.detach().cpu().numpy()
        images = images.astype(np.float32)
        if images.ndim == 4:
            images = images.reshape(images.shape[0], -1)
        elif images.ndim == 3:
            images = images.reshape(images.shape[0], -1)
        return normalize(images)

    real_vecs = preprocess(real_images)
    gen_vecs = preprocess(generated_images)

    nn_real = NearestNeighbors(n_neighbors=k).fit(real_vecs)
    dists_real_to_gen, _ = nn_real.kneighbors(gen_vecs)
    precision = np.mean(np.min(dists_real_to_gen, axis=1) < eps)

    nn_gen = NearestNeighbors(n_neighbors=k).fit(gen_vecs)
    dists_gen_to_real, _ = nn_gen.kneighbors(real_vecs)
    recall = np.mean(np.min(dists_gen_to_real, axis=1) < eps)

    return precision, recall

def evaluate_gan(gen_images, classifier, real_data):
    classifier = classifier.to(device)

    fake = gen_images
    if fake.ndim == 3:
        fake = fake[:, None, :, :]
    elif fake.shape[1] != 1:
        fake = fake[:, :1, :, :]

    real_feats = get_classifier_features(classifier, real_data)
    fake_feats = get_classifier_features(classifier, fake)

    prec, rec = compute_precision_recall(real_data, fake, eps=0.4)
    fid = compute_fid(real_feats, fake_feats)

    return {"Precision": prec, "Recall": rec, "FID": fid}


Using device: cuda


In [None]:
G = Generator(z_dim=100, img_channels=1)
D = Discriminator(img_channels=1)

trained_G = train_BasicGAN(
    generator=G,
    discriminator=D,
    dataloader=train_loader,
    lr=0.0002,
    latent_dim=100,
    n_epochs=50
)

Epoch [1/50] | D Loss: 0.8522 | G Loss: 1.2559
Epoch [2/50] | D Loss: 1.1778 | G Loss: 0.7849
Epoch [3/50] | D Loss: 1.1673 | G Loss: 0.9453
Epoch [4/50] | D Loss: 1.2490 | G Loss: 0.8069
Epoch [5/50] | D Loss: 1.3686 | G Loss: 0.8480
Epoch [6/50] | D Loss: 1.3036 | G Loss: 0.7776
Epoch [7/50] | D Loss: 1.2714 | G Loss: 0.8059
Epoch [8/50] | D Loss: 1.2598 | G Loss: 0.8056
Epoch [9/50] | D Loss: 1.3428 | G Loss: 0.7431
Epoch [10/50] | D Loss: 1.3213 | G Loss: 0.7802
Epoch [11/50] | D Loss: 1.2911 | G Loss: 0.7459
Epoch [12/50] | D Loss: 1.3398 | G Loss: 0.7654
Epoch [13/50] | D Loss: 1.3498 | G Loss: 0.7475
Epoch [14/50] | D Loss: 1.3169 | G Loss: 0.7872
Epoch [15/50] | D Loss: 1.3141 | G Loss: 0.7487
Epoch [16/50] | D Loss: 1.3070 | G Loss: 0.7484
Epoch [17/50] | D Loss: 1.3725 | G Loss: 0.7481
Epoch [18/50] | D Loss: 1.3544 | G Loss: 0.7346
Epoch [19/50] | D Loss: 1.3466 | G Loss: 0.6983
Epoch [20/50] | D Loss: 1.3598 | G Loss: 0.7734
Epoch [21/50] | D Loss: 1.3661 | G Loss: 0.7204
E

In [None]:
gen_images=generate_BasicGAN(G, num_samples=1000, plotit=False)

In [None]:
results = {}
results["BasicGAN"] = evaluate_gan(gen_images, classifier, real_data=X_mnist[:1000])
results["BasicGAN"]

  batch = torch.tensor(data[i:i+batch_size], dtype=torch.float32, device=device)


{'Precision': np.float64(0.374),
 'Recall': np.float64(0.371),
 'FID': np.float64(142.83780519641408)}

DIA

In [18]:
def safe_tensor_conversion(data):
    if not isinstance(data, torch.Tensor):
        return torch.tensor(data, dtype=torch.float32)
    return data.clone().detach().float()

def train_DiaGAN(generator, discriminator, data, lr=0.0005, latent_dim=100, n_epochs=200,
                 phase1_ratio=0.9, batch_size=128, window_size=25, k=1.0,
                 min_clip=0.01, max_ratio=50):

    generator = generator.to(device)
    discriminator = discriminator.to(device)

    criterion = nn.BCEWithLogitsLoss()
    data = safe_tensor_conversion(data)
    tensor_data = data.to(device)

    is_image = (data.ndim > 2)
    conv_mode = any(isinstance(m, nn.Conv2d) for m in discriminator.modules())
    dataset_size = len(tensor_data)
    ldr_dict = defaultdict(list)

    for epoch in range(n_epochs):
        in_phase1 = epoch < int(phase1_ratio * n_epochs)

        if in_phase1:
            sampler = torch.utils.data.RandomSampler(tensor_data)
        else:
            scores = []
            for i in range(dataset_size):
                ldrs = np.array(ldr_dict[i][-window_size:])
                if len(ldrs) == 0:
                    scores.append(1.0)
                else:
                    ldrm = np.mean(ldrs)
                    ldrv = np.var(ldrs)
                    score = ldrm + k * np.sqrt(ldrv)
                    scores.append(score)

            scores = np.clip(scores, a_min=min_clip, a_max=min_clip * max_ratio)
            probs = scores / scores.sum()
            sampler = WeightedRandomSampler(probs, num_samples=dataset_size, replacement=True)

        dataloader = DataLoader(TensorDataset(tensor_data), batch_size=batch_size, sampler=sampler)

        for real_batch, in dataloader:
            imgs = real_batch.to(device)
            bs = imgs.size(0)

            real_labels = torch.ones(bs, 1, device=device)
            fake_labels = torch.zeros(bs, 1, device=device)

            z = torch.randn(bs, latent_dim, device=device)
            fake_imgs = generator(z)

            real_input = imgs if conv_mode else imgs.view(bs, -1)
            fake_input = fake_imgs if conv_mode else fake_imgs.view(bs, -1)

            d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
            d_optimizer.zero_grad()
            real_out = discriminator(real_input)
            fake_out = discriminator(fake_input.detach())
            d_loss = criterion(real_out, real_labels) + criterion(fake_out, fake_labels)
            d_loss.backward()
            d_optimizer.step()

            g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
            g_optimizer.zero_grad()
            fake_out = discriminator(fake_input)
            g_loss = criterion(fake_out, real_labels)
            g_loss.backward()
            g_optimizer.step()

        with torch.no_grad():
            D_x = discriminator(tensor_data if conv_mode else tensor_data.view(len(tensor_data), -1)).squeeze()
            D_x_sigmoid = torch.sigmoid(D_x)
            LDR_x = torch.log(D_x_sigmoid / (1 - D_x_sigmoid + 1e-6)).cpu().numpy()
            for i in range(dataset_size):
                ldr_dict[i].append(LDR_x[i])

        if epoch % 2 == 0:
            print(f"[Epoch {epoch}] D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    return generator


In [19]:
diaG = Generator(z_dim=100, img_channels=1)
diaD = Discriminator(img_channels=1)

trained_diaG = train_DiaGAN(
    generator=diaG,
    discriminator=diaD,
    data=X_mnist,
    lr=0.0002,
    latent_dim=100,
    n_epochs=50,
    phase1_ratio=0.5,
    k=0.3,
    batch_size=64
)

[Epoch 0] D Loss: 1.3173 | G Loss: 0.6141
[Epoch 2] D Loss: 1.3033 | G Loss: 0.6896
[Epoch 4] D Loss: 1.3112 | G Loss: 0.7080
[Epoch 6] D Loss: 1.2864 | G Loss: 0.8916
[Epoch 8] D Loss: 1.3825 | G Loss: 0.7072
[Epoch 10] D Loss: 1.3041 | G Loss: 0.8523
[Epoch 12] D Loss: 1.3243 | G Loss: 0.8208
[Epoch 14] D Loss: 1.3543 | G Loss: 0.7682
[Epoch 16] D Loss: 1.3582 | G Loss: 0.8096
[Epoch 18] D Loss: 1.4081 | G Loss: 0.7545
[Epoch 20] D Loss: 1.3605 | G Loss: 0.7602
[Epoch 22] D Loss: 1.3318 | G Loss: 0.7363
[Epoch 24] D Loss: 1.3429 | G Loss: 0.7864
[Epoch 26] D Loss: 1.3103 | G Loss: 0.7300
[Epoch 28] D Loss: 1.3442 | G Loss: 0.7462
[Epoch 30] D Loss: 1.3361 | G Loss: 0.8141
[Epoch 32] D Loss: 1.3410 | G Loss: 0.7386
[Epoch 34] D Loss: 1.3032 | G Loss: 0.8371
[Epoch 36] D Loss: 1.3154 | G Loss: 0.7097
[Epoch 38] D Loss: 1.3334 | G Loss: 0.7810
[Epoch 40] D Loss: 1.3144 | G Loss: 0.6995
[Epoch 42] D Loss: 1.2913 | G Loss: 0.8007
[Epoch 44] D Loss: 1.3144 | G Loss: 0.7053
[Epoch 46] D Los

In [20]:
gen_images_dia=generate_BasicGAN(diaG, num_samples=1000)
results = {}
results["DiaGAN"] = evaluate_gan(gen_images_dia, classifier, real_data=X_mnist[:1000])
results["DiaGAN"]

  batch = torch.tensor(data[i:i+batch_size], dtype=torch.float32, device=device)


{'Precision': np.float64(0.361),
 'Recall': np.float64(0.401),
 'FID': np.float64(166.76698073824383)}

BASIC + DRS

In [15]:
def apply_DRS(generator, real_data, z_dim=100, num_gen=10000, batch_size=128, n_epochs=5):
    generator = generator.to(device)

    # Prepare real data
    if not isinstance(real_data, torch.Tensor):
        real_tensor = torch.tensor(real_data, dtype=torch.float32)
    else:
        real_tensor = real_data.clone().detach().float()

    real_tensor = real_tensor.to(device)
    real_labels = torch.ones(len(real_tensor), 1, device=device)

    # Generate fake data
    generator.eval()
    with torch.no_grad():
        z = torch.randn(num_gen, z_dim, device=device)
        fake_tensor = generator(z).detach()
    fake_labels = torch.zeros(len(fake_tensor), 1, device=device)

    # Combine real and fake
    combined_data = torch.cat([real_tensor, fake_tensor], dim=0)
    combined_labels = torch.cat([real_labels, fake_labels], dim=0)
    loader = DataLoader(TensorDataset(combined_data, combined_labels), batch_size=batch_size, shuffle=True)

    # Train auxiliary discriminator
    aux_disc = Discriminator(img_channels=1).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(aux_disc.parameters(), lr=0.0002, betas=(0.5, 0.999))

    print("Training auxiliary discriminator for DRS...")
    aux_disc.train()
    for epoch in range(n_epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = aux_disc(x)
            loss = criterion(out, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"[Epoch {epoch+1}] Aux D Loss: {loss.item():.4f}")

    aux_disc.eval()
    with torch.no_grad():
        logits = aux_disc(fake_tensor).squeeze()
        probs = torch.sigmoid(logits)
        ldr = probs / (1 - probs + 1e-6)

    # Normalize and resample with probability LDR
    ldr_np = ldr.cpu().numpy()
    acceptance_probs = ldr_np / ldr_np.max()
    accept_flags = np.random.rand(len(ldr_np)) < acceptance_probs

    filtered_samples = fake_tensor[accept_flags]
    print(f"DRS Accepted {len(filtered_samples)}/{num_gen} samples ({accept_flags.mean()*100:.2f}%)")
    return filtered_samples.cpu().numpy()


In [None]:
filtered_samples_basic = apply_DRS(
    generator=G,
    real_data=X_mnist,  # shape: (60000, 1, 28, 28)
    z_dim=100,
    num_gen=10000
)

Training auxiliary discriminator for DRS...
[Epoch 1] Aux D Loss: 0.2668
[Epoch 2] Aux D Loss: 0.3534
[Epoch 3] Aux D Loss: 0.3987
[Epoch 4] Aux D Loss: 0.2601
[Epoch 5] Aux D Loss: 0.3391
DRS Accepted 224/10000 samples (2.24%)


In [None]:

results["BasicGAN + DRS"] = evaluate_gan(filtered_samples_basic, classifier, real_data=X_mnist[:len(filtered_samples_basic)])
results["BasicGAN + DRS"]

  batch = torch.tensor(data[i:i+batch_size], dtype=torch.float32, device=device)


{'Precision': np.float64(0.15178571428571427),
 'Recall': np.float64(0.20535714285714285),
 'FID': np.float64(389.9210941864924)}

DIA + DRS

In [21]:
filtered_samples_dia = apply_DRS(
    generator=diaG,
    real_data=X_mnist,  # shape: (60000, 1, 28, 28)
    z_dim=100,
    num_gen=10000
)

Training auxiliary discriminator for DRS...
[Epoch 1] Aux D Loss: 0.3708
[Epoch 2] Aux D Loss: 0.3610
[Epoch 3] Aux D Loss: 0.3149
[Epoch 4] Aux D Loss: 0.3227
[Epoch 5] Aux D Loss: 0.3413
DRS Accepted 542/10000 samples (5.42%)


In [22]:
results["DiaGAN + DRS"] = evaluate_gan(filtered_samples_dia, classifier, real_data=X_mnist[:len(filtered_samples_dia)])
results["DiaGAN + DRS"]

  batch = torch.tensor(data[i:i+batch_size], dtype=torch.float32, device=device)


{'Precision': np.float64(0.29704797047970477),
 'Recall': np.float64(0.3118081180811808),
 'FID': np.float64(217.246650653262)}