In [None]:
%pip install torch-geometric
%pip install pynvml

In [None]:
import os
import math
import time
import psutil
import pynvml
import random
import numpy as np
import networkx as nx

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils as nn_utils
import torch_geometric.utils as utils
from torch_geometric.utils import to_networkx
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.datasets import TUDataset, Planetoid
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

from sklearn.linear_model import Ridge
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, mean_absolute_error

from scipy.spatial.distance import jensenshannon
from scipy.sparse.linalg import eigsh
from scipy.sparse import csgraph


if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
else:
    device = torch.device("cpu")
print(f'using {device}')

def set_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

set_seed(42)

# Data Preprocessing

In [None]:
# dataset_path = '/Users/gitaayusalsabila/Documents/0thesis/code/sandbox/dataset/'
dataset_path = '/notebooks/dataset/'


def data_cleansing(dataset):
    # Replace negative values with 0
    dataset[dataset < 0] = 0
    
    # Replace NaN values with 0
    dataset = np.nan_to_num(dataset, nan=0)
    
    return dataset

def check_and_drop_invalid_graphs(graph_dataset):
    num_graphs, num_timepoints, num_nodes, _ = graph_dataset.shape
    num_dimensions = 1
    
    valid_graphs = []

    for i in range(num_graphs):
        is_valid = True
        for t in range(num_timepoints):
            adj_matrix = graph_dataset[i, t, :, :]
            num_edges = np.sum(adj_matrix > 0)
            if num_edges == 0:
                is_valid = False
                break
        
        if is_valid:
            valid_graphs.append(i)
    
    cleaned_dataset = graph_dataset[valid_graphs, :, :, :]
    
    return cleaned_dataset

def split_data(adj_data, features_data, train_ratio=0.7, val_ratio=0.1, test_ratio=0.2, random_seed=None):
    assert adj_data.shape[0] == features_data.shape[0], "Adjacency and features data must have the same number of samples"
    assert np.isclose(train_ratio + val_ratio + test_ratio, 1.0), "The sum of train, val and test ratios must be 1"

    num_samples = adj_data.shape[0]
    indices = np.arange(num_samples)
    if random_seed is not None:
        np.random.seed(random_seed)
    np.random.shuffle(indices)

    train_end = int(train_ratio * num_samples)
    val_end = train_end + int(val_ratio * num_samples)

    train_indices = indices[:train_end]
    val_indices = indices[train_end:val_end]
    test_indices = indices[val_end:]

    adj_train = torch.tensor(adj_data[train_indices], dtype=torch.float32)
    adj_val = torch.tensor(adj_data[val_indices], dtype=torch.float32)
    adj_test = torch.tensor(adj_data[test_indices], dtype=torch.float32)

    features_train = torch.tensor(features_data[train_indices], dtype=torch.float32)
    features_val = torch.tensor(features_data[val_indices], dtype=torch.float32)
    features_test = torch.tensor(features_data[test_indices], dtype=torch.float32)

    return adj_train, adj_val, adj_test, features_train, features_val, features_test


In [None]:
## Simulated Dataset
simulated_adj = np.load(dataset_path + 'simulated_adj.npy')
simulated_adj = simulated_adj[:, :, :, :,0] #only take the domain 1
simulated_features = np.load(dataset_path + 'simulated_laplacian_features.npy')
simulated_features = simulated_features[:, :, :, :,0]
simulated_num_samples, simulated_num_time, simulated_num_nodes, simulated_num_features = simulated_features.shape 
print(f'Simulated Dataset: Number of Samples= {simulated_num_samples},  Number of Times= {simulated_num_time}, Number of Nodes= {simulated_num_nodes}, Number of Features= {simulated_num_features}')
simulated_adj_train, simulated_adj_val, simulated_adj_test, simulated_features_train, simulated_features_val, simulated_features_test = split_data(simulated_adj, simulated_features)

