In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import matplotlib.pyplot as plt
import sys
import numpy as np
import random
from utils import GroundTruthProcess
from utils import NCHW_to_NHWC_np
from utils import NHWC_to_NCHW_np
from DME_deformable import DMENet
from config import DefaultConfig
import torch.cuda as torch_cudab
from data_process.DatasetConstructor import DatasetConstructor
import metrics
MAE = 10240000
MSE = 10240000
%matplotlib inline

# obtain the gpu device
assert torch.cuda.is_available()
cuda_device = torch.device("cuda")  # device object representing GPU
opt = DefaultConfig()

# data_load
img_dir = "/home/zzn/part_B_final/train_data/images"
gt_dir = "/home/zzn/part_B_final/train_data/gt_map"

# model construct
net = DMENet().to(cuda_device)
gt_map_process_model = GroundTruthProcess(1, 1, 8).to(cuda_device) # to keep the same resolution with the prediction

# set optimizer and estimator
criterion = metrics.DMELoss().to(cuda_device)
optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
ae_batch = metrics.AEBatch().to(cuda_device)
se_batch = metrics.SEBatch().to(cuda_device)

In [None]:
# train
for i in range(opt.max_epoch):
    batch_dataset = DatasetConstructor(img_dir, gt_dir, 400, True)
    train_loader = torch.utils.data.DataLoader(dataset=batch_dataset, batch_size=1)
    step = 0
    for batch_index, img, gt_batch in train_loader:
        
        validate_dataset = DatasetConstructor(img_dir, gt_dir, 20, False)
        validate_loader = torch.utils.data.DataLoader(dataset=validate_dataset, batch_size=1)
        
        # clear the gradient
        for j in range(10):
            
            optimizer.zero_grad()
            # get the batch data
            if opt.use_gpu:
                x = img.view(-1, 3, 600, 800)[j:j+1].cuda()
                gt = gt_batch.view(-1, 1, 600, 800)[j:j+1].cuda()
            # processs the gt because the final output only have 1/8 resolution
            gt_map = gt_map_process_model(gt)
            # flow the data through the net
            estimated_density_map = net(x)
            # calculate the loss and backward the gradient
            loss = criterion(estimated_density_map, gt_map)
            loss.backward()
            optimizer.step()
            step += 1
        
        # validate
        if step % 10 == 0:
            net.eval()
            loss_ = []
            MAE_ = []
            MSE_ = []
            
            # gather all validate information
            for validate_index, validate_img, validate_gt in validate_loader:
                if opt.use_gpu:
                    validate_x = validate_img.view(1, 3, 768, 1024).cuda()
                    validate_gt = validate_gt.view(1, 1, 768, 1024).cuda()
                validate_predict_map = net(validate_x)
                validate_gt_map = gt_map_process_model(validate_gt)
                # That’s because numpy doesn’t support CUDA, 
                # so there’s no way to make it use GPU memory without a copy to CPU first. 
                # Remember that .numpy() doesn’t do any copy, 
                # but returns an array that uses the same memory as the tensor
                validate_loss = criterion(validate_predict_map, validate_gt_map).data.cpu().numpy()
                batch_ae = ae_batch(validate_predict_map, validate_gt_map).data.cpu().numpy()
                batch_se = se_batch(validate_predict_map, validate_gt_map).data.cpu().numpy()
                loss_.append(validate_loss)
                MAE_.append(batch_ae)
                MSE_.append(batch_se)
            
            # calculate the validate loss, validate MAE and validate RMSE
            loss_ = np.reshape(loss_, [-1])
            MAE_ = np.reshape(MAE_, [-1])
            MSE_ = np.reshape(MSE_, [-1])
            
            validate_loss = np.mean(loss_)
            validate_MAE = np.mean(MAE_)
            validate_RMSE = np.sqrt(np.mean(MSE_))
            
#             # show one sample from the validation
# #             random_num = random.randint(0, 19)
#             random_num = 1
#             random_sample_x = torch.FloatTensor(image_validate[random_num:random_num+1]).cuda()
#             random_sample_gt = torch.FloatTensor(gt_validate[random_num:random_num+1]).cuda()
#             random_sample_predict = NCHW_to_NHWC_np(net(random_sample_x).data.cpu().numpy())
#             random_sample_gt_map = NCHW_to_NHWC_np(gt_map_process_model(random_sample_gt).data.cpu().numpy())
            
#             figure, (origin, density_gt, pred) = plt.subplots(1, 3, figsize=(20, 4))
#             origin.imshow(np.squeeze(NCHW_to_NHWC_np(image_validate[random_num:random_num+1])))
#             origin.set_title('Origin Image')
#             density_gt.imshow(np.squeeze(random_sample_gt_map), cmap=plt.cm.jet)
#             density_gt.set_title('ground_truth')
#             pred.imshow(np.squeeze(random_sample_predict), cmap=plt.cm.jet)
#             pred.set_title('back_end')
#             plt.suptitle("one sample from the validate")
#             plt.show()
#             plt.close()
#             gt_counts = np.squeeze(torch.sum(validate_gt_map, dim=(0, 1, 2, 3)).data.cpu().numpy())
#             pred_counts = np.squeeze(torch.sum(validate_predict_map, dim=(0, 1, 2, 3)).data.cpu().numpy())
            
            # show the validate MAE and MSE values on stdout
#             sys.stdout.write('The gt counts of the above sample:{}, and the pred counts:{}\n'.format(gt_counts, pred_counts))
            sys.stdout.write('In step {}, epoch {}, with loss {}, MAE = {}, MSE = {}\n'.format(step, i + 1, validate_loss, validate_MAE, validate_RMSE))
            sys.stdout.flush()
            
            # save model
            if MAE > validate_MAE:
                MAE = validate_MAE
                torch.save(net, opt.mae_model_b)
                
            # save model
            if MSE > validate_RMSE:
                MSE = validate_RMSE
                torch.save(net, opt.mse_model_b)
            
            net.train()
        
            
                
                

In step 10, epoch 1, with loss 17.50596809387207, MAE = 138.51156616210938, MSE = 163.13900756835938
In step 20, epoch 1, with loss 14.114282608032227, MAE = 82.8393783569336, MSE = 98.06463623046875
In step 30, epoch 1, with loss 19.614627838134766, MAE = 70.83032989501953, MSE = 88.1166763305664
In step 40, epoch 1, with loss 18.057106018066406, MAE = 99.68263244628906, MSE = 133.88198852539062
In step 50, epoch 1, with loss 37.12691116333008, MAE = 157.36154174804688, MSE = 205.28555297851562
In step 60, epoch 1, with loss 16.60133171081543, MAE = 82.88035583496094, MSE = 124.52913665771484
In step 70, epoch 1, with loss 12.197083473205566, MAE = 50.269996643066406, MSE = 70.79744720458984
In step 80, epoch 1, with loss 17.501253128051758, MAE = 62.25359344482422, MSE = 78.38209533691406
In step 90, epoch 1, with loss 18.316333770751953, MAE = 68.48020935058594, MSE = 101.18235778808594
In step 100, epoch 1, with loss 16.755359649658203, MAE = 59.722068786621094, MSE = 122.351638793

In step 820, epoch 1, with loss 12.646014213562012, MAE = 149.41177368164062, MSE = 158.5875701904297
In step 830, epoch 1, with loss 9.769606590270996, MAE = 144.1625518798828, MSE = 151.6800079345703
In step 840, epoch 1, with loss 13.206263542175293, MAE = 59.40740203857422, MSE = 69.63935089111328
In step 850, epoch 1, with loss 23.48668670654297, MAE = 77.38081359863281, MSE = 93.73677825927734
In step 860, epoch 1, with loss 14.49183177947998, MAE = 119.41397857666016, MSE = 129.78094482421875
In step 870, epoch 1, with loss 12.217822074890137, MAE = 50.041282653808594, MSE = 64.64569091796875
In step 880, epoch 1, with loss 16.54207992553711, MAE = 100.11296081542969, MSE = 127.63117218017578
In step 890, epoch 1, with loss 15.417829513549805, MAE = 75.50968933105469, MSE = 109.95507049560547
In step 900, epoch 1, with loss 25.67705726623535, MAE = 177.3982391357422, MSE = 200.03465270996094
In step 910, epoch 1, with loss 24.32595443725586, MAE = 142.3128662109375, MSE = 187.40

In step 1630, epoch 1, with loss 12.567106246948242, MAE = 113.91053771972656, MSE = 140.8381805419922
In step 1640, epoch 1, with loss 9.929710388183594, MAE = 45.14259719848633, MSE = 57.00809097290039
In step 1650, epoch 1, with loss 13.810028076171875, MAE = 69.59037780761719, MSE = 95.0931167602539
In step 1660, epoch 1, with loss 14.26587200164795, MAE = 65.97000122070312, MSE = 96.01671600341797
In step 1670, epoch 1, with loss 19.761831283569336, MAE = 83.79293060302734, MSE = 109.49166107177734
In step 1680, epoch 1, with loss 9.825750350952148, MAE = 49.333560943603516, MSE = 74.65872192382812
In step 1690, epoch 1, with loss 14.290170669555664, MAE = 55.21669387817383, MSE = 70.89410400390625
In step 1700, epoch 1, with loss 11.873880386352539, MAE = 74.02790832519531, MSE = 89.42569732666016
In step 1710, epoch 1, with loss 21.507577896118164, MAE = 128.2425994873047, MSE = 170.5183563232422
In step 1720, epoch 1, with loss 14.96397876739502, MAE = 245.3037109375, MSE = 250

