### IMPORTS

In [None]:
import os
import math
import yaml
import pandas as pd
import numpy as np

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [None]:
from utils.plots import plot_srcc_MAE, plot_srcc_MAE_2, plot_losses, parity_plot_2, parity_plot_3, plot_srcc_MAE_3
from utils.helpers import load_csv

from models.FFN import FFN

### HELPERS

In [None]:
def save_config_file(model_checkpoints_folder, config):
    if not os.path.exists(model_checkpoints_folder):
        os.makedirs(model_checkpoints_folder)
        #make th config file. The file dons't already exist there

    with open(os.path.join(model_checkpoints_folder, 'config_finetune.yaml'), 'w') as file:
            yaml.dump(config, file)

def load_pre_trained_weights(model, config, device):
        try:
            checkpoints_folder = config['fine_tune_from']
            # checkpoints_folder = os.path.join('./finetune', self.config['fine_tune_from'], 'checkpoints')
            state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=device)
            model.load_state_dict(state_dict)
            # model.load_my_state_dict(state_dict)
            print("Loaded pre-trained model with success.")
            return model, False
        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")
            return model, True

### TRAIN AND TEST

In [None]:
def train(config):
   
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    data = load_csv(config['dataset']['file'])
    if config['task_name'] == 'cond':
        indices = [154, 996, 1114, 1031, 373, 729, 45, 521, 423, 287, 318, 236, 558, 1096, 836, 570, 412, 1046, 592, 377, 818, 1159, 789, 598, 35, 319, 834, 269, 336, 454, 726, 504, 722, 1128, 320, 1110, 806, 897, 1038, 843, 168, 90, 1092, 1100, 992, 948, 927, 355, 138, 863, 741, 1142, 297, 7, 160, 348, 766, 73, 1082, 442, 1044, 247, 658, 266, 111, 286, 572, 1086, 1177, 213, 210, 1173, 368, 1093, 257, 178, 304, 1048, 259, 1150, 899, 337, 407, 1172, 1072, 356, 452, 381, 495, 126, 898, 280, 721, 176, 367, 913, 338, 398, 877, 975, 606, 268, 613, 574, 580, 200, 997, 122, 291, 1117, 459, 1081, 315, 52, 647, 131, 151, 418, 1204, 589, 409, 30, 728, 466, 1005, 619, 14, 745, 66, 712, 175, 930, 901, 42, 196, 1074, 9, 1220, 937, 795, 59, 453, 508, 892, 511, 705, 156, 220, 157, 1113, 1019, 281, 1065, 119, 594, 225, 65, 187, 637, 826, 623, 498, 579, 939, 1155, 972, 400, 124, 223, 1069, 1024, 405, 366, 907, 496, 1196, 1201, 1095, 378, 876, 292, 851, 23, 627, 272, 322, 77, 879, 350, 793, 4, 462, 506, 980, 191, 911, 624, 903, 861, 617, 1146, 390, 277, 1148, 915, 46, 773, 567, 197, 414, 1144, 955, 500, 153, 533, 478, 239, 542, 147, 1203, 76, 86, 889, 54, 822, 720, 797, 169, 798, 258, 704, 294, 227, 1179, 141, 943, 963, 681, 735, 967, 703, 1045, 775, 219, 1011, 95, 343, 180, 179, 105, 951, 543, 365, 802, 649, 887, 886, 28, 796, 252, 538, 1, 524, 1213, 416, 974, 882, 988, 1047, 714, 841, 1033, 404, 487, 522, 12, 36, 1215, 58, 1125, 825, 1189, 1014, 284, 139, 1137, 285, 3, 1002, 536, 1186, 218, 1139, 110, 62, 49, 610, 708, 207, 778, 785, 1170, 374, 128, 684, 1211, 746, 1124, 1200, 769, 557, 945, 510, 450, 727, 477, 936, 949, 148, 401, 1194, 688, 201, 1192, 864, 464, 837, 1154, 1210, 1003, 96, 645, 181, 1104, 1066, 1145, 691, 1140, 170, 48, 87, 57, 1206, 499, 1034, 526, 747, 217, 883, 174, 249, 670, 920, 369, 758, 1166, 1090, 555, 93, 612, 419, 484, 72, 578, 635, 989, 807, 8, 490, 307, 850, 551, 26, 842, 626, 226, 1073, 31, 1182, 548, 1099, 1135, 1085, 518, 1123, 661, 425, 830, 577, 447, 697, 631, 303, 733, 971, 134, 214, 869, 27, 305, 63, 237, 342, 644, 385, 981, 221, 479, 689, 754, 749, 917, 2, 1036, 288, 940, 1106, 1112, 264, 206, 1130, 376, 891, 756, 620, 363, 641, 819, 402, 112, 931, 654, 634, 1077, 461, 267, 1032, 957, 1165, 656, 290, 817, 422, 881, 293, 799, 212, 391, 857, 198, 755, 1187, 1097, 568, 643, 70, 585, 935, 757, 424, 1167, 75, 683, 438, 942, 301, 786, 92, 858, 777, 330, 67, 1126, 491, 328, 602, 117, 970, 17, 687, 1018, 1162, 1188, 104, 1212, 840, 282, 1098, 695, 906, 202, 360, 809, 171, 482, 334, 278, 941, 968, 870, 445, 489, 717, 1208, 1207, 306, 411, 106, 662, 260, 772, 140, 208, 960, 321, 135, 29, 420, 1111, 372, 588, 1149, 136, 384, 718, 1055, 316, 118, 132, 763, 1161, 604, 1185, 685, 1195, 1151, 1078, 607, 335, 1062, 855, 1025, 597, 1017, 982, 302, 593, 1022, 709, 1217, 1000, 701, 488, 397, 1138, 759, 748, 98, 730, 590, 79, 163, 839, 724, 601, 235, 209, 541, 916, 528, 486, 675, 275, 325, 0, 231, 203, 677, 173, 194, 15, 768, 99, 91, 1119, 289, 673, 599, 710, 120, 672, 831, 455, 333, 234, 929, 781, 1012, 990, 1197, 921, 816, 774, 64, 155, 1040, 364, 51, 1136, 182, 553, 995, 380, 639, 961, 513, 860, 186, 1028, 725, 788, 944, 739, 808, 1157, 609, 503, 966, 790, 83, 43, 633, 177, 299, 1219, 465, 723, 354, 792, 581, 520, 779, 615, 183, 671, 509, 41, 115, 101, 494, 1030, 872, 692, 314, 859, 457, 761, 25, 39, 636, 706, 271, 1216, 1027, 1132, 274, 909, 332, 6, 161, 847, 783, 448, 1010, 311, 853, 1171, 370, 1129, 862, 782, 1020, 794, 959, 137, 674, 1075, 679, 950, 1026, 327, 801, 868, 1061, 595, 539, 646, 737, 625, 433, 44, 663, 1202, 1051, 1049, 279, 1131, 481, 428, 565, 1035, 413, 353, 1122, 562, 791, 744, 1008, 682, 923, 893, 359, 273, 238, 517, 652, 800, 736, 94, 149, 1127, 375, 569, 630, 261, 13, 956, 396, 383, 382, 582, 97, 241, 985, 846, 878, 60, 107, 638, 1103, 145, 693, 999, 324, 784, 270, 719, 1083, 1101, 738, 347, 339, 300, 399, 815, 583, 561, 37, 444, 146, 821, 667, 753, 812, 1118, 653, 535, 53, 666, 100, 642, 1037, 497, 125, 516, 71, 113, 771, 1029, 16, 977, 531, 232, 651, 946, 874, 1001, 485, 904, 924, 222, 199, 395, 326, 767, 1175, 699, 702, 1156, 731, 1054, 576, 1089, 547, 523, 427, 142, 854, 803, 265, 890, 74, 165, 852, 1052, 114, 502, 190, 1091, 68, 559, 552, 811, 1102, 463, 1060, 1108, 659, 742, 650, 1079, 32, 530, 986, 740, 586, 660, 984, 317, 549, 912, 529, 244, 1076, 922, 928, 1178, 994, 611, 827, 838, 918, 483, 965, 1160, 953, 1004, 776, 805, 253, 1152, 603, 514, 587, 216, 880, 344, 1088, 298, 1116, 648, 308, 1050, 387, 1053, 55, 655, 760, 1043, 243, 629, 84, 22, 628, 38, 991, 1070, 823, 85, 431, 449, 848, 596, 392, 162, 1169, 296, 566, 591, 1191, 47, 632, 440, 865, 446, 832, 312, 164, 1180, 678, 361, 248, 443, 386, 1143, 698, 211, 229, 50, 329, 254, 900, 532, 575, 251, 1153, 564, 1007, 295, 456, 1176, 501, 780, 1174, 1009, 694, 109, 925, 993, 1080, 525, 895, 88, 751, 764, 310, 358, 690, 545, 417, 1042, 969, 246, 458, 349, 910, 1068, 926, 1134, 434, 537, 534, 1193, 1184, 716, 973, 1164, 188, 127, 5, 608, 888, 130, 1058, 1064, 480, 224, 192, 341, 600, 408, 987, 1163, 896, 820, 34, 546, 676, 964, 571, 80, 1059, 665, 410, 732, 276, 1016, 954, 527, 976, 228, 331, 204, 492, 255, 554, 871, 403, 560, 250, 885, 240, 1168, 998, 493, 833, 938, 143, 205, 184, 713, 309, 1071, 426, 908, 1115, 1094, 357, 1039, 323, 680, 867, 711, 515, 1105, 20, 1120, 829, 283, 1199, 934, 345, 664, 844, 432, 193, 734, 1190, 544, 640, 512, 1158, 743, 144, 1109, 668, 947, 123, 129, 849, 770, 752, 824, 81, 415, 437, 696, 371, 958, 189, 394, 884, 1023, 102, 875, 1107, 657, 18, 245, 905, 765, 810, 866, 919, 388, 1121, 621, 1021, 150, 856, 406, 1209, 19, 618, 563, 622, 185, 429, 56, 215, 762, 40, 873, 167, 10, 1141, 103, 616, 605, 835, 435, 158, 352, 932, 460, 894, 700, 256, 340, 556, 952, 1057, 814, 686, 389, 550, 195, 439, 61, 233, 540, 436, 573, 584, 441, 505, 1041, 1056, 467, 902, 230, 262, 1133, 750, 172, 430, 11, 351, 1205, 346, 1063, 787, 313, 669, 362, 166, 108, 519, 33, 1198, 1013, 82, 121, 979, 89, 116, 451, 828, 978, 393, 1006, 152, 914, 21, 983, 78, 507, 962, 1183, 1067, 69, 1221, 421, 1084, 707, 1147, 1218, 845, 159, 715, 133, 813, 1181, 614, 1214, 242, 24, 1087, 379, 804, 263, 1015, 933]
    elif config['task_name'] == 'visc':
        indices = [270, 102, 121, 16, 281, 8, 400, 179, 10, 98, 170, 221, 114, 249, 303, 388, 229, 65, 259, 456, 175, 111, 139, 357, 129, 298, 260, 215, 420, 337, 458, 379, 457, 108, 340, 241, 403, 430, 449, 243, 12, 274, 261, 4, 207, 22, 474, 31, 434, 141, 271, 192, 60, 15, 265, 122, 97, 0, 149, 284, 69, 383, 126, 324, 147, 358, 437, 127, 362, 154, 71, 292, 219, 137, 426, 231, 242, 310, 318, 256, 30, 283, 395, 196, 276, 155, 389, 142, 133, 78, 354, 72, 49, 320, 416, 210, 445, 330, 404, 460, 258, 140, 349, 187, 224, 405, 48, 132, 38, 293, 25, 470, 94, 105, 66, 398, 183, 84, 250, 40, 421, 287, 123, 419, 41, 366, 275, 29, 441, 232, 323, 370, 360, 469, 216, 350, 235, 332, 352, 81, 453, 286, 44, 312, 200, 326, 364, 115, 206, 394, 386, 191, 201, 285, 279, 80, 319, 444, 220, 74, 214, 109, 291, 296, 309, 368, 315, 347, 407, 264, 223, 353, 46, 13, 85, 331, 230, 186, 225, 401, 390, 11, 228, 106, 158, 255, 161, 20, 333, 311, 226, 425, 239, 131, 307, 443, 317, 290, 450, 348, 222, 365, 248, 177, 338, 130, 280, 378, 165, 336, 304, 156, 325, 244, 341, 466, 50, 119, 146, 189, 468, 51, 157, 34, 471, 162, 99, 47, 5, 308, 160, 410, 32, 442, 218, 475, 87, 208, 205, 116, 54, 472, 342, 53, 273, 152, 202, 321, 55, 376, 355, 237, 9, 27, 64, 253, 278, 68, 128, 19, 387, 263, 181, 90, 439, 168, 373, 135, 344, 409, 209, 43, 361, 167, 334, 297, 18, 384, 211, 174, 63, 42, 393, 294, 267, 402, 79, 172, 100, 251, 118, 58, 396, 371, 33, 335, 408, 459, 195, 14, 385, 3, 184, 277, 418, 272, 83, 234, 345, 150, 262, 289, 300, 180, 23, 417, 346, 76, 91, 125, 254, 37, 306, 21, 467, 440, 447, 327, 414, 212, 190, 163, 236, 26, 197, 88, 423, 446, 305, 73, 382, 185, 153, 104, 329, 424, 257, 406, 57, 432, 52, 107, 301, 412, 233, 103, 413, 428, 377, 148, 194, 171, 144, 363, 448, 101, 328, 45, 117, 391, 268, 7, 138, 380, 429, 198, 435, 464, 367, 247, 164, 92, 359, 143, 1, 188, 59, 203, 411, 465, 397, 35, 227, 238, 75, 56, 399, 2, 454, 436, 6, 461, 415, 112, 351, 61, 199, 269, 240, 463, 246, 28, 113, 288, 369, 299, 24, 145, 124, 176, 433, 95, 375, 473, 70, 295, 381, 245, 266, 178, 151, 314, 313, 120, 136, 431, 452, 17, 77, 62, 193, 173, 169, 166, 356, 252, 462, 93, 438, 89, 451, 422, 217, 67, 82, 204, 159, 427, 182, 343, 96, 134, 213, 392, 455, 302, 282, 36, 374, 86, 316, 322, 372, 39, 110, 339]

    num_train = data.shape[0]
    split = int(np.floor(0.1 * num_train))
    split2 = int(np.floor(0.2 * num_train))
    split3 = int(np.floor(config['dataset']['train_size'] * num_train))
    test_idx, valid_idx, train_idx = indices[:split2], indices[split2:split+split2], indices[split+split2:split+split2+split3]

    train = data.iloc[train_idx, :]
    val = data.iloc[valid_idx, :]
    test_data = data.iloc[test_idx, :]

    #convert bool to int
    train = train.astype({col: 'int' for col in train.select_dtypes(include=['bool']).columns})
    val = val.astype({col: 'int' for col in val.select_dtypes(include=['bool']).columns})
    test_data = test_data.astype({col: 'int' for col in test_data.select_dtypes(include=['bool']).columns})
    
    train.iloc[:, 1:] = train.iloc[:, 1:].apply(pd.to_numeric, errors='coerce')
    val.iloc[:, 1:] = val.iloc[:, 1:].apply(pd.to_numeric, errors='coerce')
    test_data.iloc[:, 1:] = test_data.iloc[:, 1:].apply(pd.to_numeric, errors='coerce')

    train = train.dropna()
    val = val.dropna()
    test_data = test_data.dropna()

    train = train.reset_index(drop=True)
    val = val.reset_index(drop=True)
    test_data = test_data.reset_index(drop=True)

    # data = data.astype({col: 'int' for col in data.select_dtypes(include=['bool']).columns}) 
    # data.iloc[:, 1:] = data.iloc[:, 1:].apply(pd.to_numeric, errors='coerce')

    # #remove rows with string
    # data = data.dropna()
    # data = data.reset_index(drop=True)

    train.iloc[:, -7:-2] = np.log(train.iloc[:, -7:-2])
    test_data.iloc[:, -7:-2] = np.log(test_data.iloc[:, -7:-2])
    val.iloc[:, -7:-2] = np.log(val.iloc[:, -7:-2])

    
    #dataset
    train_dataset = TensorDataset(torch.tensor(train.iloc[:, 1:-12].values, dtype=torch.float32), torch.tensor(train.iloc[:, -12:-7].values, dtype=torch.float32), torch.tensor(train.iloc[:, -7:].values, dtype=torch.float32))
    val_dataset = TensorDataset(torch.tensor(val.iloc[:, 1:-12].values, dtype=torch.float32), torch.tensor(val.iloc[:, -12:-7].values, dtype=torch.float32), torch.tensor(val.iloc[:, -7:].values, dtype=torch.float32))
    test_dataset = TensorDataset(torch.tensor(test_data.iloc[:, 1:-12].values, dtype=torch.float32), torch.tensor(test_data.iloc[:, -12:-7].values, dtype=torch.float32), torch.tensor(test_data.iloc[:, -7:].values, dtype=torch.float32))

    # dataloader
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

    # load model
    model = FFN(data.shape[1]-13, config)
    # model1 = ANNModel2(data.shape[1]-13, 150)
    # model1 = model1.to(device)

    # model = newModel(model1, 150, config)
    model = model.to(device)

    model, scratch = load_pre_trained_weights(model, config, device)
    
    param_list = ['ln_A.weight', 'B.weight']
    if not scratch:
        for name, param in model.named_parameters():
            if name not in param_list:
                print(name)
                param.requires_grad = False

    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=config['lr'])

    # save config file
    save_config_file(model.model_folder_path, config)

    best_val_loss = math.inf
    patience_counter = 0

    train_losses = []
    val_losses = []

    epochs = config['epochs']
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        batch = 0
        for features, temp, labels in train_loader:
            batch += 1
            features, temp, labels = features.to(device), temp.to(device), labels.to(device)
            optimizer.zero_grad()
            
            outputs, ln_A, B = model(features, temp)
            # print(outputs)
            loss = 0.85* criterion(outputs, labels[:, :-2]) + 0.15 * criterion(ln_A, labels[:, -2]) + 0 * criterion(B, labels[:, -1])
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for features, temp, labels in val_loader:
                features, temp, labels = features.to(device), temp.to(device), labels.to(device)
                outputs, ln_A, B = model(features, temp)
                loss = 0.85* criterion(outputs, labels[:, :-2]) + 0.15 * criterion(ln_A, labels[:, -2]) + 0 * criterion(B, labels[:, -1])
                val_loss += loss.item()
        
        train_loss /= batch
        val_loss /= batch

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}\n best_val_loss: {best_val_loss:.4f} patience_counter: {patience_counter}')

        if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                torch.save(model.state_dict(), os.path.join(model.model_folder_path, "model.pth"))
        else:
            patience_counter += 1
        
        if patience_counter >= config['patience']:
            print("Early stopping")
            break
    
    return train_losses, val_losses, model, test_loader

