In [None]:
import torch

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
from torchvision import datasets, transforms
from tqdm import trange
import argparse


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self, g_output_dim, latent_dim=100, num_classes=10):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
        self.num_classes = num_classes

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

class Discriminator(nn.Module):
    def __init__(self, d_input_dim, K=11):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, K)
        self.K = K

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return self.fc4(x)


class GaussianM(nn.Module):
    """
    Class that represents the Gaussian Mixture module
    Output follows the k Gaussian Distributions of the latent space
    """
    def __init__(self, K, d, sigma=0.4):
        super(GaussianM, self).__init__()
        self.K = K
        self.d = d
        self.sigma = sigma  # Scaling factor for the covariance matrix

        # Define fixed means and the standard deviation scaling factor
        self.means = torch.zeros(K, d).to(DEVICE)  # Fixed mean vectors for each Gaussian component
        self.sigma_matrix = self.sigma * torch.eye(d).to(DEVICE)  # Covariance matrix: sigma * I_d

    def forward(self, k, z):
        # Get the mean for the selected components
        mu = self.means[k.argmax(dim=1)]  # Get the mean vector corresponding to each k

        # Compute the sampled vector from the Gaussian mixture
        # z is assumed to be standard normal noise
        return mu + (self.sigma_matrix @ z.unsqueeze(-1)).squeeze(-1)


In [None]:
# def sample_from_gmm(batch_size, K, latent_dim, means, covariances, device):
#     """
#     Sample batch_size latent vectors z from a Gaussian mixture model.
#     """
#     gaussian_indices = torch.randint(0, K, (batch_size,)).to(device)
#     z = torch.empty(batch_size, latent_dim).to(device)
#     for i, idx in enumerate(gaussian_indices):
#         mean, covariance = means[idx], covariances[idx]
#         z[i] = torch.normal(mean, torch.sqrt(covariance.diag()))
#     return z, gaussian_indices


In [None]:
import time
import torch
import torchvision
import os
import numpy as np

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
d = 100 #dimension of latent space
K = 11 #size of the output of discrimnator

# def D_train(x, y, G, D, GaussianM, D_optimizer, criterion):
#     #=======================Train the discriminator=======================#
#     G.train()
#     D.train()

#     D.zero_grad()

#     # Train discriminator on real samples
#     x_real, y_real = x, y
#     x_real, y_real = x_real.to(DEVICE), y_real.to(DEVICE)

#     D_output_real = D(x_real)
#     D_real_loss = criterion(D_output_real, y_real)

#     # Sample from the Gaussian Mixture Model (fixed)
#     k_values = torch.randint(0, GaussianM.K, (x.shape[0],)).to(DEVICE)  # Randomly select Gaussian components
#     # Move the identity matrix to the correct device
#     y_one_hot = torch.eye(GaussianM.K, device=DEVICE)[k_values]  # One-hot encoded labels

#     # Generate standard normal noise
#     N = torch.distributions.MultivariateNormal(torch.zeros(d).to(DEVICE), torch.eye(d).to(DEVICE))
#     z = N.sample((x.shape[0],)).to(DEVICE).to(torch.float32)

#     # The vector of latent space sampled from the Gaussian Mixture
#     z_tilde = GaussianM(y_one_hot, z)

#     # Generate fake sample x_fake
#     x_fake = G(z_tilde)

#     D_output_fake = D(x_fake)
#     target_fake = torch.full((x.shape[0],), 10, dtype=torch.long).to(DEVICE)

#     D_fake_loss = criterion(D_output_fake, target_fake)

#     # Gradient backpropagation and optimization of D's parameters
#     D_loss = D_real_loss + D_fake_loss
#     D_loss.backward()
#     D_optimizer.step()

#     return D_loss.data.item()