In step 2440, epoch 1, with loss 14.903696060180664, MAE = 77.7323989868164, MSE = 88.69219207763672
In step 2450, epoch 1, with loss 14.879781723022461, MAE = 120.3985366821289, MSE = 151.27413940429688
In step 2460, epoch 1, with loss 16.41741943359375, MAE = 99.41801452636719, MSE = 130.52114868164062
In step 2470, epoch 1, with loss 11.10163688659668, MAE = 114.34062194824219, MSE = 121.65118408203125
In step 2480, epoch 1, with loss 13.053537368774414, MAE = 88.93855285644531, MSE = 94.50704956054688
In step 2490, epoch 1, with loss 15.790102005004883, MAE = 120.27873229980469, MSE = 147.6620330810547
In step 2500, epoch 1, with loss 12.16539192199707, MAE = 77.94971466064453, MSE = 98.31157684326172
In step 2510, epoch 1, with loss 19.28584098815918, MAE = 70.1462173461914, MSE = 104.93867492675781
In step 2520, epoch 1, with loss 12.197997093200684, MAE = 60.67144775390625, MSE = 81.06652069091797
In step 2530, epoch 1, with loss 18.2185001373291, MAE = 278.8465881347656, MSE = 

In step 3250, epoch 1, with loss 8.481176376342773, MAE = 35.11491012573242, MSE = 46.956871032714844
In step 3260, epoch 1, with loss 11.764204025268555, MAE = 49.29331588745117, MSE = 67.99735260009766
In step 3270, epoch 1, with loss 10.031827926635742, MAE = 37.76125717163086, MSE = 52.84515380859375
In step 3280, epoch 1, with loss 11.734368324279785, MAE = 42.65998077392578, MSE = 59.1632080078125
In step 3290, epoch 1, with loss 12.643261909484863, MAE = 70.39505767822266, MSE = 94.71732330322266
In step 3300, epoch 1, with loss 13.406598091125488, MAE = 79.97520446777344, MSE = 106.02922058105469
In step 3310, epoch 1, with loss 13.776516914367676, MAE = 56.87888717651367, MSE = 85.20492553710938
In step 3320, epoch 1, with loss 20.72515869140625, MAE = 78.76876068115234, MSE = 128.50840759277344
In step 3330, epoch 1, with loss 12.544950485229492, MAE = 48.31053161621094, MSE = 68.42796325683594
In step 3340, epoch 1, with loss 14.702618598937988, MAE = 51.11442947387695, MSE 

In step 60, epoch 2, with loss 18.55630111694336, MAE = 110.50407409667969, MSE = 135.8972625732422
In step 70, epoch 2, with loss 15.402214050292969, MAE = 63.97197341918945, MSE = 76.72301483154297
In step 80, epoch 2, with loss 13.529609680175781, MAE = 125.3201675415039, MSE = 133.89561462402344
In step 90, epoch 2, with loss 13.251187324523926, MAE = 91.33439636230469, MSE = 103.13758087158203
In step 100, epoch 2, with loss 12.555174827575684, MAE = 67.26587677001953, MSE = 95.07179260253906
In step 110, epoch 2, with loss 9.968599319458008, MAE = 66.75556182861328, MSE = 95.47431182861328
In step 120, epoch 2, with loss 13.187914848327637, MAE = 68.47998809814453, MSE = 95.48897552490234
In step 130, epoch 2, with loss 9.628179550170898, MAE = 50.676673889160156, MSE = 76.13250732421875
In step 140, epoch 2, with loss 13.369196891784668, MAE = 81.10590362548828, MSE = 114.26021575927734
In step 150, epoch 2, with loss 11.103438377380371, MAE = 180.44381713867188, MSE = 184.89265

In step 880, epoch 2, with loss 14.476966857910156, MAE = 191.1620330810547, MSE = 203.29331970214844
In step 890, epoch 2, with loss 11.083454132080078, MAE = 106.01611328125, MSE = 113.26895141601562
In step 900, epoch 2, with loss 21.07208251953125, MAE = 94.37266540527344, MSE = 137.42340087890625
In step 910, epoch 2, with loss 15.051427841186523, MAE = 62.12592697143555, MSE = 85.05244445800781
In step 920, epoch 2, with loss 10.512017250061035, MAE = 48.53813934326172, MSE = 76.92719268798828
In step 930, epoch 2, with loss 8.84062385559082, MAE = 43.13273620605469, MSE = 64.39632415771484
In step 940, epoch 2, with loss 7.3033599853515625, MAE = 35.36592483520508, MSE = 57.4183235168457
In step 950, epoch 2, with loss 14.711645126342773, MAE = 75.96871948242188, MSE = 84.33680725097656
In step 960, epoch 2, with loss 10.905877113342285, MAE = 49.40631866455078, MSE = 68.97693634033203
In step 970, epoch 2, with loss 13.131810188293457, MAE = 56.598121643066406, MSE = 85.8797607

In step 1690, epoch 2, with loss 15.572848320007324, MAE = 78.65010070800781, MSE = 108.63579559326172
In step 1700, epoch 2, with loss 18.028844833374023, MAE = 181.9066925048828, MSE = 191.66195678710938
In step 1710, epoch 2, with loss 10.630827903747559, MAE = 133.05419921875, MSE = 146.50137329101562
In step 1720, epoch 2, with loss 9.893348693847656, MAE = 49.962032318115234, MSE = 62.54106903076172
In step 1730, epoch 2, with loss 15.960474014282227, MAE = 67.79080963134766, MSE = 97.2018814086914
In step 1740, epoch 2, with loss 17.41958236694336, MAE = 87.4618148803711, MSE = 118.7352523803711
In step 1750, epoch 2, with loss 16.841785430908203, MAE = 65.38151550292969, MSE = 96.28248596191406
In step 1760, epoch 2, with loss 14.027763366699219, MAE = 43.195762634277344, MSE = 61.90345764160156
In step 1770, epoch 2, with loss 12.602350234985352, MAE = 62.635887145996094, MSE = 90.31967163085938
In step 1780, epoch 2, with loss 12.404674530029297, MAE = 97.55426025390625, MSE 

In step 2500, epoch 2, with loss 15.839640617370605, MAE = 61.2857666015625, MSE = 98.6010971069336
In step 2510, epoch 2, with loss 14.857681274414062, MAE = 84.94647216796875, MSE = 108.71542358398438
In step 2520, epoch 2, with loss 12.258134841918945, MAE = 50.52830123901367, MSE = 77.40666198730469
In step 2530, epoch 2, with loss 5.854182243347168, MAE = 39.202205657958984, MSE = 46.20677947998047
In step 2540, epoch 2, with loss 9.31149959564209, MAE = 63.75762176513672, MSE = 88.3708724975586
In step 2550, epoch 2, with loss 13.841473579406738, MAE = 59.72796630859375, MSE = 79.62007904052734
In step 2560, epoch 2, with loss 13.176974296569824, MAE = 54.236480712890625, MSE = 69.36402893066406
In step 2570, epoch 2, with loss 11.21806526184082, MAE = 35.817169189453125, MSE = 46.55872344970703
In step 2580, epoch 2, with loss 10.296380996704102, MAE = 66.88615417480469, MSE = 84.02214813232422
In step 2590, epoch 2, with loss 8.64130687713623, MAE = 52.2014045715332, MSE = 70.3

In step 3310, epoch 2, with loss 14.775070190429688, MAE = 56.954620361328125, MSE = 67.13542175292969
In step 3320, epoch 2, with loss 7.4390106201171875, MAE = 35.50847244262695, MSE = 47.83729934692383
In step 3330, epoch 2, with loss 10.808343887329102, MAE = 49.81432342529297, MSE = 67.05859375
In step 3340, epoch 2, with loss 11.34970760345459, MAE = 58.44376754760742, MSE = 74.38689422607422
In step 3350, epoch 2, with loss 13.652958869934082, MAE = 47.34597396850586, MSE = 65.77641296386719
In step 3360, epoch 2, with loss 10.369527816772461, MAE = 40.41164779663086, MSE = 70.88679504394531
In step 3370, epoch 2, with loss 9.080144882202148, MAE = 59.31148147583008, MSE = 79.60221862792969
In step 3380, epoch 2, with loss 12.698355674743652, MAE = 56.3835563659668, MSE = 82.15149688720703
In step 3390, epoch 2, with loss 7.424337863922119, MAE = 67.52195739746094, MSE = 74.09019470214844
In step 3400, epoch 2, with loss 10.002398490905762, MAE = 64.60459899902344, MSE = 86.7342

