In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import utils
from torch.utils.data import DataLoader
import tqdm as tqdm
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import StandardScaler
from utils import load_csv, drop_cols, remove_strings, groupedAvg, subsample, normalize
import create_dataset
# from generators import UnetGenerator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load only a small part of the data and drop the unnecessary columns
path = "/home/johann/Desktop/Uni/Masterarbeit/Cycle_GAN/LeRntVAD_csv_exports_not_all_the_data/constant_speed_interventions/"
df = utils.load_csv(path)

df = utils.drop_cols(df)
df = df.dropna()

# select only rows where 'Phasenzuordnung' is 1
df = df.loc[df['Phasenzuordnung'] == 1]

print(df.shape)  

(720947, 13)


In [3]:
class AnimalDatasetEmbedding(torch.utils.data.Dataset):
    def __init__(self, df, feature_names, target_name, test,
                 window_length=256):
        self.df = df
        self.feature_names = feature_names
        self.target_name = target_name
        self.window_length = window_length
        self.test = test
        
        self.num_animals = len(np.unique(df["animal"]))
        self.animal_dfs = [group[1] for group in df.groupby("animal")]
        # get statistics for test dataset
        self.animal_lens = [len(an_df) // self.window_length for an_df in self.animal_dfs]
        self.animal_cumsum = np.cumsum(self.animal_lens)
        self.num_windows = sum(self.animal_lens)

        
    def __len__(self):
        # if self.test:
        return self.num_windows
        # else:
        #     return self.num_animals
    
    def __getitem__(self, idx):
        if self.test:
            # look up which test animal the idx corresponds to
            animal_idx = int(np.where(self.animal_cumsum >= idx)[0][0])
            animal_df = self.animal_dfs[animal_idx]
            # look up which part of the test animal the idx corresponds to 
            if animal_idx > 0:
                start_idx = idx - self.animal_cumsum[animal_idx - 1]
            else:
                start_idx = idx
            start_idx *= self.window_length
        else:
            # animal_df = self.animal_dfs[idx]
            animal_idx = int(np.where(self.animal_cumsum >= idx)[0][0])
            animal_df = self.animal_dfs[animal_idx]
            
            # take window
            start_idx = np.random.randint(0, len(animal_df) - self.window_length - 1)
        end_idx = start_idx + self.window_length
        animal_df = animal_df.iloc[start_idx: end_idx]
        
        # extract features
        input_df = animal_df[self.feature_names]
        target_df = animal_df[self.target_name]
        phase_df = animal_df["Phasenzuordnung"]
        
        # to torch
        inputs = torch.tensor(input_df.to_numpy()).permute(1, 0)
        targets = torch.tensor(target_df.to_numpy()).unsqueeze(0)
        phase = torch.tensor(phase_df.to_numpy())
        
        return inputs, targets, phase

In [4]:
# Verify that the data is loaded properly

SIG_A = "VadQ"
SIG_B = "LVP"
SIG_C = "AoQ"
SIG_D = "AoP"
feature_names = [SIG_A, SIG_B, SIG_C, SIG_D]
target = "LVtot_kalibriert" 

train_dataset = AnimalDatasetEmbedding(df, feature_names, target_name = target, test = False, window_length = 256)
test_dataset = AnimalDatasetEmbedding(df, feature_names, target_name = target, test = True, window_length = 256)

# Data loader
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=10, pin_memory=True,)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=True, num_workers=10, pin_memory=True,)

# length of train and test dataset
print(len(train_dataset), len(test_dataset))

2814 2814


In [7]:
def double_conv_pad(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1, padding_mode='zeros'),
        nn.LeakyReLU(inplace=True),
        nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1, padding_mode='zeros'),
        nn.LeakyReLU(inplace=True)
    )

class SkipConnectionsMultiChannelUnetGenerator(nn.Module):
    def __init__(self, INPUTCHANNELS, OUTPUTCHANNELS):
        super(SkipConnectionsMultiChannelUnetGenerator, self).__init__()
        self.maxpool = nn.MaxPool1d((2))  

        self.down_conv1 = double_conv_pad(INPUTCHANNELS, 32) 
        self.down_conv2 = double_conv_pad(32, 64) 
        self.down_conv3 = double_conv_pad(64, 128)
        self.down_conv4 = double_conv_pad(128, 256)

        self.embedding = torch.nn.Embedding(num_embeddings=5, embedding_dim=32)

        self.up_trans1 = nn.ConvTranspose1d(256, 128, kernel_size=(2), stride=2, padding=0)
        self.up_conv1 = double_conv_pad(256, 128)
        self.up_trans2 = nn.ConvTranspose1d(128, 64, kernel_size=(2), stride=2, padding=0)
        self.up_conv2 = double_conv_pad(128, 64)
        self.up_trans3 = nn.ConvTranspose1d(64, 32, kernel_size=(2), stride=2, padding=0)
        self.up_conv3 = double_conv_pad(64, 32)

        self.out = nn.Conv1d(32, OUTPUTCHANNELS, kernel_size=1) # kernel_size must be == 1

    def forward(self, input, phase):
        # [Batch size, Channels in, Height, Width]
        
        # downsampling
        x1 = self.down_conv1(input)   
        x2 = self.maxpool(x1) 
        x3 = self.down_conv2(x2)  
        x4 = self.maxpool(x3) 
        x5 = self.down_conv3(x4)  
        x6 = self.maxpool(x5) 
        x7 = self.down_conv4(x6)

        # upsampling
        e = self.embedding(phase)
        x7 = x7 + e
        x = self.up_trans1(x7)
        x = self.up_conv1(torch.cat([x, x5], 1))  # skip connection
        x = self.up_trans2(x)
        x = self.up_conv2(torch.cat([x, x3], 1))  # skip connection
        x = self.up_trans3(x)
        x = self.up_conv3(torch.cat([x, x1], 1))  # skip connection
        x = self.out(x)
        return x
    
CHANNELS = 4 # Channel anpassen!!

dummy = torch.LongTensor([[1]])
phase = torch.ones_like(dummy)
x1 = torch.rand(512, 1, 256)
x2 = torch.rand(512, 1, 256)
x3 = torch.rand(512, 1, 256)
x4 = torch.rand(512, 1, 256)
input = torch.cat([x1, x2, x3, x4], 1)
# print('Input shpe: ', input.size())
model = SkipConnectionsMultiChannelUnetGenerator(INPUTCHANNELS = CHANNELS, OUTPUTCHANNELS= 1)
print('\nOutput for {} chanels: '.format(CHANNELS), model(input, phase).shape)


Output for 4 chanels:  torch.Size([512, 1, 256])


In [None]:
CHANNELS = 4 # Channel anpassen!!

def double_conv_pad(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1, padding_mode='zeros'),
        nn.LeakyReLU(inplace=True),
        nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1, padding_mode='zeros'),
        nn.LeakyReLU(inplace=True)
    )