def D_train(x, y, G, D, GaussianM, D_optimizer, criterion):
    #=======================Train the discriminator=======================#
    G.train()
    D.train()
    D.zero_grad()


    # train discriminator on real samples
    x_real, y_real = x, y
    x_real, y_real = x_real.to(DEVICE), y_real.to(DEVICE)

    D_output_real = D(x_real)
    D_real_loss = criterion(D_output_real, y_real)

    #representing one of the K Gaussian distributions
    k_values = torch.randint(0, 10, (x.shape[0],))
    y = torch.eye(K)[k_values].to(DEVICE)
    N = torch.distributions.MultivariateNormal(torch.zeros(d), torch.eye(d))

    #random noise
    z = N.sample((x.shape[0],)).to(DEVICE).to(torch.float32)

    #the vector of latent space sampled from the Gaussian Mixture
    z_tilde = GaussianM(y, z)

    #Generate fake sample x_fake
    x_fake = G(z_tilde)

    D_output_fake =  D(x_fake)
    target_fake = torch.full((x.shape[0],), 10, dtype=torch.long).to(DEVICE)

    D_fake_loss = criterion(D_output_fake, target_fake)

    # gradient backpropagation and optimization of D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()

    return  D_loss.data.item()




def G_train(x, y, G, D, GaussianM, G_optimizer, criterion):
    #=======================Train the generator=======================#
    G.train()
    D.train()
    G.zero_grad()



    #representing one of the K Gaussian distributions
    k_values = torch.randint(0, 10, (x.shape[0],))
    y = torch.eye(K)[k_values].to(DEVICE)
    N = torch.distributions.MultivariateNormal(torch.zeros(d), torch.eye(d))
    #random noise
    z = N.sample((x.shape[0],)).to(DEVICE).to(torch.float32)

    #the vector of latent space sampled from the Gaussian Mixture
    z_tilde = GaussianM(y, z)

    G_output = G(z_tilde)

    D_output = D(G_output)
    G_loss = criterion(D_output, torch.argmax(y, dim=1)) #le vrai y

    # gradient backpropagation and optimization of G and GM's parameters
    G_loss.backward()
    G_optimizer.step()
    #GM is an extension of two layers of the generator


    return G_loss.data.item()


In [None]:
def save_models(G, D, folder):
    if not os.path.exists(folder):
        os.makedirs(folder)
    torch.save(G.module.state_dict(), os.path.join(folder, 'G.pth'))
    torch.save(D.module.state_dict(), os.path.join(folder, 'D.pth'))

def load_model(model, folder):
    model.load_state_dict(torch.load(os.path.join(folder, 'G.pth')))
    return model


In [None]:
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

train_dataset = datasets.MNIST(root='data/MNIST/', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)


In [None]:
from tqdm.notebook import trange
mnist_dim = 784
K = 11
d = 100
G = torch.nn.DataParallel(Generator(g_output_dim = mnist_dim)).to(DEVICE)
D = torch.nn.DataParallel(Discriminator(mnist_dim,K)).to(DEVICE)
GM = torch.nn.DataParallel(GaussianM(K,d)).to(DEVICE)

#initializing gaussian mixture parameters (mu and sigma)


print('Model loaded.')

# define loss
criterion = nn.CrossEntropyLoss()

# define optimizers
G_optimizer = optim.Adam(G.parameters(), lr = 0.0005)
D_optimizer = optim.Adam(D.parameters(), lr = 0.0005)



Model loaded.


In [None]:
# Instantiate your Gaussian Mixture Model
GaussianM_instance = GaussianM(K, d).to(DEVICE)

# Ensure your models are on the correct device
G = G.to(DEVICE)
D = D.to(DEVICE)

epochs = 100
log_interval = 10
G_l1=[]
D_l1=[]
# Loop over the number of epochs
for epoch in range(1, epochs + 1):
    G_loss_total = 0.0
    D_loss_total = 0.0

    # Loop over the training dataset
    for batch_idx, (x, y) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")):
        x = x.view(-1, 784).to(DEVICE)  # Flatten the images (for MNIST)
        y = y.to(DEVICE)

        # Train the Discriminator
        D.zero_grad()  # Reset gradients
        D_loss = D_train(x, y, G, D, GaussianM_instance, D_optimizer, criterion)
        D_loss_total += D_loss
        D_l1.append(D_loss)

        # Train the Generator
        G.zero_grad()  # Reset gradients
        G_loss = G_train(x, y, G, D, GaussianM_instance, G_optimizer, criterion)
        G_loss_total += G_loss
        G_l1.append(G_loss)

    # Average losses for the epoch
    D_loss_avg = D_loss_total / len(train_loader)
    G_loss_avg = G_loss_total / len(train_loader)

    # Logging the progress
    print(f"Epoch [{epoch}/{epochs}], D Loss: {D_loss_avg:.4f}, G Loss: {G_loss_avg:.4f}")

    # Save the model at specified intervals
    if epoch % log_interval == 0:
        save_models(G, D, f'checkpoints/epoch_{epoch}')
        print(f"Models saved at epoch {epoch}")