In step 120, epoch 3, with loss 12.865720748901367, MAE = 48.99142074584961, MSE = 68.5334701538086
In step 130, epoch 3, with loss 9.20300006866455, MAE = 48.58873748779297, MSE = 53.972190856933594
In step 140, epoch 3, with loss 9.128015518188477, MAE = 131.3324432373047, MSE = 134.4019012451172
In step 150, epoch 3, with loss 10.978878021240234, MAE = 49.529632568359375, MSE = 62.4994010925293
In step 160, epoch 3, with loss 12.660911560058594, MAE = 57.76203155517578, MSE = 91.924560546875
In step 170, epoch 3, with loss 10.472081184387207, MAE = 34.06911849975586, MSE = 43.82548141479492
In step 180, epoch 3, with loss 11.702310562133789, MAE = 78.16432189941406, MSE = 98.62311553955078
In step 190, epoch 3, with loss 6.991569519042969, MAE = 29.8900203704834, MSE = 41.99716567993164
In step 200, epoch 3, with loss 10.806561470031738, MAE = 43.44935607910156, MSE = 74.23375701904297
In step 210, epoch 3, with loss 11.686738967895508, MAE = 46.48448944091797, MSE = 66.612548828125

In step 940, epoch 3, with loss 16.466861724853516, MAE = 90.20048522949219, MSE = 116.81493377685547
In step 950, epoch 3, with loss 11.883268356323242, MAE = 57.43768310546875, MSE = 66.64131164550781
In step 960, epoch 3, with loss 10.785022735595703, MAE = 46.37262725830078, MSE = 57.01252746582031
In step 970, epoch 3, with loss 8.249350547790527, MAE = 75.27095794677734, MSE = 84.06720733642578
In step 980, epoch 3, with loss 9.316625595092773, MAE = 71.49861145019531, MSE = 83.26594543457031
In step 990, epoch 3, with loss 15.764719009399414, MAE = 100.47801208496094, MSE = 114.41063690185547
In step 1000, epoch 3, with loss 12.882989883422852, MAE = 66.05570220947266, MSE = 99.07705688476562
In step 1010, epoch 3, with loss 27.004602432250977, MAE = 86.85977172851562, MSE = 116.63685607910156
In step 1020, epoch 3, with loss 9.925111770629883, MAE = 45.06098175048828, MSE = 60.36542892456055
In step 1030, epoch 3, with loss 30.740066528320312, MAE = 543.950439453125, MSE = 563.

In step 1750, epoch 3, with loss 7.454123020172119, MAE = 142.0130157470703, MSE = 144.46168518066406
In step 1760, epoch 3, with loss 11.32589340209961, MAE = 54.654991149902344, MSE = 73.38976287841797
In step 1770, epoch 3, with loss 9.530555725097656, MAE = 121.0634994506836, MSE = 130.51486206054688
In step 1780, epoch 3, with loss 8.77691650390625, MAE = 32.157615661621094, MSE = 36.57905960083008
In step 1790, epoch 3, with loss 15.441566467285156, MAE = 59.964561462402344, MSE = 79.90861511230469
In step 1800, epoch 3, with loss 14.38269329071045, MAE = 68.70675659179688, MSE = 101.00526428222656
In step 1810, epoch 3, with loss 11.106576919555664, MAE = 35.381893157958984, MSE = 57.559608459472656
In step 1820, epoch 3, with loss 15.964869499206543, MAE = 46.54377365112305, MSE = 70.69246673583984
In step 1830, epoch 3, with loss 8.60432243347168, MAE = 43.6088981628418, MSE = 67.33252716064453
In step 1840, epoch 3, with loss 9.66582202911377, MAE = 53.1046257019043, MSE = 58

In step 2560, epoch 3, with loss 12.04675006866455, MAE = 82.1168441772461, MSE = 102.50056457519531
In step 2570, epoch 3, with loss 8.044631958007812, MAE = 41.88928985595703, MSE = 63.320556640625
In step 2580, epoch 3, with loss 8.85676383972168, MAE = 41.95064163208008, MSE = 65.18589782714844
In step 2590, epoch 3, with loss 7.473306179046631, MAE = 29.890201568603516, MSE = 43.49530029296875
In step 2600, epoch 3, with loss 9.782003402709961, MAE = 37.2418327331543, MSE = 46.12418746948242
In step 2610, epoch 3, with loss 15.143301010131836, MAE = 47.29994201660156, MSE = 53.06658172607422
In step 2620, epoch 3, with loss 14.979101181030273, MAE = 39.23833465576172, MSE = 61.504791259765625
In step 2630, epoch 3, with loss 11.300616264343262, MAE = 61.98565673828125, MSE = 84.0284194946289
In step 2640, epoch 3, with loss 10.194784164428711, MAE = 63.52460861206055, MSE = 82.207275390625
In step 2650, epoch 3, with loss 14.489730834960938, MAE = 56.137916564941406, MSE = 94.7917

In step 3370, epoch 3, with loss 11.179634094238281, MAE = 59.37595748901367, MSE = 75.33336639404297
In step 3380, epoch 3, with loss 14.299585342407227, MAE = 45.731910705566406, MSE = 53.344913482666016
In step 3390, epoch 3, with loss 7.417179107666016, MAE = 50.616539001464844, MSE = 59.94264221191406
In step 3400, epoch 3, with loss 11.388381958007812, MAE = 50.0119743347168, MSE = 77.84334564208984
In step 3410, epoch 3, with loss 19.354061126708984, MAE = 95.76853942871094, MSE = 122.87201690673828
In step 3420, epoch 3, with loss 11.371981620788574, MAE = 105.20501708984375, MSE = 110.9030990600586
In step 3430, epoch 3, with loss 13.41950511932373, MAE = 68.5135498046875, MSE = 86.1751937866211
In step 3440, epoch 3, with loss 10.581014633178711, MAE = 37.34366226196289, MSE = 48.00309753417969
In step 3450, epoch 3, with loss 13.856321334838867, MAE = 50.103885650634766, MSE = 70.77238464355469
In step 3460, epoch 3, with loss 9.050437927246094, MAE = 32.28844451904297, MSE 

In step 190, epoch 4, with loss 10.37904167175293, MAE = 82.67131805419922, MSE = 92.07493591308594
In step 200, epoch 4, with loss 8.019418716430664, MAE = 36.900123596191406, MSE = 42.27691650390625
In step 210, epoch 4, with loss 9.955469131469727, MAE = 35.25199508666992, MSE = 47.078983306884766
In step 220, epoch 4, with loss 11.84699821472168, MAE = 40.39348602294922, MSE = 64.57630157470703
In step 230, epoch 4, with loss 7.671513557434082, MAE = 33.586692810058594, MSE = 49.72355270385742
In step 240, epoch 4, with loss 11.062685012817383, MAE = 56.49226760864258, MSE = 62.78672409057617
In step 250, epoch 4, with loss 8.206181526184082, MAE = 35.16775894165039, MSE = 40.12656021118164
In step 260, epoch 4, with loss 7.431092262268066, MAE = 20.834392547607422, MSE = 24.786197662353516
In step 270, epoch 4, with loss 6.0901007652282715, MAE = 33.78396224975586, MSE = 42.957950592041016
In step 280, epoch 4, with loss 11.942405700683594, MAE = 70.52002716064453, MSE = 91.893241

In step 1010, epoch 4, with loss 9.518396377563477, MAE = 27.298954010009766, MSE = 33.02999496459961
In step 1020, epoch 4, with loss 12.655542373657227, MAE = 123.46159362792969, MSE = 138.89366149902344
In step 1030, epoch 4, with loss 8.67955207824707, MAE = 22.51114845275879, MSE = 27.49967384338379
In step 1040, epoch 4, with loss 14.426048278808594, MAE = 65.7197494506836, MSE = 74.68460845947266
In step 1050, epoch 4, with loss 11.418909072875977, MAE = 72.60379791259766, MSE = 97.80518341064453
In step 1060, epoch 4, with loss 13.694559097290039, MAE = 62.8116340637207, MSE = 94.0273666381836
In step 1070, epoch 4, with loss 18.298175811767578, MAE = 118.4963607788086, MSE = 155.76182556152344
In step 1080, epoch 4, with loss 14.815272331237793, MAE = 66.41374969482422, MSE = 106.46028900146484
In step 1090, epoch 4, with loss 8.628641128540039, MAE = 46.199859619140625, MSE = 51.510040283203125
In step 1100, epoch 4, with loss 10.739545822143555, MAE = 45.8194465637207, MSE =

In step 1820, epoch 4, with loss 9.465348243713379, MAE = 45.366390228271484, MSE = 63.951595306396484
In step 1830, epoch 4, with loss 10.836359977722168, MAE = 61.52006149291992, MSE = 89.19589233398438
In step 1840, epoch 4, with loss 9.073712348937988, MAE = 130.51576232910156, MSE = 134.6241455078125
In step 1850, epoch 4, with loss 6.3992743492126465, MAE = 73.05350494384766, MSE = 84.45596313476562
In step 1860, epoch 4, with loss 11.65609359741211, MAE = 89.08343505859375, MSE = 105.44483184814453
In step 1870, epoch 4, with loss 7.473379611968994, MAE = 50.94190979003906, MSE = 55.21915054321289
In step 1880, epoch 4, with loss 11.795356750488281, MAE = 52.3555908203125, MSE = 69.72150421142578
In step 1890, epoch 4, with loss 8.335647583007812, MAE = 51.3351936340332, MSE = 62.553855895996094
In step 1900, epoch 4, with loss 16.90988540649414, MAE = 120.28226470947266, MSE = 136.76951599121094
In step 1910, epoch 4, with loss 11.513287544250488, MAE = 48.97899627685547, MSE =

