# Pix2Pix Cycle Gan
adapted from https://learnopencv.com/paired-image-to-image-translation-pix2pix/

In [None]:
import numpy as np

from torch.utils.data import DataLoader
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

from tqdm.notebook import tqdm

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
    Tensor = torch.cuda.FloatTensor
else:
    device = 'cpu'
    Tensor = torch.Tensor



In [None]:
wave_shape=1000
control_channels = 5
signal_channels = 1
control_shape = (control_channels, wave_shape)
signal_shape = (signal_channels, wave_shape)

# Fake Dataset

In [None]:
n_data=10000
control_data = torch.rand( (n_data, *control_shape)) #Uniform
signal_data = torch.randn((n_data, *signal_shape))  #Gaussian

val_control_data = torch.rand( (n_data, *control_shape)) #Uniform
val_signal_data = torch.randn((n_data, *signal_shape))  #Gaussian

In [None]:
fake_dataset =torch.cat([signal_data,control_data],dim=1) #one block, signal data at channel 0, control data at channels >0
val_fake_dataset =torch.cat([val_signal_data,val_control_data],dim=1)
batch_size = 128

In [None]:
dataloader = DataLoader(
    fake_dataset,
    batch_size=batch_size,
    shuffle=True,
)

In [None]:
val_dataloader = DataLoader(
    val_fake_dataset,
    batch_size=batch_size,
    shuffle=False,
)

In [None]:
for i, batch in enumerate(tqdm(dataloader)):
    # Set model input
    real_B = batch[:,0].unsqueeze(1)
    real_A = batch[:,1:]
    print(real_A.shape)
    print(real_B.shape)
    break

  0%|          | 0/79 [00:00<?, ?it/s]

torch.Size([128, 5, 1000])
torch.Size([128, 1, 1000])


In [None]:
#A->B control to signal

# Model

In [None]:
def init_weights(net, init_type='normal', scaling=0.02):
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv')) != -1:
            torch.nn.init.normal_(m.weight.data, 0.0, scaling)
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            torch.nn.init.normal_(m.weight.data, 1.0, scaling)
            torch.nn.init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>

## Generator