class MultiChannelUnetGenerator(nn.Module):
    def __init__(self):
        super(MultiChannelUnetGenerator, self).__init__()
        self.maxpool = nn.MaxPool2d((1, CHANNELS))

        self.down_conv1 = double_conv_pad(CHANNELS, 32) 
        self.down_conv2 = double_conv_pad(32, 64) 
        self.down_conv3 = double_conv_pad(64, 128)
        self.down_conv4 = double_conv_pad(128, 256)

        self.up_trans1 = nn.ConvTranspose1d(256, 128, kernel_size=(CHANNELS), stride=2, padding=0)
        self.up_conv1 = double_conv_pad(128, 128)
        self.up_trans2 = nn.ConvTranspose1d(128, 64, kernel_size=(CHANNELS), stride=2, padding=0)
        self.up_conv2 = double_conv_pad(64, 64)
        self.up_trans3 = nn.ConvTranspose1d(64, 32, kernel_size=(CHANNELS), stride=2, padding=0)
        self.up_conv3 = double_conv_pad(32, 32)

        self.out = nn.Conv1d(32, 1, kernel_size=CHANNELS)

    def forward(self, input):
        # [Batch size, Channels in, Height, Width]
        print("Input sizes: ", input.size())
        x1 = self.down_conv1(input)
        print('x1', x1.size())  
        x2 = self.maxpool(x1) 
        print('x2', x2.size())
        x3 = self.down_conv2(x2)  #
        print('x3', x3.size())
        x4 = self.maxpool(x3) 
        print('x4', x4.size()) 
        x5 = self.down_conv3(x4)  #
        print('x5', x5.size()) 
        x6 = self.maxpool(x5)
        print('x6', x6.size())  
        x7 = self.down_conv4(x6)
        print('x7', x7.size())

        # # decoder
        print("Upsampling")
        x = self.up_trans1(x7)
        print(x.size())
        x = self.up_conv1(x)
        x = self.up_trans2(x)
        print(x.size())
        x = self.up_conv2(x)
        x = self.up_trans3(x)
        print(x.size())
        x = self.up_conv3(x)
        x = self.out(x)
        return x
    
# x1 = torch.rand(512, 1, 1, 256)
# x2 = torch.rand(512, 1, 1, 256)

x1 = torch.rand(512, 1, 256)
x2 = torch.rand(512, 1, 256)
x3 = torch.rand(512, 1, 256)
x4 = torch.rand(512, 1, 256)
input = torch.cat([x1, x2, x3, x4], 1)
# print('Input shpe: ', input.size())
model = MultiChannelUnetGenerator()
print(model(input).shape)

Input sizes:  torch.Size([512, 2, 256])
x1 torch.Size([512, 32, 256])
x2 torch.Size([512, 32, 128])
x3 torch.Size([512, 64, 128])
x4 torch.Size([512, 64, 64])
x5 torch.Size([512, 128, 64])
x6 torch.Size([512, 128, 32])
x7 torch.Size([512, 256, 32])
Upsampling
torch.Size([512, 128, 64])
torch.Size([512, 64, 128])
torch.Size([512, 32, 256])
torch.Size([512, 1, 255])

In [None]:
# config

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 512  #1024 didsn't work so well
LEARNING_RATE = 1e-5  # 1e-5 was too small for 'LVtot_kalibriert' and 'LVtot'
# LAMBDA_IDENTITY = 0.1
LAMBDA_CYCLE = 10.0 # try out different values
NUM_WORKERS = 10
NUM_EPOCHS = 2000
LR_DECAY_AFTER_EPOCH = 800
GENERATION_AFTER_EPOCH = NUM_EPOCHS # number of epochs after which the model generates a sample
LOAD_MODEL = False
SAVE_MODEL = False
SIG_A = "AoP" # "VADcurrent" # "VadQ" # "AoP" # "LVP"  
SIG_B = "VadQ" # "LVtot" # "LVtot_kalibriert" # "LVtot" # "LVtot_kalibriert" 
TARGET = "LVtot_kalibriert" 
CHECKPOINT_GEN_A2B = "Checkpoints/Generated_data/gen_{}.pth.tar".format(SIG_B)
CHECKPOINT_GEN_B2A = "Checkpoints/Generated_data/gen_{}.pth.tar".format(SIG_A)
CHECKPOINT_DISC_A =  "Checkpoints/Generated_data/disc{}.pth.tar".format(SIG_A)
CHECKPOINT_DISC_B =  "Checkpoints/Generated_data/disc{}.pth.tar".format(SIG_B)

## Dataset
### Load data from csv

In [None]:
# paths of only a small part of the data
path = "/home/johann/Desktop/Uni/Masterarbeit/Cycle_GAN/LeRntVAD_csv_exports_not_all_the_data/constant_speed_interventions/" # time: ~ 19 sec

# all the data
path_1 = "/home/johann/Desktop/Uni/Masterarbeit/Cycle_GAN/csv_export_files_alle_Daten/csv_export_files/Data_split_1"  
path_2 = "/home/johann/Desktop/Uni/Masterarbeit/Cycle_GAN/csv_export_files_alle_Daten/csv_export_files/Data_split_2"
path_3 = "/home/johann/Desktop/Uni/Masterarbeit/Cycle_GAN/csv_export_files_alle_Daten/csv_export_files/Data_split_3"
path_4 = "/home/johann/Desktop/Uni/Masterarbeit/Cycle_GAN/csv_export_files_alle_Daten/csv_export_files/Data_split_4"

In [None]:
# load only a small part of the data and drop the unnecessary columns
path = "/home/johann/Desktop/Uni/Masterarbeit/Cycle_GAN/LeRntVAD_csv_exports_not_all_the_data/constant_speed_interventions/"
df = utils.load_csv(path)

df = utils.drop_cols(df)
df = df.dropna()

# select only rows where 'Phasenzuordnung' is 1
df = df.loc[df['Phasenzuordnung'] == 1]

print(df.shape)  

In [None]:
# Load all the data and drop unnecessary columns
# We load the data separately, to avoid a Runtime error

df_1 = utils.load_csv(path_1)
df_1 = utils.drop_cols(df_1)

df_2 = utils.load_csv(path_2)
df_2 = utils.drop_cols(df_2)

df_3 = utils.load_csv(path_3)
df_3 = utils.drop_cols(df_3)

df_4 = utils.load_csv(path_4)
df_4 = utils.drop_cols(df_4)

# concatenate the separate dataframes
df = pd.concat([df_1, df_2, df_3, df_4], axis=0, ignore_index=True)
df = df.dropna()

print('Size of the whole dataset',df.shape)
# select only rows where 'Phasenzuordnung' is 1
df = df.loc[df['Phasenzuordnung'] == 1]
print('Size of dataset with only the first phase',df.shape)

## Preprocesssing
### Removing strings from the data

In [None]:
df = utils.remove_strings(df)
# utils.visualize(df, 'AoP', 'LVP', 'VadQ', 'VADcurrent', 512)

### Subsample the data by the factor 10

In [None]:
df = utils.subsample(df, 10)
# utils.visualize(df, 'AoP', 'LVP', 'VadQ', 'VADcurrent', 512)

### Normalize the data

In [None]:
df = utils.normalize(df)
# utils.visualize(df, 'AoP', 'LVP', 'VadQ', 'VADcurrent', 512)

### Modulo 256