In step 2630, epoch 4, with loss 15.210779190063477, MAE = 73.49771881103516, MSE = 101.09644317626953
In step 2640, epoch 4, with loss 12.861668586730957, MAE = 61.29426193237305, MSE = 69.25818634033203
In step 2650, epoch 4, with loss 8.767781257629395, MAE = 83.69996643066406, MSE = 95.41471099853516
In step 2660, epoch 4, with loss 9.219704627990723, MAE = 90.89971923828125, MSE = 107.49818420410156
In step 2670, epoch 4, with loss 13.822669982910156, MAE = 49.26057434082031, MSE = 68.52479553222656
In step 2680, epoch 4, with loss 17.776222229003906, MAE = 53.40528106689453, MSE = 82.72529602050781
In step 2690, epoch 4, with loss 6.921104431152344, MAE = 31.218997955322266, MSE = 35.733802795410156
In step 2700, epoch 4, with loss 8.130603790283203, MAE = 89.5127182006836, MSE = 99.89561462402344
In step 2710, epoch 4, with loss 8.954548835754395, MAE = 73.63652038574219, MSE = 80.93936157226562
In step 2720, epoch 4, with loss 12.11190128326416, MAE = 37.19504165649414, MSE = 5

In step 3440, epoch 4, with loss 10.045292854309082, MAE = 42.5037956237793, MSE = 53.157135009765625
In step 3450, epoch 4, with loss 7.087973117828369, MAE = 32.179237365722656, MSE = 44.08381271362305
In step 3460, epoch 4, with loss 8.008474349975586, MAE = 47.18602752685547, MSE = 64.24676513671875
In step 3470, epoch 4, with loss 9.462244033813477, MAE = 45.17023468017578, MSE = 59.609066009521484
In step 3480, epoch 4, with loss 10.220154762268066, MAE = 47.455718994140625, MSE = 68.39163970947266
In step 3490, epoch 4, with loss 16.337772369384766, MAE = 142.13275146484375, MSE = 149.82574462890625
In step 3500, epoch 4, with loss 5.078225135803223, MAE = 69.03387451171875, MSE = 75.08219909667969
In step 3510, epoch 4, with loss 8.73658561706543, MAE = 39.527740478515625, MSE = 65.53888702392578
In step 3520, epoch 4, with loss 9.567251205444336, MAE = 54.22053909301758, MSE = 79.86062622070312
In step 3530, epoch 4, with loss 10.253475189208984, MAE = 47.07114028930664, MSE =

In step 250, epoch 5, with loss 11.585526466369629, MAE = 53.8515510559082, MSE = 88.64086151123047
In step 260, epoch 5, with loss 10.927385330200195, MAE = 41.61089324951172, MSE = 50.366451263427734
In step 270, epoch 5, with loss 8.248745918273926, MAE = 48.441009521484375, MSE = 60.94964599609375
In step 280, epoch 5, with loss 8.461115837097168, MAE = 52.27644729614258, MSE = 73.43861389160156
In step 290, epoch 5, with loss 11.113689422607422, MAE = 172.881103515625, MSE = 179.77938842773438
In step 300, epoch 5, with loss 7.871546268463135, MAE = 24.61240005493164, MSE = 30.094127655029297
In step 310, epoch 5, with loss 9.619756698608398, MAE = 70.06135559082031, MSE = 78.47663116455078
In step 320, epoch 5, with loss 8.300500869750977, MAE = 31.764896392822266, MSE = 38.279327392578125
In step 330, epoch 5, with loss 14.578413009643555, MAE = 39.01966857910156, MSE = 56.02253341674805
In step 340, epoch 5, with loss 6.901281833648682, MAE = 53.15751266479492, MSE = 57.1691856

In step 1070, epoch 5, with loss 8.390893936157227, MAE = 30.100515365600586, MSE = 44.20340347290039
In step 1080, epoch 5, with loss 10.483372688293457, MAE = 80.83551788330078, MSE = 87.72526550292969
In step 1090, epoch 5, with loss 9.874127388000488, MAE = 69.00274658203125, MSE = 76.6238784790039
In step 1100, epoch 5, with loss 8.850740432739258, MAE = 31.951345443725586, MSE = 48.2629508972168
In step 1110, epoch 5, with loss 10.035210609436035, MAE = 43.209930419921875, MSE = 54.201725006103516
In step 1120, epoch 5, with loss 10.276289939880371, MAE = 28.898967742919922, MSE = 36.595367431640625
In step 1130, epoch 5, with loss 9.04599666595459, MAE = 44.80168533325195, MSE = 59.334617614746094
In step 1140, epoch 5, with loss 8.812047958374023, MAE = 71.44076538085938, MSE = 77.86101531982422
In step 1150, epoch 5, with loss 21.52139663696289, MAE = 135.96261596679688, MSE = 171.49261474609375
In step 1160, epoch 5, with loss 10.724023818969727, MAE = 47.71892547607422, MSE 

In step 1880, epoch 5, with loss 9.298282623291016, MAE = 50.57698059082031, MSE = 60.14552307128906
In step 1890, epoch 5, with loss 13.174324035644531, MAE = 83.98974609375, MSE = 102.02459716796875
In step 1900, epoch 5, with loss 10.899518966674805, MAE = 46.88393020629883, MSE = 67.22198486328125
In step 1910, epoch 5, with loss 7.259839057922363, MAE = 16.492382049560547, MSE = 22.59957504272461
In step 1920, epoch 5, with loss 11.373899459838867, MAE = 30.8070011138916, MSE = 43.987876892089844
In step 1930, epoch 5, with loss 6.615933895111084, MAE = 65.18904876708984, MSE = 71.88631439208984
In step 1940, epoch 5, with loss 6.6581130027771, MAE = 55.7421989440918, MSE = 62.171653747558594
In step 1950, epoch 5, with loss 12.466097831726074, MAE = 57.730918884277344, MSE = 72.73159790039062
In step 1960, epoch 5, with loss 8.044351577758789, MAE = 38.00480270385742, MSE = 46.67412567138672
In step 1970, epoch 5, with loss 9.681386947631836, MAE = 26.618968963623047, MSE = 35.68

In step 2690, epoch 5, with loss 5.649073600769043, MAE = 31.582204818725586, MSE = 44.647178649902344
In step 2700, epoch 5, with loss 10.527705192565918, MAE = 48.0108528137207, MSE = 71.98912048339844
In step 2710, epoch 5, with loss 8.912301063537598, MAE = 36.98262023925781, MSE = 55.73017883300781
In step 2720, epoch 5, with loss 9.161958694458008, MAE = 47.64812088012695, MSE = 66.237548828125
In step 2730, epoch 5, with loss 7.694025993347168, MAE = 43.91191482543945, MSE = 56.61378860473633
In step 2740, epoch 5, with loss 8.086564064025879, MAE = 44.467308044433594, MSE = 49.218017578125
In step 2750, epoch 5, with loss 10.001581192016602, MAE = 56.22740936279297, MSE = 68.57828521728516
In step 2760, epoch 5, with loss 5.896082401275635, MAE = 20.190509796142578, MSE = 32.8674201965332
In step 2770, epoch 5, with loss 8.123724937438965, MAE = 72.20985412597656, MSE = 78.23033905029297
In step 2780, epoch 5, with loss 5.883731365203857, MAE = 72.33744812011719, MSE = 74.74196

In step 3500, epoch 5, with loss 6.577510833740234, MAE = 28.62264633178711, MSE = 34.887630462646484
In step 3510, epoch 5, with loss 10.95762825012207, MAE = 39.75525665283203, MSE = 61.76693344116211
In step 3520, epoch 5, with loss 5.737114906311035, MAE = 17.73263168334961, MSE = 19.34369468688965
In step 3530, epoch 5, with loss 7.154791355133057, MAE = 57.251487731933594, MSE = 61.324241638183594
In step 3540, epoch 5, with loss 5.598930358886719, MAE = 41.42318344116211, MSE = 45.98674392700195
In step 3550, epoch 5, with loss 8.576021194458008, MAE = 25.649005889892578, MSE = 32.60110092163086
In step 3560, epoch 5, with loss 6.577439785003662, MAE = 73.43214416503906, MSE = 79.11656188964844
In step 3570, epoch 5, with loss 10.372610092163086, MAE = 31.029727935791016, MSE = 36.0967903137207
In step 3580, epoch 5, with loss 9.812483787536621, MAE = 38.723838806152344, MSE = 52.288612365722656
In step 3590, epoch 5, with loss 10.832620620727539, MAE = 34.91267395019531, MSE = 

In step 320, epoch 6, with loss 14.877542495727539, MAE = 48.68649673461914, MSE = 62.56789016723633
In step 330, epoch 6, with loss 10.548837661743164, MAE = 37.92462921142578, MSE = 67.28529357910156
In step 340, epoch 6, with loss 11.399613380432129, MAE = 60.21272659301758, MSE = 83.24295043945312
In step 350, epoch 6, with loss 21.954774856567383, MAE = 84.07759094238281, MSE = 139.1324462890625
In step 360, epoch 6, with loss 12.989649772644043, MAE = 63.81184768676758, MSE = 88.74979400634766
In step 370, epoch 6, with loss 17.235509872436523, MAE = 130.40687561035156, MSE = 137.47857666015625
In step 380, epoch 6, with loss 8.807883262634277, MAE = 81.09452056884766, MSE = 88.22238159179688
In step 390, epoch 6, with loss 16.85511016845703, MAE = 78.23136901855469, MSE = 86.80943298339844
In step 400, epoch 6, with loss 13.611615180969238, MAE = 105.491943359375, MSE = 114.84552764892578
In step 410, epoch 6, with loss 6.145216941833496, MAE = 16.077342987060547, MSE = 20.93749

