# Cycle Gan
adapted from https://www.kaggle.com/code/songseungwon/cyclegan-tutorial-from-scratch-monet-to-photo

In [1]:
import numpy as np

from torch.utils.data import DataLoader
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

from tqdm.notebook import tqdm

In [2]:
if torch.cuda.is_available():
    device = 'cuda'
    Tensor = torch.cuda.FloatTensor
else:
    device = 'cpu'
    Tensor = torch.Tensor



In [3]:
img_height = 256
img_width = 256
control_channels = 5
signal_channels = 1
control_shape = (control_channels, img_height, img_width)
signal_shape = (signal_channels, img_height, img_width)

# Fake Dataset

In [4]:
n_data=1000
control_data = torch.randn( (n_data, *control_shape)) #Gaussian
signal_data = torch.rand((n_data, *signal_shape))  #Uniform

val_control_data = torch.randn( (n_data, *control_shape)) #Gaussian
val_signal_data = torch.rand((n_data, *signal_shape))  #Uniform

In [5]:
fake_dataset =torch.cat([signal_data,control_data],dim=1) #one block, signal data at channel 0, control data at channels >0
val_fake_dataset =torch.cat([val_signal_data,val_control_data],dim=1)
batch_size = 4

In [6]:
dataloader = DataLoader(
    fake_dataset,
    batch_size=batch_size, # 1
    shuffle=True,
)

In [7]:
val_dataloader = DataLoader(
    val_fake_dataset,
    batch_size=batch_size, # 1
    shuffle=False,
)

In [8]:
for i, batch in enumerate(tqdm(dataloader)):
    # Set model input
    real_B = batch[:,0].unsqueeze(1)
    real_A = batch[:,1:]
    print(real_A.shape)
    print(real_B.shape)
    break

  0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([4, 5, 256, 256])
torch.Size([4, 1, 256, 256])


In [9]:
#A->B control to signal

# Model

In [10]:
#Tunare inizializzazioni
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02) # reset Conv2d's weight(tensor) with Gaussian Distribution
        if hasattr(m, 'bias') and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0) # reset Conv2d's bias(tensor) with Constant(0)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02) # reset BatchNorm2d's weight(tensor) with Gaussian Distribution
            torch.nn.init.constant_(m.bias.data, 0.0) # reset BatchNorm2d's bias(tensor) with Constant(0)

In [11]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1), # Pads the input tensor using the reflection of the input boundary
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)


class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_block, output_shape):
        super(GeneratorResNet, self).__init__()

        channels = input_shape[0]
        target_channels = output_shape[0]
        # Initial Convolution Block
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True)
        ]
        in_features = out_features

        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features

        # Residual blocks
        for _ in range(num_residual_block):
            model += [ResidualBlock(out_features)]

        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2), # --> width*2, heigh*2
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features

        # Output Layer
        model += [nn.ReflectionPad2d(target_channels),
                  nn.Conv2d(out_features,  target_channels, 7),
                  #nn.Tanh() #tune this activation
                 ]

        # Unpacking
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [12]:
G_AtoB= GeneratorResNet(control_shape,4,signal_shape).to(device)
G_BtoA= GeneratorResNet(signal_shape,4,control_shape).to(device)

In [13]:
G_AtoB.apply(weights_init_normal)
G_BtoA.apply(weights_init_normal)

GeneratorResNet(
  (model): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace=True)
    (10): ResidualBlock(
      (block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): ReLU(inplace=True)
        (4): ReflectionPad2d((1, 1, 1, 1))
        

In [14]:
test=G_AtoB(torch.randn( (8, *control_shape)).to(device) )
test.shape

torch.Size([8, 1, 256, 256])

In [15]:
del test

In [16]:
test=G_BtoA(torch.rand((8, *signal_shape)).to(device))
test.shape

torch.Size([8, 5, 256, 256])

In [17]:
del test

In [18]:
#Tunare numero di patch in output
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height//2**4, width//2**4)

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128,256),
            *discriminator_block(256,512),
            nn.ZeroPad2d((1,0,1,0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [19]:
D_A=Discriminator(control_shape).to(device)
D_B=Discriminator(signal_shape).to(device)

In [20]:
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)

Discriminator(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): ZeroPad2d((1, 0, 1, 0))
    (12): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [21]:
test=D_A(torch.randn( (8, *control_shape)).to(device) )
test.shape

torch.Size([8, 1, 16, 16])

In [22]:
D_A.output_shape

(1, 16, 16)

In [23]:
test=D_B(torch.rand((8, *signal_shape)).to(device))
test.shape

torch.Size([8, 1, 16, 16])

In [24]:
D_B.output_shape

(1, 16, 16)

# Training

optimizer

In [25]:


import itertools
lr = 0.0002
b1 = 0.5
b2 = 0.999

optimizer_G = torch.optim.Adam(
    itertools.chain(G_AtoB.parameters(), G_BtoA.parameters()), lr=lr, betas=(b1,b2)
)

optimizer_D_A = torch.optim.Adam(
    D_A.parameters(), lr=lr, betas=(b1,b2)
)
optimizer_D_B = torch.optim.Adam(
    D_B.parameters(), lr=lr, betas=(b1,b2)
)



Scheduler (da tunare/testare alternative)

In [26]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch+self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)