In [None]:
# I apply this currently in the dataframe loader

### Split data into train and test split

In [None]:
# hom many different intervention ids are there?
print('\nDifferent interventions: \n',df['intervention'].unique())

# hom many different animal ids are there?
print('\nDifferent animal IDs: \n',len(df['animal'].unique()))

# remove animals with less than 10 data points
df = df.groupby('animal').filter(lambda x: len(x) > 10)
print('\nDifferent animal IDs after removing those with less than 10 data points: \n',len(df['animal'].unique()))

# length of data per animal
#print(df.groupby('animal').size())

all_animals = df['animal'].unique()
# pick 2 random animals for test data
test_animals = df['animal'].sample(n=1, random_state=1).unique()

# remove test animals from train animals
train_animals =  [x for x in all_animals if x not in test_animals]

print('\nTest animals:', test_animals)
# test data
df_test = df[df['animal'].isin(test_animals)]

# train dataframe with only animals from train_animals
df_train = df[df['animal'].isin(train_animals)]
print('\nDifferent animal IDs after removing those that are in the test dataset: \n',len(df_train['animal'].unique()))


print('\nTrain data shape:', df_train.shape)
print('\nTest data shape:', df_test.shape)

#print('Unique animals in df_train: ',df_train['animal'].unique())
#print('Unique animals in df_test: ',df_test['animal'].unique())

# lengt of df_train
print('\nThe test dataset is {} percent of the whole data: '.format((len(df_test)/(len(df_train) + len(df_test))) * 100))

In [None]:
# create gen_dataset which is a part of the test dataset
# df_gen = df_test.sample(frac=0.01, random_state=1)
# print('\nGen data shape:', df_gen.shape)

### create a combined tensor of SIG_A and SIG_B

In [None]:
# create a combined tensor of SIG_A and SIG_B
df_A = df[SIG_A]
df_B = df[SIG_B]
#print(df_A.head)
#print(df_B.head)

df_A = df_A.to_numpy()
df_B = df_B.to_numpy()

# split df_A into a 2D array 
df_A = np.split(df_A, df_A.shape[0], axis=0)
df_B = np.split(df_B, df_B.shape[0], axis=0)

# create a combined tensor of df_A and df_B
combined = np.concatenate((df_A, df_B), axis=1)
# to dataframe
combined = pd.DataFrame(combined, columns = [SIG_A, SIG_B])
print(combined.head)
print(combined.shape)

In [None]:
# plot the combined tensor in a single plot

plt.plot(combined[SIG_A][:256])
plt.plot(combined[SIG_B][:256])
plt.show()


In [None]:
print(df_A[:5])
print(df_B[:5])

### Dataset loader

In [None]:
df.head()

In [None]:
class AnimalDataset(torch.utils.data.Dataset):
    def __init__(self, df, feature_names, target_name, test,
                 window_length=256):
        self.df = df
        self.feature_names = feature_names
        self.target_name = target_name
        self.window_length = window_length
        self.test = test
        
        self.num_animals = len(np.unique(df["animal"]))
        self.animal_dfs = [group[1] for group in df.groupby("animal")]
        # get statistics for test dataset
        self.animal_lens = [len(an_df) // self.window_length for an_df in self.animal_dfs]
        self.animal_cumsum = np.cumsum(self.animal_lens)
        self.num_windows = sum(self.animal_lens)

        
    def __len__(self):
        if self.test:
            return self.num_windows
        else:
            return self.num_animals
    
    def __getitem__(self, idx):
        if self.test:
            # look up which test animal the idx corresponds to
            animal_idx = np.where(self.animal_cumsum >= idx)[0]
            animal_df = self.animal_dfs[idx]
            # look up which part of the test animal the idx corresponds to 
            if animal_idx > 0:
                start_idx = idx - self.animal_cumsum[animal_idx - 1]
            else:
                start_idx = idx
            start_idx *= self.window_length
        else:
            animal_df = self.animal_dfs[idx]

            # take window
            start_idx = np.random.randint(0, len(animal_df) - self.window_length - 1)
        end_idx = start_idx + self.window_length
        animal_df = animal_df.iloc[start_idx: end_idx]
        
        # extract features
        input_df = animal_df[self.feature_names]
        target_df = animal_df[self.target_name]
        
        # to torch
        inputs = torch.tensor(input_df.to_numpy())
        targets = torch.tensor(target_df.to_numpy())
        
        return inputs, targets

In [None]:
test = True
print(len(df_test)/256)

In [None]:
# Verify that the data is loaded properly

SIG_A = "VadQ"
SIG_B = "LVP"
SIG_C = "AoQ"
SIG_D = "AoP"
feature_names = [SIG_A, SIG_B]
target = "LVtot_kalibriert" 
test_F = False
test_T = True



train_dataset = AnimalDataset(df_train, feature_names, target_name = target, test = False, window_length = 256)
test_dataset = AnimalDataset(df_test, feature_names, target_name = target, test = True, window_length = 256)
# train_dataset = create_dataset.TestDataset(signal_A=SIG_A, signal_B=SIG_B, df=df_train)
# print(len(train_dataset))
# train_dataset = TrainDataset(signal_A=SIG_A, signal_B=SIG_B, signal_C = SIG_C, signal_D = SIG_D, target = target, window = 4, df=df_train)
# test_dataset = create_dataset.TestDataset(signal_A=SIG_A, signal_B=SIG_B, df=df_test)
# print(len(test_dataset))
# test_dataset = TrainDataset(signal_A=SIG_A, signal_B=SIG_B, signal_C = SIG_C, signal_D = SIG_D, target = target, window = 4, df=df_test)

# Data loader
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=10, pin_memory=True,)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=True, pin_memory=True,)

# length of train and test dataset
print(len(train_dataset), len(test_dataset))