In step 1140, epoch 6, with loss 12.7203950881958, MAE = 53.06473922729492, MSE = 81.42864990234375
In step 1150, epoch 6, with loss 7.817763328552246, MAE = 41.79841995239258, MSE = 59.02577209472656
In step 1160, epoch 6, with loss 9.105463027954102, MAE = 49.758541107177734, MSE = 63.971221923828125
In step 1170, epoch 6, with loss 10.673140525817871, MAE = 90.16326141357422, MSE = 95.11373901367188
In step 1180, epoch 6, with loss 10.266361236572266, MAE = 88.55208587646484, MSE = 100.5421371459961
In step 1190, epoch 6, with loss 7.112758636474609, MAE = 22.936767578125, MSE = 35.640506744384766
In step 1200, epoch 6, with loss 12.561586380004883, MAE = 46.811500549316406, MSE = 61.4371337890625
In step 1210, epoch 6, with loss 7.185776710510254, MAE = 40.41803741455078, MSE = 43.28352737426758
In step 1220, epoch 6, with loss 12.328699111938477, MAE = 36.68035125732422, MSE = 49.49480438232422
In step 1230, epoch 6, with loss 8.035405158996582, MAE = 81.62349700927734, MSE = 87.1

In step 1950, epoch 6, with loss 7.361525535583496, MAE = 33.511539459228516, MSE = 48.17201232910156
In step 1960, epoch 6, with loss 11.687917709350586, MAE = 62.965599060058594, MSE = 67.68440246582031
In step 1970, epoch 6, with loss 7.47598934173584, MAE = 22.812204360961914, MSE = 26.86837387084961
In step 1980, epoch 6, with loss 12.490801811218262, MAE = 42.623695373535156, MSE = 78.21170043945312
In step 1990, epoch 6, with loss 7.824715614318848, MAE = 18.854175567626953, MSE = 28.686105728149414
In step 2000, epoch 6, with loss 10.778882026672363, MAE = 42.31011199951172, MSE = 61.706119537353516
In step 2010, epoch 6, with loss 8.67817497253418, MAE = 33.293216705322266, MSE = 38.326942443847656
In step 2020, epoch 6, with loss 7.0132269859313965, MAE = 47.85546112060547, MSE = 51.07352828979492
In step 2030, epoch 6, with loss 10.513279914855957, MAE = 36.51232147216797, MSE = 47.378990173339844
In step 2040, epoch 6, with loss 8.214883804321289, MAE = 56.18621063232422, M

In step 2760, epoch 6, with loss 4.779693603515625, MAE = 28.876140594482422, MSE = 33.49728775024414
In step 2770, epoch 6, with loss 11.3673095703125, MAE = 122.3712158203125, MSE = 125.52719116210938
In step 2780, epoch 6, with loss 6.823849678039551, MAE = 77.96919250488281, MSE = 80.4233627319336
In step 2790, epoch 6, with loss 7.443669319152832, MAE = 32.509037017822266, MSE = 46.39541244506836
In step 2800, epoch 6, with loss 8.878043174743652, MAE = 31.609928131103516, MSE = 45.96533966064453
In step 2810, epoch 6, with loss 6.16872501373291, MAE = 20.945735931396484, MSE = 25.628503799438477
In step 2820, epoch 6, with loss 12.60114860534668, MAE = 50.577877044677734, MSE = 70.0048599243164
In step 2830, epoch 6, with loss 8.825587272644043, MAE = 33.975013732910156, MSE = 35.814884185791016
In step 2840, epoch 6, with loss 7.429758548736572, MAE = 26.81563377380371, MSE = 31.362215042114258
In step 2850, epoch 6, with loss 8.148340225219727, MAE = 53.392242431640625, MSE = 6

In step 3570, epoch 6, with loss 11.764787673950195, MAE = 39.493438720703125, MSE = 58.7606315612793
In step 3580, epoch 6, with loss 10.239937782287598, MAE = 54.5096321105957, MSE = 71.4146728515625
In step 3590, epoch 6, with loss 6.207376956939697, MAE = 26.95685386657715, MSE = 37.52703094482422
In step 3600, epoch 6, with loss 10.370866775512695, MAE = 34.89337921142578, MSE = 61.96242904663086
In step 3610, epoch 6, with loss 7.144239902496338, MAE = 33.80731964111328, MSE = 37.78641128540039
In step 3620, epoch 6, with loss 6.924895286560059, MAE = 20.205711364746094, MSE = 25.17491340637207
In step 3630, epoch 6, with loss 8.132278442382812, MAE = 20.069133758544922, MSE = 29.18880271911621
In step 3640, epoch 6, with loss 8.73204231262207, MAE = 27.54276466369629, MSE = 35.71829605102539
In step 3650, epoch 6, with loss 6.499366760253906, MAE = 43.368282318115234, MSE = 49.65464401245117
In step 3660, epoch 6, with loss 10.193901062011719, MAE = 57.96392822265625, MSE = 73.2

In step 380, epoch 7, with loss 5.735886573791504, MAE = 47.43824768066406, MSE = 51.370662689208984
In step 390, epoch 7, with loss 5.371476173400879, MAE = 20.5054988861084, MSE = 26.28899574279785
In step 400, epoch 7, with loss 6.748950004577637, MAE = 36.60907745361328, MSE = 48.556915283203125
In step 410, epoch 7, with loss 6.068674564361572, MAE = 42.20014953613281, MSE = 47.96805953979492
In step 420, epoch 7, with loss 6.439833164215088, MAE = 21.067584991455078, MSE = 38.97791290283203
In step 430, epoch 7, with loss 7.57867956161499, MAE = 64.25265502929688, MSE = 71.07337951660156
In step 440, epoch 7, with loss 5.9650492668151855, MAE = 61.68120574951172, MSE = 66.36399841308594
In step 450, epoch 7, with loss 5.134411811828613, MAE = 21.24610710144043, MSE = 25.23804473876953
In step 460, epoch 7, with loss 6.923887729644775, MAE = 23.20590591430664, MSE = 33.700008392333984
In step 470, epoch 7, with loss 5.440001487731934, MAE = 19.525943756103516, MSE = 27.05724334716

In step 1200, epoch 7, with loss 9.645472526550293, MAE = 30.257726669311523, MSE = 48.32984161376953
In step 1210, epoch 7, with loss 7.067492485046387, MAE = 23.14047622680664, MSE = 29.614826202392578
In step 1220, epoch 7, with loss 7.730578422546387, MAE = 75.95552062988281, MSE = 80.39163970947266
In step 1230, epoch 7, with loss 13.150564193725586, MAE = 78.41571044921875, MSE = 97.8713607788086
In step 1240, epoch 7, with loss 4.570278644561768, MAE = 40.26225280761719, MSE = 43.87152862548828
In step 1250, epoch 7, with loss 9.653614044189453, MAE = 27.42019271850586, MSE = 48.1459846496582
In step 1260, epoch 7, with loss 6.696630001068115, MAE = 33.85710144042969, MSE = 37.78769302368164
In step 1270, epoch 7, with loss 8.808351516723633, MAE = 130.532470703125, MSE = 136.0768585205078
In step 1280, epoch 7, with loss 6.295652866363525, MAE = 71.74532318115234, MSE = 73.25628662109375
In step 1290, epoch 7, with loss 9.127031326293945, MAE = 32.85719299316406, MSE = 48.27085

In step 2010, epoch 7, with loss 9.169739723205566, MAE = 29.648366928100586, MSE = 41.30279541015625
In step 2020, epoch 7, with loss 8.826131820678711, MAE = 32.56475830078125, MSE = 53.824188232421875
In step 2030, epoch 7, with loss 6.350044250488281, MAE = 22.4005184173584, MSE = 30.849008560180664
In step 2040, epoch 7, with loss 7.701112270355225, MAE = 35.843666076660156, MSE = 51.682125091552734
In step 2050, epoch 7, with loss 6.242371559143066, MAE = 21.35474395751953, MSE = 26.138715744018555
In step 2060, epoch 7, with loss 5.237995624542236, MAE = 26.472925186157227, MSE = 31.816904067993164
In step 2070, epoch 7, with loss 9.344375610351562, MAE = 101.77588653564453, MSE = 107.62086486816406
In step 2080, epoch 7, with loss 11.862565040588379, MAE = 30.377843856811523, MSE = 44.16606903076172
In step 2090, epoch 7, with loss 5.145717620849609, MAE = 12.179546356201172, MSE = 17.25403594970703
In step 2100, epoch 7, with loss 7.028662204742432, MAE = 33.33783721923828, MS