In [None]:
def test(model, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_path = os.path.join(model.model_folder_path, 'model.pth')
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    print("Loaded trained model with success.")

    model.to(device)
    predictions = []
    label = []
    ln_As = []
    A_pred = []
    Bs = []
    B_pred = []
    
    total_loss = 0.0
    criterion = nn.L1Loss()

    with torch.no_grad():
        batch = 0
        for features, temp, labels in test_loader:
            batch += 1
            features, temp, labels = features.to(device), temp.to(device), labels.to(device)

            outputs, ln_A, B = model(features, temp)
            loss = 0.85* criterion(outputs, labels[:, :-2]) + 0.15* criterion(ln_A, labels[:, -2]) + 0 * criterion(B, labels[:, -1])
            total_loss += loss.item()
            
            predictions.extend(outputs.cpu().numpy())
            label.extend(labels[:, :-2].cpu().numpy())

            ln_A = ln_A.reshape(-1, 1).expand(-1, 5)
            A_pred.extend(ln_A.cpu().numpy())
            ln_A_actual = labels[:, -2].reshape(-1, 1).expand(-1, 5)
            ln_As.extend(ln_A_actual.cpu().numpy())

            B = B.reshape(-1, 1).expand(-1, 5)
            B_pred.extend(B.cpu().numpy())
            B_actual = labels[:, -1].reshape(-1, 1).expand(-1, 5)
            Bs.extend(B_actual.cpu().numpy())
            
    test_loss = total_loss / batch
    print(f'Test Loss: {test_loss:.4f}')

    predictions = np.array(predictions).reshape(-1, 5)
    label = np.array(label).reshape(-1, 5)
    A_pred = np.array(A_pred).reshape(-1, 5)
    ln_As = np.array(ln_As).reshape(-1, 5)
    Bs = np.array(Bs).reshape(-1, 5)
    B_pred = np.array(B_pred).reshape(-1, 5)

    return predictions, label, A_pred, ln_As, B_pred, Bs


In [None]:
def main(config):
    train_losses, val_losses, model, test_loader = train(config)
    
    plot_losses(train_losses, val_losses, model.model_folder_path)

    predictions, labels, A_pred, ln_As, B_pred, Bs = test(model, test_loader)

    srcc, mae = parity_plot_3(predictions, labels, model.model_folder_path, tag = config['task_name'])

    parity_plot_2(A_pred, ln_As, model.model_folder_path, tag = 'ln(A)')

    parity_plot_2(B_pred, Bs, model.model_folder_path, tag = 'B')
    
    return srcc, mae

In [None]:
def run(task, main_folder, tune_from, runs, targets, transfer):   
    config = yaml.load(open("config_finetune.yaml", "r"), Loader=yaml.FullLoader)

    if task == 'visc':
        config['dataset']['file'] = 'Datasets/viscData/all_data_254.csv'
        config['task_name'] = 'visc'
    elif task == 'cond':
        config['dataset']['file'] = 'Datasets/condData/all_data_254.csv'
        config['task_name'] = 'cond'

        
    s_avg = np.zeros((targets, 5))
    s_std = np.zeros((targets, 5))
    m_avg = np.zeros((targets, 5))
    m_std = np.zeros((targets, 5))

    if transfer:    
        results = {
            'scratch_srcc': np.zeros((runs, targets, 5)).tolist(),
            'Transfer_srcc': np.zeros((runs, targets, 5)).tolist(),
            'scratch_mae': np.zeros((runs, targets, 5)).tolist(),
            'Transfer_mae': np.zeros((runs, targets, 5)).tolist(),
            'axis': np.zeros(targets).tolist()
        }
        transfer_s_avg = np.zeros((targets, 5))
        transfer_s_std = np.zeros((targets, 5))
        transfer_m_avg = np.zeros((targets, 5))
        transfer_m_std = np.zeros((targets, 5))
    else:
        results = {
            'scratch_srcc': np.zeros((runs, targets, 5)).tolist(),
            'scratch_mae': np.zeros((runs, targets, 5)).tolist(),
            'axis': np.zeros(targets).tolist()
        }

    for j in range(runs):
        folder = os.path.join(main_folder, f'test_{j+1}')
        os.makedirs(folder, exist_ok=False)

        if transfer:    
            transfer_srcc = np.zeros((targets, 5))        
            transfer_mae = np.zeros((targets, 5))

        scratch_mae = np.zeros((targets, 5))
        scratch = np.zeros((targets, 5))

        axis = np.zeros(targets)

        a = 0
        for i in range(6, 6-targets, -1):
            print("start")
            config['dataset']['train_size'] = (i+1)*0.1
            config['save_folder'] = folder
            if transfer:
                config['fine_tune_from'] = tune_from
                config['name'] = f'Transfer_{(i+1)*0.1:.1f}'
                
                srcc_BT, mae_BT = main(config)
                transfer_srcc[a] = srcc_BT
                transfer_mae[a] = mae_BT

            config['fine_tune_from'] = 'None'
            config['name'] = f'Scratch_{(i+1)*0.1:.1f}'
            srcc_s, mae_s = main(config)
            scratch[a] = srcc_s
            scratch_mae[a] = mae_s
            

            if config["task_name"] == 'visc':
                axis[a] = int((i+1)*0.1*477)
            elif config["task_name"] == 'cond':
                axis[a] = int((i+1)*0.1*1222)
            elif config["task_name"] == 'visc_hc':
                axis[a] = int((i+1)*0.1*182)

            a += 1
            print("done")
        
        if transfer:
            for i in range(5):
                plot_srcc_MAE(scratch[:, i], transfer_srcc[:, i], s_std[:, i], transfer_s_std[:, i], axis, f'{i+1}', folder, tag2='srcc')
                plot_srcc_MAE(scratch_mae[:, i], transfer_mae[:, i], m_std[:, i], transfer_m_std[:, i], axis, f'{i+1}', folder, tag2='mae')

            results['Transfer_srcc'][j] = transfer_srcc.tolist()
            results['Transfer_mae'][j] = transfer_mae.tolist()

            transfer_s_avg += transfer_srcc
            transfer_m_avg += transfer_mae

        results['scratch_mae'][j] = scratch_mae.tolist()
        results['scratch_srcc'][j] = scratch.tolist()
        results['axis'] = axis.tolist()

        s_avg += scratch
        m_avg += scratch_mae

    if transfer:
        transfer_s_avg /= runs
        transfer_m_avg /= runs
        s_avg /= runs
        m_avg /= runs

        if runs > 1:
            s_std = np.std(np.array(results['scratch_srcc']), axis=0)
            m_std = np.std(np.array(results['scratch_mae']), axis=0)
            transfer_s_std = np.std(np.array(results['Transfer_srcc']), axis=0)
            transfer_m_std = np.std(np.array(results['Transfer_mae']), axis=0)

            for i in range(5):
                plot_srcc_MAE(s_avg[:, i], transfer_s_avg[:, i], s_std[:, i], transfer_s_std[:, i], axis, f'{i+1}', main_folder, tag2='srcc', tag3 = 'avg')
                plot_srcc_MAE(m_avg[:, i], transfer_m_avg[:, i], m_std[:, i], transfer_m_std[:, i], axis, f'{i+1}', main_folder, tag2='mae', tag3 = 'avg')
    else:
        s_avg /= runs
        m_avg /= runs
        if runs > 1:
            s_std = np.std(np.array(results['scratch_srcc']), axis=0)
            m_std = np.std(np.array(results['scratch_mae']), axis=0)

            for i in range(5):
                plot_srcc_MAE_2(s_avg[:, i], s_std[:, i], axis, f'{i+1}', main_folder, tag2='srcc', tag3 = 'avg')
                plot_srcc_MAE_2(m_avg[:, i], m_std[:, i], axis, f'{i+1}', main_folder, tag2='mae', tag3 = 'avg')

    # Save results to YAML file
    with open(f"{main_folder}/results.yaml", "w") as file:
        yaml.dump(results, file)

In [None]:
# use the config finetune for changing hyperparameter

task = ['cond']   # 'cond', 'visc'
tune_from = ['results/visc/ANN_scratch/test_1/Scratch_0.7']  # use the entire path to the folder containing the model. If not using one: ' '
runs = [2]
targets = [7]
main_folder = ['ANN_scratch_2'] # only need the name of the folder 
transfer = [True]


for i in range(len(main_folder)):
    run(task[i], f'results/{task[i]}/{main_folder[i]}', tune_from[i], runs[i], targets[i], transfer[i])