# Cycle GAN

In [2]:
#import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import config
import utils
from torch.utils.data import DataLoader
import tqdm as tqdm


  from .autonotebook import tqdm as notebook_tqdm


### Discriminator

In [3]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, out_channels=1):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Conv1d(in_channels, out_channels=1, kernel_size=4, stride=2, padding=0),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(1, 1, 4, 2, 0),
            nn.InstanceNorm1d(1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(1, 1, 4, 2, 0),
            nn.InstanceNorm1d(1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(1, 1, 4, 2, 0),
            nn.InstanceNorm1d(1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(1, 1, 4, 2, 0),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

# pass 1d tensor through the discriminator and return the output
#LVP_tiny = LVP[0:100]
#LVP_tiny = torch.tensor(LVP_tiny).double()
#LVP_tiny = LVP_tiny.reshape(1, 1, 100)
#LVP_tiny = LVP_tiny.float()
#print(LVP_tiny)
disc = Discriminator(1)
#y = disc(LVP_tiny)
print(disc)


Discriminator(
  (disc): Sequential(
    (0): Conv1d(1, 1, kernel_size=(4,), stride=(2,))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv1d(1, 1, kernel_size=(4,), stride=(2,))
    (3): InstanceNorm1d(1, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv1d(1, 1, kernel_size=(4,), stride=(2,))
    (6): InstanceNorm1d(1, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv1d(1, 1, kernel_size=(4,), stride=(2,))
    (9): InstanceNorm1d(1, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv1d(1, 1, kernel_size=(4,), stride=(2,))
    (12): LeakyReLU(negative_slope=0.2, inplace=True)
    (13): Sigmoid()
  )
)


### Generator

In [4]:
#residual block
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv1d(channels, channels, 3, 1, 1),
            nn.InstanceNorm1d(channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(channels, channels, 3, 1, 1),
            nn.InstanceNorm1d(channels),
                )

    def forward(self, x):
        return x + self.block(x)

# Generator class with 2 Residual blocks
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Conv1d(in_channels, out_channels=64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Conv1d(64, 128, 3, 2, 1),
            nn.InstanceNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 3, 2, 1),
            nn.InstanceNorm1d(256),
            nn.ReLU(inplace=True),
            ResidualBlock(256),
            ResidualBlock(256),
            nn.ConvTranspose1d(256, 128, 3, 2, 1, 1),
            nn.InstanceNorm1d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose1d(128, 64, 3, 2, 1, 1),
            nn.InstanceNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Conv1d(64, out_channels, 7, 1, 3),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)

# pass 1d tensor through the block and return the output
gen = Generator(1, 1)
#y = gen( LVP_tiny)
#print("Output of Generator: ", y.shape, '\n', y, '\n')
print(gen)

Generator(
  (gen): Sequential(
    (0): Conv1d(1, 64, kernel_size=(7,), stride=(1,), padding=(3,))
    (1): InstanceNorm1d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv1d(64, 128, kernel_size=(3,), stride=(2,), padding=(1,))
    (4): InstanceNorm1d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
    (6): Conv1d(128, 256, kernel_size=(3,), stride=(2,), padding=(1,))
    (7): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (8): ReLU(inplace=True)
    (9): ResidualBlock(
      (block): Sequential(
        (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(inplace=True)
        (3): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        (4): InstanceNorm1d(256, eps=1e-05, momentum=0.1, a

### Dataset

In [6]:
# read csv file semi-colon separated
df_1 = pd.read_csv("/home/johann/Desktop/Uni/Masterarbeit/Cycle_GAN/LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_1_1_1_1_1_2.csv", sep=";")
#df_2 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_1_1_1_1_3_2.csv", sep=";")
#df_3 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_1_1_1_1_10_2.csv", sep=";")
#df_4 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_2_1_1_1_1_2.csv", sep=";")
#df_5 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_2_1_1_1_3_2.csv", sep=";")
#df_6 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_2_1_1_1_10_2.csv", sep=";")
#df_7 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_3_1_1_1_1_2.csv", sep=";")
#df_8 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_3_1_1_1_3_2.csv", sep=";")
#df_9 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_3_1_1_1_10_2.csv", sep=";")
#df_10 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_4_1_1_1_1_2.csv", sep=";")
#df_11 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_4_1_1_1_3_2.csv", sep=";")
#df_12 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_13_1_1_1_1_1_1_2.csv", sep=";")
#df_13 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_13_1_1_1_1_1_3_2.csv", sep=";")
#df_14 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_13_1_1_1_1_1_10_2.csv", sep=";")
#df_15 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_13_1_2_1_1_1_1_2.csv", sep=";")
#df_16 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_13_1_2_1_1_1_3_2.csv", sep=";")

# concatenate all dataframes
df = df_1
#df = pd.concat([df_1, df_2, df_3, df_4, df_5, df_6, df_7, df_8, df_9, df_10, df_11, df_12, df_13, df_14, df_15, df_16], ignore_index=True)
print(df.shape)

# access columns by name (e.g. df['LVP']) or by index (e.g. df.iloc[:, 0])

#df_57 = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_20_1_3_1_1_1_1_2.csv", sep=";")
df_58 = pd.read_csv("/home/johann/Desktop/Uni/Masterarbeit/Cycle_GAN/LeRntVAD_csv_exports/constant_speed_interventions/intervention_20_1_3_1_1_1_3_2.csv", sep=";")

df_test = df_58 #pd.concat([df_57, df_58], ignore_index=True)
print(df_test.shape)

(116300, 39)
(65800, 39)


In [8]:
# subsample data by a factor of 10
#df = df.sample(frac=0.1, random_state=1)
#df_test = df_test.sample(frac=0.1, random_state=1)

print(df.shape)
print(df_test.shape)

(116300, 39)
(65800, 39)


In [6]:
# small dataframe to make it easier to test the code
# df = pd.read_csv("LeRntVAD_csv_exports/constant_speed_interventions/intervention_11_1_1_1_1_1_1_2.csv", sep=";")

In [7]:
class SignalDataset(Dataset):
    def __init__(self, signal_A, signal_B, df):
        self.df = df
        self.signal_A = self.df[signal_A]
        self.signal_B = self.df[signal_B]

        # creating tensor from df 
        self.tensor_A = torch.tensor(self.df[signal_A].values).double()
        self.tensor_B = torch.tensor(self.df[signal_B].values).double()

        # split tensor into tensors of size 100
        self.tensor_A = self.tensor_A.split(100)
        self.tensor_B = self.tensor_B.split(100)

        for el in self.tensor_A:
            el = el.reshape(1, 1, 100)
            el = el.float()
            
        for el in self.tensor_B:
            el = el.reshape(1, 1, 100)
            el = el.float()
   

    def __len__(self):
        # signal_A and signal_B should have the same length
        return len(self.tensor_A)

    def __getitem__(self, index):
        # return the signal at the given index
        return self.tensor_A[index], self.tensor_B[index]

### Train model

In [8]:
# load checkpoint - only during train phase
#model_A2B = utils.load_checkpoint(config.CHECKPOINT_GEN_A2B, gen_A2B, opt_gen, config.LEARNING_RATE)
#model_B2A = utils.load_checkpoint(config.CHECKPOINT_GEN_B2A, gen_B2A, opt_gen, config.LEARNING_RATE)

In [9]:
def main():
    # initialize generator and discriminator
    gen_A2B = Generator(in_channels=1, out_channels=1).to("cuda")
    gen_B2A = Generator(in_channels=1, out_channels=1).to("cuda")
    disc_A = Discriminator(in_channels=1).to("cuda")
    disc_B = Discriminator(in_channels=1).to("cuda")

    # optimizers for discriminator and generator 
    opt_disc = torch.optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()), 
        lr=0.0002, 
        betas=(0.5, 0.999) 
    )
    opt_gen = torch.optim.Adam(
        list(gen_A2B.parameters()) + list(gen_B2A.parameters()),
        lr=0.0002,
        betas=(0.5, 0.999)
    )

    l1 = nn.L1Loss() # L1 loss for cycle consistency and identity loss
    mse = nn.MSELoss() # MSE loss for adversarial loss

    # load checkpoint if required
    load_checkpoint = False

    if load_checkpoint:
        utils.load_checkpoint(
            config.CHECKPOINT_GEN_B2A, gen_A2B, opt_gen, config.LEARNING_RATE,
        )
        utils.load_checkpoint(
            config.CHECKPOINT_GEN_A2B, gen_B2A, opt_gen, config.LEARNING_RATE,
        )
        utils.load_checkpoint(
            config.CHECKPOINT_DISC_A, disc_A, opt_disc, config.LEARNING_RATE,
        )
        utils.load_checkpoint(
            config.CHECKPOINT_DISC_B, disc_B, opt_disc, config.LEARNING_RATE,
        )

    # train data
    print('Shape of train df: ', df.shape, '\n') # df is a global variable that contains the data
    df_train = df

    # test data
    print('Shape of test dataframe: ', df_test.shape)

    # create datasets with class SignalDataset
    dataset = SignalDataset(signal_A='LVP', signal_B='AoP', df=df_train)
    test_dataset = SignalDataset(signal_A='LVP', signal_B='AoP', df=df_test)
    
    # Data loader
    batch_size = 1  # best batch size according to cycle GAN paper

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True,)

    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True,)

    # run in float16
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    # training loop

    # tqdm for progress bar
    # loop = tqdm(loader, leave=True) # does not work
    NUM_EPOCHS = 5
    for epoch in range(NUM_EPOCHS):

        for sig_A, sig_B in loader:
            # convert to float16
            sig_A = sig_A.float()
            sig_B = sig_B.float()

            # move to GPU
            sig_A = sig_A.to(config.DEVICE)
            sig_B = sig_B.to(config.DEVICE)

            # train discriminators
            with torch.cuda.amp.autocast(): #necessary for float16

                fake_B = gen_A2B(sig_A)
                d_B_real = disc_B(sig_B)
                d_B_fake = disc_B(fake_B.detach())
                d_B_real_loss = mse(d_B_real, torch.ones_like(d_B_real)) # why ones? Mentioned in cycle GAN paper
                d_B_fake_loss = mse(d_B_fake, torch.zeros_like(d_B_fake)) # why zeros? Mentioned in cycle GAN paper
                d_B_loss = d_B_real_loss + d_B_fake_loss

                fake_A = gen_B2A(sig_B)
                d_A_real = disc_A(sig_A)
                d_A_fake = disc_B(fake_A.detach()) 
                d_A_real_loss = mse(d_A_real, torch.ones_like(d_A_real)) # why ones? Mentioned in cycle GAN paper
                d_A_fake_loss = mse(d_A_fake, torch.zeros_like(d_A_fake)) # why zeros? Mentioned in cycle GAN paper
                d_A_loss = d_A_real_loss + d_A_fake_loss

                # put it all together
                d_loss = (d_A_loss + d_B_loss) / 2 # why average? Mentioned in cycle GAN paper

            # exit amp.auto_cast() context manager and backpropagate 
            opt_disc.zero_grad()   
            d_scaler.scale(d_loss).backward() 
            d_scaler.step(opt_disc)
            d_scaler.update()

            # train generators
            with torch.cuda.amp.autocast():

                # adversarial loss for both generators
                d_A_fake = disc_A(fake_A) # d_A_fake is the output of the discriminator for the fake_A signal
                d_B_fake = disc_B(fake_B)
                g_A_loss = mse(d_A_fake, torch.ones_like(d_A_fake))
                g_B_loss = mse(d_B_fake, torch.ones_like(d_B_fake))

                # cycle consistency loss
                cycle_B = gen_A2B(fake_A)
                cycle_A = gen_B2A(fake_B)
                cycle_B_loss = l1(sig_B, cycle_B)
                cycle_A_loss = l1(sig_A, cycle_A)

                # identity loss
                id_B = gen_A2B(sig_B) # id_B is the output of the generator for the sig_B signal
                id_A = gen_B2A(sig_A)
                id_B_loss = l1(sig_B, id_B)
                id_A_loss = l1(sig_A, id_A)

                # put it all together
                g_loss = (
                    g_A_loss +
                    g_B_loss +
                    cycle_B_loss * config.LAMBDA_CYCLE +
                    cycle_A_loss * config.LAMBDA_CYCLE +
                    id_B_loss * config.LAMBDA_IDENTITY +  # config.LAMBDA_IDENTITY = 0.0 -> no identity loss 
                    id_A_loss * config.LAMBDA_IDENTITY    # we could remove it to increase training speed
                )

            opt_gen.zero_grad()
            g_scaler.scale(g_loss).backward()
            g_scaler.step(opt_gen)
            g_scaler.update()
        
        # print(f"Loss D: {d_loss.item():.4f}, loss G: {g_loss.item():.4f}")
        print('\nEpoch [{}/{}], Loss D: {:.4f}, loss G: {:.4f}'.format(epoch+1, NUM_EPOCHS, d_loss.item(), g_loss.item()))
        
        #  validation
        disc_A.eval()  # set discriminator to evaluation mode
        disc_B.eval()  # turns off Dropouts Layers, BatchNorm Layers etc
        gen_A2B.eval()
        gen_B2A.eval()

        with torch.no_grad():
            mse_G_A2B = 0
            mse_G_B2A = 0
            for sig_A, sig_B in test_loader:
                # convert to float16
                sig_A = sig_A.float()
                sig_B = sig_B.float()

                # move to GPU
                sig_A = sig_A.to(config.DEVICE)
                sig_B = sig_B.to(config.DEVICE)

                fake_B = gen_A2B(sig_A)
                fake_A = gen_B2A(sig_B)
                
                # calculate mse loss of fake signals and real signals
                mse_A2B = mse(sig_B, fake_B)
                mse_B2A = mse(sig_A, fake_A)
                mse_G_A2B += mse_A2B
                mse_G_B2A += mse_B2A

            # print('\nTest:\nMSE loss Generator A2B: {:.4f}, MSE loss Generator B2A: {:.4f}'.format(mse_G_A2B/len(test_loader), mse_G_B2A/len(test_loader)))

    if config.SAVE_MODEL:
        utils.save_checkpoint(gen_A2B, opt_gen, path=config.CHECKPOINT_GEN_A2B)
        utils.save_checkpoint(gen_B2A, opt_gen, path=config.CHECKPOINT_GEN_B2A)
        utils.save_checkpoint(disc_A, opt_disc, path=config.CHECKPOINT_DISC_A)
        utils.save_checkpoint(disc_B, opt_disc, path=config.CHECKPOINT_DISC_B)

        # print progress
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")


if __name__ == '__main__':
    main()

Shape of train df:  (116300, 39) 

Shape of test dataframe:  (175800, 39)

Epoch [1/5], Loss D: 0.4390, loss G: 845.9126

Epoch [2/5], Loss D: 0.4278, loss G: 493.8284

Epoch [3/5], Loss D: 0.3816, loss G: 483.7133

Epoch [4/5], Loss D: 0.3756, loss G: 629.4152

Epoch [5/5], Loss D: 0.3648, loss G: 457.9908
=> Saving checkpoint at location:  Checkpoints/LVtot_kalibriert_to_RVtot_kalibriert/gen_LVtot.pth.tar
=> Saving checkpoint at location:  Checkpoints/LVtot_kalibriert_to_RVtot_kalibriert/gen_RVtot.pth.tar
=> Saving checkpoint at location:  Checkpoints/LVtot_kalibriert_to_RVtot_kalibriert/disc_LVtot.pth.tar
=> Saving checkpoint at location:  Checkpoints/LVtot_kalibriert_to_RVtot_kalibriert/disc_RVtot.pth.tar
Epoch [5/5]


### Some results

##### LVP to AoP

Epoch [1/5], Loss D: 0.3012, loss G: 619.1838

Epoch [2/5], Loss D: 0.3729, loss G: 1133.2758

Epoch [3/5], Loss D: 0.2997, loss G: 818.2902

Epoch [4/5], Loss D: 0.3086, loss G: 794.9103

Epoch [5/5]

#### AoP to AoQ

Epoch [1/10], Loss D: 0.3331, loss G: 763.7483

Epoch [2/10], Loss D: 0.3324, loss G: 528.8701

Epoch [3/10], Loss D: 0.3328, loss G: 416.0652

Epoch [4/10], Loss D: 0.3331, loss G: 541.8179

Epoch [5/10], Loss D: 0.3312, loss G: 500.3121

Epoch [6/10], Loss D: 0.3270, loss G: 490.1998

Epoch [7/10], Loss D: 0.3597, loss G: 390.4794

Epoch [8/10], Loss D: 0.3440, loss G: 403.8621

Epoch [9/10], Loss D: 0.3366, loss G: 566.6377

Epoch [10/10], Loss D: 0.3217, loss G: 452.6429

In [10]:
#To load a saved version of the model:

#saved_model = GarmentClassifier()
#saved_model.load_state_dict(torch.load(PATH))

#### To do: Use batch size > 1, mit Anton den code durchgehen

In [11]:
def test_fn(gen_A2B, gen_B2A, test_loader):

    for sig_A, sig_B in test_loader:
        # convert to float16
        sig_A = sig_A.float()
        sig_B = sig_B.float()

        # move to GPU
        sig_A = sig_A.to(config.DEVICE)
        sig_B = sig_B.to(config.DEVICE)

        # generate fake signals
        fake_B = gen_A2B(sig_A)
        fake_A = gen_B2A(sig_B)

        # store fake signals in dataframe
        fake_B = fake_B.cpu().detach().numpy()
        fake_A = fake_A.cpu().detach().numpy()
        fake_B = pd.DataFrame(fake_B)
        fake_A = pd.DataFrame(fake_A)
        # append both dataframes
        fake_B = fake_B.append(fake_A) 
        fake_B.to_csv('fake_signals.csv')

        #calculate mse loss between fake and real signals
        mse_loss = mse(sig_B, fake_B)
        print(f"MSE loss: {mse_loss.item():.4f}")

#### The code below works, but is split into a main() and train() function

In [12]:
def train_fn(disc_A, disc_B, gen_A2B, gen_B2A, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):

   #loop = tqdm(loader) # progress bar for training loop 

   for sig_A, sig_B in loader:
        # convert to float16
        sig_A = sig_A.float()
        sig_B = sig_B.float()

        # move to GPU
        sig_A = sig_A.to(config.DEVICE)
        sig_B = sig_B.to(config.DEVICE)

        # train discriminators
        with torch.cuda.amp.autocast(): #necessary for float16

            fake_B = gen_A2B(sig_A)
            d_B_real = disc_B(sig_B)
            d_B_fake = disc_B(fake_B.detach())
            d_B_real_loss = mse(d_B_real, torch.ones_like(d_B_real)) # why ones? Mentioned in cycle GAN paper
            d_B_fake_loss = mse(d_B_fake, torch.zeros_like(d_B_fake)) # why zeros? Mentioned in cycle GAN paper
            d_B_loss = d_B_real_loss + d_B_fake_loss

            fake_A = gen_B2A(sig_B)
            d_A_real = disc_A(sig_A)
            d_A_fake = disc_B(fake_A.detach()) 
            d_A_real_loss = mse(d_A_real, torch.ones_like(d_A_real)) # why ones? Mentioned in cycle GAN paper
            d_A_fake_loss = mse(d_A_fake, torch.zeros_like(d_A_fake)) # why zeros? Mentioned in cycle GAN paper
            d_A_loss = d_A_real_loss + d_A_fake_loss

            # put it together
            d_loss = (d_A_loss + d_B_loss) / 2 # why average? Mentioned in cycle GAN paper

        # exit amp.auto_cast() context manager and backpropagate 
        opt_disc.zero_grad()   
        d_scaler.scale(d_loss).backward() 
        d_scaler.step(opt_disc)
        d_scaler.update()

        # train generators
        with torch.cuda.amp.autocast():

            # adversarial loss for both generators
            d_A_fake = disc_A(fake_A) # d_A_fake is the output of the discriminator for the fake_A signal
            d_B_fake = disc_B(fake_B)
            g_A_loss = mse(d_A_fake, torch.ones_like(d_A_fake))
            g_B_loss = mse(d_B_fake, torch.ones_like(d_B_fake))

            # cycle consistency loss
            cycle_B = gen_A2B(fake_A)
            cycle_A = gen_B2A(fake_B)
            cycle_B_loss = l1(sig_B, cycle_B)
            cycle_A_loss = l1(sig_A, cycle_A)

            # identity loss
            id_B = gen_A2B(sig_B) # id_B is the output of the generator for the sig_B signal
            id_A = gen_B2A(sig_A)
            id_B_loss = l1(sig_B, id_B)
            id_A_loss = l1(sig_A, id_A)

            # put it together
            g_loss = (
                g_A_loss +
                g_B_loss +
                cycle_B_loss * config.LAMBDA_CYCLE +
                cycle_A_loss * config.LAMBDA_CYCLE +
                id_B_loss * config.LAMBDA_IDENTITY +  # LAMBDA_IDENTITY = 0.0 -> no identity loss 
                id_A_loss * config.LAMBDA_IDENTITY    # we could remove it t increase training speed
            )

        opt_gen.zero_grad()
        g_scaler.scale(g_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        # print(f"Loss D: {d_loss.item():.4f}, loss G: {g_loss.item():.4f}")
        
        #  validation
        

def main():
    disc_A = Discriminator(in_channels=1).to("cuda")
    disc_B = Discriminator(in_channels=1).to("cuda")
    gen_A2B = Generator(in_channels=1, out_channels=1).to("cuda")
    gen_B2A = Generator(in_channels=1, out_channels=1).to("cuda")

    # optimizers for discriminator and generator 
    opt_disc = torch.optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()), 
        lr=0.0002, 
        betas=(0.5, 0.999) 
    )
    opt_gen = torch.optim.Adam(
        list(gen_A2B.parameters()) + list(gen_B2A.parameters()),
        lr=0.0002,
        betas=(0.5, 0.999)
    )

    L1_loss = nn.L1Loss() # L1 loss for cycle consistency and identity loss
    MSE_loss = nn.MSELoss() # MSE loss for adversarial loss

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN_B2A, gen_A2B, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_GEN_A2B, gen_B2A, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISC_A, disc_A, opt_disc, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISC_B, disc_B, opt_disc, config.LEARNING_RATE,
        )

    # dataset
    # df is a global variable that contains the data
    print('Shape of df: ', df.shape, '\n')
    df_train = df
    # test data
    print('Shape of df_test: ', df_test.shape)

    # create dataset
    dataset = SignalDataset(signal_A='LVP', signal_B='AoP', df=df_train)
    test_dataset = SignalDataset(signal_A='LVP', signal_B='AoP', df=df_test)
    
    # Data loader

    loader = DataLoader(
        dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True,
    )

    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True,)

    # run in float16
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    # training loop
    NUM_EPOCHS = 1
    for epoch in range(NUM_EPOCHS):
        train_fn(
            disc_A,
            disc_B,
            gen_A2B,
            gen_B2A,
            loader,
            opt_disc,
            opt_gen,
            L1_loss,
            MSE_loss,
            d_scaler,
            g_scaler,
        )

        if config.SAVE_MODEL:
            utils.save_checkpoint(gen_A2B, opt_gen, filename=config.CHECKPOINT_GEN_A2B)
            utils.save_checkpoint(gen_B2A, opt_gen, filename=config.CHECKPOINT_GEN_B2A)
            utils.save_checkpoint(disc_A, opt_disc, filename=config.CHECKPOINT_DISC_A)
            utils.save_checkpoint(disc_B, opt_disc, filename=config.CHECKPOINT_DISC_B)

        # print progress
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")


if __name__ == '__main__':
    main()

Shape of df:  (116300, 39) 

Shape of df_test:  (175800, 39)


TypeError: save_checkpoint() got an unexpected keyword argument 'filename'