print("Training complete.")


Epoch 1/100: 100%|██████████| 938/938 [00:23<00:00, 39.27it/s]


Epoch [1/100], D Loss: 0.8057, G Loss: 10.0247


Epoch 2/100: 100%|██████████| 938/938 [00:24<00:00, 38.93it/s]


Epoch [2/100], D Loss: 0.2100, G Loss: 14.9270


Epoch 3/100: 100%|██████████| 938/938 [00:24<00:00, 38.33it/s]


Epoch [3/100], D Loss: 0.1450, G Loss: 16.6012


Epoch 4/100: 100%|██████████| 938/938 [00:24<00:00, 38.47it/s]


Epoch [4/100], D Loss: 0.1135, G Loss: 17.6922


Epoch 5/100: 100%|██████████| 938/938 [00:24<00:00, 38.74it/s]


Epoch [5/100], D Loss: 0.0949, G Loss: 21.1240


Epoch 6/100: 100%|██████████| 938/938 [00:23<00:00, 39.24it/s]


Epoch [6/100], D Loss: 0.0851, G Loss: 22.3040


Epoch 7/100: 100%|██████████| 938/938 [00:23<00:00, 39.56it/s]


Epoch [7/100], D Loss: 0.0749, G Loss: 26.5313


Epoch 8/100: 100%|██████████| 938/938 [00:23<00:00, 39.36it/s]


Epoch [8/100], D Loss: 0.0665, G Loss: 27.8293


Epoch 9/100: 100%|██████████| 938/938 [00:23<00:00, 39.48it/s]


Epoch [9/100], D Loss: 0.0604, G Loss: 30.7992


Epoch 10/100: 100%|██████████| 938/938 [00:23<00:00, 39.72it/s]


Epoch [10/100], D Loss: 0.0605, G Loss: 24.4106
Models saved at epoch 10


Epoch 11/100: 100%|██████████| 938/938 [00:24<00:00, 38.61it/s]


Epoch [11/100], D Loss: 0.0511, G Loss: 20.8314


Epoch 12/100: 100%|██████████| 938/938 [00:23<00:00, 39.64it/s]


Epoch [12/100], D Loss: 0.0496, G Loss: 23.6366


Epoch 13/100: 100%|██████████| 938/938 [00:23<00:00, 39.67it/s]


Epoch [13/100], D Loss: 0.0484, G Loss: 26.4556


Epoch 14/100: 100%|██████████| 938/938 [00:23<00:00, 40.44it/s]


Epoch [14/100], D Loss: 0.0403, G Loss: 25.8721


Epoch 15/100: 100%|██████████| 938/938 [00:23<00:00, 40.18it/s]


Epoch [15/100], D Loss: 0.0396, G Loss: 25.9785


Epoch 16/100: 100%|██████████| 938/938 [00:23<00:00, 39.53it/s]


Epoch [16/100], D Loss: 0.0356, G Loss: 29.3819


Epoch 17/100: 100%|██████████| 938/938 [00:23<00:00, 39.67it/s]


Epoch [17/100], D Loss: 0.0385, G Loss: 28.6219


Epoch 18/100: 100%|██████████| 938/938 [00:23<00:00, 39.78it/s]


Epoch [18/100], D Loss: 0.0353, G Loss: 30.9732


Epoch 19/100: 100%|██████████| 938/938 [00:24<00:00, 38.40it/s]


Epoch [19/100], D Loss: 0.0341, G Loss: 31.5645


Epoch 20/100: 100%|██████████| 938/938 [00:23<00:00, 39.94it/s]


Epoch [20/100], D Loss: 0.0293, G Loss: 35.0524
Models saved at epoch 20


Epoch 21/100: 100%|██████████| 938/938 [00:23<00:00, 40.22it/s]


