# GAN mnist


In [1]:
#import
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
import torchvision
from torch.utils.tensorboard import SummaryWriter


import numpy as np
import matplotlib.pyplot as plt

In [2]:
#device selection
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cpu device


In [3]:
class Discriminator(nn.Module):
    def __init__(self,nb_channels):
        super(Discriminator, self).__init__()
        self.ngpu = 0
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nb_channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(64, 128, 4, 2, 2, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
            # state size. (ndf*8) x 4 x 4
            #nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
    
class Generator(nn.Module):
    def __init__(self, input_dim, nb_channels):
        super(Generator, self).__init__()
        self.ngpu = 0
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( input_dim, 128 * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(128 * 4),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(128 * 4, 128 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128 * 2),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( 128 * 2, 128 , 4, 2, 1, bias=False),
            nn.BatchNorm2d(128 ),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( 128 *1, nb_channels, 4, 2, 3, bias=False),
            #nn.BatchNorm2d(ngf),
            #nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            #nn.ConvTranspose2d( 128, 1, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

In [4]:
disc=Discriminator(1)
gen=Generator(100,1)

noise=torch.randn(32,100,1,1)
out=gen(noise)
print(out.size())
out2=disc(out)
print(out2.size())

torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 1, 1])


In [5]:
#Hyper parameters
lr = 2*10**(-4)
momentum=0.5
batch_size=128

input_dim=64
image_dim=28*28
nb_channels=1

num_epoch=50

disc=Discriminator(nb_channels).to(device)
gen=Generator(input_dim, nb_channels).to(device)

#fixed noise
fixed_noise = torch.randn(batch_size,input_dim,1,1).to(device)

#writer
writer_fake = SummaryWriter(f"runs/DCGAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/DCGAN_MNIST/real")

transformations=transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))]
    )

dataset = datasets.MNIST(root="dataset/",transform=transformations, download=True) #Chargement du dataset
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) #batch creation

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

step=0


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [8]:
for epoch in range(num_epoch):
    for batch_idx, (real,_) in enumerate(loader):
        #current batch_size
        batch_size = real.shape[0]
        
        #Faux echantillons
        #Génération de N=batch_size bruits 
        noise = torch.randn(batch_size,input_dim,1,1).to(device)
        #Génération des "faux" à partir du bruit
        fake = gen(noise)
        #Passage des faux dans le discriminateur
        disc_fake = disc(fake).view(-1)
        #Passage dans la loss discriminateur et generateur
        
        loss_G = criterion(disc_fake, torch.ones_like(disc_fake))
        
        #On continue en mettant les opérations propres au discriminateur ensembles
        loss_Dfake = criterion(disc_fake, torch.zeros_like(disc_fake))
        
        #Vrais échantillons
        real = real.view(-1,1,28,28).to(device)
        #Passage des vrais dans le discriminateur
        disc_real = disc(real).view(-1)
        #Passage dans la loss discriminateur
        loss_Dreal = criterion(disc_real, torch.ones_like(disc_real))
        #Calcul de la loss totale
        loss_D = (loss_Dreal+loss_Dfake)/2
        #Update des poids du discriminateur et du générateur
        disc.zero_grad()
        loss_D.backward(retain_graph=True)
             
        
        gen.zero_grad()
        loss_G.backward()
        
        opt_disc.step() 
        opt_gen.step()
    
        if (batch_idx%10)==0:
            print(
                f"Epoch [{epoch}/{num_epoch}"
                f"Loss D: {loss_D:.4f}, Loss G: {loss_G:.4f}"
                )
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1,1,28,28)
                real,_=next(iter(loader))
                data = real.reshape(-1,1,28,28)
                
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                writer_fake.add_image(
                    "Mnist fake images",img_grid_fake,global_step=epoch*1500+batch_idx
                )
                writer_real.add_image(
                    "Mnist real images",img_grid_real,global_step=epoch*1500+batch_idx
                )
                writer_real.flush()
                writer_fake.flush()