In step 2820, epoch 7, with loss 8.994091033935547, MAE = 44.31819534301758, MSE = 79.5088119506836
In step 2830, epoch 7, with loss 2.785722017288208, MAE = 25.901355743408203, MSE = 28.481887817382812
In step 2840, epoch 7, with loss 9.92878532409668, MAE = 41.36835479736328, MSE = 58.571998596191406
In step 2850, epoch 7, with loss 5.146565914154053, MAE = 28.83316421508789, MSE = 38.802852630615234
In step 2860, epoch 7, with loss 9.753828048706055, MAE = 86.74856567382812, MSE = 94.98493194580078
In step 2870, epoch 7, with loss 11.175431251525879, MAE = 98.15956115722656, MSE = 110.08429718017578
In step 2880, epoch 7, with loss 8.516960144042969, MAE = 40.839088439941406, MSE = 60.98095703125
In step 2890, epoch 7, with loss 13.87951374053955, MAE = 57.38984298706055, MSE = 97.32494354248047
In step 2900, epoch 7, with loss 10.475231170654297, MAE = 60.8182373046875, MSE = 79.24955749511719
In step 2910, epoch 7, with loss 7.039504051208496, MAE = 53.869110107421875, MSE = 61.61

In step 3630, epoch 7, with loss 8.544453620910645, MAE = 42.0706787109375, MSE = 49.83629608154297
In step 3640, epoch 7, with loss 10.948627471923828, MAE = 78.6176528930664, MSE = 90.90787506103516
In step 3650, epoch 7, with loss 7.030712127685547, MAE = 31.120840072631836, MSE = 38.54007339477539
In step 3660, epoch 7, with loss 9.667961120605469, MAE = 39.796058654785156, MSE = 57.06657791137695
In step 3670, epoch 7, with loss 16.572919845581055, MAE = 49.457000732421875, MSE = 60.68821334838867
In step 3680, epoch 7, with loss 17.953121185302734, MAE = 113.42182922363281, MSE = 151.8192596435547
In step 3690, epoch 7, with loss 6.024279594421387, MAE = 36.49825668334961, MSE = 39.527835845947266
In step 3700, epoch 7, with loss 8.852968215942383, MAE = 231.8308563232422, MSE = 234.56979370117188
In step 3710, epoch 7, with loss 12.758810043334961, MAE = 139.56492614746094, MSE = 147.4146728515625
In step 3720, epoch 7, with loss 11.840052604675293, MAE = 38.26115417480469, MSE 

In step 450, epoch 8, with loss 5.67299747467041, MAE = 19.394454956054688, MSE = 27.359779357910156
In step 460, epoch 8, with loss 5.19304895401001, MAE = 23.786584854125977, MSE = 28.162128448486328
In step 470, epoch 8, with loss 6.558756351470947, MAE = 24.46737289428711, MSE = 38.26500701904297
In step 480, epoch 8, with loss 9.105114936828613, MAE = 44.86912536621094, MSE = 73.12356567382812
In step 490, epoch 8, with loss 8.478595733642578, MAE = 25.559053421020508, MSE = 34.47636795043945
In step 500, epoch 8, with loss 8.383687973022461, MAE = 21.322708129882812, MSE = 24.746997833251953
In step 510, epoch 8, with loss 8.883747100830078, MAE = 66.4826889038086, MSE = 74.25504302978516
In step 520, epoch 8, with loss 9.162805557250977, MAE = 112.17237854003906, MSE = 115.92578887939453
In step 530, epoch 8, with loss 10.483924865722656, MAE = 82.06230163574219, MSE = 97.29839324951172
In step 540, epoch 8, with loss 5.717984199523926, MAE = 16.183349609375, MSE = 26.7168464660

In step 1270, epoch 8, with loss 6.544881343841553, MAE = 51.406715393066406, MSE = 57.482723236083984
In step 1280, epoch 8, with loss 6.7742462158203125, MAE = 19.291284561157227, MSE = 27.790714263916016
In step 1290, epoch 8, with loss 7.7052001953125, MAE = 54.89772415161133, MSE = 58.44612121582031
In step 1300, epoch 8, with loss 7.850285530090332, MAE = 24.330514907836914, MSE = 33.75797653198242
In step 1310, epoch 8, with loss 6.551308631896973, MAE = 21.654634475708008, MSE = 28.20132064819336
In step 1320, epoch 8, with loss 6.1133928298950195, MAE = 17.172603607177734, MSE = 24.87477684020996
In step 1330, epoch 8, with loss 10.069607734680176, MAE = 22.396976470947266, MSE = 27.17531967163086
In step 1340, epoch 8, with loss 9.54173469543457, MAE = 30.265213012695312, MSE = 44.81431579589844
In step 1350, epoch 8, with loss 10.573694229125977, MAE = 36.67170715332031, MSE = 44.3609619140625
In step 1360, epoch 8, with loss 6.153468608856201, MAE = 18.13481330871582, MSE =

In step 2080, epoch 8, with loss 11.347166061401367, MAE = 21.844585418701172, MSE = 36.95124435424805
In step 2090, epoch 8, with loss 7.1895904541015625, MAE = 18.27601432800293, MSE = 23.43660545349121
In step 2100, epoch 8, with loss 7.342504978179932, MAE = 18.774559020996094, MSE = 21.379932403564453
In step 2110, epoch 8, with loss 10.22486686706543, MAE = 50.50400924682617, MSE = 57.96192932128906
In step 2120, epoch 8, with loss 9.426228523254395, MAE = 39.0152587890625, MSE = 52.28374481201172
In step 2130, epoch 8, with loss 5.626099586486816, MAE = 21.647953033447266, MSE = 29.000837326049805
In step 2140, epoch 8, with loss 8.190834045410156, MAE = 27.09454345703125, MSE = 39.53414535522461
In step 2150, epoch 8, with loss 8.817405700683594, MAE = 21.361125946044922, MSE = 26.58915138244629
In step 2160, epoch 8, with loss 7.903780460357666, MAE = 16.163734436035156, MSE = 19.780488967895508
In step 2170, epoch 8, with loss 6.310243606567383, MAE = 22.527626037597656, MSE 

In step 2890, epoch 8, with loss 8.980389595031738, MAE = 43.97999954223633, MSE = 49.09783935546875
In step 2900, epoch 8, with loss 5.848944664001465, MAE = 27.677989959716797, MSE = 32.91490936279297
In step 2910, epoch 8, with loss 4.9951324462890625, MAE = 13.547845840454102, MSE = 16.715301513671875
In step 2920, epoch 8, with loss 7.582684516906738, MAE = 15.456342697143555, MSE = 25.285236358642578
In step 2930, epoch 8, with loss 8.325878143310547, MAE = 36.521400451660156, MSE = 49.9738655090332
In step 2940, epoch 8, with loss 14.402026176452637, MAE = 95.23512268066406, MSE = 112.60706329345703
In step 2950, epoch 8, with loss 7.581561088562012, MAE = 35.718536376953125, MSE = 48.89945602416992
In step 2960, epoch 8, with loss 11.61211109161377, MAE = 41.49050521850586, MSE = 55.6982307434082
In step 2970, epoch 8, with loss 10.947200775146484, MAE = 122.96241760253906, MSE = 129.9219512939453
In step 2980, epoch 8, with loss 8.987776756286621, MAE = 38.95005798339844, MSE 

In step 3700, epoch 8, with loss 6.648951053619385, MAE = 32.652313232421875, MSE = 35.151241302490234
In step 3710, epoch 8, with loss 5.0533246994018555, MAE = 17.134449005126953, MSE = 28.79950714111328
In step 3720, epoch 8, with loss 9.561243057250977, MAE = 36.389427185058594, MSE = 53.01535415649414
In step 3730, epoch 8, with loss 4.958138465881348, MAE = 30.483478546142578, MSE = 32.90446853637695
In step 3740, epoch 8, with loss 8.818951606750488, MAE = 23.99936294555664, MSE = 39.024078369140625
In step 3750, epoch 8, with loss 8.954676628112793, MAE = 30.859424591064453, MSE = 48.99160385131836
In step 3760, epoch 8, with loss 9.072789192199707, MAE = 36.55316925048828, MSE = 59.29472351074219
In step 3770, epoch 8, with loss 7.095428466796875, MAE = 30.058944702148438, MSE = 37.12400436401367
In step 3780, epoch 8, with loss 7.622443199157715, MAE = 21.517030715942383, MSE = 39.128692626953125
In step 3790, epoch 8, with loss 9.111077308654785, MAE = 63.909767150878906, MS

In step 510, epoch 9, with loss 9.229809761047363, MAE = 26.24587631225586, MSE = 40.80982971191406
In step 520, epoch 9, with loss 8.493852615356445, MAE = 23.708118438720703, MSE = 42.47865295410156
In step 530, epoch 9, with loss 9.814289093017578, MAE = 45.152626037597656, MSE = 57.375370025634766
In step 540, epoch 9, with loss 7.906069755554199, MAE = 29.24935531616211, MSE = 37.84724426269531
In step 550, epoch 9, with loss 7.474604606628418, MAE = 86.53754425048828, MSE = 92.2582015991211
In step 560, epoch 9, with loss 8.102230072021484, MAE = 40.781837463378906, MSE = 53.916263580322266
In step 570, epoch 9, with loss 8.047569274902344, MAE = 23.656158447265625, MSE = 31.03850555419922
In step 580, epoch 9, with loss 7.2661027908325195, MAE = 33.38925552368164, MSE = 39.15380859375
In step 590, epoch 9, with loss 10.13234806060791, MAE = 45.069732666015625, MSE = 54.437171936035156
In step 600, epoch 9, with loss 13.832735061645508, MAE = 47.775146484375, MSE = 86.12884521484

