In [1]:
import torch.nn as nn
import torch.nn.functional as F

# Inportar los datos

In [2]:
#cd 'data/'

In [3]:
#!git clone https://github.com/aitorzip/PyTorch-CycleGAN.git

In [4]:
#cd './PyTorch-CycleGAN/'

In [5]:
#%%sh 
#sh ./download_dataset summer2winter_yosemite

In [6]:
#!mv datasets/summer2winter_yosemite /content/drive/MyDrive/Colab\ Notebooks/IA/data/

# Crear  la red

## Red Residual

---



In [7]:
class ResidualBlock(nn.Module):
  def __init__(self,in_features):
    super(ResidualBlock,self).__init__()
  
    conv_block = [
                nn.ReflectionPad2d(1), #Mejor padding #Hacer padding es rellenar la red, El Reflection padding hace un mejor padding que un cero padding, porque conserva mas la distribucion de la imagen.
                nn.Conv2d(in_features, in_features, 3),
                nn.InstanceNorm2d(in_features), #Para las GAN es mejor instance norm porque batch norm no maneja bien el contraste, pero instance norm no es tan buen regulizador como batch norm
                nn.ReLU(True),
                nn.ReflectionPad2d(1),
                nn.Conv2d(in_features,in_features,3),
                nn.InstanceNorm2d(in_features)
  ]

    self.conv_block = nn.Sequential(*conv_block)
  
  def forward(self,x):
    return self.conv_block(x) + x 

## Red Generativa

---



In [8]:
class Generator(nn.Module):
  def __init__(self,input_channels,out_channels,n_residual_blocks=9):
    super(Generator,self).__init__()

    model = [nn.ReflectionPad2d(3),
             nn.Conv2d(input_channels,64, 7), # ((I - 7 + 6)/1 )+ 1 = I El size de la imagen no se ve alterado
             nn.InstanceNorm2d(64),
             nn.ReLU(True)
    ]

    in_features = 64
    out_features = in_features * 2

    #Encoding (Compresion)

    for _ in range(2):
      model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), # I/2
                nn.InstanceNorm2d(out_features),
                nn.ReLU(True)
      ]
      in_features = out_features
      out_features = in_features * 2

    # Transformaciones residuales 
    for _ in range(n_residual_blocks):
      model += [ResidualBlock(in_features)]

    # Decoding (Aumentar)
    out_features = in_features // 2
    for _ in range(2):
      model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), # 2*I  #Capas de deconvolucion
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
      ] 

      in_features = out_features
      out_features = in_features //2
    
    #Salida
    model += [ nn.ReflectionPad2d(3),
              nn.Conv2d(64, out_channels, 7), #I
              nn.Tanh()
              ]

    self.model = nn.Sequential(*model)
  
  def forward(self,x):
    return self.model(x)



## Red Discriminativa

---



In [9]:
class Discriminator(nn.Module):
  " PatchGAN : Discrimina estilo o textura"
  def __init__(self,input_channels):
    super(Discriminator,self).__init__()

    model = [nn.Conv2d(input_channels, 64, 4, stride=2, padding=1), #I/2
             nn.LeakyReLU(0.2, inplace=True),
             ]

    model += [nn.Conv2d(64, 128, 4, stride=2, padding=1), #I/2
              nn.InstanceNorm2d(128),
             nn.LeakyReLU(0.2, inplace=True),
             ]
    model += [nn.Conv2d(128, 256, 4, stride=2, padding=1), #I/2
              nn.InstanceNorm2d(256),
             nn.LeakyReLU(0.2, inplace=True),
             ]
    model += [nn.Conv2d(256, 512, 4, padding=1), #I-1
              nn.InstanceNorm2d(512),
             nn.LeakyReLU(0.2, inplace=True),
             ]

    #Flatten 
    model += [nn.Conv2d(512,1, 4, padding=1)] #I-1

    self.model = nn.Sequential(*model)

  def forward(self,x):
    x = self.model(x)
    return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1) #-1

# Preparar el entrenamiento

---



In [10]:
import sys
sys.path.append('data/')

import os
import glob #Manejo de files
import random 
import itertools #Concatenaciones entre lista y iteradores
from PIL import Image

import torch

from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

from utils import ReplayBuffer # Una forma de guardar las imagenes fake que va creando la red generativa y poder agregarla a la loss

from livelossplot import PlotLosses
from utils import Logger

In [11]:
class ImageDataset(Dataset):
  def __init__(self, base_dir, transform = None , split='train'):
    self.transform = transforms.Compose(transform)
    self.file_A = sorted(glob.glob(os.path.join(base_dir, '{}/A/*.*'.format(split))))
    self.file_B = sorted(glob.glob(os.path.join(base_dir, '{}/B/*.*'.format(split))))

  def __len__(self):
    return max(len(self.file_A), len(self.file_B))

  def __getitem__(self,idx):
    Image_A = self.transform(Image.open(self.file_A[idx]))
    Image_B = self.transform(Image.open(self.file_B[random.randint(0,len(self.file_B) - 1)]))
    return  {
        'A':Image_A,
        'B':Image_B
    }



In [12]:
epoch = 0
n_epoch = 200
batch_size = 2
lr = 0.0002
size = 256
input_channels = 3
output_channels = 3
decay_epoch = 100

cuda = torch.cuda.is_available()
n_cpu = 2

base_dir = 'data/summer2winter_yosemite'

device = torch.device('cuda' if cuda else 'cpu')

In [13]:
def weights_init_normal(m):
  if isinstance(m,nn.Conv2d) :
    torch.nn.init.normal(m.weight.data, 0.0, 0.02)
  elif isinstance(m,nn.BatchNorm2d) :
    torch.nn.init.normal(m.weight.data, 1.0, 0.02)
    torch.nn.init.constant(m.bias, 0.0)