Epoch [21/100], D Loss: 0.0343, G Loss: 45.1287


Epoch 22/100: 100%|██████████| 938/938 [00:23<00:00, 39.76it/s]


Epoch [22/100], D Loss: 0.0318, G Loss: 38.6805


Epoch 23/100: 100%|██████████| 938/938 [00:23<00:00, 40.02it/s]


Epoch [23/100], D Loss: 0.0266, G Loss: 33.7753


Epoch 24/100: 100%|██████████| 938/938 [00:23<00:00, 40.18it/s]


Epoch [24/100], D Loss: 0.0273, G Loss: 43.4412


Epoch 25/100: 100%|██████████| 938/938 [00:23<00:00, 40.29it/s]


Epoch [25/100], D Loss: 0.0302, G Loss: 87.4698


Epoch 26/100: 100%|██████████| 938/938 [00:23<00:00, 40.45it/s]


Epoch [26/100], D Loss: 0.0267, G Loss: 82.4250


Epoch 27/100: 100%|██████████| 938/938 [00:24<00:00, 38.77it/s]


Epoch [27/100], D Loss: 0.0263, G Loss: 37.8208


Epoch 28/100: 100%|██████████| 938/938 [00:23<00:00, 39.72it/s]


Epoch [28/100], D Loss: 0.0266, G Loss: 33.3006


Epoch 29/100: 100%|██████████| 938/938 [00:23<00:00, 39.76it/s]


Epoch [29/100], D Loss: 0.0223, G Loss: 39.6538


Epoch 30/100: 100%|██████████| 938/938 [00:23<00:00, 39.93it/s]


Epoch [30/100], D Loss: 0.0252, G Loss: 50.7815
Models saved at epoch 30


Epoch 31/100: 100%|██████████| 938/938 [00:23<00:00, 39.87it/s]


Epoch [31/100], D Loss: 0.0255, G Loss: 89.3243


Epoch 32/100: 100%|██████████| 938/938 [00:23<00:00, 40.15it/s]


Epoch [32/100], D Loss: 0.0212, G Loss: 65.9016


Epoch 33/100: 100%|██████████| 938/938 [00:23<00:00, 39.97it/s]


Epoch [33/100], D Loss: 0.0249, G Loss: 55.3868


Epoch 34/100: 100%|██████████| 938/938 [00:23<00:00, 39.71it/s]


Epoch [34/100], D Loss: 0.0242, G Loss: 51.9406


Epoch 35/100: 100%|██████████| 938/938 [00:24<00:00, 38.36it/s]


Epoch [35/100], D Loss: 0.0217, G Loss: 57.7562


Epoch 36/100: 100%|██████████| 938/938 [00:23<00:00, 39.83it/s]


Epoch [36/100], D Loss: 0.0209, G Loss: 57.9094


Epoch 37/100: 100%|██████████| 938/938 [00:23<00:00, 40.29it/s]


Epoch [37/100], D Loss: 0.0235, G Loss: 55.2926


Epoch 38/100: 100%|██████████| 938/938 [00:23<00:00, 40.09it/s]


Epoch [38/100], D Loss: 0.0201, G Loss: 60.0606


Epoch 39/100: 100%|██████████| 938/938 [00:23<00:00, 39.53it/s]


Epoch [39/100], D Loss: 0.0226, G Loss: 52.1216


Epoch 40/100: 100%|██████████| 938/938 [00:23<00:00, 39.55it/s]


Epoch [40/100], D Loss: 0.0210, G Loss: 65.7538
Models saved at epoch 40


Epoch 41/100: 100%|██████████| 938/938 [00:23<00:00, 39.47it/s]


Epoch [41/100], D Loss: 0.0180, G Loss: 72.8326


Epoch 42/100: 100%|██████████| 938/938 [00:23<00:00, 39.78it/s]


Epoch [42/100], D Loss: 0.0239, G Loss: 64.7504


Epoch 43/100: 100%|██████████| 938/938 [00:24<00:00, 38.32it/s]


Epoch [43/100], D Loss: 0.0198, G Loss: 60.9828


Epoch 44/100: 100%|██████████| 938/938 [00:23<00:00, 39.68it/s]