In step 1330, epoch 9, with loss 8.345486640930176, MAE = 28.078195571899414, MSE = 43.29718780517578
In step 1340, epoch 9, with loss 7.037271976470947, MAE = 23.128089904785156, MSE = 29.70230484008789
In step 1350, epoch 9, with loss 4.675936222076416, MAE = 8.45695686340332, MSE = 11.407790184020996
In step 1360, epoch 9, with loss 7.972740173339844, MAE = 20.869884490966797, MSE = 41.879207611083984
In step 1370, epoch 9, with loss 7.224862575531006, MAE = 34.94202423095703, MSE = 48.80376052856445
In step 1380, epoch 9, with loss 7.436440467834473, MAE = 58.64060592651367, MSE = 71.88040924072266
In step 1390, epoch 9, with loss 7.198511600494385, MAE = 71.2875747680664, MSE = 78.94938659667969
In step 1400, epoch 9, with loss 9.157630920410156, MAE = 36.818809509277344, MSE = 58.96591567993164
In step 1410, epoch 9, with loss 5.7794904708862305, MAE = 20.570274353027344, MSE = 30.70667839050293
In step 1420, epoch 9, with loss 7.052319526672363, MAE = 19.683671951293945, MSE = 2

In step 2140, epoch 9, with loss 7.033916473388672, MAE = 22.042421340942383, MSE = 30.240880966186523
In step 2150, epoch 9, with loss 6.168244361877441, MAE = 50.376861572265625, MSE = 56.74872589111328
In step 2160, epoch 9, with loss 8.884557723999023, MAE = 40.06194305419922, MSE = 54.33042907714844
In step 2170, epoch 9, with loss 8.475493431091309, MAE = 66.44285583496094, MSE = 75.83332061767578
In step 2180, epoch 9, with loss 7.411938667297363, MAE = 20.71309471130371, MSE = 27.66046905517578
In step 2190, epoch 9, with loss 7.341611385345459, MAE = 32.93360137939453, MSE = 39.315086364746094
In step 2200, epoch 9, with loss 6.665280342102051, MAE = 22.235149383544922, MSE = 27.124759674072266
In step 2210, epoch 9, with loss 6.489195823669434, MAE = 17.661113739013672, MSE = 25.059181213378906
In step 2220, epoch 9, with loss 6.993141174316406, MAE = 23.194866180419922, MSE = 29.520221710205078
In step 2230, epoch 9, with loss 6.609609127044678, MAE = 27.807018280029297, MSE

In step 2950, epoch 9, with loss 5.294398784637451, MAE = 46.93126678466797, MSE = 51.45530700683594
In step 2960, epoch 9, with loss 15.854202270507812, MAE = 67.33943176269531, MSE = 86.74382019042969
In step 2970, epoch 9, with loss 7.09640645980835, MAE = 57.039947509765625, MSE = 60.2845573425293
In step 2980, epoch 9, with loss 9.65339469909668, MAE = 70.97529602050781, MSE = 81.20819854736328
In step 2990, epoch 9, with loss 5.820860385894775, MAE = 39.320858001708984, MSE = 41.96913146972656
In step 3000, epoch 9, with loss 8.261487007141113, MAE = 25.93147850036621, MSE = 33.98160934448242
In step 3010, epoch 9, with loss 8.241479873657227, MAE = 24.955585479736328, MSE = 40.2740592956543
In step 3020, epoch 9, with loss 6.270589351654053, MAE = 22.989315032958984, MSE = 28.32823944091797
In step 3030, epoch 9, with loss 8.345521926879883, MAE = 35.93608474731445, MSE = 39.958065032958984
In step 3040, epoch 9, with loss 7.645684719085693, MAE = 34.901031494140625, MSE = 43.40

In step 3760, epoch 9, with loss 5.344038009643555, MAE = 21.12909698486328, MSE = 27.726831436157227
In step 3770, epoch 9, with loss 5.32896614074707, MAE = 17.98611068725586, MSE = 24.980867385864258
In step 3780, epoch 9, with loss 5.141870498657227, MAE = 39.39516830444336, MSE = 42.25033950805664
In step 3790, epoch 9, with loss 8.778999328613281, MAE = 35.41191101074219, MSE = 58.78614807128906
In step 3800, epoch 9, with loss 10.395002365112305, MAE = 147.33544921875, MSE = 153.23574829101562
In step 3810, epoch 9, with loss 8.1153564453125, MAE = 111.5645751953125, MSE = 116.03277587890625
In step 3820, epoch 9, with loss 10.33903694152832, MAE = 52.565284729003906, MSE = 55.40726852416992
In step 3830, epoch 9, with loss 5.479589462280273, MAE = 14.858563423156738, MSE = 19.741504669189453
In step 3840, epoch 9, with loss 7.3767876625061035, MAE = 28.387134552001953, MSE = 43.228214263916016
In step 3850, epoch 9, with loss 8.11279582977295, MAE = 14.611584663391113, MSE = 30

In step 570, epoch 10, with loss 5.593844890594482, MAE = 25.212921142578125, MSE = 33.69072723388672
In step 580, epoch 10, with loss 5.505239486694336, MAE = 14.487353324890137, MSE = 26.470897674560547
In step 590, epoch 10, with loss 6.248238563537598, MAE = 19.790485382080078, MSE = 30.627216339111328
In step 600, epoch 10, with loss 4.325000762939453, MAE = 35.981754302978516, MSE = 44.520774841308594
In step 610, epoch 10, with loss 5.552661895751953, MAE = 23.662378311157227, MSE = 36.55812072753906
In step 620, epoch 10, with loss 7.96309757232666, MAE = 62.01024627685547, MSE = 66.22618103027344
In step 630, epoch 10, with loss 7.2463274002075195, MAE = 86.73392486572266, MSE = 89.72321319580078
In step 640, epoch 10, with loss 7.948086738586426, MAE = 84.83710479736328, MSE = 89.8170394897461
In step 650, epoch 10, with loss 5.042662620544434, MAE = 18.083799362182617, MSE = 27.956022262573242
In step 660, epoch 10, with loss 7.000892639160156, MAE = 57.764793395996094, MSE 

In step 1380, epoch 10, with loss 6.8072710037231445, MAE = 19.827510833740234, MSE = 27.36899185180664
In step 1390, epoch 10, with loss 6.759226322174072, MAE = 23.305160522460938, MSE = 29.98597526550293
In step 1400, epoch 10, with loss 7.447314262390137, MAE = 34.1992301940918, MSE = 39.479637145996094
In step 1410, epoch 10, with loss 5.361700534820557, MAE = 25.945343017578125, MSE = 29.41046142578125
In step 1420, epoch 10, with loss 8.375663757324219, MAE = 83.25898742675781, MSE = 88.37796783447266
In step 1430, epoch 10, with loss 10.610294342041016, MAE = 69.71965026855469, MSE = 86.10528564453125
In step 1440, epoch 10, with loss 8.223291397094727, MAE = 33.79938888549805, MSE = 41.40420150756836
In step 1450, epoch 10, with loss 6.624229431152344, MAE = 44.346153259277344, MSE = 53.92533493041992
In step 1460, epoch 10, with loss 9.382485389709473, MAE = 34.04477310180664, MSE = 64.36954498291016
In step 1470, epoch 10, with loss 6.06172513961792, MAE = 22.809980392456055

In step 2180, epoch 10, with loss 8.380006790161133, MAE = 25.701711654663086, MSE = 39.08209228515625
In step 2190, epoch 10, with loss 4.838849067687988, MAE = 33.960506439208984, MSE = 35.33925247192383
In step 2200, epoch 10, with loss 8.679372787475586, MAE = 76.35223388671875, MSE = 90.00940704345703
In step 2210, epoch 10, with loss 10.324174880981445, MAE = 36.580322265625, MSE = 45.90366744995117
In step 2220, epoch 10, with loss 6.567008972167969, MAE = 45.400699615478516, MSE = 49.368377685546875
In step 2230, epoch 10, with loss 6.232418537139893, MAE = 19.206016540527344, MSE = 27.43120574951172
In step 2240, epoch 10, with loss 5.5318989753723145, MAE = 14.753007888793945, MSE = 35.37977981567383
In step 2250, epoch 10, with loss 4.098970413208008, MAE = 16.10744857788086, MSE = 18.840158462524414
In step 2260, epoch 10, with loss 7.328513145446777, MAE = 57.419700622558594, MSE = 64.71046447753906
In step 2270, epoch 10, with loss 7.171913146972656, MAE = 27.860408782958