In [None]:
class TrainDataset(Dataset):
    def __init__(self, signal_A, signal_B, signal_C, signal_D, target,  window, df):
        self.df = df
        self.signal_A = self.df[signal_A]
        self.target = target
        self.window = window
        if window >= 2:
            self.signal_B = self.df[signal_B]
        if window >= 3:
            self.signal_B = self.df[signal_C]
        if window >= 4:
            self.signal_B = self.df[signal_D]

        if window > 4:
            print("The window determines, how many signals from the input are taken into consideration.")
            print("The window should be between 1 and 4")
            

        # only data from a single animal per batch 
        for animal in self.df['animal'].unique():
            df_single_animal = self.df[self.df['animal'] == animal]
            
            # length should be modulo 256 = 0
            df_single_animal = df_single_animal.iloc[:-(len(df_single_animal) % 256), :]


            if window >= 1:
                tensor_A = torch.tensor(df_single_animal[signal_A].values) # creating tensor from df 
                tensor_A = tensor_A.split(256)  # tensor shape (256, 1)
                stack_A = torch.stack(tensor_A).unsqueeze(1)

            if window >= 2:
                tensor_B = torch.tensor(df_single_animal[signal_B].values) # creating tensor from df 
                tensor_B = tensor_B.split(256)  # tensor shape (256, 1)
                stack_B = torch.stack(tensor_B).unsqueeze(1)

            if window >= 3:
                tensor_C = torch.tensor(df_single_animal[signal_C].values) # creating tensor from df 
                tensor_C = tensor_C.split(256)  # tensor shape (256, 1)
                stack_C = torch.stack(tensor_C).unsqueeze(1)

            if window == 4:
                tensor_D = torch.tensor(df_single_animal[signal_D].values) # creating tensor from df 
                tensor_D = tensor_D.split(256)  # tensor shape (256, 1)
                stack_D = torch.stack(tensor_D).unsqueeze(1)

            target_tensor = torch.tensor(df_single_animal[target].values) # creating tensor from df 
            target_tensor = target_tensor.split(256)  # tensor shape (256, 1)
            target_stack = torch.stack(target_tensor).unsqueeze(1)
                
            # The tensor of the first animal is added to self.tensor
            if animal == self.df['animal'].unique()[0]:
                self.tensor_A = stack_A
                self.target = target_stack
                if window >= 2:
                    self.tensor_B = stack_B
                if window >= 3:
                    self.tensor_C = stack_C
                if window >= 4:
                    self.tensor_D = stack_D
                
            else: # The tensor of each following animal is concatenated to self.tensor
                self.tensor_A = torch.cat((self.tensor_A, stack_A), 0)
                self.target = torch.cat((self.target, target_stack), 0)
                if window >= 2:
                    self.tensor_B = torch.cat((self.tensor_B, stack_B), 0)
                if window >= 3:
                    self.tensor_C = torch.cat((self.tensor_C, stack_C), 0)
                if window == 4:
                    self.tensor_D = torch.cat((self.tensor_D, stack_D), 0)
          

    def __len__(self):
        # all signals should have the same length
        return len(self.tensor_A)

    def __getitem__(self, index):
        #return the signal at the given index 
        if self.window == 1: 
            return self.tensor_A[index], self.target[index] # , self.tensor_B[index], self.tensor_C[index], self.tensor_D[index]
        if self.window == 2: 
            return self.tensor_A[index] , self.tensor_B[index], self.target[index]#, self.tensor_C[index], self.tensor_D[index], 
        if self.window == 3: 
            return self.tensor_A[index] , self.tensor_B[index], self.tensor_C[index], self.target[index] # self.tensor_D[index],
        if self.window == 4: 
            return self.tensor_A[index] , self.tensor_B[index], self.tensor_C[index], self.tensor_D[index], self.target[index]

In [None]:
# Data should be of shape (Batch, channels = 2, Hight = 1, Width = 256)
# Test Dataset includes data from a single animal

# class TrainDataset(Dataset):
#     def __init__(self, signal_A, signal_B, target,  df):
#         self.df = df
#         self.signal_A = self.df[signal_A]
#         self.signal_B = self.df[signal_B]
#         self.target = self.df[target]

#         # length should be modulo 256 = 0
#         self.df = self.df.iloc[:-(len(self.df) % 256), :]

#         # creating tensor from df 
#         tensor_A = torch.tensor(self.df[signal_A].values)
#         tensor_B = torch.tensor(self.df[signal_B].values)
#         tensor_target = torch.tensor(self.df[target].values)

#         # split tensor into tensors of size 256
#         tensor_A = tensor_A.split(256)  # tensor shape (256, 1) 
#         tensor_B = tensor_B.split(256)   
#         tensor_target = tensor_target.split(256)    

#         # stack tensors
#         self.tensor_A = torch.stack(tensor_A).unsqueeze(1) 
#         self.tensor_B = torch.stack(tensor_B).unsqueeze(1) 
#         self.tensor_target = torch.stack(tensor_target).unsqueeze(1)
#         # print(self.tensor_A.shape, self.tensor_A.shape)

#         self.tensor_A = self.tensor_A.unsqueeze(1)
#         self.tensor_B = self.tensor_B.unsqueeze(1)
#         self.tensor_target =self.tensor_target.unsqueeze(1)


#     def __len__(self):
#         # signal_A and signal_B should have the same length
#         return len(self.tensor_A)

#     def __getitem__(self, index):
#         # return the signal at the given index  # add data augmentation?
#         return self.tensor_A[index], self.tensor_B[index], self.tensor_target[index]

In [None]:
# class TestDataset(Dataset):
#     def __init__(self, signal_A, signal_B, target,  df):
#         self.df = df
#         self.signal_A = self.df[signal_A]
#         self.signal_B = self.df[signal_B]
#         self.target = self.df[target]

#         # length should be modulo 256 = 0
#         #self.df = self.df.iloc[:-(len(self.df) % 256), :]
#         #print(self.df.shape)
#         # creating tensor from df 
#         tensor_A = torch.tensor(self.df[signal_A].values)
#         tensor_B = torch.tensor(self.df[signal_B].values)
#         tensor_target = torch.tensor(self.df[target].values)

#         # split tensor into tensors of size 256
#         tensor_A = tensor_A.split(256)  # tensor shape (256, 1) 
#         tensor_B = tensor_B.split(256)   
#         tensor_target = tensor_target.split(256)    

#         # stack tensors
#         self.tensor_A = torch.stack(tensor_A).unsqueeze(1) 
#         self.tensor_B = torch.stack(tensor_B).unsqueeze(1) 
#         self.tensor_target = torch.stack(tensor_target).unsqueeze(1)
#         # print(self.tensor_A.shape, self.tensor_A.shape)

#         self.tensor_A = self.tensor_A.unsqueeze(1)
#         self.tensor_B = self.tensor_B.unsqueeze(1)
#         self.tensor_target =self.tensor_target.unsqueeze(1)


#     def __len__(self):
#         # signal_A and signal_B should have the same length
#         return len(self.tensor_A)

#     def __getitem__(self, index):
#         # return the signal at the given index  # add data augmentation?
#         return self.tensor_A[index], self.tensor_B[index], self.tensor_target[index]

In [None]:
print(len(df_test))
print(len(df_train))

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels= 2, out_channels= 32, kernel_size = 3, stride = 2, padding =1)
        self.conv2 = nn.Conv2d(in_channels= 32, out_channels= 64, kernel_size = 3, stride = 2, padding =1)
        self.conv3 = nn.Conv2d(in_channels= 64, out_channels= 128, kernel_size = 3, stride = 2, padding =1)
        self.conv4 = nn.Conv2d(in_channels= 128, out_channels= 1, kernel_size = 3, stride = 2, padding =1)
        
        self.out = nn.Sigmoid()


    def forward(self, input):
        x = self.conv1(input)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        #x = self.conv5(x)
        #x = self.conv6(x)
        x = self.out(x)
        return x
    
x1 = torch.rand(512, 1, 1, 256)
x2 = torch.rand(512, 1, 1, 256)
combined_input = torch.cat((x1, x2), dim=1)
print(combined_input.shape)
discriminator = Discriminator()
print(discriminator(combined_input).shape)

## Generator

In [None]:
# config

CHANNELS = 2

In [None]:
def double_conv_pad(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1, padding_mode='zeros'),
        nn.LeakyReLU(inplace=True),
        nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1, padding_mode='zeros'),
        nn.LeakyReLU(inplace=True)
    )