In [27]:
n_epochs = 10
epoch = 0
decay_epoch = 5

In [28]:
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G,
    lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)

lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A,
    lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B,
    lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)

In [29]:
#Loss

In [30]:
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()

In [31]:
#Training

In [33]:


for epoch in range(epoch, n_epochs):
    for i, batch in enumerate(tqdm(dataloader)):

        # Set model input
        real_B = batch[:,0].unsqueeze(1).to(device)
        real_A = batch[:,1:].to(device)

        # Adversarial ground truths
        valid = Tensor(np.ones((real_A.size(0), *D_A.output_shape))) # requires_grad = False. Default.
        fake = Tensor(np.zeros((real_A.size(0), *D_A.output_shape))) # requires_grad = False. Default.

# -----------------
# Train Generators
# -----------------
        G_AtoB.train() # train mode
        G_BtoA.train() # train mode

        optimizer_G.zero_grad() # Integrated optimizer(G_AB, G_BA)

        # Identity Loss #non si puo' usare se input e output hanno canali diversi
        #loss_id_A = criterion_identity(G_BA(real_A), real_A) # If you put A into a generator that creates A with B,
        #loss_id_B = criterion_identity(G_AB(real_B), real_B) # then of course A must come out as it is.
                                                             # Taking this into consideration, add an identity loss that simply compares 'A and A' (or 'B and B').
        #loss_identity = (loss_id_A + loss_id_B)/2

        # GAN Loss
        fake_B = G_AtoB(real_A) # fake_B is fake-photo that generated by real monet-drawing
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid) # tricking the 'fake-B' into 'real-B'
        fake_A = G_BtoA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid) # tricking the 'fake-A' into 'real-A'

        loss_GAN = (loss_GAN_AB + loss_GAN_BA)/2

        # Cycle Loss
        recov_A = G_BtoA(fake_B) # recov_A is fake-monet-drawing that generated by fake-photo
        loss_cycle_A = criterion_cycle(recov_A, real_A) # Reduces the difference between the restored image and the real image
        recov_B = G_AtoB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B)/2

# ------> Total Loss
        loss_G = loss_GAN + (10.0*loss_cycle) #+ (5.0*loss_identity) # multiply suggested weight(default cycle loss weight : 10, default identity loss weight : 5) #Tunare parametri

        loss_G.backward()
        optimizer_G.step()

# -----------------
# Train Discriminator A
# -----------------
        optimizer_D_A.zero_grad()

        loss_real = criterion_GAN(D_A(real_A), valid) # train to discriminate real images as real
        loss_fake = criterion_GAN(D_A(fake_A.detach()), fake) # train to discriminate fake images as fake

        loss_D_A = (loss_real + loss_fake)/2

        loss_D_A.backward()
        optimizer_D_A.step()

# -----------------
# Train Discriminator B
# -----------------
        optimizer_D_B.zero_grad()

        loss_real = criterion_GAN(D_B(real_B), valid) # train to discriminate real images as real
        loss_fake = criterion_GAN(D_B(fake_B.detach()), fake) # train to discriminate fake images as fake

        loss_D_B = (loss_real + loss_fake)/2

        loss_D_B.backward()
        optimizer_D_B.step()

# ------> Total Loss
        loss_D = (loss_D_A + loss_D_B)/2

# -----------------
# Show Progress
# -----------------
        if (i+1) % 50 == 0:
            print('[Epoch %d/%d] [Batch %d/%d] [D loss : %f] [G loss : %f - (adv : %f, cycle : %f)]'
                    %(epoch+1,n_epochs,       # [Epoch -]
                      i+1,len(dataloader),   # [Batch -]
                      loss_D.item(),       # [D loss -]
                      loss_G.item(),       # [G loss -]
                      loss_GAN.item(),     # [adv -]
                      loss_cycle.item(),   # [cycle -]
                     ))



  0%|          | 0/250 [00:00<?, ?it/s]