In step 2980, epoch 10, with loss 7.391202449798584, MAE = 10.536023139953613, MSE = 15.690321922302246
In step 2990, epoch 10, with loss 4.011660575866699, MAE = 11.20662784576416, MSE = 17.88660430908203
In step 3000, epoch 10, with loss 4.222071647644043, MAE = 28.61284828186035, MSE = 32.025062561035156
In step 3010, epoch 10, with loss 6.137884140014648, MAE = 22.326486587524414, MSE = 23.919464111328125
In step 3020, epoch 10, with loss 5.624701023101807, MAE = 19.11991310119629, MSE = 25.844852447509766
In step 3030, epoch 10, with loss 5.811936855316162, MAE = 77.17790985107422, MSE = 80.21056365966797
In step 3040, epoch 10, with loss 9.658990859985352, MAE = 43.20991134643555, MSE = 49.43014907836914
In step 3050, epoch 10, with loss 7.682486534118652, MAE = 40.63908386230469, MSE = 44.459754943847656
In step 3060, epoch 10, with loss 8.864519119262695, MAE = 32.56089401245117, MSE = 43.20405578613281
In step 3070, epoch 10, with loss 5.305100440979004, MAE = 42.0652732849121

In step 3780, epoch 10, with loss 7.2401885986328125, MAE = 30.810649871826172, MSE = 39.782535552978516
In step 3790, epoch 10, with loss 6.978965759277344, MAE = 21.087493896484375, MSE = 33.74000549316406
In step 3800, epoch 10, with loss 8.379776000976562, MAE = 23.266504287719727, MSE = 38.06755828857422
In step 3810, epoch 10, with loss 7.914982795715332, MAE = 23.18767738342285, MSE = 45.89099884033203
In step 3820, epoch 10, with loss 6.526531219482422, MAE = 30.7208309173584, MSE = 43.01472854614258
In step 3830, epoch 10, with loss 7.3217363357543945, MAE = 31.337631225585938, MSE = 47.70077133178711
In step 3840, epoch 10, with loss 7.177504539489746, MAE = 34.026493072509766, MSE = 49.88718032836914
In step 3850, epoch 10, with loss 6.857278347015381, MAE = 30.077505111694336, MSE = 51.79711151123047
In step 3860, epoch 10, with loss 10.847009658813477, MAE = 34.76872253417969, MSE = 55.123558044433594
In step 3870, epoch 10, with loss 6.165469169616699, MAE = 16.0498695373

In step 590, epoch 11, with loss 4.342744827270508, MAE = 21.488277435302734, MSE = 29.317995071411133
In step 600, epoch 11, with loss 4.456053733825684, MAE = 24.44908905029297, MSE = 39.48023986816406
In step 610, epoch 11, with loss 6.061315059661865, MAE = 82.00901794433594, MSE = 84.82199096679688
In step 620, epoch 11, with loss 7.336467742919922, MAE = 60.129966735839844, MSE = 64.01183319091797
In step 630, epoch 11, with loss 8.219220161437988, MAE = 29.587738037109375, MSE = 37.4382209777832
In step 640, epoch 11, with loss 5.172281265258789, MAE = 15.97442626953125, MSE = 20.392471313476562
In step 650, epoch 11, with loss 4.0011677742004395, MAE = 13.037867546081543, MSE = 22.604549407958984
In step 660, epoch 11, with loss 5.386031150817871, MAE = 18.106273651123047, MSE = 27.193397521972656
In step 670, epoch 11, with loss 6.021515846252441, MAE = 17.012088775634766, MSE = 34.785545349121094
In step 680, epoch 11, with loss 7.413495063781738, MAE = 20.388723373413086, MS

In step 1400, epoch 11, with loss 12.969629287719727, MAE = 288.80926513671875, MSE = 296.1475830078125
In step 1410, epoch 11, with loss 7.090479850769043, MAE = 101.06685638427734, MSE = 105.5124740600586
In step 1420, epoch 11, with loss 4.6913628578186035, MAE = 26.914730072021484, MSE = 34.0699577331543
In step 1430, epoch 11, with loss 5.895406246185303, MAE = 41.464080810546875, MSE = 43.154605865478516
In step 1440, epoch 11, with loss 5.509099006652832, MAE = 85.22037506103516, MSE = 86.85616302490234
In step 1450, epoch 11, with loss 7.988458156585693, MAE = 33.35088348388672, MSE = 46.185546875
In step 1460, epoch 11, with loss 12.9091157913208, MAE = 84.7268295288086, MSE = 112.07185363769531
In step 1470, epoch 11, with loss 10.55951976776123, MAE = 68.26082611083984, MSE = 75.94402313232422
In step 1480, epoch 11, with loss 7.507855415344238, MAE = 49.09523010253906, MSE = 58.583831787109375
In step 1490, epoch 11, with loss 6.137695789337158, MAE = 21.753459930419922, MS

In step 2200, epoch 11, with loss 8.444589614868164, MAE = 23.205469131469727, MSE = 35.95711135864258
In step 2210, epoch 11, with loss 5.159206390380859, MAE = 22.05218505859375, MSE = 25.888633728027344
In step 2220, epoch 11, with loss 6.896190643310547, MAE = 23.261058807373047, MSE = 33.07322692871094
In step 2230, epoch 11, with loss 8.535059928894043, MAE = 39.3070182800293, MSE = 54.510986328125
In step 2240, epoch 11, with loss 10.937973976135254, MAE = 28.812511444091797, MSE = 47.47149658203125
In step 2250, epoch 11, with loss 4.173033714294434, MAE = 20.899919509887695, MSE = 23.957138061523438
In step 2260, epoch 11, with loss 6.441204071044922, MAE = 34.42756652832031, MSE = 38.863800048828125
In step 2270, epoch 11, with loss 5.023245811462402, MAE = 11.227087020874023, MSE = 14.831421852111816
In step 2280, epoch 11, with loss 7.746511936187744, MAE = 18.725244522094727, MSE = 25.870128631591797
In step 2290, epoch 11, with loss 7.85500431060791, MAE = 23.449028015136

In step 3000, epoch 11, with loss 10.023309707641602, MAE = 18.514347076416016, MSE = 28.564327239990234
In step 3010, epoch 11, with loss 3.453831195831299, MAE = 9.361475944519043, MSE = 13.37721061706543
In step 3020, epoch 11, with loss 4.536126613616943, MAE = 18.28737449645996, MSE = 22.956876754760742
In step 3030, epoch 11, with loss 8.860072135925293, MAE = 26.931156158447266, MSE = 37.00290298461914
In step 3040, epoch 11, with loss 6.6946611404418945, MAE = 14.536436080932617, MSE = 22.677104949951172
In step 3050, epoch 11, with loss 7.509315490722656, MAE = 22.52597427368164, MSE = 45.1041259765625
In step 3060, epoch 11, with loss 4.792199611663818, MAE = 10.712736129760742, MSE = 12.446718215942383
In step 3070, epoch 11, with loss 5.200579643249512, MAE = 11.136883735656738, MSE = 15.098340034484863
In step 3080, epoch 11, with loss 6.191040992736816, MAE = 31.928714752197266, MSE = 36.30195236206055
In step 3090, epoch 11, with loss 5.383894920349121, MAE = 52.32843017

In [23]:
a = torch.randn(2, 3)
b = torch.randn(2, 3)
print(a, b)
print

tensor([[ 0.4242,  0.7217,  1.1008],
        [-0.0442, -0.6967,  0.4862]]) tensor([[ 0.0191,  1.7691,  0.8167],
        [-1.4132, -0.5933,  0.9567]])
tensor([[0.0081, 1.2768, 0.8991],
        [0.0624, 0.4133, 0.4651]])
tensor([[0.1800, 0.5209, 1.2118],
        [0.0020, 0.4853, 0.2364]])
tensor(-3.2035)
tensor(-1.6017)


In [12]:
print(0.4807 * 0.4807)

0.23107249000000002


In [3]:
# optimizer = torch.optim.Adam()
print(len(list(net.parameters())))

71


In [16]:
a = list(list(net.children())[1])[0]
b = list(list(a.children())[0].children())[2]

In [8]:
x = np.random.rand(1, 3, 4 ,5)
print(x, x.shape)
x = np.swapaxes(x, 1, 2)
x = np.swapaxes(x, 2, 3)
print(x, x.shape)

[[[[0.56534958 0.24370027 0.72369011 0.67889216 0.04249182]
   [0.00842371 0.48885883 0.21877356 0.75751709 0.55534385]
   [0.21609105 0.38360997 0.69587199 0.57972236 0.37004643]
   [0.77530353 0.21747802 0.84950338 0.1185008  0.06807957]]

  [[0.00871294 0.36645846 0.45530013 0.104406   0.56377999]
   [0.48795602 0.76807449 0.06316015 0.01205147 0.9723964 ]
   [0.23015443 0.01537153 0.98212864 0.04200032 0.20598123]
   [0.76175589 0.60508917 0.8725611  0.8562562  0.02366146]]

  [[0.6246715  0.16085285 0.43329321 0.18251999 0.04658025]
   [0.83781238 0.86846604 0.46351249 0.16919308 0.50824695]
   [0.35417069 0.84248133 0.80769374 0.41581033 0.23163003]
   [0.81354884 0.64415368 0.27507783 0.01011959 0.18384952]]]] (1, 3, 4, 5)
[[[[0.56534958 0.00871294 0.6246715 ]
   [0.24370027 0.36645846 0.16085285]
   [0.72369011 0.45530013 0.43329321]
   [0.67889216 0.104406   0.18251999]
   [0.04249182 0.56377999 0.04658025]]

  [[0.00842371 0.48795602 0.83781238]
   [0.48885883 0.76807449 0.86