class UnetGenerator(nn.Module):
    def __init__(self):
        super(UnetGenerator, self).__init__()
        self.maxpool = nn.MaxPool1d(2)

        self.down_conv1 = double_conv_pad(CHANNELS, 32) 
        self.down_conv2 = double_conv_pad(32, 64) 
        self.down_conv3 = double_conv_pad(64, 128)
        self.down_conv4 = double_conv_pad(128, 256)

        self.up_trans1 = nn.ConvTranspose1d(256, 128, kernel_size=2, stride=2)
        self.up_conv1 = double_conv_pad(256, 128)
        self.up_trans2 = nn.ConvTranspose1d(128, 64, kernel_size=2, stride=2)
        self.up_conv2 = double_conv_pad(128, 64)
        self.up_trans3 = nn.ConvTranspose1d(64, 32, kernel_size=2, stride=2)
        self.up_conv3 = double_conv_pad(64, 32)

        self.out = nn.Conv1d(32, 1, kernel_size=1)

    def forward(self, input1, input2):
        
        # batch_size, channels, tensor_size
        # downsampling
        input = torch.cat([input1, input2], 1)
        x1 = self.down_conv1(input)  
        print(x1.size())   
        x2 = self.maxpool(x1) 
        print(x2.size()) 
        x3 = self.down_conv2(x2)  
        print(x3.size()) 
        x4 = self.maxpool(x3)
        print(x4.size())  
        x5 = self.down_conv3(x4)
        print(x5.size())    
        x6 = self.maxpool(x5)
        print(x6.size())  
        x7 = self.down_conv4(x6)
        print(x7.size()) 

        # upsampling
        x = self.up_trans1(x7)
        x = self.up_conv1(torch.cat([x, x5], 1))
        x = self.up_trans2(x)
        x = self.up_conv2(torch.cat([x, x3], 1))
        x = self.up_trans3(x)
        x = self.up_conv3(torch.cat([x, x1], 1))
        x = self.out(x)
        return x
    
x1 = torch.rand(512, 1, 256)
x2 = torch.rand(512, 1, 256)
model = UnetGenerator()
print(model(x1, x2).shape)

In [None]:
def double_conv_pad(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, padding_mode='zeros'),
        nn.LeakyReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, padding_mode='zeros'),
        nn.LeakyReLU(inplace=True)
    )

class Two_channel_UnetGenerator(nn.Module):
    def __init__(self):
        super(Two_channel_UnetGenerator, self).__init__()
        self.maxpool = nn.MaxPool2d((1,2))

        self.down_conv1 = double_conv_pad(2, 32) 
        self.down_conv2 = double_conv_pad(32, 64) 
        self.down_conv3 = double_conv_pad(64, 128)
        self.down_conv4 = double_conv_pad(128, 256)

        self.up_trans1 = nn.ConvTranspose2d(256, 128, kernel_size=(1,2), stride=2, padding=0)
        self.up_conv1 = double_conv_pad(128, 128)
        self.up_trans2 = nn.ConvTranspose2d(128, 64, kernel_size=(1,2), stride=2, padding=0)
        self.up_conv2 = double_conv_pad(64, 64)
        self.up_trans3 = nn.ConvTranspose2d(64, 32, kernel_size=(1,2), stride=2, padding=0)
        self.up_conv3 = double_conv_pad(32, 32)

        self.out = nn.Conv2d(32, 1, kernel_size=1)

    def forward(self, input_A, input_B):
        # [Batch size, Channels in, Height, Width]
        #print("Input sizes: ", input_A.size(), input_B.size())
        x1 = self.down_conv1(torch.cat([input_A, input_B], 1))
        # print(x1.size())  
        x2 = self.maxpool(x1) 
        # print(x2.size())
        x3 = self.down_conv2(x2)  #
        # print(x3.size())
        x4 = self.maxpool(x3) 
        # print(x4.size()) 
        x5 = self.down_conv3(x4)  #
        # print(x5.size()) 
        x6 = self.maxpool(x5)
        # print(x6.size())  
        x7 = self.down_conv4(x6)
        # print(x7.size())

        # # decoder
        # print("Upsampling")
        x = self.up_trans1(x7)
        # print(x.size())
        x = self.up_conv1(x)
        x = self.up_trans2(x)
        # print(x.size())
        x = self.up_conv2(x)
        x = self.up_trans3(x)
        # print(x.size())
        x = self.up_conv3(x)
        x = self.out(x)
        return x
    
# x1 = torch.rand(512, 1, 1, 256)
# x2 = torch.rand(512, 1, 1, 256)

x1 = torch.rand(512, 1, 256)
x2 = torch.rand(512, 1, 256)
model = Two_channel_UnetGenerator()
print(model(x1, x2).shape)

## Training

In [None]:
# Visualize the data we want to generate
plt.figure(figsize=(20, 5))
plt.plot(df_train[SIG_A][:1000], label= SIG_A)
plt.plot(df_train[SIG_B][:1000], label= SIG_B)
plt.plot(df_train[TARGET][:1000], label= TARGET)
plt.title('Comparison of {} and {}'.format(SIG_A, SIG_B))
plt.legend()
plt.show()

In [None]:
# initialize generator and discriminator
gen_source = Two_channel_UnetGenerator().to(DEVICE)
gen_target = Two_channel_UnetGenerator().to(DEVICE)
disc_source = Discriminator().to(DEVICE)
disc_target = Discriminator().to(DEVICE)

# optimizers for discriminator and generator 
opt_disc = torch.optim.AdamW(                                         
    list(disc_source.parameters()) + list(disc_target.parameters()), 
    lr=LEARNING_RATE, 
)
opt_gen = torch.optim.AdamW(
    list(gen_source.parameters()) + list(gen_target.parameters()),
    lr=LEARNING_RATE,
)

# scheduler
gen_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer = opt_gen,
                                                      total_iters = NUM_EPOCHS-LR_DECAY_AFTER_EPOCH, 
                                                      power = 1,
                                                    )
disc_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer = opt_disc,
                                                       total_iters = NUM_EPOCHS-LR_DECAY_AFTER_EPOCH, 
                                                       power = 1,
                                                    )

# run in float16
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

l1 = nn.L1Loss() # L1 loss for cycle consistency and identity loss
mse = nn.MSELoss() # MSE loss for adversarial loss

# create datasets 
dataset = TrainDataset(signal_A=SIG_A, signal_B=SIG_B, target=TARGET, df=df_train)
test_dataset = TestDataset(signal_A=SIG_A, signal_B=SIG_B, target=TARGET, df=df_test)
#gen_dataset = TestDataset(signal_A=SIG_A, signal_B=SIG_B, target=TARGET, df=df_gen)  

# Data loader
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True,)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True,)
#gen_loader = DataLoader(gen_dataset, batch_size=1, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True,)

# run in float16
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

train_losses = {
            'Discrminator source loss' : [],
            'Discrminator target loss' : [],
            'Total Discrminator loss' : [],
            'Adversaral loss source' : [],
            'Adversaral loss target' : [],
            'Cycle consistency loss source' : [],
            'Cycle consistency loss target' : [],
            'Total Generator loss' : [],
        }

