In [None]:
import torch

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
from torchvision import datasets, transforms
from tqdm import trange
import argparse


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self, g_output_dim, latent_dim=100, num_classes=10):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
        self.num_classes = num_classes

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

class Discriminator(nn.Module):
    def __init__(self, d_input_dim, K=11):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, K)
        self.K = K

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return self.fc4(x)


class GaussianM(nn.Module):
    """
    Class that represents the Gaussian Mixture module
    Output follows the k Gaussian Distributions of the latent space
    """
    def __init__(self, K, d, sigma=0.4):
        super(GaussianM, self).__init__()
        self.K = K
        self.d = d
        self.sigma = sigma  # Scaling factor for the covariance matrix

        # Define fixed means and the standard deviation scaling factor
        self.means = torch.zeros(K, d).to(DEVICE)  # Fixed mean vectors for each Gaussian component
        self.sigma_matrix = self.sigma * torch.eye(d).to(DEVICE)  # Covariance matrix: sigma * I_d

    def forward(self, k, z):
        # Get the mean for the selected components
        mu = self.means[k.argmax(dim=1)]  # Get the mean vector corresponding to each k

        # Compute the sampled vector from the Gaussian mixture
        # z is assumed to be standard normal noise
        return mu + (self.sigma_matrix @ z.unsqueeze(-1)).squeeze(-1)


In [None]:
# def sample_from_gmm(batch_size, K, latent_dim, means, covariances, device):
#     """
#     Sample batch_size latent vectors z from a Gaussian mixture model.
#     """
#     gaussian_indices = torch.randint(0, K, (batch_size,)).to(device)
#     z = torch.empty(batch_size, latent_dim).to(device)
#     for i, idx in enumerate(gaussian_indices):
#         mean, covariance = means[idx], covariances[idx]
#         z[i] = torch.normal(mean, torch.sqrt(covariance.diag()))
#     return z, gaussian_indices


In [None]:
import time
import torch
import torchvision
import os
import numpy as np

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
d = 100 #dimension of latent space
K = 11 #size of the output of discrimnator


def D_train(x, y, G, D, GaussianM, D_optimizer, criterion):
    #=======================Train the discriminator=======================#
    G.train()
    D.train()
    D.zero_grad()


    # train discriminator on real samples
    x_real, y_real = x, y
    x_real, y_real = x_real.to(DEVICE), y_real.to(DEVICE)

    D_output_real = D(x_real)
    D_real_loss = criterion(D_output_real, y_real)

    #representing one of the K Gaussian distributions
    k_values = torch.randint(0, 10, (x.shape[0],))
    y = torch.eye(K)[k_values].to(DEVICE)
    N = torch.distributions.MultivariateNormal(torch.zeros(d), torch.eye(d))

    #random noise
    z = N.sample((x.shape[0],)).to(DEVICE).to(torch.float32)

    #the vector of latent space sampled from the Gaussian Mixture
    z_tilde = GaussianM(y, z)

    #Generate fake sample x_fake
    x_fake = G(z_tilde)

    D_output_fake =  D(x_fake)
    target_fake = torch.full((x.shape[0],), 10, dtype=torch.long).to(DEVICE)

    D_fake_loss = criterion(D_output_fake, target_fake)

    # gradient backpropagation and optimization of D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()

    return  D_loss.data.item()




def G_train(x, y, G, D, GaussianM, G_optimizer, criterion):
    #=======================Train the generator=======================#
    G.train()
    D.train()
    G.zero_grad()



    #representing one of the K Gaussian distributions
    k_values = torch.randint(0, 10, (x.shape[0],))
    y = torch.eye(K)[k_values].to(DEVICE)
    N = torch.distributions.MultivariateNormal(torch.zeros(d), torch.eye(d))
    #random noise
    z = N.sample((x.shape[0],)).to(DEVICE).to(torch.float32)

    #the vector of latent space sampled from the Gaussian Mixture
    z_tilde = GaussianM(y, z)

    G_output = G(z_tilde)

    D_output = D(G_output)
    G_loss = criterion(D_output, torch.argmax(y, dim=1)) #le vrai y

    # gradient backpropagation and optimization of G and GM's parameters
    G_loss.backward()
    G_optimizer.step()
    #GM is an extension of two layers of the generator


    return G_loss.data.item()


In [None]:
def save_models(G, D, folder):
    if not os.path.exists(folder):
        os.makedirs(folder)
    torch.save(G.module.state_dict(), os.path.join(folder, 'G.pth'))
    torch.save(D.module.state_dict(), os.path.join(folder, 'D.pth'))

def load_model(model, folder):
    model.load_state_dict(torch.load(os.path.join(folder, 'G.pth')))
    return model


In [None]:
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

train_dataset = datasets.MNIST(root='data/MNIST/', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]


Extracting data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 488kB/s]


Extracting data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.47MB/s]


Extracting data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 5.25MB/s]


Extracting data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw



