In [None]:
import torch, torchvision, os, PIL, pdb
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
def show(tensor, name=' ',num=25):
  #tensor: 128X784
  data=tensor.detach().cpu() #128 X 1 X 28 X 28
  grid=make_grid(data[:num],nrow=5).permute(1,2,0) #1 X 28 X 28==> 28 X 28 X1
  plt.imshow(grid.clip(0,1))
  plt.show()

In [None]:
n_epochs=100
batch_size=128
lr=1e-4
z_dim=200
device="cuda"

current_step=0
crit_cycles=5
gen_losses=[]
crit_losses=[]
show_step=35
save_step=35



In [None]:
from math import tanh
class Generator(nn.Module):
  def __init__(self, z_dim=64, d_dim=16):
    super(Generator, self).__init__()
    self.z_dim=z_dim

    self.gen = nn.Sequential(
            ## ConvTranspose2d: in_channels, out_channels, kernel_size, stride=1, padding=0
            ## Calculating new width and height: (n-1)*stride -2*padding +ks
            ## n = width or height
            ## ks = kernel size
            ## we begin with a 1x1 image with z_dim number of channels (200)
            nn.ConvTranspose2d(z_dim, d_dim * 32, 4, 1, 0), ## 4x4 (ch: 200, 512)
            nn.BatchNorm2d(d_dim*32),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim*32, d_dim*16, 4, 2, 1), ## 8x8 (ch: 512, 256)
            nn.BatchNorm2d(d_dim*16),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim*16, d_dim*8, 4, 2, 1), ## 16x16 (ch: 256, 128)
            #(n-1)*stride -2*padding +ks = (8-1)*2-2*1+4=16
            nn.BatchNorm2d(d_dim*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim*8, d_dim*4, 4, 2, 1), ## 32x32 (ch: 128, 64)
            nn.BatchNorm2d(d_dim*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim*4, d_dim*2, 4, 2, 1), ## 64x64 (ch: 64, 32)
            nn.BatchNorm2d(d_dim*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim*2, 3, 4, 2, 1), ## 128x128 (ch: 32, 3)
            nn.Tanh() ### produce result in the range from -1 to 1
    )


  def forward(self, noise):
    x=noise.view(len(noise), self.z_dim, 1, 1)  # 128 x 200 x 1 x 1
    return self.gen(x)

def gen_noise(num, z_dim, device='cuda'):
   return torch.randn(num, z_dim, device=device) # 128 x 200




In [None]:
class Critic(nn.Module):
  def __init__(self, d_dim=16):
    super(Critic, self).__init__()

    self.crit = nn.Sequential(
      # Conv2d: in_channels, out_channels, kernel_size, stride=1, padding=0
      ## New width and height: # (n+2*pad-ks)//stride +1
      nn.Conv2d(3, d_dim, 4, 2, 1), #(n+2*pad-ks)//stride +1 = (128+2*1-4)//2+1=64x64 (ch: 3,16)
      nn.InstanceNorm2d(d_dim),
      nn.LeakyReLU(0.2),

      nn.Conv2d(d_dim, d_dim*2, 4, 2, 1), ## 32x32 (ch: 16, 32)
      nn.InstanceNorm2d(d_dim*2),
      nn.LeakyReLU(0.2),

      nn.Conv2d(d_dim*2, d_dim*4, 4, 2, 1), ## 16x16 (ch: 32, 64)
      nn.InstanceNorm2d(d_dim*4),
      nn.LeakyReLU(0.2),

      nn.Conv2d(d_dim*4, d_dim*8, 4, 2, 1), ## 8x8 (ch: 64, 128)
      nn.InstanceNorm2d(d_dim*8),
      nn.LeakyReLU(0.2),

      nn.Conv2d(d_dim*8, d_dim*16, 4, 2, 1), ## 4x4 (ch: 128, 256)
      nn.InstanceNorm2d(d_dim*16),
      nn.LeakyReLU(0.2),

      nn.Conv2d(d_dim*16, 1, 4, 1, 0), #(n+2*pad-ks)//stride +1=(4+2*0-4)//1+1= 1X1 (ch: 256,1)

    )


  def forward(self, image):
    # image: 128 x 3 x 128 x 128
    crit_pred = self.crit(image) # 128 x 1 x 1 x 1
    return crit_pred.view(len(crit_pred),-1) ## 128 x 1