In [14]:
class LambdaLR():
  def __init__(self,n_epoch, offset, decay_start_epoch):
    assert((n_epoch - decay_start_epoch) > 0)
    self.n_epoch = n_epoch
    self.decay_start_epoch = decay_start_epoch
    self.offset = offset

  def step(self, epoch):
    return 1 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epoch - self.decay_start_epoch)

In [15]:
netG_A2B = Generator(input_channels, output_channels)
netG_B2A = Generator(input_channels, output_channels)
netD_A   = Discriminator(input_channels)
netD_B   = Discriminator(input_channels)

#Inicializar pesos
netG_A2B.apply(weights_init_normal)
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)

if cuda:
  netG_A2B.to(device)
  netG_B2A.to(device)
  netD_A.to(device)
  netD_B.to(device)

#Perdida
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

#Optimizadores
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),netG_B2A.parameters()), lr = lr, betas=(0.5,0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5,0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5,0.999))

#Schedulers (Actualizar el lr de forma dinamica durante el entrenamiento)
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_epoch,epoch,decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(n_epoch,epoch,decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(n_epoch,epoch,decay_epoch).step)

  torch.nn.init.normal(m.weight.data, 0.0, 0.02)


In [16]:
# inputs y Targets 
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor #Otra manera de llevar los tensores a gpu
target_real = Tensor(batch_size).fill_(1.0)
target_fake = Tensor(batch_size).fill_(0.0) 

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

#Data Loader

transform = [#transforms.Resize((100,100)),
             transforms.Resize(int(size*1.12), Image.BICUBIC),
             transforms.RandomCrop(size),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
             ]

dataloader = DataLoader(ImageDataset(base_dir, transform), 
                        batch_size=batch_size, shuffle=True, 
                        num_workers=n_cpu, drop_last=True)  #Drop_last evitar que el dataloader se quede sin imagenes y de un error


## Funciones de Perdidas

In [17]:
def Gen_GAN_loss(G, D, real, loss, target_real):
  fake = G(real)
  pred_fake = D(fake)
  L = loss(pred_fake,target_real)
  return L, fake

def Dis_GAN_loss(D2, fake2, real2, fake_2_buffer, loss, target_real, target_fake):
  pred_real = D2(real2)
  loss_D2_real = loss(pred_real, target_real)

  fake2 = fake_2_buffer.push_and_pop(fake2)
  pred_fake = D2(fake2.detach())
  loss_D2_fake = loss(pred_fake, target_fake)
  loss_D2 = (loss_D2_real + loss_D2_fake) * 0.5
  return loss_D2

def cycle_loss(G1, G2, real, loss):
  recovered = G2(G1(real))
  L = loss(recovered, real)
  return L

def identity_loss(G, real, loss):
  same = G(real)
  L = loss(same,real)
  return L

# Loop

In [18]:
from utils import Logger
logger = Logger(n_epoch, len(dataloader))
liveloss = PlotLosses()

In [19]:
for epoch in range(epoch, n_epoch):
  for i, batch in enumerate(dataloader):
    real_A = batch['A'].to(device)
    real_B = batch['B'].to(device)

    #Generativas 
    optimizer_G.zero_grad()

    loss_GAN_A2B, fake_B = Gen_GAN_loss(netG_A2B, netD_B, real_A, criterion_GAN, target_real)
    loss_GAN_B2A, fake_A = Gen_GAN_loss(netG_B2A, netD_A, real_B, criterion_GAN, target_real)

    loss_cycle_ABA = cycle_loss(netG_A2B, netG_B2A, real_A, criterion_cycle)
    loss_cycle_BAB = cycle_loss(netG_B2A, netG_A2B, real_B, criterion_cycle)

    loss_identity_B = identity_loss(netG_A2B, real_B, criterion_identity)
    loss_identity_A = identity_loss(netG_B2A, real_A, criterion_identity)

    loss_G = (loss_GAN_A2B + loss_GAN_B2A) + 10.0*(loss_cycle_ABA+loss_cycle_BAB) + 5.0*(loss_identity_A + loss_identity_B)
    loss_G.backward()
    optimizer_G.step()

    #Discriminativas
    optimizer_D_A.zero_grad()
    loss_D_A = Dis_GAN_loss(netD_A, fake_A, real_A, fake_A_buffer, criterion_GAN, target_real, target_fake)
    loss_D_A.backward()
    optimizer_D_A.step()

    optimizer_D_B.zero_grad()
    loss_D_B = Dis_GAN_loss(netD_B, fake_B, real_B, fake_B_buffer, criterion_GAN, target_real, target_fake)
    loss_D_B.backward()
    optimizer_D_B.step()


    log_values = {
        'loss_G':          loss_G,
        'loss_G_identity':(loss_identity_A + loss_identity_B),
        'loss_G_cycle':   (loss_cycle_ABA  + loss_cycle_BAB),
        'loss_G_GAN' :    (loss_GAN_A2B    + loss_GAN_B2A),
        'loss_D' :        (loss_D_A        + loss_D_B) 
    }

    logger.log(log_values, images={'real_A':real_A, 'real_B':real_B, 'fake_A':fake_A, 'fake_B':fake_B})
  
  liveloss.update(log_values)
  liveloss.draw()

  lr_scheduler_G.step()
  lr_scheduler_D_A.step()
  lr_scheduler_D_B.step()
    

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 001/200 [0003/0615] -- loss_G: 18.4073 | loss_G_identity: 1.1292 | loss_G_cycle: 1.1687 | loss_G_GAN: 1.0740 | loss_D: 0.7979 -- ETA: 59 days, 23:35:11.271396

KeyboardInterrupt: 