In [None]:
from tqdm.notebook import trange
mnist_dim = 784
K = 11
d = 100
G = torch.nn.DataParallel(Generator(g_output_dim = mnist_dim)).to(DEVICE)
D = torch.nn.DataParallel(Discriminator(mnist_dim,K)).to(DEVICE)
GM = torch.nn.DataParallel(GaussianM(K,d)).to(DEVICE)

#initializing gaussian mixture parameters (mu and sigma)


print('Model loaded.')

# define loss
criterion = nn.CrossEntropyLoss()

# define optimizers
G_optimizer = optim.Adam(G.parameters(), lr = 0.0005)
D_optimizer = optim.Adam(D.parameters(), lr = 0.0005)



Model loaded.


In [None]:
from tqdm import tqdm

In [None]:
# Instantiate your Gaussian Mixture Model
GaussianM_instance = GaussianM(K, d).to(DEVICE)

# Ensure your models are on the correct device
G = G.to(DEVICE)
D = D.to(DEVICE)

epochs = 100
log_interval = 10
G_l1=[]
D_l1=[]
# Loop over the number of epochs
for epoch in range(1, epochs + 1):
    G_loss_total = 0.0
    D_loss_total = 0.0

    # Loop over the training dataset
    for batch_idx, (x, y) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")):
        x = x.view(-1, 784).to(DEVICE)  # Flatten the images (for MNIST)
        y = y.to(DEVICE)

        # Train the Discriminator
        D.zero_grad()  # Reset gradients
        D_loss = D_train(x, y, G, D, GaussianM_instance, D_optimizer, criterion)
        D_loss_total += D_loss
        D_l1.append(D_loss)

        # Train the Generator
        G.zero_grad()  # Reset gradients
        G_loss = G_train(x, y, G, D, GaussianM_instance, G_optimizer, criterion)
        G_loss_total += G_loss
        G_l1.append(G_loss)

    # Average losses for the epoch
    D_loss_avg = D_loss_total / len(train_loader)
    G_loss_avg = G_loss_total / len(train_loader)

    # Logging the progress
    print(f"Epoch [{epoch}/{epochs}], D Loss: {D_loss_avg:.4f}, G Loss: {G_loss_avg:.4f}")

    # Save the model at specified intervals
    if epoch % log_interval == 0:
        save_models(G, D, f'checkpoints/epoch_{epoch}')
        print(f"Models saved at epoch {epoch}")

print("Training complete.")


Epoch 1/100: 100%|██████████| 938/938 [00:24<00:00, 38.33it/s]


Epoch [1/100], D Loss: 0.8333, G Loss: 14.1446


Epoch 2/100: 100%|██████████| 938/938 [00:24<00:00, 39.08it/s]


Epoch [2/100], D Loss: 0.2221, G Loss: 17.4846


Epoch 3/100: 100%|██████████| 938/938 [00:24<00:00, 38.86it/s]


Epoch [3/100], D Loss: 0.1485, G Loss: 18.2082


Epoch 4/100: 100%|██████████| 938/938 [00:24<00:00, 38.75it/s]


Epoch [4/100], D Loss: 0.1214, G Loss: 18.4050


Epoch 5/100: 100%|██████████| 938/938 [00:23<00:00, 39.90it/s]


Epoch [5/100], D Loss: 0.1012, G Loss: 18.5369


Epoch 6/100: 100%|██████████| 938/938 [00:23<00:00, 39.91it/s]


Epoch [6/100], D Loss: 0.1197, G Loss: 17.3231


Epoch 7/100: 100%|██████████| 938/938 [00:23<00:00, 39.61it/s]


Epoch [7/100], D Loss: 0.0767, G Loss: 17.2471


Epoch 8/100: 100%|██████████| 938/938 [00:23<00:00, 39.27it/s]


Epoch [8/100], D Loss: 0.0693, G Loss: 22.4836


Epoch 9/100:  69%|██████▉   | 645/938 [00:16<00:07, 38.30it/s]


KeyboardInterrupt: 

In [None]:
import torch
import torchvision
import os



# Parameters
batch_size = 64  # Set your desired batch size
n_samples_target = 600  # Total number of samples to generate

print('Model Loading...')
# Model Pipeline
mnist_dim = 784
model = Generator(g_output_dim=mnist_dim).cuda()
model = load_model(model, 'checkpoints/epoch_80')
model = torch.nn.DataParallel(model).cuda()
model.eval()

print('Model loaded.')
print('Start Generating')
os.makedirs('samples_new', exist_ok=True)

n_samples = 0
with torch.no_grad():
    while n_samples < n_samples_target:
        z = torch.randn(batch_size, 100).cuda()
        x = model(z)
        x = x.reshape(batch_size, 1, 28, 28)  # Reshape for saving
        for k in range(x.shape[0]):
            if n_samples < n_samples_target:
                torchvision.utils.save_image(x[k], os.path.join('samples_new', f'{n_samples}.png'))
                n_samples += 1

print('Generated samples done.')


Model Loading...
Model loaded.
Start Generating


  model.load_state_dict(torch.load(os.path.join(folder, 'G.pth')))


Generated samples done.
