In [None]:
#!nvidia-smi


In [None]:
#import torch
#print(torch.__version__)
#print(torch.cuda.is_available())
#print(torch.version.cuda)


In [None]:
import gc
import datetime

import os
import pickle
import random
import csv 
import pandas as pd
import numpy as np
import nibabel as nib

from tqdm import tqdm
from scipy import ndimage
from glob import glob
import matplotlib.pyplot as plt

import torch
import torchio as tio
from torchio import SubjectsDataset, SubjectsLoader

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import mlflow
import mlflow.pytorch

import argparse

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
import warnings

warnings.filterwarnings("ignore", message="Using TorchIO images without a torchio.SubjectsLoader")

pd.set_option('future.no_silent_downcasting', True)

In [None]:
seed = 42  # Fixed for reproducibility
print("Seed value:", seed)

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # For multi-GPU

In [None]:
global_fold = 0

In [None]:
losses_used = "/exp_OURS"
task_path = "PATH_TO_TASK_DIR/Task_FreeSurfer_1-5T_longit"
task_name_exp = "Longit-OURS"
print("-------", task_name_exp, "-------")
max_tp = 13

if not os.path.exists(task_path + losses_used):
    os.makedirs(task_path + losses_used)
    os.makedirs(task_path + losses_used + '/models')
    os.makedirs(task_path + losses_used + '/vol_loss_plots')
    print(f"Directory created: {task_path + losses_used}")
else:
    print(f"Directory already exists: {task_path + losses_used}")


In [None]:
csv_filename = task_path + losses_used + "/fold_" + str(global_fold) + "_training_log.csv"

# Create CSV file with header at the beginning of the script
with open(csv_filename, mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["Epoch"])  # Writing header


def log_epoch(epoch):
    with open(csv_filename, mode="a", newline="") as file:
        writer = csv.writer(file)
        writer.writerow([epoch])  # Append epoch number

In [None]:
now = datetime.now()
print("Started running the code at:", now)
print("############################################################################\n\n")

In [None]:
from utils.dataset import *
from utils.model import *
from utils.criterion import *
from utils.helpers import *

In [None]:
fields=['Dice Loss', 'BCE Loss', 'hlfc Loss', 'loss1', 'Smoothness Loss OR loss2', 'Age Contraint Loss OR loss3', 'Mean Avg Volume loss', 'SDF loss', 'Total Loss', 'Epoch', 'Learning Rate']
temp_fields=['Dice Loss', 'BCE Loss', 'hlfc Loss', 'loss1', 'Smoothness Loss OR loss2', 'Age Contraint Loss OR loss3', 'Mean Avg Volume loss', 'SDF loss', 'Total Loss']

In [None]:
print("LOSS USED --- SDF-VL-AC-SC on CROP-REG DS Experiment") 

In [None]:
def plot_and_save_losses(train_losses, val_losses, fold):

    plt.figure(figsize=(5, 5))
    epochs = range(1, len(train_losses) + 1)
    
    plt.plot(epochs, train_losses, 'b', label='Training Loss')
    plt.plot(epochs, val_losses, 'r', label='Validation Loss')
    
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Save the plot (optional)
    plt.savefig(task_path + losses_used +  '/train_val_losses_plot_' + str(fold) + '.png')
    
    # Show plot
    plt.show()
    plt.close()  # Close the plot to release memory

