In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from tqdm import tqdm
import torch.nn.functional as F


In [3]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Apple GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("CUDA GPU")
else:
    device = torch.device('cpu')

Apple GPU


In [4]:
def getData(path):
    train_file = np.load(path+"/train.npz")
    train_data = train_file['data']
    test_file = np.load(path+"/test_input.npz")
    test_data = test_file['data']
    print(f"Training Data's shape is {train_data.shape} and Test Data's is {test_data.shape}")
    return train_data, test_data
trainData, testData = getData("./data/")

Training Data's shape is (10000, 50, 110, 6) and Test Data's is (2100, 50, 50, 6)


In [5]:
class WindowedNormalizedDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]
        X = scene[:, :50, :].copy()
        Y = scene[0, 50:, :2].copy()
        origin = X[0, 49, :2].copy()
        return torch.tensor(X, dtype=torch.float32), torch.tensor(Y, dtype=torch.float32), torch.tensor(origin, dtype=torch.float32)

In [6]:
trainData[0,0,0,:]

array([ 338.59322192, -672.21574762,   -5.32538052,    1.61518358,
          2.84662927,    0.        ])

In [None]:
class EncoderDecoderModel(nn.Module):
    def __init__(self, infeatures = 6, outfeatures = 2):
        super().__init__()
        self.input_upsample = nn.ConvTranspose1d(
            in_channels=infeatures,
            out_channels=infeatures,
            kernel_size=13,
            stride=1,
            padding=1,
            output_padding=0
        )
        self.layer1 = nn.Linear(in_features = infeatures, out_features = 16)
        self.layer2 = nn.Linear(in_features = 16, out_features = 32)
        self.layer3 = nn.Linear(in_features = 32, out_features = 64)
        self.encoderlstm = nn.LSTM(input_size = 64, hidden_size = 128, num_layers = 2, batch_first = True, dropout = 0.3)
        self.relu = nn.ReLU()
        self.decoderlstm = nn.LSTM(input_size = 128, hidden_size = 64, num_layers = 2, batch_first = True, dropout = 0.3)
        self.layer10 = nn.Linear(in_features = 64, out_features = 32)
        self.layer11 = nn.Linear(in_features = 32, out_features = 16)
        self.layer12 = nn.Linear(in_features = 16, out_features = outfeatures)

        # layer 1 and layer 12 skip
        # layer 2 and layer 11 skip
        # layer 3 and layer 10 skip
        # encoder lstm to decoder lstm skip
        self.skip1 = nn.Linear(in_features = 16, out_features = 16)
        self.skip2 = nn.Linear(in_features = 32, out_features = 32)
        self.skip3 = nn.Linear(in_features = 64, out_features = 64)
        self.skip4 = nn.Linear(in_features = 128, out_features = 128)
        

    def forward(self, x):
        # x = data.x
        # print(x.shape)
        batch_size, channels, height, width = x.shape
        x = x[:, 0, :, :]
        # x = F.interpolate(x.permute(0, 2, 1), size=60, mode='linear').permute(0, 2, 1)  # [128, 60, 6]
        # batch_size, channels, height, width = x.shape
        # x = x[:, 0, :, :]  # Shape: (batch, 50, 6)
        
        # Learnable upsampling from 50 to 60 time steps
        x = x.permute(0, 2, 1)  # (batch, 6, 50)
        x = self.input_upsample(x)  # (batch, 6, 60)
        x = x.permute(0, 2, 1)  # (batch, 60, 6)


        out1 = self.layer1(x)
        out1 = nn.ReLU()(out1)
        
        out2 = self.layer2(out1)
        out2 = nn.ReLU()(out2)
        
        out3 = self.layer3(out2)
        out3 = nn.ReLU()(out3)
        # print("out 3: ", out3.shape)
        
        out4, _ = self.encoderlstm(out3)
        # print("lstm Last: ", out4[:, -1, :].shape)
        # print("Out4 View : ", out4.view(batch_size, 60, batch_size))
        # print("out 4: ", out4.shape)

         
        out5, _ = self.decoderlstm(out4)
        # print("out 5: ", out5.shape)
        
        # print(out3_reduced.shape, out5.shape)
        mlpskip1 = out3 + self.skip3(out5)
        # print("MLP Skip 1: ", mlpskip1.shape)

        out6 = self.layer10(mlpskip1)
        out6 = nn.ReLU()(out6)
        # print("Out6: ", out6.shape)

        # out2_pooled = F.adaptive_avg_pool2d(out2.permute(0, 3, 1, 2), (60, 1))  # [128, 64, 20, 1]
        # out2_reduced = out2_pooled.squeeze(-1).permute(0, 2, 1)  # [128, 20, 64]

        mlpskip2 = out2 + self.skip2(out6)
        # print("MLP Skip 2: ", mlpskip2.shape)
        out7 = self.layer11(mlpskip2)
        out7 = nn.ReLU()(out7)
        # print("Out 7: ", out7.shape)

        
        mlpskip3 = out1 + self.skip1(out7)
        # print("MLP Skip 3: ",mlpskip3.shape)
        out8 = self.layer12(mlpskip3)
        # print("Out 8: ", out7.shape)
        # out9 = out.view(-1, 60, 2)
        # print("Out 9: ", out9.shape)

        return out8

# model = EncoderDecoderModel(6, 2)
model = EncoderDecoderModel()
# model.apply(xavier_init_weights)

test = torch.randn(128, 50, 50, 6)
out = model(test)
print(out.shape)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

torch.Size([128, 60, 2])
Total parameters: 342252