## OASIS Dataset
oasis = np.load(dataset_path + 'oasis_adj.npy')
oasis_cleaned = data_cleansing(oasis)
oasis_adj = check_and_drop_invalid_graphs(oasis_cleaned)
oasis_features = np.load(dataset_path + 'oasis_laplacian_features.npy') 
oasis_num_samples, oasis_num_time, oasis_num_nodes, oasis_num_features = oasis_features.shape 
print(f'OASIS Dataset: Number of Samples= {oasis_num_samples},  Number of Times= {oasis_num_time}, Number of Nodes= {oasis_num_nodes}, Number of Features= {oasis_num_features}')
oasis_adj_train, oasis_adj_val, oasis_adj_test, oasis_features_train, oasis_features_val, oasis_features_test = split_data(oasis_adj, oasis_features)

## EMCI-AD Dataset
emci = np.load(dataset_path + 'emci-ad_adj.npy')
emci_cleaned = data_cleansing(emci)
emci_adj = check_and_drop_invalid_graphs(emci_cleaned)
emci_features = np.load(dataset_path + 'emci-ad_laplacian_features.npy') 
emci_num_samples, emci_num_time, emci_num_nodes, emci_num_features = emci_features.shape 
print(f'EMCI-AD Dataset: Number of Samples= {emci_num_samples},  Number of Times= {emci_num_time}, Number of Nodes= {emci_num_nodes}, Number of Features= {emci_num_features}')
emci_adj_train, emci_adj_val, emci_adj_test, emci_features_train, emci_features_val, emci_features_test = split_data(emci_adj, emci_features)

## SLIM160 Dataset
slim160 = np.load(dataset_path + 'slim160_adj.npy')
slim160_cleaned = data_cleansing(slim160)
slim160_adj = check_and_drop_invalid_graphs(slim160_cleaned)
slim160_features = np.load(dataset_path + 'slim160_laplacian_features_8.npy') 
slim160_num_samples, slim160_num_time, slim160_num_nodes, slim160_num_features = slim160_features.shape 
print(f'SLIM160 Dataset: Number of Samples= {slim160_num_samples},  Number of Times= {slim160_num_time}, Number of Nodes = {slim160_num_nodes}, Number of Features = {slim160_num_features}')
slim160_adj_train, slim160_adj_val, slim160_adj_test, slim160_features_train, slim160_features_val, slim160_features_test = split_data(slim160_adj, slim160_features)

Simulated Dataset: Number of Samples= 100,  Number of Times= 3, Number of Nodes= 35, Number of Features= 8
EMCI-AD Dataset: Number of Samples= 67,  Number of Times= 2, Number of Nodes= 35, Number of Features= 8
SLIM160 Dataset: Number of Samples= 109,  Number of Times= 3, Number of Nodes = 160, Number of Features = 8


# RBGM

In [None]:
def eucledian_distance(x, target_size):
  repeated_out = x.repeat(target_size,1,1)
  repeated_t = torch.transpose(repeated_out, 0, 1)
  diff = torch.abs(repeated_out - repeated_t)
  return torch.sum(diff, 2)

def frobenious_distance(test_sample,predicted):
  diff = torch.abs(test_sample - predicted)
  dif = diff*diff
  sum_of_all = diff.sum()
  d = torch.sqrt(sum_of_all)
  return d

## Model
class RNNCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, hidden_state):
        super(RNNCell, self).__init__()
        self.weight = nn.Linear(input_dim, hidden_dim, bias=True)
        self.weight_h = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.out = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.tanh = Tanh()
        self.hidden_state = hidden_state
    
    def forward(self,x):
        h = self.hidden_state
        y = self.tanh(self.weight(x.to(device)) + self.weight_h(h))
        self.hidden_state = y.detach()
        return y

class RBGM(nn.Module):
    def __init__(self, conv_size, hidden_state_size):
        super(RBGM, self).__init__()
        shape = torch.Size((hidden_state_size, hidden_state_size))
        hidden_state = torch.rand(shape, device=device)
        self.conv_size = conv_size
        self.rnn = nn.Sequential(RNNCell(1,hidden_state_size, hidden_state), ReLU())
        self.gnn_conv = NNConv(self.conv_size, self.conv_size, self.rnn, aggr='mean', root_weight=True, bias = True)
        
    def forward(self, data):
        edge_index, edge_attr, _, _ = create_edge_index_attribute(data)
        x1 = F.relu(self.gnn_conv(data, edge_index.to(device), edge_attr.to(device)))
        x1 = eucledian_distance(x1, self.conv_size)
        return x1

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)