def plot_all(loss_file_loc, fold):
    
    plt.figure(figsize=(5, 5))

    train_all_losses = pd.read_csv(loss_file_loc)
    val_all_losses = pd.read_csv(loss_file_loc.replace("all_losses_train_", "all_losses_val_"))
    
    #print("train loss length", len(train_all_losses))
    #print("val loss length", len(val_all_losses))
    
    epochs = range(1, len(train_all_losses) + 1)
    #print("total epochs:", epochs)

    for cur_col in train_all_losses:

        #print(cur_col)
        
        cur_train_loss = train_all_losses[cur_col]
        cur_val_loss = val_all_losses[cur_col]
    
        plt.plot(epochs, cur_train_loss, 'b', label='Training Loss')
        plt.plot(epochs, cur_val_loss, 'r', label='Validation Loss')
        
        plt.title('Training and Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel(cur_col)
        plt.legend()
        plt.grid(True)

        # Save the plot (optional)
        plt.savefig(task_path + losses_used + '/plot_losses_fold_' + str(fold) + '_' + cur_col +'.png')
        
        # Show plot
        #plt.show()
        plt.close()  # Close the plot to release memory

def plot_all_iter(loss_file_loc, fold, type_loss):
    
    plt.figure(figsize=(5, 5))

    train_all_losses = pd.read_csv(loss_file_loc)
    #print("train loss length", len(train_all_losses))
    
    epochs = range(1, len(train_all_losses) + 1)
    #print("total epochs:", epochs)

    for cur_col in train_all_losses:

        #print(cur_col)
        
        cur_train_loss = train_all_losses[cur_col]

        if type_loss == "train":
            plt.plot(epochs, cur_train_loss, 'b', label= 'Training Loss')
        if type_loss == "val":
            plt.plot(epochs, cur_train_loss, 'r', label= 'Validation Loss')
        
        plt.title('Loss')
        plt.xlabel('Iterations')
        plt.ylabel(cur_col)
        plt.legend()
        plt.grid(True)

        # Save the plot (optional)
        plt.savefig(task_path + losses_used + '/plot_losses_iter_' + type_loss +'_fold_' + str(fold) + '_' + cur_col +'.png')
        
        # Show plot
        #plt.show()
        plt.close()  # Close the plot to release memory



In [None]:
def get_average_from_csv(file_path):

    with open(file_path, mode='r') as file:
        reader = csv.reader(file)
        data = [row for row in reader]
        data = np.array(data)

    # Convert data to float (skipping the header)
    data = data[1:].astype(float)
    # Calculate column-wise mean
    averages = np.mean(data, axis=0)
    
    return averages

def get_average_from_csv_without_header(file_path):
    
    # Read the CSV file excluding the header
    df = pd.read_csv(file_path, header=0)
    # Convert the DataFrame to a NumPy array
    data = df.to_numpy(dtype=np.float64)
    # Calculate column-wise mean
    averages = np.mean(data, axis=0)
    
    return averages

In [None]:

def clear_gpu_memory():
    torch.cuda.empty_cache()


def train_batch(model, data, optimizer, criterion, loss_file_loc):

    cur_fold = global_fold
    print("################################## Train Loop: is fold", cur_fold)

    cur_pred = []
    all_preds = []
    ages = []
    
    model.train()
    ## optimizer.zero_grad() ## takng outside loop
    
    images, targets, total_tp, all_IDs, feat_target  = data
    #print("TRAIN: Feature shape out of the dataloader", feat_target.shape)
    #images, total_tp, ages = data
    
    #, targets
    #print("image device", images.device)
    #print("Train: total time points in this sample", total_tp)
    #selected = [random.randint(2, value) for value in total_tp]
    #print("selected time points in list", selected)
    idx = max(total_tp)
    #print("selected time point", idx)
    #print("all images shape in training: ", images.cpu().numpy().shape)
    #print("all targets shape in training: ", targets.cpu().numpy().shape)

    """
    Train: total time points in this sample tensor([6, 5, 3, 6, 4, 4, 5, 8], dtype=torch.int32)
    all images shape in training:  (8, 13, 1, 112, 88, 88)
    all targets shape in training:  (8, 13, 3, 112, 88, 88)

    all pred shape in training:  (8, 8, 3, 112, 88, 88)
    all targets shape going into function:  (8, 8, 3, 112, 88, 88)
    """

    # Assumes: images shape is [B, T, 1, D, H, W]
    B, _, _, D, H, W = images.shape
    all_preds = torch.zeros((B, idx, 3, D, H, W), device=images.device)
    all_feat_pred = torch.zeros((B, idx, 256, 12, 12, 12), device=images.device)
    
    for bbb in range(B): 
        TPPP = total_tp[bbb]
        for ttt in range(TPPP):  
            img = images[bbb, ttt, :, :, :, :] 
            pred, cur_feat = model(img.unsqueeze(0)) 
            all_preds[bbb, ttt, :, :, :, :] = pred 
            all_feat_pred[bbb, ttt, :, :, :, :] = cur_feat

    
    #loss_1, dice1, dice2, all_of_them = criterion(all_preds[:2, :, :, :, :, :], all_f[:2, :, :, :, :, :], targets[:2, :idx, :, :, :, :], None, None, None, None, cur_fold, loss_name="loss1")
    loss_1, dice1, dice2, all_of_them = criterion(all_preds, all_feat_pred, targets[:, :idx, :, :, :, :], feat_target[:, :idx, :, :, :, :], None, total_tp, None, None, None, loss_name="loss1")
    
    del cur_pred

    sh = (images[:, :, :, :, :, :].cpu().numpy().shape[0], images[:, :, :, :, :, :].cpu().numpy().shape[1], 3)
    get_volume = np.zeros(sh)
    #print("volume record shape", get_volume.shape)
    for cur in range(0, max_tp, 1):
        preds_cur, _ = model(images[:, cur, :, :, :, :])
        probs_cur = torch.softmax(preds_cur, dim=1)
        preds_cur = torch.argmax(probs_cur, dim=1).cpu().numpy()
        #print("current pred dimension", preds_cur.shape)
        for b in range(0, get_volume.shape[0], 1):
            #print("time point in cur", cur, "actual total tp", total_tp[b])
            if (cur<total_tp[b]):
                get_volume[b][cur][0] = (preds_cur[b, :, :, :] == 0).sum()
                get_volume[b][cur][1] = (preds_cur[b, :, :, :] == 1).sum()
                get_volume[b][cur][2] = (preds_cur[b, :, :, :] == 2).sum()

    del preds_cur, _, probs_cur
    #print(get_volume)
    #get_volume = (get_volume/2097152)
    
    loss_2, _, _, temp = criterion(None, None, None, None, get_volume, total_tp, ages, all_IDs, None, loss_name="loss2")
    all_of_them.append(temp)
    del temp
    
    loss_3, _, _, temp = criterion(None, None, None, None, get_volume, total_tp, ages, all_IDs, None, loss_name="loss3")
    all_of_them.append(temp)
    del temp

    ICVs = []
    diagnosis = []
    
    loss_4, _, _, temp = criterion(None, None, None, None, get_volume, total_tp, ages, all_IDs, cur_fold, loss_name="vol_loss")
    all_of_them.append(temp)
    del temp

    loss_5, _, _, temp = criterion(all_preds, None, None, None, None, total_tp, None, None, None, loss_name="sdf_loss")
    all_of_them.append(temp)
    del _, temp
    del all_preds
    
    clear_gpu_memory()
    
    #print("unique in ground truth", torch.unique(target))
    #print("unique in prediction", torch.unique((preds)))

    loss = loss_1 + (0.5*loss_2) + (0.1*loss_3) + loss_4 + loss_5
    print("Train: Loss 1", loss_1.item(), ", Loss 2", loss_2.item(), ", Loss 3", loss_3.item(), ", Loss 4", loss_4.item(), ", Loss 5", loss_5.item())

    all_of_them.append(loss.item())
    loss_row = all_of_them
    with open(loss_file_loc, 'a') as fl:
        writer = csv.writer(fl)
        writer.writerow(loss_row)
    
    #loss.backward()  ## takng outside loop
    #optimizer.step()  ## takng outside loop

    #for obj in gc.get_objects():
    #    if torch.is_tensor(obj) and obj.is_cuda:
    #        print(f"Tensor: {obj.shape}, dtype: {obj.dtype}, device: {obj.device}")
        
    gc.collect()
    torch.cuda.empty_cache()

    return loss, dice1.item(), dice2.item() 

@torch.no_grad()
def validate_batch(model, data, criterion, loss_file_loc):

    cur_fold = global_fold
    print("################################## Val Loop: is fold", cur_fold)
    
    #loss_file_loc = loss_file_loc.replace("_train_", "_val_")
    cur_pred = []
    all_preds = []
    ages = []
    
    model.eval()
    images, targets, total_tp, all_IDs, feat_target  = data
    #print("VAL: Feature shape out of the dataloader", feat_target.shape)
    
    #print(ages, ICVs, diagnosis)
    #images, total_tp, ages = data
    
    #, targets
    #print("image device", images.device)
    #print("Val total time points in this sample", total_tp)
    #selected = [random.randint(2, value) for value in total_tp]
    #print("selected time points in list", selected)
    idx = max(total_tp)
    #print("selected time point", idx)
    #print("all images shape in validation: ", images[:, :, :, :, :, :].cpu().numpy().shape)
 
    # Assumes: images shape is [B, T, 1, D, H, W]
    B, _, _, D, H, W = images.shape
    all_preds = torch.zeros((B, idx, 3, D, H, W), device=images.device)
    all_feat_pred = torch.zeros((B, idx, 256, 12, 12, 12), device=images.device)
    
    for bbb in range(B): 
        TPPP = total_tp[bbb]
        for ttt in range(TPPP):  
            img = images[bbb, ttt, :, :, :, :] 
            pred, cur_feat = model(img.unsqueeze(0)) 
            all_preds[bbb, ttt, :, :, :, :] = pred 
            all_feat_pred[bbb, ttt, :, :, :, :] = cur_feat

    #print("VAL: ALL IMAGE PREDICTION SIZE", all_preds.shape)
    #print("VAL: ALL FEATURE PREDICTION SIZE", all_feat_pred.shape)

    #loss_1, dice1, dice2, all_of_them = criterion(all_preds[:2, :, :, :, :, :], all_f[:2, :, :, :, :, :], targets[:2, :idx, :, :, :, :], None, None, None, None, None, loss_name="loss1")
    loss_1, dice1, dice2, all_of_them = criterion(all_preds, all_feat_pred, targets[:, :idx, :, :, :, :], feat_target[:, :idx, :, :, :, :], None, total_tp, None, None, None, loss_name="loss1")
    
    del cur_pred

    sh = (images[:, :, :, :, :, :].cpu().numpy().shape[0], images[:, :, :, :, :, :].cpu().numpy().shape[1], 3)
    get_volume = np.zeros(sh)
    #print("volume record shape", get_volume.shape)
    for cur in range(0, max_tp, 1):
        preds_cur, _ = model(images[:, cur, :, :, :, :])
        probs_cur = torch.softmax(preds_cur, dim=1)
        preds_cur = torch.argmax(probs_cur, dim=1).cpu().numpy()
        #print("current pred dimension", preds_cur.shape)
        for b in range(0, get_volume.shape[0], 1):
            #print("time point in cur", cur, "actual total tp", total_tp[b])
            if (cur<total_tp[b]):
                get_volume[b][cur][0] = (preds_cur[b, :, :, :] == 0).sum()
                get_volume[b][cur][1] = (preds_cur[b, :, :, :] == 1).sum()
                get_volume[b][cur][2] = (preds_cur[b, :, :, :] == 2).sum()

    del preds_cur, _, probs_cur
    #get_volume = (get_volume/2097152)
    #print(get_volume)

    loss_2, _, _, temp = criterion(None, None, None, None, get_volume, total_tp, ages, all_IDs, None, loss_name="loss2")
    all_of_them.append(temp)
    del temp
    
    loss_3, _, _, temp = criterion(None, None, None, None, get_volume, total_tp, ages, all_IDs, None, loss_name="loss3")
    all_of_them.append(temp)
    del temp

    ICVs = []
    diagnosis = []
    
    loss_4, _, _, temp = criterion(None, None, None, None, get_volume, total_tp, ages, all_IDs, cur_fold, loss_name="vol_loss")
    all_of_them.append(temp)
    del temp

    loss_5, _, _, temp = criterion(all_preds, None, None, None, None, total_tp, None, None, None, loss_name="sdf_loss")
    all_of_them.append(temp)
    del temp
    del all_preds
    
    clear_gpu_memory()

    loss = loss_1 + (0.5*loss_2) + (0.1*loss_3) + loss_4 + loss_5
    print("Val: Loss 1", loss_1.item(), ", Loss 2", loss_2.item(), ", Loss 3", loss_3.item(), ", Loss 4", loss_4.item(), ", Loss 5", loss_5.item())
    
    all_of_them.append(loss.item())
    loss_row = all_of_them
    with open(loss_file_loc, 'a') as fl:
        writer = csv.writer(fl)
        writer.writerow(loss_row)

    #for obj in gc.get_objects():
    #    if torch.is_tensor(obj) and obj.is_cuda:
    #        print(f"Tensor: {obj.shape}, dtype: {obj.dtype}, device: {obj.device}")
        
    gc.collect()
    torch.cuda.empty_cache()

    return loss.item(), dice1.item(), dice2.item()

patience = 20 #10  # Number of epochs to wait before stopping if no improvement
early_stopping_counter = 0  # Counter for early stopping

for fold in range(global_fold, global_fold+1, 1):

    fold_dir_path = task_path + losses_used + '/models/fold-' + str(fold)
    os.makedirs(fold_dir_path, exist_ok=True)

    all_train_losses = []
    all_val_losses = []
    
    print("############################################################################")
    print("Loading Data for fold ", fold)

    train_dataset = mydata_nnunet(task_path + "/",'train', fold)
    val_dataset = mydata_nnunet(task_path + "/",'valid', fold)

    train_loader = DataLoader(train_dataset, batch_size = 4, shuffle = True, collate_fn=train_dataset.collate)  
    val_loader = DataLoader(val_dataset, batch_size = 4, shuffle = False, collate_fn=val_dataset.collate) 

    print(f"Train Loader {len(train_loader)} : Val Loader {len(val_loader)}")
    print("############################################################################")

    ### model = Unet3D(1, 3, 16) 

    ### LOADING FOR THIS LONGIT
    pretrained_model = nnUNet3D(in_channels=1, out_channels=3)  #Unet3D(1, 3, 16) 

    ## for the vol loss function - crop ONLY DS is enough
    # should not need a registered dataset
    best_model_path = 'PATH_TO_PRETRAIN/models/fold-' + str(fold) + '/model_fold_' + str(fold) + '_best_model.pth'

    #### Loading the model
    #model.load_state_dict(torch.load(best_model_path, weights_only=True))
    pretrained_model.load_state_dict(torch.load(best_model_path, weights_only=True), strict=False)
    print("Model Loaded:", best_model_path) 
    
    # Initialize new model with Dropout
    model = nnUNet3D_Dropout(in_channels=1, out_channels=3)  #Unet3D_Dropout(in_dim=1, out_dim=3, num_filters=16, dropout_rate=0.4)
    # Transfer Weights (Ignoring Mismatched Layers)
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    
    
    # Wrap the model with DataParallel to use multiple GPUs
    if torch.cuda.device_count() > 1:
        print("multiple gpus identified")
        model = nn.DataParallel(model)

    
    # Move model to GPU
    model = model.cuda()
    
    initial_lr = 0.005 #0.005  
    nEpochs = 100
    best_val_loss = float('inf')

    print("############################################################################")
    print("Running for total", nEpochs, "epochs")
    print("############################################################################")

    
    # optimser
    #opt = torch.optim.Adam(model.parameters(), lr=initial_lr, weight_decay=1e-5) # original longit
    opt = torch.optim.Adam(model.parameters(), lr=initial_lr, weight_decay=1e-5) # <--- one that worked for cross
    
    # loss function
    criterion = get_loss()
    
    # scheduler
    #scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5) #original longit
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.8, patience=3, verbose=True) # <--- one that worked for cross
    #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.7, patience=3, verbose=True) 

    
    loss_file_loc = task_path + losses_used + "/all_losses_train_fold_" + str(fold) + ".csv"
    
    with open(loss_file_loc, 'w') as fl:
        
        writer = csv.writer(fl)
        writer.writerow(fields)

    with open(loss_file_loc.replace("all_losses_train_", "all_losses_val_"), 'w') as fl:
        
        writer = csv.writer(fl)
        writer.writerow(fields)

    with open(task_path + losses_used + "/all_losses_iter_train_fold_" + str(fold) + ".csv", 'w') as fl:
        writer = csv.writer(fl)
        writer.writerow(temp_fields)

    with open(task_path + losses_used + "/all_losses_iter_val_fold_" + str(fold) + ".csv", 'w') as fl:
        writer = csv.writer(fl)
        writer.writerow(temp_fields)

    mlflow.set_experiment(task_name_exp)
    # enable logging of system metrics such as CPU usage, memory usage, disk I/O, and network I/O to MLflow runs -> monitoring resource utilization
    #mlflow.enable_system_metrics_logging()
    run_name = task_name_exp + "_fold_" + str(fold) + "_epochs_" + str(nEpochs) + '_time_' + str(now)
    
    with mlflow.start_run(run_name=run_name, log_system_metrics=True) as run:
        
        for epoch in tqdm(range(nEpochs)):
    
            with open(task_path + losses_used + "/temp_train_fold_" + str(fold) + ".csv", 'w') as fl:
                writer = csv.writer(fl)
                writer.writerow(temp_fields)
    
            with open(task_path + losses_used + "/temp_val_fold_" + str(fold) + ".csv", 'w') as fl:
                writer = csv.writer(fl)
                writer.writerow(temp_fields)
    
            train_loss = val_loss = 0
            dice_1_t = dice_2_t= dice_3_t = dice_1_v = dice_2_v = dice_3_v = 0
        
            print("Epoch ", epoch, ":")


            accum_steps = 2  # for example
            opt.zero_grad()  

            for i, data in enumerate(train_loader):

                #data = data
                #print(data.device)
                loss, dice1, dice2 = train_batch(model, data, opt, criterion, loss_file_loc=task_path + losses_used + "/temp_train_fold_" + str(fold) + ".csv")
                (loss / accum_steps).backward() # Scale loss by accumulation steps to keep gradients consistent
                train_loss += loss.item()
                dice_1_t += dice1
                dice_2_t += dice2

                 # Perform optimizer step and zero_grad every accum_steps batches OR at the last batch
                if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader):
                    opt.step()
                    opt.zero_grad()
                    
            d = len(train_loader)
            avg_train = f"Train : Overall Loss: {train_loss/d} | L1 Dice : {(dice_1_t/d)*100}% | L2 Dice : {(dice_2_t/d)*100}%"
            #print(avg_train)
        
            for i_v, data_v in enumerate(val_loader):
                #data_v = data_v
                loss, dice1, dice2 = validate_batch(model, data_v, criterion, loss_file_loc=task_path + losses_used + "/temp_val_fold_" + str(fold) + ".csv")
                val_loss += loss
                dice_1_v += dice1
                dice_2_v += dice2
            d_v = len(val_loader)
            avg_val = f"Val : Overall Loss: {val_loss/d_v} | L1 Dice : {(dice_1_v/d_v)*100}% | L2 Dice : {(dice_2_v/d_v)*100}%"
            #print(avg_val)

            mlflow.log_metric('Train Total Loss', train_loss/d, step=epoch)
            mlflow.log_metric('Val Total Loss', val_loss/d_v, step=epoch)
            

            # Print the learning rate
            for param_group in opt.param_groups:
                current_lr = param_group['lr']
            print(f"Learning rate: {current_lr}")

            #mlflow.log_metric('Learning_rate', current_lr, step=epoch)

            
            get_train_avg = get_average_from_csv(task_path + losses_used + "/temp_train_fold_" + str(fold) + ".csv")
            get_val_avg = get_average_from_csv(task_path + losses_used + "/temp_val_fold_" + str(fold) + ".csv")
            #print("############################################################################")
            get_train_avg = np.append(get_train_avg, epoch)
            get_train_avg = np.append(get_train_avg, current_lr)
            #print(get_train_avg)
            #print("############################################################################")
            get_val_avg = np.append(get_val_avg, epoch)
            get_val_avg = np.append(get_val_avg, current_lr)
            #print(get_val_avg)
            #print("############################################################################")
            # Load the new data (without header)
            new_data_train = pd.read_csv(task_path + losses_used + "/temp_train_fold_" + str(fold) + ".csv", header=None, skiprows=1)
            new_data_train.to_csv(task_path + losses_used + "/all_losses_iter_train_fold_" + str(fold) + ".csv", mode="a", header=False, index=False)
            #print("temp train:", new_data_train)
            new_data_val = pd.read_csv(task_path + losses_used + "/temp_val_fold_" + str(fold) + ".csv", header=None, skiprows=1)
            new_data_val.to_csv(task_path + losses_used + "/all_losses_iter_val_fold_" + str(fold) + ".csv", mode="a", header=False, index=False)
            #print("temp val:", new_data_val)

            for p in range(0, len(fields), 1):
                #print("ML logging:", fields[p], get_train_avg[p], epoch)
                mlflow.log_metric("Train " + fields[p], get_train_avg[p], step=epoch)

            for p in range(0, len(fields), 1):
                mlflow.log_metric("Val " + fields[p], get_val_avg[p], step=epoch)
                
            with open(loss_file_loc, 'a') as fl:
                writer = csv.writer(fl)
                writer.writerow(get_train_avg)
    
            with open(loss_file_loc.replace("all_losses_train_", "all_losses_val_"), 'a') as fl:
                writer = csv.writer(fl)
                writer.writerow(get_val_avg)

            """
            ##### Condition of saving the model and early stopping the training
            if best_val_loss>(val_loss/d_v) and (val_loss/d_v)>0:
        
                print(avg_train)
                print(avg_val)
                early_stopping_counter = 0 
                file_name = task_path + losses_used + '/models/model_fold_' + str(fold) + '_best_model.pth'
                print("Saving Model")
                torch.save(model.state_dict(), file_name)
                torch.save(opt.state_dict(), file_name.replace('best_model.pth', 'optimizer_state.pth'))
                torch.save(scheduler.state_dict(), file_name.replace('best_model.pth', 'scheduler_state.pth'))
                best_val_loss = (val_loss/d_v)
                log_epoch(epoch)
                
            else:
                early_stopping_counter += 1
                print(f"No improvement. Early stopping counter: {early_stopping_counter}/{patience}")
            

            if early_stopping_counter >= patience:
                print("Early stopping triggered.")
                break
            """

            #### Saving all models and no early stopping
            print(avg_train)
            print(avg_val)
            file_name = task_path + losses_used + '/models/fold-' + str(fold) + '/model_fold_' + str(fold) + '_best_model_epoch_' + str(epoch) + '.pth'
            print("Saving Model as ", file_name)
            torch.save(model.state_dict(), file_name)
            torch.save(opt.state_dict(), file_name.replace('best_model_', 'optimizer_state_'))
            torch.save(scheduler.state_dict(), file_name.replace('best_model_', 'scheduler_state_'))
            best_val_loss = (val_loss/d_v)
            log_epoch(epoch)
            
            
            # Step the scheduler at every epoch
            scheduler.step(val_loss/d_v)
            #scheduler.step()
    
            # Simulating updated losses
            all_train_losses.append((train_loss/d))  
            all_val_losses.append((val_loss/d_v)) 
            #print("train losses", all_train_losses)
            #print("val losses", all_val_losses)
            #print(len((np.array(all_train_losses))), len((np.array(all_val_losses)))) 
            
            # Plot and save current losses
            plot_all(loss_file_loc, fold)
            plot_all_iter(task_path + losses_used + "/all_losses_iter_train_fold_" + str(fold) + ".csv", fold, "train")
            plot_all_iter(task_path + losses_used + "/all_losses_iter_val_fold_" + str(fold) + ".csv", fold, "val")
            #plot_and_save_losses(np.array(all_train_losses), np.array(all_val_losses), fold)

            
    # Log the model
    mlflow.pytorch.log_model(model, task_name_exp)
    print('\n')


#now = datetime.now()
now = datetime.datetime.now()
print("TRAINING DONE !")
print("Terminated at:", now)