[Epoch 1/10] [Batch 50/250] [D loss : 0.373314] [G loss : 6.223432 - (adv : 0.893611, cycle : 0.532982)]
[Epoch 1/10] [Batch 100/250] [D loss : 0.295497] [G loss : 6.453386 - (adv : 1.115755, cycle : 0.533763)]
[Epoch 1/10] [Batch 150/250] [D loss : 0.208511] [G loss : 6.040005 - (adv : 0.703751, cycle : 0.533625)]
[Epoch 1/10] [Batch 200/250] [D loss : 0.156324] [G loss : 5.878373 - (adv : 0.587631, cycle : 0.529074)]
[Epoch 1/10] [Batch 250/250] [D loss : 0.128249] [G loss : 5.895085 - (adv : 0.639083, cycle : 0.525600)]


  0%|          | 0/250 [00:00<?, ?it/s]

[Epoch 2/10] [Batch 50/250] [D loss : 0.142328] [G loss : 6.056457 - (adv : 0.796089, cycle : 0.526037)]
[Epoch 2/10] [Batch 100/250] [D loss : 0.090149] [G loss : 6.230555 - (adv : 0.957894, cycle : 0.527266)]
[Epoch 2/10] [Batch 150/250] [D loss : 0.060314] [G loss : 6.273722 - (adv : 0.982296, cycle : 0.529143)]
[Epoch 2/10] [Batch 200/250] [D loss : 0.128777] [G loss : 6.370587 - (adv : 1.019404, cycle : 0.535118)]
[Epoch 2/10] [Batch 250/250] [D loss : 0.111018] [G loss : 5.885938 - (adv : 0.626019, cycle : 0.525992)]


  0%|          | 0/250 [00:00<?, ?it/s]

[Epoch 3/10] [Batch 50/250] [D loss : 0.091625] [G loss : 6.086421 - (adv : 0.811792, cycle : 0.527463)]
[Epoch 3/10] [Batch 100/250] [D loss : 0.471972] [G loss : 7.664742 - (adv : 2.415781, cycle : 0.524896)]
[Epoch 3/10] [Batch 150/250] [D loss : 0.248099] [G loss : 6.334882 - (adv : 1.029215, cycle : 0.530567)]
[Epoch 3/10] [Batch 200/250] [D loss : 0.143979] [G loss : 5.774282 - (adv : 0.486006, cycle : 0.528828)]
[Epoch 3/10] [Batch 250/250] [D loss : 0.162593] [G loss : 5.626543 - (adv : 0.345347, cycle : 0.528120)]


  0%|          | 0/250 [00:00<?, ?it/s]

[Epoch 4/10] [Batch 50/250] [D loss : 0.129090] [G loss : 6.043681 - (adv : 0.793050, cycle : 0.525063)]
[Epoch 4/10] [Batch 100/250] [D loss : 0.249046] [G loss : 6.655861 - (adv : 1.404075, cycle : 0.525179)]
[Epoch 4/10] [Batch 150/250] [D loss : 0.110049] [G loss : 6.129122 - (adv : 0.725287, cycle : 0.540383)]
[Epoch 4/10] [Batch 200/250] [D loss : 0.128982] [G loss : 5.741242 - (adv : 0.491140, cycle : 0.525010)]
[Epoch 4/10] [Batch 250/250] [D loss : 0.214282] [G loss : 5.598208 - (adv : 0.339012, cycle : 0.525920)]


  0%|          | 0/250 [00:00<?, ?it/s]

[Epoch 5/10] [Batch 50/250] [D loss : 0.156466] [G loss : 5.833733 - (adv : 0.544773, cycle : 0.528896)]
[Epoch 5/10] [Batch 100/250] [D loss : 0.125914] [G loss : 5.884059 - (adv : 0.633181, cycle : 0.525088)]
[Epoch 5/10] [Batch 150/250] [D loss : 0.131426] [G loss : 5.810683 - (adv : 0.555794, cycle : 0.525489)]
[Epoch 5/10] [Batch 200/250] [D loss : 0.090779] [G loss : 6.030189 - (adv : 0.642151, cycle : 0.538804)]
[Epoch 5/10] [Batch 250/250] [D loss : 0.200812] [G loss : 5.655812 - (adv : 0.281104, cycle : 0.537471)]


  0%|          | 0/250 [00:00<?, ?it/s]