def train_rbgm(model_1, model_2, train_adj, num_epochs=200, lr=0.0001, save_path='models/RBGM/rgbm_model', tp_c = 10):
    # Loss Function
    mael = torch.nn.L1Loss().to(device)
    tp = torch.nn.MSELoss().to(device)

    optimizer_1 = torch.optim.Adam(model_1.parameters(), lr = lr)
    optimizer_2 = torch.optim.Adam(model_2.parameters(), lr = lr)

    training_loss = []
    epoch_time = []
    cpu_usage = []
    memory_usage = []
    gpu_usage = []
    
    for epoch in range(num_epochs):
        total_loss_1, total_loss_2 = 0.0, 0.0 
        tp_loss_1, tp_loss_2, gen_loss_1, gen_loss_2 = 0.0, 0.0, 0.0, 0.0

        epoch_loss = 0
        epoch_gen_loss = 0
        epoch_top_loss = 0
        
        set_seed(42)
        model_1.train()
        model_2.train()

        epoch_start_time = time.time()
        
        for i in tqdm(range(train_adj.size(0)), desc=f'Epoch {epoch+1}/{num_epochs}'):
        # for i, data in enumerate(h_data_train_loader):
            data = train_adj
            
            data_t0 = data[i, 0]
            data_t1 = data[i, 1]
            data_t2 = data[i, 2]
            
            #Time Point 1
            optimizer_1.zero_grad()
            out_1 = model_1(data_t0)
            
            tpl_1 = tp(out_1.sum(dim=-1), data_t1.sum(dim=-1))
            tp_loss_1 += tpl_1.item()
            genl_1 = mael(out_1, data_t1)
            gen_loss_1 += genl_1.item()
            
            loss_1 = genl_1 + (tp_c * tpl_1)
            total_loss_1 += loss_1.item()
            loss_1.backward()
            optimizer_1.step()
            
            #Time Point 2
            optimizer_2.zero_grad()
            out_2 = model_2(data_t1)
            
            tpl_2 = tp(out_2.sum(dim=-1), data_t2.sum(dim=-1))
            tp_loss_2 += tpl_2.item()
            
            genl_2 = mael(out_2, data_t2)
            gen_loss_2 += genl_2.item()
            
            loss_2 = genl_2 + tp_c * tpl_2
            total_loss_2 += loss_2.item()
            loss_2.backward()
            optimizer_2.step()

            #All Training Loss
            epoch_loss      = total_loss_1 + total_loss_2
            epoch_gen_loss  = gen_loss_1 + gen_loss_2
            epoch_tp_loss   = tp_loss_1 + tp_loss_2    
        
        epoch_end_time = time.time()
        epoch_time.append(epoch_end_time - epoch_start_time)
        cpu_usage.append(psutil.cpu_percent(interval=None) / 100 * psutil.virtual_memory().total / (1024**3))  # CPU usage in GB
        memory_usage.append(psutil.virtual_memory().used / (1024**3))  # Memory usage in GB

        if device.type == 'cuda':
            gpu_usage.append(torch.cuda.memory_allocated(device) / (1024**3))  # GPU usage in GB
        else:
            gpu_usage.append(0)

        epoch_loss /= train_adj.size(0)
        epoch_gen_loss /= train_adj.size(0)
        epoch_top_loss /= train_adj.size(0)
        training_loss.append(epoch_loss)
        
        print(f'Epoch {epoch + 1}, Loss: {epoch_loss}, Generative Loss: {epoch_gen_loss}, Topological Loss: {epoch_tp_loss}')
        print(f'Time: {epoch_time[-1]:.2f}s, CPU: {cpu_usage[-1]:.2f}GB, Memorzy: {memory_usage[-1]:.2f}GB, GPU: {gpu_usage[-1]:.2f}GB\n')
        

    # Plot the training loss
    plt.plot(training_loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.show()

    # Save the trained model
    save_path_1 = save_path + '_model_1.pth'
    save_path_2 = save_path + '_model_2.pth'
    torch.save(model_1.state_dict(), save_path_1)
    torch.save(model_2.state_dict(), save_path_2)
    print(f'Model saved to {save_path}')

    print(f'Average Time per Epoch: {np.mean(epoch_time):.2f}s')
    print(f'Average CPU Usage: {np.mean(cpu_usage):.2f}GB')
    print(f'Average Memory Usage: {np.mean(memory_usage):.2f}GB')
    print(f'Average GPU Usage: {np.mean(gpu_usage):.2f}GB')

    print(f'\nTotal Training Time: {np.sum(epoch_time):.2f}s')
    print(f'Max CPU Usage: {np.max(cpu_usage):.2f}GB')
    print(f'Max Memory Usage: {np.max(memory_usage):.2f}GB')
    print(f'Max GPU Usage: {np.max(gpu_usage):.2f}GB')

    return

# EvoGraph

In [None]:
class EvoGenerator(nn.Module):
    def __init__(self, conv_size, hidden_size):
        super(EvoGenerator, self).__init__()
        self.conv_size = conv_size
        self.hidden_size = hidden_size

        lin = Sequential(Linear(1, self.hidden_size), ReLU())
        self.conv1 = NNConv(self.conv_size, self.conv_size, lin, aggr='mean', root_weight=True, bias=True)
        self.conv11 = BatchNorm(self.conv_size, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

        lin = Sequential(Linear(1, self.conv_size), ReLU())
        self.conv2 = NNConv(self.conv_size, 1, lin, aggr='mean', root_weight=True, bias=True)
        self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

        lin = Sequential(Linear(1, self.conv_size), ReLU())
        self.conv3 = NNConv(1, self.conv_size, lin, aggr='mean', root_weight=True, bias=True)
        self.conv33 = BatchNorm(self.conv_size, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x1 = torch.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))
        x1 = F.dropout(x1, training=self.training)
        
        x1 = (x1 + x1.T) / 2.0
        x1.fill_diagonal_(fill_value=0)
        x2 = torch.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr)))
        x2 = F.dropout(x2, training=self.training)

        x3 = torch.cat([torch.sigmoid(self.conv33(self.conv3(x2, edge_index, edge_attr))), x1], dim=1)
        x4 = x3[:, 0:self.conv_size]
        x5 = x3[:, self.conv_size:self.conv_size*2]

        x6 = (x4 + x5) / 2
        x6 = (x6 + x6.T) / 2.0
        x6.fill_diagonal_(fill_value=0)
        return x6
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