In [None]:
datasetfile="dataset"
class Dataset(Dataset):
  def __init__(self,path,size=128,lim=7000):
    self.sizes=[size,size]
    items, labels=[],[]
    for data in os.listdir(datasetfile)[:lim]:
      item=os.path.join(path,data)
      items.append(item)
      labels.append(data)
    self.labels=labels
    self.items=items
  def __len__(self):
    return len(self.items)
  def __getitem__(self,idx):
    data=PIL.Image.open(self.items[idx]).convert("RGB")
    data=np.asarray(torchvision.transforms.Resize(self.sizes)(data))
    data = np.transpose(data,(2,0,1)).astype(np.float32, copy=False)
    data=torch.from_numpy(data).div(255)
    return data, self.labels[idx]


In [None]:
ds=Dataset(datasetfile,size=128,lim=7000)

In [None]:
dataloader=DataLoader(ds,batch_size=batch_size,shuffle=True)
gen=Generator(z_dim).to(device)
critic=Critic().to(device)

gen_opt=torch.optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.9))
critic_opt=torch.optim.Adam(critic.parameters(),lr=lr,betas=(0.5,0.9))

x,y=next(iter(dataloader))
show(x)

In [None]:
def get_gp(fake,real,crit,alpha,gamma=10):
  mix_imgs=real*alpha+fake*(1-alpha)
  mix_scores=crit(mix_imgs)
  gradient=torch.autograd.grad(
      inputs=mix_imgs,
      outputs=mix_scores,
      grad_outputs=torch.ones_like(mix_scores),
      retain_graph=True,
      create_graph=True
  )[0]
  gradient=gradient.view(len(gradient),-1)
  gradient_norm=gradient.norm(2,dim=1)
  gp = gamma * ((gradient_norm-1)**2).mean()
  # gp=((gradient_norm-1)**2).mean()
  return gp

In [None]:
root="./data/"
def save_model(name):
  torch.save(
      {
          "epoch":epoch,
          "model_state_dict":gen.state_dict(),
          "optimizer_state_dict":gen_opt.state_dict(),
      },f"{root}G-{name}.pkl"
  )

  torch.save(
      {
          "epoch":epoch,
          "model_state_dict":critic.state_dict(),
          "optimizer_state_dict":critic_opt.state_dict(),
      },f"{root}C-{name}.pkl"
  )
  print("model saved!")

In [None]:
def load_model(name):
  checkpoint=torch.load(f"{root}G-{name}.pkl")
  gen.load_state_dict(checkpoint["model_state_dict"])
  gen_opt.load_state_dict(checkpoint["optimizer_state_dict"])
  checkpoint=torch.load(f"{root}C-{name}.pkl")
  critic.load_state_dict(checkpoint["model_state_dict"])
  critic_opt.load_state_dict(checkpoint["optimizer_state_dict"])


In [None]:
epoch=1
save_model("test")

In [None]:
n_epochs=15
for epoch in range(n_epochs):
  for real, _ in tqdm(dataloader):
    cur_bs=len(real)
    real=real.to(device)

    ###critic
    mean_critic_loss=0
    for _ in range(crit_cycles):
      critic_opt.zero_grad()
      noise=gen_noise(cur_bs,z_dim)
      fake=gen(noise)
      crit_fake_pred=critic(fake.detach())
      crit_real_pred=critic(real)
      alpha=torch.rand(len(real),1,1,1,requires_grad=True,device=device)
      gp=get_gp(fake.detach(),real,critic,alpha)
      critic_loss=crit_fake_pred.mean() - crit_real_pred.mean() + gp
      mean_critic_loss+=critic_loss.item() / crit_cycles
      critic_loss.backward(retain_graph=True)
      critic_opt.step()
    crit_losses+=[mean_critic_loss]

    ##gen
    gen_opt.zero_grad()
    noise=gen_noise(cur_bs,z_dim)
    fake=gen(noise)
    critc_pred=critic(fake)
    gen_loss=-critc_pred.mean()
    gen_loss.backward()
    gen_opt.step()
    gen_losses+=[gen_loss.item()]

    ##stats
    if(current_step%show_step==0 and current_step>0):
      save_model("mygan")
      show(fake,name="fake")
      show(real,name="real")
      gen_mean=sum(gen_losses[-show_step:])/show_step
      critc_mean=sum(crit_losses[-show_step:])/show_step
      print(f"Epoch:{epoch}: step:{current_step}: genloss:{gen_mean}: critcloss:{critc_mean}")
      plt.plot(
          range(len(gen_losses)),
          torch.tensor(gen_losses),
          label="generator losses"

      )
      plt.plot(
          range(len(crit_losses)),
          torch.tensor(crit_losses),
          label="critic losses"

      )
      plt.ylim(-200,200)
      plt.legend()
      plt.show()
    current_step+=1