### Convolutonal UNet

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, kernel_size=5):
        super(Decoder, self).__init__()

        self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
        self.relu1 = nn.LeakyReLU(0.2, inplace=True)

        self.conv2 = nn.Conv1d(64, 128,kernel_size=kernel_size, stride=1, padding=kernel_size//2)
        self.relu2 = nn.LeakyReLU(0.2, inplace=True)

        self.conv3 = nn.Conv1d(128, 256,kernel_size=kernel_size, stride=1, padding=kernel_size//2)
        self.relu3 = nn.LeakyReLU(0.2, inplace=True)

        self.conv4 = nn.Conv1d(256, 1, kernel_size=kernel_size, stride=1, padding=kernel_size//2)

    def _forward_features(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = self.relu3(self.conv3(x))
        x = self.conv4(x)
        return x

    def forward(self, x):
        return self._forward_features(x)

In [None]:
generator=Decoder(5).to(device)

In [None]:
test=generator(torch.randn( (8, *control_shape)).to(device))
test.shape

torch.Size([8, 1, 1000])

In [None]:
del test, generator

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features, out_features, activation=nn.ReLU()):
        super(ResidualBlock, self).__init__()

        self.block = nn.Sequential(
            nn.Linear(in_features, out_features),
            activation,
            nn.Linear(out_features, in_features)
        )

    def forward(self, x):
        return x + self.block(x)

class ResMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_blocks):
        super(ResMLP, self).__init__()

        self.input_layer = nn.Linear(input_size, hidden_size)
        self.hidden_layers = nn.ModuleList([
            ResidualBlock(hidden_size, hidden_size) for _ in range(num_blocks)
        ])
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.input_layer(x.view((x.shape[0],-1)))
        for block in self.hidden_layers:
            x = block(x)
        x = self.output_layer(x).unsqueeze(1)
        return x


In [None]:
input_size = 1000*5
hidden_size = 128
output_size = 1000
num_blocks = 3
generator = ResMLP(input_size, hidden_size, output_size, num_blocks).to(device)

In [None]:
test=generator(torch.randn( (8, *control_shape)).to(device))
test.shape

torch.Size([8, 1, 1000])

In [None]:
del test, generator

In [None]:
class conbr_block(nn.Module):
    def __init__(self, in_layer, out_layer, kernel_size, stride, dilation):
        super(conbr_block, self).__init__()

        self.conv1 = nn.Conv1d(in_layer, out_layer, kernel_size=kernel_size, stride=stride, dilation = dilation, padding = 3, bias=True)
        self.bn = nn.BatchNorm1d(out_layer)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn(x)
        out = self.relu(x)

        return out

class se_block(nn.Module):
    def __init__(self,in_layer, out_layer):
        super(se_block, self).__init__()

        self.conv1 = nn.Conv1d(in_layer, out_layer//8, kernel_size=1, padding=0)
        self.conv2 = nn.Conv1d(out_layer//8, in_layer, kernel_size=1, padding=0)
        self.fc = nn.Linear(1,out_layer//8)
        self.fc2 = nn.Linear(out_layer//8,out_layer)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self,x):

        x_se = nn.functional.adaptive_avg_pool1d(x,1)
        x_se = self.conv1(x_se)
        x_se = self.relu(x_se)
        x_se = self.conv2(x_se)
        x_se = self.sigmoid(x_se)

        x_out = torch.add(x, x_se)
        return x_out

class re_block(nn.Module):
    def __init__(self, in_layer, out_layer, kernel_size, dilation):
        super(re_block, self).__init__()

        self.cbr1 = conbr_block(in_layer,out_layer, kernel_size, 1, dilation)
        self.cbr2 = conbr_block(out_layer,out_layer, kernel_size, 1, dilation)
        self.seblock = se_block(out_layer, out_layer)

    def forward(self,x):

        x_re = self.cbr1(x)
        x_re = self.cbr2(x_re)
        x_re = self.seblock(x_re)
        x_out = torch.add(x, x_re)
        return x_out

class UNET_1D(nn.Module):
    def __init__(self ,input_dim,layer_n,kernel_size,depth,out_channels):
        super(UNET_1D, self).__init__()
        self.input_dim = input_dim
        self.layer_n = layer_n
        self.kernel_size = kernel_size
        self.depth = depth

        self.AvgPool1D1 = nn.AvgPool1d(input_dim, stride=5)
        self.AvgPool1D2 = nn.AvgPool1d(input_dim, stride=25)
        self.AvgPool1D3 = nn.AvgPool1d(input_dim, stride=125)

        self.layer1 = self.down_layer(self.input_dim, self.layer_n, self.kernel_size,1, 2)
        self.layer2 = self.down_layer(self.layer_n, int(self.layer_n*2), self.kernel_size,5, 2)
        self.layer3 = self.down_layer(int(self.layer_n*2)+int(self.input_dim), int(self.layer_n*3), self.kernel_size,5, 2)
        self.layer4 = self.down_layer(int(self.layer_n*3)+int(self.input_dim), int(self.layer_n*4), self.kernel_size,5, 2)
        self.layer5 = self.down_layer(int(self.layer_n*4)+int(self.input_dim), int(self.layer_n*5), self.kernel_size,4, 2)

        self.cbr_up1 = conbr_block(int(self.layer_n*7), int(self.layer_n*3), self.kernel_size, 1, 1)
        self.cbr_up2 = conbr_block(int(self.layer_n*5), int(self.layer_n*2), self.kernel_size, 1, 1)
        self.cbr_up3 = conbr_block(int(self.layer_n*3), self.layer_n, self.kernel_size, 1, 1)
        self.upsample = nn.Upsample(scale_factor=5, mode='nearest')
        self.upsample1 = nn.Upsample(scale_factor=5, mode='nearest')

        self.outcov = nn.Conv1d(self.layer_n, out_channels, kernel_size=self.kernel_size, stride=1,padding = 3)


    def down_layer(self, input_layer, out_layer, kernel, stride, depth):
        block = []
        block.append(conbr_block(input_layer, out_layer, kernel, stride, 1))
        for i in range(depth):
            block.append(re_block(out_layer,out_layer,kernel,1))
        return nn.Sequential(*block)

    def forward(self, x):

        pool_x1 = self.AvgPool1D1(x)
        pool_x2 = self.AvgPool1D2(x)
        pool_x3 = self.AvgPool1D3(x)

        #############Encoder#####################

        out_0 = self.layer1(x)
        out_1 = self.layer2(out_0)

        x = torch.cat([out_1,pool_x1],1)
        out_2 = self.layer3(x)

        x = torch.cat([out_2,pool_x2],1)
        x = self.layer4(x)

        #############Decoder####################

        up = self.upsample1(x)
        up = torch.cat([up,out_2],1)
        up = self.cbr_up1(up)

        up = self.upsample(up)
        up = torch.cat([up,out_1],1)
        up = self.cbr_up2(up)

        up = self.upsample(up)
        up = torch.cat([up,out_0],1)
        up = self.cbr_up3(up)

        out = self.outcov(up)

        #out = nn.functional.softmax(out,dim=2)

        return out

In [None]:
#Funziona solo per kernel=7

generator =UNET_1D(5,layer_n=32,kernel_size=7,depth=1,out_channels=1).to(device)
init_weights(generator, 'normal', scaling=0.02)
print(generator)

initialize network with normal
UNET_1D(
  (AvgPool1D1): AvgPool1d(kernel_size=(5,), stride=(5,), padding=(0,))
  (AvgPool1D2): AvgPool1d(kernel_size=(5,), stride=(25,), padding=(0,))
  (AvgPool1D3): AvgPool1d(kernel_size=(5,), stride=(125,), padding=(0,))
  (layer1): Sequential(
    (0): conbr_block(
      (conv1): Conv1d(5, 32, kernel_size=(7,), stride=(1,), padding=(3,))
      (bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (1): re_block(
      (cbr1): conbr_block(
        (conv1): Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(3,))
        (bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
      )
      (cbr2): conbr_block(
        (conv1): Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(3,))
        (bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
      )
      (seblock): se_block(
 

In [None]:
test=generator(torch.randn( (8, *control_shape)).to(device))

In [None]:
test.shape

torch.Size([8, 1, 1000])

In [None]:
del test, generator

# Training

In [None]:
generator = Decoder(5).to(device)

loss functions

In [None]:
l1_loss = nn.MSELoss()

In [None]:
lr=0.001

optimizer

In [None]:
G_optimizer = torch.optim.Adam(generator.parameters(), lr=lr )

Training functions

## Encoder

In [None]:
def train_decoder(num_epochs,generator,criterion,optimizer,dataloader,val_loader):
    loss_plot =[]
    val_loss_plot =[]
    for epoch in range(1,num_epochs+1):
        epoch_loss =[]
        for i, batch in enumerate(tqdm(dataloader)):
            target = batch[:,0].unsqueeze(1).to(device)
            input = batch[:,1:].to(device)
            optimizer.zero_grad()
            generated = generator(input)
            loss=criterion(generated,target)
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.detach().cpu().numpy())
        val_loss =[]
        for batch in(tqdm(val_loader)):
            target = batch[:,0].unsqueeze(1).to(device)
            input = batch[:,1:].to(device)
            with torch.no_grad():
                generated = generator(input)
                loss=criterion(generated,target)
                val_loss.append(loss.detach().cpu().numpy())

        loss_plot.append(np.mean(epoch_loss))
        val_loss_plot.append(np.mean(val_loss))
        print('epoch: {} loss: {} val loss: {}'.format(epoch,loss_plot[-1],val_loss_plot[-1]))
    return loss_plot, val_loss_plot


In [None]:
history=train_decoder(100,generator,l1_loss,G_optimizer,dataloader,val_dataloader)

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 1 loss: 1.0007201433181763 val loss: 1.0002129077911377


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 2 loss: 0.9997130036354065 val loss: 1.0002306699752808


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 3 loss: 0.999698281288147 val loss: 1.000222086906433


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 4 loss: 0.9998498558998108 val loss: 1.0002238750457764


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 5 loss: 0.9997504353523254 val loss: 1.0002238750457764


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 6 loss: 0.9998071193695068 val loss: 1.000233769416809


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 7 loss: 0.9997954368591309 val loss: 1.0002423524856567


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 8 loss: 0.9996923804283142 val loss: 1.0002758502960205


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 9 loss: 0.9997217655181885 val loss: 1.0002830028533936


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 10 loss: 0.9996635317802429 val loss: 1.000285029411316


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 11 loss: 0.9997549653053284 val loss: 1.0003257989883423


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 12 loss: 0.9996540546417236 val loss: 1.0003024339675903


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 13 loss: 0.9993191361427307 val loss: 1.000385046005249


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 14 loss: 0.9997027516365051 val loss: 1.0004234313964844


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 15 loss: 0.999388575553894 val loss: 1.0004386901855469


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 16 loss: 0.9993555545806885 val loss: 1.000512957572937


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 17 loss: 0.9992579221725464 val loss: 1.000674843788147


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 18 loss: 0.9991976618766785 val loss: 1.0007456541061401


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 19 loss: 0.9994896054267883 val loss: 1.0007290840148926


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 20 loss: 0.999148964881897 val loss: 1.0009150505065918


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 21 loss: 0.9991765022277832 val loss: 1.0009245872497559


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 22 loss: 0.9987685680389404 val loss: 1.0011365413665771


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 23 loss: 0.9985387325286865 val loss: 1.0013097524642944


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 24 loss: 0.998165488243103 val loss: 1.0013834238052368


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

epoch: 25 loss: 0.998210072517395 val loss: 1.001717448234558


  0%|          | 0/79 [00:00<?, ?it/s]

KeyboardInterrupt: ignored