In [None]:
class SimpleModel(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=128, output_dim=60 * 2):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        # self.fc2 = nn.Linear(hidden_dim, 2*hidden_dim)
        self.lstm1 = nn.LSTM(hidden_dim, hidden_dim, num_layers = 2, batch_first=True)
        # self.lstm2 = nn.LSTM(2*hidden_dim, 2*hidden_dim, num_layers = 2, batch_first=True)
        # self.fc3 = nn.Linear(2*hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_dim)
        

    def forward(self, x):
        # x = data.x
        # print(x.shape)
        x= x.reshape(-1, 50, 50, 6)  # (batch_size, num_agents, seq_len, input_dim)
        x = x[:, 0, :, :] # Only Consider ego agent index 0

        x = nn.ReLU()(self.fc1(x))
        # x = nn.ReLU()(self.fc2(x))
        x, _ = self.lstm1(x)
        # x, _ = self.lstm2(x)
        # x = nn.ReLU()(self.fc3(x))
        
        # print("LSTM Out: ", lstm_out.shape)
        # lstm_out is of shape (batch_size, seq_len, hidden_dim) and we want the last time step output
        # print("LSTM last: ", lstm_out[:, -1, :].shape)
        out = self.fc4(x[:, -1, :])
        return out.view(-1, 60, 2)

model = SimpleModel()
test = torch.randn(128, 50, 50, 6)
out = model(test)
print(out.shape)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

torch.Size([128, 60, 2])
Total parameters: 280568


In [None]:
# model = EncoderDecoderModel().to(device)
model = SimpleModel().to(device)
# model = MLP(50 * 50 * 6, 60 * 2).to(device)
# model = LSTM().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.25) # You can try different schedulers
early_stopping_patience = 10
best_val_loss = float('inf')
no_improvement = 0
criterion = nn.MSELoss()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
torch.manual_seed(42)
val_ratio = 0.1
N = len(trainData)
val_size = int(val_ratio * N)
train_size = N - val_size
train_dataset = WindowedNormalizedDataset(trainData[:train_size])
valid_dataset = WindowedNormalizedDataset(trainData[train_size:])
# train_dataset.__len__(), valid_dataset.__len__()
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(valid_dataset, batch_size=128, shuffle=False)
# next(iter(train_dataloader))

In [None]:
epochs = 1000
lossFn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.25) # You can try different schedulers
best_val_loss = float('inf')
best_train_loss = float('inf')

