In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
import MultiTaskDataset

In [2]:
class IAModel(nn.Module):
    def __init__(self, input_dimension=7, output_dimension=20):
        super(IAModel, self).__init__()
        
        self.input_dimension = input_dimension
        self.output_dimension = output_dimension
        
        self.network = nn.Sequential(
            nn.Linear(input_dimension, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, output_dimension)
        )

    def forward(self, x):
        x = self.network(x)
        return x

In [3]:
input_data = np.load('combined_inputs_constant_bolplanck.npy').astype(np.float32)
output_data = np.load('combined_outputs_constant_bolplanck.npy').astype(np.float32)

nan_indices = np.argwhere(np.isnan(output_data))
row_indices_with_nan = set(nan_indices[:, 0])

input_cleaned = np.delete(input_data, list(row_indices_with_nan), axis=0)
output_cleaned = np.delete(output_data, list(row_indices_with_nan), axis=0)

train_int, dev_int, train_out, dev_out = train_test_split(input_cleaned, output_cleaned, 
                                                          test_size=0.2, random_state=42)

train_out1 = train_out[:, 0, :]
train_out2 = train_out[:, 1, :]
train_out3 = train_out[:, 2, :]

dev_out1 = dev_out[:, 0, :]
dev_out2 = dev_out[:, 1, :]
dev_out3 = dev_out[:, 2, :]

In [4]:
train_dataset = MultiTaskDataset.MultiTaskDataset(train_int, train_out1, train_out2, train_out3)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)

dev_dataset = MultiTaskDataset.MultiTaskDataset(dev_int, dev_out1, dev_out2, dev_out3)
dev_dataloader = DataLoader(dev_dataset)

In [10]:
def AAPD(predictions, target):
    absolute_percentage_difference = torch.abs(predictions - target) / (torch.abs(target) + 1e-14)
    aapd = torch.mean(absolute_percentage_difference)
    return aapd

In [17]:
def ASPD(predictions, target):
    absolute_square_percent_diff = (predictions - target)**2 / (torch.abs(target) + 1e-14)
    aspd = torch.mean(absolute_square_percent_diff)
    return aspd

In [26]:
def train_model(model, model_name, train_loader, dev_loader, target_idx=1, num_epochs=1000, learning_rate=0.1, weight_decay=1e-4):
    #criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[250, 500, 750, 1000], gamma=0.1)
    
    best_model = None
    best_aapd = None

    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            input_data = batch['input_data']
            target1 = batch['out1']
            target2 = batch['out2']
            target3 = batch['out3']
            
            optimizer.zero_grad()
            outputs = model(input_data)
            if target_idx == 1:
                train_loss = AAPD(outputs, target1)
            elif target_idx == 2:
                train_loss = AAPD(outputs, target2)
            elif target_idx == 3:
                train_loss = AAPD(outputs, target3)
            else:
                print("Invalid target_idx")
            
            train_loss.backward()
            optimizer.step()
        scheduler.step()
        
        model.eval()
        
        total_aapd = 0
        num_samples = 0
        
        with torch.no_grad():
            for batch in dev_loader:
                dev_int = batch['input_data']
                dev_out1 = batch['out1']
                dev_out2 = batch['out2']
                dev_out3 = batch['out3']
                
                predictions = model(dev_int)
                
                if target_idx == 1:
                    batch_aapd = AAPD(predictions, dev_out1)
                elif target_idx == 2:
                    batch_aapd = AAPD(predictions, dev_out2)
                elif target_idx == 3:
                    batch_aapd = AAPD(predictions, dev_out3)
                else:
                    print("Invalid target_idx")
                    
                total_aapd += batch_aapd * len(dev_int)
                num_samples += len(dev_int)
            
            aapd = total_aapd / num_samples
            
            if best_aapd == None or aapd < best_aapd:
                best_aapd = aapd
                best_model = model
                torch.save(model.state_dict(), model_name)
        
        print(f'Epoch [{epoch + 1}/{num_epochs}] - Loss: {train_loss.item():.4f}; AAPD: {aapd}')

    return best_model

In [21]:
model = IAModel(input_dimension=train_int.shape[1], output_dimension=train_out1.shape[1])

In [16]:
best_model_AAPD_1 = train_model(model, 'best_model_AAPD_1.pth', train_dataloader, dev_dataloader, target_idx=1)

Epoch [1/1000] - Loss: 0.6439; AAPD: 0.5957490801811218
Epoch [2/1000] - Loss: 0.7581; AAPD: 0.7242905497550964
Epoch [3/1000] - Loss: 0.7563; AAPD: 0.7147077918052673
Epoch [4/1000] - Loss: 0.7126; AAPD: 0.7122555375099182
Epoch [5/1000] - Loss: 0.7259; AAPD: 0.713603675365448
Epoch [6/1000] - Loss: 0.6902; AAPD: 0.715836226940155
Epoch [7/1000] - Loss: 0.6884; AAPD: 0.711166501045227
Epoch [8/1000] - Loss: 0.6879; AAPD: 0.713876485824585
Epoch [9/1000] - Loss: 0.7798; AAPD: 0.7130299210548401
Epoch [10/1000] - Loss: 0.7062; AAPD: 0.7130486369132996
Epoch [11/1000] - Loss: 0.7306; AAPD: 0.7181627750396729
Epoch [12/1000] - Loss: 0.6964; AAPD: 0.712512195110321
Epoch [13/1000] - Loss: 0.6961; AAPD: 0.7129423022270203
Epoch [14/1000] - Loss: 0.6845; AAPD: 0.7126301527023315
Epoch [15/1000] - Loss: 0.6926; AAPD: 0.7173328399658203
Epoch [16/1000] - Loss: 0.8295; AAPD: 0.7132474780082703
Epoch [17/1000] - Loss: 0.7593; AAPD: 0.7127993106842041
Epoch [18/1000] - Loss: 0.6044; AAPD: 0.71299

Epoch [145/1000] - Loss: 0.7353; AAPD: 0.7124621272087097
Epoch [146/1000] - Loss: 0.6867; AAPD: 0.7171221971511841
Epoch [147/1000] - Loss: 0.7292; AAPD: 0.7149987816810608
Epoch [148/1000] - Loss: 0.7568; AAPD: 0.7141122221946716
Epoch [149/1000] - Loss: 0.6963; AAPD: 0.7129796743392944
Epoch [150/1000] - Loss: 0.7611; AAPD: 0.7125725746154785
Epoch [151/1000] - Loss: 0.6200; AAPD: 0.7145295143127441
Epoch [152/1000] - Loss: 0.6924; AAPD: 0.7134584784507751
Epoch [153/1000] - Loss: 0.6948; AAPD: 0.7128730416297913
Epoch [154/1000] - Loss: 0.7646; AAPD: 0.7119861245155334
Epoch [155/1000] - Loss: 0.6019; AAPD: 0.7132304906845093
Epoch [156/1000] - Loss: 0.7329; AAPD: 0.7142159342765808
Epoch [157/1000] - Loss: 0.6364; AAPD: 0.7137237787246704
Epoch [158/1000] - Loss: 0.8298; AAPD: 0.7161356210708618
Epoch [159/1000] - Loss: 0.7877; AAPD: 0.7135691046714783
Epoch [160/1000] - Loss: 0.6672; AAPD: 0.7134929895401001
Epoch [161/1000] - Loss: 0.6431; AAPD: 0.7136338949203491
Epoch [162/100

Epoch [287/1000] - Loss: 0.7271; AAPD: 0.6944892406463623
Epoch [288/1000] - Loss: 0.7186; AAPD: 0.6886288523674011
Epoch [289/1000] - Loss: 0.4954; AAPD: 0.5610947608947754
Epoch [290/1000] - Loss: 0.4652; AAPD: 0.5209718346595764
Epoch [291/1000] - Loss: 0.5995; AAPD: 0.525632917881012
Epoch [292/1000] - Loss: 0.4499; AAPD: 0.5095353722572327
Epoch [293/1000] - Loss: 0.4680; AAPD: 0.535639762878418
Epoch [294/1000] - Loss: 0.3243; AAPD: 0.442961186170578
Epoch [295/1000] - Loss: 0.3917; AAPD: 0.4326959550380707
Epoch [296/1000] - Loss: 0.3802; AAPD: 0.44109681248664856
Epoch [297/1000] - Loss: 0.4209; AAPD: 0.4287852346897125
Epoch [298/1000] - Loss: 0.3767; AAPD: 0.4067024290561676
Epoch [299/1000] - Loss: 0.3224; AAPD: 0.40758371353149414
Epoch [300/1000] - Loss: 0.3891; AAPD: 0.3898580074310303
Epoch [301/1000] - Loss: 0.4192; AAPD: 0.3751072287559509
Epoch [302/1000] - Loss: 0.3638; AAPD: 0.37198731303215027
Epoch [303/1000] - Loss: 0.3636; AAPD: 0.37861648201942444
Epoch [304/10

Epoch [428/1000] - Loss: 0.3191; AAPD: 0.340279221534729
Epoch [429/1000] - Loss: 0.3047; AAPD: 0.3272031247615814
Epoch [430/1000] - Loss: 0.2991; AAPD: 0.35929638147354126
Epoch [431/1000] - Loss: 0.3762; AAPD: 0.35973626375198364
Epoch [432/1000] - Loss: 0.5149; AAPD: 0.3653222620487213
Epoch [433/1000] - Loss: 0.2698; AAPD: 0.33466094732284546
Epoch [434/1000] - Loss: 0.2356; AAPD: 0.3386039733886719
Epoch [435/1000] - Loss: 0.3738; AAPD: 0.37257319688796997
Epoch [436/1000] - Loss: 0.3015; AAPD: 0.33928412199020386
Epoch [437/1000] - Loss: 0.2971; AAPD: 0.3503713011741638
Epoch [438/1000] - Loss: 0.2853; AAPD: 0.3315029442310333
Epoch [439/1000] - Loss: 0.3995; AAPD: 0.3525925278663635
Epoch [440/1000] - Loss: 0.4051; AAPD: 0.3611142337322235
Epoch [441/1000] - Loss: 0.2839; AAPD: 0.34945055842399597
Epoch [442/1000] - Loss: 0.2446; AAPD: 0.35655128955841064
Epoch [443/1000] - Loss: 0.3635; AAPD: 0.33041855692863464
Epoch [444/1000] - Loss: 0.3766; AAPD: 0.3745799958705902
Epoch [

Epoch [569/1000] - Loss: 0.2442; AAPD: 0.2691028118133545
Epoch [570/1000] - Loss: 0.2900; AAPD: 0.27200567722320557
Epoch [571/1000] - Loss: 0.2567; AAPD: 0.27080363035202026
Epoch [572/1000] - Loss: 0.2563; AAPD: 0.26958298683166504
Epoch [573/1000] - Loss: 0.2691; AAPD: 0.27239784598350525
Epoch [574/1000] - Loss: 0.2640; AAPD: 0.2671128511428833
Epoch [575/1000] - Loss: 0.2735; AAPD: 0.27083638310432434
Epoch [576/1000] - Loss: 0.2895; AAPD: 0.27308833599090576
Epoch [577/1000] - Loss: 0.2683; AAPD: 0.27396160364151
Epoch [578/1000] - Loss: 0.2777; AAPD: 0.2747892737388611
Epoch [579/1000] - Loss: 0.2528; AAPD: 0.2753540277481079
Epoch [580/1000] - Loss: 0.2745; AAPD: 0.27340927720069885
Epoch [581/1000] - Loss: 0.3248; AAPD: 0.27043595910072327
Epoch [582/1000] - Loss: 0.2711; AAPD: 0.26622793078422546
Epoch [583/1000] - Loss: 0.2697; AAPD: 0.26977434754371643
Epoch [584/1000] - Loss: 0.2117; AAPD: 0.2751997411251068
Epoch [585/1000] - Loss: 0.2748; AAPD: 0.2709735035896301
Epoch 

Epoch [710/1000] - Loss: 0.2036; AAPD: 0.25291526317596436
Epoch [711/1000] - Loss: 0.2538; AAPD: 0.24914973974227905
Epoch [712/1000] - Loss: 0.2142; AAPD: 0.2558137774467468
Epoch [713/1000] - Loss: 0.2502; AAPD: 0.25460106134414673
Epoch [714/1000] - Loss: 0.3116; AAPD: 0.24460434913635254
Epoch [715/1000] - Loss: 0.2519; AAPD: 0.25427666306495667
Epoch [716/1000] - Loss: 0.3014; AAPD: 0.2569279074668884
Epoch [717/1000] - Loss: 0.2975; AAPD: 0.2519901394844055
Epoch [718/1000] - Loss: 0.3186; AAPD: 0.2482420802116394
Epoch [719/1000] - Loss: 0.2556; AAPD: 0.26263466477394104
Epoch [720/1000] - Loss: 0.3969; AAPD: 0.26049306988716125
Epoch [721/1000] - Loss: 0.1980; AAPD: 0.24984680116176605
Epoch [722/1000] - Loss: 0.2605; AAPD: 0.24979093670845032
Epoch [723/1000] - Loss: 0.1864; AAPD: 0.25469815731048584
Epoch [724/1000] - Loss: 0.3620; AAPD: 0.2645724415779114
Epoch [725/1000] - Loss: 0.1695; AAPD: 0.2473517507314682
Epoch [726/1000] - Loss: 0.2902; AAPD: 0.27446603775024414
Epo

Epoch [850/1000] - Loss: 0.2686; AAPD: 0.22733527421951294
Epoch [851/1000] - Loss: 0.2503; AAPD: 0.22453050315380096
Epoch [852/1000] - Loss: 0.2358; AAPD: 0.2251737415790558
Epoch [853/1000] - Loss: 0.1820; AAPD: 0.22180120646953583
Epoch [854/1000] - Loss: 0.1870; AAPD: 0.22135886549949646
Epoch [855/1000] - Loss: 0.1955; AAPD: 0.22085334360599518
Epoch [856/1000] - Loss: 0.1788; AAPD: 0.22232715785503387
Epoch [857/1000] - Loss: 0.2621; AAPD: 0.22664344310760498
Epoch [858/1000] - Loss: 0.2187; AAPD: 0.22001786530017853
Epoch [859/1000] - Loss: 0.2467; AAPD: 0.21983006596565247
Epoch [860/1000] - Loss: 0.1486; AAPD: 0.22075743973255157
Epoch [861/1000] - Loss: 0.1782; AAPD: 0.22030945122241974
Epoch [862/1000] - Loss: 0.2224; AAPD: 0.2200455367565155
Epoch [863/1000] - Loss: 0.2340; AAPD: 0.21916760504245758
Epoch [864/1000] - Loss: 0.1814; AAPD: 0.2275509536266327
Epoch [865/1000] - Loss: 0.1478; AAPD: 0.2189544290304184
Epoch [866/1000] - Loss: 0.2318; AAPD: 0.21879512071609497
E

Epoch [990/1000] - Loss: 0.2146; AAPD: 0.20429757237434387
Epoch [991/1000] - Loss: 0.1469; AAPD: 0.2060972899198532
Epoch [992/1000] - Loss: 0.1134; AAPD: 0.20442664623260498
Epoch [993/1000] - Loss: 0.2014; AAPD: 0.20406098663806915
Epoch [994/1000] - Loss: 0.1828; AAPD: 0.20577938854694366
Epoch [995/1000] - Loss: 0.2272; AAPD: 0.2042919546365738
Epoch [996/1000] - Loss: 0.2429; AAPD: 0.20403438806533813
Epoch [997/1000] - Loss: 0.1469; AAPD: 0.2043539583683014
Epoch [998/1000] - Loss: 0.2148; AAPD: 0.2042607069015503
Epoch [999/1000] - Loss: 0.1901; AAPD: 0.20400193333625793
Epoch [1000/1000] - Loss: 0.2228; AAPD: 0.20434916019439697


In [20]:
best_model_ASPD_1 = train_model(model, 'best_model_ASPD_1.pth', train_dataloader, dev_dataloader, target_idx=1)

Epoch [1/1000] - Loss: 573.8908; AAPD: 0.9449365139007568
Epoch [2/1000] - Loss: 417.5602; AAPD: 0.9206342101097107
Epoch [3/1000] - Loss: 1125.0707; AAPD: 1.0001150369644165
Epoch [4/1000] - Loss: 1063.8912; AAPD: 0.9649357795715332
Epoch [5/1000] - Loss: 2363.7888; AAPD: 0.9730366468429565
Epoch [6/1000] - Loss: 1190.8456; AAPD: 1.0032057762145996
Epoch [7/1000] - Loss: 646.9072; AAPD: 0.9706142544746399
Epoch [8/1000] - Loss: 1641.1843; AAPD: 0.9419130682945251
Epoch [9/1000] - Loss: 642.6738; AAPD: 0.9929583668708801
Epoch [10/1000] - Loss: 525.2616; AAPD: 0.7864018082618713
Epoch [11/1000] - Loss: 1158.2382; AAPD: 0.8987759351730347
Epoch [12/1000] - Loss: 447.8716; AAPD: 0.9245490431785583
Epoch [13/1000] - Loss: 621.5920; AAPD: 0.9395466446876526
Epoch [14/1000] - Loss: 2186.9612; AAPD: 0.9301175475120544
Epoch [15/1000] - Loss: 801.0953; AAPD: 0.937135636806488
Epoch [16/1000] - Loss: 1183.4219; AAPD: 0.9398585557937622
Epoch [17/1000] - Loss: 1215.4321; AAPD: 0.936351716518402

Epoch [139/1000] - Loss: 1252.7141; AAPD: 0.9425013661384583
Epoch [140/1000] - Loss: 1139.0380; AAPD: 0.941634476184845
Epoch [141/1000] - Loss: 1177.0667; AAPD: 0.9327710270881653
Epoch [142/1000] - Loss: 1559.9203; AAPD: 0.9354181289672852
Epoch [143/1000] - Loss: 1081.7380; AAPD: 0.9240027070045471
Epoch [144/1000] - Loss: 2506.8394; AAPD: 0.9401810169219971
Epoch [145/1000] - Loss: 1050.3181; AAPD: 0.9426705837249756
Epoch [146/1000] - Loss: 1609.6752; AAPD: 0.9875584244728088
Epoch [147/1000] - Loss: 1302.2727; AAPD: 0.9387228488922119
Epoch [148/1000] - Loss: 976.0226; AAPD: 2.5785810947418213
Epoch [149/1000] - Loss: 1492.8691; AAPD: 1.298581600189209
Epoch [150/1000] - Loss: 1500.4004; AAPD: 0.9888836741447449
Epoch [151/1000] - Loss: 501.2108; AAPD: 0.9391704201698303
Epoch [152/1000] - Loss: 1497.8118; AAPD: 2.6828017234802246
Epoch [153/1000] - Loss: 530.2473; AAPD: 2.075925827026367
Epoch [154/1000] - Loss: 1573.8691; AAPD: 1.4659210443496704
Epoch [155/1000] - Loss: 853.2

Epoch [275/1000] - Loss: 418.3195; AAPD: 0.4089542031288147
Epoch [276/1000] - Loss: 414.4045; AAPD: 0.4780025780200958
Epoch [277/1000] - Loss: 153.6325; AAPD: 0.4031071364879608
Epoch [278/1000] - Loss: 282.3789; AAPD: 0.43281933665275574
Epoch [279/1000] - Loss: 527.9802; AAPD: 0.621480405330658
Epoch [280/1000] - Loss: 171.3996; AAPD: 0.45738866925239563
Epoch [281/1000] - Loss: 416.2550; AAPD: 0.5450253486633301
Epoch [282/1000] - Loss: 292.3776; AAPD: 0.45281270146369934
Epoch [283/1000] - Loss: 143.9226; AAPD: 0.4154397249221802
Epoch [284/1000] - Loss: 541.7142; AAPD: 0.43265581130981445
Epoch [285/1000] - Loss: 90.7693; AAPD: 0.4404200613498688
Epoch [286/1000] - Loss: 253.9487; AAPD: 0.47221848368644714
Epoch [287/1000] - Loss: 233.3679; AAPD: 0.4062744677066803
Epoch [288/1000] - Loss: 321.9010; AAPD: 0.4414040148258209
Epoch [289/1000] - Loss: 324.5577; AAPD: 0.40070098638534546
Epoch [290/1000] - Loss: 123.3899; AAPD: 0.4573631286621094
Epoch [291/1000] - Loss: 809.4070; A

Epoch [412/1000] - Loss: 36.3557; AAPD: 0.25128844380378723
Epoch [413/1000] - Loss: 47.8699; AAPD: 0.22807809710502625
Epoch [414/1000] - Loss: 22.8512; AAPD: 0.2632335424423218
Epoch [415/1000] - Loss: 16.8518; AAPD: 0.29613637924194336
Epoch [416/1000] - Loss: 77.7866; AAPD: 0.25579673051834106
Epoch [417/1000] - Loss: 67.3843; AAPD: 0.21854020655155182
Epoch [418/1000] - Loss: 29.7685; AAPD: 0.4453505873680115
Epoch [419/1000] - Loss: 25.2855; AAPD: 0.21748583018779755
Epoch [420/1000] - Loss: 95.9432; AAPD: 0.2771376371383667
Epoch [421/1000] - Loss: 27.3143; AAPD: 0.256200909614563
Epoch [422/1000] - Loss: 36.2803; AAPD: 0.2800484001636505
Epoch [423/1000] - Loss: 23.8753; AAPD: 0.22333109378814697
Epoch [424/1000] - Loss: 54.4244; AAPD: 0.2979651689529419
Epoch [425/1000] - Loss: 19.2792; AAPD: 0.28185874223709106
Epoch [426/1000] - Loss: 104.7493; AAPD: 0.2200988531112671
Epoch [427/1000] - Loss: 10.3818; AAPD: 0.21352124214172363
Epoch [428/1000] - Loss: 52.4915; AAPD: 0.24868

Epoch [550/1000] - Loss: 334.6929; AAPD: 0.15041792392730713
Epoch [551/1000] - Loss: 12.1714; AAPD: 0.14362160861492157
Epoch [552/1000] - Loss: 119.8698; AAPD: 0.15214556455612183
Epoch [553/1000] - Loss: 238.7825; AAPD: 0.140186607837677
Epoch [554/1000] - Loss: 5.3717; AAPD: 0.14621825516223907
Epoch [555/1000] - Loss: 75.3325; AAPD: 0.14755606651306152
Epoch [556/1000] - Loss: 31.1982; AAPD: 0.14358322322368622
Epoch [557/1000] - Loss: 28.1740; AAPD: 0.14944927394390106
Epoch [558/1000] - Loss: 54.2109; AAPD: 0.17156805098056793
Epoch [559/1000] - Loss: 8.0990; AAPD: 0.15405115485191345
Epoch [560/1000] - Loss: 15.5104; AAPD: 0.14105722308158875
Epoch [561/1000] - Loss: 3.4602; AAPD: 0.14685603976249695
Epoch [562/1000] - Loss: 93.8721; AAPD: 0.1475042849779129
Epoch [563/1000] - Loss: 8.4096; AAPD: 0.15658298134803772
Epoch [564/1000] - Loss: 14.6867; AAPD: 0.14091627299785614
Epoch [565/1000] - Loss: 24.7481; AAPD: 0.14415021240711212
Epoch [566/1000] - Loss: 69.1679; AAPD: 0.21

Epoch [688/1000] - Loss: 221.6914; AAPD: 0.15619274973869324
Epoch [689/1000] - Loss: 53.8018; AAPD: 0.12389467656612396
Epoch [690/1000] - Loss: 24.8951; AAPD: 0.1369030773639679
Epoch [691/1000] - Loss: 15.8022; AAPD: 0.12493988871574402
Epoch [692/1000] - Loss: 232.2460; AAPD: 0.14010074734687805
Epoch [693/1000] - Loss: 44.8828; AAPD: 0.13037489354610443
Epoch [694/1000] - Loss: 237.9876; AAPD: 0.1262294501066208
Epoch [695/1000] - Loss: 351.0884; AAPD: 0.13839878141880035
Epoch [696/1000] - Loss: 129.9367; AAPD: 0.13439354300498962
Epoch [697/1000] - Loss: 12.7785; AAPD: 0.15140780806541443
Epoch [698/1000] - Loss: 165.0451; AAPD: 0.13127131760120392
Epoch [699/1000] - Loss: 11.0787; AAPD: 0.12702488899230957
Epoch [700/1000] - Loss: 57.2960; AAPD: 0.13613395392894745
Epoch [701/1000] - Loss: 19.7288; AAPD: 0.12356620281934738
Epoch [702/1000] - Loss: 5.1859; AAPD: 0.12287580966949463
Epoch [703/1000] - Loss: 39.8060; AAPD: 0.13043293356895447
Epoch [704/1000] - Loss: 71.2343; AAP

Epoch [826/1000] - Loss: 25.6476; AAPD: 0.11071790009737015
Epoch [827/1000] - Loss: 25.1741; AAPD: 0.11611944437026978
Epoch [828/1000] - Loss: 11.9985; AAPD: 0.10883044451475143
Epoch [829/1000] - Loss: 53.1634; AAPD: 0.10569462180137634
Epoch [830/1000] - Loss: 5.6684; AAPD: 0.1069478988647461
Epoch [831/1000] - Loss: 9.4628; AAPD: 0.12704575061798096
Epoch [832/1000] - Loss: 358.8984; AAPD: 0.10540913790464401
Epoch [833/1000] - Loss: 138.9644; AAPD: 0.10644551366567612
Epoch [834/1000] - Loss: 39.8028; AAPD: 0.10574047267436981
Epoch [835/1000] - Loss: 2.4120; AAPD: 0.10514923185110092
Epoch [836/1000] - Loss: 59.3978; AAPD: 0.11142526566982269
Epoch [837/1000] - Loss: 22.8803; AAPD: 0.1079995334148407
Epoch [838/1000] - Loss: 49.3207; AAPD: 0.10581357032060623
Epoch [839/1000] - Loss: 408.5001; AAPD: 0.10658161342144012
Epoch [840/1000] - Loss: 213.8301; AAPD: 0.10745976120233536
Epoch [841/1000] - Loss: 42.2371; AAPD: 0.10510118305683136
Epoch [842/1000] - Loss: 10.7965; AAPD: 0

Epoch [964/1000] - Loss: 56.5964; AAPD: 0.0997353121638298
Epoch [965/1000] - Loss: 4.2908; AAPD: 0.1023179143667221
Epoch [966/1000] - Loss: 64.9747; AAPD: 0.10077362507581711
Epoch [967/1000] - Loss: 185.1535; AAPD: 0.10069463402032852
Epoch [968/1000] - Loss: 30.2905; AAPD: 0.10040771216154099
Epoch [969/1000] - Loss: 189.2021; AAPD: 0.10273412615060806
Epoch [970/1000] - Loss: 434.7903; AAPD: 0.10200148820877075
Epoch [971/1000] - Loss: 163.4500; AAPD: 0.09955786913633347
Epoch [972/1000] - Loss: 31.5542; AAPD: 0.10235442966222763
Epoch [973/1000] - Loss: 10.8485; AAPD: 0.09998512268066406
Epoch [974/1000] - Loss: 8.1609; AAPD: 0.10241872817277908
Epoch [975/1000] - Loss: 203.8310; AAPD: 0.09878663718700409
Epoch [976/1000] - Loss: 6.6153; AAPD: 0.0985875278711319
Epoch [977/1000] - Loss: 23.0666; AAPD: 0.09933599084615707
Epoch [978/1000] - Loss: 42.6909; AAPD: 0.10373964160680771
Epoch [979/1000] - Loss: 7.1787; AAPD: 0.10016288608312607
Epoch [980/1000] - Loss: 2.0419; AAPD: 0.0

In [22]:
best_model_ASPD_2 = train_model(model, 'best_model_ASPD_2.pth', train_dataloader, dev_dataloader, target_idx=2)

Epoch [1/1000] - Loss: 0.0269; AAPD: 1.8445143699645996
Epoch [2/1000] - Loss: 0.0384; AAPD: 1.0259267091751099
Epoch [3/1000] - Loss: 0.0215; AAPD: 1.0231976509094238
Epoch [4/1000] - Loss: 0.0247; AAPD: 1.0220545530319214
Epoch [5/1000] - Loss: 0.0310; AAPD: 1.055550217628479
Epoch [6/1000] - Loss: 0.0273; AAPD: 1.0122548341751099
Epoch [7/1000] - Loss: 0.0232; AAPD: 1.0286167860031128
Epoch [8/1000] - Loss: 0.0326; AAPD: 1.0622284412384033
Epoch [9/1000] - Loss: 0.0259; AAPD: 1.047387957572937
Epoch [10/1000] - Loss: 0.0302; AAPD: 1.0389622449874878
Epoch [11/1000] - Loss: 0.0246; AAPD: 1.046159267425537
Epoch [12/1000] - Loss: 47.9113; AAPD: 147.63943481445312
Epoch [13/1000] - Loss: 0.0295; AAPD: 1.636544942855835
Epoch [14/1000] - Loss: 0.0327; AAPD: 1.1026802062988281
Epoch [15/1000] - Loss: 0.0221; AAPD: 1.3774909973144531
Epoch [16/1000] - Loss: 0.0267; AAPD: 1.0730409622192383
Epoch [17/1000] - Loss: 0.0245; AAPD: 1.1216706037521362
Epoch [18/1000] - Loss: 0.0239; AAPD: 1.065

Epoch [144/1000] - Loss: 2.9848; AAPD: 28.39226531982422
Epoch [145/1000] - Loss: 98.4780; AAPD: 214.51206970214844
Epoch [146/1000] - Loss: 41.2745; AAPD: 52.92497634887695
Epoch [147/1000] - Loss: 22.5647; AAPD: 44.2236328125
Epoch [148/1000] - Loss: 19.2943; AAPD: 44.17055892944336
Epoch [149/1000] - Loss: 12.9329; AAPD: 34.85716247558594
Epoch [150/1000] - Loss: 7.8103; AAPD: 28.255016326904297
Epoch [151/1000] - Loss: 6.6604; AAPD: 21.815073013305664
Epoch [152/1000] - Loss: 3.1703; AAPD: 13.669536590576172
Epoch [153/1000] - Loss: 2.4004; AAPD: 8.411958694458008
Epoch [154/1000] - Loss: 0.5544; AAPD: 29.923187255859375
Epoch [155/1000] - Loss: 0.0380; AAPD: 1.6427544355392456
Epoch [156/1000] - Loss: 0.0341; AAPD: 1.0556460618972778
Epoch [157/1000] - Loss: 12.6046; AAPD: 77.06525421142578
Epoch [158/1000] - Loss: 0.0305; AAPD: 1.3441489934921265
Epoch [159/1000] - Loss: 0.0388; AAPD: 1.3058295249938965
Epoch [160/1000] - Loss: 0.0805; AAPD: 1.4441066980361938
Epoch [161/1000] - 

Epoch [285/1000] - Loss: 3.2739; AAPD: 4.2267608642578125
Epoch [286/1000] - Loss: 0.0734; AAPD: 3.2224199771881104
Epoch [287/1000] - Loss: 0.1403; AAPD: 3.5689361095428467
Epoch [288/1000] - Loss: 0.7401; AAPD: 3.9112601280212402
Epoch [289/1000] - Loss: 0.0861; AAPD: 3.2891623973846436
Epoch [290/1000] - Loss: 0.0615; AAPD: 2.9823319911956787
Epoch [291/1000] - Loss: 0.0976; AAPD: 3.7989871501922607
Epoch [292/1000] - Loss: 0.5834; AAPD: 6.27363920211792
Epoch [293/1000] - Loss: 0.1027; AAPD: 6.895804405212402
Epoch [294/1000] - Loss: 0.2755; AAPD: 4.354248046875
Epoch [295/1000] - Loss: 0.2055; AAPD: 3.8108644485473633
Epoch [296/1000] - Loss: 0.0894; AAPD: 3.626831531524658
Epoch [297/1000] - Loss: 0.2281; AAPD: 4.204985618591309
Epoch [298/1000] - Loss: 5.1030; AAPD: 44.15367889404297
Epoch [299/1000] - Loss: 0.1152; AAPD: 3.949613571166992
Epoch [300/1000] - Loss: 0.3427; AAPD: 4.173369407653809
Epoch [301/1000] - Loss: 0.2363; AAPD: 3.1962335109710693
Epoch [302/1000] - Loss: 0

Epoch [428/1000] - Loss: 0.0585; AAPD: 1.4891749620437622
Epoch [429/1000] - Loss: 0.0338; AAPD: 1.9391980171203613
Epoch [430/1000] - Loss: 0.0268; AAPD: 2.4237873554229736
Epoch [431/1000] - Loss: 0.0513; AAPD: 2.9016213417053223
Epoch [432/1000] - Loss: 0.0245; AAPD: 1.731361985206604
Epoch [433/1000] - Loss: 0.0249; AAPD: 1.2544125318527222
Epoch [434/1000] - Loss: 0.0297; AAPD: 1.222542405128479
Epoch [435/1000] - Loss: 0.0369; AAPD: 1.2470403909683228
Epoch [436/1000] - Loss: 0.0474; AAPD: 1.3641103506088257
Epoch [437/1000] - Loss: 0.0242; AAPD: 1.314665675163269
Epoch [438/1000] - Loss: 0.0290; AAPD: 1.4412113428115845
Epoch [439/1000] - Loss: 0.0347; AAPD: 5.185628414154053
Epoch [440/1000] - Loss: 0.0258; AAPD: 6.761176109313965
Epoch [441/1000] - Loss: 0.0417; AAPD: 3.7189409732818604
Epoch [442/1000] - Loss: 0.0338; AAPD: 2.4778406620025635
Epoch [443/1000] - Loss: 0.0239; AAPD: 1.177046537399292
Epoch [444/1000] - Loss: 0.0221; AAPD: 1.2453793287277222
Epoch [445/1000] - L

Epoch [571/1000] - Loss: 0.0219; AAPD: 3.138737678527832
Epoch [572/1000] - Loss: 0.0385; AAPD: 1.7945754528045654
Epoch [573/1000] - Loss: 0.0242; AAPD: 1.8845638036727905
Epoch [574/1000] - Loss: 0.0211; AAPD: 1.8625136613845825
Epoch [575/1000] - Loss: 0.0253; AAPD: 1.6804085969924927
Epoch [576/1000] - Loss: 0.0393; AAPD: 2.2337071895599365
Epoch [577/1000] - Loss: 0.0396; AAPD: 6.112628936767578
Epoch [578/1000] - Loss: 0.0943; AAPD: 4.635527610778809
Epoch [579/1000] - Loss: 0.0261; AAPD: 2.2860336303710938
Epoch [580/1000] - Loss: 0.0215; AAPD: 2.8999264240264893
Epoch [581/1000] - Loss: 0.0289; AAPD: 2.573925256729126
Epoch [582/1000] - Loss: 0.0415; AAPD: 2.768343210220337
Epoch [583/1000] - Loss: 0.0340; AAPD: 1.6146458387374878
Epoch [584/1000] - Loss: 0.0359; AAPD: 2.7595653533935547
Epoch [585/1000] - Loss: 0.0592; AAPD: 4.896533012390137
Epoch [586/1000] - Loss: 0.0458; AAPD: 1.7931604385375977
Epoch [587/1000] - Loss: 0.0401; AAPD: 2.3361852169036865
Epoch [588/1000] - L

Epoch [714/1000] - Loss: 0.0330; AAPD: 2.39611554145813
Epoch [715/1000] - Loss: 0.0324; AAPD: 1.2475428581237793
Epoch [716/1000] - Loss: 0.0185; AAPD: 1.3320984840393066
Epoch [717/1000] - Loss: 0.0425; AAPD: 3.8233659267425537
Epoch [718/1000] - Loss: 0.0318; AAPD: 1.9638057947158813
Epoch [719/1000] - Loss: 0.0308; AAPD: 3.3936259746551514
Epoch [720/1000] - Loss: 0.1031; AAPD: 4.453928470611572
Epoch [721/1000] - Loss: 0.0259; AAPD: 2.0932748317718506
Epoch [722/1000] - Loss: 0.0332; AAPD: 1.720862865447998
Epoch [723/1000] - Loss: 0.0458; AAPD: 2.2208383083343506
Epoch [724/1000] - Loss: 0.0659; AAPD: 3.6695847511291504
Epoch [725/1000] - Loss: 0.0258; AAPD: 1.7053277492523193
Epoch [726/1000] - Loss: 0.0272; AAPD: 3.137467384338379
Epoch [727/1000] - Loss: 0.0230; AAPD: 1.7384706735610962
Epoch [728/1000] - Loss: 0.0236; AAPD: 2.072011709213257
Epoch [729/1000] - Loss: 0.0228; AAPD: 1.5637273788452148
Epoch [730/1000] - Loss: 0.0359; AAPD: 2.076664924621582
Epoch [731/1000] - Lo

Epoch [857/1000] - Loss: 0.0255; AAPD: 1.7193093299865723
Epoch [858/1000] - Loss: 0.0279; AAPD: 1.4252674579620361
Epoch [859/1000] - Loss: 0.0452; AAPD: 3.078214406967163
Epoch [860/1000] - Loss: 0.0307; AAPD: 1.955788254737854
Epoch [861/1000] - Loss: 0.0259; AAPD: 3.2293224334716797
Epoch [862/1000] - Loss: 0.0495; AAPD: 2.8064510822296143
Epoch [863/1000] - Loss: 0.0318; AAPD: 1.9632402658462524
Epoch [864/1000] - Loss: 0.0320; AAPD: 2.112816095352173
Epoch [865/1000] - Loss: 0.0304; AAPD: 1.4587979316711426
Epoch [866/1000] - Loss: 0.0223; AAPD: 1.5116599798202515
Epoch [867/1000] - Loss: 0.0286; AAPD: 1.549451470375061
Epoch [868/1000] - Loss: 0.0260; AAPD: 1.8870625495910645
Epoch [869/1000] - Loss: 0.0390; AAPD: 2.5679397583007812
Epoch [870/1000] - Loss: 0.0564; AAPD: 2.0088818073272705
Epoch [871/1000] - Loss: 0.0396; AAPD: 2.2432003021240234
Epoch [872/1000] - Loss: 0.0246; AAPD: 1.7344051599502563
Epoch [873/1000] - Loss: 0.0273; AAPD: 1.752648949623108
Epoch [874/1000] - 

Epoch [1000/1000] - Loss: 0.0286; AAPD: 1.6798810958862305


In [23]:
model = IAModel(input_dimension=train_int.shape[1], output_dimension=train_out1.shape[1])
best_model_ASPD_3 = train_model(model, 'best_model_ASPD_3.pth', train_dataloader, dev_dataloader, 
                                target_idx=3, num_epochs=2000)

Epoch [1/2000] - Loss: 0.1550; AAPD: 12.919614791870117
Epoch [2/2000] - Loss: 0.0117; AAPD: 1.5585758686065674
Epoch [3/2000] - Loss: 0.0091; AAPD: 1.2400150299072266
Epoch [4/2000] - Loss: 0.0107; AAPD: 1.2638429403305054
Epoch [5/2000] - Loss: 0.0104; AAPD: 1.5373367071151733
Epoch [6/2000] - Loss: 0.0105; AAPD: 1.2402822971343994
Epoch [7/2000] - Loss: 0.0151; AAPD: 1.3904705047607422
Epoch [8/2000] - Loss: 0.0546; AAPD: 9.232166290283203
Epoch [9/2000] - Loss: 0.0204; AAPD: 3.3536508083343506
Epoch [10/2000] - Loss: 0.0196; AAPD: 2.7790374755859375
Epoch [11/2000] - Loss: 0.0090; AAPD: 2.7801499366760254
Epoch [12/2000] - Loss: 0.0186; AAPD: 2.5182607173919678
Epoch [13/2000] - Loss: 0.2035; AAPD: 43.70640563964844
Epoch [14/2000] - Loss: 0.0338; AAPD: 10.254897117614746
Epoch [15/2000] - Loss: 0.0216; AAPD: 8.76844596862793
Epoch [16/2000] - Loss: 3.9099; AAPD: 93.6883773803711
Epoch [17/2000] - Loss: 212.6996; AAPD: 397.6949768066406
Epoch [18/2000] - Loss: 0.0274; AAPD: 8.51513

Epoch [145/2000] - Loss: 98.7446; AAPD: 322.7739562988281
Epoch [146/2000] - Loss: 1373.1656; AAPD: 1335.1898193359375
Epoch [147/2000] - Loss: 16.3962; AAPD: 311.0936279296875
Epoch [148/2000] - Loss: 0.3734; AAPD: 53.70566940307617
Epoch [149/2000] - Loss: 6564.0342; AAPD: 32521.12890625
Epoch [150/2000] - Loss: 2179.3689; AAPD: 9752.443359375
Epoch [151/2000] - Loss: 954.1208; AAPD: 4070.073486328125
Epoch [152/2000] - Loss: 6445.3354; AAPD: 10135.3486328125
Epoch [153/2000] - Loss: 920.9121; AAPD: 2303.745849609375
Epoch [154/2000] - Loss: 257.2039; AAPD: 706.9910278320312
Epoch [155/2000] - Loss: 36.0848; AAPD: 199.3832550048828
Epoch [156/2000] - Loss: 1.0281; AAPD: 44.94121170043945
Epoch [157/2000] - Loss: 0.0311; AAPD: 6.799141883850098
Epoch [158/2000] - Loss: 0.0103; AAPD: 1.539739727973938
Epoch [159/2000] - Loss: 0.0095; AAPD: 1.2066624164581299
Epoch [160/2000] - Loss: 0.0100; AAPD: 1.3711317777633667
Epoch [161/2000] - Loss: 0.0128; AAPD: 1.3197221755981445
Epoch [162/20

Epoch [287/2000] - Loss: 0.0199; AAPD: 1.280790090560913
Epoch [288/2000] - Loss: 0.0157; AAPD: 1.2843183279037476
Epoch [289/2000] - Loss: 0.0078; AAPD: 2.2508656978607178
Epoch [290/2000] - Loss: 0.0155; AAPD: 1.312294602394104
Epoch [291/2000] - Loss: 0.0054; AAPD: 1.241316795349121
Epoch [292/2000] - Loss: 0.0352; AAPD: 12.061505317687988
Epoch [293/2000] - Loss: 0.0144; AAPD: 1.405802607536316
Epoch [294/2000] - Loss: 0.0225; AAPD: 11.534296989440918
Epoch [295/2000] - Loss: 0.0108; AAPD: 1.3429019451141357
Epoch [296/2000] - Loss: 0.0096; AAPD: 1.1280663013458252
Epoch [297/2000] - Loss: 0.0090; AAPD: 1.1629445552825928
Epoch [298/2000] - Loss: 0.0233; AAPD: 1.3161540031433105
Epoch [299/2000] - Loss: 0.0049; AAPD: 1.2702628374099731
Epoch [300/2000] - Loss: 0.0213; AAPD: 1.2899549007415771
Epoch [301/2000] - Loss: 0.0175; AAPD: 1.5418227910995483
Epoch [302/2000] - Loss: 0.0098; AAPD: 1.5516115427017212
Epoch [303/2000] - Loss: 452.4201; AAPD: 1006.532470703125
Epoch [304/2000] 

Epoch [430/2000] - Loss: 0.0735; AAPD: 21.001684188842773
Epoch [431/2000] - Loss: 0.0221; AAPD: 2.74613881111145
Epoch [432/2000] - Loss: 0.0764; AAPD: 20.158994674682617
Epoch [433/2000] - Loss: 0.1301; AAPD: 10.417202949523926
Epoch [434/2000] - Loss: 0.0231; AAPD: 13.7650728225708
Epoch [435/2000] - Loss: 0.0111; AAPD: 1.942638874053955
Epoch [436/2000] - Loss: 0.0259; AAPD: 4.833407878875732
Epoch [437/2000] - Loss: 0.0246; AAPD: 5.344541072845459
Epoch [438/2000] - Loss: 0.0205; AAPD: 5.374595642089844
Epoch [439/2000] - Loss: 0.0654; AAPD: 18.999420166015625
Epoch [440/2000] - Loss: 0.4405; AAPD: 44.985984802246094
Epoch [441/2000] - Loss: 0.1774; AAPD: 16.093772888183594
Epoch [442/2000] - Loss: 0.0249; AAPD: 3.372286796569824
Epoch [443/2000] - Loss: 0.0066; AAPD: 2.0286924839019775
Epoch [444/2000] - Loss: 0.0301; AAPD: 7.138736248016357
Epoch [445/2000] - Loss: 0.0126; AAPD: 3.441401958465576
Epoch [446/2000] - Loss: 1.1062; AAPD: 109.96312713623047
Epoch [447/2000] - Loss: 

Epoch [573/2000] - Loss: 0.0115; AAPD: 1.6486648321151733
Epoch [574/2000] - Loss: 0.0081; AAPD: 4.852036952972412
Epoch [575/2000] - Loss: 0.0162; AAPD: 2.7214365005493164
Epoch [576/2000] - Loss: 0.0152; AAPD: 2.1691951751708984
Epoch [577/2000] - Loss: 0.0063; AAPD: 1.5837448835372925
Epoch [578/2000] - Loss: 0.0086; AAPD: 2.160737991333008
Epoch [579/2000] - Loss: 0.0113; AAPD: 1.4224398136138916
Epoch [580/2000] - Loss: 0.0096; AAPD: 1.7611396312713623
Epoch [581/2000] - Loss: 0.0223; AAPD: 2.1428844928741455
Epoch [582/2000] - Loss: 0.0189; AAPD: 1.832668423652649
Epoch [583/2000] - Loss: 0.0139; AAPD: 4.691717624664307
Epoch [584/2000] - Loss: 0.0177; AAPD: 3.3257617950439453
Epoch [585/2000] - Loss: 0.0107; AAPD: 5.67841100692749
Epoch [586/2000] - Loss: 0.0141; AAPD: 1.9825973510742188
Epoch [587/2000] - Loss: 0.0097; AAPD: 2.2911527156829834
Epoch [588/2000] - Loss: 0.0121; AAPD: 1.7817962169647217
Epoch [589/2000] - Loss: 0.0123; AAPD: 2.5352888107299805
Epoch [590/2000] - L

Epoch [716/2000] - Loss: 0.0169; AAPD: 1.8120381832122803
Epoch [717/2000] - Loss: 0.0084; AAPD: 2.1090407371520996
Epoch [718/2000] - Loss: 0.0178; AAPD: 3.4156274795532227
Epoch [719/2000] - Loss: 0.0200; AAPD: 2.465719699859619
Epoch [720/2000] - Loss: 0.0104; AAPD: 1.6728407144546509
Epoch [721/2000] - Loss: 0.0156; AAPD: 2.924417018890381
Epoch [722/2000] - Loss: 0.0113; AAPD: 2.1004831790924072
Epoch [723/2000] - Loss: 0.0139; AAPD: 1.5901421308517456
Epoch [724/2000] - Loss: 0.0113; AAPD: 2.0470423698425293
Epoch [725/2000] - Loss: 0.0071; AAPD: 2.06413197517395
Epoch [726/2000] - Loss: 0.0214; AAPD: 2.702296495437622
Epoch [727/2000] - Loss: 0.0162; AAPD: 2.099700450897217
Epoch [728/2000] - Loss: 0.0104; AAPD: 2.12583589553833
Epoch [729/2000] - Loss: 0.0095; AAPD: 1.6941883563995361
Epoch [730/2000] - Loss: 0.0081; AAPD: 2.366363048553467
Epoch [731/2000] - Loss: 0.0106; AAPD: 2.4738924503326416
Epoch [732/2000] - Loss: 0.0150; AAPD: 1.4044528007507324
Epoch [733/2000] - Loss

Epoch [859/2000] - Loss: 0.0110; AAPD: 1.3554556369781494
Epoch [860/2000] - Loss: 0.0189; AAPD: 1.4076812267303467
Epoch [861/2000] - Loss: 0.0133; AAPD: 2.3519203662872314
Epoch [862/2000] - Loss: 0.0094; AAPD: 1.3400131464004517
Epoch [863/2000] - Loss: 0.0143; AAPD: 3.478755235671997
Epoch [864/2000] - Loss: 0.0076; AAPD: 1.6367460489273071
Epoch [865/2000] - Loss: 0.0077; AAPD: 2.20757794380188
Epoch [866/2000] - Loss: 0.0144; AAPD: 1.3244044780731201
Epoch [867/2000] - Loss: 0.0138; AAPD: 1.2514692544937134
Epoch [868/2000] - Loss: 0.0096; AAPD: 1.9212934970855713
Epoch [869/2000] - Loss: 0.0160; AAPD: 1.2425042390823364
Epoch [870/2000] - Loss: 0.0153; AAPD: 1.7386469841003418
Epoch [871/2000] - Loss: 0.0229; AAPD: 2.258364200592041
Epoch [872/2000] - Loss: 0.0098; AAPD: 1.3232496976852417
Epoch [873/2000] - Loss: 0.0197; AAPD: 1.2486652135849
Epoch [874/2000] - Loss: 0.0083; AAPD: 1.5859062671661377
Epoch [875/2000] - Loss: 0.0095; AAPD: 1.6595673561096191
Epoch [876/2000] - Lo

Epoch [1001/2000] - Loss: 0.0185; AAPD: 1.2778948545455933
Epoch [1002/2000] - Loss: 0.0074; AAPD: 1.8593127727508545
Epoch [1003/2000] - Loss: 0.0161; AAPD: 1.4747393131256104
Epoch [1004/2000] - Loss: 0.0102; AAPD: 1.3122000694274902
Epoch [1005/2000] - Loss: 0.0119; AAPD: 1.3344192504882812
Epoch [1006/2000] - Loss: 0.0114; AAPD: 1.4195102453231812
Epoch [1007/2000] - Loss: 0.0138; AAPD: 1.2956773042678833
Epoch [1008/2000] - Loss: 0.0057; AAPD: 2.2349629402160645
Epoch [1009/2000] - Loss: 0.0143; AAPD: 1.379286766052246
Epoch [1010/2000] - Loss: 0.0250; AAPD: 1.3270750045776367
Epoch [1011/2000] - Loss: 0.0149; AAPD: 1.8500632047653198
Epoch [1012/2000] - Loss: 0.0105; AAPD: 1.412840485572815
Epoch [1013/2000] - Loss: 0.0138; AAPD: 1.3219428062438965
Epoch [1014/2000] - Loss: 0.0063; AAPD: 1.3111757040023804
Epoch [1015/2000] - Loss: 0.0151; AAPD: 1.3673243522644043
Epoch [1016/2000] - Loss: 0.0130; AAPD: 1.5150811672210693
Epoch [1017/2000] - Loss: 0.0222; AAPD: 1.627027153968811


Epoch [1141/2000] - Loss: 0.0153; AAPD: 1.5223850011825562
Epoch [1142/2000] - Loss: 0.0277; AAPD: 1.987241506576538
Epoch [1143/2000] - Loss: 0.0062; AAPD: 1.419325351715088
Epoch [1144/2000] - Loss: 0.0169; AAPD: 1.446317434310913
Epoch [1145/2000] - Loss: 0.0094; AAPD: 1.5549815893173218
Epoch [1146/2000] - Loss: 0.0114; AAPD: 2.318188190460205
Epoch [1147/2000] - Loss: 0.0137; AAPD: 2.0082614421844482
Epoch [1148/2000] - Loss: 0.0219; AAPD: 1.6556042432785034
Epoch [1149/2000] - Loss: 0.0111; AAPD: 1.265869140625
Epoch [1150/2000] - Loss: 0.0076; AAPD: 1.372597098350525
Epoch [1151/2000] - Loss: 0.0144; AAPD: 1.4938631057739258
Epoch [1152/2000] - Loss: 0.0091; AAPD: 1.4649379253387451
Epoch [1153/2000] - Loss: 0.0053; AAPD: 1.61734139919281
Epoch [1154/2000] - Loss: 0.0171; AAPD: 1.3239010572433472
Epoch [1155/2000] - Loss: 0.0246; AAPD: 2.142533540725708
Epoch [1156/2000] - Loss: 0.0079; AAPD: 1.6046442985534668
Epoch [1157/2000] - Loss: 0.0310; AAPD: 1.246899962425232
Epoch [115

Epoch [1281/2000] - Loss: 0.0242; AAPD: 2.0148122310638428
Epoch [1282/2000] - Loss: 0.0190; AAPD: 1.4886908531188965
Epoch [1283/2000] - Loss: 0.0072; AAPD: 1.5660797357559204
Epoch [1284/2000] - Loss: 0.0110; AAPD: 1.4163236618041992
Epoch [1285/2000] - Loss: 0.0077; AAPD: 1.3965067863464355
Epoch [1286/2000] - Loss: 0.0113; AAPD: 2.116507053375244
Epoch [1287/2000] - Loss: 0.0164; AAPD: 2.4157729148864746
Epoch [1288/2000] - Loss: 0.0152; AAPD: 2.077042818069458
Epoch [1289/2000] - Loss: 0.0110; AAPD: 1.3023542165756226
Epoch [1290/2000] - Loss: 0.0113; AAPD: 1.373868465423584
Epoch [1291/2000] - Loss: 0.0199; AAPD: 1.4792367219924927
Epoch [1292/2000] - Loss: 0.0126; AAPD: 1.8160123825073242
Epoch [1293/2000] - Loss: 0.0201; AAPD: 1.2483081817626953
Epoch [1294/2000] - Loss: 0.0090; AAPD: 1.4143712520599365
Epoch [1295/2000] - Loss: 0.0124; AAPD: 1.2829068899154663
Epoch [1296/2000] - Loss: 0.0127; AAPD: 1.3348076343536377
Epoch [1297/2000] - Loss: 0.0114; AAPD: 1.3679755926132202


Epoch [1421/2000] - Loss: 0.0135; AAPD: 1.87450110912323
Epoch [1422/2000] - Loss: 0.0154; AAPD: 1.2918717861175537
Epoch [1423/2000] - Loss: 0.0095; AAPD: 1.2373766899108887
Epoch [1424/2000] - Loss: 0.0143; AAPD: 1.5353269577026367
Epoch [1425/2000] - Loss: 0.0132; AAPD: 1.3893229961395264
Epoch [1426/2000] - Loss: 0.0044; AAPD: 1.6053072214126587
Epoch [1427/2000] - Loss: 0.0085; AAPD: 1.3217291831970215
Epoch [1428/2000] - Loss: 0.0221; AAPD: 1.4420992136001587
Epoch [1429/2000] - Loss: 0.0154; AAPD: 1.243918776512146
Epoch [1430/2000] - Loss: 0.0118; AAPD: 1.4498193264007568
Epoch [1431/2000] - Loss: 0.0133; AAPD: 1.3570494651794434
Epoch [1432/2000] - Loss: 0.0084; AAPD: 1.3251956701278687
Epoch [1433/2000] - Loss: 0.0209; AAPD: 1.402575969696045
Epoch [1434/2000] - Loss: 0.0132; AAPD: 1.2184685468673706
Epoch [1435/2000] - Loss: 0.0106; AAPD: 1.3931729793548584
Epoch [1436/2000] - Loss: 0.0135; AAPD: 1.2925031185150146
Epoch [1437/2000] - Loss: 0.0108; AAPD: 1.8314870595932007
E

Epoch [1561/2000] - Loss: 0.0183; AAPD: 1.3114356994628906
Epoch [1562/2000] - Loss: 0.0124; AAPD: 1.7584962844848633
Epoch [1563/2000] - Loss: 0.0090; AAPD: 1.430005431175232
Epoch [1564/2000] - Loss: 0.0300; AAPD: 1.4623576402664185
Epoch [1565/2000] - Loss: 0.0156; AAPD: 1.3816543817520142
Epoch [1566/2000] - Loss: 0.0150; AAPD: 1.516201138496399
Epoch [1567/2000] - Loss: 0.0129; AAPD: 2.053506374359131
Epoch [1568/2000] - Loss: 0.0118; AAPD: 1.2741059064865112
Epoch [1569/2000] - Loss: 0.0148; AAPD: 1.2316805124282837
Epoch [1570/2000] - Loss: 0.0212; AAPD: 1.3316967487335205
Epoch [1571/2000] - Loss: 0.0077; AAPD: 1.4443656206130981
Epoch [1572/2000] - Loss: 0.0056; AAPD: 1.3266451358795166
Epoch [1573/2000] - Loss: 0.0101; AAPD: 1.2562345266342163
Epoch [1574/2000] - Loss: 0.0121; AAPD: 1.3734184503555298
Epoch [1575/2000] - Loss: 0.0095; AAPD: 1.3315094709396362
Epoch [1576/2000] - Loss: 0.0170; AAPD: 1.5123307704925537
Epoch [1577/2000] - Loss: 0.0233; AAPD: 1.5597069263458252


Epoch [1701/2000] - Loss: 0.0146; AAPD: 1.1980855464935303
Epoch [1702/2000] - Loss: 0.0061; AAPD: 1.3516591787338257
Epoch [1703/2000] - Loss: 0.0152; AAPD: 1.3568134307861328
Epoch [1704/2000] - Loss: 0.0146; AAPD: 1.4862078428268433
Epoch [1705/2000] - Loss: 0.0158; AAPD: 1.3880006074905396
Epoch [1706/2000] - Loss: 0.0130; AAPD: 3.537511110305786
Epoch [1707/2000] - Loss: 0.0096; AAPD: 1.313332438468933
Epoch [1708/2000] - Loss: 0.0107; AAPD: 2.1698286533355713
Epoch [1709/2000] - Loss: 0.0160; AAPD: 1.6500340700149536
Epoch [1710/2000] - Loss: 0.0170; AAPD: 1.2026712894439697
Epoch [1711/2000] - Loss: 0.0093; AAPD: 1.461850881576538
Epoch [1712/2000] - Loss: 0.0141; AAPD: 1.7783334255218506
Epoch [1713/2000] - Loss: 0.0068; AAPD: 1.463445782661438
Epoch [1714/2000] - Loss: 0.0079; AAPD: 1.2943190336227417
Epoch [1715/2000] - Loss: 0.0110; AAPD: 1.3840664625167847
Epoch [1716/2000] - Loss: 0.0124; AAPD: 1.2547316551208496
Epoch [1717/2000] - Loss: 0.0089; AAPD: 1.3608089685440063
E

Epoch [1841/2000] - Loss: 0.0174; AAPD: 1.3077775239944458
Epoch [1842/2000] - Loss: 0.0143; AAPD: 1.497836709022522
Epoch [1843/2000] - Loss: 0.0066; AAPD: 1.7469371557235718
Epoch [1844/2000] - Loss: 0.0127; AAPD: 1.5378828048706055
Epoch [1845/2000] - Loss: 0.0064; AAPD: 1.4392011165618896
Epoch [1846/2000] - Loss: 0.0112; AAPD: 2.530853509902954
Epoch [1847/2000] - Loss: 0.0091; AAPD: 2.062723398208618
Epoch [1848/2000] - Loss: 0.0077; AAPD: 1.2453680038452148
Epoch [1849/2000] - Loss: 0.0133; AAPD: 1.3819279670715332
Epoch [1850/2000] - Loss: 0.0081; AAPD: 1.2237268686294556
Epoch [1851/2000] - Loss: 0.0158; AAPD: 1.7035719156265259
Epoch [1852/2000] - Loss: 0.0082; AAPD: 1.3823403120040894
Epoch [1853/2000] - Loss: 0.0164; AAPD: 1.2348556518554688
Epoch [1854/2000] - Loss: 0.0187; AAPD: 1.5029362440109253
Epoch [1855/2000] - Loss: 0.0093; AAPD: 1.3600183725357056
Epoch [1856/2000] - Loss: 0.0076; AAPD: 1.2391209602355957
Epoch [1857/2000] - Loss: 0.0170; AAPD: 1.2899417877197266


Epoch [1981/2000] - Loss: 0.0187; AAPD: 4.402172088623047
Epoch [1982/2000] - Loss: 0.0130; AAPD: 1.2414767742156982
Epoch [1983/2000] - Loss: 0.0149; AAPD: 1.3179043531417847
Epoch [1984/2000] - Loss: 0.0176; AAPD: 1.498810052871704
Epoch [1985/2000] - Loss: 0.0091; AAPD: 1.4055724143981934
Epoch [1986/2000] - Loss: 0.0122; AAPD: 1.419360876083374
Epoch [1987/2000] - Loss: 0.0125; AAPD: 1.5144685506820679
Epoch [1988/2000] - Loss: 0.0150; AAPD: 1.3677313327789307
Epoch [1989/2000] - Loss: 0.0077; AAPD: 1.6457412242889404
Epoch [1990/2000] - Loss: 0.0207; AAPD: 1.4762517213821411
Epoch [1991/2000] - Loss: 0.0167; AAPD: 1.580424427986145
Epoch [1992/2000] - Loss: 0.0092; AAPD: 1.3136889934539795
Epoch [1993/2000] - Loss: 0.0085; AAPD: 1.3581583499908447
Epoch [1994/2000] - Loss: 0.0263; AAPD: 1.3689326047897339
Epoch [1995/2000] - Loss: 0.0108; AAPD: 1.1639142036437988
Epoch [1996/2000] - Loss: 0.0238; AAPD: 1.3905484676361084
Epoch [1997/2000] - Loss: 0.0167; AAPD: 1.4323029518127441
E

In [25]:
model = IAModel(input_dimension=train_int.shape[1], output_dimension=train_out1.shape[1])
best_model_AAPD_2 = train_model(model, 'best_model_AAPD_2.pth', train_dataloader, dev_dataloader, 
                                target_idx=2, num_epochs=2000)

Epoch [1/2000] - Loss: 31.3723; AAPD: 63.449798583984375
Epoch [2/2000] - Loss: 55.8277; AAPD: 203.3865203857422
Epoch [3/2000] - Loss: 121.8133; AAPD: 556.0507202148438
Epoch [4/2000] - Loss: 1448.4609; AAPD: 534.3009033203125
Epoch [5/2000] - Loss: 145.1593; AAPD: 161.5963592529297
Epoch [6/2000] - Loss: 173.6989; AAPD: 384.0987243652344
Epoch [7/2000] - Loss: 57.1925; AAPD: 287.4859313964844
Epoch [8/2000] - Loss: 520.8979; AAPD: 823.3607788085938
Epoch [9/2000] - Loss: 254.8053; AAPD: 512.3489379882812
Epoch [10/2000] - Loss: 1093.1835; AAPD: 578.7546997070312
Epoch [11/2000] - Loss: 361.2981; AAPD: 424.8931579589844
Epoch [12/2000] - Loss: 133.8467; AAPD: 533.2420043945312
Epoch [13/2000] - Loss: 180.6164; AAPD: 379.9814147949219
Epoch [14/2000] - Loss: 508.6285; AAPD: 406.76629638671875
Epoch [15/2000] - Loss: 122.4398; AAPD: 263.45318603515625
Epoch [16/2000] - Loss: 271.6106; AAPD: 521.8683471679688
Epoch [17/2000] - Loss: 79.4779; AAPD: 593.0987548828125
Epoch [18/2000] - Loss

Epoch [143/2000] - Loss: 257.0619; AAPD: 526.7969970703125
Epoch [144/2000] - Loss: 366.4296; AAPD: 511.2597961425781
Epoch [145/2000] - Loss: 45.6710; AAPD: 215.95167541503906
Epoch [146/2000] - Loss: 125.3817; AAPD: 200.74888610839844
Epoch [147/2000] - Loss: 376.4375; AAPD: 651.680908203125
Epoch [148/2000] - Loss: 315.2320; AAPD: 476.1474914550781
Epoch [149/2000] - Loss: 135.8472; AAPD: 298.5892639160156
Epoch [150/2000] - Loss: 172.1217; AAPD: 880.9335327148438
Epoch [151/2000] - Loss: 55.4588; AAPD: 273.67547607421875
Epoch [152/2000] - Loss: 233.6493; AAPD: 336.63482666015625
Epoch [153/2000] - Loss: 126.6783; AAPD: 216.8974609375
Epoch [154/2000] - Loss: 203.0330; AAPD: 284.2077331542969
Epoch [155/2000] - Loss: 113.2182; AAPD: 816.2620239257812
Epoch [156/2000] - Loss: 1625.1793; AAPD: 667.4577026367188
Epoch [157/2000] - Loss: 255.0540; AAPD: 465.8634338378906
Epoch [158/2000] - Loss: 187.3188; AAPD: 397.8564453125
Epoch [159/2000] - Loss: 191.9333; AAPD: 623.9733276367188
E

Epoch [283/2000] - Loss: 7.5357; AAPD: 19.3326416015625
Epoch [284/2000] - Loss: 10.9663; AAPD: 68.49842834472656
Epoch [285/2000] - Loss: 4.5490; AAPD: 32.728755950927734
Epoch [286/2000] - Loss: 6.5865; AAPD: 22.24757957458496
Epoch [287/2000] - Loss: 21.0379; AAPD: 40.68840026855469
Epoch [288/2000] - Loss: 17.1896; AAPD: 62.101783752441406
Epoch [289/2000] - Loss: 4.0967; AAPD: 39.71251678466797
Epoch [290/2000] - Loss: 21.6514; AAPD: 60.38949966430664
Epoch [291/2000] - Loss: 4.1350; AAPD: 23.99854278564453
Epoch [292/2000] - Loss: 7.3456; AAPD: 34.18486022949219
Epoch [293/2000] - Loss: 18.4060; AAPD: 33.76166915893555
Epoch [294/2000] - Loss: 80.8354; AAPD: 39.238040924072266
Epoch [295/2000] - Loss: 10.1127; AAPD: 49.19048309326172
Epoch [296/2000] - Loss: 5.6100; AAPD: 22.112319946289062
Epoch [297/2000] - Loss: 40.4637; AAPD: 32.205467224121094
Epoch [298/2000] - Loss: 5.9707; AAPD: 32.34433364868164
Epoch [299/2000] - Loss: 18.2431; AAPD: 28.802000045776367
Epoch [300/2000] 

Epoch [425/2000] - Loss: 68.6678; AAPD: 57.43452072143555
Epoch [426/2000] - Loss: 3.1166; AAPD: 21.49382781982422
Epoch [427/2000] - Loss: 29.1569; AAPD: 66.50119018554688
Epoch [428/2000] - Loss: 101.6511; AAPD: 17.54768943786621
Epoch [429/2000] - Loss: 6.9793; AAPD: 26.655120849609375
Epoch [430/2000] - Loss: 25.4867; AAPD: 31.558815002441406
Epoch [431/2000] - Loss: 51.5769; AAPD: 70.36360931396484
Epoch [432/2000] - Loss: 24.9511; AAPD: 25.961214065551758
Epoch [433/2000] - Loss: 24.2057; AAPD: 65.42959594726562
Epoch [434/2000] - Loss: 172.5781; AAPD: 66.23807525634766
Epoch [435/2000] - Loss: 13.2122; AAPD: 25.793865203857422
Epoch [436/2000] - Loss: 9.0964; AAPD: 26.23073387145996
Epoch [437/2000] - Loss: 4.7117; AAPD: 25.752090454101562
Epoch [438/2000] - Loss: 20.0559; AAPD: 93.62139129638672
Epoch [439/2000] - Loss: 12.5574; AAPD: 58.64510726928711
Epoch [440/2000] - Loss: 10.7205; AAPD: 19.531728744506836
Epoch [441/2000] - Loss: 34.4984; AAPD: 78.12973022460938
Epoch [442

Epoch [568/2000] - Loss: 1.4742; AAPD: 4.911922454833984
Epoch [569/2000] - Loss: 3.6705; AAPD: 9.810873985290527
Epoch [570/2000] - Loss: 2.5483; AAPD: 6.245875358581543
Epoch [571/2000] - Loss: 1.3985; AAPD: 5.334527015686035
Epoch [572/2000] - Loss: 1.7156; AAPD: 4.5329461097717285
Epoch [573/2000] - Loss: 1.3098; AAPD: 5.14412260055542
Epoch [574/2000] - Loss: 1.6868; AAPD: 5.938754081726074
Epoch [575/2000] - Loss: 2.9920; AAPD: 6.676125526428223
Epoch [576/2000] - Loss: 2.2580; AAPD: 4.586738109588623
Epoch [577/2000] - Loss: 3.3630; AAPD: 5.5300493240356445
Epoch [578/2000] - Loss: 9.6617; AAPD: 2.894108295440674
Epoch [579/2000] - Loss: 2.5158; AAPD: 5.316781997680664
Epoch [580/2000] - Loss: 2.5732; AAPD: 7.301856517791748
Epoch [581/2000] - Loss: 1.0268; AAPD: 1.7456992864608765
Epoch [582/2000] - Loss: 2.1357; AAPD: 4.426623821258545
Epoch [583/2000] - Loss: 2.2882; AAPD: 4.547918796539307
Epoch [584/2000] - Loss: 1.4198; AAPD: 2.832242012023926
Epoch [585/2000] - Loss: 1.55

Epoch [712/2000] - Loss: 3.2741; AAPD: 10.341519355773926
Epoch [713/2000] - Loss: 3.1810; AAPD: 9.650635719299316
Epoch [714/2000] - Loss: 8.3113; AAPD: 2.4360945224761963
Epoch [715/2000] - Loss: 1.7827; AAPD: 4.9249467849731445
Epoch [716/2000] - Loss: 1.5990; AAPD: 3.5578877925872803
Epoch [717/2000] - Loss: 3.4187; AAPD: 3.352945327758789
Epoch [718/2000] - Loss: 1.8075; AAPD: 4.661569118499756
Epoch [719/2000] - Loss: 5.1116; AAPD: 5.350374221801758
Epoch [720/2000] - Loss: 1.5098; AAPD: 5.1986918449401855
Epoch [721/2000] - Loss: 2.4503; AAPD: 5.352773189544678
Epoch [722/2000] - Loss: 2.2601; AAPD: 7.421180248260498
Epoch [723/2000] - Loss: 1.5496; AAPD: 3.268345355987549
Epoch [724/2000] - Loss: 1.4601; AAPD: 3.6754279136657715
Epoch [725/2000] - Loss: 2.1465; AAPD: 9.018689155578613
Epoch [726/2000] - Loss: 1.8809; AAPD: 9.130391120910645
Epoch [727/2000] - Loss: 10.9870; AAPD: 12.19847297668457
Epoch [728/2000] - Loss: 4.8012; AAPD: 13.152853012084961
Epoch [729/2000] - Loss

KeyboardInterrupt: 

In [27]:
model = IAModel(input_dimension=train_int.shape[1], output_dimension=train_out1.shape[1])
best_model_AAPD_3 = train_model(model, 'best_model_AAPD_3.pth', train_dataloader, dev_dataloader, 
                                target_idx=3, num_epochs=2000)

Epoch [1/2000] - Loss: 4085.8081; AAPD: 3516.738525390625
Epoch [2/2000] - Loss: 1005.9949; AAPD: 5765.31787109375
Epoch [3/2000] - Loss: 141.2170; AAPD: 926.9696655273438
Epoch [4/2000] - Loss: 8072.9204; AAPD: 3619.17626953125
Epoch [5/2000] - Loss: 684.8079; AAPD: 2453.5986328125
Epoch [6/2000] - Loss: 1160.7731; AAPD: 3318.04931640625
Epoch [7/2000] - Loss: 81.5703; AAPD: 328.60247802734375
Epoch [8/2000] - Loss: 2793.3088; AAPD: 3128.867431640625
Epoch [9/2000] - Loss: 1006.4730; AAPD: 6002.04150390625
Epoch [10/2000] - Loss: 203.6106; AAPD: 881.5415649414062
Epoch [11/2000] - Loss: 236.8095; AAPD: 1938.771240234375
Epoch [12/2000] - Loss: 802.7342; AAPD: 2488.275146484375
Epoch [13/2000] - Loss: 740.9330; AAPD: 3992.8720703125
Epoch [14/2000] - Loss: 3103.9644; AAPD: 6257.74609375
Epoch [15/2000] - Loss: 1344.1774; AAPD: 3036.865478515625
Epoch [16/2000] - Loss: 7526.3638; AAPD: 4205.59423828125
Epoch [17/2000] - Loss: 998.7208; AAPD: 2387.570068359375
Epoch [18/2000] - Loss: 366

Epoch [143/2000] - Loss: 6675.0015; AAPD: 2220.5068359375
Epoch [144/2000] - Loss: 2200.8833; AAPD: 2184.37060546875
Epoch [145/2000] - Loss: 1875.3004; AAPD: 3486.555419921875
Epoch [146/2000] - Loss: 2371.6182; AAPD: 3960.033935546875
Epoch [147/2000] - Loss: 5313.1538; AAPD: 6276.337890625
Epoch [148/2000] - Loss: 3661.9871; AAPD: 2840.336181640625
Epoch [149/2000] - Loss: 132041.2031; AAPD: 2892.283203125
Epoch [150/2000] - Loss: 1397.8390; AAPD: 3537.1279296875
Epoch [151/2000] - Loss: 1202.7039; AAPD: 2822.19140625
Epoch [152/2000] - Loss: 1503.7433; AAPD: 5631.92041015625
Epoch [153/2000] - Loss: 465.8116; AAPD: 2568.797119140625
Epoch [154/2000] - Loss: 264.8414; AAPD: 2243.076171875
Epoch [155/2000] - Loss: 2130.9502; AAPD: 3000.887939453125
Epoch [156/2000] - Loss: 1360.8706; AAPD: 3646.30029296875
Epoch [157/2000] - Loss: 1192.0969; AAPD: 1807.9920654296875
Epoch [158/2000] - Loss: 4222.6089; AAPD: 7626.18115234375
Epoch [159/2000] - Loss: 1132.6874; AAPD: 4776.81298828125
E

Epoch [283/2000] - Loss: 2309.9915; AAPD: 463.85394287109375
Epoch [284/2000] - Loss: 13.5472; AAPD: 77.40947723388672
Epoch [285/2000] - Loss: 259.9401; AAPD: 481.4889221191406
Epoch [286/2000] - Loss: 70.5029; AAPD: 570.4014282226562
Epoch [287/2000] - Loss: 156.2765; AAPD: 422.87872314453125
Epoch [288/2000] - Loss: 113.2608; AAPD: 330.9897155761719
Epoch [289/2000] - Loss: 521.1241; AAPD: 257.8894348144531
Epoch [290/2000] - Loss: 130.4616; AAPD: 959.5130615234375
Epoch [291/2000] - Loss: 18.1416; AAPD: 256.09527587890625
Epoch [292/2000] - Loss: 44.9776; AAPD: 165.18121337890625
Epoch [293/2000] - Loss: 400.6371; AAPD: 702.1658935546875
Epoch [294/2000] - Loss: 69.2565; AAPD: 556.887451171875
Epoch [295/2000] - Loss: 91.3740; AAPD: 521.392578125
Epoch [296/2000] - Loss: 122.3133; AAPD: 981.9932861328125
Epoch [297/2000] - Loss: 28.2403; AAPD: 301.9574890136719
Epoch [298/2000] - Loss: 94.5837; AAPD: 119.1085205078125
Epoch [299/2000] - Loss: 115.5723; AAPD: 250.12417602539062
Epoc

Epoch [423/2000] - Loss: 26.1200; AAPD: 87.53482818603516
Epoch [424/2000] - Loss: 56.2230; AAPD: 492.0018310546875
Epoch [425/2000] - Loss: 1743.5696; AAPD: 559.1296997070312
Epoch [426/2000] - Loss: 156.6899; AAPD: 521.3615112304688
Epoch [427/2000] - Loss: 5673.1846; AAPD: 358.82958984375
Epoch [428/2000] - Loss: 1893.1461; AAPD: 294.5361022949219
Epoch [429/2000] - Loss: 157.7854; AAPD: 210.26429748535156
Epoch [430/2000] - Loss: 71.4190; AAPD: 513.8185424804688
Epoch [431/2000] - Loss: 143.4102; AAPD: 212.44081115722656
Epoch [432/2000] - Loss: 273.5407; AAPD: 640.7515869140625
Epoch [433/2000] - Loss: 258.0755; AAPD: 435.0715637207031
Epoch [434/2000] - Loss: 93.7722; AAPD: 948.1473388671875
Epoch [435/2000] - Loss: 19.7817; AAPD: 182.16990661621094
Epoch [436/2000] - Loss: 424.6506; AAPD: 503.5376892089844
Epoch [437/2000] - Loss: 396.2939; AAPD: 423.335693359375
Epoch [438/2000] - Loss: 124.3996; AAPD: 269.53387451171875
Epoch [439/2000] - Loss: 100.6391; AAPD: 338.091674804687

Epoch [564/2000] - Loss: 42.9011; AAPD: 132.7105712890625
Epoch [565/2000] - Loss: 4.0921; AAPD: 19.279521942138672
Epoch [566/2000] - Loss: 11.1083; AAPD: 40.95769500732422
Epoch [567/2000] - Loss: 46.2485; AAPD: 82.18570709228516
Epoch [568/2000] - Loss: 3.5337; AAPD: 22.608871459960938
Epoch [569/2000] - Loss: 233.7673; AAPD: 89.574462890625
Epoch [570/2000] - Loss: 13.9928; AAPD: 34.74465560913086
Epoch [571/2000] - Loss: 22.0460; AAPD: 38.419639587402344
Epoch [572/2000] - Loss: 9.1415; AAPD: 26.01421546936035
Epoch [573/2000] - Loss: 10.6418; AAPD: 11.050042152404785
Epoch [574/2000] - Loss: 9.0631; AAPD: 25.228485107421875
Epoch [575/2000] - Loss: 13.2457; AAPD: 34.63775634765625
Epoch [576/2000] - Loss: 8.9157; AAPD: 18.82479476928711
Epoch [577/2000] - Loss: 12.6418; AAPD: 49.85152816772461
Epoch [578/2000] - Loss: 16.7052; AAPD: 61.86381149291992
Epoch [579/2000] - Loss: 7.7629; AAPD: 85.35032653808594
Epoch [580/2000] - Loss: 3.8143; AAPD: 13.024443626403809
Epoch [581/2000]

Epoch [706/2000] - Loss: 29.1244; AAPD: 69.37327575683594
Epoch [707/2000] - Loss: 5.7493; AAPD: 42.86913299560547
Epoch [708/2000] - Loss: 8.6886; AAPD: 20.467052459716797
Epoch [709/2000] - Loss: 54.0546; AAPD: 100.52051544189453
Epoch [710/2000] - Loss: 10.9167; AAPD: 53.563751220703125
Epoch [711/2000] - Loss: 379.9968; AAPD: 48.35240936279297
Epoch [712/2000] - Loss: 27.3835; AAPD: 111.44634246826172
Epoch [713/2000] - Loss: 12.5383; AAPD: 19.456300735473633
Epoch [714/2000] - Loss: 49.1604; AAPD: 79.49772644042969
Epoch [715/2000] - Loss: 18.1988; AAPD: 50.48940658569336
Epoch [716/2000] - Loss: 19.1746; AAPD: 21.85385513305664
Epoch [717/2000] - Loss: 19.5554; AAPD: 35.01729202270508
Epoch [718/2000] - Loss: 17.6196; AAPD: 51.96388244628906
Epoch [719/2000] - Loss: 13.1365; AAPD: 85.56275177001953
Epoch [720/2000] - Loss: 45.1841; AAPD: 86.34764862060547
Epoch [721/2000] - Loss: 21.6593; AAPD: 45.78323745727539
Epoch [722/2000] - Loss: 25.7018; AAPD: 67.95114135742188
Epoch [723

Epoch [848/2000] - Loss: 4.5337; AAPD: 11.805895805358887
Epoch [849/2000] - Loss: 7.3908; AAPD: 7.088296413421631
Epoch [850/2000] - Loss: 19.8662; AAPD: 16.265377044677734
Epoch [851/2000] - Loss: 1.5668; AAPD: 2.986332416534424
Epoch [852/2000] - Loss: 6.8403; AAPD: 9.388204574584961
Epoch [853/2000] - Loss: 12.2683; AAPD: 12.943058967590332
Epoch [854/2000] - Loss: 3.3942; AAPD: 9.069198608398438
Epoch [855/2000] - Loss: 2.7290; AAPD: 6.890663146972656
Epoch [856/2000] - Loss: 3.7946; AAPD: 11.335433006286621
Epoch [857/2000] - Loss: 5.2600; AAPD: 10.941311836242676
Epoch [858/2000] - Loss: 6.8746; AAPD: 22.050674438476562
Epoch [859/2000] - Loss: 15.8347; AAPD: 13.60901927947998
Epoch [860/2000] - Loss: 1.7991; AAPD: 6.36584997177124
Epoch [861/2000] - Loss: 4.1328; AAPD: 12.038869857788086
Epoch [862/2000] - Loss: 25.2431; AAPD: 10.57034683227539
Epoch [863/2000] - Loss: 7.1315; AAPD: 7.2623114585876465
Epoch [864/2000] - Loss: 4.9855; AAPD: 9.665288925170898
Epoch [865/2000] - L

Epoch [991/2000] - Loss: 6.4535; AAPD: 36.70518112182617
Epoch [992/2000] - Loss: 83.5078; AAPD: 9.403682708740234
Epoch [993/2000] - Loss: 72.4117; AAPD: 8.291946411132812
Epoch [994/2000] - Loss: 11.7654; AAPD: 20.674095153808594
Epoch [995/2000] - Loss: 19.4551; AAPD: 17.47737693786621
Epoch [996/2000] - Loss: 11.2127; AAPD: 15.319567680358887
Epoch [997/2000] - Loss: 78.3411; AAPD: 10.812091827392578
Epoch [998/2000] - Loss: 2.1851; AAPD: 5.049142360687256
Epoch [999/2000] - Loss: 3.9680; AAPD: 13.501252174377441
Epoch [1000/2000] - Loss: 3.9951; AAPD: 22.67093849182129
Epoch [1001/2000] - Loss: 3.6601; AAPD: 12.606352806091309
Epoch [1002/2000] - Loss: 1.5624; AAPD: 3.7687747478485107
Epoch [1003/2000] - Loss: 1.2691; AAPD: 1.8584022521972656
Epoch [1004/2000] - Loss: 1.1513; AAPD: 1.8285248279571533
Epoch [1005/2000] - Loss: 1.3880; AAPD: 1.7631466388702393
Epoch [1006/2000] - Loss: 1.3953; AAPD: 2.189723491668701
Epoch [1007/2000] - Loss: 2.0156; AAPD: 2.661851167678833
Epoch [1

Epoch [1131/2000] - Loss: 1.0462; AAPD: 1.3033825159072876
Epoch [1132/2000] - Loss: 1.3662; AAPD: 3.437825918197632
Epoch [1133/2000] - Loss: 2.4094; AAPD: 2.5083558559417725
Epoch [1134/2000] - Loss: 1.8961; AAPD: 4.457403182983398
Epoch [1135/2000] - Loss: 1.2449; AAPD: 1.6066956520080566
Epoch [1136/2000] - Loss: 1.0186; AAPD: 1.6396600008010864
Epoch [1137/2000] - Loss: 2.2901; AAPD: 3.062164306640625
Epoch [1138/2000] - Loss: 1.6456; AAPD: 1.7963756322860718
Epoch [1139/2000] - Loss: 1.2067; AAPD: 1.3964661359786987
Epoch [1140/2000] - Loss: 6.6576; AAPD: 2.3639378547668457
Epoch [1141/2000] - Loss: 1.5053; AAPD: 3.055938720703125
Epoch [1142/2000] - Loss: 1.3371; AAPD: 2.8936784267425537
Epoch [1143/2000] - Loss: 1.6980; AAPD: 1.8960305452346802
Epoch [1144/2000] - Loss: 1.2771; AAPD: 2.0457284450531006
Epoch [1145/2000] - Loss: 1.3408; AAPD: 2.762266159057617
Epoch [1146/2000] - Loss: 1.8048; AAPD: 2.3102214336395264
Epoch [1147/2000] - Loss: 1.0744; AAPD: 1.2465050220489502
Ep

Epoch [1271/2000] - Loss: 1.3519; AAPD: 2.341400146484375
Epoch [1272/2000] - Loss: 1.0093; AAPD: 1.5605671405792236
Epoch [1273/2000] - Loss: 1.3665; AAPD: 2.4185142517089844
Epoch [1274/2000] - Loss: 1.2307; AAPD: 2.0792455673217773
Epoch [1275/2000] - Loss: 1.1686; AAPD: 1.6046993732452393
Epoch [1276/2000] - Loss: 1.0379; AAPD: 1.6365337371826172
Epoch [1277/2000] - Loss: 1.0975; AAPD: 1.6134157180786133
Epoch [1278/2000] - Loss: 1.1842; AAPD: 1.6615763902664185
Epoch [1279/2000] - Loss: 25.6467; AAPD: 3.5969045162200928
Epoch [1280/2000] - Loss: 1.0139; AAPD: 1.5037684440612793
Epoch [1281/2000] - Loss: 1.8051; AAPD: 2.1311824321746826
Epoch [1282/2000] - Loss: 1.1489; AAPD: 1.5653952360153198
Epoch [1283/2000] - Loss: 4.8051; AAPD: 3.2461583614349365
Epoch [1284/2000] - Loss: 1.1462; AAPD: 1.5275505781173706
Epoch [1285/2000] - Loss: 1.1162; AAPD: 3.0249552726745605
Epoch [1286/2000] - Loss: 22.3468; AAPD: 1.8006250858306885
Epoch [1287/2000] - Loss: 1.0893; AAPD: 1.8076345920562

Epoch [1411/2000] - Loss: 1.0489; AAPD: 1.9530894756317139
Epoch [1412/2000] - Loss: 1.9533; AAPD: 2.3230626583099365
Epoch [1413/2000] - Loss: 1.1726; AAPD: 2.7278666496276855
Epoch [1414/2000] - Loss: 1.1994; AAPD: 2.7333977222442627
Epoch [1415/2000] - Loss: 1.0050; AAPD: 1.5642155408859253
Epoch [1416/2000] - Loss: 2.0345; AAPD: 2.167253017425537
Epoch [1417/2000] - Loss: 1.0263; AAPD: 1.375268578529358
Epoch [1418/2000] - Loss: 1.3692; AAPD: 2.178361415863037
Epoch [1419/2000] - Loss: 1.0243; AAPD: 1.3591684103012085
Epoch [1420/2000] - Loss: 1.4291; AAPD: 3.84053897857666
Epoch [1421/2000] - Loss: 1.3433; AAPD: 1.734045147895813
Epoch [1422/2000] - Loss: 1.2810; AAPD: 1.533601999282837
Epoch [1423/2000] - Loss: 1.1926; AAPD: 1.7563071250915527
Epoch [1424/2000] - Loss: 2.1098; AAPD: 3.6586720943450928
Epoch [1425/2000] - Loss: 1.0380; AAPD: 1.443503975868225
Epoch [1426/2000] - Loss: 1.4334; AAPD: 2.6940536499023438
Epoch [1427/2000] - Loss: 3.2879; AAPD: 4.209037780761719
Epoch 

Epoch [1551/2000] - Loss: 1.3199; AAPD: 2.110645294189453
Epoch [1552/2000] - Loss: 1.1000; AAPD: 1.588879942893982
Epoch [1553/2000] - Loss: 1.6459; AAPD: 1.7201184034347534
Epoch [1554/2000] - Loss: 1.8658; AAPD: 1.8709551095962524
Epoch [1555/2000] - Loss: 2.0302; AAPD: 2.084069013595581
Epoch [1556/2000] - Loss: 1.0373; AAPD: 1.5699442625045776
Epoch [1557/2000] - Loss: 1.0391; AAPD: 1.2347337007522583
Epoch [1558/2000] - Loss: 1.2683; AAPD: 1.6720490455627441
Epoch [1559/2000] - Loss: 1.0633; AAPD: 1.6097660064697266
Epoch [1560/2000] - Loss: 1.9843; AAPD: 2.585076093673706
Epoch [1561/2000] - Loss: 2.1096; AAPD: 1.4302785396575928
Epoch [1562/2000] - Loss: 864.1713; AAPD: 1.9928312301635742
Epoch [1563/2000] - Loss: 1.4985; AAPD: 2.166461229324341
Epoch [1564/2000] - Loss: 1.0400; AAPD: 1.9217242002487183
Epoch [1565/2000] - Loss: 0.9854; AAPD: 1.2188185453414917
Epoch [1566/2000] - Loss: 1.5281; AAPD: 1.8440706729888916
Epoch [1567/2000] - Loss: 1.3282; AAPD: 2.574732542037964
E

Epoch [1691/2000] - Loss: 1.2067; AAPD: 1.7289091348648071
Epoch [1692/2000] - Loss: 1.0614; AAPD: 1.4762203693389893
Epoch [1693/2000] - Loss: 1.0678; AAPD: 1.8242123126983643
Epoch [1694/2000] - Loss: 1.1171; AAPD: 1.6839150190353394
Epoch [1695/2000] - Loss: 2.1270; AAPD: 2.6362600326538086
Epoch [1696/2000] - Loss: 3.4235; AAPD: 2.116102933883667
Epoch [1697/2000] - Loss: 1.0810; AAPD: 1.5646884441375732
Epoch [1698/2000] - Loss: 1.4029; AAPD: 1.7237145900726318
Epoch [1699/2000] - Loss: 0.9827; AAPD: 1.2138828039169312
Epoch [1700/2000] - Loss: 1.6239; AAPD: 2.0785751342773438
Epoch [1701/2000] - Loss: 2.7034; AAPD: 4.16233491897583
Epoch [1702/2000] - Loss: 1.0756; AAPD: 1.5477513074874878
Epoch [1703/2000] - Loss: 1.0778; AAPD: 2.0131676197052
Epoch [1704/2000] - Loss: 1.5701; AAPD: 2.1933095455169678
Epoch [1705/2000] - Loss: 1.0691; AAPD: 1.2417665719985962
Epoch [1706/2000] - Loss: 2.3464; AAPD: 2.8036019802093506
Epoch [1707/2000] - Loss: 1.0008; AAPD: 1.455980658531189
Epoc

Epoch [1831/2000] - Loss: 1.2121; AAPD: 1.5636497735977173
Epoch [1832/2000] - Loss: 1.3083; AAPD: 3.210921287536621
Epoch [1833/2000] - Loss: 1.0188; AAPD: 1.772727608680725
Epoch [1834/2000] - Loss: 1.0037; AAPD: 1.2365713119506836
Epoch [1835/2000] - Loss: 2.2306; AAPD: 2.624253034591675
Epoch [1836/2000] - Loss: 1.1871; AAPD: 1.6577576398849487
Epoch [1837/2000] - Loss: 1.2786; AAPD: 2.004490613937378
Epoch [1838/2000] - Loss: 1.0240; AAPD: 1.4156748056411743
Epoch [1839/2000] - Loss: 1.3349; AAPD: 1.7307788133621216
Epoch [1840/2000] - Loss: 1.1632; AAPD: 1.3120918273925781
Epoch [1841/2000] - Loss: 2.9197; AAPD: 2.1500425338745117
Epoch [1842/2000] - Loss: 1.1650; AAPD: 1.6344795227050781
Epoch [1843/2000] - Loss: 1.8378; AAPD: 4.25306510925293
Epoch [1844/2000] - Loss: 1.3191; AAPD: 2.2365944385528564
Epoch [1845/2000] - Loss: 1.7616; AAPD: 2.730569839477539
Epoch [1846/2000] - Loss: 1.3237; AAPD: 3.2894132137298584
Epoch [1847/2000] - Loss: 1.2529; AAPD: 1.634290099143982
Epoch

Epoch [1971/2000] - Loss: 1.1735; AAPD: 1.6238669157028198
Epoch [1972/2000] - Loss: 1.0221; AAPD: 1.378380537033081
Epoch [1973/2000] - Loss: 1.5454; AAPD: 2.859626054763794
Epoch [1974/2000] - Loss: 2.8806; AAPD: 1.600611925125122
Epoch [1975/2000] - Loss: 4.9025; AAPD: 3.7580642700195312
Epoch [1976/2000] - Loss: 1.7520; AAPD: 3.2641263008117676
Epoch [1977/2000] - Loss: 1.0964; AAPD: 1.4537783861160278
Epoch [1978/2000] - Loss: 1.5471; AAPD: 1.5324980020523071
Epoch [1979/2000] - Loss: 1.1119; AAPD: 1.3251594305038452
Epoch [1980/2000] - Loss: 1.2222; AAPD: 1.902136206626892
Epoch [1981/2000] - Loss: 1.3817; AAPD: 1.813990831375122
Epoch [1982/2000] - Loss: 1.9486; AAPD: 1.674411416053772
Epoch [1983/2000] - Loss: 1.4504; AAPD: 2.4721240997314453
Epoch [1984/2000] - Loss: 1.0451; AAPD: 1.217984914779663
Epoch [1985/2000] - Loss: 2.1949; AAPD: 2.98054575920105
Epoch [1986/2000] - Loss: 1.0196; AAPD: 1.5647027492523193
Epoch [1987/2000] - Loss: 1.0919; AAPD: 1.5105974674224854
Epoch 

In [29]:
def train_model(model, model_name, train_loader, dev_loader, target_idx=1, num_epochs=1000, learning_rate=0.1, weight_decay=1e-4):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[250, 500, 750], gamma=0.1)
    
    best_model = None
    best_aapd = None

    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            input_data = batch['input_data']
            target1 = batch['out1']
            target2 = batch['out2']
            target3 = batch['out3']
            
            optimizer.zero_grad()
            outputs = model(input_data)
            if target_idx == 1:
                train_loss = criterion(outputs, target1)
            elif target_idx == 2:
                train_loss = criterion(outputs, target2)
            elif target_idx == 3:
                train_loss = criterion(outputs, target3)
            else:
                print("Invalid target_idx")
            
            train_loss.backward()
            optimizer.step()
        scheduler.step()
        
        model.eval()
        
        total_aapd = 0
        num_samples = 0
        
        with torch.no_grad():
            for batch in dev_loader:
                dev_int = batch['input_data']
                dev_out1 = batch['out1']
                dev_out2 = batch['out2']
                dev_out3 = batch['out3']
                
                predictions = model(dev_int)
                
                if target_idx == 1:
                    batch_aapd = AAPD(predictions, dev_out1)
                elif target_idx == 2:
                    batch_aapd = AAPD(predictions, dev_out2)
                elif target_idx == 3:
                    batch_aapd = AAPD(predictions, dev_out3)
                else:
                    print("Invalid target_idx")
                    
                total_aapd += batch_aapd * len(dev_int)
                num_samples += len(dev_int)
            
            aapd = total_aapd / num_samples
            
            if best_aapd == None or aapd < best_aapd:
                best_aapd = aapd
                best_model = model
                torch.save(model.state_dict(), model_name)
        
        print(f'Epoch [{epoch + 1}/{num_epochs}] - Loss: {train_loss.item():.4f}; AAPD: {aapd}')

    return model

In [31]:
model = IAModel(input_dimension=train_int.shape[1], output_dimension=train_out1.shape[1])
model_MSE_1 = train_model(model, 'best_model_MSE_1.pth', train_dataloader, dev_dataloader, target_idx=1)
torch.save(model_MSE_1.state_dict(), "model_MSE_1.pth")

Epoch [1/1000] - Loss: 3981883.0000; AAPD: 36.84638977050781
Epoch [2/1000] - Loss: 6570370.0000; AAPD: 7.828675270080566
Epoch [3/1000] - Loss: 5708913.0000; AAPD: 9.459177017211914
Epoch [4/1000] - Loss: 2813490.5000; AAPD: 17.237632751464844
Epoch [5/1000] - Loss: 2869575.5000; AAPD: 44.168190002441406
Epoch [6/1000] - Loss: 1093560.7500; AAPD: 4.587279796600342
Epoch [7/1000] - Loss: 3345217.5000; AAPD: 34.196327209472656
Epoch [8/1000] - Loss: 54119872.0000; AAPD: 34.389747619628906
Epoch [9/1000] - Loss: 6883877.0000; AAPD: 33.78701400756836
Epoch [10/1000] - Loss: 2016597.1250; AAPD: 35.850494384765625
Epoch [11/1000] - Loss: 2428995.5000; AAPD: 33.4381217956543
Epoch [12/1000] - Loss: 8707390.0000; AAPD: 33.4765625
Epoch [13/1000] - Loss: 9263953.0000; AAPD: 34.694034576416016
Epoch [14/1000] - Loss: 5605779.5000; AAPD: 34.92958068847656
Epoch [15/1000] - Loss: 4632325.0000; AAPD: 33.471099853515625
Epoch [16/1000] - Loss: 7077759.0000; AAPD: 34.39447021484375
Epoch [17/1000] -

Epoch [132/1000] - Loss: 6857838.0000; AAPD: 23.14415168762207
Epoch [133/1000] - Loss: 6327085.5000; AAPD: 13.646100044250488
Epoch [134/1000] - Loss: 5136164.5000; AAPD: 15.021327018737793
Epoch [135/1000] - Loss: 4920066.0000; AAPD: 17.734106063842773
Epoch [136/1000] - Loss: 4379379.5000; AAPD: 9.24532699584961
Epoch [137/1000] - Loss: 6508842.0000; AAPD: 13.70814037322998
Epoch [138/1000] - Loss: 18910660.0000; AAPD: 22.654123306274414
Epoch [139/1000] - Loss: 768214.3750; AAPD: 7.084895610809326
Epoch [140/1000] - Loss: 69989832.0000; AAPD: 7.379437446594238
Epoch [141/1000] - Loss: 18960638.0000; AAPD: 7.437113285064697
Epoch [142/1000] - Loss: 6017515.5000; AAPD: 7.7998433113098145
Epoch [143/1000] - Loss: 8054452.5000; AAPD: 8.614048957824707
Epoch [144/1000] - Loss: 13739521.0000; AAPD: 9.15752124786377
Epoch [145/1000] - Loss: 80314544.0000; AAPD: 10.257402420043945
Epoch [146/1000] - Loss: 27957166.0000; AAPD: 11.057765007019043
Epoch [147/1000] - Loss: 31932192.0000; AAPD:

Epoch [261/1000] - Loss: 3013545.2500; AAPD: 4.453546524047852
Epoch [262/1000] - Loss: 2244564.5000; AAPD: 3.338218927383423
Epoch [263/1000] - Loss: 9808633.0000; AAPD: 4.093448638916016
Epoch [264/1000] - Loss: 1598964.1250; AAPD: 3.657832145690918
Epoch [265/1000] - Loss: 9931567.0000; AAPD: 3.72007417678833
Epoch [266/1000] - Loss: 11388438.0000; AAPD: 3.569589376449585
Epoch [267/1000] - Loss: 3057628.2500; AAPD: 4.347651958465576
Epoch [268/1000] - Loss: 1887929.5000; AAPD: 3.1412224769592285
Epoch [269/1000] - Loss: 3264657.5000; AAPD: 3.680954694747925
Epoch [270/1000] - Loss: 261557.9844; AAPD: 3.4226973056793213
Epoch [271/1000] - Loss: 72864712.0000; AAPD: 4.255995750427246
Epoch [272/1000] - Loss: 69371728.0000; AAPD: 4.006348133087158
Epoch [273/1000] - Loss: 98017168.0000; AAPD: 3.181023597717285
Epoch [274/1000] - Loss: 1755583.7500; AAPD: 4.0562005043029785
Epoch [275/1000] - Loss: 7416347.5000; AAPD: 4.151138782501221
Epoch [276/1000] - Loss: 26980778.0000; AAPD: 3.36

Epoch [391/1000] - Loss: 32126488.0000; AAPD: 9.905457496643066
Epoch [392/1000] - Loss: 13139545.0000; AAPD: 2.2551586627960205
Epoch [393/1000] - Loss: 89353.8359; AAPD: 2.5830864906311035
Epoch [394/1000] - Loss: 1068210.5000; AAPD: 2.6987245082855225
Epoch [395/1000] - Loss: 909916.9375; AAPD: 1.4276964664459229
Epoch [396/1000] - Loss: 167964.1406; AAPD: 2.7880167961120605
Epoch [397/1000] - Loss: 249183.6719; AAPD: 5.958190441131592
Epoch [398/1000] - Loss: 161101.1406; AAPD: 2.5818660259246826
Epoch [399/1000] - Loss: 585343.3750; AAPD: 2.2540979385375977
Epoch [400/1000] - Loss: 134779.3125; AAPD: 3.351257085800171
Epoch [401/1000] - Loss: 1555214.8750; AAPD: 3.795039176940918
Epoch [402/1000] - Loss: 96576.2969; AAPD: 2.3848605155944824
Epoch [403/1000] - Loss: 355881.2812; AAPD: 2.852203369140625
Epoch [404/1000] - Loss: 2156549.2500; AAPD: 3.88813853263855
Epoch [405/1000] - Loss: 167581.9062; AAPD: 4.144285202026367
Epoch [406/1000] - Loss: 1291569.1250; AAPD: 4.39464712142

Epoch [522/1000] - Loss: 139918.0312; AAPD: 1.0101635456085205
Epoch [523/1000] - Loss: 923702.3125; AAPD: 1.1941372156143188
Epoch [524/1000] - Loss: 40493572.0000; AAPD: 0.8774029016494751
Epoch [525/1000] - Loss: 7033941.5000; AAPD: 0.9048607349395752
Epoch [526/1000] - Loss: 2411124.0000; AAPD: 0.9243351221084595
Epoch [527/1000] - Loss: 243606.8281; AAPD: 0.8956783413887024
Epoch [528/1000] - Loss: 162408.5781; AAPD: 0.8803731799125671
Epoch [529/1000] - Loss: 3352364.7500; AAPD: 1.0614233016967773
Epoch [530/1000] - Loss: 4943449.5000; AAPD: 1.2791085243225098
Epoch [531/1000] - Loss: 1209416.6250; AAPD: 0.92435222864151
Epoch [532/1000] - Loss: 104498.7188; AAPD: 1.2381728887557983
Epoch [533/1000] - Loss: 558034.1250; AAPD: 1.1921625137329102
Epoch [534/1000] - Loss: 697920.8125; AAPD: 1.0267760753631592
Epoch [535/1000] - Loss: 171515.2188; AAPD: 1.0392475128173828
Epoch [536/1000] - Loss: 147653.1562; AAPD: 0.9949643611907959
Epoch [537/1000] - Loss: 9987459.0000; AAPD: 0.924

Epoch [652/1000] - Loss: 130421.1953; AAPD: 0.7347515225410461
Epoch [653/1000] - Loss: 1321531.1250; AAPD: 0.8010669946670532
Epoch [654/1000] - Loss: 567818.8125; AAPD: 0.7364276647567749
Epoch [655/1000] - Loss: 102990.5703; AAPD: 1.4561710357666016
Epoch [656/1000] - Loss: 336150.7188; AAPD: 0.8200136423110962
Epoch [657/1000] - Loss: 294901.5938; AAPD: 0.8773903250694275
Epoch [658/1000] - Loss: 709733.0625; AAPD: 0.8561293482780457
Epoch [659/1000] - Loss: 1650030.1250; AAPD: 0.8856849670410156
Epoch [660/1000] - Loss: 88823.4688; AAPD: 0.851274311542511
Epoch [661/1000] - Loss: 55156.4102; AAPD: 0.7956981658935547
Epoch [662/1000] - Loss: 447579.6250; AAPD: 0.7608660459518433
Epoch [663/1000] - Loss: 223537.9375; AAPD: 0.7801295518875122
Epoch [664/1000] - Loss: 1261471.3750; AAPD: 0.7819609045982361
Epoch [665/1000] - Loss: 665290.0625; AAPD: 1.0791726112365723
Epoch [666/1000] - Loss: 2591686.2500; AAPD: 0.9542134404182434
Epoch [667/1000] - Loss: 1255866.0000; AAPD: 0.7560257

Epoch [782/1000] - Loss: 2762882.7500; AAPD: 0.6139033436775208
Epoch [783/1000] - Loss: 435761.4062; AAPD: 0.5438340902328491
Epoch [784/1000] - Loss: 576335.3125; AAPD: 0.5582226514816284
Epoch [785/1000] - Loss: 719652.0625; AAPD: 0.5495432019233704
Epoch [786/1000] - Loss: 372974.2812; AAPD: 0.5825815796852112
Epoch [787/1000] - Loss: 782350.4375; AAPD: 0.5675207376480103
Epoch [788/1000] - Loss: 34499284.0000; AAPD: 0.5471594929695129
Epoch [789/1000] - Loss: 153498.7812; AAPD: 0.5797268152236938
Epoch [790/1000] - Loss: 210958.5156; AAPD: 0.5506628155708313
Epoch [791/1000] - Loss: 56909636.0000; AAPD: 0.5774232149124146
Epoch [792/1000] - Loss: 363422.1875; AAPD: 0.6012020707130432
Epoch [793/1000] - Loss: 4874187.5000; AAPD: 0.5961069464683533
Epoch [794/1000] - Loss: 418490.5625; AAPD: 0.565826952457428
Epoch [795/1000] - Loss: 4852039.5000; AAPD: 0.5889132022857666
Epoch [796/1000] - Loss: 2365714.7500; AAPD: 0.5797341465950012
Epoch [797/1000] - Loss: 30387172.0000; AAPD: 0.

Epoch [912/1000] - Loss: 138140.4531; AAPD: 1.018357515335083
Epoch [913/1000] - Loss: 12787237.0000; AAPD: 1.0200029611587524
Epoch [914/1000] - Loss: 1491336.6250; AAPD: 1.0159646272659302
Epoch [915/1000] - Loss: 569562.6875; AAPD: 1.0030221939086914
Epoch [916/1000] - Loss: 158425.0469; AAPD: 0.995906412601471
Epoch [917/1000] - Loss: 428729.1875; AAPD: 1.030644178390503
Epoch [918/1000] - Loss: 501992.6875; AAPD: 1.057267427444458
Epoch [919/1000] - Loss: 5986829.5000; AAPD: 1.0368319749832153
Epoch [920/1000] - Loss: 122043.6328; AAPD: 1.0548421144485474
Epoch [921/1000] - Loss: 108780.3828; AAPD: 1.0449903011322021
Epoch [922/1000] - Loss: 4887759.0000; AAPD: 1.083622694015503
Epoch [923/1000] - Loss: 144314.3438; AAPD: 1.0478848218917847
Epoch [924/1000] - Loss: 312541.7188; AAPD: 1.129390001296997
Epoch [925/1000] - Loss: 262957.0625; AAPD: 1.0663801431655884
Epoch [926/1000] - Loss: 262991.3438; AAPD: 1.0712833404541016
Epoch [927/1000] - Loss: 3568411.7500; AAPD: 1.116319179

In [32]:
model = IAModel(input_dimension=train_int.shape[1], output_dimension=train_out2.shape[1])
model_MSE_2 = train_model(model, 'best_model_MSE_2.pth', train_dataloader, dev_dataloader, target_idx=2)
torch.save(model_MSE_2.state_dict(), "model_MSE_2.pth")

Epoch [1/1000] - Loss: 0.0059; AAPD: 6.2145867347717285
Epoch [2/1000] - Loss: 0.0015; AAPD: 5.8686981201171875
Epoch [3/1000] - Loss: 0.0026; AAPD: 5.514272689819336
Epoch [4/1000] - Loss: 0.0016; AAPD: 5.5012946128845215
Epoch [5/1000] - Loss: 0.0030; AAPD: 5.905490398406982
Epoch [6/1000] - Loss: 0.0010; AAPD: 4.34203577041626
Epoch [7/1000] - Loss: 0.0022; AAPD: 4.750061988830566
Epoch [8/1000] - Loss: 0.0026; AAPD: 3.8513972759246826
Epoch [9/1000] - Loss: 0.0030; AAPD: 5.636204242706299
Epoch [10/1000] - Loss: 0.0137; AAPD: 123.87358093261719
Epoch [11/1000] - Loss: 0.0050; AAPD: 5.282081604003906
Epoch [12/1000] - Loss: 0.0025; AAPD: 5.313068866729736
Epoch [13/1000] - Loss: 0.0039; AAPD: 5.297768592834473
Epoch [14/1000] - Loss: 0.0031; AAPD: 5.310486316680908
Epoch [15/1000] - Loss: 0.0097; AAPD: 5.6554059982299805
Epoch [16/1000] - Loss: 0.0046; AAPD: 5.351686954498291
Epoch [17/1000] - Loss: 0.0032; AAPD: 5.441605091094971
Epoch [18/1000] - Loss: 0.0017; AAPD: 4.366918563842

Epoch [147/1000] - Loss: 4.0832; AAPD: 2499.336669921875
Epoch [148/1000] - Loss: 0.6891; AAPD: 1031.4530029296875
Epoch [149/1000] - Loss: 0.0634; AAPD: 306.34765625
Epoch [150/1000] - Loss: 0.0036; AAPD: 59.841670989990234
Epoch [151/1000] - Loss: 1.5848; AAPD: 1630.1160888671875
Epoch [152/1000] - Loss: 0.1306; AAPD: 445.6032409667969
Epoch [153/1000] - Loss: 0.0043; AAPD: 77.26447296142578
Epoch [154/1000] - Loss: 0.0035; AAPD: 7.264284133911133
Epoch [155/1000] - Loss: 0.0025; AAPD: 4.879648685455322
Epoch [156/1000] - Loss: 0.0021; AAPD: 5.204026699066162
Epoch [157/1000] - Loss: 0.0023; AAPD: 5.17747688293457
Epoch [158/1000] - Loss: 0.0043; AAPD: 5.030097961425781
Epoch [159/1000] - Loss: 0.0064; AAPD: 4.871355056762695
Epoch [160/1000] - Loss: 0.0009; AAPD: 5.105198860168457
Epoch [161/1000] - Loss: 8.6632; AAPD: 3240.42822265625
Epoch [162/1000] - Loss: 2.4668; AAPD: 1647.2213134765625
Epoch [163/1000] - Loss: 0.4361; AAPD: 657.4839477539062
Epoch [164/1000] - Loss: 0.0472; A

Epoch [291/1000] - Loss: 0.0011; AAPD: 4.412954330444336
Epoch [292/1000] - Loss: 0.0042; AAPD: 5.983149528503418
Epoch [293/1000] - Loss: 0.0023; AAPD: 5.29151725769043
Epoch [294/1000] - Loss: 0.0017; AAPD: 5.248634338378906
Epoch [295/1000] - Loss: 0.0022; AAPD: 4.597679138183594
Epoch [296/1000] - Loss: 0.0003; AAPD: 6.521917819976807
Epoch [297/1000] - Loss: 0.0055; AAPD: 6.075915813446045
Epoch [298/1000] - Loss: 0.0041; AAPD: 5.681215286254883
Epoch [299/1000] - Loss: 0.0021; AAPD: 8.079827308654785
Epoch [300/1000] - Loss: 0.0027; AAPD: 5.045929908752441
Epoch [301/1000] - Loss: 0.0014; AAPD: 8.48855209350586
Epoch [302/1000] - Loss: 0.0012; AAPD: 4.6255998611450195
Epoch [303/1000] - Loss: 0.0014; AAPD: 17.533369064331055
Epoch [304/1000] - Loss: 0.0016; AAPD: 11.795307159423828
Epoch [305/1000] - Loss: 0.0010; AAPD: 9.592155456542969
Epoch [306/1000] - Loss: 0.0007; AAPD: 12.565857887268066
Epoch [307/1000] - Loss: 0.0008; AAPD: 10.559599876403809
Epoch [308/1000] - Loss: 0.0

Epoch [435/1000] - Loss: 0.0020; AAPD: 7.681858062744141
Epoch [436/1000] - Loss: 0.0031; AAPD: 13.346392631530762
Epoch [437/1000] - Loss: 0.0005; AAPD: 7.954648494720459
Epoch [438/1000] - Loss: 0.0022; AAPD: 8.480018615722656
Epoch [439/1000] - Loss: 0.0017; AAPD: 8.699400901794434
Epoch [440/1000] - Loss: 0.0012; AAPD: 9.138018608093262
Epoch [441/1000] - Loss: 0.0036; AAPD: 11.077425003051758
Epoch [442/1000] - Loss: 0.0009; AAPD: 8.638158798217773
Epoch [443/1000] - Loss: 0.0012; AAPD: 19.033876419067383
Epoch [444/1000] - Loss: 0.0017; AAPD: 8.453481674194336
Epoch [445/1000] - Loss: 0.0007; AAPD: 8.428838729858398
Epoch [446/1000] - Loss: 0.0011; AAPD: 9.39182186126709
Epoch [447/1000] - Loss: 0.0029; AAPD: 4.538508892059326
Epoch [448/1000] - Loss: 0.0016; AAPD: 7.375894069671631
Epoch [449/1000] - Loss: 0.0019; AAPD: 8.104159355163574
Epoch [450/1000] - Loss: 0.0037; AAPD: 10.473816871643066
Epoch [451/1000] - Loss: 0.0010; AAPD: 4.5102105140686035
Epoch [452/1000] - Loss: 0.

Epoch [579/1000] - Loss: 0.0011; AAPD: 4.820505142211914
Epoch [580/1000] - Loss: 0.0012; AAPD: 4.544028282165527
Epoch [581/1000] - Loss: 0.0005; AAPD: 4.782678127288818
Epoch [582/1000] - Loss: 0.0009; AAPD: 5.185828685760498
Epoch [583/1000] - Loss: 0.0005; AAPD: 4.829641342163086
Epoch [584/1000] - Loss: 0.0004; AAPD: 5.016762733459473
Epoch [585/1000] - Loss: 0.0012; AAPD: 4.9118218421936035
Epoch [586/1000] - Loss: 0.0011; AAPD: 4.609354019165039
Epoch [587/1000] - Loss: 0.0024; AAPD: 5.447302341461182
Epoch [588/1000] - Loss: 0.0003; AAPD: 4.667316913604736
Epoch [589/1000] - Loss: 0.0007; AAPD: 4.82318639755249
Epoch [590/1000] - Loss: 0.0004; AAPD: 5.019527912139893
Epoch [591/1000] - Loss: 0.0003; AAPD: 4.696702480316162
Epoch [592/1000] - Loss: 0.0018; AAPD: 5.520680904388428
Epoch [593/1000] - Loss: 0.0007; AAPD: 4.7029595375061035
Epoch [594/1000] - Loss: 0.0002; AAPD: 4.857153415679932
Epoch [595/1000] - Loss: 0.0010; AAPD: 4.820481777191162
Epoch [596/1000] - Loss: 0.001

Epoch [723/1000] - Loss: 0.0005; AAPD: 4.820993900299072
Epoch [724/1000] - Loss: 0.0014; AAPD: 5.105757713317871
Epoch [725/1000] - Loss: 0.0012; AAPD: 4.902627944946289
Epoch [726/1000] - Loss: 0.0013; AAPD: 5.520339488983154
Epoch [727/1000] - Loss: 0.0002; AAPD: 5.007190227508545
Epoch [728/1000] - Loss: 0.0015; AAPD: 5.371786117553711
Epoch [729/1000] - Loss: 0.0009; AAPD: 4.819275856018066
Epoch [730/1000] - Loss: 0.0017; AAPD: 4.703530788421631
Epoch [731/1000] - Loss: 0.0008; AAPD: 5.163214206695557
Epoch [732/1000] - Loss: 0.0010; AAPD: 5.3675642013549805
Epoch [733/1000] - Loss: 0.0003; AAPD: 5.208938121795654
Epoch [734/1000] - Loss: 0.0007; AAPD: 4.794325351715088
Epoch [735/1000] - Loss: 0.0015; AAPD: 5.360342025756836
Epoch [736/1000] - Loss: 0.0009; AAPD: 5.1046342849731445
Epoch [737/1000] - Loss: 0.0031; AAPD: 5.008069038391113
Epoch [738/1000] - Loss: 0.0005; AAPD: 4.618483066558838
Epoch [739/1000] - Loss: 0.0008; AAPD: 5.30181884765625
Epoch [740/1000] - Loss: 0.000

Epoch [867/1000] - Loss: 0.0012; AAPD: 5.000239849090576
Epoch [868/1000] - Loss: 0.0002; AAPD: 4.792843818664551
Epoch [869/1000] - Loss: 0.0022; AAPD: 4.971897125244141
Epoch [870/1000] - Loss: 0.0005; AAPD: 4.913881301879883
Epoch [871/1000] - Loss: 0.0007; AAPD: 4.912372589111328
Epoch [872/1000] - Loss: 0.0008; AAPD: 4.714375972747803
Epoch [873/1000] - Loss: 0.0007; AAPD: 4.953449249267578
Epoch [874/1000] - Loss: 0.0006; AAPD: 4.81397008895874
Epoch [875/1000] - Loss: 0.0006; AAPD: 4.874979496002197
Epoch [876/1000] - Loss: 0.0004; AAPD: 4.807338714599609
Epoch [877/1000] - Loss: 0.0004; AAPD: 4.77714729309082
Epoch [878/1000] - Loss: 0.0011; AAPD: 4.892035484313965
Epoch [879/1000] - Loss: 0.0012; AAPD: 4.831558704376221
Epoch [880/1000] - Loss: 0.0008; AAPD: 4.663914680480957
Epoch [881/1000] - Loss: 0.0013; AAPD: 4.8110671043396
Epoch [882/1000] - Loss: 0.0008; AAPD: 4.904122829437256
Epoch [883/1000] - Loss: 0.0005; AAPD: 4.8405327796936035
Epoch [884/1000] - Loss: 0.0011; A

In [33]:
model = IAModel(input_dimension=train_int.shape[1], output_dimension=train_out3.shape[1])
model_MSE_3 = train_model(model, 'best_model_MSE_3.pth', train_dataloader, dev_dataloader, target_idx=3)
torch.save(model_MSE_3.state_dict(), "model_MSE_3.pth")

Epoch [1/1000] - Loss: 0.0049; AAPD: 17.715662002563477
Epoch [2/1000] - Loss: 0.0006; AAPD: 15.953250885009766
Epoch [3/1000] - Loss: 0.0011; AAPD: 14.039657592773438
Epoch [4/1000] - Loss: 0.0002; AAPD: 12.542947769165039
Epoch [5/1000] - Loss: 0.0016; AAPD: 13.063459396362305
Epoch [6/1000] - Loss: 0.0032; AAPD: 13.249251365661621
Epoch [7/1000] - Loss: 0.0045; AAPD: 14.004016876220703
Epoch [8/1000] - Loss: 0.0012; AAPD: 13.558238983154297
Epoch [9/1000] - Loss: 0.0004; AAPD: 16.228723526000977
Epoch [10/1000] - Loss: 0.0002; AAPD: 13.929708480834961
Epoch [11/1000] - Loss: 0.3137; AAPD: 6096.43017578125
Epoch [12/1000] - Loss: 0.0005; AAPD: 13.089045524597168
Epoch [13/1000] - Loss: 0.0024; AAPD: 12.423871040344238
Epoch [14/1000] - Loss: 0.0012; AAPD: 86.07958221435547
Epoch [15/1000] - Loss: 0.0010; AAPD: 13.550871849060059
Epoch [16/1000] - Loss: 0.0007; AAPD: 12.391279220581055
Epoch [17/1000] - Loss: 0.0032; AAPD: 14.869194030761719
Epoch [18/1000] - Loss: 0.0010; AAPD: 14.40

Epoch [148/1000] - Loss: 0.0194; AAPD: 1831.3726806640625
Epoch [149/1000] - Loss: 0.0050; AAPD: 264.9447021484375
Epoch [150/1000] - Loss: 0.0007; AAPD: 30.807531356811523
Epoch [151/1000] - Loss: 0.0008; AAPD: 14.058629989624023
Epoch [152/1000] - Loss: 20.3803; AAPD: 57805.01171875
Epoch [153/1000] - Loss: 3.7120; AAPD: 24753.09375
Epoch [154/1000] - Loss: 0.3633; AAPD: 7782.2880859375
Epoch [155/1000] - Loss: 0.0155; AAPD: 1602.3114013671875
Epoch [156/1000] - Loss: 0.0004; AAPD: 187.5637664794922
Epoch [157/1000] - Loss: 0.0012; AAPD: 19.892940521240234
Epoch [158/1000] - Loss: 0.0005; AAPD: 13.204229354858398
Epoch [159/1000] - Loss: 1.1745; AAPD: 11782.2626953125
Epoch [160/1000] - Loss: 0.0004; AAPD: 108.16059875488281
Epoch [161/1000] - Loss: 0.0865; AAPD: 1705.5130615234375
Epoch [162/1000] - Loss: 0.0012; AAPD: 13.51927661895752
Epoch [163/1000] - Loss: 0.0004; AAPD: 80.60726928710938
Epoch [164/1000] - Loss: 0.1756; AAPD: 5428.85986328125
Epoch [165/1000] - Loss: 0.0039; AA

Epoch [292/1000] - Loss: 0.0026; AAPD: 17.920082092285156
Epoch [293/1000] - Loss: 0.0042; AAPD: 15.069082260131836
Epoch [294/1000] - Loss: 0.0008; AAPD: 24.375606536865234
Epoch [295/1000] - Loss: 0.0009; AAPD: 17.476369857788086
Epoch [296/1000] - Loss: 0.0009; AAPD: 15.856122016906738
Epoch [297/1000] - Loss: 0.0017; AAPD: 23.10577964782715
Epoch [298/1000] - Loss: 0.0003; AAPD: 13.961724281311035
Epoch [299/1000] - Loss: 0.0006; AAPD: 24.27754783630371
Epoch [300/1000] - Loss: 0.0025; AAPD: 23.80522346496582
Epoch [301/1000] - Loss: 0.0023; AAPD: 41.92120361328125
Epoch [302/1000] - Loss: 0.0004; AAPD: 30.548381805419922
Epoch [303/1000] - Loss: 0.0026; AAPD: 45.30094528198242
Epoch [304/1000] - Loss: 0.0003; AAPD: 24.268489837646484
Epoch [305/1000] - Loss: 0.0047; AAPD: 41.84573745727539
Epoch [306/1000] - Loss: 0.0003; AAPD: 21.247032165527344
Epoch [307/1000] - Loss: 0.0010; AAPD: 27.403104782104492
Epoch [308/1000] - Loss: 0.0004; AAPD: 19.38285255432129
Epoch [309/1000] - Lo

Epoch [435/1000] - Loss: 0.0017; AAPD: 30.77313995361328
Epoch [436/1000] - Loss: 0.0021; AAPD: 29.763059616088867
Epoch [437/1000] - Loss: 0.0008; AAPD: 25.427602767944336
Epoch [438/1000] - Loss: 0.0010; AAPD: 25.323894500732422
Epoch [439/1000] - Loss: 0.0013; AAPD: 30.44930648803711
Epoch [440/1000] - Loss: 0.0021; AAPD: 27.6168212890625
Epoch [441/1000] - Loss: 0.0020; AAPD: 40.56349563598633
Epoch [442/1000] - Loss: 0.0006; AAPD: 18.95874786376953
Epoch [443/1000] - Loss: 0.0073; AAPD: 40.4547004699707
Epoch [444/1000] - Loss: 0.0040; AAPD: 63.844573974609375
Epoch [445/1000] - Loss: 0.0003; AAPD: 26.93996238708496
Epoch [446/1000] - Loss: 0.0025; AAPD: 25.650663375854492
Epoch [447/1000] - Loss: 0.0014; AAPD: 40.15703582763672
Epoch [448/1000] - Loss: 0.0008; AAPD: 23.518346786499023
Epoch [449/1000] - Loss: 0.0017; AAPD: 38.866844177246094
Epoch [450/1000] - Loss: 0.0038; AAPD: 35.6757698059082
Epoch [451/1000] - Loss: 0.0011; AAPD: 23.539737701416016
Epoch [452/1000] - Loss: 0

Epoch [578/1000] - Loss: 0.0017; AAPD: 19.638011932373047
Epoch [579/1000] - Loss: 0.0008; AAPD: 16.700782775878906
Epoch [580/1000] - Loss: 0.0019; AAPD: 16.17696189880371
Epoch [581/1000] - Loss: 0.0012; AAPD: 17.155532836914062
Epoch [582/1000] - Loss: 0.0003; AAPD: 15.319125175476074
Epoch [583/1000] - Loss: 0.0037; AAPD: 19.223979949951172
Epoch [584/1000] - Loss: 0.0004; AAPD: 17.984418869018555
Epoch [585/1000] - Loss: 0.0006; AAPD: 15.781944274902344
Epoch [586/1000] - Loss: 0.0001; AAPD: 15.022562980651855
Epoch [587/1000] - Loss: 0.0051; AAPD: 21.754812240600586
Epoch [588/1000] - Loss: 0.0003; AAPD: 17.735036849975586
Epoch [589/1000] - Loss: 0.0018; AAPD: 21.798030853271484
Epoch [590/1000] - Loss: 0.0002; AAPD: 15.194485664367676
Epoch [591/1000] - Loss: 0.0018; AAPD: 18.335649490356445
Epoch [592/1000] - Loss: 0.0005; AAPD: 17.784162521362305
Epoch [593/1000] - Loss: 0.0016; AAPD: 16.35139274597168
Epoch [594/1000] - Loss: 0.0012; AAPD: 14.301718711853027
Epoch [595/1000]

Epoch [721/1000] - Loss: 0.0005; AAPD: 15.909965515136719
Epoch [722/1000] - Loss: 0.0011; AAPD: 16.913755416870117
Epoch [723/1000] - Loss: 0.0008; AAPD: 13.85119915008545
Epoch [724/1000] - Loss: 0.0017; AAPD: 16.793134689331055
Epoch [725/1000] - Loss: 0.0013; AAPD: 16.30617332458496
Epoch [726/1000] - Loss: 0.0005; AAPD: 13.736793518066406
Epoch [727/1000] - Loss: 0.0025; AAPD: 15.133624076843262
Epoch [728/1000] - Loss: 0.0033; AAPD: 17.302627563476562
Epoch [729/1000] - Loss: 0.0008; AAPD: 15.4503173828125
Epoch [730/1000] - Loss: 0.0007; AAPD: 18.160709381103516
Epoch [731/1000] - Loss: 0.0018; AAPD: 19.385286331176758
Epoch [732/1000] - Loss: 0.0010; AAPD: 20.85465431213379
Epoch [733/1000] - Loss: 0.0006; AAPD: 19.428565979003906
Epoch [734/1000] - Loss: 0.0012; AAPD: 15.49794864654541
Epoch [735/1000] - Loss: 0.0002; AAPD: 16.931970596313477
Epoch [736/1000] - Loss: 0.0003; AAPD: 17.471017837524414
Epoch [737/1000] - Loss: 0.0008; AAPD: 16.86298179626465
Epoch [738/1000] - Lo

Epoch [863/1000] - Loss: 0.0013; AAPD: 13.302657127380371
Epoch [864/1000] - Loss: 0.0015; AAPD: 13.57000732421875
Epoch [865/1000] - Loss: 0.0004; AAPD: 14.815221786499023
Epoch [866/1000] - Loss: 0.0017; AAPD: 13.388447761535645
Epoch [867/1000] - Loss: 0.0002; AAPD: 13.974649429321289
Epoch [868/1000] - Loss: 0.0024; AAPD: 13.653751373291016
Epoch [869/1000] - Loss: 0.0005; AAPD: 12.850127220153809
Epoch [870/1000] - Loss: 0.0002; AAPD: 13.80860710144043
Epoch [871/1000] - Loss: 0.0003; AAPD: 13.99897575378418
Epoch [872/1000] - Loss: 0.0016; AAPD: 14.016119956970215
Epoch [873/1000] - Loss: 0.0003; AAPD: 12.519407272338867
Epoch [874/1000] - Loss: 0.0014; AAPD: 13.612879753112793
Epoch [875/1000] - Loss: 0.0004; AAPD: 12.965483665466309
Epoch [876/1000] - Loss: 0.0002; AAPD: 14.750652313232422
Epoch [877/1000] - Loss: 0.0010; AAPD: 13.319908142089844
Epoch [878/1000] - Loss: 0.0023; AAPD: 13.535484313964844
Epoch [879/1000] - Loss: 0.0009; AAPD: 14.027434349060059
Epoch [880/1000] 

In [37]:
class IAModelLeakyReLU(nn.Module):
    def __init__(self, input_dimension=7, output_dimension=20):
        super(IAModelLeakyReLU, self).__init__()
        
        self.input_dimension = input_dimension
        self.output_dimension = output_dimension
        
        self.network = nn.Sequential(
            nn.Linear(input_dimension, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Linear(64, output_dimension)
        )

    def forward(self, x):
        x = self.network(x)
        return x

In [46]:
def train_model(model, model_name, train_loader, dev_loader, target_idx=1, num_epochs=1500, learning_rate=0.1, weight_decay=1e-4):
    #criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[250, 500, 750, 1000], gamma=0.1)
    
    best_model = None
    best_aapd = None

    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            input_data = batch['input_data']
            target1 = batch['out1']
            target2 = batch['out2']
            target3 = batch['out3']
            
            optimizer.zero_grad()
            outputs = model(input_data)
            if target_idx == 1:
                train_loss = AAPD(outputs, target1)
            elif target_idx == 2:
                train_loss = AAPD(outputs, target2)
            elif target_idx == 3:
                train_loss = AAPD(outputs, target3)
            else:
                print("Invalid target_idx")
            
            train_loss.backward()
            optimizer.step()
        scheduler.step()
        
        model.eval()
        
        total_aapd = 0
        num_samples = 0
        
        with torch.no_grad():
            for batch in dev_loader:
                dev_int = batch['input_data']
                dev_out1 = batch['out1']
                dev_out2 = batch['out2']
                dev_out3 = batch['out3']
                
                predictions = model(dev_int)
                
                if target_idx == 1:
                    batch_aapd = AAPD(predictions, dev_out1)
                elif target_idx == 2:
                    batch_aapd = AAPD(predictions, dev_out2)
                elif target_idx == 3:
                    batch_aapd = AAPD(predictions, dev_out3)
                else:
                    print("Invalid target_idx")
                    
                total_aapd += batch_aapd * len(dev_int)
                num_samples += len(dev_int)
            
            aapd = total_aapd / num_samples
            
            if best_aapd == None or aapd < best_aapd:
                best_aapd = aapd
                best_model = model
                torch.save(model.state_dict(), model_name)
        
        print(f'Epoch [{epoch + 1}/{num_epochs}] - Loss: {train_loss.item():.4f}; AAPD: {aapd}')

    return model

In [39]:
model = IAModelLeakyReLU(input_dimension=train_int.shape[1], output_dimension=train_out1.shape[1])
best_model_Leaky_1 = train_model(model, 'best_model_Leaky_1.pth', train_dataloader, dev_dataloader, target_idx=1)

Epoch [1/1500] - Loss: 1408.9928; AAPD: 1.0124073028564453
Epoch [2/1500] - Loss: 755.0059; AAPD: 0.9566273093223572
Epoch [3/1500] - Loss: 1980.0990; AAPD: 0.8984865546226501
Epoch [4/1500] - Loss: 626.9908; AAPD: 0.9198014140129089
Epoch [5/1500] - Loss: 1026.9822; AAPD: 0.9435927867889404
Epoch [6/1500] - Loss: 765.2798; AAPD: 0.9289637804031372
Epoch [7/1500] - Loss: 1938.3007; AAPD: 9.364928245544434
Epoch [8/1500] - Loss: 1251.3826; AAPD: 8.416902542114258
Epoch [9/1500] - Loss: 31327.2598; AAPD: 87.89869689941406
Epoch [10/1500] - Loss: 2573.5122; AAPD: 3.1519908905029297
Epoch [11/1500] - Loss: 1544.2535; AAPD: 2.890791654586792
Epoch [12/1500] - Loss: 984.5840; AAPD: 2.604640483856201
Epoch [13/1500] - Loss: 775.6949; AAPD: 2.8164944648742676
Epoch [14/1500] - Loss: 672.7597; AAPD: 3.2110588550567627
Epoch [15/1500] - Loss: 744.5822; AAPD: 5.694701194763184
Epoch [16/1500] - Loss: 726.8362; AAPD: 23.96734046936035
Epoch [17/1500] - Loss: 1798.9641; AAPD: 2.680182933807373
Epoc

Epoch [138/1500] - Loss: 962.9860; AAPD: 3.743687868118286
Epoch [139/1500] - Loss: 18663.6309; AAPD: 83.80780029296875
Epoch [140/1500] - Loss: 10689.8730; AAPD: 75.25505828857422
Epoch [141/1500] - Loss: 25350.3848; AAPD: 60.88768768310547
Epoch [142/1500] - Loss: 12140.3457; AAPD: 47.98912811279297
Epoch [143/1500] - Loss: 5415.5957; AAPD: 37.23966979980469
Epoch [144/1500] - Loss: 4855.0757; AAPD: 37.81318283081055
Epoch [145/1500] - Loss: 2833.9485; AAPD: 33.70315933227539
Epoch [146/1500] - Loss: 2736.1528; AAPD: 32.02388000488281
Epoch [147/1500] - Loss: 2700.7683; AAPD: 25.47076988220215
Epoch [148/1500] - Loss: 55576.5156; AAPD: 43.707542419433594
Epoch [149/1500] - Loss: 6706203.0000; AAPD: 1442.70556640625
Epoch [150/1500] - Loss: 1526.7791; AAPD: 17.246084213256836
Epoch [151/1500] - Loss: 40997.7109; AAPD: 145.97088623046875
Epoch [152/1500] - Loss: 1924.5625; AAPD: 19.41280174255371
Epoch [153/1500] - Loss: 5049.6436; AAPD: 26.442880630493164
Epoch [154/1500] - Loss: 6313

Epoch [273/1500] - Loss: 2727.4600; AAPD: 17.322193145751953
Epoch [274/1500] - Loss: 1083.0120; AAPD: 6.4529924392700195
Epoch [275/1500] - Loss: 6092.8687; AAPD: 23.646652221679688
Epoch [276/1500] - Loss: 638.0378; AAPD: 7.082281589508057
Epoch [277/1500] - Loss: 1906.5015; AAPD: 12.390589714050293
Epoch [278/1500] - Loss: 4454.7402; AAPD: 33.341495513916016
Epoch [279/1500] - Loss: 2188.1846; AAPD: 6.662670612335205
Epoch [280/1500] - Loss: 1724.9066; AAPD: 4.687750339508057
Epoch [281/1500] - Loss: 6446.9946; AAPD: 6.974724769592285
Epoch [282/1500] - Loss: 2985.0962; AAPD: 5.683686256408691
Epoch [283/1500] - Loss: 46635.1992; AAPD: 104.23685455322266
Epoch [284/1500] - Loss: 1529.3458; AAPD: 6.523567199707031
Epoch [285/1500] - Loss: 10666.6250; AAPD: 104.13672637939453
Epoch [286/1500] - Loss: 2849.8579; AAPD: 21.401472091674805
Epoch [287/1500] - Loss: 1213.9800; AAPD: 5.027349948883057
Epoch [288/1500] - Loss: 2123.7803; AAPD: 3.063464879989624
Epoch [289/1500] - Loss: 2024.7

Epoch [409/1500] - Loss: 2104.3716; AAPD: 6.368396759033203
Epoch [410/1500] - Loss: 2483.0845; AAPD: 32.33606719970703
Epoch [411/1500] - Loss: 1131.8510; AAPD: 7.972255229949951
Epoch [412/1500] - Loss: 1531.6359; AAPD: 7.407702445983887
Epoch [413/1500] - Loss: 1718.0768; AAPD: 2.9024131298065186
Epoch [414/1500] - Loss: 1034.3879; AAPD: 3.819288969039917
Epoch [415/1500] - Loss: 1277.3956; AAPD: 3.598477363586426
Epoch [416/1500] - Loss: 1296.7590; AAPD: 3.870450735092163
Epoch [417/1500] - Loss: 971.1004; AAPD: 7.774872303009033
Epoch [418/1500] - Loss: 1285.1774; AAPD: 6.689743518829346
Epoch [419/1500] - Loss: 2676.0071; AAPD: 17.437976837158203
Epoch [420/1500] - Loss: 149.7899; AAPD: 2.8375754356384277
Epoch [421/1500] - Loss: 1595.8134; AAPD: 3.558363676071167
Epoch [422/1500] - Loss: 569.6684; AAPD: 6.000649929046631
Epoch [423/1500] - Loss: 1407.6082; AAPD: 14.43757152557373
Epoch [424/1500] - Loss: 1896.1393; AAPD: 4.109669208526611
Epoch [425/1500] - Loss: 1133.1644; AAPD

Epoch [546/1500] - Loss: 1249.0255; AAPD: 2.3280773162841797
Epoch [547/1500] - Loss: 1178.9812; AAPD: 3.905595064163208
Epoch [548/1500] - Loss: 932.0612; AAPD: 2.7952823638916016
Epoch [549/1500] - Loss: 1052.0879; AAPD: 3.2685561180114746
Epoch [550/1500] - Loss: 753.9636; AAPD: 1.5931169986724854
Epoch [551/1500] - Loss: 2607.8921; AAPD: 1.4609191417694092
Epoch [552/1500] - Loss: 758.3478; AAPD: 1.9920408725738525
Epoch [553/1500] - Loss: 707.6299; AAPD: 1.7458622455596924
Epoch [554/1500] - Loss: 607.2743; AAPD: 2.3054392337799072
Epoch [555/1500] - Loss: 1387.3640; AAPD: 2.2644472122192383
Epoch [556/1500] - Loss: 1272.1022; AAPD: 1.5373116731643677
Epoch [557/1500] - Loss: 1355.8868; AAPD: 1.773686170578003
Epoch [558/1500] - Loss: 266.5207; AAPD: 1.7613898515701294
Epoch [559/1500] - Loss: 413.3652; AAPD: 1.6248443126678467
Epoch [560/1500] - Loss: 634.1309; AAPD: 3.0570621490478516
Epoch [561/1500] - Loss: 2314.7378; AAPD: 1.9683918952941895
Epoch [562/1500] - Loss: 555.7234;

Epoch [683/1500] - Loss: 709.0383; AAPD: 1.0745521783828735
Epoch [684/1500] - Loss: 1400.3579; AAPD: 1.320153832435608
Epoch [685/1500] - Loss: 1819.1622; AAPD: 5.370154857635498
Epoch [686/1500] - Loss: 289.2074; AAPD: 1.1933690309524536
Epoch [687/1500] - Loss: 1076.1466; AAPD: 1.1263700723648071
Epoch [688/1500] - Loss: 179.2104; AAPD: 1.3733831644058228
Epoch [689/1500] - Loss: 361.4209; AAPD: 1.4614858627319336
Epoch [690/1500] - Loss: 559.0733; AAPD: 2.1956899166107178
Epoch [691/1500] - Loss: 574.6062; AAPD: 2.8083667755126953
Epoch [692/1500] - Loss: 1379.7107; AAPD: 1.7318453788757324
Epoch [693/1500] - Loss: 528.4650; AAPD: 1.065469741821289
Epoch [694/1500] - Loss: 301.4704; AAPD: 1.4347821474075317
Epoch [695/1500] - Loss: 428.5434; AAPD: 1.339286208152771
Epoch [696/1500] - Loss: 512.4085; AAPD: 1.3576103448867798
Epoch [697/1500] - Loss: 975.2743; AAPD: 1.0967828035354614
Epoch [698/1500] - Loss: 339.2681; AAPD: 1.9476450681686401
Epoch [699/1500] - Loss: 1192.8365; AAPD

Epoch [820/1500] - Loss: 1157.5096; AAPD: 1.16153085231781
Epoch [821/1500] - Loss: 532.8527; AAPD: 0.9752354621887207
Epoch [822/1500] - Loss: 153.6945; AAPD: 0.8824145793914795
Epoch [823/1500] - Loss: 376.5278; AAPD: 1.2032655477523804
Epoch [824/1500] - Loss: 324.0952; AAPD: 0.9056012034416199
Epoch [825/1500] - Loss: 322.9359; AAPD: 0.8963425755500793
Epoch [826/1500] - Loss: 98.3895; AAPD: 0.8794329166412354
Epoch [827/1500] - Loss: 194.1698; AAPD: 0.9131422638893127
Epoch [828/1500] - Loss: 1154.3989; AAPD: 0.9731797575950623
Epoch [829/1500] - Loss: 232.2430; AAPD: 0.910637378692627
Epoch [830/1500] - Loss: 973.4021; AAPD: 0.8752149939537048
Epoch [831/1500] - Loss: 1730.2615; AAPD: 0.9743326902389526
Epoch [832/1500] - Loss: 314.1222; AAPD: 1.288447618484497
Epoch [833/1500] - Loss: 2291.7488; AAPD: 0.9311824440956116
Epoch [834/1500] - Loss: 461.5597; AAPD: 0.9016733765602112
Epoch [835/1500] - Loss: 776.1939; AAPD: 0.9532893300056458
Epoch [836/1500] - Loss: 416.5143; AAPD: 

Epoch [957/1500] - Loss: 575.0059; AAPD: 0.8604675531387329
Epoch [958/1500] - Loss: 268.4527; AAPD: 0.7939595580101013
Epoch [959/1500] - Loss: 551.5633; AAPD: 0.7826202511787415
Epoch [960/1500] - Loss: 469.5785; AAPD: 0.7820615172386169
Epoch [961/1500] - Loss: 274.8845; AAPD: 0.7754691243171692
Epoch [962/1500] - Loss: 296.2029; AAPD: 0.8082372546195984
Epoch [963/1500] - Loss: 391.7242; AAPD: 0.7718145251274109
Epoch [964/1500] - Loss: 277.9066; AAPD: 0.7748505473136902
Epoch [965/1500] - Loss: 738.8115; AAPD: 0.7771443128585815
Epoch [966/1500] - Loss: 127.7799; AAPD: 0.747277557849884
Epoch [967/1500] - Loss: 575.6203; AAPD: 0.8503099679946899
Epoch [968/1500] - Loss: 690.8364; AAPD: 0.7922828197479248
Epoch [969/1500] - Loss: 410.2267; AAPD: 0.7987715005874634
Epoch [970/1500] - Loss: 188.4978; AAPD: 0.981598973274231
Epoch [971/1500] - Loss: 591.8449; AAPD: 0.9644724130630493
Epoch [972/1500] - Loss: 523.9845; AAPD: 0.8823390603065491
Epoch [973/1500] - Loss: 449.0052; AAPD: 0

Epoch [1093/1500] - Loss: 1257.0774; AAPD: 0.6962250471115112
Epoch [1094/1500] - Loss: 360.3385; AAPD: 0.6925404071807861
Epoch [1095/1500] - Loss: 96.1359; AAPD: 0.6970527172088623
Epoch [1096/1500] - Loss: 345.5099; AAPD: 0.6966604590415955
Epoch [1097/1500] - Loss: 192.8344; AAPD: 0.7001906633377075
Epoch [1098/1500] - Loss: 1212.2151; AAPD: 0.6998129487037659
Epoch [1099/1500] - Loss: 1505.6915; AAPD: 0.6976643204689026
Epoch [1100/1500] - Loss: 130.7460; AAPD: 0.694729208946228
Epoch [1101/1500] - Loss: 51.3026; AAPD: 0.6957619190216064
Epoch [1102/1500] - Loss: 166.2635; AAPD: 0.6948133111000061
Epoch [1103/1500] - Loss: 533.0460; AAPD: 0.6953946352005005
Epoch [1104/1500] - Loss: 348.2023; AAPD: 0.6942170858383179
Epoch [1105/1500] - Loss: 453.1249; AAPD: 0.6981770992279053
Epoch [1106/1500] - Loss: 1101.1625; AAPD: 0.6954160928726196
Epoch [1107/1500] - Loss: 966.9073; AAPD: 0.6975753307342529
Epoch [1108/1500] - Loss: 936.9874; AAPD: 0.6916359663009644
Epoch [1109/1500] - Los

Epoch [1228/1500] - Loss: 445.0035; AAPD: 0.6829526424407959
Epoch [1229/1500] - Loss: 136.8635; AAPD: 0.6889299750328064
Epoch [1230/1500] - Loss: 828.1485; AAPD: 0.686847448348999
Epoch [1231/1500] - Loss: 193.3255; AAPD: 0.7038301229476929
Epoch [1232/1500] - Loss: 520.8900; AAPD: 0.6875303387641907
Epoch [1233/1500] - Loss: 1515.6949; AAPD: 0.6911891102790833
Epoch [1234/1500] - Loss: 763.4940; AAPD: 0.6825008988380432
Epoch [1235/1500] - Loss: 57.5115; AAPD: 0.6886513829231262
Epoch [1236/1500] - Loss: 280.4817; AAPD: 0.6857746839523315
Epoch [1237/1500] - Loss: 1411.6646; AAPD: 0.6932902932167053
Epoch [1238/1500] - Loss: 532.9288; AAPD: 0.696721613407135
Epoch [1239/1500] - Loss: 1233.4646; AAPD: 0.6820067167282104
Epoch [1240/1500] - Loss: 313.5257; AAPD: 0.6821889877319336
Epoch [1241/1500] - Loss: 341.5179; AAPD: 0.7016140222549438
Epoch [1242/1500] - Loss: 461.2422; AAPD: 0.7129794359207153
Epoch [1243/1500] - Loss: 175.8960; AAPD: 0.7083512544631958
Epoch [1244/1500] - Loss

Epoch [1363/1500] - Loss: 148.5069; AAPD: 0.6735135316848755
Epoch [1364/1500] - Loss: 73.7722; AAPD: 0.6800652742385864
Epoch [1365/1500] - Loss: 310.8080; AAPD: 0.6795627474784851
Epoch [1366/1500] - Loss: 623.0540; AAPD: 0.6762030124664307
Epoch [1367/1500] - Loss: 761.4561; AAPD: 0.678109884262085
Epoch [1368/1500] - Loss: 1118.6655; AAPD: 0.6955458521842957
Epoch [1369/1500] - Loss: 349.1947; AAPD: 0.6722779273986816
Epoch [1370/1500] - Loss: 572.8724; AAPD: 0.6719024181365967
Epoch [1371/1500] - Loss: 634.9039; AAPD: 0.6717768311500549
Epoch [1372/1500] - Loss: 995.5770; AAPD: 0.6796193718910217
Epoch [1373/1500] - Loss: 1077.9473; AAPD: 0.6818490028381348
Epoch [1374/1500] - Loss: 1175.9530; AAPD: 0.6792941689491272
Epoch [1375/1500] - Loss: 471.3090; AAPD: 0.6749935746192932
Epoch [1376/1500] - Loss: 490.0556; AAPD: 0.6741899847984314
Epoch [1377/1500] - Loss: 294.5427; AAPD: 0.6716718673706055
Epoch [1378/1500] - Loss: 129.5813; AAPD: 0.672503650188446
Epoch [1379/1500] - Loss

Epoch [1498/1500] - Loss: 65.8440; AAPD: 0.6652002930641174
Epoch [1499/1500] - Loss: 209.9812; AAPD: 0.6695204973220825
Epoch [1500/1500] - Loss: 522.1546; AAPD: 0.6744105815887451


In [40]:
model = IAModelLeakyReLU(input_dimension=train_int.shape[1], output_dimension=train_out2.shape[1])
best_model_Leaky_2 = train_model(model, 'best_model_Leaky_2.pth', train_dataloader, dev_dataloader, target_idx=2)

Epoch [1/1500] - Loss: 3.3100; AAPD: 86.77537536621094
Epoch [2/1500] - Loss: 3.5555; AAPD: 74.29946899414062
Epoch [3/1500] - Loss: 6.7744; AAPD: 132.5123748779297
Epoch [4/1500] - Loss: 4.8761; AAPD: 94.66427612304688
Epoch [5/1500] - Loss: 27.0023; AAPD: 86.61015319824219
Epoch [6/1500] - Loss: 5496130048.0000; AAPD: 3222097.5
Epoch [7/1500] - Loss: 176403056.0000; AAPD: 374867.375
Epoch [8/1500] - Loss: 15986541.0000; AAPD: 42213.5625
Epoch [9/1500] - Loss: 6762357.5000; AAPD: 135891.140625
Epoch [10/1500] - Loss: 60177.2070; AAPD: 7481.4921875
Epoch [11/1500] - Loss: 2148179.5000; AAPD: 33579.6875
Epoch [12/1500] - Loss: 16828.6523; AAPD: 6861.2099609375
Epoch [13/1500] - Loss: 4271168000.0000; AAPD: 295797.3125
Epoch [14/1500] - Loss: 3092186.7500; AAPD: 79714.1953125
Epoch [15/1500] - Loss: 550030592.0000; AAPD: 530518.1875
Epoch [16/1500] - Loss: 20128886784.0000; AAPD: 617005.125
Epoch [17/1500] - Loss: 26859148.0000; AAPD: 143682.609375
Epoch [18/1500] - Loss: 660187.0000; AA

Epoch [143/1500] - Loss: 1214849024.0000; AAPD: 876099.3125
Epoch [144/1500] - Loss: 23497830.0000; AAPD: 56297.0625
Epoch [145/1500] - Loss: 2749851136.0000; AAPD: 1076126.25
Epoch [146/1500] - Loss: 162816608.0000; AAPD: 422492.6875
Epoch [147/1500] - Loss: 146896683008.0000; AAPD: 11580297.0
Epoch [148/1500] - Loss: 1898361.8750; AAPD: 42054.4296875
Epoch [149/1500] - Loss: 28593246.0000; AAPD: 85149.25
Epoch [150/1500] - Loss: 526672.1250; AAPD: 25979.99609375
Epoch [151/1500] - Loss: 8714222.0000; AAPD: 122514.625
Epoch [152/1500] - Loss: 542863.1875; AAPD: 27312.50390625
Epoch [153/1500] - Loss: 4253202.5000; AAPD: 48182.609375
Epoch [154/1500] - Loss: 8794527.0000; AAPD: 73780.4375
Epoch [155/1500] - Loss: 1594176.7500; AAPD: 26748.51953125
Epoch [156/1500] - Loss: 1582652.6250; AAPD: 26672.509765625
Epoch [157/1500] - Loss: 4164248.5000; AAPD: 91760.3359375
Epoch [158/1500] - Loss: 974547.6250; AAPD: 23998.865234375
Epoch [159/1500] - Loss: 893338.4375; AAPD: 68560.5546875
Epoc

Epoch [283/1500] - Loss: 131341928.0000; AAPD: 365297.1875
Epoch [284/1500] - Loss: 361009952.0000; AAPD: 911443.25
Epoch [285/1500] - Loss: 20559685632.0000; AAPD: 5462793.0
Epoch [286/1500] - Loss: 2122539520.0000; AAPD: 1115259.625
Epoch [287/1500] - Loss: 255119968.0000; AAPD: 507602.53125
Epoch [288/1500] - Loss: 22224552.0000; AAPD: 324360.28125
Epoch [289/1500] - Loss: 33364382.0000; AAPD: 326379.65625
Epoch [290/1500] - Loss: 46303676.0000; AAPD: 374809.15625
Epoch [291/1500] - Loss: 40556440.0000; AAPD: 221014.859375
Epoch [292/1500] - Loss: 315293664.0000; AAPD: 608777.0625
Epoch [293/1500] - Loss: 2491735296.0000; AAPD: 699785.875
Epoch [294/1500] - Loss: 186440480.0000; AAPD: 192116.421875
Epoch [295/1500] - Loss: 109920928.0000; AAPD: 150326.765625
Epoch [296/1500] - Loss: 12935305.0000; AAPD: 136030.0
Epoch [297/1500] - Loss: 83199776.0000; AAPD: 159488.625
Epoch [298/1500] - Loss: 392756128.0000; AAPD: 210245.234375
Epoch [299/1500] - Loss: 719478848.0000; AAPD: 252777.7

Epoch [422/1500] - Loss: 723984.5625; AAPD: 20997.478515625
Epoch [423/1500] - Loss: 6616103.5000; AAPD: 37674.48828125
Epoch [424/1500] - Loss: 9863500.0000; AAPD: 35102.19140625
Epoch [425/1500] - Loss: 184171.3594; AAPD: 16430.978515625
Epoch [426/1500] - Loss: 6378306.0000; AAPD: 133406.546875
Epoch [427/1500] - Loss: 6733145.0000; AAPD: 40877.39453125
Epoch [428/1500] - Loss: 1665658.2500; AAPD: 75829.2578125
Epoch [429/1500] - Loss: 2571171.0000; AAPD: 46516.0859375
Epoch [430/1500] - Loss: 375242.7188; AAPD: 20858.091796875
Epoch [431/1500] - Loss: 307401.6250; AAPD: 18349.376953125
Epoch [432/1500] - Loss: 11456347.0000; AAPD: 65444.4375
Epoch [433/1500] - Loss: 572444.1875; AAPD: 29988.8125
Epoch [434/1500] - Loss: 2248319232.0000; AAPD: 84703.390625
Epoch [435/1500] - Loss: 669065.7500; AAPD: 26779.828125
Epoch [436/1500] - Loss: 765827.2500; AAPD: 19430.853515625
Epoch [437/1500] - Loss: 145680.5625; AAPD: 22605.298828125
Epoch [438/1500] - Loss: 152512.8906; AAPD: 22755.257

Epoch [560/1500] - Loss: 65771.8828; AAPD: 8400.7373046875
Epoch [561/1500] - Loss: 76073.8906; AAPD: 11786.5625
Epoch [562/1500] - Loss: 62759.9258; AAPD: 8451.9375
Epoch [563/1500] - Loss: 38445.2148; AAPD: 10875.59765625
Epoch [564/1500] - Loss: 89454.9766; AAPD: 8931.482421875
Epoch [565/1500] - Loss: 1054812.3750; AAPD: 15749.2802734375
Epoch [566/1500] - Loss: 155548.5938; AAPD: 7193.57666015625
Epoch [567/1500] - Loss: 323919.4375; AAPD: 7461.01611328125
Epoch [568/1500] - Loss: 239412.4531; AAPD: 8397.0263671875
Epoch [569/1500] - Loss: 55129.8789; AAPD: 10294.4013671875
Epoch [570/1500] - Loss: 35488.0234; AAPD: 11599.0888671875
Epoch [571/1500] - Loss: 180676.0781; AAPD: 8267.076171875
Epoch [572/1500] - Loss: 89187.1328; AAPD: 8208.1318359375
Epoch [573/1500] - Loss: 520323.0938; AAPD: 11221.94921875
Epoch [574/1500] - Loss: 2207756.2500; AAPD: 6682.66650390625
Epoch [575/1500] - Loss: 269142.8125; AAPD: 17118.296875
Epoch [576/1500] - Loss: 94911.0781; AAPD: 8109.2470703125

Epoch [698/1500] - Loss: 123941.6953; AAPD: 5875.12109375
Epoch [699/1500] - Loss: 55421.3516; AAPD: 3713.516845703125
Epoch [700/1500] - Loss: 189744.3906; AAPD: 3659.453857421875
Epoch [701/1500] - Loss: 79058.9688; AAPD: 5951.38427734375
Epoch [702/1500] - Loss: 3958161.2500; AAPD: 10511.9619140625
Epoch [703/1500] - Loss: 278086.2812; AAPD: 6570.6591796875
Epoch [704/1500] - Loss: 184498.7188; AAPD: 4492.3037109375
Epoch [705/1500] - Loss: 327749.0000; AAPD: 3241.752197265625
Epoch [706/1500] - Loss: 47599.0977; AAPD: 3674.477783203125
Epoch [707/1500] - Loss: 33129.0977; AAPD: 3367.392822265625
Epoch [708/1500] - Loss: 28179.8711; AAPD: 6344.56494140625
Epoch [709/1500] - Loss: 51394.7734; AAPD: 4774.7099609375
Epoch [710/1500] - Loss: 10799.6533; AAPD: 5734.22998046875
Epoch [711/1500] - Loss: 37378.7305; AAPD: 11717.791015625
Epoch [712/1500] - Loss: 387344.9062; AAPD: 7032.35595703125
Epoch [713/1500] - Loss: 101179.2109; AAPD: 6847.98388671875
Epoch [714/1500] - Loss: 322831.5

Epoch [835/1500] - Loss: 1070.5988; AAPD: 1207.8504638671875
Epoch [836/1500] - Loss: 1407.3049; AAPD: 1182.98876953125
Epoch [837/1500] - Loss: 6654.1655; AAPD: 1709.018310546875
Epoch [838/1500] - Loss: 512.9598; AAPD: 1315.957275390625
Epoch [839/1500] - Loss: 3075.2388; AAPD: 1445.910400390625
Epoch [840/1500] - Loss: 4053.6670; AAPD: 1414.8006591796875
Epoch [841/1500] - Loss: 1642.2335; AAPD: 1708.26416015625
Epoch [842/1500] - Loss: 1187.6619; AAPD: 1235.89501953125
Epoch [843/1500] - Loss: 5464.1450; AAPD: 1224.8778076171875
Epoch [844/1500] - Loss: 14679.7881; AAPD: 1192.9398193359375
Epoch [845/1500] - Loss: 1568.7913; AAPD: 1582.2724609375
Epoch [846/1500] - Loss: 104435.0156; AAPD: 1808.465576171875
Epoch [847/1500] - Loss: 2502.9963; AAPD: 1299.2376708984375
Epoch [848/1500] - Loss: 5759.6626; AAPD: 1210.2335205078125
Epoch [849/1500] - Loss: 3319.1792; AAPD: 1655.444091796875
Epoch [850/1500] - Loss: 344.4425; AAPD: 1224.326171875
Epoch [851/1500] - Loss: 942.6127; AAPD: 

Epoch [972/1500] - Loss: 442.2787; AAPD: 932.561767578125
Epoch [973/1500] - Loss: 1729.4279; AAPD: 1029.292724609375
Epoch [974/1500] - Loss: 609.1660; AAPD: 1079.591796875
Epoch [975/1500] - Loss: 1241.2207; AAPD: 1062.127197265625
Epoch [976/1500] - Loss: 882.3080; AAPD: 963.181640625
Epoch [977/1500] - Loss: 3076.5139; AAPD: 987.4769287109375
Epoch [978/1500] - Loss: 4545.7622; AAPD: 2585.2177734375
Epoch [979/1500] - Loss: 2289.7583; AAPD: 1157.188720703125
Epoch [980/1500] - Loss: 2852.6008; AAPD: 1069.9566650390625
Epoch [981/1500] - Loss: 1937.6367; AAPD: 1024.4566650390625
Epoch [982/1500] - Loss: 2231.4766; AAPD: 1851.961669921875
Epoch [983/1500] - Loss: 5994.7334; AAPD: 2211.17578125
Epoch [984/1500] - Loss: 1186.8278; AAPD: 990.7357788085938
Epoch [985/1500] - Loss: 764.4354; AAPD: 1119.87548828125
Epoch [986/1500] - Loss: 4267.0591; AAPD: 1031.930419921875
Epoch [987/1500] - Loss: 539.3742; AAPD: 1148.1163330078125
Epoch [988/1500] - Loss: 10969.5273; AAPD: 2245.609619140

Epoch [1110/1500] - Loss: 1905.6042; AAPD: 794.7000122070312
Epoch [1111/1500] - Loss: 936.2241; AAPD: 804.8048706054688
Epoch [1112/1500] - Loss: 389.7198; AAPD: 775.6824340820312
Epoch [1113/1500] - Loss: 512.0520; AAPD: 763.2589721679688
Epoch [1114/1500] - Loss: 725.1623; AAPD: 797.5001220703125
Epoch [1115/1500] - Loss: 639.2767; AAPD: 807.6293334960938
Epoch [1116/1500] - Loss: 1605.5779; AAPD: 778.5597534179688
Epoch [1117/1500] - Loss: 293.8013; AAPD: 756.358154296875
Epoch [1118/1500] - Loss: 210.9147; AAPD: 869.2579956054688
Epoch [1119/1500] - Loss: 2245.4368; AAPD: 768.3374633789062
Epoch [1120/1500] - Loss: 307.6307; AAPD: 794.9981689453125
Epoch [1121/1500] - Loss: 992.7408; AAPD: 763.9208374023438
Epoch [1122/1500] - Loss: 3312.8765; AAPD: 792.2138061523438
Epoch [1123/1500] - Loss: 616.4456; AAPD: 782.65380859375
Epoch [1124/1500] - Loss: 2027.8097; AAPD: 773.5858764648438
Epoch [1125/1500] - Loss: 260.9162; AAPD: 822.0927734375
Epoch [1126/1500] - Loss: 2366.7454; AAPD

Epoch [1248/1500] - Loss: 620.7692; AAPD: 725.0797729492188
Epoch [1249/1500] - Loss: 336.8866; AAPD: 749.543701171875
Epoch [1250/1500] - Loss: 167.6848; AAPD: 688.0497436523438
Epoch [1251/1500] - Loss: 1798.1261; AAPD: 689.5363159179688
Epoch [1252/1500] - Loss: 510.9807; AAPD: 700.0237426757812
Epoch [1253/1500] - Loss: 868.3180; AAPD: 717.8085327148438
Epoch [1254/1500] - Loss: 302.1197; AAPD: 734.5562744140625
Epoch [1255/1500] - Loss: 206.3783; AAPD: 731.95263671875
Epoch [1256/1500] - Loss: 1207.2932; AAPD: 730.9000854492188
Epoch [1257/1500] - Loss: 415.9487; AAPD: 795.9156494140625
Epoch [1258/1500] - Loss: 861.3441; AAPD: 734.7132568359375
Epoch [1259/1500] - Loss: 133.3123; AAPD: 719.2068481445312
Epoch [1260/1500] - Loss: 639.4502; AAPD: 687.1527099609375
Epoch [1261/1500] - Loss: 396.1701; AAPD: 688.205810546875
Epoch [1262/1500] - Loss: 2211.6731; AAPD: 731.7927856445312
Epoch [1263/1500] - Loss: 1176.9985; AAPD: 724.9027099609375
Epoch [1264/1500] - Loss: 186.3479; AAPD

Epoch [1386/1500] - Loss: 362.8253; AAPD: 703.2734985351562
Epoch [1387/1500] - Loss: 279.3661; AAPD: 651.873779296875
Epoch [1388/1500] - Loss: 203.1975; AAPD: 688.5178833007812
Epoch [1389/1500] - Loss: 3375.6724; AAPD: 705.7459716796875
Epoch [1390/1500] - Loss: 227.9337; AAPD: 638.2349853515625
Epoch [1391/1500] - Loss: 232.5487; AAPD: 634.568115234375
Epoch [1392/1500] - Loss: 815.5257; AAPD: 651.9178466796875
Epoch [1393/1500] - Loss: 1217.7448; AAPD: 631.9892578125
Epoch [1394/1500] - Loss: 2895.9556; AAPD: 654.7498168945312
Epoch [1395/1500] - Loss: 557.6388; AAPD: 627.2009887695312
Epoch [1396/1500] - Loss: 1841.7362; AAPD: 650.7389526367188
Epoch [1397/1500] - Loss: 323.5625; AAPD: 638.0469970703125
Epoch [1398/1500] - Loss: 199.9101; AAPD: 673.8684692382812
Epoch [1399/1500] - Loss: 212.7675; AAPD: 628.57666015625
Epoch [1400/1500] - Loss: 308.5394; AAPD: 668.0216064453125
Epoch [1401/1500] - Loss: 318.0847; AAPD: 640.3959350585938
Epoch [1402/1500] - Loss: 189.4529; AAPD: 6

In [41]:
model = IAModelLeakyReLU(input_dimension=train_int.shape[1], output_dimension=train_out3.shape[1])
best_model_Leaky_3 = train_model(model, 'best_model_Leaky_3.pth', train_dataloader, dev_dataloader, target_idx=3)

Epoch [1/1500] - Loss: 2340831887360.0000; AAPD: 316806880.0
Epoch [2/1500] - Loss: 74783328.0000; AAPD: 3215037.25
Epoch [3/1500] - Loss: 27876108.0000; AAPD: 482528.625
Epoch [4/1500] - Loss: 857587.0625; AAPD: 59105.41796875
Epoch [5/1500] - Loss: 133112656.0000; AAPD: 411759.25
Epoch [6/1500] - Loss: 429407.3125; AAPD: 69224.375
Epoch [7/1500] - Loss: 8950131.0000; AAPD: 142237.359375
Epoch [8/1500] - Loss: 1383815.3750; AAPD: 76856.2890625
Epoch [9/1500] - Loss: 1607143.2500; AAPD: 80106.4765625
Epoch [10/1500] - Loss: 1010369.1250; AAPD: 123630.96875
Epoch [11/1500] - Loss: 180780352.0000; AAPD: 968264.4375
Epoch [12/1500] - Loss: 5955830.0000; AAPD: 434538.15625
Epoch [13/1500] - Loss: 458622631936.0000; AAPD: 57422164.0
Epoch [14/1500] - Loss: 359526432768.0000; AAPD: 36885088.0
Epoch [15/1500] - Loss: 1446524.0000; AAPD: 266622.75
Epoch [16/1500] - Loss: 6125888.5000; AAPD: 327083.125
Epoch [17/1500] - Loss: 44513924.0000; AAPD: 397539.125
Epoch [18/1500] - Loss: 12877106.0000

Epoch [144/1500] - Loss: 38779860.0000; AAPD: 730072.1875
Epoch [145/1500] - Loss: 7787677.0000; AAPD: 667867.0625
Epoch [146/1500] - Loss: 1763652224.0000; AAPD: 6758532.5
Epoch [147/1500] - Loss: 1618193024.0000; AAPD: 1228154.5
Epoch [148/1500] - Loss: 7214201.5000; AAPD: 320814.03125
Epoch [149/1500] - Loss: 34099680.0000; AAPD: 540238.375
Epoch [150/1500] - Loss: 20926716.0000; AAPD: 294937.15625
Epoch [151/1500] - Loss: 6244145.5000; AAPD: 196127.109375
Epoch [152/1500] - Loss: 810189.6250; AAPD: 233631.078125
Epoch [153/1500] - Loss: 4568291.0000; AAPD: 344136.15625
Epoch [154/1500] - Loss: 21841203200.0000; AAPD: 14691406.0
Epoch [155/1500] - Loss: 11147695.0000; AAPD: 309255.78125
Epoch [156/1500] - Loss: 17233798.0000; AAPD: 315067.4375
Epoch [157/1500] - Loss: 2839470.5000; AAPD: 226946.25
Epoch [158/1500] - Loss: 1711992.7500; AAPD: 199045.828125
Epoch [159/1500] - Loss: 5917049.5000; AAPD: 155404.828125
Epoch [160/1500] - Loss: 4977614.5000; AAPD: 243009.421875
Epoch [161/

Epoch [285/1500] - Loss: 4771997.0000; AAPD: 279621.75
Epoch [286/1500] - Loss: 51266680.0000; AAPD: 588609.25
Epoch [287/1500] - Loss: 2257007.5000; AAPD: 213622.453125
Epoch [288/1500] - Loss: 342441472.0000; AAPD: 2483787.5
Epoch [289/1500] - Loss: 16764373.0000; AAPD: 586550.8125
Epoch [290/1500] - Loss: 16802026.0000; AAPD: 336968.6875
Epoch [291/1500] - Loss: 10472999.0000; AAPD: 695012.375
Epoch [292/1500] - Loss: 48887688.0000; AAPD: 651645.3125
Epoch [293/1500] - Loss: 13423473.0000; AAPD: 261236.0
Epoch [294/1500] - Loss: 817870.2500; AAPD: 121651.3515625
Epoch [295/1500] - Loss: 7228038.5000; AAPD: 162898.28125
Epoch [296/1500] - Loss: 3896401408.0000; AAPD: 3508783.25
Epoch [297/1500] - Loss: 20350852.0000; AAPD: 649588.8125
Epoch [298/1500] - Loss: 4750589.5000; AAPD: 225075.3125
Epoch [299/1500] - Loss: 13219005.0000; AAPD: 188320.234375
Epoch [300/1500] - Loss: 24559250.0000; AAPD: 280144.3125
Epoch [301/1500] - Loss: 372976064.0000; AAPD: 622157.5
Epoch [302/1500] - Los

Epoch [428/1500] - Loss: 105235.7656; AAPD: 55545.6875
Epoch [429/1500] - Loss: 205413.1719; AAPD: 48057.16015625
Epoch [430/1500] - Loss: 147769.9844; AAPD: 35280.0078125
Epoch [431/1500] - Loss: 28913.8965; AAPD: 26739.12109375
Epoch [432/1500] - Loss: 35084.0391; AAPD: 28477.896484375
Epoch [433/1500] - Loss: 556503.8125; AAPD: 61528.8125
Epoch [434/1500] - Loss: 45208236.0000; AAPD: 1213544.875
Epoch [435/1500] - Loss: 89183.2188; AAPD: 33508.4375
Epoch [436/1500] - Loss: 291459.2812; AAPD: 118703.453125
Epoch [437/1500] - Loss: 306132.6562; AAPD: 104346.984375
Epoch [438/1500] - Loss: 417659.3125; AAPD: 36733.10546875
Epoch [439/1500] - Loss: 2045727.5000; AAPD: 136473.890625
Epoch [440/1500] - Loss: 1591053.5000; AAPD: 112247.375
Epoch [441/1500] - Loss: 157954.6562; AAPD: 27123.435546875
Epoch [442/1500] - Loss: 274431.8125; AAPD: 42098.3828125
Epoch [443/1500] - Loss: 8747270.0000; AAPD: 318287.9375
Epoch [444/1500] - Loss: 65655.1641; AAPD: 41903.44140625
Epoch [445/1500] - Lo

Epoch [570/1500] - Loss: 17887.9004; AAPD: 23400.17578125
Epoch [571/1500] - Loss: 6633.2705; AAPD: 5557.7216796875
Epoch [572/1500] - Loss: 1424.7117; AAPD: 8953.7255859375
Epoch [573/1500] - Loss: 61197.4492; AAPD: 27795.681640625
Epoch [574/1500] - Loss: 32540.2695; AAPD: 7246.90234375
Epoch [575/1500] - Loss: 161603.2812; AAPD: 25219.65234375
Epoch [576/1500] - Loss: 12365.0781; AAPD: 10916.794921875
Epoch [577/1500] - Loss: 7405.4272; AAPD: 10754.0048828125
Epoch [578/1500] - Loss: 168087.3125; AAPD: 51922.2265625
Epoch [579/1500] - Loss: 6522.5220; AAPD: 13124.943359375
Epoch [580/1500] - Loss: 2593.5317; AAPD: 6990.39599609375
Epoch [581/1500] - Loss: 3001.7310; AAPD: 5076.7763671875
Epoch [582/1500] - Loss: 18913.8203; AAPD: 9538.7890625
Epoch [583/1500] - Loss: 3074.2607; AAPD: 7212.8740234375
Epoch [584/1500] - Loss: 12980.6553; AAPD: 30978.802734375
Epoch [585/1500] - Loss: 10177.2939; AAPD: 14537.955078125
Epoch [586/1500] - Loss: 12116.5195; AAPD: 17668.369140625
Epoch [58

Epoch [712/1500] - Loss: 1470.8931; AAPD: 4828.34326171875
Epoch [713/1500] - Loss: 27473.8047; AAPD: 15194.064453125
Epoch [714/1500] - Loss: 2540.6260; AAPD: 5050.90283203125
Epoch [715/1500] - Loss: 3575.9521; AAPD: 6327.40283203125
Epoch [716/1500] - Loss: 3625.7026; AAPD: 4984.640625
Epoch [717/1500] - Loss: 1933.6412; AAPD: 4502.9814453125
Epoch [718/1500] - Loss: 2987.6001; AAPD: 4674.18505859375
Epoch [719/1500] - Loss: 3067.4287; AAPD: 4327.7939453125
Epoch [720/1500] - Loss: 16858.8848; AAPD: 8445.203125
Epoch [721/1500] - Loss: 4110.2183; AAPD: 8912.06640625
Epoch [722/1500] - Loss: 1132.4941; AAPD: 5170.47607421875
Epoch [723/1500] - Loss: 987.2841; AAPD: 4834.02099609375
Epoch [724/1500] - Loss: 3086.0391; AAPD: 7562.6064453125
Epoch [725/1500] - Loss: 2233.0159; AAPD: 5834.3818359375
Epoch [726/1500] - Loss: 26989.2305; AAPD: 9374.7314453125
Epoch [727/1500] - Loss: 5473.7329; AAPD: 8860.20703125
Epoch [728/1500] - Loss: 3430.8423; AAPD: 5078.59619140625
Epoch [729/1500] 

Epoch [854/1500] - Loss: 1363.9718; AAPD: 5865.34716796875
Epoch [855/1500] - Loss: 2218.1262; AAPD: 4728.7900390625
Epoch [856/1500] - Loss: 638.2277; AAPD: 4073.866943359375
Epoch [857/1500] - Loss: 1253.7966; AAPD: 5176.91357421875
Epoch [858/1500] - Loss: 5507.6504; AAPD: 4523.77392578125
Epoch [859/1500] - Loss: 775.8741; AAPD: 4607.22998046875
Epoch [860/1500] - Loss: 3207.6077; AAPD: 7703.06884765625
Epoch [861/1500] - Loss: 3413.6089; AAPD: 4032.750732421875
Epoch [862/1500] - Loss: 305.6385; AAPD: 3887.124267578125
Epoch [863/1500] - Loss: 2754.6147; AAPD: 5971.744140625
Epoch [864/1500] - Loss: 928.1913; AAPD: 4069.078125
Epoch [865/1500] - Loss: 2488.4231; AAPD: 3857.212158203125
Epoch [866/1500] - Loss: 1960.8802; AAPD: 4152.80029296875
Epoch [867/1500] - Loss: 1068.5237; AAPD: 4289.2939453125
Epoch [868/1500] - Loss: 295.9013; AAPD: 4363.00244140625
Epoch [869/1500] - Loss: 8947.0850; AAPD: 4857.19873046875
Epoch [870/1500] - Loss: 730.8217; AAPD: 3827.9326171875
Epoch [87

Epoch [995/1500] - Loss: 1482.1652; AAPD: 4751.341796875
Epoch [996/1500] - Loss: 668.0329; AAPD: 3818.742431640625
Epoch [997/1500] - Loss: 669.1881; AAPD: 4024.45458984375
Epoch [998/1500] - Loss: 1346.9478; AAPD: 4762.8466796875
Epoch [999/1500] - Loss: 411.0739; AAPD: 3768.593505859375
Epoch [1000/1500] - Loss: 1604.6273; AAPD: 3984.792724609375
Epoch [1001/1500] - Loss: 1080.6971; AAPD: 3583.638916015625
Epoch [1002/1500] - Loss: 3416.5781; AAPD: 3731.007568359375
Epoch [1003/1500] - Loss: 804.7249; AAPD: 3610.227783203125
Epoch [1004/1500] - Loss: 331.7203; AAPD: 3600.268798828125
Epoch [1005/1500] - Loss: 309.8102; AAPD: 3614.1171875
Epoch [1006/1500] - Loss: 483.0340; AAPD: 3615.179443359375
Epoch [1007/1500] - Loss: 448.3244; AAPD: 3638.41259765625
Epoch [1008/1500] - Loss: 655.1345; AAPD: 3691.01318359375
Epoch [1009/1500] - Loss: 12297.4219; AAPD: 3619.828125
Epoch [1010/1500] - Loss: 230.5904; AAPD: 3677.16650390625
Epoch [1011/1500] - Loss: 687.2873; AAPD: 3669.30908203125

Epoch [1134/1500] - Loss: 257.7316; AAPD: 3646.031494140625
Epoch [1135/1500] - Loss: 650.3631; AAPD: 3709.85009765625
Epoch [1136/1500] - Loss: 495739.7188; AAPD: 3639.5810546875
Epoch [1137/1500] - Loss: 2366.9497; AAPD: 3766.1865234375
Epoch [1138/1500] - Loss: 857.6416; AAPD: 3593.828125
Epoch [1139/1500] - Loss: 655.0715; AAPD: 3637.6484375
Epoch [1140/1500] - Loss: 719.2725; AAPD: 3742.191650390625
Epoch [1141/1500] - Loss: 43688.4805; AAPD: 3636.5419921875
Epoch [1142/1500] - Loss: 495.4529; AAPD: 3766.748291015625
Epoch [1143/1500] - Loss: 4495.7739; AAPD: 3632.00537109375
Epoch [1144/1500] - Loss: 341.9518; AAPD: 3710.95068359375
Epoch [1145/1500] - Loss: 278.1152; AAPD: 3566.32177734375
Epoch [1146/1500] - Loss: 378.6346; AAPD: 3615.22119140625
Epoch [1147/1500] - Loss: 1362.2014; AAPD: 3657.767578125
Epoch [1148/1500] - Loss: 689.0081; AAPD: 3701.67724609375
Epoch [1149/1500] - Loss: 585.7289; AAPD: 3612.463623046875
Epoch [1150/1500] - Loss: 205.0254; AAPD: 3677.87451171875

Epoch [1273/1500] - Loss: 259.4772; AAPD: 3680.5068359375
Epoch [1274/1500] - Loss: 191.9607; AAPD: 3695.696044921875
Epoch [1275/1500] - Loss: 609.8894; AAPD: 3611.511474609375
Epoch [1276/1500] - Loss: 225.5484; AAPD: 3620.335693359375
Epoch [1277/1500] - Loss: 1382.7743; AAPD: 3862.250244140625
Epoch [1278/1500] - Loss: 757.3429; AAPD: 3553.596923828125
Epoch [1279/1500] - Loss: 2373.3052; AAPD: 3540.1552734375
Epoch [1280/1500] - Loss: 696.4598; AAPD: 3644.960693359375
Epoch [1281/1500] - Loss: 701.6172; AAPD: 3741.96826171875
Epoch [1282/1500] - Loss: 32772.3906; AAPD: 3584.9833984375
Epoch [1283/1500] - Loss: 429.0912; AAPD: 3685.0859375
Epoch [1284/1500] - Loss: 569.2745; AAPD: 3648.56689453125
Epoch [1285/1500] - Loss: 331.0904; AAPD: 3595.881103515625
Epoch [1286/1500] - Loss: 373.7415; AAPD: 3695.24853515625
Epoch [1287/1500] - Loss: 1718.4142; AAPD: 3604.463134765625
Epoch [1288/1500] - Loss: 758.8329; AAPD: 3612.833251953125
Epoch [1289/1500] - Loss: 1509.0203; AAPD: 3655.0

Epoch [1412/1500] - Loss: 10931.6611; AAPD: 3694.460205078125
Epoch [1413/1500] - Loss: 1335.8188; AAPD: 3749.013671875
Epoch [1414/1500] - Loss: 214.7961; AAPD: 3745.755859375
Epoch [1415/1500] - Loss: 788.5624; AAPD: 3677.53759765625
Epoch [1416/1500] - Loss: 2402.4409; AAPD: 3586.550048828125
Epoch [1417/1500] - Loss: 1352.3597; AAPD: 3609.94580078125
Epoch [1418/1500] - Loss: 846.1975; AAPD: 3640.01708984375
Epoch [1419/1500] - Loss: 927.9938; AAPD: 3578.887451171875
Epoch [1420/1500] - Loss: 360.3016; AAPD: 3653.7734375
Epoch [1421/1500] - Loss: 233.6532; AAPD: 3629.438720703125
Epoch [1422/1500] - Loss: 664.7860; AAPD: 3627.72265625
Epoch [1423/1500] - Loss: 476.7872; AAPD: 3648.93603515625
Epoch [1424/1500] - Loss: 407.6866; AAPD: 3553.319580078125
Epoch [1425/1500] - Loss: 736.7434; AAPD: 3569.38623046875
Epoch [1426/1500] - Loss: 411.1096; AAPD: 3720.14208984375
Epoch [1427/1500] - Loss: 911.0769; AAPD: 3707.2216796875
Epoch [1428/1500] - Loss: 544.0079; AAPD: 3640.31274414062

In [43]:
model = IAModelLeakyReLU(input_dimension=train_int.shape[1], output_dimension=train_out1.shape[1])
model_MSE_Leaky_1 = train_model(model, 'best_model_MSE_Leaky_1.pth', train_dataloader, dev_dataloader, target_idx=1)
torch.save(model_MSE_Leaky_1.state_dict(), "model_MSE_Leaky_1.pth")

Epoch [1/1500] - Loss: 417579.9688; AAPD: 10.178823471069336
Epoch [2/1500] - Loss: 61705876.0000; AAPD: 8.679473876953125
Epoch [3/1500] - Loss: 5555274.0000; AAPD: 15.271012306213379
Epoch [4/1500] - Loss: 20962314.0000; AAPD: 38.688438415527344
Epoch [5/1500] - Loss: 4353307.0000; AAPD: 23.645965576171875
Epoch [6/1500] - Loss: 2083707.2500; AAPD: 54.898643493652344
Epoch [7/1500] - Loss: 10190556.0000; AAPD: 31.581682205200195
Epoch [8/1500] - Loss: 1374637.2500; AAPD: 15.112069129943848
Epoch [9/1500] - Loss: 70593096.0000; AAPD: 73.04328155517578
Epoch [10/1500] - Loss: 633735.1250; AAPD: 33.25265121459961
Epoch [11/1500] - Loss: 57420936.0000; AAPD: 16.849552154541016
Epoch [12/1500] - Loss: 5002865.5000; AAPD: 55.89371109008789
Epoch [13/1500] - Loss: 20046452.0000; AAPD: 30.81563949584961
Epoch [14/1500] - Loss: 1830574.3750; AAPD: 15.652100563049316
Epoch [15/1500] - Loss: 32487748.0000; AAPD: 50.39274215698242
Epoch [16/1500] - Loss: 2543277.2500; AAPD: 8.395999908447266
Epo

Epoch [132/1500] - Loss: 2167971.5000; AAPD: 40.38095474243164
Epoch [133/1500] - Loss: 2734928.5000; AAPD: 35.749290466308594
Epoch [134/1500] - Loss: 5971222.0000; AAPD: 35.356327056884766
Epoch [135/1500] - Loss: 115402640.0000; AAPD: 30.96712303161621
Epoch [136/1500] - Loss: 2353673.0000; AAPD: 23.241703033447266
Epoch [137/1500] - Loss: 2822639.2500; AAPD: 751.5911865234375
Epoch [138/1500] - Loss: 6231776.5000; AAPD: 438.4264831542969
Epoch [139/1500] - Loss: 74033128.0000; AAPD: 286.8964538574219
Epoch [140/1500] - Loss: 3102423.5000; AAPD: 255.40023803710938
Epoch [141/1500] - Loss: 1274266.7500; AAPD: 220.91172790527344
Epoch [142/1500] - Loss: 4210695.5000; AAPD: 190.12152099609375
Epoch [143/1500] - Loss: 7156138.5000; AAPD: 185.4453582763672
Epoch [144/1500] - Loss: 4406694.0000; AAPD: 165.21116638183594
Epoch [145/1500] - Loss: 4283382.5000; AAPD: 225.4469757080078
Epoch [146/1500] - Loss: 23112380.0000; AAPD: 217.88671875
Epoch [147/1500] - Loss: 4329903.0000; AAPD: 372.

Epoch [262/1500] - Loss: 3122116.2500; AAPD: 47.96610641479492
Epoch [263/1500] - Loss: 3751164.5000; AAPD: 40.74638748168945
Epoch [264/1500] - Loss: 1099241.2500; AAPD: 39.72227096557617
Epoch [265/1500] - Loss: 6424748.5000; AAPD: 61.36554718017578
Epoch [266/1500] - Loss: 1313679.7500; AAPD: 39.91238021850586
Epoch [267/1500] - Loss: 1202027.5000; AAPD: 36.25959396362305
Epoch [268/1500] - Loss: 5129908.5000; AAPD: 65.50259399414062
Epoch [269/1500] - Loss: 3755028.5000; AAPD: 46.08168411254883
Epoch [270/1500] - Loss: 2162862.7500; AAPD: 119.19261932373047
Epoch [271/1500] - Loss: 2444862.5000; AAPD: 29.917394638061523
Epoch [272/1500] - Loss: 844508.6250; AAPD: 33.19982147216797
Epoch [273/1500] - Loss: 6829043.0000; AAPD: 79.29669952392578
Epoch [274/1500] - Loss: 2412068.0000; AAPD: 52.00367736816406
Epoch [275/1500] - Loss: 1506314.7500; AAPD: 27.697397232055664
Epoch [276/1500] - Loss: 1244690.3750; AAPD: 35.14978790283203
Epoch [277/1500] - Loss: 1667881.5000; AAPD: 75.75521

Epoch [392/1500] - Loss: 544147.0000; AAPD: 15.229436874389648
Epoch [393/1500] - Loss: 344487.4062; AAPD: 22.68903923034668
Epoch [394/1500] - Loss: 482299.3125; AAPD: 38.37332534790039
Epoch [395/1500] - Loss: 1276591.1250; AAPD: 19.81571388244629
Epoch [396/1500] - Loss: 6996894.5000; AAPD: 34.002235412597656
Epoch [397/1500] - Loss: 2668140.0000; AAPD: 43.270633697509766
Epoch [398/1500] - Loss: 1611383.1250; AAPD: 25.692977905273438
Epoch [399/1500] - Loss: 12602734.0000; AAPD: 71.58194732666016
Epoch [400/1500] - Loss: 10244994.0000; AAPD: 19.085039138793945
Epoch [401/1500] - Loss: 897495.3750; AAPD: 28.782756805419922
Epoch [402/1500] - Loss: 654272.8750; AAPD: 19.586111068725586
Epoch [403/1500] - Loss: 1673878.8750; AAPD: 18.626789093017578
Epoch [404/1500] - Loss: 452318.4688; AAPD: 18.615453720092773
Epoch [405/1500] - Loss: 1476287.8750; AAPD: 17.709209442138672
Epoch [406/1500] - Loss: 3149341.2500; AAPD: 27.96866798400879
Epoch [407/1500] - Loss: 147771.8438; AAPD: 18.06

Epoch [523/1500] - Loss: 816722.8125; AAPD: 8.574336051940918
Epoch [524/1500] - Loss: 178146.6719; AAPD: 7.912341117858887
Epoch [525/1500] - Loss: 2846907.5000; AAPD: 7.624415874481201
Epoch [526/1500] - Loss: 213304.7188; AAPD: 8.847650527954102
Epoch [527/1500] - Loss: 258188.7812; AAPD: 7.780192852020264
Epoch [528/1500] - Loss: 11925268.0000; AAPD: 11.291276931762695
Epoch [529/1500] - Loss: 2153534.7500; AAPD: 6.545529365539551
Epoch [530/1500] - Loss: 23470828.0000; AAPD: 7.358023166656494
Epoch [531/1500] - Loss: 764205.5625; AAPD: 10.206907272338867
Epoch [532/1500] - Loss: 3071061.2500; AAPD: 6.845347881317139
Epoch [533/1500] - Loss: 931003.9375; AAPD: 7.021864414215088
Epoch [534/1500] - Loss: 949004.3125; AAPD: 9.447792053222656
Epoch [535/1500] - Loss: 390205.4688; AAPD: 6.6617913246154785
Epoch [536/1500] - Loss: 16228478.0000; AAPD: 6.83765983581543
Epoch [537/1500] - Loss: 522708.6250; AAPD: 11.057976722717285
Epoch [538/1500] - Loss: 1534833.6250; AAPD: 11.0187587738

Epoch [654/1500] - Loss: 5642991.5000; AAPD: 10.495326042175293
Epoch [655/1500] - Loss: 981636.0000; AAPD: 6.314314365386963
Epoch [656/1500] - Loss: 164718.6875; AAPD: 7.855029106140137
Epoch [657/1500] - Loss: 4284161.5000; AAPD: 6.34869909286499
Epoch [658/1500] - Loss: 39767508.0000; AAPD: 9.275516510009766
Epoch [659/1500] - Loss: 965525.1875; AAPD: 6.3338799476623535
Epoch [660/1500] - Loss: 1136739.7500; AAPD: 8.19510269165039
Epoch [661/1500] - Loss: 2439056.5000; AAPD: 7.6970109939575195
Epoch [662/1500] - Loss: 10538941.0000; AAPD: 5.526066780090332
Epoch [663/1500] - Loss: 2333752.7500; AAPD: 8.6959867477417
Epoch [664/1500] - Loss: 446518.9062; AAPD: 7.087174892425537
Epoch [665/1500] - Loss: 757657.7500; AAPD: 5.733604907989502
Epoch [666/1500] - Loss: 450388.1562; AAPD: 5.989292621612549
Epoch [667/1500] - Loss: 2997719.2500; AAPD: 8.274726867675781
Epoch [668/1500] - Loss: 428387.5000; AAPD: 6.032771587371826
Epoch [669/1500] - Loss: 1218519.6250; AAPD: 5.48944234848022

Epoch [786/1500] - Loss: 484128.9688; AAPD: 4.335707187652588
Epoch [787/1500] - Loss: 282303.6562; AAPD: 4.394235134124756
Epoch [788/1500] - Loss: 1222017.5000; AAPD: 4.076605796813965
Epoch [789/1500] - Loss: 537712.6875; AAPD: 4.187765598297119
Epoch [790/1500] - Loss: 276233.3125; AAPD: 3.9784929752349854
Epoch [791/1500] - Loss: 1235495.3750; AAPD: 4.381641387939453
Epoch [792/1500] - Loss: 286516.5312; AAPD: 4.255034923553467
Epoch [793/1500] - Loss: 1022324.1250; AAPD: 4.348249912261963
Epoch [794/1500] - Loss: 153515.0312; AAPD: 3.9997012615203857
Epoch [795/1500] - Loss: 5643823.5000; AAPD: 3.8608531951904297
Epoch [796/1500] - Loss: 544060.8750; AAPD: 4.031019687652588
Epoch [797/1500] - Loss: 508431.2188; AAPD: 5.133116245269775
Epoch [798/1500] - Loss: 557271.4375; AAPD: 4.201747894287109
Epoch [799/1500] - Loss: 5007807.5000; AAPD: 3.893214702606201
Epoch [800/1500] - Loss: 261879.4844; AAPD: 4.214362144470215
Epoch [801/1500] - Loss: 2158234.5000; AAPD: 4.970174312591553

Epoch [917/1500] - Loss: 539124.6875; AAPD: 4.421694755554199
Epoch [918/1500] - Loss: 4524562.5000; AAPD: 5.390832901000977
Epoch [919/1500] - Loss: 946280.7500; AAPD: 3.9634242057800293
Epoch [920/1500] - Loss: 361668.5625; AAPD: 4.224884033203125
Epoch [921/1500] - Loss: 209216.4219; AAPD: 3.8257224559783936
Epoch [922/1500] - Loss: 351482.6250; AAPD: 3.6843619346618652
Epoch [923/1500] - Loss: 168249.7031; AAPD: 4.475510120391846
Epoch [924/1500] - Loss: 668168.8125; AAPD: 4.145482063293457
Epoch [925/1500] - Loss: 228090.7500; AAPD: 4.0820631980896
Epoch [926/1500] - Loss: 2160103.5000; AAPD: 3.8072707653045654
Epoch [927/1500] - Loss: 461734.1875; AAPD: 4.226417064666748
Epoch [928/1500] - Loss: 380354.0000; AAPD: 3.9810638427734375
Epoch [929/1500] - Loss: 309782.3438; AAPD: 4.411983489990234
Epoch [930/1500] - Loss: 889111.6250; AAPD: 4.256272792816162
Epoch [931/1500] - Loss: 230680.3281; AAPD: 4.099393367767334
Epoch [932/1500] - Loss: 512195.4062; AAPD: 5.058780193328857
Epo

Epoch [1048/1500] - Loss: 3866366.5000; AAPD: 3.9277937412261963
Epoch [1049/1500] - Loss: 3534331.7500; AAPD: 3.94952654838562
Epoch [1050/1500] - Loss: 170157.8281; AAPD: 3.858051300048828
Epoch [1051/1500] - Loss: 517010.7812; AAPD: 3.8347933292388916
Epoch [1052/1500] - Loss: 4412915.0000; AAPD: 3.889968156814575
Epoch [1053/1500] - Loss: 347848.1562; AAPD: 4.1128106117248535
Epoch [1054/1500] - Loss: 12994450.0000; AAPD: 4.050858020782471
Epoch [1055/1500] - Loss: 954093.0000; AAPD: 3.948349714279175
Epoch [1056/1500] - Loss: 341407.7812; AAPD: 3.944875478744507
Epoch [1057/1500] - Loss: 775503.8750; AAPD: 4.017293930053711
Epoch [1058/1500] - Loss: 397531.5938; AAPD: 4.132177352905273
Epoch [1059/1500] - Loss: 991221.0625; AAPD: 3.9634366035461426
Epoch [1060/1500] - Loss: 201918.7812; AAPD: 3.9009461402893066
Epoch [1061/1500] - Loss: 310083.0312; AAPD: 3.977285146713257
Epoch [1062/1500] - Loss: 1483465.8750; AAPD: 3.9401679039001465
Epoch [1063/1500] - Loss: 4061576.5000; AAPD

Epoch [1177/1500] - Loss: 421506.1875; AAPD: 3.986497163772583
Epoch [1178/1500] - Loss: 1227464.3750; AAPD: 3.9434850215911865
Epoch [1179/1500] - Loss: 1820456.7500; AAPD: 4.061975002288818
Epoch [1180/1500] - Loss: 568085.6250; AAPD: 3.9447569847106934
Epoch [1181/1500] - Loss: 5516716.5000; AAPD: 3.8559210300445557
Epoch [1182/1500] - Loss: 179977.6406; AAPD: 3.874997615814209
Epoch [1183/1500] - Loss: 1694435.8750; AAPD: 3.956312656402588
Epoch [1184/1500] - Loss: 12822610.0000; AAPD: 3.9981706142425537
Epoch [1185/1500] - Loss: 5867705.5000; AAPD: 3.929990530014038
Epoch [1186/1500] - Loss: 1644417.8750; AAPD: 4.152729034423828
Epoch [1187/1500] - Loss: 5200207.5000; AAPD: 4.076344966888428
Epoch [1188/1500] - Loss: 1291520.1250; AAPD: 3.9489076137542725
Epoch [1189/1500] - Loss: 100176.8359; AAPD: 4.094366073608398
Epoch [1190/1500] - Loss: 5489873.0000; AAPD: 4.032687187194824
Epoch [1191/1500] - Loss: 1066349.5000; AAPD: 4.141897201538086
Epoch [1192/1500] - Loss: 25163284.000

Epoch [1306/1500] - Loss: 1761890.2500; AAPD: 4.006803512573242
Epoch [1307/1500] - Loss: 284282.7188; AAPD: 3.9590506553649902
Epoch [1308/1500] - Loss: 1468951.7500; AAPD: 3.9850494861602783
Epoch [1309/1500] - Loss: 3818097.5000; AAPD: 3.7101621627807617
Epoch [1310/1500] - Loss: 948222.9375; AAPD: 4.046522617340088
Epoch [1311/1500] - Loss: 258618.4688; AAPD: 3.808558940887451
Epoch [1312/1500] - Loss: 13480599.0000; AAPD: 3.9136557579040527
Epoch [1313/1500] - Loss: 1022505.3750; AAPD: 4.032165050506592
Epoch [1314/1500] - Loss: 493890.0625; AAPD: 3.9213624000549316
Epoch [1315/1500] - Loss: 993739.1250; AAPD: 4.047891139984131
Epoch [1316/1500] - Loss: 839776.8750; AAPD: 3.839411497116089
Epoch [1317/1500] - Loss: 588498.9375; AAPD: 3.987457275390625
Epoch [1318/1500] - Loss: 3447435.7500; AAPD: 4.074881553649902
Epoch [1319/1500] - Loss: 338725.9062; AAPD: 3.9275593757629395
Epoch [1320/1500] - Loss: 218877.6875; AAPD: 3.9608817100524902
Epoch [1321/1500] - Loss: 433599.5938; AA

Epoch [1435/1500] - Loss: 13615747.0000; AAPD: 3.960049629211426
Epoch [1436/1500] - Loss: 186796.4844; AAPD: 3.918022632598877
Epoch [1437/1500] - Loss: 296233.5312; AAPD: 4.001088619232178
Epoch [1438/1500] - Loss: 428691.7188; AAPD: 3.8268210887908936
Epoch [1439/1500] - Loss: 1757952.6250; AAPD: 3.966742515563965
Epoch [1440/1500] - Loss: 15551381.0000; AAPD: 3.9626545906066895
Epoch [1441/1500] - Loss: 42075288.0000; AAPD: 3.9791038036346436
Epoch [1442/1500] - Loss: 92649.8594; AAPD: 3.9353458881378174
Epoch [1443/1500] - Loss: 731737.4375; AAPD: 3.990036964416504
Epoch [1444/1500] - Loss: 142734.5469; AAPD: 3.812838554382324
Epoch [1445/1500] - Loss: 719657.0000; AAPD: 3.9003870487213135
Epoch [1446/1500] - Loss: 348611.6875; AAPD: 3.7413933277130127
Epoch [1447/1500] - Loss: 191460.9219; AAPD: 3.810758352279663
Epoch [1448/1500] - Loss: 3321711.7500; AAPD: 4.05122709274292
Epoch [1449/1500] - Loss: 687396.3750; AAPD: 3.9327917098999023
Epoch [1450/1500] - Loss: 250527.6719; AAP

In [44]:
model = IAModelLeakyReLU(input_dimension=train_int.shape[1], output_dimension=train_out2.shape[1])
model_MSE_Leaky_2 = train_model(model, 'best_model_MSE_Leaky_2.pth', train_dataloader, dev_dataloader, target_idx=2)
torch.save(model_MSE_Leaky_2.state_dict(), "model_MSE_Leaky_2.pth")

Epoch [1/1500] - Loss: 0.0038; AAPD: 32.544803619384766
Epoch [2/1500] - Loss: 0.0063; AAPD: 118.86566925048828
Epoch [3/1500] - Loss: 0.0090; AAPD: 74.2061767578125
Epoch [4/1500] - Loss: 0.0037; AAPD: 19.491539001464844
Epoch [5/1500] - Loss: 7356.0522; AAPD: 48599.0625
Epoch [6/1500] - Loss: 4333.5376; AAPD: 45713.08203125
Epoch [7/1500] - Loss: 6729.9272; AAPD: 53974.75
Epoch [8/1500] - Loss: 30240.6289; AAPD: 133160.875
Epoch [9/1500] - Loss: 173241.3281; AAPD: 265746.53125
Epoch [10/1500] - Loss: 1163.6346; AAPD: 29682.078125
Epoch [11/1500] - Loss: 84052.7344; AAPD: 384320.96875
Epoch [12/1500] - Loss: 63890.9922; AAPD: 316813.71875
Epoch [13/1500] - Loss: 258825.8281; AAPD: 653920.8125
Epoch [14/1500] - Loss: 420.8041; AAPD: 10983.158203125
Epoch [15/1500] - Loss: 587.1465; AAPD: 15237.9404296875
Epoch [16/1500] - Loss: 588.0569; AAPD: 14501.91015625
Epoch [17/1500] - Loss: 496.7965; AAPD: 14772.302734375
Epoch [18/1500] - Loss: 296.4091; AAPD: 10982.6884765625
Epoch [19/1500] 

Epoch [150/1500] - Loss: 5553.4189; AAPD: 70516.5703125
Epoch [151/1500] - Loss: 2081390.8750; AAPD: 1830995.875
Epoch [152/1500] - Loss: 42743.6016; AAPD: 150391.171875
Epoch [153/1500] - Loss: 1708114.1250; AAPD: 943099.1875
Epoch [154/1500] - Loss: 77629.7266; AAPD: 274323.09375
Epoch [155/1500] - Loss: 14544.5625; AAPD: 82038.0390625
Epoch [156/1500] - Loss: 2791.1631; AAPD: 27184.1328125
Epoch [157/1500] - Loss: 1468.2427; AAPD: 25559.978515625
Epoch [158/1500] - Loss: 2335.3840; AAPD: 26025.90234375
Epoch [159/1500] - Loss: 1766.6941; AAPD: 24353.77734375
Epoch [160/1500] - Loss: 1734.4882; AAPD: 26275.7109375
Epoch [161/1500] - Loss: 1819.1145; AAPD: 28668.22265625
Epoch [162/1500] - Loss: 1519.1458; AAPD: 30150.88671875
Epoch [163/1500] - Loss: 2687.6304; AAPD: 63833.6015625
Epoch [164/1500] - Loss: 3028187.7500; AAPD: 1946402.75
Epoch [165/1500] - Loss: 62510.1758; AAPD: 358963.875
Epoch [166/1500] - Loss: 369379.9688; AAPD: 913453.1875
Epoch [167/1500] - Loss: 9408.5332; AAPD

Epoch [297/1500] - Loss: 637.5671; AAPD: 34453.7734375
Epoch [298/1500] - Loss: 20.8344; AAPD: 4187.88525390625
Epoch [299/1500] - Loss: 320.0782; AAPD: 14068.7509765625
Epoch [300/1500] - Loss: 4248.5962; AAPD: 50778.9375
Epoch [301/1500] - Loss: 2695.6365; AAPD: 49882.8515625
Epoch [302/1500] - Loss: 823.6254; AAPD: 40414.12890625
Epoch [303/1500] - Loss: 2676.1160; AAPD: 48663.6015625
Epoch [304/1500] - Loss: 1153.9733; AAPD: 40050.9375
Epoch [305/1500] - Loss: 48.6596; AAPD: 10018.9775390625
Epoch [306/1500] - Loss: 348.5775; AAPD: 13677.8603515625
Epoch [307/1500] - Loss: 2.2600; AAPD: 1701.2193603515625
Epoch [308/1500] - Loss: 1.3765; AAPD: 1601.9847412109375
Epoch [309/1500] - Loss: 278.4616; AAPD: 12873.146484375
Epoch [310/1500] - Loss: 2761.1519; AAPD: 45682.734375
Epoch [311/1500] - Loss: 461.6495; AAPD: 13806.689453125
Epoch [312/1500] - Loss: 17.0951; AAPD: 15801.2919921875
Epoch [313/1500] - Loss: 155.6089; AAPD: 17358.55859375
Epoch [314/1500] - Loss: 2777.6931; AAPD: 5

Epoch [443/1500] - Loss: 137.7660; AAPD: 14431.259765625
Epoch [444/1500] - Loss: 89.7779; AAPD: 10107.76953125
Epoch [445/1500] - Loss: 9900.2432; AAPD: 102135.1640625
Epoch [446/1500] - Loss: 2763.0942; AAPD: 59750.13671875
Epoch [447/1500] - Loss: 94.7894; AAPD: 12754.5244140625
Epoch [448/1500] - Loss: 18.2385; AAPD: 5890.97265625
Epoch [449/1500] - Loss: 118.4049; AAPD: 6808.12939453125
Epoch [450/1500] - Loss: 107.0089; AAPD: 9674.0751953125
Epoch [451/1500] - Loss: 27.5524; AAPD: 6038.21484375
Epoch [452/1500] - Loss: 76.8041; AAPD: 9353.49609375
Epoch [453/1500] - Loss: 21.7044; AAPD: 3994.64697265625
Epoch [454/1500] - Loss: 33.7616; AAPD: 4997.05908203125
Epoch [455/1500] - Loss: 15.6559; AAPD: 3396.8857421875
Epoch [456/1500] - Loss: 231.6782; AAPD: 12217.34765625
Epoch [457/1500] - Loss: 10.7451; AAPD: 3221.490478515625
Epoch [458/1500] - Loss: 24.3625; AAPD: 6996.08740234375
Epoch [459/1500] - Loss: 252.9775; AAPD: 15382.6640625
Epoch [460/1500] - Loss: 14.8845; AAPD: 4173

Epoch [589/1500] - Loss: 0.2696; AAPD: 606.7752685546875
Epoch [590/1500] - Loss: 1.6271; AAPD: 2084.619873046875
Epoch [591/1500] - Loss: 16.3273; AAPD: 4104.62109375
Epoch [592/1500] - Loss: 0.3580; AAPD: 693.6063842773438
Epoch [593/1500] - Loss: 10.1171; AAPD: 3052.552490234375
Epoch [594/1500] - Loss: 0.2409; AAPD: 442.9262390136719
Epoch [595/1500] - Loss: 0.1759; AAPD: 545.4557495117188
Epoch [596/1500] - Loss: 2.1404; AAPD: 810.5060424804688
Epoch [597/1500] - Loss: 0.0908; AAPD: 358.85650634765625
Epoch [598/1500] - Loss: 5.3239; AAPD: 2297.4560546875
Epoch [599/1500] - Loss: 0.1539; AAPD: 453.28704833984375
Epoch [600/1500] - Loss: 0.0818; AAPD: 569.2755737304688
Epoch [601/1500] - Loss: 0.6263; AAPD: 708.27294921875
Epoch [602/1500] - Loss: 0.1383; AAPD: 440.3091125488281
Epoch [603/1500] - Loss: 0.1033; AAPD: 371.2383728027344
Epoch [604/1500] - Loss: 0.2319; AAPD: 306.0469665527344
Epoch [605/1500] - Loss: 0.7587; AAPD: 1335.43994140625
Epoch [606/1500] - Loss: 11.2569; AA

Epoch [734/1500] - Loss: 0.2122; AAPD: 490.9107666015625
Epoch [735/1500] - Loss: 0.0808; AAPD: 310.1225891113281
Epoch [736/1500] - Loss: 0.0679; AAPD: 397.9205627441406
Epoch [737/1500] - Loss: 0.6575; AAPD: 681.1031494140625
Epoch [738/1500] - Loss: 0.0428; AAPD: 164.63885498046875
Epoch [739/1500] - Loss: 10.6376; AAPD: 4381.06689453125
Epoch [740/1500] - Loss: 0.0421; AAPD: 264.13580322265625
Epoch [741/1500] - Loss: 0.0811; AAPD: 398.33673095703125
Epoch [742/1500] - Loss: 2.6664; AAPD: 2270.9697265625
Epoch [743/1500] - Loss: 5.5813; AAPD: 3614.266357421875
Epoch [744/1500] - Loss: 0.0366; AAPD: 213.20436096191406
Epoch [745/1500] - Loss: 0.0931; AAPD: 470.7882385253906
Epoch [746/1500] - Loss: 0.1019; AAPD: 350.4603271484375
Epoch [747/1500] - Loss: 0.8391; AAPD: 1132.0611572265625
Epoch [748/1500] - Loss: 8.0430; AAPD: 3680.238525390625
Epoch [749/1500] - Loss: 0.5815; AAPD: 891.3759155273438
Epoch [750/1500] - Loss: 0.1351; AAPD: 503.9610900878906
Epoch [751/1500] - Loss: 0.0

Epoch [877/1500] - Loss: 0.0281; AAPD: 235.36032104492188
Epoch [878/1500] - Loss: 0.0453; AAPD: 188.66529846191406
Epoch [879/1500] - Loss: 0.0321; AAPD: 188.28872680664062
Epoch [880/1500] - Loss: 0.0183; AAPD: 185.72586059570312
Epoch [881/1500] - Loss: 0.0377; AAPD: 241.11859130859375
Epoch [882/1500] - Loss: 0.0266; AAPD: 190.3685760498047
Epoch [883/1500] - Loss: 0.0168; AAPD: 181.65689086914062
Epoch [884/1500] - Loss: 0.0297; AAPD: 185.63279724121094
Epoch [885/1500] - Loss: 0.0386; AAPD: 187.56591796875
Epoch [886/1500] - Loss: 0.0279; AAPD: 217.59475708007812
Epoch [887/1500] - Loss: 0.0328; AAPD: 237.807861328125
Epoch [888/1500] - Loss: 0.0353; AAPD: 203.0758514404297
Epoch [889/1500] - Loss: 0.0516; AAPD: 226.2176513671875
Epoch [890/1500] - Loss: 0.0297; AAPD: 146.05796813964844
Epoch [891/1500] - Loss: 0.0214; AAPD: 193.8756561279297
Epoch [892/1500] - Loss: 0.0187; AAPD: 171.57913208007812
Epoch [893/1500] - Loss: 0.0598; AAPD: 175.1395263671875
Epoch [894/1500] - Loss:

Epoch [1020/1500] - Loss: 0.0216; AAPD: 131.35964965820312
Epoch [1021/1500] - Loss: 0.0163; AAPD: 121.12776947021484
Epoch [1022/1500] - Loss: 0.0150; AAPD: 103.21154022216797
Epoch [1023/1500] - Loss: 0.0093; AAPD: 112.48118591308594
Epoch [1024/1500] - Loss: 0.0075; AAPD: 117.25233459472656
Epoch [1025/1500] - Loss: 0.0138; AAPD: 108.86619567871094
Epoch [1026/1500] - Loss: 0.0109; AAPD: 151.5921630859375
Epoch [1027/1500] - Loss: 0.0137; AAPD: 192.38021850585938
Epoch [1028/1500] - Loss: 0.0154; AAPD: 113.367431640625
Epoch [1029/1500] - Loss: 0.0126; AAPD: 128.66806030273438
Epoch [1030/1500] - Loss: 0.0164; AAPD: 124.98925018310547
Epoch [1031/1500] - Loss: 0.0144; AAPD: 140.61508178710938
Epoch [1032/1500] - Loss: 0.0152; AAPD: 105.25028991699219
Epoch [1033/1500] - Loss: 0.0128; AAPD: 109.07317352294922
Epoch [1034/1500] - Loss: 0.0131; AAPD: 136.31185913085938
Epoch [1035/1500] - Loss: 0.0112; AAPD: 145.6522674560547
Epoch [1036/1500] - Loss: 0.0279; AAPD: 169.16285705566406
E

Epoch [1160/1500] - Loss: 0.0171; AAPD: 158.8178253173828
Epoch [1161/1500] - Loss: 0.0173; AAPD: 123.40058898925781
Epoch [1162/1500] - Loss: 0.0136; AAPD: 100.99160766601562
Epoch [1163/1500] - Loss: 0.0089; AAPD: 131.41006469726562
Epoch [1164/1500] - Loss: 0.0054; AAPD: 101.68572998046875
Epoch [1165/1500] - Loss: 0.0079; AAPD: 151.6176300048828
Epoch [1166/1500] - Loss: 0.0107; AAPD: 100.46438598632812
Epoch [1167/1500] - Loss: 0.0220; AAPD: 102.7774887084961
Epoch [1168/1500] - Loss: 0.0115; AAPD: 93.31050109863281
Epoch [1169/1500] - Loss: 0.0089; AAPD: 96.13899993896484
Epoch [1170/1500] - Loss: 0.0127; AAPD: 141.7666015625
Epoch [1171/1500] - Loss: 0.0121; AAPD: 102.42520904541016
Epoch [1172/1500] - Loss: 0.0083; AAPD: 94.71012878417969
Epoch [1173/1500] - Loss: 0.0177; AAPD: 127.20085144042969
Epoch [1174/1500] - Loss: 0.0064; AAPD: 121.57125854492188
Epoch [1175/1500] - Loss: 0.0138; AAPD: 109.63922119140625
Epoch [1176/1500] - Loss: 0.0076; AAPD: 115.9796142578125
Epoch [1

Epoch [1301/1500] - Loss: 0.0112; AAPD: 105.21820068359375
Epoch [1302/1500] - Loss: 0.0125; AAPD: 101.5316162109375
Epoch [1303/1500] - Loss: 0.0073; AAPD: 94.72274017333984
Epoch [1304/1500] - Loss: 0.0086; AAPD: 96.92962646484375
Epoch [1305/1500] - Loss: 0.0127; AAPD: 118.32318115234375
Epoch [1306/1500] - Loss: 0.0124; AAPD: 121.04914855957031
Epoch [1307/1500] - Loss: 0.0112; AAPD: 99.29537963867188
Epoch [1308/1500] - Loss: 0.0126; AAPD: 111.46559143066406
Epoch [1309/1500] - Loss: 0.0125; AAPD: 130.98831176757812
Epoch [1310/1500] - Loss: 0.0084; AAPD: 108.09992218017578
Epoch [1311/1500] - Loss: 0.0117; AAPD: 89.6982192993164
Epoch [1312/1500] - Loss: 0.0081; AAPD: 85.75808715820312
Epoch [1313/1500] - Loss: 0.0052; AAPD: 90.83409118652344
Epoch [1314/1500] - Loss: 0.0105; AAPD: 90.02073669433594
Epoch [1315/1500] - Loss: 0.0075; AAPD: 117.00071716308594
Epoch [1316/1500] - Loss: 0.0088; AAPD: 123.19351196289062
Epoch [1317/1500] - Loss: 0.0145; AAPD: 92.50948333740234
Epoch [

Epoch [1442/1500] - Loss: 0.0118; AAPD: 97.12686157226562
Epoch [1443/1500] - Loss: 0.0081; AAPD: 108.2367172241211
Epoch [1444/1500] - Loss: 0.0132; AAPD: 92.99624633789062
Epoch [1445/1500] - Loss: 0.0079; AAPD: 106.94676971435547
Epoch [1446/1500] - Loss: 0.0093; AAPD: 92.40965270996094
Epoch [1447/1500] - Loss: 0.0085; AAPD: 99.79573822021484
Epoch [1448/1500] - Loss: 0.0105; AAPD: 109.47025299072266
Epoch [1449/1500] - Loss: 0.0095; AAPD: 130.107421875
Epoch [1450/1500] - Loss: 0.0088; AAPD: 109.53996276855469
Epoch [1451/1500] - Loss: 0.0086; AAPD: 107.86792755126953
Epoch [1452/1500] - Loss: 0.0073; AAPD: 89.12548828125
Epoch [1453/1500] - Loss: 0.0121; AAPD: 90.64733123779297
Epoch [1454/1500] - Loss: 0.0093; AAPD: 98.98939514160156
Epoch [1455/1500] - Loss: 0.0068; AAPD: 101.57489013671875
Epoch [1456/1500] - Loss: 0.0082; AAPD: 92.86619567871094
Epoch [1457/1500] - Loss: 0.0084; AAPD: 112.9929428100586
Epoch [1458/1500] - Loss: 0.0104; AAPD: 94.11241912841797
Epoch [1459/1500

In [45]:
model = IAModelLeakyReLU(input_dimension=train_int.shape[1], output_dimension=train_out3.shape[1])
model_MSE_Leaky_3 = train_model(model, 'best_model_MSE_Leaky_3.pth', train_dataloader, dev_dataloader, target_idx=3)
torch.save(model_MSE_Leaky_3.state_dict(), "model_MSE_Leaky_3.pth")

Epoch [1/1500] - Loss: 0.3175; AAPD: 7710.18994140625
Epoch [2/1500] - Loss: 0.0029; AAPD: 335.44097900390625
Epoch [3/1500] - Loss: 147.4273; AAPD: 201308.546875
Epoch [4/1500] - Loss: 566.9792; AAPD: 157540.234375
Epoch [5/1500] - Loss: 125.0837; AAPD: 167626.0625
Epoch [6/1500] - Loss: 272.1854; AAPD: 123170.4140625
Epoch [7/1500] - Loss: 263.8968; AAPD: 220452.796875
Epoch [8/1500] - Loss: 5912.9541; AAPD: 967685.25
Epoch [9/1500] - Loss: 2830.1025; AAPD: 809627.75
Epoch [10/1500] - Loss: 46.5882; AAPD: 151911.703125
Epoch [11/1500] - Loss: 7099.0610; AAPD: 851669.625
Epoch [12/1500] - Loss: 12346.0186; AAPD: 874623.875
Epoch [13/1500] - Loss: 78.9000; AAPD: 109716.4375
Epoch [14/1500] - Loss: 37.1650; AAPD: 144925.15625
Epoch [15/1500] - Loss: 5812.9341; AAPD: 850084.0
Epoch [16/1500] - Loss: 335.4174; AAPD: 50149.3828125
Epoch [17/1500] - Loss: 9.3039; AAPD: 21164.46484375
Epoch [18/1500] - Loss: 0.9089; AAPD: 12456.9462890625
Epoch [19/1500] - Loss: 783.1949; AAPD: 229592.8125
E

Epoch [155/1500] - Loss: 572612.3750; AAPD: 4584479.5
Epoch [156/1500] - Loss: 36001.0469; AAPD: 1590132.125
Epoch [157/1500] - Loss: 8423.3594; AAPD: 1158035.25
Epoch [158/1500] - Loss: 21000.6582; AAPD: 1152832.25
Epoch [159/1500] - Loss: 17745.9434; AAPD: 1129252.5
Epoch [160/1500] - Loss: 12322.8037; AAPD: 1148426.625
Epoch [161/1500] - Loss: 24740.1367; AAPD: 1079744.375
Epoch [162/1500] - Loss: 15531.2305; AAPD: 1029849.0625
Epoch [163/1500] - Loss: 5446.6885; AAPD: 970990.1875
Epoch [164/1500] - Loss: 10631.7998; AAPD: 881019.1875
Epoch [165/1500] - Loss: 8232.3037; AAPD: 763872.8125
Epoch [166/1500] - Loss: 4913.8872; AAPD: 717775.25
Epoch [167/1500] - Loss: 12414.7070; AAPD: 873552.6875
Epoch [168/1500] - Loss: 86043.7578; AAPD: 1331793.0
Epoch [169/1500] - Loss: 3790.2756; AAPD: 976545.0
Epoch [170/1500] - Loss: 49466.5781; AAPD: 2479893.75
Epoch [171/1500] - Loss: 23252.8652; AAPD: 1722773.25
Epoch [172/1500] - Loss: 19043308.0000; AAPD: 44786736.0
Epoch [173/1500] - Loss: 1

Epoch [307/1500] - Loss: 3789.5183; AAPD: 443922.71875
Epoch [308/1500] - Loss: 850.1287; AAPD: 291786.625
Epoch [309/1500] - Loss: 12641.9990; AAPD: 869260.5625
Epoch [310/1500] - Loss: 99.8668; AAPD: 124283.9453125
Epoch [311/1500] - Loss: 1077.3856; AAPD: 428650.03125
Epoch [312/1500] - Loss: 726.2256; AAPD: 239064.015625
Epoch [313/1500] - Loss: 6947.7900; AAPD: 801926.0
Epoch [314/1500] - Loss: 4946.9810; AAPD: 595379.4375
Epoch [315/1500] - Loss: 277.8846; AAPD: 141806.03125
Epoch [316/1500] - Loss: 404.9545; AAPD: 209885.53125
Epoch [317/1500] - Loss: 251.9924; AAPD: 336040.125
Epoch [318/1500] - Loss: 2116.6528; AAPD: 544909.5625
Epoch [319/1500] - Loss: 1433.2426; AAPD: 449706.0
Epoch [320/1500] - Loss: 6275.6802; AAPD: 1037431.5
Epoch [321/1500] - Loss: 140.3458; AAPD: 101523.1953125
Epoch [322/1500] - Loss: 1867.1200; AAPD: 425201.0625
Epoch [323/1500] - Loss: 1015.3354; AAPD: 340520.71875
Epoch [324/1500] - Loss: 8618.5605; AAPD: 1078821.375
Epoch [325/1500] - Loss: 31.3791

Epoch [460/1500] - Loss: 7.2030; AAPD: 20310.31640625
Epoch [461/1500] - Loss: 30.7841; AAPD: 56886.453125
Epoch [462/1500] - Loss: 0.4151; AAPD: 5205.208984375
Epoch [463/1500] - Loss: 0.2738; AAPD: 6539.19873046875
Epoch [464/1500] - Loss: 0.2040; AAPD: 4784.1533203125
Epoch [465/1500] - Loss: 16.2712; AAPD: 27530.01171875
Epoch [466/1500] - Loss: 1.2011; AAPD: 21073.6953125
Epoch [467/1500] - Loss: 10.8150; AAPD: 23303.396484375
Epoch [468/1500] - Loss: 0.8219; AAPD: 9674.37890625
Epoch [469/1500] - Loss: 33.2044; AAPD: 48945.5703125
Epoch [470/1500] - Loss: 4.4909; AAPD: 22368.458984375
Epoch [471/1500] - Loss: 1.5646; AAPD: 16527.51953125
Epoch [472/1500] - Loss: 0.1982; AAPD: 3681.601806640625
Epoch [473/1500] - Loss: 0.0982; AAPD: 3346.38427734375
Epoch [474/1500] - Loss: 25.4756; AAPD: 56747.35546875
Epoch [475/1500] - Loss: 1.5382; AAPD: 13218.3369140625
Epoch [476/1500] - Loss: 2.8265; AAPD: 13014.7236328125
Epoch [477/1500] - Loss: 0.2798; AAPD: 5474.00537109375
Epoch [478/1

Epoch [607/1500] - Loss: 0.0304; AAPD: 2011.696533203125
Epoch [608/1500] - Loss: 0.1005; AAPD: 2907.4208984375
Epoch [609/1500] - Loss: 0.0377; AAPD: 2230.307861328125
Epoch [610/1500] - Loss: 0.0717; AAPD: 3304.105224609375
Epoch [611/1500] - Loss: 0.0336; AAPD: 1857.994873046875
Epoch [612/1500] - Loss: 0.2080; AAPD: 5054.875
Epoch [613/1500] - Loss: 0.1471; AAPD: 3221.671875
Epoch [614/1500] - Loss: 0.0411; AAPD: 2292.166259765625
Epoch [615/1500] - Loss: 0.0297; AAPD: 1608.410888671875
Epoch [616/1500] - Loss: 0.0534; AAPD: 1755.661865234375
Epoch [617/1500] - Loss: 0.0456; AAPD: 3193.1767578125
Epoch [618/1500] - Loss: 0.1162; AAPD: 4250.5751953125
Epoch [619/1500] - Loss: 0.0813; AAPD: 1602.99169921875
Epoch [620/1500] - Loss: 0.1514; AAPD: 3994.428955078125
Epoch [621/1500] - Loss: 0.0344; AAPD: 2489.38525390625
Epoch [622/1500] - Loss: 0.1034; AAPD: 4538.841796875
Epoch [623/1500] - Loss: 0.0294; AAPD: 1387.7841796875
Epoch [624/1500] - Loss: 0.0345; AAPD: 1481.9405517578125
E

Epoch [754/1500] - Loss: 0.0068; AAPD: 524.6096801757812
Epoch [755/1500] - Loss: 0.0025; AAPD: 491.60479736328125
Epoch [756/1500] - Loss: 0.0020; AAPD: 424.2773742675781
Epoch [757/1500] - Loss: 0.0041; AAPD: 700.4921875
Epoch [758/1500] - Loss: 0.0044; AAPD: 716.969970703125
Epoch [759/1500] - Loss: 0.0068; AAPD: 903.2111206054688
Epoch [760/1500] - Loss: 0.0039; AAPD: 534.2553100585938
Epoch [761/1500] - Loss: 0.0034; AAPD: 629.3564453125
Epoch [762/1500] - Loss: 0.0088; AAPD: 561.0694580078125
Epoch [763/1500] - Loss: 0.0075; AAPD: 631.5448608398438
Epoch [764/1500] - Loss: 0.0036; AAPD: 625.3375244140625
Epoch [765/1500] - Loss: 0.0062; AAPD: 593.0667114257812
Epoch [766/1500] - Loss: 0.0050; AAPD: 547.6062622070312
Epoch [767/1500] - Loss: 0.0085; AAPD: 969.913818359375
Epoch [768/1500] - Loss: 0.0060; AAPD: 664.8218994140625
Epoch [769/1500] - Loss: 0.0073; AAPD: 633.9390258789062
Epoch [770/1500] - Loss: 0.0040; AAPD: 570.2659912109375
Epoch [771/1500] - Loss: 0.0051; AAPD: 67

Epoch [899/1500] - Loss: 0.0019; AAPD: 453.67852783203125
Epoch [900/1500] - Loss: 0.0055; AAPD: 652.0042114257812
Epoch [901/1500] - Loss: 0.0039; AAPD: 461.427978515625
Epoch [902/1500] - Loss: 0.0036; AAPD: 639.4188232421875
Epoch [903/1500] - Loss: 0.0045; AAPD: 626.8103637695312
Epoch [904/1500] - Loss: 0.0039; AAPD: 450.9974365234375
Epoch [905/1500] - Loss: 0.0034; AAPD: 616.5390625
Epoch [906/1500] - Loss: 0.0038; AAPD: 471.9814758300781
Epoch [907/1500] - Loss: 0.0053; AAPD: 970.41650390625
Epoch [908/1500] - Loss: 0.0042; AAPD: 440.9172058105469
Epoch [909/1500] - Loss: 0.0025; AAPD: 476.70782470703125
Epoch [910/1500] - Loss: 0.0053; AAPD: 648.5313110351562
Epoch [911/1500] - Loss: 0.0016; AAPD: 372.9001159667969
Epoch [912/1500] - Loss: 0.0036; AAPD: 486.56744384765625
Epoch [913/1500] - Loss: 0.0023; AAPD: 471.1046142578125
Epoch [914/1500] - Loss: 0.0056; AAPD: 752.4470825195312
Epoch [915/1500] - Loss: 0.0037; AAPD: 423.76080322265625
Epoch [916/1500] - Loss: 0.0044; AAP

Epoch [1043/1500] - Loss: 0.0021; AAPD: 303.1966552734375
Epoch [1044/1500] - Loss: 0.0013; AAPD: 254.43251037597656
Epoch [1045/1500] - Loss: 0.0013; AAPD: 255.68994140625
Epoch [1046/1500] - Loss: 0.0023; AAPD: 337.1226806640625
Epoch [1047/1500] - Loss: 0.0011; AAPD: 260.2375793457031
Epoch [1048/1500] - Loss: 0.0031; AAPD: 315.5895690917969
Epoch [1049/1500] - Loss: 0.0017; AAPD: 287.0720520019531
Epoch [1050/1500] - Loss: 0.0049; AAPD: 289.61199951171875
Epoch [1051/1500] - Loss: 0.0011; AAPD: 266.78302001953125
Epoch [1052/1500] - Loss: 0.0025; AAPD: 262.71136474609375
Epoch [1053/1500] - Loss: 0.0014; AAPD: 274.5184020996094
Epoch [1054/1500] - Loss: 0.0013; AAPD: 278.5005187988281
Epoch [1055/1500] - Loss: 0.0015; AAPD: 257.6500244140625
Epoch [1056/1500] - Loss: 0.0035; AAPD: 307.9548645019531
Epoch [1057/1500] - Loss: 0.0009; AAPD: 270.494140625
Epoch [1058/1500] - Loss: 0.0021; AAPD: 288.93670654296875
Epoch [1059/1500] - Loss: 0.0012; AAPD: 251.38748168945312
Epoch [1060/15

Epoch [1185/1500] - Loss: 0.0015; AAPD: 329.4267883300781
Epoch [1186/1500] - Loss: 0.0015; AAPD: 246.45359802246094
Epoch [1187/1500] - Loss: 0.0014; AAPD: 253.9798583984375
Epoch [1188/1500] - Loss: 0.0032; AAPD: 361.8206787109375
Epoch [1189/1500] - Loss: 0.0013; AAPD: 258.5246276855469
Epoch [1190/1500] - Loss: 0.0015; AAPD: 294.1305847167969
Epoch [1191/1500] - Loss: 0.0017; AAPD: 234.2801971435547
Epoch [1192/1500] - Loss: 0.0009; AAPD: 230.5029296875
Epoch [1193/1500] - Loss: 0.0012; AAPD: 236.14830017089844
Epoch [1194/1500] - Loss: 0.0012; AAPD: 233.05340576171875
Epoch [1195/1500] - Loss: 0.0027; AAPD: 244.65310668945312
Epoch [1196/1500] - Loss: 0.0014; AAPD: 256.1174621582031
Epoch [1197/1500] - Loss: 0.0011; AAPD: 243.62648010253906
Epoch [1198/1500] - Loss: 0.0033; AAPD: 274.80474853515625
Epoch [1199/1500] - Loss: 0.0008; AAPD: 236.78555297851562
Epoch [1200/1500] - Loss: 0.0026; AAPD: 250.4051971435547
Epoch [1201/1500] - Loss: 0.0016; AAPD: 260.754638671875
Epoch [1202

Epoch [1326/1500] - Loss: 0.0018; AAPD: 276.7638854980469
Epoch [1327/1500] - Loss: 0.0009; AAPD: 239.8800506591797
Epoch [1328/1500] - Loss: 0.0011; AAPD: 230.89447021484375
Epoch [1329/1500] - Loss: 0.0034; AAPD: 286.7662658691406
Epoch [1330/1500] - Loss: 0.0021; AAPD: 223.17880249023438
Epoch [1331/1500] - Loss: 0.0043; AAPD: 281.4814758300781
Epoch [1332/1500] - Loss: 0.0016; AAPD: 265.0247497558594
Epoch [1333/1500] - Loss: 0.0036; AAPD: 248.96768188476562
Epoch [1334/1500] - Loss: 0.0008; AAPD: 239.31019592285156
Epoch [1335/1500] - Loss: 0.0007; AAPD: 232.145751953125
Epoch [1336/1500] - Loss: 0.0015; AAPD: 228.21066284179688
Epoch [1337/1500] - Loss: 0.0018; AAPD: 220.33456420898438
Epoch [1338/1500] - Loss: 0.0013; AAPD: 268.177978515625
Epoch [1339/1500] - Loss: 0.0009; AAPD: 229.4944305419922
Epoch [1340/1500] - Loss: 0.0009; AAPD: 206.42828369140625
Epoch [1341/1500] - Loss: 0.0014; AAPD: 216.24420166015625
Epoch [1342/1500] - Loss: 0.0020; AAPD: 239.1339569091797
Epoch [1

Epoch [1467/1500] - Loss: 0.0013; AAPD: 206.14712524414062
Epoch [1468/1500] - Loss: 0.0020; AAPD: 270.4286193847656
Epoch [1469/1500] - Loss: 0.0010; AAPD: 209.2259063720703
Epoch [1470/1500] - Loss: 0.0012; AAPD: 238.78797912597656
Epoch [1471/1500] - Loss: 0.0017; AAPD: 223.548828125
Epoch [1472/1500] - Loss: 0.0006; AAPD: 191.2397003173828
Epoch [1473/1500] - Loss: 0.0011; AAPD: 256.3915710449219
Epoch [1474/1500] - Loss: 0.0008; AAPD: 230.20697021484375
Epoch [1475/1500] - Loss: 0.0042; AAPD: 237.97393798828125
Epoch [1476/1500] - Loss: 0.0009; AAPD: 240.28155517578125
Epoch [1477/1500] - Loss: 0.0016; AAPD: 209.41900634765625
Epoch [1478/1500] - Loss: 0.0033; AAPD: 306.7059631347656
Epoch [1479/1500] - Loss: 0.0038; AAPD: 272.2531433105469
Epoch [1480/1500] - Loss: 0.0012; AAPD: 199.89236450195312
Epoch [1481/1500] - Loss: 0.0011; AAPD: 227.3121795654297
Epoch [1482/1500] - Loss: 0.0010; AAPD: 212.11279296875
Epoch [1483/1500] - Loss: 0.0011; AAPD: 204.3422393798828
Epoch [1484/1

In [47]:
model = IAModelLeakyReLU(input_dimension=train_int.shape[1], output_dimension=train_out1.shape[1])
model_AAPD_Leaky_1 = train_model(model, 'best_model_AAPD_Leaky_1.pth', train_dataloader, dev_dataloader, target_idx=1)
torch.save(model_AAPD_Leaky_1.state_dict(), "model_AAPD_Leaky_1.pth")

Epoch [1/1500] - Loss: 0.7166; AAPD: 0.5851823091506958
Epoch [2/1500] - Loss: 0.5199; AAPD: 0.666333794593811
Epoch [3/1500] - Loss: 0.6303; AAPD: 0.6103214025497437
Epoch [4/1500] - Loss: 0.5232; AAPD: 0.5869712233543396
Epoch [5/1500] - Loss: 0.6552; AAPD: 0.6138136386871338
Epoch [6/1500] - Loss: 0.6675; AAPD: 0.5608230829238892
Epoch [7/1500] - Loss: 0.5563; AAPD: 0.5924200415611267
Epoch [8/1500] - Loss: 116.3908; AAPD: 158.45579528808594
Epoch [9/1500] - Loss: 91.2725; AAPD: 80.94341278076172
Epoch [10/1500] - Loss: 90.7566; AAPD: 38.227577209472656
Epoch [11/1500] - Loss: 72.2557; AAPD: 61.4130973815918
Epoch [12/1500] - Loss: 348.3586; AAPD: 266.05218505859375
Epoch [13/1500] - Loss: 22.1524; AAPD: 41.807769775390625
Epoch [14/1500] - Loss: 20.7119; AAPD: 20.29962158203125
Epoch [15/1500] - Loss: 16.3702; AAPD: 13.755488395690918
Epoch [16/1500] - Loss: 750.9442; AAPD: 1365.4560546875
Epoch [17/1500] - Loss: 62.7072; AAPD: 63.27603530883789
Epoch [18/1500] - Loss: 580.7094; AA

Epoch [144/1500] - Loss: 74.9297; AAPD: 67.3440933227539
Epoch [145/1500] - Loss: 19.0926; AAPD: 33.11275100708008
Epoch [146/1500] - Loss: 24.2234; AAPD: 37.72980499267578
Epoch [147/1500] - Loss: 24.4428; AAPD: 23.10069465637207
Epoch [148/1500] - Loss: 80.3000; AAPD: 29.955272674560547
Epoch [149/1500] - Loss: 181.0479; AAPD: 108.8330078125
Epoch [150/1500] - Loss: 60.5815; AAPD: 64.04488372802734
Epoch [151/1500] - Loss: 21.4498; AAPD: 24.794214248657227
Epoch [152/1500] - Loss: 31.0676; AAPD: 42.43649673461914
Epoch [153/1500] - Loss: 5.7307; AAPD: 7.146366596221924
Epoch [154/1500] - Loss: 109.4848; AAPD: 89.28895568847656
Epoch [155/1500] - Loss: 19.0543; AAPD: 35.71908950805664
Epoch [156/1500] - Loss: 11.0614; AAPD: 9.394319534301758
Epoch [157/1500] - Loss: 5.6702; AAPD: 6.650601387023926
Epoch [158/1500] - Loss: 21.2990; AAPD: 31.88852882385254
Epoch [159/1500] - Loss: 13.1251; AAPD: 10.14181137084961
Epoch [160/1500] - Loss: 10.3136; AAPD: 12.066737174987793
Epoch [161/1500

Epoch [286/1500] - Loss: 0.6036; AAPD: 0.8160324692726135
Epoch [287/1500] - Loss: 0.5041; AAPD: 0.691678524017334
Epoch [288/1500] - Loss: 0.8135; AAPD: 0.6940560936927795
Epoch [289/1500] - Loss: 0.7156; AAPD: 0.6350792050361633
Epoch [290/1500] - Loss: 0.6761; AAPD: 0.6209929585456848
Epoch [291/1500] - Loss: 0.7433; AAPD: 0.5956257581710815
Epoch [292/1500] - Loss: 0.6296; AAPD: 0.622054398059845
Epoch [293/1500] - Loss: 0.7031; AAPD: 0.7054399251937866
Epoch [294/1500] - Loss: 0.6323; AAPD: 0.5869311690330505
Epoch [295/1500] - Loss: 0.7010; AAPD: 0.6269994378089905
Epoch [296/1500] - Loss: 0.6540; AAPD: 0.6453002691268921
Epoch [297/1500] - Loss: 0.5800; AAPD: 0.7170726656913757
Epoch [298/1500] - Loss: 0.6646; AAPD: 0.6086488366127014
Epoch [299/1500] - Loss: 0.5843; AAPD: 0.6047635674476624
Epoch [300/1500] - Loss: 0.4855; AAPD: 0.6980475187301636
Epoch [301/1500] - Loss: 0.6875; AAPD: 0.6463329792022705
Epoch [302/1500] - Loss: 0.8567; AAPD: 0.7223326563835144
Epoch [303/1500]

Epoch [428/1500] - Loss: 0.5700; AAPD: 0.5941886901855469
Epoch [429/1500] - Loss: 0.5597; AAPD: 0.5935763716697693
Epoch [430/1500] - Loss: 0.5824; AAPD: 0.5928866267204285
Epoch [431/1500] - Loss: 0.6048; AAPD: 0.5995749235153198
Epoch [432/1500] - Loss: 0.4898; AAPD: 0.5945943593978882
Epoch [433/1500] - Loss: 0.5911; AAPD: 0.5959749221801758
Epoch [434/1500] - Loss: 0.6262; AAPD: 0.5939836502075195
Epoch [435/1500] - Loss: 0.5472; AAPD: 0.5925996899604797
Epoch [436/1500] - Loss: 0.6412; AAPD: 0.593814492225647
Epoch [437/1500] - Loss: 0.6421; AAPD: 0.5952969193458557
Epoch [438/1500] - Loss: 0.5678; AAPD: 0.5935941338539124
Epoch [439/1500] - Loss: 0.6290; AAPD: 0.5935698747634888
Epoch [440/1500] - Loss: 0.6011; AAPD: 0.5956853032112122
Epoch [441/1500] - Loss: 0.5525; AAPD: 0.5975547432899475
Epoch [442/1500] - Loss: 0.4520; AAPD: 0.5941542983055115
Epoch [443/1500] - Loss: 0.5118; AAPD: 0.5951147675514221
Epoch [444/1500] - Loss: 0.6929; AAPD: 0.5983401536941528
Epoch [445/1500

Epoch [570/1500] - Loss: 0.3936; AAPD: 0.453373521566391
Epoch [571/1500] - Loss: 0.5915; AAPD: 0.45552462339401245
Epoch [572/1500] - Loss: 0.3992; AAPD: 0.45580893754959106
Epoch [573/1500] - Loss: 0.4107; AAPD: 0.45298391580581665
Epoch [574/1500] - Loss: 0.4805; AAPD: 0.4529327154159546
Epoch [575/1500] - Loss: 0.4878; AAPD: 0.4518577456474304
Epoch [576/1500] - Loss: 0.4828; AAPD: 0.4523601233959198
Epoch [577/1500] - Loss: 0.4858; AAPD: 0.46073397994041443
Epoch [578/1500] - Loss: 0.4919; AAPD: 0.452897846698761
Epoch [579/1500] - Loss: 0.4004; AAPD: 0.460168719291687
Epoch [580/1500] - Loss: 0.4585; AAPD: 0.4553435146808624
Epoch [581/1500] - Loss: 0.4831; AAPD: 0.4518151581287384
Epoch [582/1500] - Loss: 0.3466; AAPD: 0.44869959354400635
Epoch [583/1500] - Loss: 0.4818; AAPD: 0.44684767723083496
Epoch [584/1500] - Loss: 0.5184; AAPD: 0.4546654224395752
Epoch [585/1500] - Loss: 0.4337; AAPD: 0.456073135137558
Epoch [586/1500] - Loss: 0.5283; AAPD: 0.46658414602279663
Epoch [587/

Epoch [711/1500] - Loss: 0.4551; AAPD: 0.41345667839050293
Epoch [712/1500] - Loss: 0.4383; AAPD: 0.43039101362228394
Epoch [713/1500] - Loss: 0.3610; AAPD: 0.42482084035873413
Epoch [714/1500] - Loss: 0.3642; AAPD: 0.42000865936279297
Epoch [715/1500] - Loss: 0.3809; AAPD: 0.42090585827827454
Epoch [716/1500] - Loss: 0.4458; AAPD: 0.43007656931877136
Epoch [717/1500] - Loss: 0.3793; AAPD: 0.41846126317977905
Epoch [718/1500] - Loss: 0.4466; AAPD: 0.4186999201774597
Epoch [719/1500] - Loss: 0.3316; AAPD: 0.41493403911590576
Epoch [720/1500] - Loss: 0.3913; AAPD: 0.41953545808792114
Epoch [721/1500] - Loss: 0.4194; AAPD: 0.4113923907279968
Epoch [722/1500] - Loss: 0.4476; AAPD: 0.443976491689682
Epoch [723/1500] - Loss: 0.4643; AAPD: 0.4139038026332855
Epoch [724/1500] - Loss: 0.4600; AAPD: 0.43692857027053833
Epoch [725/1500] - Loss: 0.4175; AAPD: 0.4100925624370575
Epoch [726/1500] - Loss: 0.3435; AAPD: 0.40642374753952026
Epoch [727/1500] - Loss: 0.4367; AAPD: 0.4346056580543518
Epoc

Epoch [852/1500] - Loss: 0.4377; AAPD: 0.38180023431777954
Epoch [853/1500] - Loss: 0.3320; AAPD: 0.38204601407051086
Epoch [854/1500] - Loss: 0.3662; AAPD: 0.3807382881641388
Epoch [855/1500] - Loss: 0.3516; AAPD: 0.3818840980529785
Epoch [856/1500] - Loss: 0.4047; AAPD: 0.3823440670967102
Epoch [857/1500] - Loss: 0.4123; AAPD: 0.3813249170780182
Epoch [858/1500] - Loss: 0.4192; AAPD: 0.3792289197444916
Epoch [859/1500] - Loss: 0.4152; AAPD: 0.3796774446964264
Epoch [860/1500] - Loss: 0.3293; AAPD: 0.3797742426395416
Epoch [861/1500] - Loss: 0.3947; AAPD: 0.37958359718322754
Epoch [862/1500] - Loss: 0.4229; AAPD: 0.37889209389686584
Epoch [863/1500] - Loss: 0.3489; AAPD: 0.37929946184158325
Epoch [864/1500] - Loss: 0.3358; AAPD: 0.3788304328918457
Epoch [865/1500] - Loss: 0.2913; AAPD: 0.3789653182029724
Epoch [866/1500] - Loss: 0.3928; AAPD: 0.37941527366638184
Epoch [867/1500] - Loss: 0.2964; AAPD: 0.3789919316768646
Epoch [868/1500] - Loss: 0.3401; AAPD: 0.37921205163002014
Epoch [

Epoch [993/1500] - Loss: 0.3106; AAPD: 0.34423261880874634
Epoch [994/1500] - Loss: 0.3552; AAPD: 0.34648212790489197
Epoch [995/1500] - Loss: 0.3590; AAPD: 0.34482526779174805
Epoch [996/1500] - Loss: 0.3428; AAPD: 0.3429204225540161
Epoch [997/1500] - Loss: 0.3942; AAPD: 0.34296971559524536
Epoch [998/1500] - Loss: 0.6161; AAPD: 0.34392067790031433
Epoch [999/1500] - Loss: 0.3007; AAPD: 0.3432592749595642
Epoch [1000/1500] - Loss: 0.2801; AAPD: 0.3421720564365387
Epoch [1001/1500] - Loss: 0.2821; AAPD: 0.3411102592945099
Epoch [1002/1500] - Loss: 0.3010; AAPD: 0.34125351905822754
Epoch [1003/1500] - Loss: 0.3508; AAPD: 0.3410276472568512
Epoch [1004/1500] - Loss: 0.2912; AAPD: 0.3410745859146118
Epoch [1005/1500] - Loss: 0.3006; AAPD: 0.34099411964416504
Epoch [1006/1500] - Loss: 0.3671; AAPD: 0.3410177528858185
Epoch [1007/1500] - Loss: 0.3753; AAPD: 0.34116870164871216
Epoch [1008/1500] - Loss: 0.2964; AAPD: 0.3410249650478363
Epoch [1009/1500] - Loss: 0.3730; AAPD: 0.3410718142986

Epoch [1132/1500] - Loss: 0.2728; AAPD: 0.337746262550354
Epoch [1133/1500] - Loss: 0.2831; AAPD: 0.3372507691383362
Epoch [1134/1500] - Loss: 0.2856; AAPD: 0.33728504180908203
Epoch [1135/1500] - Loss: 0.3457; AAPD: 0.3371683359146118
Epoch [1136/1500] - Loss: 0.3666; AAPD: 0.3372194170951843
Epoch [1137/1500] - Loss: 0.3630; AAPD: 0.337156742811203
Epoch [1138/1500] - Loss: 0.3123; AAPD: 0.3372211456298828
Epoch [1139/1500] - Loss: 0.4066; AAPD: 0.33709990978240967
Epoch [1140/1500] - Loss: 0.2695; AAPD: 0.33715173602104187
Epoch [1141/1500] - Loss: 0.4292; AAPD: 0.337114155292511
Epoch [1142/1500] - Loss: 0.3537; AAPD: 0.3371366858482361
Epoch [1143/1500] - Loss: 0.4297; AAPD: 0.3370693624019623
Epoch [1144/1500] - Loss: 0.3473; AAPD: 0.3369286358356476
Epoch [1145/1500] - Loss: 0.3613; AAPD: 0.33693838119506836
Epoch [1146/1500] - Loss: 0.2993; AAPD: 0.3371535837650299
Epoch [1147/1500] - Loss: 0.3502; AAPD: 0.33694034814834595
Epoch [1148/1500] - Loss: 0.3342; AAPD: 0.336861461400

Epoch [1271/1500] - Loss: 0.2576; AAPD: 0.33300065994262695
Epoch [1272/1500] - Loss: 0.3876; AAPD: 0.33308449387550354
Epoch [1273/1500] - Loss: 0.2999; AAPD: 0.33299919962882996
Epoch [1274/1500] - Loss: 0.3106; AAPD: 0.3330918550491333
Epoch [1275/1500] - Loss: 0.3428; AAPD: 0.3327781558036804
Epoch [1276/1500] - Loss: 0.3877; AAPD: 0.3329600989818573
Epoch [1277/1500] - Loss: 0.3507; AAPD: 0.33304208517074585
Epoch [1278/1500] - Loss: 0.3424; AAPD: 0.33283132314682007
Epoch [1279/1500] - Loss: 0.3735; AAPD: 0.3328007757663727
Epoch [1280/1500] - Loss: 0.3130; AAPD: 0.33269456028938293
Epoch [1281/1500] - Loss: 0.3298; AAPD: 0.3327656090259552
Epoch [1282/1500] - Loss: 0.4249; AAPD: 0.332667738199234
Epoch [1283/1500] - Loss: 0.4203; AAPD: 0.3326597213745117
Epoch [1284/1500] - Loss: 0.3331; AAPD: 0.3325924575328827
Epoch [1285/1500] - Loss: 0.2914; AAPD: 0.33264151215553284
Epoch [1286/1500] - Loss: 0.3489; AAPD: 0.3325801193714142
Epoch [1287/1500] - Loss: 0.3246; AAPD: 0.33247080

Epoch [1409/1500] - Loss: 0.3203; AAPD: 0.32907190918922424
Epoch [1410/1500] - Loss: 0.3481; AAPD: 0.32909706234931946
Epoch [1411/1500] - Loss: 0.3763; AAPD: 0.32935410737991333
Epoch [1412/1500] - Loss: 0.3912; AAPD: 0.32924163341522217
Epoch [1413/1500] - Loss: 0.2905; AAPD: 0.3291609287261963
Epoch [1414/1500] - Loss: 0.4070; AAPD: 0.32914474606513977
Epoch [1415/1500] - Loss: 0.3336; AAPD: 0.32909390330314636
Epoch [1416/1500] - Loss: 0.3210; AAPD: 0.32894912362098694
Epoch [1417/1500] - Loss: 0.3464; AAPD: 0.3288840055465698
Epoch [1418/1500] - Loss: 0.3174; AAPD: 0.329008549451828
Epoch [1419/1500] - Loss: 0.3223; AAPD: 0.3291158080101013
Epoch [1420/1500] - Loss: 0.2686; AAPD: 0.3289637267589569
Epoch [1421/1500] - Loss: 0.3408; AAPD: 0.32875165343284607
Epoch [1422/1500] - Loss: 0.2810; AAPD: 0.3288213014602661
Epoch [1423/1500] - Loss: 0.3614; AAPD: 0.32909727096557617
Epoch [1424/1500] - Loss: 0.3057; AAPD: 0.32872700691223145
Epoch [1425/1500] - Loss: 0.3505; AAPD: 0.32875

In [48]:
model = IAModelLeakyReLU(input_dimension=train_int.shape[1], output_dimension=train_out2.shape[1])
model_AAPD_Leaky_2 = train_model(model, 'best_model_AAPD_Leaky_2.pth', train_dataloader, dev_dataloader, target_idx=2)
torch.save(model_AAPD_Leaky_2.state_dict(), "model_AAPD_Leaky_2.pth")

Epoch [1/1500] - Loss: 249340.2656; AAPD: 499381.875
Epoch [2/1500] - Loss: 757129.9375; AAPD: 2064029.75
Epoch [3/1500] - Loss: 571006.9375; AAPD: 1451642.5
Epoch [4/1500] - Loss: 228524.5312; AAPD: 1571852.125
Epoch [5/1500] - Loss: 441261.6875; AAPD: 946879.5625
Epoch [6/1500] - Loss: 554464.8750; AAPD: 864326.25
Epoch [7/1500] - Loss: 5130773.5000; AAPD: 3317725.5
Epoch [8/1500] - Loss: 566553.0625; AAPD: 751910.4375
Epoch [9/1500] - Loss: 259824.7031; AAPD: 1153751.0
Epoch [10/1500] - Loss: 1106277.8750; AAPD: 1170112.375
Epoch [11/1500] - Loss: 1572719.5000; AAPD: 942984.8125
Epoch [12/1500] - Loss: 101981.5156; AAPD: 287628.53125
Epoch [13/1500] - Loss: 306808.1250; AAPD: 475252.375
Epoch [14/1500] - Loss: 266865.1562; AAPD: 414847.1875
Epoch [15/1500] - Loss: 100091.2109; AAPD: 235618.203125
Epoch [16/1500] - Loss: 3942813.0000; AAPD: 1992526.875
Epoch [17/1500] - Loss: 162733.7344; AAPD: 738041.75
Epoch [18/1500] - Loss: 1948711.3750; AAPD: 3333468.0
Epoch [19/1500] - Loss: 42

Epoch [151/1500] - Loss: 3186860.7500; AAPD: 12710721.0
Epoch [152/1500] - Loss: 1935069.3750; AAPD: 6087433.0
Epoch [153/1500] - Loss: 1202841.6250; AAPD: 5220838.5
Epoch [154/1500] - Loss: 921379.0000; AAPD: 2876350.25
Epoch [155/1500] - Loss: 442222.4688; AAPD: 1170087.625
Epoch [156/1500] - Loss: 379012.3125; AAPD: 1348647.875
Epoch [157/1500] - Loss: 738431.6875; AAPD: 1710776.75
Epoch [158/1500] - Loss: 127354.0391; AAPD: 256297.34375
Epoch [159/1500] - Loss: 123721.5547; AAPD: 540007.125
Epoch [160/1500] - Loss: 670193.6250; AAPD: 811003.875
Epoch [161/1500] - Loss: 352218.6875; AAPD: 4053372.25
Epoch [162/1500] - Loss: 1953244.0000; AAPD: 16128501.0
Epoch [163/1500] - Loss: 194229.4062; AAPD: 942885.375
Epoch [164/1500] - Loss: 953204.2500; AAPD: 5182769.5
Epoch [165/1500] - Loss: 1610032.6250; AAPD: 4097892.75
Epoch [166/1500] - Loss: 4450613.5000; AAPD: 847042.0
Epoch [167/1500] - Loss: 300613.4375; AAPD: 1793859.625
Epoch [168/1500] - Loss: 584751.7500; AAPD: 928833.125
Epoc

Epoch [299/1500] - Loss: 42666.9609; AAPD: 40860.87109375
Epoch [300/1500] - Loss: 7748.7246; AAPD: 10008.2802734375
Epoch [301/1500] - Loss: 9589.9375; AAPD: 23403.220703125
Epoch [302/1500] - Loss: 35670.6602; AAPD: 34954.8203125
Epoch [303/1500] - Loss: 8095.1147; AAPD: 22916.216796875
Epoch [304/1500] - Loss: 26165.1387; AAPD: 22565.3046875
Epoch [305/1500] - Loss: 19340.4453; AAPD: 34481.1640625
Epoch [306/1500] - Loss: 50216.3633; AAPD: 14498.908203125
Epoch [307/1500] - Loss: 83130.2812; AAPD: 87965.796875
Epoch [308/1500] - Loss: 17339.6016; AAPD: 27107.587890625
Epoch [309/1500] - Loss: 4429.4844; AAPD: 26664.28515625
Epoch [310/1500] - Loss: 32626.3242; AAPD: 36088.58203125
Epoch [311/1500] - Loss: 281148.5938; AAPD: 32837.5078125
Epoch [312/1500] - Loss: 99833.6484; AAPD: 32060.513671875
Epoch [313/1500] - Loss: 3399.8506; AAPD: 20691.529296875
Epoch [314/1500] - Loss: 3534.1655; AAPD: 48671.53515625
Epoch [315/1500] - Loss: 15963.3496; AAPD: 28839.12890625
Epoch [316/1500] 

Epoch [442/1500] - Loss: 74376.9609; AAPD: 32803.2109375
Epoch [443/1500] - Loss: 44306.0859; AAPD: 201507.765625
Epoch [444/1500] - Loss: 416614.9062; AAPD: 747746.25
Epoch [445/1500] - Loss: 132664.6875; AAPD: 378831.75
Epoch [446/1500] - Loss: 459921.8750; AAPD: 432348.4375
Epoch [447/1500] - Loss: 46665.3906; AAPD: 138297.328125
Epoch [448/1500] - Loss: 324178.0000; AAPD: 277905.90625
Epoch [449/1500] - Loss: 206778.2656; AAPD: 571653.1875
Epoch [450/1500] - Loss: 51694.3516; AAPD: 192707.953125
Epoch [451/1500] - Loss: 15472.8262; AAPD: 96278.859375
Epoch [452/1500] - Loss: 105600.9219; AAPD: 89560.421875
Epoch [453/1500] - Loss: 47653.0820; AAPD: 49692.83203125
Epoch [454/1500] - Loss: 192957.2031; AAPD: 42202.7265625
Epoch [455/1500] - Loss: 50955.3242; AAPD: 40460.83203125
Epoch [456/1500] - Loss: 75011.5234; AAPD: 72987.59375
Epoch [457/1500] - Loss: 121751.7188; AAPD: 130572.84375
Epoch [458/1500] - Loss: 773734.4375; AAPD: 108071.7890625
Epoch [459/1500] - Loss: 202934.0312;

Epoch [586/1500] - Loss: 142059.4375; AAPD: 12724.048828125
Epoch [587/1500] - Loss: 5826.3428; AAPD: 24077.9765625
Epoch [588/1500] - Loss: 7557.9995; AAPD: 19911.326171875
Epoch [589/1500] - Loss: 9082.6719; AAPD: 23520.07421875
Epoch [590/1500] - Loss: 5990.0176; AAPD: 24626.07421875
Epoch [591/1500] - Loss: 5882.3599; AAPD: 13671.8603515625
Epoch [592/1500] - Loss: 11516.0674; AAPD: 11929.080078125
Epoch [593/1500] - Loss: 9417.5498; AAPD: 19443.921875
Epoch [594/1500] - Loss: 12260.7939; AAPD: 27176.51171875
Epoch [595/1500] - Loss: 6880.8853; AAPD: 6854.50048828125
Epoch [596/1500] - Loss: 5725.6895; AAPD: 15232.38671875
Epoch [597/1500] - Loss: 2365.1006; AAPD: 6981.2734375
Epoch [598/1500] - Loss: 73836.0859; AAPD: 17854.63671875
Epoch [599/1500] - Loss: 4300.0522; AAPD: 24396.353515625
Epoch [600/1500] - Loss: 26428.4883; AAPD: 21562.328125
Epoch [601/1500] - Loss: 36456.9727; AAPD: 8448.3544921875
Epoch [602/1500] - Loss: 6466.0845; AAPD: 29196.177734375
Epoch [603/1500] - Lo

Epoch [728/1500] - Loss: 5228.6631; AAPD: 11464.0927734375
Epoch [729/1500] - Loss: 2755.3076; AAPD: 8301.2451171875
Epoch [730/1500] - Loss: 745.4514; AAPD: 4210.39013671875
Epoch [731/1500] - Loss: 1552.1283; AAPD: 3294.635986328125
Epoch [732/1500] - Loss: 2220.3884; AAPD: 10541.8466796875
Epoch [733/1500] - Loss: 706.0333; AAPD: 6646.9150390625
Epoch [734/1500] - Loss: 10943.9844; AAPD: 6170.416015625
Epoch [735/1500] - Loss: 2772.1990; AAPD: 10535.3798828125
Epoch [736/1500] - Loss: 3058.3479; AAPD: 8434.4697265625
Epoch [737/1500] - Loss: 4140.2598; AAPD: 7608.3359375
Epoch [738/1500] - Loss: 5067.1777; AAPD: 11452.0732421875
Epoch [739/1500] - Loss: 965.4302; AAPD: 5828.50341796875
Epoch [740/1500] - Loss: 3409.0583; AAPD: 7528.1708984375
Epoch [741/1500] - Loss: 4598.2544; AAPD: 7935.66259765625
Epoch [742/1500] - Loss: 3603.6982; AAPD: 7632.33154296875
Epoch [743/1500] - Loss: 2807.2612; AAPD: 14766.8779296875
Epoch [744/1500] - Loss: 6955.1040; AAPD: 24841.537109375
Epoch [74

Epoch [867/1500] - Loss: 338.6506; AAPD: 805.2174072265625
Epoch [868/1500] - Loss: 580.2740; AAPD: 1005.24267578125
Epoch [869/1500] - Loss: 436.3480; AAPD: 955.6751098632812
Epoch [870/1500] - Loss: 461.7016; AAPD: 986.4972534179688
Epoch [871/1500] - Loss: 169.2990; AAPD: 1101.5238037109375
Epoch [872/1500] - Loss: 2221.3284; AAPD: 864.534423828125
Epoch [873/1500] - Loss: 317.4445; AAPD: 1043.7022705078125
Epoch [874/1500] - Loss: 1003.5575; AAPD: 1077.1796875
Epoch [875/1500] - Loss: 454.8426; AAPD: 955.2442016601562
Epoch [876/1500] - Loss: 1165.0266; AAPD: 1277.234619140625
Epoch [877/1500] - Loss: 186.4477; AAPD: 915.5498046875
Epoch [878/1500] - Loss: 258.8777; AAPD: 1314.0567626953125
Epoch [879/1500] - Loss: 459.9326; AAPD: 1041.3607177734375
Epoch [880/1500] - Loss: 596.9404; AAPD: 974.0618896484375
Epoch [881/1500] - Loss: 172.9627; AAPD: 942.7632446289062
Epoch [882/1500] - Loss: 580.9109; AAPD: 902.8137817382812
Epoch [883/1500] - Loss: 644.1508; AAPD: 1233.968994140625


Epoch [1007/1500] - Loss: 366.4428; AAPD: 618.3821411132812
Epoch [1008/1500] - Loss: 412.9609; AAPD: 654.4611206054688
Epoch [1009/1500] - Loss: 133.1670; AAPD: 612.2127075195312
Epoch [1010/1500] - Loss: 155.2627; AAPD: 628.3903198242188
Epoch [1011/1500] - Loss: 391.8442; AAPD: 651.81982421875
Epoch [1012/1500] - Loss: 336.4149; AAPD: 636.549072265625
Epoch [1013/1500] - Loss: 66.6488; AAPD: 634.0093383789062
Epoch [1014/1500] - Loss: 286.2495; AAPD: 613.33544921875
Epoch [1015/1500] - Loss: 181.5602; AAPD: 647.332763671875
Epoch [1016/1500] - Loss: 268.3268; AAPD: 630.5673217773438
Epoch [1017/1500] - Loss: 271.8217; AAPD: 656.8562622070312
Epoch [1018/1500] - Loss: 317.0833; AAPD: 657.6755981445312
Epoch [1019/1500] - Loss: 147.5117; AAPD: 634.3761596679688
Epoch [1020/1500] - Loss: 151.4620; AAPD: 628.0120849609375
Epoch [1021/1500] - Loss: 973.8192; AAPD: 597.4425659179688
Epoch [1022/1500] - Loss: 138.5681; AAPD: 606.4277954101562
Epoch [1023/1500] - Loss: 616.2542; AAPD: 635.6

Epoch [1145/1500] - Loss: 197.6927; AAPD: 630.38916015625
Epoch [1146/1500] - Loss: 185.6253; AAPD: 590.2152099609375
Epoch [1147/1500] - Loss: 123.6733; AAPD: 603.3587646484375
Epoch [1148/1500] - Loss: 599.9780; AAPD: 616.4967651367188
Epoch [1149/1500] - Loss: 100.4107; AAPD: 606.2229614257812
Epoch [1150/1500] - Loss: 195.7740; AAPD: 602.0103759765625
Epoch [1151/1500] - Loss: 11312.7646; AAPD: 613.4505004882812
Epoch [1152/1500] - Loss: 156.8459; AAPD: 610.1721801757812
Epoch [1153/1500] - Loss: 87.7723; AAPD: 629.8226318359375
Epoch [1154/1500] - Loss: 228.6034; AAPD: 604.6669921875
Epoch [1155/1500] - Loss: 190.4622; AAPD: 602.9487915039062
Epoch [1156/1500] - Loss: 264.6948; AAPD: 651.8753662109375
Epoch [1157/1500] - Loss: 751.4982; AAPD: 596.9613037109375
Epoch [1158/1500] - Loss: 512.6085; AAPD: 598.5233764648438
Epoch [1159/1500] - Loss: 348.1700; AAPD: 588.7741088867188
Epoch [1160/1500] - Loss: 220.6473; AAPD: 593.5535278320312
Epoch [1161/1500] - Loss: 403.4195; AAPD: 61

Epoch [1283/1500] - Loss: 188.8967; AAPD: 620.980224609375
Epoch [1284/1500] - Loss: 340.0849; AAPD: 582.9453125
Epoch [1285/1500] - Loss: 383.3004; AAPD: 649.5496215820312
Epoch [1286/1500] - Loss: 1307.2316; AAPD: 619.2786865234375
Epoch [1287/1500] - Loss: 103.5310; AAPD: 598.7727661132812
Epoch [1288/1500] - Loss: 459.0157; AAPD: 584.2252807617188
Epoch [1289/1500] - Loss: 711.2558; AAPD: 621.2511596679688
Epoch [1290/1500] - Loss: 165.3079; AAPD: 583.294677734375
Epoch [1291/1500] - Loss: 222.7500; AAPD: 591.0214233398438
Epoch [1292/1500] - Loss: 200.6507; AAPD: 576.6809692382812
Epoch [1293/1500] - Loss: 105.5380; AAPD: 595.7620849609375
Epoch [1294/1500] - Loss: 271.7878; AAPD: 587.7361450195312
Epoch [1295/1500] - Loss: 116.2768; AAPD: 585.5949096679688
Epoch [1296/1500] - Loss: 297.8413; AAPD: 606.9776000976562
Epoch [1297/1500] - Loss: 192.6299; AAPD: 626.9085083007812
Epoch [1298/1500] - Loss: 86.4582; AAPD: 576.519775390625
Epoch [1299/1500] - Loss: 736.9374; AAPD: 602.952

Epoch [1421/1500] - Loss: 179.5324; AAPD: 573.3130493164062
Epoch [1422/1500] - Loss: 386.0365; AAPD: 606.7363891601562
Epoch [1423/1500] - Loss: 306.9185; AAPD: 594.11181640625
Epoch [1424/1500] - Loss: 1901.0393; AAPD: 599.1560668945312
Epoch [1425/1500] - Loss: 339.3885; AAPD: 581.2457885742188
Epoch [1426/1500] - Loss: 135.1629; AAPD: 602.071044921875
Epoch [1427/1500] - Loss: 227.7584; AAPD: 582.7437744140625
Epoch [1428/1500] - Loss: 235.1577; AAPD: 589.798828125
Epoch [1429/1500] - Loss: 280.3556; AAPD: 621.96923828125
Epoch [1430/1500] - Loss: 149.2867; AAPD: 567.6771850585938
Epoch [1431/1500] - Loss: 157.9845; AAPD: 587.8163452148438
Epoch [1432/1500] - Loss: 797.3203; AAPD: 597.108154296875
Epoch [1433/1500] - Loss: 147.3220; AAPD: 589.003662109375
Epoch [1434/1500] - Loss: 218.8333; AAPD: 567.9087524414062
Epoch [1435/1500] - Loss: 236.8853; AAPD: 626.2579345703125
Epoch [1436/1500] - Loss: 582.0881; AAPD: 568.493896484375
Epoch [1437/1500] - Loss: 571.6728; AAPD: 608.02404

In [49]:
model = IAModelLeakyReLU(input_dimension=train_int.shape[1], output_dimension=train_out3.shape[1])
model_AAPD_Leaky_3 = train_model(model, 'best_model_AAPD_Leaky_3.pth', train_dataloader, dev_dataloader, target_idx=3)
torch.save(model_AAPD_Leaky_3.state_dict(), "model_AAPD_Leaky_3.pth")

Epoch [1/1500] - Loss: 3091874.7500; AAPD: 16888748.0
Epoch [2/1500] - Loss: 110278416.0000; AAPD: 11081436.0
Epoch [3/1500] - Loss: 13234814.0000; AAPD: 11121654.0
Epoch [4/1500] - Loss: 1726003.3750; AAPD: 9812867.0
Epoch [5/1500] - Loss: 1337561.2500; AAPD: 1720975.625
Epoch [6/1500] - Loss: 1056745.5000; AAPD: 2581251.25
Epoch [7/1500] - Loss: 3772633.2500; AAPD: 1979058.5
Epoch [8/1500] - Loss: 26867562.0000; AAPD: 58266196.0
Epoch [9/1500] - Loss: 125556528.0000; AAPD: 382148448.0
Epoch [10/1500] - Loss: 22902122.0000; AAPD: 68918248.0
Epoch [11/1500] - Loss: 27572466.0000; AAPD: 78876888.0
Epoch [12/1500] - Loss: 2735695.5000; AAPD: 18719984.0
Epoch [13/1500] - Loss: 10796688.0000; AAPD: 19334440.0
Epoch [14/1500] - Loss: 443549.8438; AAPD: 2153806.5
Epoch [15/1500] - Loss: 15484174.0000; AAPD: 8814589.0
Epoch [16/1500] - Loss: 3164761.7500; AAPD: 5260348.5
Epoch [17/1500] - Loss: 245111328.0000; AAPD: 15093438.0
Epoch [18/1500] - Loss: 1411103.5000; AAPD: 5840268.5
Epoch [19/15

Epoch [150/1500] - Loss: 10651233.0000; AAPD: 16178188.0
Epoch [151/1500] - Loss: 1392962.2500; AAPD: 7163761.0
Epoch [152/1500] - Loss: 1710171.3750; AAPD: 4779484.5
Epoch [153/1500] - Loss: 1597299.2500; AAPD: 5419132.5
Epoch [154/1500] - Loss: 5948763.5000; AAPD: 9874420.0
Epoch [155/1500] - Loss: 5292062.5000; AAPD: 27565198.0
Epoch [156/1500] - Loss: 10024813.0000; AAPD: 10188916.0
Epoch [157/1500] - Loss: 10326743.0000; AAPD: 10297076.0
Epoch [158/1500] - Loss: 7410146.5000; AAPD: 5749900.5
Epoch [159/1500] - Loss: 1351182.7500; AAPD: 1367404.875
Epoch [160/1500] - Loss: 6522402.0000; AAPD: 8941913.0
Epoch [161/1500] - Loss: 493155.6875; AAPD: 1145394.25
Epoch [162/1500] - Loss: 2085121.5000; AAPD: 7730001.0
Epoch [163/1500] - Loss: 36734612.0000; AAPD: 11717131.0
Epoch [164/1500] - Loss: 4471966.5000; AAPD: 28693084.0
Epoch [165/1500] - Loss: 2905303.5000; AAPD: 6818544.5
Epoch [166/1500] - Loss: 2795339.2500; AAPD: 5592190.5
Epoch [167/1500] - Loss: 336860.4375; AAPD: 670615.0


Epoch [297/1500] - Loss: 26225.7129; AAPD: 54979.64453125
Epoch [298/1500] - Loss: 28783.1973; AAPD: 201660.0625
Epoch [299/1500] - Loss: 34405.0898; AAPD: 152130.609375
Epoch [300/1500] - Loss: 1137863.6250; AAPD: 507903.125
Epoch [301/1500] - Loss: 13526.2754; AAPD: 48000.79296875
Epoch [302/1500] - Loss: 51505.5000; AAPD: 445850.4375
Epoch [303/1500] - Loss: 21203.1504; AAPD: 92047.0546875
Epoch [304/1500] - Loss: 80940.3672; AAPD: 73489.03125
Epoch [305/1500] - Loss: 259419.0312; AAPD: 307319.4375
Epoch [306/1500] - Loss: 115506.9531; AAPD: 304165.03125
Epoch [307/1500] - Loss: 6743.3691; AAPD: 31151.712890625
Epoch [308/1500] - Loss: 73830.8203; AAPD: 179216.234375
Epoch [309/1500] - Loss: 43905.6953; AAPD: 90625.265625
Epoch [310/1500] - Loss: 21714.5840; AAPD: 139542.171875
Epoch [311/1500] - Loss: 171085.0781; AAPD: 174359.015625
Epoch [312/1500] - Loss: 10704.7725; AAPD: 46829.1328125
Epoch [313/1500] - Loss: 6790.8862; AAPD: 31818.912109375
Epoch [314/1500] - Loss: 85419.6797

Epoch [442/1500] - Loss: 19000.2129; AAPD: 39761.41015625
Epoch [443/1500] - Loss: 10391.0400; AAPD: 61764.5859375
Epoch [444/1500] - Loss: 25084.9629; AAPD: 65998.5
Epoch [445/1500] - Loss: 1373040.7500; AAPD: 3619937.25
Epoch [446/1500] - Loss: 2908395.0000; AAPD: 3145069.0
Epoch [447/1500] - Loss: 2547722.7500; AAPD: 8880381.0
Epoch [448/1500] - Loss: 184627.8438; AAPD: 659961.1875
Epoch [449/1500] - Loss: 62050.5625; AAPD: 220718.375
Epoch [450/1500] - Loss: 246340.9688; AAPD: 665282.8125
Epoch [451/1500] - Loss: 660334.3750; AAPD: 110987.8828125
Epoch [452/1500] - Loss: 54646.6914; AAPD: 191178.296875
Epoch [453/1500] - Loss: 337170.8438; AAPD: 219784.703125
Epoch [454/1500] - Loss: 80770.1328; AAPD: 153070.203125
Epoch [455/1500] - Loss: 320528.8750; AAPD: 299588.78125
Epoch [456/1500] - Loss: 30518.1289; AAPD: 182679.53125
Epoch [457/1500] - Loss: 220375.2656; AAPD: 630414.0
Epoch [458/1500] - Loss: 169387.0469; AAPD: 934949.3125
Epoch [459/1500] - Loss: 132916.8906; AAPD: 51665

Epoch [587/1500] - Loss: 2605.3118; AAPD: 18549.626953125
Epoch [588/1500] - Loss: 14719.0410; AAPD: 46938.23828125
Epoch [589/1500] - Loss: 78001.0391; AAPD: 60392.60546875
Epoch [590/1500] - Loss: 5311.3374; AAPD: 10034.77734375
Epoch [591/1500] - Loss: 31350.2168; AAPD: 43656.23046875
Epoch [592/1500] - Loss: 122873.0156; AAPD: 11796.216796875
Epoch [593/1500] - Loss: 35006.2539; AAPD: 47682.7578125
Epoch [594/1500] - Loss: 9604.2793; AAPD: 18961.5625
Epoch [595/1500] - Loss: 9081.2236; AAPD: 35755.0625
Epoch [596/1500] - Loss: 8240.1797; AAPD: 20250.646484375
Epoch [597/1500] - Loss: 5430.1885; AAPD: 18628.0625
Epoch [598/1500] - Loss: 25932.3203; AAPD: 31498.08203125
Epoch [599/1500] - Loss: 6166.5010; AAPD: 11755.0712890625
Epoch [600/1500] - Loss: 6620.0654; AAPD: 11284.57421875
Epoch [601/1500] - Loss: 16450.1230; AAPD: 28363.451171875
Epoch [602/1500] - Loss: 9033.5117; AAPD: 26439.1796875
Epoch [603/1500] - Loss: 5765.0908; AAPD: 15928.26953125
Epoch [604/1500] - Loss: 16368.

Epoch [729/1500] - Loss: 3243.4490; AAPD: 6860.47314453125
Epoch [730/1500] - Loss: 15793.6602; AAPD: 18108.9296875
Epoch [731/1500] - Loss: 6285.7632; AAPD: 18406.51171875
Epoch [732/1500] - Loss: 5304.2832; AAPD: 14675.41796875
Epoch [733/1500] - Loss: 6895.8667; AAPD: 21302.4375
Epoch [734/1500] - Loss: 7900.8813; AAPD: 22741.04296875
Epoch [735/1500] - Loss: 3983.4480; AAPD: 28663.197265625
Epoch [736/1500] - Loss: 6675.2808; AAPD: 16622.513671875
Epoch [737/1500] - Loss: 48206.4531; AAPD: 11869.994140625
Epoch [738/1500] - Loss: 8848.5537; AAPD: 29285.607421875
Epoch [739/1500] - Loss: 5450.2988; AAPD: 10591.623046875
Epoch [740/1500] - Loss: 5053.3228; AAPD: 17156.40234375
Epoch [741/1500] - Loss: 15021.0693; AAPD: 11984.51171875
Epoch [742/1500] - Loss: 2115.9739; AAPD: 8644.1875
Epoch [743/1500] - Loss: 4430.4092; AAPD: 11471.927734375
Epoch [744/1500] - Loss: 16978.6445; AAPD: 15690.935546875
Epoch [745/1500] - Loss: 399.1040; AAPD: 4887.998046875
Epoch [746/1500] - Loss: 1958

Epoch [870/1500] - Loss: 1237.4774; AAPD: 3636.882080078125
Epoch [871/1500] - Loss: 5319.6509; AAPD: 3209.77783203125
Epoch [872/1500] - Loss: 472.8949; AAPD: 2994.21533203125
Epoch [873/1500] - Loss: 1500.4764; AAPD: 3142.82470703125
Epoch [874/1500] - Loss: 490.8877; AAPD: 3253.292724609375
Epoch [875/1500] - Loss: 989.4553; AAPD: 3078.857421875
Epoch [876/1500] - Loss: 866.7419; AAPD: 3826.218994140625
Epoch [877/1500] - Loss: 1098.1322; AAPD: 3194.80126953125
Epoch [878/1500] - Loss: 6462.3145; AAPD: 3234.5537109375
Epoch [879/1500] - Loss: 3318.3025; AAPD: 3691.117919921875
Epoch [880/1500] - Loss: 9136.7119; AAPD: 2992.326171875
Epoch [881/1500] - Loss: 277.2859; AAPD: 3145.531494140625
Epoch [882/1500] - Loss: 838.5623; AAPD: 3195.273193359375
Epoch [883/1500] - Loss: 1332.5912; AAPD: 3174.485595703125
Epoch [884/1500] - Loss: 1497.8717; AAPD: 3627.731689453125
Epoch [885/1500] - Loss: 743.9584; AAPD: 2866.8896484375
Epoch [886/1500] - Loss: 1615.0161; AAPD: 3518.9384765625
Epo

Epoch [1010/1500] - Loss: 513.6935; AAPD: 2669.45751953125
Epoch [1011/1500] - Loss: 891.5383; AAPD: 2675.390625
Epoch [1012/1500] - Loss: 1360.6969; AAPD: 2712.79638671875
Epoch [1013/1500] - Loss: 1511.2828; AAPD: 2732.501220703125
Epoch [1014/1500] - Loss: 1571.6047; AAPD: 2678.321533203125
Epoch [1015/1500] - Loss: 815.0245; AAPD: 2685.0693359375
Epoch [1016/1500] - Loss: 508.0723; AAPD: 2721.862060546875
Epoch [1017/1500] - Loss: 1529.0437; AAPD: 2722.04931640625
Epoch [1018/1500] - Loss: 1071.3566; AAPD: 2693.99560546875
Epoch [1019/1500] - Loss: 797.1111; AAPD: 2725.234375
Epoch [1020/1500] - Loss: 2439.6086; AAPD: 2661.216796875
Epoch [1021/1500] - Loss: 1578.3766; AAPD: 2710.119140625
Epoch [1022/1500] - Loss: 1584.4272; AAPD: 2748.105712890625
Epoch [1023/1500] - Loss: 895.3347; AAPD: 2720.667236328125
Epoch [1024/1500] - Loss: 468.7277; AAPD: 2696.8125
Epoch [1025/1500] - Loss: 3155.5369; AAPD: 2695.04638671875
Epoch [1026/1500] - Loss: 324.5012; AAPD: 2726.3173828125
Epoch 

Epoch [1149/1500] - Loss: 894.0548; AAPD: 2664.95654296875
Epoch [1150/1500] - Loss: 819.4215; AAPD: 2637.83447265625
Epoch [1151/1500] - Loss: 4840.1909; AAPD: 2708.9052734375
Epoch [1152/1500] - Loss: 692.0137; AAPD: 2645.014404296875
Epoch [1153/1500] - Loss: 16218.9277; AAPD: 2765.95263671875
Epoch [1154/1500] - Loss: 661.9154; AAPD: 2696.58056640625
Epoch [1155/1500] - Loss: 607.1200; AAPD: 2682.00146484375
Epoch [1156/1500] - Loss: 833.2855; AAPD: 2640.80908203125
Epoch [1157/1500] - Loss: 783.6676; AAPD: 2677.69921875
Epoch [1158/1500] - Loss: 1519.0867; AAPD: 2655.817138671875
Epoch [1159/1500] - Loss: 690.5750; AAPD: 2696.129638671875
Epoch [1160/1500] - Loss: 3446.0249; AAPD: 2698.8330078125
Epoch [1161/1500] - Loss: 2090.4600; AAPD: 2809.912353515625
Epoch [1162/1500] - Loss: 998.4890; AAPD: 2693.62109375
Epoch [1163/1500] - Loss: 574.9789; AAPD: 2643.927490234375
Epoch [1164/1500] - Loss: 327.5424; AAPD: 2640.346435546875
Epoch [1165/1500] - Loss: 3030.9253; AAPD: 2678.7656

Epoch [1287/1500] - Loss: 1509.1359; AAPD: 2613.971923828125
Epoch [1288/1500] - Loss: 2899.4661; AAPD: 2656.216552734375
Epoch [1289/1500] - Loss: 674.5389; AAPD: 2626.072265625
Epoch [1290/1500] - Loss: 1883.1523; AAPD: 2687.605712890625
Epoch [1291/1500] - Loss: 591.6311; AAPD: 2563.493408203125
Epoch [1292/1500] - Loss: 884.9545; AAPD: 2613.6484375
Epoch [1293/1500] - Loss: 1653.3511; AAPD: 2712.03173828125
Epoch [1294/1500] - Loss: 746.9840; AAPD: 2678.6982421875
Epoch [1295/1500] - Loss: 1175.0626; AAPD: 2621.706787109375
Epoch [1296/1500] - Loss: 967.7341; AAPD: 2591.1640625
Epoch [1297/1500] - Loss: 955.6528; AAPD: 2628.75830078125
Epoch [1298/1500] - Loss: 1134.3162; AAPD: 2763.937744140625
Epoch [1299/1500] - Loss: 452.1859; AAPD: 2584.714599609375
Epoch [1300/1500] - Loss: 537.3099; AAPD: 2707.3173828125
Epoch [1301/1500] - Loss: 697.8461; AAPD: 2593.137939453125
Epoch [1302/1500] - Loss: 1298.5228; AAPD: 2585.887939453125
Epoch [1303/1500] - Loss: 628.6326; AAPD: 2608.74707

Epoch [1425/1500] - Loss: 584.9481; AAPD: 2729.034423828125
Epoch [1426/1500] - Loss: 1431.0005; AAPD: 2668.056884765625
Epoch [1427/1500] - Loss: 1249.7849; AAPD: 2504.49609375
Epoch [1428/1500] - Loss: 1192.8895; AAPD: 2485.111572265625
Epoch [1429/1500] - Loss: 1247.0721; AAPD: 2527.537353515625
Epoch [1430/1500] - Loss: 605.8356; AAPD: 2540.361328125
Epoch [1431/1500] - Loss: 570.7875; AAPD: 2672.4794921875
Epoch [1432/1500] - Loss: 627.8546; AAPD: 2548.018310546875
Epoch [1433/1500] - Loss: 17917.9590; AAPD: 2494.850830078125
Epoch [1434/1500] - Loss: 1034.5928; AAPD: 2543.16845703125
Epoch [1435/1500] - Loss: 1152.4800; AAPD: 2500.501953125
Epoch [1436/1500] - Loss: 556.2073; AAPD: 2507.87255859375
Epoch [1437/1500] - Loss: 3806.8142; AAPD: 2553.17919921875
Epoch [1438/1500] - Loss: 2102.5459; AAPD: 2532.816650390625
Epoch [1439/1500] - Loss: 370.6909; AAPD: 2512.889404296875
Epoch [1440/1500] - Loss: 362.9878; AAPD: 2549.792236328125
Epoch [1441/1500] - Loss: 1687.4036; AAPD: 25

In [56]:
class linearencoderdecoder(nn.Module):
    def __init__(self, input_dim: int = 7, hidden_dim: int = 128, bottleneck_dim = 64, output_dim: int = 20):
        super(linearencoderdecoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, bottleneck_dim),
            nn.BatchNorm1d(bottleneck_dim),
            nn.ReLU(),
            nn.Linear(bottleneck_dim, bottleneck_dim),
            nn.BatchNorm1d(bottleneck_dim),
            nn.ReLU(),
            nn.Linear(bottleneck_dim, output_dim)
        )
        
    def forward(self, x):
        # Encoding
        x = self.encoder(x)
        # Decoding
        x = self.decoder(x)
        return x

In [51]:
model = linearencoderdecoder()
best_model_ED_AAPD_1 = train_model(model, 'best_model_ED_AAPD_1.pth', train_dataloader, dev_dataloader, target_idx=1)

Epoch [1/1500] - Loss: 0.5244; AAPD: 0.7371818423271179
Epoch [2/1500] - Loss: 0.7407; AAPD: 0.754469633102417
Epoch [3/1500] - Loss: 0.5560; AAPD: 0.7459468245506287
Epoch [4/1500] - Loss: 0.5441; AAPD: 0.6127431392669678
Epoch [5/1500] - Loss: 0.5785; AAPD: 0.9211863875389099
Epoch [6/1500] - Loss: 0.4569; AAPD: 0.5518184900283813
Epoch [7/1500] - Loss: 0.5698; AAPD: 0.5703815221786499
Epoch [8/1500] - Loss: 0.6624; AAPD: 0.8394404053688049
Epoch [9/1500] - Loss: 0.6642; AAPD: 0.6961545944213867
Epoch [10/1500] - Loss: 0.6107; AAPD: 0.7185382843017578
Epoch [11/1500] - Loss: 0.5550; AAPD: 0.6115034818649292
Epoch [12/1500] - Loss: 0.5035; AAPD: 1.4535695314407349
Epoch [13/1500] - Loss: 0.4897; AAPD: 0.7812218070030212
Epoch [14/1500] - Loss: 0.5939; AAPD: 0.6524147391319275
Epoch [15/1500] - Loss: 0.6877; AAPD: 0.7552523016929626
Epoch [16/1500] - Loss: 0.5280; AAPD: 0.6020445227622986
Epoch [17/1500] - Loss: 0.5701; AAPD: 14.106233596801758
Epoch [18/1500] - Loss: 0.6615; AAPD: 1.2

Epoch [145/1500] - Loss: 0.6308; AAPD: 0.6939078569412231
Epoch [146/1500] - Loss: 0.7274; AAPD: 0.6911382079124451
Epoch [147/1500] - Loss: 0.5571; AAPD: 0.6950570344924927
Epoch [148/1500] - Loss: 0.7676; AAPD: 0.7004995942115784
Epoch [149/1500] - Loss: 0.7151; AAPD: 0.7077040076255798
Epoch [150/1500] - Loss: 0.6778; AAPD: 0.6950393319129944
Epoch [151/1500] - Loss: 0.5806; AAPD: 0.6975367665290833
Epoch [152/1500] - Loss: 0.5360; AAPD: 0.7091765403747559
Epoch [153/1500] - Loss: 0.6956; AAPD: 0.6944398283958435
Epoch [154/1500] - Loss: 0.7114; AAPD: 0.6921356320381165
Epoch [155/1500] - Loss: 0.6946; AAPD: 0.6916086673736572
Epoch [156/1500] - Loss: 0.7101; AAPD: 0.6996235251426697
Epoch [157/1500] - Loss: 0.7449; AAPD: 0.7036827206611633
Epoch [158/1500] - Loss: 0.7786; AAPD: 0.7105445861816406
Epoch [159/1500] - Loss: 0.5651; AAPD: 0.7003313899040222
Epoch [160/1500] - Loss: 0.7039; AAPD: 0.7042077779769897
Epoch [161/1500] - Loss: 0.8236; AAPD: 0.6953739523887634
Epoch [162/150

Epoch [287/1500] - Loss: 0.6960; AAPD: 0.6893819570541382
Epoch [288/1500] - Loss: 0.6689; AAPD: 0.6884683966636658
Epoch [289/1500] - Loss: 0.6709; AAPD: 0.6896113753318787
Epoch [290/1500] - Loss: 0.6970; AAPD: 0.6885413527488708
Epoch [291/1500] - Loss: 0.6770; AAPD: 0.6904717683792114
Epoch [292/1500] - Loss: 0.7217; AAPD: 0.6885119676589966
Epoch [293/1500] - Loss: 0.6854; AAPD: 0.690592348575592
Epoch [294/1500] - Loss: 0.7405; AAPD: 0.6888799667358398
Epoch [295/1500] - Loss: 0.6792; AAPD: 0.6889464259147644
Epoch [296/1500] - Loss: 0.6389; AAPD: 0.6898065209388733
Epoch [297/1500] - Loss: 0.6696; AAPD: 0.6885348558425903
Epoch [298/1500] - Loss: 0.7492; AAPD: 0.690246045589447
Epoch [299/1500] - Loss: 0.7437; AAPD: 0.6889954805374146
Epoch [300/1500] - Loss: 0.7511; AAPD: 0.6888957023620605
Epoch [301/1500] - Loss: 0.7515; AAPD: 0.6895553469657898
Epoch [302/1500] - Loss: 0.7037; AAPD: 0.6892096996307373
Epoch [303/1500] - Loss: 0.6453; AAPD: 0.6889619827270508
Epoch [304/1500]

Epoch [429/1500] - Loss: 0.6599; AAPD: 0.6887961626052856
Epoch [430/1500] - Loss: 0.6351; AAPD: 0.6884450912475586
Epoch [431/1500] - Loss: 0.6079; AAPD: 0.6885595321655273
Epoch [432/1500] - Loss: 0.5882; AAPD: 0.6895567178726196
Epoch [433/1500] - Loss: 0.5763; AAPD: 0.6887519359588623
Epoch [434/1500] - Loss: 0.6684; AAPD: 0.689047634601593
Epoch [435/1500] - Loss: 0.6169; AAPD: 0.6893998384475708
Epoch [436/1500] - Loss: 0.6890; AAPD: 0.6892547011375427
Epoch [437/1500] - Loss: 0.6811; AAPD: 0.6907510161399841
Epoch [438/1500] - Loss: 0.6331; AAPD: 0.688506007194519
Epoch [439/1500] - Loss: 0.8490; AAPD: 0.6885100603103638
Epoch [440/1500] - Loss: 0.6730; AAPD: 0.691831648349762
Epoch [441/1500] - Loss: 0.8485; AAPD: 0.6886940598487854
Epoch [442/1500] - Loss: 0.7449; AAPD: 0.689125120639801
Epoch [443/1500] - Loss: 0.5652; AAPD: 0.6901220083236694
Epoch [444/1500] - Loss: 0.6046; AAPD: 0.690819263458252
Epoch [445/1500] - Loss: 0.6576; AAPD: 0.6901406049728394
Epoch [446/1500] - 

Epoch [571/1500] - Loss: 0.6280; AAPD: 0.6887074708938599
Epoch [572/1500] - Loss: 0.7813; AAPD: 0.6887883543968201
Epoch [573/1500] - Loss: 0.7586; AAPD: 0.6884242296218872
Epoch [574/1500] - Loss: 0.6618; AAPD: 0.6884427666664124
Epoch [575/1500] - Loss: 0.7709; AAPD: 0.6885014176368713
Epoch [576/1500] - Loss: 0.6376; AAPD: 0.6884591579437256
Epoch [577/1500] - Loss: 0.7034; AAPD: 0.688429594039917
Epoch [578/1500] - Loss: 0.7805; AAPD: 0.6883867979049683
Epoch [579/1500] - Loss: 0.6680; AAPD: 0.6887555718421936
Epoch [580/1500] - Loss: 0.6755; AAPD: 0.6889394521713257
Epoch [581/1500] - Loss: 0.7577; AAPD: 0.6883717179298401
Epoch [582/1500] - Loss: 0.6987; AAPD: 0.6883301138877869
Epoch [583/1500] - Loss: 0.6666; AAPD: 0.6883941888809204
Epoch [584/1500] - Loss: 0.7290; AAPD: 0.6883359551429749
Epoch [585/1500] - Loss: 0.7189; AAPD: 0.6883706450462341
Epoch [586/1500] - Loss: 0.5527; AAPD: 0.6885235905647278
Epoch [587/1500] - Loss: 0.7666; AAPD: 0.6888545155525208
Epoch [588/1500

Epoch [713/1500] - Loss: 0.6692; AAPD: 0.6883066892623901
Epoch [714/1500] - Loss: 0.6192; AAPD: 0.6883327960968018
Epoch [715/1500] - Loss: 0.6024; AAPD: 0.6884617805480957
Epoch [716/1500] - Loss: 0.7111; AAPD: 0.6883388757705688
Epoch [717/1500] - Loss: 0.6151; AAPD: 0.688395082950592
Epoch [718/1500] - Loss: 0.6570; AAPD: 0.6884046196937561
Epoch [719/1500] - Loss: 0.7827; AAPD: 0.6884002089500427
Epoch [720/1500] - Loss: 0.7107; AAPD: 0.6883936524391174
Epoch [721/1500] - Loss: 0.7572; AAPD: 0.6883466839790344
Epoch [722/1500] - Loss: 0.7097; AAPD: 0.6883480548858643
Epoch [723/1500] - Loss: 0.6351; AAPD: 0.6883950233459473
Epoch [724/1500] - Loss: 0.7500; AAPD: 0.6883625388145447
Epoch [725/1500] - Loss: 0.7400; AAPD: 0.6883212924003601
Epoch [726/1500] - Loss: 0.6864; AAPD: 0.6883814930915833
Epoch [727/1500] - Loss: 0.6839; AAPD: 0.6885143518447876
Epoch [728/1500] - Loss: 0.6887; AAPD: 0.6883655786514282
Epoch [729/1500] - Loss: 0.6916; AAPD: 0.6883984804153442
Epoch [730/1500

Epoch [855/1500] - Loss: 0.6064; AAPD: 0.6883239150047302
Epoch [856/1500] - Loss: 0.7381; AAPD: 0.6883296966552734
Epoch [857/1500] - Loss: 0.7228; AAPD: 0.6883615851402283
Epoch [858/1500] - Loss: 0.5759; AAPD: 0.688373327255249
Epoch [859/1500] - Loss: 0.6992; AAPD: 0.6883286237716675
Epoch [860/1500] - Loss: 0.7985; AAPD: 0.6883437633514404
Epoch [861/1500] - Loss: 0.6922; AAPD: 0.6883372068405151
Epoch [862/1500] - Loss: 0.7538; AAPD: 0.6883385181427002
Epoch [863/1500] - Loss: 0.7008; AAPD: 0.6883325576782227
Epoch [864/1500] - Loss: 0.7073; AAPD: 0.6883230805397034
Epoch [865/1500] - Loss: 0.8195; AAPD: 0.688363790512085
Epoch [866/1500] - Loss: 0.6302; AAPD: 0.6884163022041321
Epoch [867/1500] - Loss: 0.7172; AAPD: 0.6883223652839661
Epoch [868/1500] - Loss: 0.6710; AAPD: 0.6883270740509033
Epoch [869/1500] - Loss: 0.8386; AAPD: 0.6883649826049805
Epoch [870/1500] - Loss: 0.6428; AAPD: 0.6883885860443115
Epoch [871/1500] - Loss: 0.5857; AAPD: 0.6883335113525391
Epoch [872/1500]

Epoch [997/1500] - Loss: 0.7272; AAPD: 0.6883214712142944
Epoch [998/1500] - Loss: 0.7579; AAPD: 0.6883602738380432
Epoch [999/1500] - Loss: 0.6625; AAPD: 0.6883294582366943
Epoch [1000/1500] - Loss: 0.6404; AAPD: 0.6883207559585571
Epoch [1001/1500] - Loss: 0.6434; AAPD: 0.6883262395858765
Epoch [1002/1500] - Loss: 0.4795; AAPD: 0.6883328557014465
Epoch [1003/1500] - Loss: 0.6850; AAPD: 0.6883262991905212
Epoch [1004/1500] - Loss: 0.7483; AAPD: 0.6883260011672974
Epoch [1005/1500] - Loss: 0.7269; AAPD: 0.6883234977722168
Epoch [1006/1500] - Loss: 0.7245; AAPD: 0.6883264780044556
Epoch [1007/1500] - Loss: 0.6857; AAPD: 0.6883270740509033
Epoch [1008/1500] - Loss: 0.6868; AAPD: 0.6883235573768616
Epoch [1009/1500] - Loss: 0.6326; AAPD: 0.6883273720741272
Epoch [1010/1500] - Loss: 0.6782; AAPD: 0.6883254051208496
Epoch [1011/1500] - Loss: 0.7878; AAPD: 0.6883251667022705
Epoch [1012/1500] - Loss: 0.6629; AAPD: 0.688326358795166
Epoch [1013/1500] - Loss: 0.7433; AAPD: 0.6883238554000854
E

Epoch [1137/1500] - Loss: 0.7203; AAPD: 0.6883281469345093
Epoch [1138/1500] - Loss: 0.7555; AAPD: 0.6883275508880615
Epoch [1139/1500] - Loss: 0.6544; AAPD: 0.6883268356323242
Epoch [1140/1500] - Loss: 0.6595; AAPD: 0.6883258819580078
Epoch [1141/1500] - Loss: 0.7161; AAPD: 0.6883264183998108
Epoch [1142/1500] - Loss: 0.6454; AAPD: 0.6883240342140198
Epoch [1143/1500] - Loss: 0.6592; AAPD: 0.6883304715156555
Epoch [1144/1500] - Loss: 0.7619; AAPD: 0.6883276700973511
Epoch [1145/1500] - Loss: 0.6631; AAPD: 0.6883267760276794
Epoch [1146/1500] - Loss: 0.6610; AAPD: 0.6883258819580078
Epoch [1147/1500] - Loss: 0.6174; AAPD: 0.6883240342140198
Epoch [1148/1500] - Loss: 0.6744; AAPD: 0.688326895236969
Epoch [1149/1500] - Loss: 0.6460; AAPD: 0.6883278489112854
Epoch [1150/1500] - Loss: 0.6591; AAPD: 0.6883331537246704
Epoch [1151/1500] - Loss: 0.7630; AAPD: 0.6883309483528137
Epoch [1152/1500] - Loss: 0.7044; AAPD: 0.6883279085159302
Epoch [1153/1500] - Loss: 0.6961; AAPD: 0.688327074050903

Epoch [1277/1500] - Loss: 0.6847; AAPD: 0.6883271932601929
Epoch [1278/1500] - Loss: 0.5368; AAPD: 0.6883277297019958
Epoch [1279/1500] - Loss: 0.6669; AAPD: 0.6883269548416138
Epoch [1280/1500] - Loss: 0.6801; AAPD: 0.6883291006088257
Epoch [1281/1500] - Loss: 0.6926; AAPD: 0.6883292198181152
Epoch [1282/1500] - Loss: 0.6909; AAPD: 0.6883264780044556
Epoch [1283/1500] - Loss: 0.6256; AAPD: 0.6883271932601929
Epoch [1284/1500] - Loss: 0.6884; AAPD: 0.6883238554000854
Epoch [1285/1500] - Loss: 0.8205; AAPD: 0.6883265972137451
Epoch [1286/1500] - Loss: 0.6888; AAPD: 0.6883231997489929
Epoch [1287/1500] - Loss: 0.6035; AAPD: 0.6883244514465332
Epoch [1288/1500] - Loss: 0.6605; AAPD: 0.688329815864563
Epoch [1289/1500] - Loss: 0.7266; AAPD: 0.6883244514465332
Epoch [1290/1500] - Loss: 0.6780; AAPD: 0.6883277297019958
Epoch [1291/1500] - Loss: 0.6638; AAPD: 0.6883337497711182
Epoch [1292/1500] - Loss: 0.7010; AAPD: 0.6883277893066406
Epoch [1293/1500] - Loss: 0.7214; AAPD: 0.688332259654998

Epoch [1417/1500] - Loss: 0.7223; AAPD: 0.6883316040039062
Epoch [1418/1500] - Loss: 0.5948; AAPD: 0.6883267760276794
Epoch [1419/1500] - Loss: 0.6998; AAPD: 0.6883300542831421
Epoch [1420/1500] - Loss: 0.6900; AAPD: 0.6883240938186646
Epoch [1421/1500] - Loss: 0.7189; AAPD: 0.6883283257484436
Epoch [1422/1500] - Loss: 0.6572; AAPD: 0.6883265972137451
Epoch [1423/1500] - Loss: 0.7621; AAPD: 0.6883282661437988
Epoch [1424/1500] - Loss: 0.7888; AAPD: 0.6883321404457092
Epoch [1425/1500] - Loss: 0.7561; AAPD: 0.6883276104927063
Epoch [1426/1500] - Loss: 0.5767; AAPD: 0.688329815864563
Epoch [1427/1500] - Loss: 0.7160; AAPD: 0.6883276700973511
Epoch [1428/1500] - Loss: 0.6820; AAPD: 0.6883280873298645
Epoch [1429/1500] - Loss: 0.7550; AAPD: 0.6883278489112854
Epoch [1430/1500] - Loss: 0.7973; AAPD: 0.6883270740509033
Epoch [1431/1500] - Loss: 0.6515; AAPD: 0.6883315443992615
Epoch [1432/1500] - Loss: 0.6058; AAPD: 0.6883271932601929
Epoch [1433/1500] - Loss: 0.8179; AAPD: 0.688327252864837

In [52]:
model = linearencoderdecoder()
best_model_ED_AAPD_2 = train_model(model, 'best_model_ED_AAPD_2.pth', train_dataloader, dev_dataloader, target_idx=2)

Epoch [1/1500] - Loss: 56.4950; AAPD: 193.76708984375
Epoch [2/1500] - Loss: 108.0351; AAPD: 423.1856689453125
Epoch [3/1500] - Loss: 115.9506; AAPD: 411.8384094238281
Epoch [4/1500] - Loss: 41.1612; AAPD: 302.8780822753906
Epoch [5/1500] - Loss: 115.4573; AAPD: 284.4253234863281
Epoch [6/1500] - Loss: 242.1904; AAPD: 715.0128784179688
Epoch [7/1500] - Loss: 1033.2378; AAPD: 1006.8487548828125
Epoch [8/1500] - Loss: 1663.7872; AAPD: 563.623046875
Epoch [9/1500] - Loss: 93.3787; AAPD: 359.1186828613281
Epoch [10/1500] - Loss: 33.6874; AAPD: 261.94061279296875
Epoch [11/1500] - Loss: 67.8199; AAPD: 247.15989685058594
Epoch [12/1500] - Loss: 70.6307; AAPD: 113.4803695678711
Epoch [13/1500] - Loss: 155.2022; AAPD: 416.650390625
Epoch [14/1500] - Loss: 33.4625; AAPD: 339.1183776855469
Epoch [15/1500] - Loss: 43.0517; AAPD: 229.66818237304688
Epoch [16/1500] - Loss: 208.8259; AAPD: 286.9233093261719
Epoch [17/1500] - Loss: 62.9738; AAPD: 163.0551300048828
Epoch [18/1500] - Loss: 77.6223; AAP

Epoch [143/1500] - Loss: 76.1809; AAPD: 372.0086669921875
Epoch [144/1500] - Loss: 66.4172; AAPD: 203.23361206054688
Epoch [145/1500] - Loss: 295.3354; AAPD: 470.35009765625
Epoch [146/1500] - Loss: 25.7728; AAPD: 89.38998413085938
Epoch [147/1500] - Loss: 541.1084; AAPD: 507.3101501464844
Epoch [148/1500] - Loss: 61.5385; AAPD: 234.6725616455078
Epoch [149/1500] - Loss: 148.3249; AAPD: 319.35040283203125
Epoch [150/1500] - Loss: 384.6680; AAPD: 608.12548828125
Epoch [151/1500] - Loss: 291.7607; AAPD: 588.2317504882812
Epoch [152/1500] - Loss: 486.8993; AAPD: 818.0592651367188
Epoch [153/1500] - Loss: 125.3367; AAPD: 299.4541320800781
Epoch [154/1500] - Loss: 169.9996; AAPD: 206.50050354003906
Epoch [155/1500] - Loss: 59.1572; AAPD: 519.9293823242188
Epoch [156/1500] - Loss: 156.7700; AAPD: 255.02383422851562
Epoch [157/1500] - Loss: 78.1434; AAPD: 170.83265686035156
Epoch [158/1500] - Loss: 521.9476; AAPD: 38215.484375
Epoch [159/1500] - Loss: 69.3054; AAPD: 325.110107421875
Epoch [16

Epoch [283/1500] - Loss: 287.7372; AAPD: 1120.6212158203125
Epoch [284/1500] - Loss: 753.7595; AAPD: 915.3795166015625
Epoch [285/1500] - Loss: 21390.7227; AAPD: 759.8955688476562
Epoch [286/1500] - Loss: 369.8864; AAPD: 495.2120666503906
Epoch [287/1500] - Loss: 518.2328; AAPD: 1168.5643310546875
Epoch [288/1500] - Loss: 143.9138; AAPD: 1746.418212890625
Epoch [289/1500] - Loss: 609.1158; AAPD: 458.8680419921875
Epoch [290/1500] - Loss: 767.2149; AAPD: 670.229248046875
Epoch [291/1500] - Loss: 253.1035; AAPD: 1966.6181640625
Epoch [292/1500] - Loss: 443.0345; AAPD: 1521.5042724609375
Epoch [293/1500] - Loss: 421.9452; AAPD: 1780.2027587890625
Epoch [294/1500] - Loss: 504.7869; AAPD: 1369.596923828125
Epoch [295/1500] - Loss: 89.0244; AAPD: 616.1768188476562
Epoch [296/1500] - Loss: 164.3833; AAPD: 969.7086791992188
Epoch [297/1500] - Loss: 228.4515; AAPD: 1133.8798828125
Epoch [298/1500] - Loss: 1310.8708; AAPD: 2256.50244140625
Epoch [299/1500] - Loss: 2104.0757; AAPD: 2892.630615234

Epoch [423/1500] - Loss: 1464.8556; AAPD: 1745.8436279296875
Epoch [424/1500] - Loss: 2799.6284; AAPD: 1913.04638671875
Epoch [425/1500] - Loss: 1011.6334; AAPD: 1836.3953857421875
Epoch [426/1500] - Loss: 3425.6172; AAPD: 1667.289306640625
Epoch [427/1500] - Loss: 733.8938; AAPD: 1744.889892578125
Epoch [428/1500] - Loss: 1365.8175; AAPD: 1730.9945068359375
Epoch [429/1500] - Loss: 7216.7134; AAPD: 4222.30712890625
Epoch [430/1500] - Loss: 554.3283; AAPD: 1154.6141357421875
Epoch [431/1500] - Loss: 1990.6698; AAPD: 1922.14111328125
Epoch [432/1500] - Loss: 568.3881; AAPD: 1324.3798828125
Epoch [433/1500] - Loss: 3208.9456; AAPD: 2397.1435546875
Epoch [434/1500] - Loss: 916.1741; AAPD: 2125.62939453125
Epoch [435/1500] - Loss: 3290.1343; AAPD: 3928.778076171875
Epoch [436/1500] - Loss: 630.6181; AAPD: 1657.346435546875
Epoch [437/1500] - Loss: 1928.2584; AAPD: 3461.97021484375
Epoch [438/1500] - Loss: 1752.4331; AAPD: 834.8628540039062
Epoch [439/1500] - Loss: 960.5723; AAPD: 2436.6147

Epoch [562/1500] - Loss: 229.6973; AAPD: 537.031494140625
Epoch [563/1500] - Loss: 3153.9807; AAPD: 279.7113037109375
Epoch [564/1500] - Loss: 97.3477; AAPD: 300.3899841308594
Epoch [565/1500] - Loss: 7573.5371; AAPD: 285.7254333496094
Epoch [566/1500] - Loss: 909.9455; AAPD: 424.3053283691406
Epoch [567/1500] - Loss: 771.1480; AAPD: 447.0787658691406
Epoch [568/1500] - Loss: 794.4108; AAPD: 257.1220397949219
Epoch [569/1500] - Loss: 375.0807; AAPD: 464.565185546875
Epoch [570/1500] - Loss: 2648.2290; AAPD: 396.5238952636719
Epoch [571/1500] - Loss: 186.7103; AAPD: 163.88917541503906
Epoch [572/1500] - Loss: 320.7560; AAPD: 378.8598327636719
Epoch [573/1500] - Loss: 547.5987; AAPD: 314.9198913574219
Epoch [574/1500] - Loss: 251.4603; AAPD: 207.68240356445312
Epoch [575/1500] - Loss: 192.1894; AAPD: 537.5333251953125
Epoch [576/1500] - Loss: 601.6856; AAPD: 311.2347106933594
Epoch [577/1500] - Loss: 451.7111; AAPD: 320.0398864746094
Epoch [578/1500] - Loss: 1325.2369; AAPD: 281.01239013

Epoch [701/1500] - Loss: 165.4682; AAPD: 230.53297424316406
Epoch [702/1500] - Loss: 1246.6505; AAPD: 289.5331115722656
Epoch [703/1500] - Loss: 354.8994; AAPD: 338.50640869140625
Epoch [704/1500] - Loss: 248.9784; AAPD: 350.94622802734375
Epoch [705/1500] - Loss: 103.7738; AAPD: 251.89710998535156
Epoch [706/1500] - Loss: 339.2708; AAPD: 497.1119689941406
Epoch [707/1500] - Loss: 148.8493; AAPD: 196.3545684814453
Epoch [708/1500] - Loss: 216.6665; AAPD: 218.963623046875
Epoch [709/1500] - Loss: 193.6660; AAPD: 349.4938659667969
Epoch [710/1500] - Loss: 154.6738; AAPD: 388.1219787597656
Epoch [711/1500] - Loss: 1874.0262; AAPD: 298.9150695800781
Epoch [712/1500] - Loss: 139.1264; AAPD: 295.97308349609375
Epoch [713/1500] - Loss: 319.2989; AAPD: 445.1780700683594
Epoch [714/1500] - Loss: 330.9996; AAPD: 496.0063171386719
Epoch [715/1500] - Loss: 307.9110; AAPD: 334.11962890625
Epoch [716/1500] - Loss: 235.3203; AAPD: 240.81149291992188
Epoch [717/1500] - Loss: 243.9299; AAPD: 229.923477

Epoch [840/1500] - Loss: 458.7156; AAPD: 21.991416931152344
Epoch [841/1500] - Loss: 135.2571; AAPD: 38.30425262451172
Epoch [842/1500] - Loss: 137.1674; AAPD: 42.01376724243164
Epoch [843/1500] - Loss: 1855.9431; AAPD: 37.46875762939453
Epoch [844/1500] - Loss: 91.3733; AAPD: 16.821809768676758
Epoch [845/1500] - Loss: 101.0370; AAPD: 50.120758056640625
Epoch [846/1500] - Loss: 179.2173; AAPD: 27.917238235473633
Epoch [847/1500] - Loss: 662.5820; AAPD: 23.09971809387207
Epoch [848/1500] - Loss: 500.5922; AAPD: 23.1987247467041
Epoch [849/1500] - Loss: 176.3371; AAPD: 38.217132568359375
Epoch [850/1500] - Loss: 135.4879; AAPD: 16.918678283691406
Epoch [851/1500] - Loss: 335.2718; AAPD: 65.50296783447266
Epoch [852/1500] - Loss: 175.2849; AAPD: 27.165388107299805
Epoch [853/1500] - Loss: 235.3359; AAPD: 34.71126937866211
Epoch [854/1500] - Loss: 475.6588; AAPD: 40.12583541870117
Epoch [855/1500] - Loss: 248.3504; AAPD: 35.059852600097656
Epoch [856/1500] - Loss: 265.1005; AAPD: 25.79495

Epoch [979/1500] - Loss: 114.5228; AAPD: 44.948143005371094
Epoch [980/1500] - Loss: 107.7226; AAPD: 23.459489822387695
Epoch [981/1500] - Loss: 102.9502; AAPD: 30.206192016601562
Epoch [982/1500] - Loss: 48.0715; AAPD: 34.045753479003906
Epoch [983/1500] - Loss: 121.9945; AAPD: 37.07307434082031
Epoch [984/1500] - Loss: 44.2331; AAPD: 36.81665802001953
Epoch [985/1500] - Loss: 404.3445; AAPD: 56.78474807739258
Epoch [986/1500] - Loss: 1226.7288; AAPD: 44.18658447265625
Epoch [987/1500] - Loss: 210.8988; AAPD: 41.48760223388672
Epoch [988/1500] - Loss: 28.1054; AAPD: 40.28565979003906
Epoch [989/1500] - Loss: 16.5683; AAPD: 30.641942977905273
Epoch [990/1500] - Loss: 45.3102; AAPD: 36.499000549316406
Epoch [991/1500] - Loss: 203.3570; AAPD: 17.501663208007812
Epoch [992/1500] - Loss: 271.1433; AAPD: 63.08163833618164
Epoch [993/1500] - Loss: 204.9504; AAPD: 33.64851760864258
Epoch [994/1500] - Loss: 181.5588; AAPD: 29.22194480895996
Epoch [995/1500] - Loss: 66.6873; AAPD: 36.6719818115

Epoch [1117/1500] - Loss: 282.5062; AAPD: 16.55572509765625
Epoch [1118/1500] - Loss: 39.9653; AAPD: 6.318164825439453
Epoch [1119/1500] - Loss: 260.3400; AAPD: 8.81581974029541
Epoch [1120/1500] - Loss: 1648.7284; AAPD: 8.282395362854004
Epoch [1121/1500] - Loss: 139.2925; AAPD: 11.661463737487793
Epoch [1122/1500] - Loss: 61.5826; AAPD: 5.801662921905518
Epoch [1123/1500] - Loss: 67.8840; AAPD: 9.584522247314453
Epoch [1124/1500] - Loss: 19.2766; AAPD: 8.224955558776855
Epoch [1125/1500] - Loss: 1990.9442; AAPD: 12.395003318786621
Epoch [1126/1500] - Loss: 22.7572; AAPD: 6.0059404373168945
Epoch [1127/1500] - Loss: 302.6430; AAPD: 7.821167469024658
Epoch [1128/1500] - Loss: 143.0055; AAPD: 6.934934616088867
Epoch [1129/1500] - Loss: 54.0684; AAPD: 5.834334373474121
Epoch [1130/1500] - Loss: 325.0864; AAPD: 7.225375652313232
Epoch [1131/1500] - Loss: 194.7183; AAPD: 6.48383092880249
Epoch [1132/1500] - Loss: 1758.2402; AAPD: 11.968754768371582
Epoch [1133/1500] - Loss: 28.8005; AAPD: 

Epoch [1254/1500] - Loss: 80.3749; AAPD: 11.366005897521973
Epoch [1255/1500] - Loss: 19.7076; AAPD: 4.302469253540039
Epoch [1256/1500] - Loss: 113.2990; AAPD: 6.650725841522217
Epoch [1257/1500] - Loss: 66.9772; AAPD: 10.368648529052734
Epoch [1258/1500] - Loss: 30.2722; AAPD: 6.524995803833008
Epoch [1259/1500] - Loss: 355.9697; AAPD: 10.08543872833252
Epoch [1260/1500] - Loss: 360.0180; AAPD: 12.511347770690918
Epoch [1261/1500] - Loss: 73.2257; AAPD: 7.49402379989624
Epoch [1262/1500] - Loss: 81.1273; AAPD: 9.158735275268555
Epoch [1263/1500] - Loss: 723.0611; AAPD: 9.375576972961426
Epoch [1264/1500] - Loss: 182.6335; AAPD: 10.842145919799805
Epoch [1265/1500] - Loss: 50.6884; AAPD: 11.944755554199219
Epoch [1266/1500] - Loss: 21.8080; AAPD: 5.636489391326904
Epoch [1267/1500] - Loss: 188.4189; AAPD: 7.034380912780762
Epoch [1268/1500] - Loss: 131.5248; AAPD: 9.13664436340332
Epoch [1269/1500] - Loss: 151.8818; AAPD: 5.384702205657959
Epoch [1270/1500] - Loss: 53.1088; AAPD: 6.02

Epoch [1392/1500] - Loss: 265.0417; AAPD: 8.361289978027344
Epoch [1393/1500] - Loss: 166.5915; AAPD: 10.134730339050293
Epoch [1394/1500] - Loss: 183.9469; AAPD: 4.48820161819458
Epoch [1395/1500] - Loss: 88.7817; AAPD: 9.594942092895508
Epoch [1396/1500] - Loss: 267.5976; AAPD: 8.335776329040527
Epoch [1397/1500] - Loss: 181.6111; AAPD: 8.474233627319336
Epoch [1398/1500] - Loss: 45.9169; AAPD: 4.4981608390808105
Epoch [1399/1500] - Loss: 83.0499; AAPD: 7.013831615447998
Epoch [1400/1500] - Loss: 74.1746; AAPD: 7.386194705963135
Epoch [1401/1500] - Loss: 34.6328; AAPD: 6.803243637084961
Epoch [1402/1500] - Loss: 119.8513; AAPD: 7.810332775115967
Epoch [1403/1500] - Loss: 11.3016; AAPD: 4.551613807678223
Epoch [1404/1500] - Loss: 46.6284; AAPD: 6.425481796264648
Epoch [1405/1500] - Loss: 250.1058; AAPD: 12.276851654052734
Epoch [1406/1500] - Loss: 150.6413; AAPD: 6.7975754737854
Epoch [1407/1500] - Loss: 142.1937; AAPD: 8.225762367248535
Epoch [1408/1500] - Loss: 292.9780; AAPD: 6.806

In [53]:
model = linearencoderdecoder()
best_model_ED_AAPD_3 = train_model(model, 'best_model_ED_AAPD_3.pth', train_dataloader, dev_dataloader, target_idx=3)

Epoch [1/1500] - Loss: 887.9481; AAPD: 4510.349609375
Epoch [2/1500] - Loss: 2568.7998; AAPD: 6504.5107421875
Epoch [3/1500] - Loss: 724.0027; AAPD: 1761.52783203125
Epoch [4/1500] - Loss: 4882.2378; AAPD: 2141.979248046875
Epoch [5/1500] - Loss: 1159.3770; AAPD: 1814.7484130859375
Epoch [6/1500] - Loss: 1130.0393; AAPD: 5892.16162109375
Epoch [7/1500] - Loss: 1263.4564; AAPD: 3653.507080078125
Epoch [8/1500] - Loss: 1320.9696; AAPD: 1579.8697509765625
Epoch [9/1500] - Loss: 7787.0464; AAPD: 3721.18115234375
Epoch [10/1500] - Loss: 1309.3113; AAPD: 2675.375
Epoch [11/1500] - Loss: 2326.7432; AAPD: 2841.653564453125
Epoch [12/1500] - Loss: 8452.3691; AAPD: 9547.044921875
Epoch [13/1500] - Loss: 909.8444; AAPD: 3186.166015625
Epoch [14/1500] - Loss: 636.1168; AAPD: 4533.03173828125
Epoch [15/1500] - Loss: 3718.2727; AAPD: 5800.26220703125
Epoch [16/1500] - Loss: 1943.8320; AAPD: 5024.52197265625
Epoch [17/1500] - Loss: 426.4530; AAPD: 2713.600341796875
Epoch [18/1500] - Loss: 4020.5952; 

Epoch [144/1500] - Loss: 1025.1722; AAPD: 4229.97021484375
Epoch [145/1500] - Loss: 18020.6191; AAPD: 5836.60302734375
Epoch [146/1500] - Loss: 1538.7655; AAPD: 5796.443359375
Epoch [147/1500] - Loss: 954.9426; AAPD: 1334.3621826171875
Epoch [148/1500] - Loss: 1388.3013; AAPD: 1515.5589599609375
Epoch [149/1500] - Loss: 2635.4954; AAPD: 4873.63232421875
Epoch [150/1500] - Loss: 2759.1313; AAPD: 4641.2744140625
Epoch [151/1500] - Loss: 2175.3853; AAPD: 3303.37890625
Epoch [152/1500] - Loss: 804.9722; AAPD: 6275.3955078125
Epoch [153/1500] - Loss: 955.8270; AAPD: 3495.569580078125
Epoch [154/1500] - Loss: 2752.4304; AAPD: 6027.455078125
Epoch [155/1500] - Loss: 524.2473; AAPD: 1814.52197265625
Epoch [156/1500] - Loss: 805.4154; AAPD: 3071.76025390625
Epoch [157/1500] - Loss: 5812.1904; AAPD: 3633.591552734375
Epoch [158/1500] - Loss: 2123.7666; AAPD: 1403.6143798828125
Epoch [159/1500] - Loss: 4413.6851; AAPD: 5815.7333984375
Epoch [160/1500] - Loss: 9629.2373; AAPD: 2283.4052734375
Epoc

Epoch [284/1500] - Loss: 176.7289; AAPD: 620.2538452148438
Epoch [285/1500] - Loss: 332.1717; AAPD: 872.8206176757812
Epoch [286/1500] - Loss: 68.5671; AAPD: 198.2926483154297
Epoch [287/1500] - Loss: 257.0003; AAPD: 481.32080078125
Epoch [288/1500] - Loss: 156.8687; AAPD: 218.6085662841797
Epoch [289/1500] - Loss: 563.4024; AAPD: 145.4454345703125
Epoch [290/1500] - Loss: 168.5943; AAPD: 247.1967010498047
Epoch [291/1500] - Loss: 363.0068; AAPD: 480.47998046875
Epoch [292/1500] - Loss: 1002.8699; AAPD: 612.144775390625
Epoch [293/1500] - Loss: 161.2512; AAPD: 272.9217529296875
Epoch [294/1500] - Loss: 206.9501; AAPD: 407.71417236328125
Epoch [295/1500] - Loss: 115.8490; AAPD: 219.12091064453125
Epoch [296/1500] - Loss: 397.2770; AAPD: 679.3057250976562
Epoch [297/1500] - Loss: 400.2560; AAPD: 396.30828857421875
Epoch [298/1500] - Loss: 509.9636; AAPD: 434.9571533203125
Epoch [299/1500] - Loss: 68.8429; AAPD: 357.3792419433594
Epoch [300/1500] - Loss: 48.0350; AAPD: 171.1581573486328
E

Epoch [424/1500] - Loss: 189.9756; AAPD: 545.4400634765625
Epoch [425/1500] - Loss: 302.3776; AAPD: 690.0884399414062
Epoch [426/1500] - Loss: 128.0109; AAPD: 499.1050720214844
Epoch [427/1500] - Loss: 248.2782; AAPD: 672.3325805664062
Epoch [428/1500] - Loss: 254.8743; AAPD: 361.2831726074219
Epoch [429/1500] - Loss: 84.0988; AAPD: 268.5278625488281
Epoch [430/1500] - Loss: 118.3096; AAPD: 161.25677490234375
Epoch [431/1500] - Loss: 8024.9829; AAPD: 322.6416320800781
Epoch [432/1500] - Loss: 140.5169; AAPD: 530.1998901367188
Epoch [433/1500] - Loss: 2023.3439; AAPD: 237.96812438964844
Epoch [434/1500] - Loss: 428.9315; AAPD: 312.9554443359375
Epoch [435/1500] - Loss: 163.9675; AAPD: 377.5975036621094
Epoch [436/1500] - Loss: 194.6868; AAPD: 234.52159118652344
Epoch [437/1500] - Loss: 245.9759; AAPD: 449.4423522949219
Epoch [438/1500] - Loss: 141.9245; AAPD: 453.125244140625
Epoch [439/1500] - Loss: 186.3951; AAPD: 480.00872802734375
Epoch [440/1500] - Loss: 143.8519; AAPD: 528.6135253

Epoch [564/1500] - Loss: 43.9930; AAPD: 122.1926040649414
Epoch [565/1500] - Loss: 103.6865; AAPD: 141.6409912109375
Epoch [566/1500] - Loss: 106.5315; AAPD: 78.31153106689453
Epoch [567/1500] - Loss: 91.6536; AAPD: 84.30459594726562
Epoch [568/1500] - Loss: 30.7278; AAPD: 103.67827606201172
Epoch [569/1500] - Loss: 35.6035; AAPD: 106.0560531616211
Epoch [570/1500] - Loss: 368.8494; AAPD: 70.29607391357422
Epoch [571/1500] - Loss: 67.9854; AAPD: 69.31096649169922
Epoch [572/1500] - Loss: 24.9153; AAPD: 77.24217224121094
Epoch [573/1500] - Loss: 259.7686; AAPD: 118.52835845947266
Epoch [574/1500] - Loss: 146.3560; AAPD: 25.82731819152832
Epoch [575/1500] - Loss: 62.6419; AAPD: 66.03861236572266
Epoch [576/1500] - Loss: 100.1827; AAPD: 93.67022705078125
Epoch [577/1500] - Loss: 39.7280; AAPD: 60.991485595703125
Epoch [578/1500] - Loss: 338.6212; AAPD: 67.7119369506836
Epoch [579/1500] - Loss: 57.1853; AAPD: 46.717777252197266
Epoch [580/1500] - Loss: 62.3774; AAPD: 99.44830322265625
Epoc

Epoch [704/1500] - Loss: 31.5257; AAPD: 67.89491271972656
Epoch [705/1500] - Loss: 29.9714; AAPD: 56.39451599121094
Epoch [706/1500] - Loss: 12.4371; AAPD: 75.16990661621094
Epoch [707/1500] - Loss: 9.3052; AAPD: 56.53770065307617
Epoch [708/1500] - Loss: 75.8492; AAPD: 212.87619018554688
Epoch [709/1500] - Loss: 29.7694; AAPD: 111.20075988769531
Epoch [710/1500] - Loss: 34.7514; AAPD: 25.051393508911133
Epoch [711/1500] - Loss: 49.6747; AAPD: 162.49217224121094
Epoch [712/1500] - Loss: 482.1569; AAPD: 105.42510223388672
Epoch [713/1500] - Loss: 31.3018; AAPD: 133.44166564941406
Epoch [714/1500] - Loss: 181.0465; AAPD: 121.51900482177734
Epoch [715/1500] - Loss: 97.6684; AAPD: 108.32563018798828
Epoch [716/1500] - Loss: 7.1970; AAPD: 30.97138786315918
Epoch [717/1500] - Loss: 36.5846; AAPD: 166.37255859375
Epoch [718/1500] - Loss: 79.3827; AAPD: 84.36915588378906
Epoch [719/1500] - Loss: 65.0303; AAPD: 137.76963806152344
Epoch [720/1500] - Loss: 96.3474; AAPD: 150.12330627441406
Epoch 

Epoch [845/1500] - Loss: 48.7841; AAPD: 21.43720245361328
Epoch [846/1500] - Loss: 8.5455; AAPD: 12.287375450134277
Epoch [847/1500] - Loss: 13.4555; AAPD: 10.649249076843262
Epoch [848/1500] - Loss: 15.3153; AAPD: 22.79052734375
Epoch [849/1500] - Loss: 7.1961; AAPD: 16.359560012817383
Epoch [850/1500] - Loss: 14.3918; AAPD: 13.216035842895508
Epoch [851/1500] - Loss: 10.2203; AAPD: 8.37718391418457
Epoch [852/1500] - Loss: 16.7741; AAPD: 19.860652923583984
Epoch [853/1500] - Loss: 25.5458; AAPD: 12.566027641296387
Epoch [854/1500] - Loss: 6.2375; AAPD: 17.969505310058594
Epoch [855/1500] - Loss: 23.9261; AAPD: 15.312239646911621
Epoch [856/1500] - Loss: 10.1169; AAPD: 11.255873680114746
Epoch [857/1500] - Loss: 28.8492; AAPD: 14.928256034851074
Epoch [858/1500] - Loss: 6.6739; AAPD: 12.243428230285645
Epoch [859/1500] - Loss: 9.5512; AAPD: 5.3980712890625
Epoch [860/1500] - Loss: 44.3468; AAPD: 15.085306167602539
Epoch [861/1500] - Loss: 17.8387; AAPD: 15.792460441589355
Epoch [862/1

Epoch [986/1500] - Loss: 16.7177; AAPD: 14.068719863891602
Epoch [987/1500] - Loss: 12.2198; AAPD: 7.324000835418701
Epoch [988/1500] - Loss: 61.3575; AAPD: 21.020505905151367
Epoch [989/1500] - Loss: 19.4292; AAPD: 12.887606620788574
Epoch [990/1500] - Loss: 32.5808; AAPD: 9.998461723327637
Epoch [991/1500] - Loss: 16848.3789; AAPD: 23.505109786987305
Epoch [992/1500] - Loss: 110.9888; AAPD: 18.35871124267578
Epoch [993/1500] - Loss: 14.6398; AAPD: 11.013812065124512
Epoch [994/1500] - Loss: 8.1622; AAPD: 10.168440818786621
Epoch [995/1500] - Loss: 17.5269; AAPD: 10.480363845825195
Epoch [996/1500] - Loss: 25.4019; AAPD: 21.447853088378906
Epoch [997/1500] - Loss: 7.0354; AAPD: 15.495758056640625
Epoch [998/1500] - Loss: 4.4342; AAPD: 15.513107299804688
Epoch [999/1500] - Loss: 6.2568; AAPD: 7.560385227203369
Epoch [1000/1500] - Loss: 26.1709; AAPD: 8.095666885375977
Epoch [1001/1500] - Loss: 11.9966; AAPD: 7.247937202453613
Epoch [1002/1500] - Loss: 8.2157; AAPD: 4.430881023406982
Ep

Epoch [1125/1500] - Loss: 4.9188; AAPD: 2.8436360359191895
Epoch [1126/1500] - Loss: 10.3762; AAPD: 2.8779609203338623
Epoch [1127/1500] - Loss: 9.7259; AAPD: 2.350806713104248
Epoch [1128/1500] - Loss: 8.7459; AAPD: 1.8809850215911865
Epoch [1129/1500] - Loss: 9.1191; AAPD: 2.634443759918213
Epoch [1130/1500] - Loss: 23.2273; AAPD: 2.10964298248291
Epoch [1131/1500] - Loss: 11.3073; AAPD: 2.5978729724884033
Epoch [1132/1500] - Loss: 44.2054; AAPD: 3.2071855068206787
Epoch [1133/1500] - Loss: 7.5322; AAPD: 1.7159417867660522
Epoch [1134/1500] - Loss: 2.6159; AAPD: 2.7540760040283203
Epoch [1135/1500] - Loss: 5.7256; AAPD: 2.708177089691162
Epoch [1136/1500] - Loss: 14.2091; AAPD: 1.7040060758590698
Epoch [1137/1500] - Loss: 5.4354; AAPD: 3.872962713241577
Epoch [1138/1500] - Loss: 23.7048; AAPD: 1.982733964920044
Epoch [1139/1500] - Loss: 4.7354; AAPD: 2.9345550537109375
Epoch [1140/1500] - Loss: 3.3455; AAPD: 2.7592368125915527
Epoch [1141/1500] - Loss: 17.1268; AAPD: 2.29588079452514

Epoch [1264/1500] - Loss: 4.5328; AAPD: 2.5518290996551514
Epoch [1265/1500] - Loss: 28.0503; AAPD: 2.26066255569458
Epoch [1266/1500] - Loss: 3.1235; AAPD: 1.4656314849853516
Epoch [1267/1500] - Loss: 10.2002; AAPD: 1.8971422910690308
Epoch [1268/1500] - Loss: 6.3340; AAPD: 1.329716444015503
Epoch [1269/1500] - Loss: 3.6206; AAPD: 1.5115569829940796
Epoch [1270/1500] - Loss: 48.8413; AAPD: 1.5621380805969238
Epoch [1271/1500] - Loss: 6.3192; AAPD: 2.19446063041687
Epoch [1272/1500] - Loss: 12.6147; AAPD: 1.8039350509643555
Epoch [1273/1500] - Loss: 7.2974; AAPD: 2.8077123165130615
Epoch [1274/1500] - Loss: 3.4427; AAPD: 2.8974480628967285
Epoch [1275/1500] - Loss: 10.4415; AAPD: 2.131162643432617
Epoch [1276/1500] - Loss: 2.8711; AAPD: 1.5933948755264282
Epoch [1277/1500] - Loss: 12.0367; AAPD: 2.181609630584717
Epoch [1278/1500] - Loss: 22.0234; AAPD: 1.899669885635376
Epoch [1279/1500] - Loss: 28.9247; AAPD: 4.503277778625488
Epoch [1280/1500] - Loss: 5.0924; AAPD: 1.466486096382141

Epoch [1403/1500] - Loss: 2.2366; AAPD: 2.923966407775879
Epoch [1404/1500] - Loss: 3.1142; AAPD: 2.5597970485687256
Epoch [1405/1500] - Loss: 11.6199; AAPD: 2.732459306716919
Epoch [1406/1500] - Loss: 7.1556; AAPD: 2.925442934036255
Epoch [1407/1500] - Loss: 9.4141; AAPD: 3.70184588432312
Epoch [1408/1500] - Loss: 10.9820; AAPD: 2.637958288192749
Epoch [1409/1500] - Loss: 12.3259; AAPD: 3.3747901916503906
Epoch [1410/1500] - Loss: 3.4012; AAPD: 1.7920595407485962
Epoch [1411/1500] - Loss: 1.8858; AAPD: 2.696664810180664
Epoch [1412/1500] - Loss: 38.3958; AAPD: 2.66229510307312
Epoch [1413/1500] - Loss: 18.8818; AAPD: 3.510925769805908
Epoch [1414/1500] - Loss: 5.4910; AAPD: 2.275766372680664
Epoch [1415/1500] - Loss: 25.5084; AAPD: 1.9098596572875977
Epoch [1416/1500] - Loss: 12.9372; AAPD: 2.712958812713623
Epoch [1417/1500] - Loss: 4.4895; AAPD: 2.657613515853882
Epoch [1418/1500] - Loss: 5.3323; AAPD: 2.5356760025024414
Epoch [1419/1500] - Loss: 4.7542; AAPD: 2.9540274143218994
Epo

In [55]:
def train_model(model, model_name, train_loader, dev_loader, target_idx=1, num_epochs=1500, learning_rate=0.01, weight_decay=1e-4):
    #criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[250, 500, 750, 1000], gamma=0.1)
    
    best_model = None
    best_aapd = None

    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            input_data = batch['input_data']
            target1 = batch['out1']
            target2 = batch['out2']
            target3 = batch['out3']
            
            optimizer.zero_grad()
            outputs = model(input_data)
            if target_idx == 1:
                train_loss = ASPD(outputs, target1)
            elif target_idx == 2:
                train_loss = ASPD(outputs, target2)
            elif target_idx == 3:
                train_loss = ASPD(outputs, target3)
            else:
                print("Invalid target_idx")
            
            train_loss.backward()
            optimizer.step()
        scheduler.step()
        
        model.eval()
        
        total_aapd = 0
        num_samples = 0
        
        with torch.no_grad():
            for batch in dev_loader:
                dev_int = batch['input_data']
                dev_out1 = batch['out1']
                dev_out2 = batch['out2']
                dev_out3 = batch['out3']
                
                predictions = model(dev_int)
                
                if target_idx == 1:
                    batch_aapd = AAPD(predictions, dev_out1)
                elif target_idx == 2:
                    batch_aapd = AAPD(predictions, dev_out2)
                elif target_idx == 3:
                    batch_aapd = AAPD(predictions, dev_out3)
                else:
                    print("Invalid target_idx")
                    
                total_aapd += batch_aapd * len(dev_int)
                num_samples += len(dev_int)
            
            aapd = total_aapd / num_samples
            
            if best_aapd == None or aapd < best_aapd:
                best_aapd = aapd
                best_model = model
                torch.save(model.state_dict(), model_name)
        
        print(f'Epoch [{epoch + 1}/{num_epochs}] - Loss: {train_loss.item():.4f}; AAPD: {aapd}')

    return model

In [57]:
model = linearencoderdecoder()
best_model_ED_ASPD_1 = train_model(model, 'best_model_ED_ASPD_1.pth', train_dataloader, dev_dataloader, target_idx=1)

Epoch [1/1500] - Loss: 596.2654; AAPD: 0.5386804938316345
Epoch [2/1500] - Loss: 500.1610; AAPD: 1.256445050239563
Epoch [3/1500] - Loss: 433.7415; AAPD: 0.36869728565216064
Epoch [4/1500] - Loss: 425.8441; AAPD: 0.3695125877857208
Epoch [5/1500] - Loss: 280.5338; AAPD: 0.5051295161247253
Epoch [6/1500] - Loss: 199.8747; AAPD: 0.33819228410720825
Epoch [7/1500] - Loss: 1339.9106; AAPD: 0.7407754063606262
Epoch [8/1500] - Loss: 146.5899; AAPD: 0.2527046799659729
Epoch [9/1500] - Loss: 65.6657; AAPD: 0.44408535957336426
Epoch [10/1500] - Loss: 832.4576; AAPD: 0.3807946741580963
Epoch [11/1500] - Loss: 391.0229; AAPD: 0.19039994478225708
Epoch [12/1500] - Loss: 587.4075; AAPD: 0.6185450553894043
Epoch [13/1500] - Loss: 1066.9235; AAPD: 0.6378443241119385
Epoch [14/1500] - Loss: 168.7259; AAPD: 0.2522839307785034
Epoch [15/1500] - Loss: 161.6247; AAPD: 0.24701310694217682
Epoch [16/1500] - Loss: 227.6484; AAPD: 0.28783971071243286
Epoch [17/1500] - Loss: 466.6808; AAPD: 0.24821436405181885

Epoch [138/1500] - Loss: 706.7862; AAPD: 0.197480246424675
Epoch [139/1500] - Loss: 260.8958; AAPD: 0.20598401129245758
Epoch [140/1500] - Loss: 2820.0815; AAPD: 0.239454984664917
Epoch [141/1500] - Loss: 154.0777; AAPD: 0.18719175457954407
Epoch [142/1500] - Loss: 243.7135; AAPD: 0.22575139999389648
Epoch [143/1500] - Loss: 233.2769; AAPD: 0.20213183760643005
Epoch [144/1500] - Loss: 154.5949; AAPD: 0.3846491575241089
Epoch [145/1500] - Loss: 1040.6475; AAPD: 0.2476469874382019
Epoch [146/1500] - Loss: 119.4873; AAPD: 0.22124435007572174
Epoch [147/1500] - Loss: 72.6346; AAPD: 0.28386053442955017
Epoch [148/1500] - Loss: 411.1798; AAPD: 0.21147498488426208
Epoch [149/1500] - Loss: 807.3282; AAPD: 0.20278480648994446
Epoch [150/1500] - Loss: 435.3360; AAPD: 0.19905689358711243
Epoch [151/1500] - Loss: 511.8851; AAPD: 0.2078048586845398
Epoch [152/1500] - Loss: 336.5947; AAPD: 0.2591659724712372
Epoch [153/1500] - Loss: 819.0004; AAPD: 0.16923962533473969
Epoch [154/1500] - Loss: 881.85

Epoch [273/1500] - Loss: 210.3541; AAPD: 0.143296480178833
Epoch [274/1500] - Loss: 3053.8494; AAPD: 0.19273772835731506
Epoch [275/1500] - Loss: 203.6298; AAPD: 0.1565312147140503
Epoch [276/1500] - Loss: 615.4091; AAPD: 0.13120245933532715
Epoch [277/1500] - Loss: 328.2173; AAPD: 0.20386841893196106
Epoch [278/1500] - Loss: 151.4833; AAPD: 0.1652039885520935
Epoch [279/1500] - Loss: 641.3461; AAPD: 0.160677969455719
Epoch [280/1500] - Loss: 166.0518; AAPD: 0.16090819239616394
Epoch [281/1500] - Loss: 930.7016; AAPD: 0.18837539851665497
Epoch [282/1500] - Loss: 158.6909; AAPD: 0.1680244356393814
Epoch [283/1500] - Loss: 1326.4777; AAPD: 0.1960832178592682
Epoch [284/1500] - Loss: 207.3920; AAPD: 0.1494183987379074
Epoch [285/1500] - Loss: 167.4330; AAPD: 0.1740964949131012
Epoch [286/1500] - Loss: 264.2109; AAPD: 0.15730425715446472
Epoch [287/1500] - Loss: 6795.5581; AAPD: 0.21047893166542053
Epoch [288/1500] - Loss: 454.3915; AAPD: 0.15985606610774994
Epoch [289/1500] - Loss: 256.75

Epoch [408/1500] - Loss: 7232.9526; AAPD: 0.20427167415618896
Epoch [409/1500] - Loss: 143.1830; AAPD: 0.15985138714313507
Epoch [410/1500] - Loss: 417.3644; AAPD: 0.13421213626861572
Epoch [411/1500] - Loss: 3153.4380; AAPD: 0.1900784820318222
Epoch [412/1500] - Loss: 668.4622; AAPD: 0.16875962913036346
Epoch [413/1500] - Loss: 480.9213; AAPD: 0.1583453267812729
Epoch [414/1500] - Loss: 363.2944; AAPD: 0.15567699074745178
Epoch [415/1500] - Loss: 255.1564; AAPD: 0.15439586341381073
Epoch [416/1500] - Loss: 239.5212; AAPD: 0.14884336292743683
Epoch [417/1500] - Loss: 1115.1008; AAPD: 0.1476956009864807
Epoch [418/1500] - Loss: 385.2097; AAPD: 0.14048810303211212
Epoch [419/1500] - Loss: 2452.2673; AAPD: 0.18060725927352905
Epoch [420/1500] - Loss: 222.1609; AAPD: 0.1450241357088089
Epoch [421/1500] - Loss: 482.9311; AAPD: 0.15000654757022858
Epoch [422/1500] - Loss: 253.2693; AAPD: 0.1646515130996704
Epoch [423/1500] - Loss: 722.7308; AAPD: 0.1392725557088852
Epoch [424/1500] - Loss: 6

Epoch [543/1500] - Loss: 399.1906; AAPD: 0.1497384011745453
Epoch [544/1500] - Loss: 353.6100; AAPD: 0.1380225569009781
Epoch [545/1500] - Loss: 779.9738; AAPD: 0.1522640436887741
Epoch [546/1500] - Loss: 529.6528; AAPD: 0.13342346251010895
Epoch [547/1500] - Loss: 98.1356; AAPD: 0.14970315992832184
Epoch [548/1500] - Loss: 544.1036; AAPD: 0.15359555184841156
Epoch [549/1500] - Loss: 405.4396; AAPD: 0.14498969912528992
Epoch [550/1500] - Loss: 123.9302; AAPD: 0.1530420035123825
Epoch [551/1500] - Loss: 6551.6514; AAPD: 0.2263038456439972
Epoch [552/1500] - Loss: 2729.4934; AAPD: 0.15810999274253845
Epoch [553/1500] - Loss: 1597.3627; AAPD: 0.12407761812210083
Epoch [554/1500] - Loss: 113.5239; AAPD: 0.15970104932785034
Epoch [555/1500] - Loss: 1110.3727; AAPD: 0.1871928572654724
Epoch [556/1500] - Loss: 2083.0159; AAPD: 0.18581317365169525
Epoch [557/1500] - Loss: 2624.7922; AAPD: 0.1462002545595169
Epoch [558/1500] - Loss: 1237.3895; AAPD: 0.14556017518043518
Epoch [559/1500] - Loss: 

Epoch [678/1500] - Loss: 2338.8916; AAPD: 0.12971612811088562
Epoch [679/1500] - Loss: 706.5914; AAPD: 0.1758679747581482
Epoch [680/1500] - Loss: 274.8483; AAPD: 0.13773901760578156
Epoch [681/1500] - Loss: 608.6131; AAPD: 0.13064193725585938
Epoch [682/1500] - Loss: 1299.2874; AAPD: 0.1581103652715683
Epoch [683/1500] - Loss: 143.8040; AAPD: 0.18151730298995972
Epoch [684/1500] - Loss: 579.3058; AAPD: 0.143976092338562
Epoch [685/1500] - Loss: 386.9555; AAPD: 0.14356912672519684
Epoch [686/1500] - Loss: 49.4578; AAPD: 0.14523115754127502
Epoch [687/1500] - Loss: 215.2835; AAPD: 0.14967846870422363
Epoch [688/1500] - Loss: 143.6421; AAPD: 0.16018205881118774
Epoch [689/1500] - Loss: 302.3036; AAPD: 0.14193831384181976
Epoch [690/1500] - Loss: 388.4854; AAPD: 0.15666091442108154
Epoch [691/1500] - Loss: 536.8815; AAPD: 0.14106279611587524
Epoch [692/1500] - Loss: 2608.6606; AAPD: 0.17732883989810944
Epoch [693/1500] - Loss: 193.6406; AAPD: 0.14321692287921906
Epoch [694/1500] - Loss: 2

Epoch [813/1500] - Loss: 77.2784; AAPD: 0.1441551297903061
Epoch [814/1500] - Loss: 1782.9563; AAPD: 0.14033745229244232
Epoch [815/1500] - Loss: 591.7233; AAPD: 0.16126270592212677
Epoch [816/1500] - Loss: 68.7224; AAPD: 0.13347582519054413
Epoch [817/1500] - Loss: 111.8317; AAPD: 0.1542060822248459
Epoch [818/1500] - Loss: 1082.1105; AAPD: 0.1563112735748291
Epoch [819/1500] - Loss: 796.8956; AAPD: 0.13819962739944458
Epoch [820/1500] - Loss: 1417.9371; AAPD: 0.14728045463562012
Epoch [821/1500] - Loss: 710.3545; AAPD: 0.14098019897937775
Epoch [822/1500] - Loss: 203.2600; AAPD: 0.13659656047821045
Epoch [823/1500] - Loss: 248.3566; AAPD: 0.14342236518859863
Epoch [824/1500] - Loss: 189.9637; AAPD: 0.11834574490785599
Epoch [825/1500] - Loss: 147.1043; AAPD: 0.14140473306179047
Epoch [826/1500] - Loss: 580.4343; AAPD: 0.11930301785469055
Epoch [827/1500] - Loss: 98.4511; AAPD: 0.14174766838550568
Epoch [828/1500] - Loss: 1187.8915; AAPD: 0.1222296878695488
Epoch [829/1500] - Loss: 26

Epoch [948/1500] - Loss: 238.2526; AAPD: 0.13751232624053955
Epoch [949/1500] - Loss: 204.2000; AAPD: 0.13246653974056244
Epoch [950/1500] - Loss: 2511.2217; AAPD: 0.13198506832122803
Epoch [951/1500] - Loss: 354.0511; AAPD: 0.12911662459373474
Epoch [952/1500] - Loss: 268.4155; AAPD: 0.153200164437294
Epoch [953/1500] - Loss: 1313.9872; AAPD: 0.1637871414422989
Epoch [954/1500] - Loss: 343.2464; AAPD: 0.15308216214179993
Epoch [955/1500] - Loss: 1898.7372; AAPD: 0.1549367606639862
Epoch [956/1500] - Loss: 1415.3601; AAPD: 0.17223164439201355
Epoch [957/1500] - Loss: 440.7243; AAPD: 0.179634690284729
Epoch [958/1500] - Loss: 302.1120; AAPD: 0.16192865371704102
Epoch [959/1500] - Loss: 1369.0880; AAPD: 0.11950927972793579
Epoch [960/1500] - Loss: 80.8565; AAPD: 0.16076135635375977
Epoch [961/1500] - Loss: 788.4895; AAPD: 0.1444198191165924
Epoch [962/1500] - Loss: 6999.1074; AAPD: 0.17116829752922058
Epoch [963/1500] - Loss: 398.0781; AAPD: 0.15460258722305298
Epoch [964/1500] - Loss: 5

Epoch [1082/1500] - Loss: 395.6467; AAPD: 0.16605530679225922
Epoch [1083/1500] - Loss: 1631.9010; AAPD: 0.1313019096851349
Epoch [1084/1500] - Loss: 195.6326; AAPD: 0.16126081347465515
Epoch [1085/1500] - Loss: 1409.4291; AAPD: 0.1639094203710556
Epoch [1086/1500] - Loss: 277.2882; AAPD: 0.14489296078681946
Epoch [1087/1500] - Loss: 94.8590; AAPD: 0.1606762409210205
Epoch [1088/1500] - Loss: 766.0837; AAPD: 0.12725739181041718
Epoch [1089/1500] - Loss: 438.9628; AAPD: 0.166036918759346
Epoch [1090/1500] - Loss: 215.8743; AAPD: 0.1500471979379654
Epoch [1091/1500] - Loss: 437.5298; AAPD: 0.15850816667079926
Epoch [1092/1500] - Loss: 147.8030; AAPD: 0.1418704241514206
Epoch [1093/1500] - Loss: 5449.2471; AAPD: 0.1568722277879715
Epoch [1094/1500] - Loss: 29019.4355; AAPD: 0.17073488235473633
Epoch [1095/1500] - Loss: 884.7322; AAPD: 0.15315480530261993
Epoch [1096/1500] - Loss: 151.8455; AAPD: 0.15961547195911407
Epoch [1097/1500] - Loss: 333.3513; AAPD: 0.14043352007865906
Epoch [1098/

Epoch [1215/1500] - Loss: 2861.8704; AAPD: 0.14600269496440887
Epoch [1216/1500] - Loss: 236.2832; AAPD: 0.1462155133485794
Epoch [1217/1500] - Loss: 1000.4637; AAPD: 0.14425356686115265
Epoch [1218/1500] - Loss: 809.9077; AAPD: 0.14155183732509613
Epoch [1219/1500] - Loss: 583.0250; AAPD: 0.14792956411838531
Epoch [1220/1500] - Loss: 296.5510; AAPD: 0.15035273134708405
Epoch [1221/1500] - Loss: 878.4714; AAPD: 0.15639621019363403
Epoch [1222/1500] - Loss: 412.0507; AAPD: 0.1721811145544052
Epoch [1223/1500] - Loss: 170.8475; AAPD: 0.13866649568080902
Epoch [1224/1500] - Loss: 330.5550; AAPD: 0.13356173038482666
Epoch [1225/1500] - Loss: 680.9393; AAPD: 0.13175398111343384
Epoch [1226/1500] - Loss: 290.6658; AAPD: 0.15355980396270752
Epoch [1227/1500] - Loss: 1187.6517; AAPD: 0.153285950422287
Epoch [1228/1500] - Loss: 64.2395; AAPD: 0.14394211769104004
Epoch [1229/1500] - Loss: 209.5438; AAPD: 0.14206798374652863
Epoch [1230/1500] - Loss: 288.2553; AAPD: 0.15379869937896729
Epoch [123

Epoch [1348/1500] - Loss: 903.4107; AAPD: 0.13495159149169922
Epoch [1349/1500] - Loss: 995.7773; AAPD: 0.1177704930305481
Epoch [1350/1500] - Loss: 1659.6948; AAPD: 0.13822925090789795
Epoch [1351/1500] - Loss: 1524.7938; AAPD: 0.18724371492862701
Epoch [1352/1500] - Loss: 319.9734; AAPD: 0.14795994758605957
Epoch [1353/1500] - Loss: 1154.9849; AAPD: 0.12264589220285416
Epoch [1354/1500] - Loss: 280.0825; AAPD: 0.13212774693965912
Epoch [1355/1500] - Loss: 751.8549; AAPD: 0.16585540771484375
Epoch [1356/1500] - Loss: 893.9387; AAPD: 0.13177461922168732
Epoch [1357/1500] - Loss: 312.9217; AAPD: 0.1497543901205063
Epoch [1358/1500] - Loss: 784.7699; AAPD: 0.1625233143568039
Epoch [1359/1500] - Loss: 461.0396; AAPD: 0.13698288798332214
Epoch [1360/1500] - Loss: 120.5821; AAPD: 0.14868147671222687
Epoch [1361/1500] - Loss: 2084.2893; AAPD: 0.17635120451450348
Epoch [1362/1500] - Loss: 551.4675; AAPD: 0.1469019204378128
Epoch [1363/1500] - Loss: 2500.0847; AAPD: 0.10837933421134949
Epoch [

Epoch [1481/1500] - Loss: 136.2370; AAPD: 0.14263546466827393
Epoch [1482/1500] - Loss: 706.0154; AAPD: 0.15163758397102356
Epoch [1483/1500] - Loss: 307.3669; AAPD: 0.14973589777946472
Epoch [1484/1500] - Loss: 1013.3204; AAPD: 0.12439648807048798
Epoch [1485/1500] - Loss: 97.0939; AAPD: 0.14950816333293915
Epoch [1486/1500] - Loss: 275.5343; AAPD: 0.13041342794895172
Epoch [1487/1500] - Loss: 143.9417; AAPD: 0.13331536948680878
Epoch [1488/1500] - Loss: 203.9851; AAPD: 0.12960346043109894
Epoch [1489/1500] - Loss: 2392.6160; AAPD: 0.15914344787597656
Epoch [1490/1500] - Loss: 2117.2605; AAPD: 0.16198422014713287
Epoch [1491/1500] - Loss: 483.3325; AAPD: 0.15781958401203156
Epoch [1492/1500] - Loss: 82.9316; AAPD: 0.1340673714876175
Epoch [1493/1500] - Loss: 2811.1487; AAPD: 0.13803082704544067
Epoch [1494/1500] - Loss: 305.7365; AAPD: 0.16804732382297516
Epoch [1495/1500] - Loss: 398.9767; AAPD: 0.13439379632472992
Epoch [1496/1500] - Loss: 1618.6080; AAPD: 0.16566164791584015
Epoch 

In [58]:
model = linearencoderdecoder()
best_model_ED_ASPD_2 = train_model(model, 'best_model_ED_ASPD_2.pth', train_dataloader, dev_dataloader, target_idx=2)

Epoch [1/1500] - Loss: 0.4126; AAPD: 6.367913722991943
Epoch [2/1500] - Loss: 0.0536; AAPD: 3.350201368331909
Epoch [3/1500] - Loss: 0.0438; AAPD: 1.3615750074386597
Epoch [4/1500] - Loss: 0.0314; AAPD: 1.106931447982788
Epoch [5/1500] - Loss: 0.0390; AAPD: 1.4187865257263184
Epoch [6/1500] - Loss: 0.0259; AAPD: 1.3315900564193726
Epoch [7/1500] - Loss: 0.0645; AAPD: 5.8044633865356445
Epoch [8/1500] - Loss: 0.0306; AAPD: 1.172752857208252
Epoch [9/1500] - Loss: 0.0314; AAPD: 1.265318751335144
Epoch [10/1500] - Loss: 0.1515; AAPD: 3.5374603271484375
Epoch [11/1500] - Loss: 0.8856; AAPD: 23.152063369750977
Epoch [12/1500] - Loss: 0.2954; AAPD: 3.8170783519744873
Epoch [13/1500] - Loss: 0.0289; AAPD: 1.0582002401351929
Epoch [14/1500] - Loss: 0.0411; AAPD: 1.3602888584136963
Epoch [15/1500] - Loss: 0.0543; AAPD: 2.6397244930267334
Epoch [16/1500] - Loss: 0.0887; AAPD: 2.313335657119751
Epoch [17/1500] - Loss: 0.0285; AAPD: 1.2454005479812622
Epoch [18/1500] - Loss: 0.0971; AAPD: 2.272480

Epoch [145/1500] - Loss: 81.1744; AAPD: 54.840553283691406
Epoch [146/1500] - Loss: 16.8000; AAPD: 62.70425796508789
Epoch [147/1500] - Loss: 14.2172; AAPD: 55.888145446777344
Epoch [148/1500] - Loss: 12.6559; AAPD: 55.563751220703125
Epoch [149/1500] - Loss: 24.2664; AAPD: 49.576805114746094
Epoch [150/1500] - Loss: 5.5481; AAPD: 20.357330322265625
Epoch [151/1500] - Loss: 324.4672; AAPD: 55.70211410522461
Epoch [152/1500] - Loss: 56.7487; AAPD: 39.956703186035156
Epoch [153/1500] - Loss: 30.8277; AAPD: 59.8074836730957
Epoch [154/1500] - Loss: 144.8779; AAPD: 76.71735382080078
Epoch [155/1500] - Loss: 121.6942; AAPD: 63.759666442871094
Epoch [156/1500] - Loss: 6.1078; AAPD: 43.21078109741211
Epoch [157/1500] - Loss: 11.4122; AAPD: 55.64858627319336
Epoch [158/1500] - Loss: 2.0839; AAPD: 8.915485382080078
Epoch [159/1500] - Loss: 9.4420; AAPD: 32.16053009033203
Epoch [160/1500] - Loss: 23.4599; AAPD: 60.907527923583984
Epoch [161/1500] - Loss: 4.5537; AAPD: 25.216148376464844
Epoch [1

Epoch [285/1500] - Loss: 46.4808; AAPD: 36.944705963134766
Epoch [286/1500] - Loss: 118.0944; AAPD: 63.765907287597656
Epoch [287/1500] - Loss: 135.8870; AAPD: 46.61631774902344
Epoch [288/1500] - Loss: 14.3241; AAPD: 74.16746520996094
Epoch [289/1500] - Loss: 16.0552; AAPD: 42.14085006713867
Epoch [290/1500] - Loss: 60.9237; AAPD: 36.41907501220703
Epoch [291/1500] - Loss: 27.3488; AAPD: 92.62619018554688
Epoch [292/1500] - Loss: 57.0354; AAPD: 59.529151916503906
Epoch [293/1500] - Loss: 6.6219; AAPD: 47.68790054321289
Epoch [294/1500] - Loss: 37.9271; AAPD: 52.30434799194336
Epoch [295/1500] - Loss: 42.7211; AAPD: 68.93289184570312
Epoch [296/1500] - Loss: 18.5160; AAPD: 49.91401672363281
Epoch [297/1500] - Loss: 71.5338; AAPD: 29.77106475830078
Epoch [298/1500] - Loss: 25.5436; AAPD: 24.910139083862305
Epoch [299/1500] - Loss: 133.9310; AAPD: 22.457429885864258
Epoch [300/1500] - Loss: 3.9642; AAPD: 18.70842170715332
Epoch [301/1500] - Loss: 9.7822; AAPD: 57.84178924560547
Epoch [30

Epoch [425/1500] - Loss: 241.9458; AAPD: 86.50463104248047
Epoch [426/1500] - Loss: 144.9837; AAPD: 88.29631805419922
Epoch [427/1500] - Loss: 15.8825; AAPD: 120.8396987915039
Epoch [428/1500] - Loss: 241.3332; AAPD: 67.63121032714844
Epoch [429/1500] - Loss: 19.3270; AAPD: 97.55362701416016
Epoch [430/1500] - Loss: 262.6596; AAPD: 119.95838928222656
Epoch [431/1500] - Loss: 60.6301; AAPD: 81.44357299804688
Epoch [432/1500] - Loss: 58.1121; AAPD: 82.0243911743164
Epoch [433/1500] - Loss: 5148.1753; AAPD: 67.54527282714844
Epoch [434/1500] - Loss: 59.9542; AAPD: 167.82843017578125
Epoch [435/1500] - Loss: 46.8002; AAPD: 89.95660400390625
Epoch [436/1500] - Loss: 316.5043; AAPD: 88.48155212402344
Epoch [437/1500] - Loss: 1027.3416; AAPD: 81.98114013671875
Epoch [438/1500] - Loss: 142.8594; AAPD: 80.66508483886719
Epoch [439/1500] - Loss: 76.5884; AAPD: 80.72522735595703
Epoch [440/1500] - Loss: 129.9474; AAPD: 82.05108642578125
Epoch [441/1500] - Loss: 33.0888; AAPD: 70.239990234375
Epoc

Epoch [565/1500] - Loss: 12.8229; AAPD: 14.571561813354492
Epoch [566/1500] - Loss: 4.9087; AAPD: 7.709666728973389
Epoch [567/1500] - Loss: 36.9209; AAPD: 10.738747596740723
Epoch [568/1500] - Loss: 11.0389; AAPD: 10.43659496307373
Epoch [569/1500] - Loss: 24.7101; AAPD: 7.609890460968018
Epoch [570/1500] - Loss: 29.8419; AAPD: 13.81041145324707
Epoch [571/1500] - Loss: 45.9317; AAPD: 16.71315574645996
Epoch [572/1500] - Loss: 45.8891; AAPD: 10.589396476745605
Epoch [573/1500] - Loss: 87.0281; AAPD: 6.737598419189453
Epoch [574/1500] - Loss: 55.2213; AAPD: 4.1477952003479
Epoch [575/1500] - Loss: 85.1029; AAPD: 9.54167366027832
Epoch [576/1500] - Loss: 1217.9607; AAPD: 12.566997528076172
Epoch [577/1500] - Loss: 23.3673; AAPD: 7.481442928314209
Epoch [578/1500] - Loss: 101.1883; AAPD: 15.487713813781738
Epoch [579/1500] - Loss: 508.6459; AAPD: 14.208734512329102
Epoch [580/1500] - Loss: 11.2551; AAPD: 6.702250003814697
Epoch [581/1500] - Loss: 9.8555; AAPD: 3.5302727222442627
Epoch [5

Epoch [706/1500] - Loss: 18.8469; AAPD: 12.260220527648926
Epoch [707/1500] - Loss: 54.1117; AAPD: 9.984631538391113
Epoch [708/1500] - Loss: 5.0470; AAPD: 10.363276481628418
Epoch [709/1500] - Loss: 49.1733; AAPD: 4.72584867477417
Epoch [710/1500] - Loss: 56.2084; AAPD: 8.576607704162598
Epoch [711/1500] - Loss: 4.0872; AAPD: 4.888283729553223
Epoch [712/1500] - Loss: 15.5026; AAPD: 8.272167205810547
Epoch [713/1500] - Loss: 12.6311; AAPD: 19.55485725402832
Epoch [714/1500] - Loss: 13.5495; AAPD: 13.31922721862793
Epoch [715/1500] - Loss: 38.3018; AAPD: 11.872063636779785
Epoch [716/1500] - Loss: 11.9703; AAPD: 14.57944393157959
Epoch [717/1500] - Loss: 16.3769; AAPD: 12.38001537322998
Epoch [718/1500] - Loss: 56.4978; AAPD: 15.33255386352539
Epoch [719/1500] - Loss: 6.1249; AAPD: 17.302961349487305
Epoch [720/1500] - Loss: 6.0099; AAPD: 3.286986827850342
Epoch [721/1500] - Loss: 14.5657; AAPD: 7.484124660491943
Epoch [722/1500] - Loss: 27.9533; AAPD: 11.079216003417969
Epoch [723/150

Epoch [847/1500] - Loss: 105.0615; AAPD: 6.030706882476807
Epoch [848/1500] - Loss: 11.1293; AAPD: 5.029484748840332
Epoch [849/1500] - Loss: 2.5869; AAPD: 6.292631149291992
Epoch [850/1500] - Loss: 608.6655; AAPD: 4.331545352935791
Epoch [851/1500] - Loss: 8.7294; AAPD: 4.099618911743164
Epoch [852/1500] - Loss: 2.0139; AAPD: 2.412015199661255
Epoch [853/1500] - Loss: 53.9567; AAPD: 5.398192405700684
Epoch [854/1500] - Loss: 4.2351; AAPD: 3.453798532485962
Epoch [855/1500] - Loss: 62.0885; AAPD: 4.808921813964844
Epoch [856/1500] - Loss: 4.3964; AAPD: 5.8822526931762695
Epoch [857/1500] - Loss: 48.8081; AAPD: 3.0124685764312744
Epoch [858/1500] - Loss: 3.9561; AAPD: 4.674258232116699
Epoch [859/1500] - Loss: 2.4975; AAPD: 3.0389230251312256
Epoch [860/1500] - Loss: 7.4211; AAPD: 4.966707706451416
Epoch [861/1500] - Loss: 11.5022; AAPD: 3.453662633895874
Epoch [862/1500] - Loss: 8.2951; AAPD: 3.397614002227783
Epoch [863/1500] - Loss: 10.9120; AAPD: 6.236787796020508
Epoch [864/1500] -

Epoch [989/1500] - Loss: 20.2010; AAPD: 3.4449543952941895
Epoch [990/1500] - Loss: 194.5746; AAPD: 6.664566993713379
Epoch [991/1500] - Loss: 32.9336; AAPD: 5.687608242034912
Epoch [992/1500] - Loss: 6.4777; AAPD: 5.265049457550049
Epoch [993/1500] - Loss: 16.6959; AAPD: 4.729199409484863
Epoch [994/1500] - Loss: 207.7069; AAPD: 3.9359652996063232
Epoch [995/1500] - Loss: 7.9782; AAPD: 7.160226345062256
Epoch [996/1500] - Loss: 88.8340; AAPD: 4.066473007202148
Epoch [997/1500] - Loss: 124.1218; AAPD: 2.4432733058929443
Epoch [998/1500] - Loss: 19.6838; AAPD: 3.0840156078338623
Epoch [999/1500] - Loss: 38.1523; AAPD: 4.474414348602295
Epoch [1000/1500] - Loss: 73.1855; AAPD: 4.125481605529785
Epoch [1001/1500] - Loss: 5.9364; AAPD: 3.53027606010437
Epoch [1002/1500] - Loss: 24.4920; AAPD: 3.303467273712158
Epoch [1003/1500] - Loss: 405.7378; AAPD: 2.61574387550354
Epoch [1004/1500] - Loss: 40.2318; AAPD: 3.373279094696045
Epoch [1005/1500] - Loss: 10.0604; AAPD: 2.4297783374786377
Epoc

Epoch [1128/1500] - Loss: 11.3664; AAPD: 4.436501979827881
Epoch [1129/1500] - Loss: 49.9142; AAPD: 3.054760694503784
Epoch [1130/1500] - Loss: 5.1391; AAPD: 2.4367058277130127
Epoch [1131/1500] - Loss: 775.2390; AAPD: 4.957183837890625
Epoch [1132/1500] - Loss: 47.6561; AAPD: 5.369926929473877
Epoch [1133/1500] - Loss: 1.6339; AAPD: 3.0493173599243164
Epoch [1134/1500] - Loss: 14.0440; AAPD: 2.7327752113342285
Epoch [1135/1500] - Loss: 261.7444; AAPD: 3.8096535205841064
Epoch [1136/1500] - Loss: 287.0346; AAPD: 4.9356279373168945
Epoch [1137/1500] - Loss: 80.2679; AAPD: 2.7166948318481445
Epoch [1138/1500] - Loss: 2.8077; AAPD: 3.208786725997925
Epoch [1139/1500] - Loss: 7.3541; AAPD: 2.6433732509613037
Epoch [1140/1500] - Loss: 45.8866; AAPD: 3.816518783569336
Epoch [1141/1500] - Loss: 4.1758; AAPD: 2.24818754196167
Epoch [1142/1500] - Loss: 22.5310; AAPD: 3.886686325073242
Epoch [1143/1500] - Loss: 21.9138; AAPD: 2.4118268489837646
Epoch [1144/1500] - Loss: 12.2176; AAPD: 3.29828453

Epoch [1267/1500] - Loss: 5.9524; AAPD: 2.59671688079834
Epoch [1268/1500] - Loss: 60.8294; AAPD: 2.532566547393799
Epoch [1269/1500] - Loss: 6.9343; AAPD: 4.583749294281006
Epoch [1270/1500] - Loss: 5.7208; AAPD: 3.363035202026367
Epoch [1271/1500] - Loss: 4.5656; AAPD: 4.0066914558410645
Epoch [1272/1500] - Loss: 56.7245; AAPD: 5.591863632202148
Epoch [1273/1500] - Loss: 3.7175; AAPD: 2.2540111541748047
Epoch [1274/1500] - Loss: 20.4173; AAPD: 3.6741933822631836
Epoch [1275/1500] - Loss: 34.7432; AAPD: 4.192574501037598
Epoch [1276/1500] - Loss: 10.4570; AAPD: 2.7218844890594482
Epoch [1277/1500] - Loss: 9.1116; AAPD: 3.1942696571350098
Epoch [1278/1500] - Loss: 3.7512; AAPD: 2.0793612003326416
Epoch [1279/1500] - Loss: 44.3080; AAPD: 2.593320369720459
Epoch [1280/1500] - Loss: 16.8518; AAPD: 5.328705310821533
Epoch [1281/1500] - Loss: 37.8589; AAPD: 2.696136713027954
Epoch [1282/1500] - Loss: 5.7666; AAPD: 3.2740378379821777
Epoch [1283/1500] - Loss: 9.3665; AAPD: 2.7569692134857178

Epoch [1406/1500] - Loss: 17.9116; AAPD: 4.32476282119751
Epoch [1407/1500] - Loss: 93.3556; AAPD: 4.110407829284668
Epoch [1408/1500] - Loss: 27.5085; AAPD: 3.8611044883728027
Epoch [1409/1500] - Loss: 4.5822; AAPD: 3.573965549468994
Epoch [1410/1500] - Loss: 5.1707; AAPD: 2.6419975757598877
Epoch [1411/1500] - Loss: 5.1195; AAPD: 2.290092945098877
Epoch [1412/1500] - Loss: 22.8473; AAPD: 3.878018856048584
Epoch [1413/1500] - Loss: 31.6346; AAPD: 2.604053258895874
Epoch [1414/1500] - Loss: 13.8643; AAPD: 2.0866215229034424
Epoch [1415/1500] - Loss: 14.4485; AAPD: 3.5861005783081055
Epoch [1416/1500] - Loss: 23.9602; AAPD: 3.0544211864471436
Epoch [1417/1500] - Loss: 9.1827; AAPD: 2.6275229454040527
Epoch [1418/1500] - Loss: 34.4582; AAPD: 2.8243770599365234
Epoch [1419/1500] - Loss: 2.6157; AAPD: 2.588656187057495
Epoch [1420/1500] - Loss: 9.7535; AAPD: 3.4495441913604736
Epoch [1421/1500] - Loss: 38.0975; AAPD: 3.239553213119507
Epoch [1422/1500] - Loss: 28.5839; AAPD: 2.903494119644

In [59]:
model = linearencoderdecoder()
best_model_ED_ASPD_3 = train_model(model, 'best_model_ED_ASPD_3.pth', train_dataloader, dev_dataloader, target_idx=3)

Epoch [1/1500] - Loss: 0.1942; AAPD: 16.968778610229492
Epoch [2/1500] - Loss: 17.2310; AAPD: 83.87564086914062
Epoch [3/1500] - Loss: 0.1388; AAPD: 40.44755172729492
Epoch [4/1500] - Loss: 0.1921; AAPD: 15.825462341308594
Epoch [5/1500] - Loss: 0.0247; AAPD: 7.316107749938965
Epoch [6/1500] - Loss: 0.2734; AAPD: 6.361852645874023
Epoch [7/1500] - Loss: 0.0247; AAPD: 2.4616520404815674
Epoch [8/1500] - Loss: 0.0221; AAPD: 7.143275737762451
Epoch [9/1500] - Loss: 0.2028; AAPD: 26.57213020324707
Epoch [10/1500] - Loss: 0.1096; AAPD: 21.392871856689453
Epoch [11/1500] - Loss: 0.0289; AAPD: 6.982656955718994
Epoch [12/1500] - Loss: 0.0144; AAPD: 4.793193817138672
Epoch [13/1500] - Loss: 0.0124; AAPD: 5.532114505767822
Epoch [14/1500] - Loss: 0.1117; AAPD: 6.229603290557861
Epoch [15/1500] - Loss: 0.1516; AAPD: 11.435063362121582
Epoch [16/1500] - Loss: 0.1582; AAPD: 26.948833465576172
Epoch [17/1500] - Loss: 0.0282; AAPD: 2.7964818477630615
Epoch [18/1500] - Loss: 0.0250; AAPD: 1.866129517

Epoch [146/1500] - Loss: 0.0682; AAPD: 6.71884822845459
Epoch [147/1500] - Loss: 0.0483; AAPD: 19.11236572265625
Epoch [148/1500] - Loss: 0.0203; AAPD: 43.53752136230469
Epoch [149/1500] - Loss: 0.0153; AAPD: 2.341411828994751
Epoch [150/1500] - Loss: 0.0285; AAPD: 1.694683313369751
Epoch [151/1500] - Loss: 0.0197; AAPD: 11.086040496826172
Epoch [152/1500] - Loss: 0.5238; AAPD: 46.034385681152344
Epoch [153/1500] - Loss: 0.0582; AAPD: 5.120475769042969
Epoch [154/1500] - Loss: 0.0091; AAPD: 11.106524467468262
Epoch [155/1500] - Loss: 0.0236; AAPD: 7.691174030303955
Epoch [156/1500] - Loss: 0.0894; AAPD: 22.691904067993164
Epoch [157/1500] - Loss: 0.0358; AAPD: 17.956981658935547
Epoch [158/1500] - Loss: 0.0099; AAPD: 9.590811729431152
Epoch [159/1500] - Loss: 0.0761; AAPD: 13.96623420715332
Epoch [160/1500] - Loss: 13.0222; AAPD: 113.28205108642578
Epoch [161/1500] - Loss: 0.0193; AAPD: 3.098022222518921
Epoch [162/1500] - Loss: 0.0103; AAPD: 2.273348331451416
Epoch [163/1500] - Loss: 

Epoch [289/1500] - Loss: 0.0238; AAPD: 2.0330498218536377
Epoch [290/1500] - Loss: 0.0112; AAPD: 2.174600601196289
Epoch [291/1500] - Loss: 0.0133; AAPD: 6.329648971557617
Epoch [292/1500] - Loss: 0.0132; AAPD: 1.830428123474121
Epoch [293/1500] - Loss: 0.0202; AAPD: 1.856498122215271
Epoch [294/1500] - Loss: 0.0256; AAPD: 1.5564346313476562
Epoch [295/1500] - Loss: 0.0178; AAPD: 1.5870786905288696
Epoch [296/1500] - Loss: 0.0193; AAPD: 2.167928457260132
Epoch [297/1500] - Loss: 0.0149; AAPD: 2.4650862216949463
Epoch [298/1500] - Loss: 0.0260; AAPD: 2.222097635269165
Epoch [299/1500] - Loss: 0.0133; AAPD: 2.541186809539795
Epoch [300/1500] - Loss: 0.0105; AAPD: 1.586824893951416
Epoch [301/1500] - Loss: 0.0160; AAPD: 2.896944761276245
Epoch [302/1500] - Loss: 0.0107; AAPD: 1.8621214628219604
Epoch [303/1500] - Loss: 0.0214; AAPD: 1.8692197799682617
Epoch [304/1500] - Loss: 0.0105; AAPD: 2.203279495239258
Epoch [305/1500] - Loss: 0.0154; AAPD: 2.450840950012207
Epoch [306/1500] - Loss: 

Epoch [432/1500] - Loss: 0.0153; AAPD: 1.772873878479004
Epoch [433/1500] - Loss: 0.0115; AAPD: 2.819514036178589
Epoch [434/1500] - Loss: 0.0196; AAPD: 1.7202191352844238
Epoch [435/1500] - Loss: 0.0385; AAPD: 9.842467308044434
Epoch [436/1500] - Loss: 0.0136; AAPD: 4.082458972930908
Epoch [437/1500] - Loss: 0.0185; AAPD: 2.083893299102783
Epoch [438/1500] - Loss: 0.0138; AAPD: 1.7458868026733398
Epoch [439/1500] - Loss: 0.0170; AAPD: 2.247345447540283
Epoch [440/1500] - Loss: 0.0121; AAPD: 4.396670341491699
Epoch [441/1500] - Loss: 0.0188; AAPD: 2.4756381511688232
Epoch [442/1500] - Loss: 0.0086; AAPD: 2.2288382053375244
Epoch [443/1500] - Loss: 0.0383; AAPD: 11.448956489562988
Epoch [444/1500] - Loss: 0.0140; AAPD: 1.679158091545105
Epoch [445/1500] - Loss: 0.0123; AAPD: 3.4074859619140625
Epoch [446/1500] - Loss: 0.0227; AAPD: 3.6644725799560547
Epoch [447/1500] - Loss: 0.0163; AAPD: 3.108161687850952
Epoch [448/1500] - Loss: 0.0119; AAPD: 2.4535884857177734
Epoch [449/1500] - Loss

Epoch [575/1500] - Loss: 0.0104; AAPD: 1.3409017324447632
Epoch [576/1500] - Loss: 0.0102; AAPD: 1.3063043355941772
Epoch [577/1500] - Loss: 0.0127; AAPD: 1.542597770690918
Epoch [578/1500] - Loss: 0.0143; AAPD: 1.5647317171096802
Epoch [579/1500] - Loss: 0.0082; AAPD: 1.345697283744812
Epoch [580/1500] - Loss: 0.0075; AAPD: 1.4402213096618652
Epoch [581/1500] - Loss: 0.0133; AAPD: 1.384312629699707
Epoch [582/1500] - Loss: 0.0106; AAPD: 1.437468409538269
Epoch [583/1500] - Loss: 0.0130; AAPD: 1.3778940439224243
Epoch [584/1500] - Loss: 0.0093; AAPD: 1.5406526327133179
Epoch [585/1500] - Loss: 0.0194; AAPD: 2.008378267288208
Epoch [586/1500] - Loss: 0.0195; AAPD: 1.3317946195602417
Epoch [587/1500] - Loss: 0.0102; AAPD: 1.2877403497695923
Epoch [588/1500] - Loss: 0.0311; AAPD: 1.2946010828018188
Epoch [589/1500] - Loss: 0.0141; AAPD: 1.3952600955963135
Epoch [590/1500] - Loss: 0.0120; AAPD: 1.3763381242752075
Epoch [591/1500] - Loss: 0.0049; AAPD: 1.4509481191635132
Epoch [592/1500] - 

Epoch [717/1500] - Loss: 0.0178; AAPD: 1.3512094020843506
Epoch [718/1500] - Loss: 0.0132; AAPD: 1.186737298965454
Epoch [719/1500] - Loss: 0.0134; AAPD: 1.3284974098205566
Epoch [720/1500] - Loss: 0.0169; AAPD: 1.3459599018096924
Epoch [721/1500] - Loss: 0.0213; AAPD: 1.5837759971618652
Epoch [722/1500] - Loss: 0.0194; AAPD: 2.5358166694641113
Epoch [723/1500] - Loss: 0.0123; AAPD: 1.6170231103897095
Epoch [724/1500] - Loss: 0.0111; AAPD: 1.5199302434921265
Epoch [725/1500] - Loss: 0.0167; AAPD: 2.3317830562591553
Epoch [726/1500] - Loss: 0.0229; AAPD: 1.2466710805892944
Epoch [727/1500] - Loss: 0.0089; AAPD: 1.3125180006027222
Epoch [728/1500] - Loss: 0.0113; AAPD: 1.3292007446289062
Epoch [729/1500] - Loss: 0.0094; AAPD: 1.2625153064727783
Epoch [730/1500] - Loss: 0.0088; AAPD: 1.350690245628357
Epoch [731/1500] - Loss: 0.0146; AAPD: 1.7934497594833374
Epoch [732/1500] - Loss: 0.0383; AAPD: 1.4741469621658325
Epoch [733/1500] - Loss: 0.0072; AAPD: 1.6090930700302124
Epoch [734/1500]

Epoch [859/1500] - Loss: 0.0147; AAPD: 1.2074931859970093
Epoch [860/1500] - Loss: 0.0167; AAPD: 1.1808209419250488
Epoch [861/1500] - Loss: 0.0135; AAPD: 1.1818029880523682
Epoch [862/1500] - Loss: 0.0129; AAPD: 1.2484519481658936
Epoch [863/1500] - Loss: 0.0097; AAPD: 1.1794377565383911
Epoch [864/1500] - Loss: 0.0154; AAPD: 1.2078510522842407
Epoch [865/1500] - Loss: 0.0244; AAPD: 1.2659534215927124
Epoch [866/1500] - Loss: 0.0061; AAPD: 1.2412464618682861
Epoch [867/1500] - Loss: 0.0069; AAPD: 1.3655272722244263
Epoch [868/1500] - Loss: 0.0090; AAPD: 1.2338188886642456
Epoch [869/1500] - Loss: 0.0169; AAPD: 1.2117310762405396
Epoch [870/1500] - Loss: 0.0199; AAPD: 1.2342219352722168
Epoch [871/1500] - Loss: 0.0207; AAPD: 1.2555646896362305
Epoch [872/1500] - Loss: 0.0153; AAPD: 1.1938188076019287
Epoch [873/1500] - Loss: 0.0156; AAPD: 1.1949892044067383
Epoch [874/1500] - Loss: 0.0187; AAPD: 1.2152658700942993
Epoch [875/1500] - Loss: 0.0114; AAPD: 1.2243084907531738
Epoch [876/150

Epoch [1001/1500] - Loss: 0.0228; AAPD: 1.1825844049453735
Epoch [1002/1500] - Loss: 0.0221; AAPD: 1.1868499517440796
Epoch [1003/1500] - Loss: 0.0057; AAPD: 1.1882294416427612
Epoch [1004/1500] - Loss: 0.0188; AAPD: 1.1809126138687134
Epoch [1005/1500] - Loss: 0.0139; AAPD: 1.1759767532348633
Epoch [1006/1500] - Loss: 0.0045; AAPD: 1.194832682609558
Epoch [1007/1500] - Loss: 0.0124; AAPD: 1.1875282526016235
Epoch [1008/1500] - Loss: 0.0124; AAPD: 1.1885008811950684
Epoch [1009/1500] - Loss: 0.0153; AAPD: 1.190146565437317
Epoch [1010/1500] - Loss: 0.0135; AAPD: 1.1919838190078735
Epoch [1011/1500] - Loss: 0.0057; AAPD: 1.1881529092788696
Epoch [1012/1500] - Loss: 0.0129; AAPD: 1.183530569076538
Epoch [1013/1500] - Loss: 0.0216; AAPD: 1.1913329362869263
Epoch [1014/1500] - Loss: 0.0085; AAPD: 1.1911578178405762
Epoch [1015/1500] - Loss: 0.0148; AAPD: 1.19089937210083
Epoch [1016/1500] - Loss: 0.0158; AAPD: 1.1848969459533691
Epoch [1017/1500] - Loss: 0.0259; AAPD: 1.1835952997207642
Ep

Epoch [1141/1500] - Loss: 0.0151; AAPD: 1.1921882629394531
Epoch [1142/1500] - Loss: 0.0069; AAPD: 1.177344799041748
Epoch [1143/1500] - Loss: 0.0152; AAPD: 1.1858947277069092
Epoch [1144/1500] - Loss: 0.0099; AAPD: 1.1817712783813477
Epoch [1145/1500] - Loss: 0.0134; AAPD: 1.1833117008209229
Epoch [1146/1500] - Loss: 0.0112; AAPD: 1.1840534210205078
Epoch [1147/1500] - Loss: 0.0083; AAPD: 1.2021336555480957
Epoch [1148/1500] - Loss: 0.0082; AAPD: 1.1834602355957031
Epoch [1149/1500] - Loss: 0.0073; AAPD: 1.1895118951797485
Epoch [1150/1500] - Loss: 0.0093; AAPD: 1.1920456886291504
Epoch [1151/1500] - Loss: 0.0118; AAPD: 1.1805962324142456
Epoch [1152/1500] - Loss: 0.0142; AAPD: 1.1882840394973755
Epoch [1153/1500] - Loss: 0.0229; AAPD: 1.1900579929351807
Epoch [1154/1500] - Loss: 0.0147; AAPD: 1.2026840448379517
Epoch [1155/1500] - Loss: 0.0199; AAPD: 1.1840778589248657
Epoch [1156/1500] - Loss: 0.0139; AAPD: 1.1860085725784302
Epoch [1157/1500] - Loss: 0.0105; AAPD: 1.191304206848144

Epoch [1281/1500] - Loss: 0.0136; AAPD: 1.1900955438613892
Epoch [1282/1500] - Loss: 0.0080; AAPD: 1.1859625577926636
Epoch [1283/1500] - Loss: 0.0093; AAPD: 1.1935040950775146
Epoch [1284/1500] - Loss: 0.0122; AAPD: 1.1778424978256226
Epoch [1285/1500] - Loss: 0.0114; AAPD: 1.1914814710617065
Epoch [1286/1500] - Loss: 0.0059; AAPD: 1.188796877861023
Epoch [1287/1500] - Loss: 0.0154; AAPD: 1.1864328384399414
Epoch [1288/1500] - Loss: 0.0100; AAPD: 1.1979917287826538
Epoch [1289/1500] - Loss: 0.0112; AAPD: 1.1810956001281738
Epoch [1290/1500] - Loss: 0.0049; AAPD: 1.1823397874832153
Epoch [1291/1500] - Loss: 0.0085; AAPD: 1.1794617176055908
Epoch [1292/1500] - Loss: 0.0185; AAPD: 1.196362853050232
Epoch [1293/1500] - Loss: 0.0158; AAPD: 1.1780529022216797
Epoch [1294/1500] - Loss: 0.0166; AAPD: 1.1790355443954468
Epoch [1295/1500] - Loss: 0.0109; AAPD: 1.1740942001342773
Epoch [1296/1500] - Loss: 0.0055; AAPD: 1.1830410957336426
Epoch [1297/1500] - Loss: 0.0140; AAPD: 1.1785269975662231

Epoch [1421/1500] - Loss: 0.0134; AAPD: 1.1856282949447632
Epoch [1422/1500] - Loss: 0.0148; AAPD: 1.18213951587677
Epoch [1423/1500] - Loss: 0.0065; AAPD: 1.1742138862609863
Epoch [1424/1500] - Loss: 0.0108; AAPD: 1.1927666664123535
Epoch [1425/1500] - Loss: 0.0053; AAPD: 1.1747397184371948
Epoch [1426/1500] - Loss: 0.0135; AAPD: 1.1828242540359497
Epoch [1427/1500] - Loss: 0.0206; AAPD: 1.188299536705017
Epoch [1428/1500] - Loss: 0.0139; AAPD: 1.1967922449111938
Epoch [1429/1500] - Loss: 0.0197; AAPD: 1.207012414932251
Epoch [1430/1500] - Loss: 0.0065; AAPD: 1.197006344795227
Epoch [1431/1500] - Loss: 0.0146; AAPD: 1.2064478397369385
Epoch [1432/1500] - Loss: 0.0160; AAPD: 1.183963656425476
Epoch [1433/1500] - Loss: 0.0200; AAPD: 1.1868950128555298
Epoch [1434/1500] - Loss: 0.0132; AAPD: 1.1908317804336548
Epoch [1435/1500] - Loss: 0.0100; AAPD: 1.1933616399765015
Epoch [1436/1500] - Loss: 0.0086; AAPD: 1.1761728525161743
Epoch [1437/1500] - Loss: 0.0127; AAPD: 1.1770647764205933
Epo