test_losses = {
            'Discrminator source loss' : [],
            'Discrminator target loss' : [],
            'Total Discrminator loss' : [],
            'Adversaral loss source' : [],
            'Adversaral loss target' : [],
            'Cycle consistency loss source' : [],
            'Cycle consistency loss target' : [],
            'Total Generator loss' : [],
            'L1 loss between real source signal and fake source signals' : [],
            'L1 loss between real target signal and fake target signals' : [],
        }

B_reals = 0
B_fakes = 0

In [None]:
# training loop

for epoch in range(NUM_EPOCHS):

    for sig_A, sig_B, target in loader:
        # convert to float16
        sig_A = sig_A.float() # neccessary to prevent error: "Input type (torch.cuda.DoubleTensor) 
        sig_B = sig_B.float() # and weight type (torch.cuda.HalfTensor) should be the same"
        target = target.float()
        
        # move to GPU
        sig_A = sig_A.to(DEVICE)
        sig_B = sig_B.to(DEVICE)
        target = target.to(DEVICE)

        #  -------------------------------- #
        #  ----- train discriminators ----- #
        #  -------------------------------- #
        with torch.cuda.amp.autocast(): 
            fake_target = gen_target(sig_A, sig_B) # generate fake target signal
            d_target_real = disc_target(torch.cat([sig_A, sig_B], 1).to(DEVICE)) # output of discriminator target for real target signal
            
            # combine fake target signal with zeros to match the input size of the discriminator
            dummy_tensor = torch.zeros(len(sig_A), 1, 1, 256).to(DEVICE)
            combined_fake_target = torch.cat([fake_target.detach(), dummy_tensor], 1).to(DEVICE)
            d_target_fake = disc_target(combined_fake_target) # output of discriminator target for fake target signal 
                
            B_reals += d_target_real.mean().item()
            B_fakes += d_target_fake.mean().item()

            # Loss between dicriminator (with real signal) output and 1 - The discriminator should output 1 for real signals
            d_target_real_loss = mse(d_target_real, torch.ones_like(d_target_real))  
            # Loss between dicriminator (with fake signal) output and 0 - The discriminator should output 0 for fake signals
            d_target_fake_loss = mse(d_target_fake, torch.zeros_like(d_target_fake)) 
            # Total loss for discriminator B
            d_target_loss = d_target_real_loss + d_target_fake_loss

        
            fake_source = gen_source(target, dummy_tensor) # generate fake source signal
            combined_true_signals = torch.cat([sig_A, sig_B], 1).to(DEVICE)
            d_source_real = disc_source(combined_true_signals) # output of discriminator source for real source signal
            
            combined_fake_source= torch.cat([fake_source, dummy_tensor], 1).to(DEVICE)
            d_source_fake = disc_source(combined_fake_source.detach()) 

            d_source_real_loss = mse(d_source_real, torch.ones_like(d_source_real)) 
            d_source_fake_loss = mse(d_source_fake, torch.zeros_like(d_source_fake))  
            d_source_loss = d_source_real_loss + d_source_fake_loss

            # Total loss for discriminator A
            d_loss = d_source_loss + d_target_loss  # in cycle GAN paper they halve the loss

        # exit amp.auto_cast() context manager and backpropagate 
        opt_disc.zero_grad() 
        d_scaler.scale(d_loss).backward()  
        d_scaler.step(opt_disc)  
        d_scaler.update()


        # -------------------------------- #
        # ------- train generators ------- #
        # -------------------------------- # 
        with torch.cuda.amp.autocast():

            # ----- adversarial loss for both generators ----- #
            d_source_fake = disc_source(combined_fake_source) # disc_source should output 0 for fake source signal 
            d_target_fake = disc_target(combined_fake_target) # disc_target should output 0 for fake target signal
            # loss between discriminator output and 0 - The discriminator should output 0 for fake signals
            g_source_loss = mse(d_source_fake, torch.zeros_like(d_source_fake)) # was ones_like before  
            g_target_loss = mse(d_target_fake, torch.zeros_like(d_target_fake)) # was ones_like before

            # ----- cycle consistency loss ----- #
            cycle_target = gen_target(fake_source, dummy_tensor)  
            cycle_source = gen_source(fake_target, dummy_tensor) # fake_B = gen_A2B(sig_A)
            cycle_target_loss = l1(sig_B, cycle_target)  # l1 loss: Mean absolute error between each element in the input x and target y.
            cycle_source_loss = l1(sig_A, cycle_source)


            # put it all together
            g_loss = (
                g_source_loss +
                g_target_loss +
                cycle_target_loss * LAMBDA_CYCLE +
                cycle_source_loss * LAMBDA_CYCLE 
            )

        # update gradients of generator
        opt_gen.zero_grad()
        g_scaler.scale(g_loss).backward() 
        g_scaler.step(opt_gen) 
        g_scaler.update()

    # save losses
    train_losses['Discrminator source loss'].append(d_source_loss.item())
    train_losses['Discrminator target loss'].append(d_target_loss.item())
    train_losses['Total Discrminator loss'].append(d_loss.item())
    train_losses['Adversaral loss source'].append(g_source_loss.item())
    train_losses['Adversaral loss target'].append(g_target_loss.item())
    train_losses['Cycle consistency loss source'].append(cycle_source_loss.item())
    train_losses['Cycle consistency loss target'].append(cycle_target_loss.item())
    train_losses['Total Generator loss'].append(g_loss.item())  

    # ------------------------ #
    # ------ Validation ------ #
    # ------------------------ #

    #  validation every 10 epochs
    if (epoch+1) % 1 == 0:

        with torch.no_grad():
            # set models to evaluation mode
            disc_source.eval()  # set discriminator to evaluation mode
            disc_target.eval()  # turns off Dropouts Layers, BatchNorm Layers etc
            gen_target.eval()
            gen_source.eval()

            # store losses for testing
            test_Discrminator_source_loss = 0 #
            test_Discrminator_target_loss = 0 #
            test_Total_Discrminator_loss = 0 #
            test_Adversaral_loss_source = 0  #
            test_Adversaral_loss_target = 0  #
            test_Cycle_consistency_loss_source = 0 #
            test_Cycle_consistency_loss_target = 0 #
            test_Total_Generator_loss = 0 #
            test_L1_real_fake_source = 0  # L1 loss between real signal A and fake signal A
            test_L1_real_fake_target = 0  # L1 loss between real signal B and fake signal B

            for sig_A, sig_B, target in test_loader:
                # convert to float16
                sig_A = sig_A.float()
                sig_B = sig_B.float()
                target = target.float()
                # move to GPU
                sig_A = sig_A.to(DEVICE)
                sig_B = sig_B.to(DEVICE)
                target = target.to(DEVICE)

                dummy_tensor = torch.zeros(len(sig_A), 1, 1, 256).to(DEVICE)
                fake_target = gen_target(sig_A, sig_B)
                fake_source = gen_source(target, dummy_tensor)
        
                # calculate l1 loss of fake signals and real signals
                test_L1_real_fake_target = l1(target, fake_target)
                test_L1_real_fake_source = l1(torch.cat([sig_A, sig_B], 1), fake_source)  # maybe switch sig_A and sig_B

                # calculate adversarial loss
                dummy_fake_target = torch.zeros(len(fake_target), 1, 1, 256).to(DEVICE)  # torch.Size([512, 1, 1, 256])
                
                disc_target_fake_target = disc_target(torch.cat([fake_target, dummy_fake_target],1))
                disc_source_fake_source = disc_source(torch.cat([fake_source, dummy_tensor],1))
                test_Adversaral_loss_target = mse(disc_target_fake_target, torch.zeros_like(disc_target_fake_target)) #was ones_like before
                test_Adversaral_loss_source = mse(disc_source_fake_source, torch.zeros_like(disc_source_fake_source))

                # ----- cycle loss ----- #
                cycle_target = gen_target(fake_source, dummy_tensor)  
                cycle_source = gen_source(fake_target, dummy_tensor)  
                test_Cycle_consistency_loss_target = l1(target, cycle_target)
                test_Cycle_consistency_loss_source = l1(torch.cat([sig_A, sig_B], 1), cycle_source)

                # ----- discriminator loss ----- #
                source = torch.cat([sig_A, sig_B], 1) # torch.Size([512, 2, 1, 256])
                # print(disc_source(source).shape)  torch.Size([512, 1, 1, 16])
                #print(fake_source.shape)   # torch.Size([512, 1, 1, 256])
                dummy = torch.zeros(len(fake_source), 1, 1, 256).to(DEVICE)
                cat = torch.cat([fake_source, dummy], 1) 
                cat_target = torch.cat([target, dummy], 1).to(DEVICE)
                cat_fake_target = torch.cat([fake_target, dummy], 1).to(DEVICE)
                # print(disc_target(torch.cat([target, dummy], 1)).shape)   # torch.Size([512, 1, 1, 16])
                test_Discrminator_source_loss = mse(disc_source(source), torch.ones_like(disc_source(source))) + mse(disc_source(cat), torch.zeros_like(disc_source(cat)))
                
                # mse(disc_target(cat_target), torch.ones_like(disc_target(cat_target)))
                # mse(disc_target(cat_fake_target), torch.zeros_like(disc_target(cat_fake_target)))

                test_Discrminator_target_loss = mse(disc_target(cat_target), torch.ones_like(disc_target(cat_target))) + mse(disc_target(cat_fake_target), torch.zeros_like(disc_target(cat_fake_target)))
                
                # ----- total generator loss ----- #
                test_Total_Generator_loss = test_Adversaral_loss_source + test_Adversaral_loss_target + test_Cycle_consistency_loss_target + test_Cycle_consistency_loss_source 
                
                # ----- total discriminator loss ----- #
                test_Total_Discrminator_loss = test_Discrminator_source_loss + test_Discrminator_target_loss

                # save losses
                test_losses['Discrminator source loss'].append(test_Discrminator_source_loss.item())
                test_losses['Discrminator target loss'].append(test_Discrminator_target_loss.item())
                test_losses['Total Discrminator loss'].append(test_Total_Discrminator_loss.item())
                test_losses['Adversaral loss source'].append(test_Adversaral_loss_source.item())
                test_losses['Adversaral loss target'].append(test_Adversaral_loss_target.item())
                test_losses['Cycle consistency loss source'].append(test_Cycle_consistency_loss_source.item())
                test_losses['Cycle consistency loss target'].append(test_Cycle_consistency_loss_target.item())
                test_losses['Total Generator loss'].append(test_Total_Generator_loss.item())
                test_losses['L1 loss between real source signal and fake source signals'].append(test_L1_real_fake_source.item())
                test_losses['L1 loss between real target signal and fake target signals'].append(test_L1_real_fake_target.item())