[Epoch 6/10] [Batch 50/250] [D loss : 0.130704] [G loss : 6.075259 - (adv : 0.825993, cycle : 0.524927)]
[Epoch 6/10] [Batch 100/250] [D loss : 0.167869] [G loss : 5.845057 - (adv : 0.590510, cycle : 0.525455)]
[Epoch 6/10] [Batch 150/250] [D loss : 0.083368] [G loss : 6.008239 - (adv : 0.763652, cycle : 0.524459)]
[Epoch 6/10] [Batch 200/250] [D loss : 0.137497] [G loss : 5.751728 - (adv : 0.494208, cycle : 0.525752)]
[Epoch 6/10] [Batch 250/250] [D loss : 0.225003] [G loss : 5.530060 - (adv : 0.284348, cycle : 0.524571)]


  0%|          | 0/250 [00:00<?, ?it/s]

[Epoch 7/10] [Batch 50/250] [D loss : 0.143728] [G loss : 5.708829 - (adv : 0.456278, cycle : 0.525255)]
[Epoch 7/10] [Batch 100/250] [D loss : 0.218354] [G loss : 5.502083 - (adv : 0.252165, cycle : 0.524992)]
[Epoch 7/10] [Batch 150/250] [D loss : 0.138739] [G loss : 5.606451 - (adv : 0.351045, cycle : 0.525541)]
[Epoch 7/10] [Batch 200/250] [D loss : 0.159292] [G loss : 5.581925 - (adv : 0.284757, cycle : 0.529717)]
[Epoch 7/10] [Batch 250/250] [D loss : 0.110150] [G loss : 5.921391 - (adv : 0.626558, cycle : 0.529483)]


  0%|          | 0/250 [00:00<?, ?it/s]

[Epoch 8/10] [Batch 50/250] [D loss : 0.127674] [G loss : 5.808625 - (adv : 0.547172, cycle : 0.526145)]
[Epoch 8/10] [Batch 100/250] [D loss : 0.073336] [G loss : 5.905272 - (adv : 0.637895, cycle : 0.526738)]
[Epoch 8/10] [Batch 150/250] [D loss : 0.175760] [G loss : 6.007504 - (adv : 0.750491, cycle : 0.525701)]
[Epoch 8/10] [Batch 200/250] [D loss : 0.100383] [G loss : 5.927482 - (adv : 0.664845, cycle : 0.526264)]
[Epoch 8/10] [Batch 250/250] [D loss : 0.110575] [G loss : 6.140530 - (adv : 0.890576, cycle : 0.524995)]


  0%|          | 0/250 [00:00<?, ?it/s]

[Epoch 9/10] [Batch 50/250] [D loss : 0.167711] [G loss : 5.732486 - (adv : 0.460358, cycle : 0.527213)]
[Epoch 9/10] [Batch 100/250] [D loss : 0.117975] [G loss : 5.782455 - (adv : 0.533379, cycle : 0.524908)]
[Epoch 9/10] [Batch 150/250] [D loss : 0.156256] [G loss : 5.746130 - (adv : 0.485926, cycle : 0.526020)]
[Epoch 9/10] [Batch 200/250] [D loss : 0.070844] [G loss : 5.991928 - (adv : 0.725011, cycle : 0.526692)]
[Epoch 9/10] [Batch 250/250] [D loss : 0.068896] [G loss : 6.022654 - (adv : 0.733314, cycle : 0.528934)]


  0%|          | 0/250 [00:00<?, ?it/s]

[Epoch 10/10] [Batch 50/250] [D loss : 0.128581] [G loss : 5.872954 - (adv : 0.618117, cycle : 0.525484)]
[Epoch 10/10] [Batch 100/250] [D loss : 0.123084] [G loss : 5.926631 - (adv : 0.635403, cycle : 0.529123)]
[Epoch 10/10] [Batch 150/250] [D loss : 0.299900] [G loss : 5.520626 - (adv : 0.264330, cycle : 0.525630)]
[Epoch 10/10] [Batch 200/250] [D loss : 0.151180] [G loss : 6.001111 - (adv : 0.739586, cycle : 0.526152)]
[Epoch 10/10] [Batch 250/250] [D loss : 0.132248] [G loss : 5.954333 - (adv : 0.483141, cycle : 0.547119)]