class EvoDiscriminator(nn.Module):
    def __init__(self, conv_size, hidden_size):
        super(EvoDiscriminator, self).__init__()
        self.conv_size = conv_size
        self.hidden_size = hidden_size
        
        lin = Sequential(Linear(2, self.hidden_size), ReLU())
        self.conv1 = NNConv(self.conv_size, self.conv_size, lin, aggr='mean', root_weight=True, bias=True)
        self.conv11 = BatchNorm(self.conv_size, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

        lin = Sequential(Linear(2, self.conv_size), ReLU())
        self.conv2 = NNConv(self.conv_size, 1, lin, aggr='mean', root_weight=True, bias=True)
        self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

    def forward(self, data, data_to_translate):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        edge_attr_data_to_translate = data_to_translate.edge_attr

        edge_attr_data_to_translate_reshaped = edge_attr_data_to_translate.view(self.hidden_size, 1)

        gen_input = torch.cat((edge_attr, edge_attr_data_to_translate_reshaped), -1)
        x = F.relu(self.conv11(self.conv1(x, edge_index, gen_input)))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv22(self.conv2(x, edge_index, gen_input)))

        return torch.sigmoid(x)
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


def train_evograph(generator_1, discriminator_1, generator_2, discriminator_2, train_adj, 
                    num_epochs=500, lr_g=0.01, lr_d=0.0002, save_path='models/EvoGraph/evograph_model', 
                    tp_c=0.00, g_c=2.0, i_c=2.0,kl_c=0.001):
    
    adversarial_loss= torch.nn.BCELoss().to(device)
    identity_loss   = torch.nn.L1Loss().to(device)  # Will be used in training
    msel            = torch.nn.MSELoss().to(device)
    mael            = torch.nn.L1Loss().to(device)  # Not to be used in training (Measure generator success)
    tp              = torch.nn.MSELoss().to(device) # Used for node strength

    i_coeff = 2.0
    g_coeff = 2.0
    kl_coeff = 0.001
    tp_coeff = 0.0

    num_nodes = train_adj.shape[2]
    optimizer_G1 = torch.optim.AdamW(generator_1.parameters(), lr=lr_g, betas=(0.5, 0.999), weight_decay=0.0)
    optimizer_D1 = torch.optim.AdamW(discriminator_1.parameters(), lr=lr_d, betas=(0.5, 0.999), weight_decay=0.0)
    optimizer_G2 = torch.optim.AdamW(generator_2.parameters(), lr=lr_g, betas=(0.5, 0.999), weight_decay=0.0)
    optimizer_D2 = torch.optim.AdamW(discriminator_2.parameters(), lr=lr_d, betas=(0.5, 0.999), weight_decay=0.0)

    total_step = train_adj.shape[0]
    data_size = total_step
    
    real_label = torch.ones(num_nodes, 1).to(device)
    fake_label = torch.zeros(num_nodes, 1).to(device)

    real_losses1, fake_losses1, mse_losses1, mae_losses1 = list(), list(), list(), list()
    real_losses2, fake_losses2, mse_losses2, mae_losses2 = list(), list(), list(), list()

    k1_losses, k2_losses = list(), list()
    tp_losses_1_tr,  tp_losses_2_tr  = list(), list()
    gan_losses_1_tr, gan_losses_2_tr = list(), list()

    training_loss = []
    epoch_time = []
    cpu_usage = []
    memory_usage = []
    gpu_usage = []

    d_loss = []
    g_loss = []

    for epoch in range(num_epochs):
        # Reporting
        r1, f1, d1, g1, mse_l1, mae_l1 = 0, 0, 0, 0, 0, 0
        r2, f2, d2, g2, mse_l2, mae_l2 = 0, 0, 0, 0, 0, 0
        k1_train, k2_train  = 0.0, 0.0
        tp1_tr, tp2_tr      = 0.0, 0.0
        gan1_tr, gan2_tr    = 0.0, 0.0

        # Train
        generator_1.train()
        discriminator_1.train()
        generator_2.train()
        discriminator_2.train()

        epoch_start_time = time.time()

        for i in tqdm(range(data_size), desc=f'Epoch {epoch+1}/{num_epochs}'):
            data = train_adj.to(device)
            
            edge_idx_t0, edge_att_t0, _, _ = create_edge_index_attribute(data[i, 0])
            data_t0 = Data(x=data[i, 0], edge_attr=edge_att_t0, edge_index=edge_idx_t0).to(device)
            edge_idx_t1, edge_att_t1, _, _ = create_edge_index_attribute(data[i, 1])
            data_t1 = Data(x=data[i, 1], edge_attr=edge_att_t1, edge_index=edge_idx_t1).to(device)
            edge_idx_t2, edge_att_t2, _, _ = create_edge_index_attribute(data[i, 2])
            data_t2 = Data(x=data[i, 2], edge_attr=edge_att_t2, edge_index=edge_idx_t2).to(device)
            
            ########################################################################################
            ################################### 1st Part Training ##################################
            ########################################################################################
            # Train the discriminator
            optimizer_D1.zero_grad()
            fake_t1 = generator_1(data_t0).detach()
            edge_idx_f1, edge_att_f1, _, _ = create_edge_index_attribute(fake_t1)
            fake_data_t1 = Data(x=fake_t1, edge_attr=edge_att_f1, edge_index=edge_idx_f1).to(device)

            # data      : Real source and real target
            # fake_data : Real source and generated target
            real_loss_1 = adversarial_loss(discriminator_1(data_t1, data_t0), real_label)
            fake_loss_1 = adversarial_loss(discriminator_1(fake_data_t1, data_t0), fake_label)
            loss_D1 = torch.mean(real_loss_1 + fake_loss_1) / 2
            r1 += real_loss_1.item()
            f1 += fake_loss_1.item()
            d1 += loss_D1.item()

            loss_D1.backward(retain_graph=True)
            optimizer_D1.step()


            # Train the generator
            optimizer_G1.zero_grad()

            # Adversarial Loss
            fake_data_t1.x = generator_1(data_t0)
            gan_loss_1 = torch.mean(adversarial_loss(discriminator_1(fake_data_t1, data_t0), real_label))
            gan1_tr += gan_loss_1.item()

            # KL Loss
            kl_loss_1 = kl.kl_divergence(normal.Normal(fake_data_t1.x.mean(dim=1), fake_data_t1.x.std(dim=1)),
                                       normal.Normal(data_t1.x.mean(dim=1), data_t1.x.std(dim=1))).sum()

            # Topology Loss
            tp_loss_1 = tp(fake_data_t1.x.sum(dim=-1), data_t0.x.sum(dim=-1))
            tp1_tr += tp_loss_1.item()

            # Identity Loss is included in the end
            swapped_data = generator_1(data_t1)
            # print(f'swapped_data shape: {swapped_data.shape}, swapped_data type: {type(swapped_data)}')
            loss_G1 = (i_coeff * identity_loss(generator_1(data_t1), data_t1.x)) + (g_coeff * gan_loss_1) + (kl_coeff * kl_loss_1) + (tp_coeff * tp_loss_1)
            g1 += loss_G1.item()
            
            loss_G1.backward(retain_graph=True)
            optimizer_G1.step()
            
            k1_train += kl_loss_1.item()
            mse_l1 += msel(generator_1(data_t0), data_t1.x).item()
            mae_l1 += mael(generator_1(data_t0), data_t1.x).item()

            ########################################################################################
            ################################### 2nd Part Training ##################################
            ########################################################################################
            # Train the discriminator
            optimizer_D2.zero_grad()
            fake_t2 = generator_2(data_t1).detach()
            edge_idx_f2, edge_att_f2, _, _ = create_edge_index_attribute(fake_t2)
            fake_data_t2 = Data(x=fake_t2, edge_attr=edge_att_f2, edge_index=edge_idx_f2).to(device)

            # data      : Real source and real target
            # fake_data : Real source and generated target
            real_loss_2 = adversarial_loss(discriminator_2(data_t2, data_t1), real_label)
            fake_loss_2 = adversarial_loss(discriminator_2(fake_data_t2, data_t1), fake_label)
            loss_D2 = torch.mean(real_loss_2 + fake_loss_2) / 2
            r2 += real_loss_2.item()
            f2 += fake_loss_2.item()
            d2 += loss_D2.item()

            loss_D2.backward(retain_graph=True)
            optimizer_D2.step()


            # Train the generator
            optimizer_G2.zero_grad()

            # Adversarial Loss
            fake_data_t2.x = generator_2(data_t1)
            gan_loss_2 = torch.mean(adversarial_loss(discriminator_2(fake_data_t2, data_t1), real_label))
            gan2_tr += gan_loss_2.item()

            # KL Loss
            kl_loss_2 = kl.kl_divergence(normal.Normal(fake_data_t2.x.mean(dim=1), fake_data_t2.x.std(dim=1)),
                                       normal.Normal(data_t2.x.mean(dim=1), data_t2.x.std(dim=1))).sum()

            # Topology Loss
            tp_loss_2 = tp(fake_data_t2.x.sum(dim=-1), data_t1.x.sum(dim=-1))
            tp2_tr += tp_loss_2.item()

            # Identity Loss is included in the end
            loss_G2 = (i_coeff * identity_loss(generator_2(data_t2), data_t2.x)) + (g_coeff * gan_loss_2) + (kl_coeff * kl_loss_2) + (tp_coeff * tp_loss_2)
            g2 += loss_G2.item()
            
            loss_G2.backward(retain_graph=True)
            optimizer_G2.step()
            
            k2_train += kl_loss_2.item()
            mse_l2 += msel(generator_2(data_t0), data_t2.x).item()
            mae_l2 += mael(generator_2(data_t0), data_t2.x).item()


        epoch_end_time = time.time()
        epoch_time.append(epoch_end_time - epoch_start_time)
        cpu_usage.append(psutil.cpu_percent(interval=None) / 100 * psutil.virtual_memory().total / (1024**3))  # CPU usage in GB
        memory_usage.append(psutil.virtual_memory().used / (1024**3))  # Memory usage in GB

        if device.type == 'cuda':
            gpu_usage.append(torch.cuda.memory_allocated(device) / (1024**3))  # GPU usage in GB
        else:
            gpu_usage.append(0)

        d1 /= total_step
        g1 /= total_step
        d2 /= total_step
        g2 /= total_step

        d_loss.append(d1 + d2) 
        g_loss.append(g1 + g2)

        print(f'Epoch [{epoch + 1}/{num_epochs}]')
        print(f'D1 Loss: {d1:.5f}, G1 Loss: {g1:.5f}, R1 Loss: {r1/total_step:.5f}, F1 Loss: {f1/total_step:.5f}, MSE: {mse_l1/total_step:.5f}, MAE: {mae_l1/total_step:.5f}')
        print(f'D2 Loss: {d2:.5f}, G2 Loss: {g2:.5f}, R2 Loss: {r2/total_step:.5f}, F2 Loss: {f2/total_step:.5f}, MSE: {mse_l2/total_step:.5f}, MAE: {mae_l2/total_step:.5f}')
        print(f'Time: {epoch_time[-1]:.2f}s, CPU: {cpu_usage[-1]:.2f}GB, Memory: {memory_usage[-1]:.2f}GB, GPU: {gpu_usage[-1]:.2f}GB\n')
        
    # Plot the training losses
    epochs = range(1, num_epochs + 1)
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, d_loss, label='Discriminator Loss (D)', marker='o', linestyle='-', color='b')
    plt.plot(epochs, g_loss, label='Generator Loss (G)', marker='x', linestyle='-', color='r')
    plt.xlabel('Epoch')
    plt.ylabel('Loss Value')
    plt.title('Line Chart of Discriminator Loss (D) and Generator Loss (G) per Epoch')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # Save the trained model
    save_path_generator1 = save_path + '_model_generator1.pth'
    torch.save(generator_1.state_dict(), save_path_generator1)
    save_path_discriminator1 = save_path + '_model_discriminator1.pth'
    torch.save(discriminator_1.state_dict(), save_path_discriminator1)

    save_path_generator2 = save_path + '_model_generator2.pth'
    torch.save(generator_2.state_dict(), save_path_generator2)
    save_path_discriminator2 = save_path + '_model_discriminator2.pth'
    torch.save(discriminator_2.state_dict(), save_path_discriminator2)
    
    
    print(f'Model saved to {save_path}')

    print(f'Average Time per Epoch: {np.mean(epoch_time):.2f}s')
    print(f'Average CPU Usage: {np.mean(cpu_usage):.2f}GB')
    print(f'Average Memory Usage: {np.mean(memory_usage):.2f}GB')
    print(f'Average GPU Usage: {np.mean(gpu_usage):.2f}GB')

    print(f'\nTotal Training Time: {np.sum(epoch_time):.2f}s')
    print(f'Max CPU Usage: {np.max(cpu_usage):.2f}GB')
    print(f'Max Memory Usage: {np.max(memory_usage):.2f}GB')
    print(f'Max GPU Usage: {np.max(gpu_usage):.2f}GB')  