In [None]:
# ----------------------------------- #
# -------------- PLOT --------------- #
# ----------------------------------- #

# Plot training losses in different subplots

fig, ax = plt.subplots(4, 1, figsize=(12, 24))
ax[0].plot(train_losses['Discrminator source loss'], label= 'Discrminator source loss (Training)')
ax[0].plot(train_losses['Discrminator target loss'], label= 'Discrminator target loss (Training)')
ax[0].plot(train_losses['Total Discrminator loss'], label= 'Total Discrminator loss (Training)')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].legend()
#ax[0].set_title('Training Discriminator Loss')

ax[1].plot(train_losses['Adversaral loss source'], label= 'Adversaral loss source (Training)')
ax[1].plot(train_losses['Adversaral loss target'], label= 'Adversaral loss target (Training)')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Loss')
ax[1].legend()
#ax[1].set_title('Training Adversarial Loss')
ax[2].plot(train_losses['Cycle consistency loss source'], label= 'Cycle consistency loss source (Training)')
ax[2].plot(train_losses['Cycle consistency loss target'], label= 'Cycle consistency loss target (Training)')
ax[2].set_xlabel('Epoch')
ax[2].set_ylabel('Loss')
ax[2].legend()
#ax[2].set_title('Training Cycle Consistency Loss')
ax[3].plot(train_losses['Total Generator loss'], label= 'Total Generator loss (Training)')
ax[3].set_xlabel('Epoch')
ax[3].set_ylabel('Loss')
ax[3].legend()

In [None]:
fig, ax = plt.subplots(5, 1, figsize=(12, 24))
ax[0].plot(test_losses['Discrminator source loss'], label= 'Discrminator source loss (Test)')
ax[0].plot(test_losses['Discrminator target loss'], label= 'Discrminator target loss (Test)')
ax[0].plot(test_losses['Total Discrminator loss'], label= 'Total Discrminator loss (Test)')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].legend()
#ax[4].set_title('Test Discriminator Loss')

ax[1].plot(test_losses['Adversaral loss source'], label= 'Adversaral loss source (Test)')
ax[1].plot(test_losses['Adversaral loss target'], label= 'Adversaral loss target (Test)')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Loss')
ax[1].legend()
#ax[5].set_title('Test Adversarial Loss')

ax[2].plot(test_losses['Cycle consistency loss source'], label= 'Cycle consistency loss source (Test)')
ax[2].plot(test_losses['Cycle consistency loss target'], label= 'Cycle consistency loss target (Test)')
ax[2].set_xlabel('Epoch')
ax[2].set_ylabel('Loss')
ax[2].legend()
#ax[6].set_title('Test Cycle Consistency Loss')

ax[3].plot(test_losses['Total Generator loss'], label= 'Total Generator loss (Test)')
ax[3].set_xlabel('Epoch')
ax[3].set_ylabel('Loss')
ax[3].legend()
#ax[7].set_title('Test Total Generator Loss')

ax[4].plot(test_losses['L1 loss between real source signal and fake source signals'], label= 'L1 loss between real source signal and fake source signals (Test)')
ax[4].plot(test_losses['L1 loss between real target signal and fake target signals'], label= 'L1 loss between real target signal and fake target signals (Test)')
ax[4].set_xlabel('Epoch')
ax[4].set_ylabel('Loss')
ax[4].legend()