Epoch [44/100], D Loss: 0.0228, G Loss: 57.5689


Epoch 45/100: 100%|██████████| 938/938 [00:23<00:00, 39.43it/s]


Epoch [45/100], D Loss: 0.0190, G Loss: 61.8253


Epoch 46/100: 100%|██████████| 938/938 [00:23<00:00, 39.44it/s]


Epoch [46/100], D Loss: 0.0224, G Loss: 59.5285


Epoch 47/100: 100%|██████████| 938/938 [00:23<00:00, 39.58it/s]


Epoch [47/100], D Loss: 0.0183, G Loss: 53.8752


Epoch 48/100: 100%|██████████| 938/938 [00:23<00:00, 39.14it/s]


Epoch [48/100], D Loss: 0.0145, G Loss: 51.9659


Epoch 49/100: 100%|██████████| 938/938 [00:24<00:00, 38.71it/s]


Epoch [49/100], D Loss: 0.0260, G Loss: 52.4113


Epoch 50/100: 100%|██████████| 938/938 [00:24<00:00, 38.49it/s]


Epoch [50/100], D Loss: 0.0211, G Loss: 262.8211
Models saved at epoch 50


Epoch 51/100: 100%|██████████| 938/938 [00:24<00:00, 38.19it/s]


Epoch [51/100], D Loss: 0.0173, G Loss: 391.8773


Epoch 52/100: 100%|██████████| 938/938 [00:23<00:00, 40.49it/s]


Epoch [52/100], D Loss: 0.0229, G Loss: 354.2618


Epoch 53/100: 100%|██████████| 938/938 [00:23<00:00, 40.28it/s]


Epoch [53/100], D Loss: 0.0126, G Loss: 337.9527


Epoch 54/100: 100%|██████████| 938/938 [00:23<00:00, 40.01it/s]


Epoch [54/100], D Loss: 0.0281, G Loss: 282.4945


Epoch 55/100: 100%|██████████| 938/938 [00:23<00:00, 39.71it/s]


Epoch [55/100], D Loss: 0.0241, G Loss: 371.1258


Epoch 56/100: 100%|██████████| 938/938 [00:23<00:00, 39.28it/s]


Epoch [56/100], D Loss: 0.0167, G Loss: 370.2001


Epoch 57/100: 100%|██████████| 938/938 [00:23<00:00, 39.52it/s]


Epoch [57/100], D Loss: 0.0211, G Loss: 398.1330


Epoch 58/100: 100%|██████████| 938/938 [00:23<00:00, 39.74it/s]


Epoch [58/100], D Loss: 0.0197, G Loss: 387.6213


Epoch 59/100: 100%|██████████| 938/938 [00:24<00:00, 38.36it/s]


Epoch [59/100], D Loss: 0.0189, G Loss: 410.7504


Epoch 60/100: 100%|██████████| 938/938 [00:23<00:00, 39.91it/s]


Epoch [60/100], D Loss: 0.0170, G Loss: 377.6673
Models saved at epoch 60


Epoch 61/100: 100%|██████████| 938/938 [00:23<00:00, 39.60it/s]


Epoch [61/100], D Loss: 0.0196, G Loss: 355.6962


Epoch 62/100: 100%|██████████| 938/938 [00:23<00:00, 39.72it/s]


Epoch [62/100], D Loss: 0.0210, G Loss: 321.1609


Epoch 63/100: 100%|██████████| 938/938 [00:23<00:00, 40.55it/s]


Epoch [63/100], D Loss: 0.0189, G Loss: 300.7090


Epoch 64/100: 100%|██████████| 938/938 [00:23<00:00, 39.92it/s]


Epoch [64/100], D Loss: 0.0145, G Loss: 276.7715


Epoch 65/100: 100%|██████████| 938/938 [00:23<00:00, 39.22it/s]


Epoch [65/100], D Loss: 0.0188, G Loss: 258.1354


Epoch 66/100: 100%|██████████| 938/938 [00:23<00:00, 39.45it/s]


Epoch [66/100], D Loss: 0.0195, G Loss: 245.7381


Epoch 67/100: 100%|██████████| 938/938 [00:24<00:00, 38.16it/s]


Epoch [67/100], D Loss: 0.0155, G Loss: 242.4549