for each_epoch in range(epochs):
    model.train()
    runningLoss = 0.0
    loop = tqdm(train_dataloader, desc=f"Epoch [{each_epoch+1}/{epochs}]")
    scale = 10
    for batchX, batchY, origin in loop:
        batchX = batchX.to(device, non_blocking=True)   
        batchY = batchY.to(device, non_blocking=True)  
        origin = origin.to(device, non_blocking=True)  
        # print(batchX.shape, batchY.shape, origin.shape)
        batchX[..., :2] = batchX[..., :2] - origin.unsqueeze(1).unsqueeze(1)
        batchY = batchY - origin.unsqueeze(1)
        batchX[..., :4] = batchX[..., :4] / 10
        batchY = batchY / 10
        pred = model(batchX)
        # pred = pred.to(device)
        loss = lossFn(pred, batchY).to(device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        runningLoss += loss.item()        
        # print(pred.shape)
        # break
    # break
    
    model.eval()
    val_loss = 0
    val_mae = 0
    val_mse = 0
    with torch.no_grad():
        for batchX, batchY, origin in val_dataloader:
            batchX = batchX.to(device, non_blocking=True)   
            batchY = batchY.to(device, non_blocking=True)  
            origin = origin.to(device, non_blocking=True) 
            # print(batchX.shape, batchY.shape, origin.shape)
            batchX[..., :2] = batchX[..., :2] - origin.unsqueeze(1).unsqueeze(1)
            batchY = batchY - origin.unsqueeze(1)
            batchX[..., :4] = batchX[..., :4] / 10
            batchY = batchY / 10
            pred = model(batchX)
            loss = lossFn(pred, batchY)
            
            val_loss += loss.item()
            # print(pred[0, :5, :2])
            # print(pred.shape)    
            pred = pred * 10
            # print(pred[0, :5, :2])
            pred = pred + origin.unsqueeze(1)
            # print(pred[0, :5, :2])
            batchY_unnorm = batchY * 10 + origin.unsqueeze(1)
            val_mae += nn.L1Loss()(pred, batchY_unnorm).item()
            val_mse += nn.MSELoss()(pred, batchY_unnorm).item()
            
    train_loss = runningLoss/len(train_dataloader)
    val_loss /= len(val_dataloader)
    val_mae /= len(val_dataloader)
    val_mse /= len(val_dataloader)
    loop.write(f" train normalized MSE {train_loss:8.4f} | val normalized MSE {val_loss:8.4f} | val MAE {val_mae:8.4f} | val MSE {val_mse:8.4f}")
    scheduler.step()
    scheduler.step(val_loss)
    
    if train_loss < best_train_loss and val_loss < best_val_loss - 1e-3:
        best_val_loss = val_loss
        best_train_loss = train_loss
        no_improvement = 0
        torch.save(model.state_dict(), "./models/modelM/best_model.pt")
    # else:
    #     no_improvement += 1
    # break
    
    

Epoch [1/1000]: 100%|██████████| 71/71 [00:01<00:00, 68.81it/s]


 train normalized MSE   0.3378 | val normalized MSE   0.2604 | val MAE   3.0064 | val MSE  26.0410


Epoch [2/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.22it/s]


 train normalized MSE   0.2211 | val normalized MSE   0.1884 | val MAE   2.4054 | val MSE  18.8430


Epoch [3/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.12it/s]


 train normalized MSE   0.1910 | val normalized MSE   0.1865 | val MAE   2.4042 | val MSE  18.6540


Epoch [4/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.15it/s]


 train normalized MSE   0.1733 | val normalized MSE   0.1632 | val MAE   2.2320 | val MSE  16.3232


Epoch [5/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.57it/s]


 train normalized MSE   0.1709 | val normalized MSE   0.1610 | val MAE   2.2479 | val MSE  16.1032


Epoch [6/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.81it/s]


 train normalized MSE   0.1547 | val normalized MSE   0.1459 | val MAE   2.0550 | val MSE  14.5899


Epoch [7/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.74it/s]


 train normalized MSE   0.1492 | val normalized MSE   0.1393 | val MAE   2.0572 | val MSE  13.9343


Epoch [8/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.50it/s]


 train normalized MSE   0.1410 | val normalized MSE   0.1260 | val MAE   1.8643 | val MSE  12.5983


Epoch [9/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.05it/s]


 train normalized MSE   0.1367 | val normalized MSE   0.1309 | val MAE   1.9857 | val MSE  13.0901


Epoch [10/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.77it/s]


 train normalized MSE   0.1323 | val normalized MSE   0.1271 | val MAE   1.8824 | val MSE  12.7082


Epoch [11/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.05it/s]


 train normalized MSE   0.1311 | val normalized MSE   0.1273 | val MAE   1.8635 | val MSE  12.7325


Epoch [12/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.33it/s]


 train normalized MSE   0.1335 | val normalized MSE   0.1314 | val MAE   1.9305 | val MSE  13.1359


Epoch [13/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.68it/s]


 train normalized MSE   0.1246 | val normalized MSE   0.1183 | val MAE   1.8372 | val MSE  11.8298


Epoch [14/1000]: 100%|██████████| 71/71 [00:01<00:00, 68.98it/s]


 train normalized MSE   0.1231 | val normalized MSE   0.1200 | val MAE   1.8828 | val MSE  12.0042


Epoch [15/1000]: 100%|██████████| 71/71 [00:01<00:00, 67.23it/s]


 train normalized MSE   0.1190 | val normalized MSE   0.1153 | val MAE   1.8449 | val MSE  11.5277


Epoch [16/1000]: 100%|██████████| 71/71 [00:01<00:00, 65.26it/s]


 train normalized MSE   0.1179 | val normalized MSE   0.1122 | val MAE   1.7735 | val MSE  11.2237


Epoch [17/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.96it/s]


 train normalized MSE   0.1151 | val normalized MSE   0.1129 | val MAE   1.8015 | val MSE  11.2924


Epoch [18/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.33it/s]


 train normalized MSE   0.1136 | val normalized MSE   0.1100 | val MAE   1.7583 | val MSE  11.0015


Epoch [19/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.91it/s]


 train normalized MSE   0.1109 | val normalized MSE   0.1115 | val MAE   1.7518 | val MSE  11.1550


Epoch [20/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.12it/s]


 train normalized MSE   0.1116 | val normalized MSE   0.1162 | val MAE   1.8385 | val MSE  11.6218


Epoch [21/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.73it/s]


 train normalized MSE   0.1109 | val normalized MSE   0.1110 | val MAE   1.7412 | val MSE  11.0987


Epoch [22/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.54it/s]


 train normalized MSE   0.1088 | val normalized MSE   0.1069 | val MAE   1.6649 | val MSE  10.6851


Epoch [23/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.36it/s]


 train normalized MSE   0.1082 | val normalized MSE   0.1110 | val MAE   1.7771 | val MSE  11.0987


Epoch [24/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.20it/s]


 train normalized MSE   0.1067 | val normalized MSE   0.1104 | val MAE   1.7604 | val MSE  11.0356


Epoch [25/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.30it/s]


 train normalized MSE   0.1048 | val normalized MSE   0.1069 | val MAE   1.6844 | val MSE  10.6858


Epoch [26/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.21it/s]


 train normalized MSE   0.1033 | val normalized MSE   0.1040 | val MAE   1.6868 | val MSE  10.3971


Epoch [27/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.70it/s]


 train normalized MSE   0.1055 | val normalized MSE   0.1031 | val MAE   1.6692 | val MSE  10.3112


Epoch [28/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.44it/s]


 train normalized MSE   0.1037 | val normalized MSE   0.1096 | val MAE   1.7582 | val MSE  10.9566


Epoch [29/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.80it/s]


 train normalized MSE   0.1026 | val normalized MSE   0.1094 | val MAE   1.7352 | val MSE  10.9395


Epoch [30/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.27it/s]


 train normalized MSE   0.1013 | val normalized MSE   0.1143 | val MAE   1.8138 | val MSE  11.4334


Epoch [31/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.62it/s]


 train normalized MSE   0.1004 | val normalized MSE   0.1039 | val MAE   1.6818 | val MSE  10.3950


Epoch [32/1000]: 100%|██████████| 71/71 [00:01<00:00, 67.30it/s]


 train normalized MSE   0.0991 | val normalized MSE   0.1046 | val MAE   1.6633 | val MSE  10.4591


Epoch [33/1000]: 100%|██████████| 71/71 [00:01<00:00, 67.53it/s]


 train normalized MSE   0.0995 | val normalized MSE   0.1033 | val MAE   1.6553 | val MSE  10.3259


Epoch [34/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.56it/s]


 train normalized MSE   0.0987 | val normalized MSE   0.1052 | val MAE   1.6860 | val MSE  10.5223


Epoch [35/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.94it/s]


 train normalized MSE   0.0960 | val normalized MSE   0.1022 | val MAE   1.6487 | val MSE  10.2207


Epoch [36/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.78it/s]


 train normalized MSE   0.0958 | val normalized MSE   0.1027 | val MAE   1.6058 | val MSE  10.2691


Epoch [37/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.10it/s]


 train normalized MSE   0.0955 | val normalized MSE   0.0985 | val MAE   1.6032 | val MSE   9.8477


Epoch [38/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.21it/s]


 train normalized MSE   0.0952 | val normalized MSE   0.1032 | val MAE   1.6804 | val MSE  10.3183


Epoch [39/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.89it/s]


 train normalized MSE   0.0948 | val normalized MSE   0.1053 | val MAE   1.7337 | val MSE  10.5333


Epoch [40/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.95it/s]


 train normalized MSE   0.0917 | val normalized MSE   0.1005 | val MAE   1.6069 | val MSE  10.0530


Epoch [41/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.05it/s]


 train normalized MSE   0.0929 | val normalized MSE   0.1026 | val MAE   1.7205 | val MSE  10.2583


Epoch [42/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.23it/s]


 train normalized MSE   0.0935 | val normalized MSE   0.1046 | val MAE   1.7041 | val MSE  10.4600


Epoch [43/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.92it/s]


 train normalized MSE   0.0931 | val normalized MSE   0.1015 | val MAE   1.6767 | val MSE  10.1497


Epoch [44/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.05it/s]


 train normalized MSE   0.0921 | val normalized MSE   0.0989 | val MAE   1.5914 | val MSE   9.8851


Epoch [45/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.57it/s]


 train normalized MSE   0.0919 | val normalized MSE   0.1078 | val MAE   1.7571 | val MSE  10.7831


Epoch [46/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.83it/s]


 train normalized MSE   0.0904 | val normalized MSE   0.1050 | val MAE   1.6783 | val MSE  10.5040


Epoch [47/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.19it/s]


 train normalized MSE   0.0893 | val normalized MSE   0.1007 | val MAE   1.6177 | val MSE  10.0660


Epoch [48/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.41it/s]


 train normalized MSE   0.0885 | val normalized MSE   0.0957 | val MAE   1.5800 | val MSE   9.5686


Epoch [49/1000]: 100%|██████████| 71/71 [00:01<00:00, 65.55it/s]


 train normalized MSE   0.0903 | val normalized MSE   0.1026 | val MAE   1.7148 | val MSE  10.2589


Epoch [50/1000]: 100%|██████████| 71/71 [00:01<00:00, 59.06it/s]


 train normalized MSE   0.0892 | val normalized MSE   0.0968 | val MAE   1.6189 | val MSE   9.6759


Epoch [51/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.65it/s]


 train normalized MSE   0.0890 | val normalized MSE   0.0982 | val MAE   1.6424 | val MSE   9.8165


Epoch [52/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.92it/s]


 train normalized MSE   0.0873 | val normalized MSE   0.0971 | val MAE   1.5821 | val MSE   9.7117


Epoch [53/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.65it/s]


 train normalized MSE   0.0863 | val normalized MSE   0.0994 | val MAE   1.6089 | val MSE   9.9408


Epoch [54/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.62it/s]


 train normalized MSE   0.0868 | val normalized MSE   0.1014 | val MAE   1.6486 | val MSE  10.1439


Epoch [55/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.57it/s]


 train normalized MSE   0.0863 | val normalized MSE   0.1009 | val MAE   1.6424 | val MSE  10.0901


Epoch [56/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.28it/s]


 train normalized MSE   0.0852 | val normalized MSE   0.1008 | val MAE   1.6668 | val MSE  10.0804


Epoch [57/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.79it/s]


 train normalized MSE   0.0852 | val normalized MSE   0.1000 | val MAE   1.6744 | val MSE  10.0035


Epoch [58/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.55it/s]


 train normalized MSE   0.0862 | val normalized MSE   0.1050 | val MAE   1.7254 | val MSE  10.4973


Epoch [59/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.11it/s]


 train normalized MSE   0.0852 | val normalized MSE   0.0960 | val MAE   1.5912 | val MSE   9.6000


Epoch [60/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.04it/s]


 train normalized MSE   0.0854 | val normalized MSE   0.0980 | val MAE   1.6128 | val MSE   9.8012


Epoch [61/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.04it/s]


 train normalized MSE   0.0834 | val normalized MSE   0.0941 | val MAE   1.5764 | val MSE   9.4110


Epoch [62/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.50it/s]


 train normalized MSE   0.0816 | val normalized MSE   0.0980 | val MAE   1.5844 | val MSE   9.7951


Epoch [63/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.10it/s]


 train normalized MSE   0.0818 | val normalized MSE   0.1048 | val MAE   1.6749 | val MSE  10.4812


Epoch [64/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.49it/s]


 train normalized MSE   0.0832 | val normalized MSE   0.1008 | val MAE   1.6144 | val MSE  10.0818


Epoch [65/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.18it/s]


 train normalized MSE   0.0830 | val normalized MSE   0.0991 | val MAE   1.6180 | val MSE   9.9097


Epoch [66/1000]: 100%|██████████| 71/71 [00:01<00:00, 68.89it/s]


 train normalized MSE   0.0834 | val normalized MSE   0.1062 | val MAE   1.7309 | val MSE  10.6225


Epoch [67/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.28it/s]


 train normalized MSE   0.0827 | val normalized MSE   0.1006 | val MAE   1.6415 | val MSE  10.0578


Epoch [68/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.98it/s]


 train normalized MSE   0.0825 | val normalized MSE   0.1063 | val MAE   1.7319 | val MSE  10.6262


Epoch [69/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.02it/s]


 train normalized MSE   0.0800 | val normalized MSE   0.1003 | val MAE   1.6112 | val MSE  10.0334


Epoch [70/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.46it/s]


 train normalized MSE   0.0812 | val normalized MSE   0.0995 | val MAE   1.6147 | val MSE   9.9468


Epoch [71/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.47it/s]


 train normalized MSE   0.0809 | val normalized MSE   0.0996 | val MAE   1.6134 | val MSE   9.9629


Epoch [72/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.79it/s]


 train normalized MSE   0.0792 | val normalized MSE   0.0970 | val MAE   1.5662 | val MSE   9.7018


Epoch [73/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.29it/s]


 train normalized MSE   0.0805 | val normalized MSE   0.0989 | val MAE   1.6093 | val MSE   9.8922


Epoch [74/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.92it/s]


 train normalized MSE   0.0794 | val normalized MSE   0.1058 | val MAE   1.6295 | val MSE  10.5780


Epoch [75/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.79it/s]


 train normalized MSE   0.0794 | val normalized MSE   0.0971 | val MAE   1.6070 | val MSE   9.7114


Epoch [76/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.16it/s]


 train normalized MSE   0.0784 | val normalized MSE   0.1001 | val MAE   1.5949 | val MSE  10.0087


Epoch [77/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.88it/s]


 train normalized MSE   0.0796 | val normalized MSE   0.0972 | val MAE   1.6002 | val MSE   9.7202


Epoch [78/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.42it/s]


 train normalized MSE   0.0778 | val normalized MSE   0.1059 | val MAE   1.6720 | val MSE  10.5898


Epoch [79/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.19it/s]


 train normalized MSE   0.0790 | val normalized MSE   0.0982 | val MAE   1.6062 | val MSE   9.8155


Epoch [80/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.88it/s]


 train normalized MSE   0.0771 | val normalized MSE   0.0968 | val MAE   1.5952 | val MSE   9.6790


Epoch [81/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.87it/s]


 train normalized MSE   0.0771 | val normalized MSE   0.0998 | val MAE   1.6350 | val MSE   9.9847


Epoch [82/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.25it/s]


 train normalized MSE   0.0781 | val normalized MSE   0.1013 | val MAE   1.6284 | val MSE  10.1263


Epoch [83/1000]: 100%|██████████| 71/71 [00:01<00:00, 67.22it/s]


 train normalized MSE   0.0762 | val normalized MSE   0.1014 | val MAE   1.6529 | val MSE  10.1445


Epoch [84/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.19it/s]


 train normalized MSE   0.0773 | val normalized MSE   0.1020 | val MAE   1.6184 | val MSE  10.1977


Epoch [85/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.85it/s]


 train normalized MSE   0.0763 | val normalized MSE   0.1013 | val MAE   1.6906 | val MSE  10.1273


Epoch [86/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.78it/s]


 train normalized MSE   0.0764 | val normalized MSE   0.0998 | val MAE   1.6261 | val MSE   9.9831


Epoch [87/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.89it/s]


 train normalized MSE   0.0765 | val normalized MSE   0.1033 | val MAE   1.6284 | val MSE  10.3308


Epoch [88/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.71it/s]


 train normalized MSE   0.0769 | val normalized MSE   0.0992 | val MAE   1.6088 | val MSE   9.9242


Epoch [89/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.02it/s]


 train normalized MSE   0.0756 | val normalized MSE   0.0997 | val MAE   1.5766 | val MSE   9.9727


Epoch [90/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.41it/s]


 train normalized MSE   0.0739 | val normalized MSE   0.0978 | val MAE   1.6103 | val MSE   9.7848


Epoch [91/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.61it/s]


 train normalized MSE   0.0758 | val normalized MSE   0.1013 | val MAE   1.6224 | val MSE  10.1293


Epoch [92/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.80it/s]


 train normalized MSE   0.0754 | val normalized MSE   0.0997 | val MAE   1.6311 | val MSE   9.9706


Epoch [93/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.69it/s]


 train normalized MSE   0.0766 | val normalized MSE   0.0990 | val MAE   1.5834 | val MSE   9.8981


Epoch [94/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.77it/s]


 train normalized MSE   0.0748 | val normalized MSE   0.1037 | val MAE   1.6658 | val MSE  10.3701


Epoch [95/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.36it/s]


 train normalized MSE   0.0758 | val normalized MSE   0.1038 | val MAE   1.6378 | val MSE  10.3849


Epoch [96/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.60it/s]


 train normalized MSE   0.0733 | val normalized MSE   0.1004 | val MAE   1.6498 | val MSE  10.0393


Epoch [97/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.30it/s]


 train normalized MSE   0.0744 | val normalized MSE   0.1001 | val MAE   1.5717 | val MSE  10.0053


Epoch [98/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.57it/s]


 train normalized MSE   0.0725 | val normalized MSE   0.1033 | val MAE   1.7094 | val MSE  10.3299


Epoch [99/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.88it/s]


 train normalized MSE   0.0758 | val normalized MSE   0.0996 | val MAE   1.6336 | val MSE   9.9586


Epoch [100/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.66it/s]


 train normalized MSE   0.0734 | val normalized MSE   0.0992 | val MAE   1.5731 | val MSE   9.9223


Epoch [101/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.21it/s]


 train normalized MSE   0.0722 | val normalized MSE   0.1053 | val MAE   1.7480 | val MSE  10.5323


Epoch [102/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.64it/s]


 train normalized MSE   0.0743 | val normalized MSE   0.1013 | val MAE   1.6741 | val MSE  10.1261


Epoch [103/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.00it/s]


 train normalized MSE   0.0714 | val normalized MSE   0.1018 | val MAE   1.6404 | val MSE  10.1850


Epoch [104/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.87it/s]


 train normalized MSE   0.0725 | val normalized MSE   0.0997 | val MAE   1.6064 | val MSE   9.9662


Epoch [105/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.73it/s]


 train normalized MSE   0.0729 | val normalized MSE   0.1014 | val MAE   1.6217 | val MSE  10.1376


Epoch [106/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.57it/s]


 train normalized MSE   0.0722 | val normalized MSE   0.0993 | val MAE   1.6262 | val MSE   9.9294


Epoch [107/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.85it/s]


 train normalized MSE   0.0733 | val normalized MSE   0.1043 | val MAE   1.6726 | val MSE  10.4270


Epoch [108/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.00it/s]


 train normalized MSE   0.0704 | val normalized MSE   0.1016 | val MAE   1.6038 | val MSE  10.1555


Epoch [109/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.09it/s]


 train normalized MSE   0.0703 | val normalized MSE   0.0988 | val MAE   1.5901 | val MSE   9.8763


Epoch [110/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.97it/s]


 train normalized MSE   0.0714 | val normalized MSE   0.0983 | val MAE   1.6045 | val MSE   9.8279


Epoch [111/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.59it/s]


 train normalized MSE   0.0720 | val normalized MSE   0.1012 | val MAE   1.6497 | val MSE  10.1241


Epoch [112/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.53it/s]


 train normalized MSE   0.0696 | val normalized MSE   0.1006 | val MAE   1.6142 | val MSE  10.0637


Epoch [113/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.88it/s]


 train normalized MSE   0.0695 | val normalized MSE   0.0988 | val MAE   1.5656 | val MSE   9.8838


Epoch [114/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.43it/s]


 train normalized MSE   0.0691 | val normalized MSE   0.1025 | val MAE   1.6570 | val MSE  10.2473


Epoch [115/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.62it/s]


 train normalized MSE   0.0696 | val normalized MSE   0.0977 | val MAE   1.5494 | val MSE   9.7714


Epoch [116/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.06it/s]


 train normalized MSE   0.0695 | val normalized MSE   0.1017 | val MAE   1.6022 | val MSE  10.1706


Epoch [117/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.20it/s]


 train normalized MSE   0.0701 | val normalized MSE   0.1059 | val MAE   1.6749 | val MSE  10.5915


Epoch [118/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.57it/s]


 train normalized MSE   0.0694 | val normalized MSE   0.1003 | val MAE   1.6616 | val MSE  10.0270


Epoch [119/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.80it/s]


 train normalized MSE   0.0717 | val normalized MSE   0.1033 | val MAE   1.6719 | val MSE  10.3332


Epoch [120/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.88it/s]


 train normalized MSE   0.0705 | val normalized MSE   0.1007 | val MAE   1.6151 | val MSE  10.0701


Epoch [121/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.91it/s]


 train normalized MSE   0.0692 | val normalized MSE   0.1012 | val MAE   1.6347 | val MSE  10.1222


Epoch [122/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.80it/s]


 train normalized MSE   0.0686 | val normalized MSE   0.1008 | val MAE   1.6191 | val MSE  10.0825


Epoch [123/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.29it/s]


 train normalized MSE   0.0679 | val normalized MSE   0.0992 | val MAE   1.6023 | val MSE   9.9196


Epoch [124/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.09it/s]


 train normalized MSE   0.0668 | val normalized MSE   0.1011 | val MAE   1.6236 | val MSE  10.1070


Epoch [125/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.29it/s]


 train normalized MSE   0.0672 | val normalized MSE   0.1040 | val MAE   1.6076 | val MSE  10.3988


Epoch [126/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.92it/s]


 train normalized MSE   0.0668 | val normalized MSE   0.0996 | val MAE   1.6092 | val MSE   9.9624


Epoch [127/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.86it/s]


 train normalized MSE   0.0669 | val normalized MSE   0.0980 | val MAE   1.5630 | val MSE   9.7968


Epoch [128/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.85it/s]


 train normalized MSE   0.0678 | val normalized MSE   0.0983 | val MAE   1.5925 | val MSE   9.8302


Epoch [129/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.73it/s]


 train normalized MSE   0.0667 | val normalized MSE   0.0985 | val MAE   1.6170 | val MSE   9.8490


Epoch [130/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.76it/s]


 train normalized MSE   0.0663 | val normalized MSE   0.1048 | val MAE   1.6477 | val MSE  10.4753


Epoch [131/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.04it/s]


 train normalized MSE   0.0692 | val normalized MSE   0.0960 | val MAE   1.5501 | val MSE   9.6005


Epoch [132/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.75it/s]


 train normalized MSE   0.0680 | val normalized MSE   0.1010 | val MAE   1.6157 | val MSE  10.1036


Epoch [133/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.19it/s]


 train normalized MSE   0.0667 | val normalized MSE   0.1001 | val MAE   1.6184 | val MSE  10.0096


Epoch [134/1000]: 100%|██████████| 71/71 [00:01<00:00, 67.82it/s]


 train normalized MSE   0.0667 | val normalized MSE   0.1011 | val MAE   1.5924 | val MSE  10.1115


Epoch [135/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.99it/s]


 train normalized MSE   0.0673 | val normalized MSE   0.1047 | val MAE   1.6208 | val MSE  10.4691


Epoch [136/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.48it/s]


 train normalized MSE   0.0667 | val normalized MSE   0.0973 | val MAE   1.5476 | val MSE   9.7341


Epoch [137/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.58it/s]


 train normalized MSE   0.0645 | val normalized MSE   0.0981 | val MAE   1.6258 | val MSE   9.8081


Epoch [138/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.22it/s]


 train normalized MSE   0.0648 | val normalized MSE   0.1039 | val MAE   1.5905 | val MSE  10.3882


Epoch [139/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.98it/s]


 train normalized MSE   0.0662 | val normalized MSE   0.1027 | val MAE   1.6160 | val MSE  10.2705


Epoch [140/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.96it/s]


 train normalized MSE   0.0656 | val normalized MSE   0.1024 | val MAE   1.5922 | val MSE  10.2357


Epoch [141/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.67it/s]


 train normalized MSE   0.0641 | val normalized MSE   0.1046 | val MAE   1.6425 | val MSE  10.4580


Epoch [142/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.98it/s]


 train normalized MSE   0.0639 | val normalized MSE   0.0980 | val MAE   1.5736 | val MSE   9.7981


Epoch [143/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.82it/s]


 train normalized MSE   0.0648 | val normalized MSE   0.0971 | val MAE   1.5546 | val MSE   9.7132


Epoch [144/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.26it/s]


 train normalized MSE   0.0642 | val normalized MSE   0.0971 | val MAE   1.5656 | val MSE   9.7057


Epoch [145/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.98it/s]


 train normalized MSE   0.0653 | val normalized MSE   0.1017 | val MAE   1.5777 | val MSE  10.1744


Epoch [146/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.76it/s]


 train normalized MSE   0.0632 | val normalized MSE   0.0977 | val MAE   1.5724 | val MSE   9.7748


Epoch [147/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.04it/s]


 train normalized MSE   0.0623 | val normalized MSE   0.1030 | val MAE   1.5892 | val MSE  10.2975


Epoch [148/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.17it/s]


 train normalized MSE   0.0618 | val normalized MSE   0.1050 | val MAE   1.6856 | val MSE  10.5049


Epoch [149/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.24it/s]


 train normalized MSE   0.0649 | val normalized MSE   0.1043 | val MAE   1.6803 | val MSE  10.4261


Epoch [150/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.70it/s]


 train normalized MSE   0.0621 | val normalized MSE   0.1006 | val MAE   1.6092 | val MSE  10.0561


Epoch [151/1000]: 100%|██████████| 71/71 [00:01<00:00, 68.42it/s]


 train normalized MSE   0.0621 | val normalized MSE   0.1014 | val MAE   1.6155 | val MSE  10.1370


Epoch [152/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.84it/s]


 train normalized MSE   0.0633 | val normalized MSE   0.0995 | val MAE   1.6382 | val MSE   9.9459


Epoch [153/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.57it/s]


 train normalized MSE   0.0626 | val normalized MSE   0.1017 | val MAE   1.5747 | val MSE  10.1742


Epoch [154/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.14it/s]


 train normalized MSE   0.0596 | val normalized MSE   0.1038 | val MAE   1.6263 | val MSE  10.3837


Epoch [155/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.32it/s]


 train normalized MSE   0.0609 | val normalized MSE   0.1047 | val MAE   1.6474 | val MSE  10.4740


Epoch [156/1000]: 100%|██████████| 71/71 [00:01<00:00, 67.88it/s]


 train normalized MSE   0.0620 | val normalized MSE   0.1008 | val MAE   1.6184 | val MSE  10.0842


Epoch [157/1000]: 100%|██████████| 71/71 [00:01<00:00, 66.84it/s]


 train normalized MSE   0.0621 | val normalized MSE   0.1001 | val MAE   1.6043 | val MSE  10.0099


Epoch [158/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.10it/s]


 train normalized MSE   0.0597 | val normalized MSE   0.0975 | val MAE   1.5719 | val MSE   9.7524


Epoch [159/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.73it/s]


 train normalized MSE   0.0587 | val normalized MSE   0.1025 | val MAE   1.5838 | val MSE  10.2469


Epoch [160/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.24it/s]


 train normalized MSE   0.0614 | val normalized MSE   0.1000 | val MAE   1.6199 | val MSE   9.9966


Epoch [161/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.70it/s]


 train normalized MSE   0.0608 | val normalized MSE   0.1011 | val MAE   1.5907 | val MSE  10.1145


Epoch [162/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.64it/s]


 train normalized MSE   0.0611 | val normalized MSE   0.1042 | val MAE   1.6705 | val MSE  10.4171


Epoch [163/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.10it/s]


 train normalized MSE   0.0611 | val normalized MSE   0.1021 | val MAE   1.5781 | val MSE  10.2081


Epoch [164/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.38it/s]


 train normalized MSE   0.0593 | val normalized MSE   0.1056 | val MAE   1.6257 | val MSE  10.5565


Epoch [165/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.38it/s]


 train normalized MSE   0.0605 | val normalized MSE   0.1016 | val MAE   1.5967 | val MSE  10.1622


Epoch [166/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.62it/s]


 train normalized MSE   0.0612 | val normalized MSE   0.1030 | val MAE   1.6013 | val MSE  10.3022


Epoch [167/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.49it/s]


 train normalized MSE   0.0598 | val normalized MSE   0.0996 | val MAE   1.5772 | val MSE   9.9641


Epoch [168/1000]: 100%|██████████| 71/71 [00:01<00:00, 68.16it/s]


 train normalized MSE   0.0588 | val normalized MSE   0.1049 | val MAE   1.6284 | val MSE  10.4889


Epoch [169/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.32it/s]


 train normalized MSE   0.0562 | val normalized MSE   0.1058 | val MAE   1.6072 | val MSE  10.5842


Epoch [170/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.04it/s]


 train normalized MSE   0.0588 | val normalized MSE   0.1031 | val MAE   1.5927 | val MSE  10.3140


Epoch [171/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.19it/s]


 train normalized MSE   0.0589 | val normalized MSE   0.1014 | val MAE   1.5818 | val MSE  10.1397


Epoch [172/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.28it/s]


 train normalized MSE   0.0585 | val normalized MSE   0.1023 | val MAE   1.5945 | val MSE  10.2327


Epoch [173/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.10it/s]


 train normalized MSE   0.0583 | val normalized MSE   0.1059 | val MAE   1.6905 | val MSE  10.5890


Epoch [174/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.31it/s]


 train normalized MSE   0.0584 | val normalized MSE   0.1066 | val MAE   1.6502 | val MSE  10.6643


Epoch [175/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.14it/s]


 train normalized MSE   0.0573 | val normalized MSE   0.1061 | val MAE   1.6329 | val MSE  10.6055


Epoch [176/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.71it/s]


 train normalized MSE   0.0582 | val normalized MSE   0.1028 | val MAE   1.6151 | val MSE  10.2750


Epoch [177/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.37it/s]


 train normalized MSE   0.0560 | val normalized MSE   0.1080 | val MAE   1.6512 | val MSE  10.8049


Epoch [178/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.38it/s]


 train normalized MSE   0.0566 | val normalized MSE   0.1020 | val MAE   1.6053 | val MSE  10.1994


Epoch [179/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.38it/s]


 train normalized MSE   0.0575 | val normalized MSE   0.1004 | val MAE   1.6332 | val MSE  10.0369


Epoch [180/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.65it/s]


 train normalized MSE   0.0575 | val normalized MSE   0.1034 | val MAE   1.6101 | val MSE  10.3416


Epoch [181/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.10it/s]


 train normalized MSE   0.0580 | val normalized MSE   0.1046 | val MAE   1.6390 | val MSE  10.4648


Epoch [182/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.80it/s]


 train normalized MSE   0.0549 | val normalized MSE   0.1047 | val MAE   1.6218 | val MSE  10.4675


Epoch [183/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.09it/s]


 train normalized MSE   0.0574 | val normalized MSE   0.1026 | val MAE   1.5632 | val MSE  10.2607


Epoch [184/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.04it/s]


 train normalized MSE   0.0565 | val normalized MSE   0.1062 | val MAE   1.6443 | val MSE  10.6210


Epoch [185/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.66it/s]


 train normalized MSE   0.0559 | val normalized MSE   0.1061 | val MAE   1.6458 | val MSE  10.6068


Epoch [186/1000]: 100%|██████████| 71/71 [00:01<00:00, 69.40it/s]


 train normalized MSE   0.0552 | val normalized MSE   0.1014 | val MAE   1.6240 | val MSE  10.1362


Epoch [187/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.33it/s]


 train normalized MSE   0.0542 | val normalized MSE   0.1069 | val MAE   1.6437 | val MSE  10.6880


Epoch [188/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.77it/s]


 train normalized MSE   0.0544 | val normalized MSE   0.1019 | val MAE   1.6076 | val MSE  10.1946


Epoch [189/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.72it/s]


 train normalized MSE   0.0546 | val normalized MSE   0.1053 | val MAE   1.6026 | val MSE  10.5301


Epoch [190/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.00it/s]


 train normalized MSE   0.0565 | val normalized MSE   0.1061 | val MAE   1.6687 | val MSE  10.6109


Epoch [191/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.32it/s]


 train normalized MSE   0.0541 | val normalized MSE   0.1067 | val MAE   1.6626 | val MSE  10.6675


Epoch [192/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.07it/s]


 train normalized MSE   0.0539 | val normalized MSE   0.1033 | val MAE   1.5984 | val MSE  10.3344


Epoch [193/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.13it/s]


 train normalized MSE   0.0542 | val normalized MSE   0.1057 | val MAE   1.6001 | val MSE  10.5713


Epoch [194/1000]: 100%|██████████| 71/71 [00:01<00:00, 68.07it/s]


 train normalized MSE   0.0554 | val normalized MSE   0.1028 | val MAE   1.6041 | val MSE  10.2835


Epoch [195/1000]: 100%|██████████| 71/71 [00:01<00:00, 70.39it/s]


 train normalized MSE   0.0545 | val normalized MSE   0.1009 | val MAE   1.5694 | val MSE  10.0934


Epoch [196/1000]: 100%|██████████| 71/71 [00:00<00:00, 71.84it/s]


 train normalized MSE   0.0559 | val normalized MSE   0.1041 | val MAE   1.6416 | val MSE  10.4075


Epoch [197/1000]: 100%|██████████| 71/71 [00:00<00:00, 72.02it/s]


 train normalized MSE   0.0549 | val normalized MSE   0.1036 | val MAE   1.5964 | val MSE  10.3650


Epoch [198/1000]:  55%|█████▍    | 39/71 [00:00<00:00, 70.20it/s]


KeyboardInterrupt: 

In [None]:
test_dataset = WindowedNormalizedDataset(testData)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

best_model = torch.load("./models/modelM/best_model.pt")
model = SimpleModel().to(device)

model.load_state_dict(best_model)
model.eval()

pred_list = []
with torch.no_grad():
    for batchX, batchY, origin in test_loader:
        batchX = batchX.to(device, non_blocking=True)   
        batchY = batchY.to(device, non_blocking=True)  
        origin = origin.to(device, non_blocking=True) 
            # print(batchX.shape, batchY.shape, origin.shape)
        batchX[..., :2] = batchX[..., :2] - origin.unsqueeze(1).unsqueeze(1)
        batchY = batchY - origin.unsqueeze(1)
        batchX[..., :4] = batchX[..., :4] / 10
        batchY = batchY / 10
        pred = model(batchX)
            
        pred = pred * 10
        pred = pred + origin.unsqueeze(1)
        # print(pred.shape)
        pred_list.append(pred.cpu().numpy())
        # print(len(pred))
        # break

pred_list = np.concatenate(pred_list, axis=0)  
pred_output = pred_list.reshape(-1, 2)  # (N*60, 2)
output_df = pd.DataFrame(pred_output, columns=['x', 'y'])
output_df.index.name = 'index'
output_df.to_csv('./models/modelM/test34.csv', index=True)