In [None]:
            #  ------------------------------------- #   
            #  ------- Generate fake signals ------- #
            #  ------------------------------------- #
            
            # Generate fake signals after the last epoch
            
            if (epoch+1) % GENERATION_AFTER_EPOCH == 0:
                print('Generate fake signals')
                # generate fake signals 10 times
                #utils.save_predictions(gen_loader, gen_B, gen_A, fake_A, fake_B, DEVICE, mse)

                for sig_A, sig_B in gen_loader:
                    
                    sig_A = sig_A.float()
                    sig_B = sig_B.float()
                    sig_A = sig_A.to(DEVICE)
                    sig_B = sig_B.to(DEVICE)

                    fake_B = gen_B(sig_A)
                    fake_A = gen_A(sig_B)

                    # plot generated signals and real signals
                    #reshape to 1D
                    fake_B = fake_B.reshape(-1)
                    fake_A = fake_A.reshape(-1)
                    sig_A = sig_A.reshape(-1)
                    sig_B = sig_B.reshape(-1)

                    fig, ax = plt.subplots(2, 1, figsize=(10, 8))
                    ax[0].plot(sig_A.cpu().detach().numpy(), label= 'Real signal A')
                    ax[0].plot(fake_A.cpu().detach().numpy(), label= 'Generated signal A')
                    ax[0].set_xlabel('Epoch')
                    ax[0].set_ylabel('Loss')
                    ax[0].legend()

                    ax[1].plot(sig_B.cpu().detach().numpy(), label= 'Real signal B')
                    ax[1].plot(fake_B.cpu().detach().numpy(), label= 'Generated signal B')
                    ax[1].set_xlabel('Epoch')
                    ax[1].set_ylabel('Loss')
                    ax[1].legend()

                    # plot generated signals and real signals
                    # plt.figure(figsize=(20, 5))
                    # plt.plot(sig_A.cpu().detach().numpy(), label= 'Real signal A')
                    # plt.plot(sig_B.cpu().detach().numpy(), label= 'Real signal B')
                    # plt.plot(fake_A.cpu().detach().numpy(), label= 'Generated signal A')
                    # plt.plot(fake_B.cpu().detach().numpy(), label= 'Generated signal B')
                    # plt.title('Generated signals vs real signals')
                    # plt.legend()

                        
                    # save generated signals
                    # utils.save_predictions(sig_A, sig_B, fake_A, fake_B, epoch, mse, l1, DEVICE)

        # activate training mode again
        disc_A.train()  
        disc_B.train()
        gen_B.train()
        gen_A.train()

    # scheduler step if epoch > LR_DECAY_AFTER_EPOCH
    if (epoch+1) >= LR_DECAY_AFTER_EPOCH:
        disc_scheduler.step()
        gen_scheduler.step()
        
# ----------------------------------- #
# -------------- PLOT --------------- #
# ----------------------------------- #

# Plot training losses in different subplots

fig, ax = plt.subplots(9, 1, figsize=(12, 24))
ax[0].plot(train_losses['Discrminator A loss'], label= 'Discrminator A loss (Training)')
ax[0].plot(train_losses['Discrminator B loss'], label= 'Discrminator B loss (Training)')
ax[0].plot(train_losses['Total Discrminator loss'], label= 'Total Discrminator loss (Training)')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].legend()
#ax[0].set_title('Training Discriminator Loss')

ax[1].plot(train_losses['Adversaral loss A'], label= 'Adversaral loss A (Training)')
ax[1].plot(train_losses['Adversaral loss B'], label= 'Adversaral loss B (Training)')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Loss')
ax[1].legend()
#ax[1].set_title('Training Adversarial Loss')
ax[2].plot(train_losses['Cycle consistency loss A'], label= 'Cycle consistency loss A (Training)')
ax[2].plot(train_losses['Cycle consistency loss B'], label= 'Cycle consistency loss B (Training)')
ax[2].set_xlabel('Epoch')
ax[2].set_ylabel('Loss')
ax[2].legend()
#ax[2].set_title('Training Cycle Consistency Loss')
ax[3].plot(train_losses['Total Generator loss'], label= 'Total Generator loss (Training)')
ax[3].set_xlabel('Epoch')
ax[3].set_ylabel('Loss')
ax[3].legend()
#ax[3].set_title('Training Total Generator Loss')

# Plot test losses in different subplots

ax[4].plot(test_losses['Discrminator A loss'], label= 'Discrminator A loss (Test)')
ax[4].plot(test_losses['Discrminator B loss'], label= 'Discrminator B loss (Test)')
ax[4].plot(test_losses['Total Discrminator loss'], label= 'Total Discrminator loss (Test)')
ax[4].set_xlabel('Epoch')
ax[4].set_ylabel('Loss')
ax[4].legend()
#ax[4].set_title('Test Discriminator Loss')

ax[5].plot(test_losses['Adversaral loss A'], label= 'Adversaral loss A (Test)')
ax[5].plot(test_losses['Adversaral loss B'], label= 'Adversaral loss B (Test)')
ax[5].set_xlabel('Epoch')
ax[5].set_ylabel('Loss')
ax[5].legend()
#ax[5].set_title('Test Adversarial Loss')

ax[6].plot(test_losses['Cycle consistency loss A'], label= 'Cycle consistency loss A (Test)')
ax[6].plot(test_losses['Cycle consistency loss B'], label= 'Cycle consistency loss B (Test)')
ax[6].set_xlabel('Epoch')
ax[6].set_ylabel('Loss')
ax[6].legend()
#ax[6].set_title('Test Cycle Consistency Loss')

ax[7].plot(test_losses['Total Generator loss'], label= 'Total Generator loss (Test)')
ax[7].set_xlabel('Epoch')
ax[7].set_ylabel('Loss')
ax[7].legend()
#ax[7].set_title('Test Total Generator Loss')

ax[8].plot(test_losses['L1 loss between real signal A and fake signals A'], label= 'L1 loss between real signal A and fake signals A (Test)')
ax[8].plot(test_losses['L1 loss between real signal B and fake signals B'], label= 'L1 loss between real signal B and fake signals B (Test)')
ax[8].set_xlabel('Epoch')
ax[8].set_ylabel('Loss')
ax[8].legend()
#ax[8].set_title('Test L1 Loss')

plt.show()

print('Selected losses for the test dataset after the last epoch:\n')
print('\nL1 loss between real signal A and fake signals A: ', test_losses['L1 loss between real signal A and fake signals A'][-1])
print('\nL1 loss between real signal B and fake signals B: ', test_losses['L1 loss between real signal B and fake signals B'][-1])
print('\nDiscrminator A loss: ', test_losses['Discrminator A loss'][-1])
print('\nDiscrminator B loss: ', test_losses['Discrminator B loss'][-1])
print('\nTotal Discriminator loss: ', test_losses['Total Discrminator loss'][-1])
print('\nAdversaral loss A: ', test_losses['Adversaral loss A'][-1])
print('\nAdversaral loss B: ', test_losses['Adversaral loss B'][-1])
print('\nCycle consistency loss A: ', test_losses['Cycle consistency loss A'][-1])
print('\nCycle consistency loss B: ', test_losses['Cycle consistency loss B'][-1])
print('\nTotal Generator loss: ', test_losses['Total Generator loss'][-1])


print('Training finished')