Epoch [0/50Loss D: 0.5660, Loss G: 0.3901
Epoch [0/50Loss D: 0.5516, Loss G: 0.4035
Epoch [0/50Loss D: 0.5701, Loss G: 0.3858
Epoch [0/50Loss D: 0.7946, Loss G: 0.2287
Epoch [0/50Loss D: 0.4930, Loss G: 0.4683
Epoch [0/50Loss D: 0.5228, Loss G: 0.4338
Epoch [0/50Loss D: 0.6191, Loss G: 0.3429
Epoch [0/50Loss D: 0.4957, Loss G: 0.4649
Epoch [0/50Loss D: 0.6589, Loss G: 0.3121
Epoch [0/50Loss D: 0.6032, Loss G: 0.3564
Epoch [0/50Loss D: 0.5004, Loss G: 0.4592
Epoch [0/50Loss D: 0.6037, Loss G: 0.3560
Epoch [0/50Loss D: 0.5006, Loss G: 0.4590
Epoch [0/50Loss D: 0.5669, Loss G: 0.3888
Epoch [0/50Loss D: 0.5343, Loss G: 0.4217
Epoch [0/50Loss D: 0.5721, Loss G: 0.3840
Epoch [0/50Loss D: 0.5522, Loss G: 0.4030
Epoch [0/50Loss D: 0.5462, Loss G: 0.4091
Epoch [0/50Loss D: 0.5547, Loss G: 0.4005
Epoch [0/50Loss D: 0.5510, Loss G: 0.4041
Epoch [0/50Loss D: 0.5597, Loss G: 0.3957
Epoch [0/50Loss D: 0.5604, Loss G: 0.3948
Epoch [0/50Loss D: 0.6075, Loss G: 0.3526
Epoch [0/50Loss D: 0.5036, Loss G:

KeyboardInterrupt: 

In [None]:
def output(Loss_D,Loss_G,changes,max_changes):
    print(
        f"Epoch [{changes}/{max_changes}"
        f"Loss D: {Loss_D:.4f}, Loss G: {Loss_G:.4f}"
        )
    with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1,1,28,28)
        real,_=next(iter(loader))
        data = real.reshape(-1,1,28,28)
        
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        writer_fake.add_image(
            "Mnist fake images",img_grid_fake,global_step=changes
        )
        writer_real.add_image(
            "Mnist real images",img_grid_real,global_step=changes
        )
        writer_real.flush()
        writer_fake.flush()
        
def train_disc(change_loss):
    for batch_idx, (real,_) in enumerate(loader):

        
        #current batch_size
        batch_size = real.shape[0]
        
        #Faux echantillons
        #Génération de N=batch_size bruits 
        noise = torch.randn(batch_size,input_dim,1,1).to(device)
        #Génération des "faux" à partir du bruit
        fake = gen(noise)
        #Passage des faux dans le discriminateur
        disc_fake = disc(fake).view(-1)
        #Passage dans la loss discriminateur et generateur
        

        
        #On continue en mettant les opérations propres au discriminateur ensembles
        loss_Dfake = criterion(disc_fake, torch.zeros_like(disc_fake))
        
        #Vrais échantillons
        real = real.view(-1,1,28,28).to(device)
        #Passage des vrais dans le discriminateur
        disc_real = disc(real).view(-1)
        #Passage dans la loss discriminateur
        loss_Dreal = criterion(disc_real, torch.ones_like(disc_real))
        #Calcul de la loss totale
        loss_D = (loss_Dreal+loss_Dfake)/2
            
        #Update des poids du discriminateur
        disc.zero_grad()
        loss_D.backward()
        opt_disc.step()
        
        if loss_D.detach().numpy()<change_loss:
            break
        
    loss_G = criterion(disc_fake, torch.ones_like(disc_fake))   
     
    return loss_G.detach().numpy(),loss_D.detach().numpy()

def train_gen(change_loss):
    for batch_idx, (real,_) in enumerate(loader):

        
        #current batch_size
        batch_size = real.shape[0]
        
        #Faux echantillons
        #Génération de N=batch_size bruits 
        noise = torch.randn(batch_size,input_dim,1,1).to(device)
        #Génération des "faux" à partir du bruit
        fake = gen(noise)
        #Passage des faux dans le discriminateur
        disc_fake = disc(fake).view(-1)
        #Passage dans la loss discriminateur et generateur
        loss_G = criterion(disc_fake, torch.ones_like(disc_fake))
            
            
        #Update des poids du discriminateur
        gen.zero_grad()
        loss_G.backward()
        opt_gen.step()
        
        if loss_G.detach().numpy()<change_loss:
            break
    
    loss_Dfake = criterion(disc_fake, torch.zeros_like(disc_fake))  
    #Vrais échantillons
    real = real.view(-1,1,28,28).to(device)
    #Passage des vrais dans le discriminateur
    disc_real = disc(real).view(-1)
    #Passage dans la loss discriminateur
    loss_Dreal = criterion(disc_real, torch.ones_like(disc_real))
    #Calcul de la loss totale
    loss_D = (loss_Dreal+loss_Dfake)/2

        
    return loss_G.detach().numpy(),loss_D.detach().numpy()
    

#writer
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
        
def train(max_changes=10,change_loss=0.01):
    state='disc'
    changes=0
    while changes < max_changes :
        if state=='disc':
            Loss_G,Loss_D=train_disc(change_loss)
            if Loss_D < change_loss:
                changes+=1
                if changes%25==0:
                    output(Loss_D,Loss_G,changes,max_changes)
                state='gen'
        elif state=='gen':
            Loss_G,Loss_D=train_gen(change_loss)
            if Loss_G < change_loss:
                changes+=1
                if changes%25==0:
                    output(Loss_D,Loss_G,changes,max_changes)
                state='disc'
            
train()

KeyboardInterrupt: 

In [9]:
torch.save(disc.state_dict(),'weights_disc')
torch.save(gen.state_dict(),'weights_gen')