Epoch 68/100: 100%|██████████| 938/938 [00:23<00:00, 39.58it/s]


Epoch [68/100], D Loss: 0.0194, G Loss: 227.5833


Epoch 69/100: 100%|██████████| 938/938 [00:23<00:00, 39.52it/s]


Epoch [69/100], D Loss: 0.0198, G Loss: 227.7100


Epoch 70/100: 100%|██████████| 938/938 [00:23<00:00, 39.38it/s]


Epoch [70/100], D Loss: 0.0178, G Loss: 254.6745
Models saved at epoch 70


Epoch 71/100: 100%|██████████| 938/938 [00:23<00:00, 39.37it/s]


Epoch [71/100], D Loss: 0.0183, G Loss: 200.1752


Epoch 72/100: 100%|██████████| 938/938 [00:23<00:00, 39.52it/s]


Epoch [72/100], D Loss: 0.0171, G Loss: 189.4076


Epoch 73/100: 100%|██████████| 938/938 [00:23<00:00, 39.72it/s]


Epoch [73/100], D Loss: 0.0182, G Loss: 174.4845


Epoch 74/100: 100%|██████████| 938/938 [00:24<00:00, 37.95it/s]


Epoch [74/100], D Loss: 0.0211, G Loss: 156.7890


Epoch 75/100: 100%|██████████| 938/938 [00:23<00:00, 39.19it/s]


Epoch [75/100], D Loss: 0.0181, G Loss: 151.7086


Epoch 76/100: 100%|██████████| 938/938 [00:23<00:00, 40.25it/s]


Epoch [76/100], D Loss: 0.0146, G Loss: 143.3821


Epoch 77/100: 100%|██████████| 938/938 [00:23<00:00, 39.97it/s]


Epoch [77/100], D Loss: 0.0237, G Loss: 151.4923


Epoch 78/100: 100%|██████████| 938/938 [00:23<00:00, 39.21it/s]


Epoch [78/100], D Loss: 0.0140, G Loss: 154.6417


Epoch 79/100: 100%|██████████| 938/938 [00:23<00:00, 39.70it/s]


Epoch [79/100], D Loss: 0.0178, G Loss: 144.5089


Epoch 80/100: 100%|██████████| 938/938 [00:23<00:00, 39.32it/s]


Epoch [80/100], D Loss: 0.0172, G Loss: 141.7183
Models saved at epoch 80


Epoch 81/100: 100%|██████████| 938/938 [00:23<00:00, 39.26it/s]


Epoch [81/100], D Loss: 0.0246, G Loss: 131.1794


Epoch 82/100: 100%|██████████| 938/938 [00:24<00:00, 37.92it/s]


Epoch [82/100], D Loss: 0.0167, G Loss: 136.6809


Epoch 83/100: 100%|██████████| 938/938 [00:23<00:00, 39.51it/s]


Epoch [83/100], D Loss: 0.0213, G Loss: 151.8109


Epoch 84/100:  31%|███       | 291/938 [00:07<00:16, 39.40it/s]


KeyboardInterrupt: 

In [None]:
import torch
import torchvision
import os



# Parameters
batch_size = 64  # Set your desired batch size
n_samples_target = 600  # Total number of samples to generate

print('Model Loading...')
# Model Pipeline
mnist_dim = 784
model = Generator(g_output_dim=mnist_dim).cuda()
model = load_model(model, 'checkpoints/epoch_80')
model = torch.nn.DataParallel(model).cuda()
model.eval()

print('Model loaded.')
print('Start Generating')
os.makedirs('samples_new', exist_ok=True)

n_samples = 0
with torch.no_grad():
    while n_samples < n_samples_target:
        z = torch.randn(batch_size, 100).cuda()
        x = model(z)
        x = x.reshape(batch_size, 1, 28, 28)  # Reshape for saving
        for k in range(x.shape[0]):
            if n_samples < n_samples_target:
                torchvision.utils.save_image(x[k], os.path.join('samples_new', f'{n_samples}.png'))
                n_samples += 1

print('Generated samples done.')


Model Loading...
Model loaded.
Start Generating


  model.load_state_dict(torch.load(os.path.join(folder, 'G.pth')))


Generated samples done.
