In [None]:
import torch,pdb
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
print(torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_arch_list())

In [None]:
#visualization fucntion
def show(tensor,ch=1,size=(28,28),num=16):
  #tensor = 128 x 784
  data = tensor.detach().cpu().view(-1,ch,*size) #128x1x28x28
  grid = make_grid(data[:num],nrow=4).permute(1,2,0) #1x28x28 = permute(28,28,1)
  plt.imshow(grid)
  plt.show()



In [None]:
#setup the main parameters and hyperparameters
ecpochs = 500
current_step = 0
info_step = 300
mean_gen_loss = 0
mean_disc_loss = 0

z_dim = 64
lr= 0.00001
loss_function = nn.BCEWithLogitsLoss()

batch_size = 128
device = 'cuda'

dataloder = DataLoader(MNIST('.',download=True,transform=transforms.ToTensor()),shuffle=True,batch_size=batch_size)

In [None]:
from torch.nn.modules.activation import Sigmoid
#declare pur model

#Generator
def gen_block(input,output):
  return nn.Sequential(
      nn.Linear(input,output),
      nn.BatchNorm1d(output),
      nn.ReLU(inplace=True)
  )


class Generator(nn.Module):
  def __init__(self,z_dim=64,i_dim=784,h_dim=128):
    super().__init__()

    self.gen = nn.Sequential(
      gen_block(z_dim,h_dim),       #64,128
      gen_block(h_dim,h_dim*2),     #128,256
      gen_block(h_dim*2,h_dim*4),   #256,512
      gen_block(h_dim*4,h_dim*8),   #512,1024
      nn.Linear(h_dim*8,i_dim),      #1024,784(28x28)
      nn.Sigmoid()
    )

  def forward(self,noise):
      return self.gen(noise)


def gen_noise(number,z_dim):
  return torch.randn(number,z_dim).to(device)


#Discriminator
def discBlock(input,output):
  return nn.Sequential(
      nn.Linear(input,output),
      nn.LeakyReLU(0.2)
  )


class Discriminator(nn.Module):
  def __init__(self,i_dim=784,h_dim=256):
    super().__init__()
    self.disc = nn.Sequential(
        discBlock(i_dim,h_dim*4),       #784,1024
        discBlock(h_dim*4,h_dim*2),     #1024,512
        discBlock(h_dim*2,h_dim),       #512,256
        nn.Linear(h_dim,1)              #256,1
    )

  def forward(self,image):
    return self.disc(image)




In [None]:

#Instance Generator class
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(),lr=lr)

#Instance Discriminator class
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(),lr=lr)

In [None]:
gen

In [None]:
disc

In [None]:
import torch, sys
print("torch.__version__:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
print("torch file:", torch.__file__)
print("python:", sys.executable)

# prueba m√≠nima CUDA (debe funcionar)
a = torch.ones(1, device="cuda")
b = a + 1
print("cuda add ok:", b)


In [None]:
x,y = next(iter(dataloder))

noise = gen_noise(batch_size,z_dim)
fake = gen(noise)

di = disc(fake)

print(x.shape)
print(y[:10])
show(fake)



In [None]:
#calciulating the loss

#generaot loss
def calc_gen_loss(loss_fn,gen,disc,number,z_dim):
  noise = gen_noise(number,z_dim)
  fake = gen(noise)
  pred = disc(fake)
  targets = torch.ones_like(pred)
  gen_loss = loss_fn(pred,targets)
  return gen_loss

#discriminator loss
def calc_disc_loss(loss_fn,gen,disc,number,real,z_dim):
  noise = gen_noise(number,z_dim)
  fake = gen(noise)
  disc_fake = disc(fake.detach())
  disc_fake_targets = torch.zeros_like(disc_fake)
  disc_fake_loss = loss_fn(disc_fake,disc_fake_targets)

  disc_real = disc(real)
  disc_real_targets=torch.ones_like(disc_real)
  disc_real_loss = loss_fn(disc_real,disc_real_targets)

  disc_loss = (disc_fake_loss + disc_real_loss) / 2

  return disc_loss


In [None]:
# def save_checkpoint(epoch, current_step, gen, disc, gen_opt, disc_opt, filename="gan_checkpoint.pth"):
#     torch.save({
#         "epoch": epoch,
#         "current_step": current_step,
#         "gen_state_dict": gen.state_dict(),
#         "disc_state_dict": disc.state_dict(),
#         "gen_opt_state_dict": gen_opt.state_dict(),
#         "disc_opt_state_dict": disc_opt.state_dict(),
#     }, filename)


In [None]:
#### 60000/128 = 468.75 = 469 steps in each epoch
#### each step is going to process = 128 images = size of batch (except the last step)
for epoch in range(ecpochs):
  for real,_ in tqdm(dataloder):
    ###discriminator
    disc_opt.zero_grad()

    current_batchsize = len(real) #real: 128 x 1 x 28 x 28
    real = real.view(current_batchsize,-1) # 128 x 784
    real = real.to(device)

    disc_loss = calc_disc_loss(loss_function,gen,disc,current_batchsize,real,z_dim)

    disc_loss.backward(retain_graph=True)
    disc_opt.step()

    ###generator
    gen_opt.zero_grad()
    gen_loss = calc_gen_loss(loss_function,gen,disc,current_batchsize,z_dim)

    gen_loss.backward(retain_graph=True)
    gen_opt.step()


    #visualization & stats
    mean_disc_loss += disc_loss.item() / info_step
    mean_gen_loss += gen_loss.item() / info_step

    if current_step % info_step == 0 and current_step > 0:
      fake_noise = gen_noise(current_batchsize,z_dim)
      fake = gen(fake_noise)
      show(fake)
      show(real)
      print(f"{epoch}: ste: {current_step} / Gen Loss: {mean_gen_loss} / disc_loss: {mean_disc_loss}")
      mean_gen_loss,mean_disc_loss = 0,0
    current_step += 1



