In [42]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.optim as optim
from torch_geometric.data import Data, Batch
import torch.nn.functional as F


In [43]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA GPU")
else:
    device = torch.device('cpu')

Using Apple Silicon GPU


In [44]:
def getData(path):
    train_file = np.load(path+"/train.npz")
    train_data = train_file['data']
    test_file = np.load(path+"/test_input.npz")
    test_data = test_file['data']
    print(f"Training Data's shape is {train_data.shape} and Test Data's is {test_data.shape}")
    return train_data, test_data

In [45]:
trainData, testData = getData("./data/")

Training Data's shape is (10000, 50, 110, 6) and Test Data's is (2100, 50, 50, 6)


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class WindowedNormalizedDataset(Dataset):
    def __init__(self, data, window_size, forecast_horizon):
        self.data = data
        self.window_size = window_size
        self.forecast_horizon = forecast_horizon

        self.indices = []
        for sample in range(data.shape[0]):
            for t in range(data.shape[2] - window_size - forecast_horizon + 1):
                self.indices.append((sample, t))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        sample_idx, t = self.indices[idx]

        origin = self.data[sample_idx, 0, 49, :2].copy()
        x = self.data[sample_idx, :, t:t+self.window_size, :]
        x[..., :2] -= origin

        x[..., :4] = x[..., :4] / 10
        y = self.data[sample_idx, 0, t+self.window_size:t+self.window_size+self.forecast_horizon, :2]
        y = y/10
        # print(x.shape, y.shape)

        return (
        torch.tensor(x, dtype=torch.float32),
        torch.tensor(y, dtype=torch.float32),
        torch.tensor(origin, dtype=torch.float32)  # normalize origin for consistency
        )

In [6]:
torch.manual_seed(42)
# np.random.seed(42)

scale = 9.0

N = trainData.shape[0]
val_size = int(0.1 * N)
train_size = N - val_size

train_dataset = WindowedNormalizedDataset(data = trainData[:train_size], window_size=50, forecast_horizon=60)
validation_dataset = WindowedNormalizedDataset(data = trainData[train_size:], window_size=50, forecast_horizon=60)


In [7]:
class EncoderDecoderModel(nn.Module):
    def __init__(self, infeatures, outfeatures=2):
        super().__init__()
        # Encoder
        self.layer1 = nn.Linear(infeatures, 32)
        self.layer2 = nn.Linear(32, 64)
        self.layer3 = nn.Linear(64, 128)
        self.encoderlstm = nn.LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.3)
        
        # Changed pooling target from 20 to 60
        self.pool = nn.AdaptiveAvgPool1d(60)
        self.dropout = nn.Dropout(0.2)
        
        # Decoder
        self.decoderlstm = nn.LSTM(256, 128, num_layers=2, batch_first=True, dropout=0.3)
        self.layer10 = nn.Linear(128, 64)
        self.layer11 = nn.Linear(64, 32)
        self.layer12 = nn.Linear(32, outfeatures)
        
        # Skip connections
        self.skip1 = nn.Linear(32, 32)
        self.skip2 = nn.Linear(64, 64)
        self.skip3 = nn.Linear(128, 128)
        self.skip4 = nn.Linear(256, 256)

    def forward(self, x):
        batch_size, channels, height, width = x.shape
        
        # Encoder
        out1 = nn.ReLU()(self.layer1(x))
        out2 = nn.ReLU()(self.layer2(out1))
        out3 = nn.ReLU()(self.layer3(out2))
        
        # LSTM processing
        tempout3 = out3.view(batch_size, -1, out3.size(-1))
        out4, _ = self.encoderlstm(tempout3)
        
        # Changed pooling to 60
        tempout4 = self.pool(out4.permute(0, 2, 1))
        tempout4 = tempout4.permute(0, 2, 1)
        lstmskip = tempout4 + self.skip4(tempout4)
        
        # Decoder LSTM
        out5, _ = self.decoderlstm(lstmskip)
        
        out3_reduced = F.adaptive_avg_pool2d(out3.permute(0, 3, 1, 2), (60, 1)).squeeze(-1).permute(0, 2, 1)
        mlpskip1 = out3_reduced + self.skip3(out5)
        out6 = nn.ReLU()(self.layer10(mlpskip1))
        
        out2_reduced = F.adaptive_avg_pool2d(out2.permute(0, 3, 1, 2), (60, 1)).squeeze(-1).permute(0, 2, 1)
        mlpskip2 = out2_reduced + self.skip2(out6)
        out7 = nn.ReLU()(self.layer11(mlpskip2))
        
        out1_reduced = F.adaptive_avg_pool2d(out1.permute(0, 3, 1, 2), (60, 1)).squeeze(-1).permute(0, 2, 1)
        mlpskip3 = out1_reduced + self.skip1(out7)
        out8 = self.layer12(mlpskip3)
        
        return out8

# Verify the output shape
def xavier_init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
        
model = EncoderDecoderModel(6, 2)
model.apply(xavier_init_weights)

# test = torch.randn(128, 50, 50, 6)
# out = model(test)
# print(out.shape)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
model.to(device)

Total parameters: 1359906


EncoderDecoderModel(
  (layer1): Linear(in_features=6, out_features=32, bias=True)
  (layer2): Linear(in_features=32, out_features=64, bias=True)
  (layer3): Linear(in_features=64, out_features=128, bias=True)
  (encoderlstm): LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.3)
  (pool): AdaptiveAvgPool1d(output_size=60)
  (dropout): Dropout(p=0.2, inplace=False)
  (decoderlstm): LSTM(256, 128, num_layers=2, batch_first=True, dropout=0.3)
  (layer10): Linear(in_features=128, out_features=64, bias=True)
  (layer11): Linear(in_features=64, out_features=32, bias=True)
  (layer12): Linear(in_features=32, out_features=2, bias=True)
  (skip1): Linear(in_features=32, out_features=32, bias=True)
  (skip2): Linear(in_features=64, out_features=64, bias=True)
  (skip3): Linear(in_features=128, out_features=128, bias=True)
  (skip4): Linear(in_features=256, out_features=256, bias=True)
)

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EncoderDecoderModel(nn.Module):
    def __init__(self, infeatures = 6, outfeatures=2, agents = 50, timestamp = 50, windowSize = 60, ):
        super().__init__()
        
        # batch_size, x_dim, y_dim, features = x.shape

        # x = x.view(batch_size, x_dim * y_dim, features)
        self.project_layer = nn.Linear(agents * timestamp, 60)  
        # MLP Encoder
        self.layer1 = nn.Linear(infeatures, 32)
        self.layer2 = nn.Linear(32, 64)
        self.layer3 = nn.Linear(64, 128)
        
        # LSTM Encoder (2 layers)
        self.encoder_lstm = nn.LSTM(128, 256, num_layers=2, batch_first=True)
        
        # Decoder LSTM (1 layer)
        self.decoder_lstm = nn.LSTM(256, 128, num_layers=1, batch_first=True)
        
        # MLP Decoder
        self.layer10 = nn.Linear(128, 64)
        self.layer11 = nn.Linear(64, 32)
        self.layer12 = nn.Linear(32, outfeatures)
        

    def forward(self, x):
        # x: [batch_size, seq_len, infeatures]
        batch_size, x_dim, y_dim, features = x.shape

        x = x.view(batch_size, x_dim * y_dim, features)
        # project = self.project_layer(x_dim * y_dim, 60)  # maps x*y → 60
        x = x.permute(0, 2, 1)  # (batch, features, x*y)
        x = self.project_layer(x)          # (batch, features, 60)
        x = x.permute(0, 2, 1)

        x = F.relu(self.layer1(x))
        
        
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        
        # # LSTM Encoder
        # print(out.shape)
        x, _ = self.encoder_lstm(x)
        # print(x.shape)
        
        # # LSTM Decoder
        x, _ = self.decoder_lstm(x)
        # print(x.shape)
        
        # # Decoder MLP
        x = F.relu(self.layer10(x))
        x = F.relu(self.layer11(x))
        x = self.layer12(x)
        
        return x

# Weight initialization
def xavier_init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

model = EncoderDecoderModel(6, 2)
model.apply(xavier_init_weights)

test = torch.randn(128, 50, 50, 6)
out = model(test)
print(out.shape)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
model.to(device)


torch.Size([128, 60, 2])
Total parameters: 1290350


EncoderDecoderModel(
  (project_layer): Linear(in_features=2500, out_features=60, bias=True)
  (layer1): Linear(in_features=6, out_features=32, bias=True)
  (layer2): Linear(in_features=32, out_features=64, bias=True)
  (layer3): Linear(in_features=64, out_features=128, bias=True)
  (encoder_lstm): LSTM(128, 256, num_layers=2, batch_first=True)
  (decoder_lstm): LSTM(256, 128, batch_first=True)
  (layer10): Linear(in_features=128, out_features=64, bias=True)
  (layer11): Linear(in_features=64, out_features=32, bias=True)
  (layer12): Linear(in_features=32, out_features=2, bias=True)
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EncoderDecoderModel(nn.Module):
    def __init__(self, infeatures=6, outfeatures=2):
        super().__init__()

        self.conv1d = nn.Conv1d(in_channels=infeatures, out_channels=64, kernel_size=3, stride=42, padding=1)

        self.layer1 = nn.Linear(64, 128)
        self.layer2 = nn.Linear(128, 256)

        self.encoder_lstm = nn.LSTM(256, 256, num_layers=2, batch_first=True)

        self.decoder_lstm = nn.LSTM(256, 256, num_layers=2, batch_first=True)

        self.layer10 = nn.Linear(256, 128)
        self.layer11 = nn.Linear(128, 64)
        self.layer12 = nn.Linear(64, 32)
        self.layer13 = nn.Linear(32, 16)
        self.layer14 = nn.Linear(16, outfeatures)

        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):
        batch_size, x_dim, y_dim, features = x.shape

        x = x.permute(0, 3, 1, 2).reshape(batch_size, features, -1)  # (batch, infeatures, 2500)

        x = self.conv1d(x)  # (batch, 64, 60)
        x = x.permute(0, 2, 1)  # (batch, 60, 64)

        x = F.relu(self.layer1(x))
        x = self.dropout(x)
        x = F.relu(self.layer2(x))
        x = self.dropout(x)

        x, _ = self.encoder_lstm(x)
        x, _ = self.decoder_lstm(x)

        x = F.relu(self.layer10(x))
        x = F.relu(self.layer11(x))
        x = F.relu(self.layer12(x))
        x = F.relu(self.layer13(x))
        x = self.layer14(x)

        return x


model = EncoderDecoderModel(6, 2)
model.apply(xavier_init_weights)

test = torch.randn(128, 50, 50, 6)
out = model(test)
print(out.shape)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
model.to(device)


torch.Size([128, 60, 2])
Total parameters: 2191698


EncoderDecoderModel(
  (conv1d): Conv1d(6, 64, kernel_size=(3,), stride=(42,), padding=(1,))
  (layer1): Linear(in_features=64, out_features=128, bias=True)
  (layer2): Linear(in_features=128, out_features=256, bias=True)
  (encoder_lstm): LSTM(256, 256, num_layers=2, batch_first=True)
  (decoder_lstm): LSTM(256, 256, num_layers=2, batch_first=True)
  (layer10): Linear(in_features=256, out_features=128, bias=True)
  (layer11): Linear(in_features=128, out_features=64, bias=True)
  (layer12): Linear(in_features=64, out_features=32, bias=True)
  (layer13): Linear(in_features=32, out_features=16, bias=True)
  (layer14): Linear(in_features=16, out_features=2, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)

In [41]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# torch.cuda.empty_cache()
# print(device)
# model.load_state_dict(torch.load("./models/modelF/medium_model_0.0062230645.pth"))  

trainDataLoader = DataLoader(train_dataset, batch_size=128, shuffle=True)
testDataLoader = DataLoader(validation_dataset, batch_size=128)
model.to(device)
# Training setup
epochs = 1000
lossFn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

tLoss = 10000
vLoss = 10000
# Iused 0.0008, 0.0001

for each_epoch in range(epochs):
    model.train()
    runningLoss = 0.0
    loop = tqdm(trainDataLoader, desc=f"Epoch [{each_epoch+1}/{epochs}]")
    totalSamples = 0
    for batchX, batchY, _ in loop:
        batchX, batchY = batchX.to(device, non_blocking=True), batchY.to(device, non_blocking=True)
        output = model(batchX)
        # print("Pred Training Shape: ", output.shape,"True Value", batchY.shape)
        # break
        loss = lossFn(output, batchY)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        runningLoss += loss.item()
        totalSamples += batchX.size(0)

    avgLoss = runningLoss/len(trainDataLoader)
    # print(avgLoss)#, runningLoss, len(trainDataLoader.dataset))    
    # break

    model.eval()
    with torch.no_grad():
        testloss = 0.0
        unnorm_loss = 0.0
        for testX, testY, origin in testDataLoader:
            testX, testY, origin = testX.to(device), testY.to(device), origin.to(device)
    
            pred = model(testX)
            # print(pred.shape, testY.shape, origin.shape)
            origin = origin.unsqueeze(1).expand(-1, 60, -1)
            # print(pred+origin)
            # break
            loss = lossFn(pred, testY)
            pred_absolute = (pred * 10) + origin 
            true_unnorm = (testY * 10) + origin
    
            testloss += loss.item()

            unnorm_loss += nn.MSELoss()(pred_absolute, true_unnorm).item()
        
        avgtestloss = testloss / len(testDataLoader)
        avgUnnormLoss = unnorm_loss /len(testDataLoader)

    # break
    print(f"Epoch {each_epoch + 1}, Training Loss: {avgLoss:},  Validation Loss:{avgtestloss:} , Absolute MSE:{avgUnnormLoss}")
    # break
    if(avgLoss < tLoss and avgtestloss < vLoss):
        tLoss = avgLoss
        vLoss = avgtestloss
        torch.save(model.state_dict(), f'./models/modelL/medium_model_{avgLoss:.10f}.pth')

    with open("./models/modelL/log_loss.txt", 'a') as f:
        f.write(f"{each_epoch + 1},{avgLoss:.10f},{avgtestloss:.10f},\n")
        
    torch.cuda.empty_cache()
    scheduler.step()


Epoch [1/1000]:   0%|          | 0/71 [00:00<?, ?it/s]

Epoch [1/1000]: 100%|██████████| 71/71 [00:20<00:00,  3.53it/s]


Epoch 1, Training Loss: 113979.82394366198,  Validation Loss:107880.9609375 , Absolute MSE:10788096.375


Epoch [2/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.94it/s]


Epoch 2, Training Loss: 108308.75704225352,  Validation Loss:100067.1298828125 , Absolute MSE:10006713.25


Epoch [3/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.89it/s]


Epoch 3, Training Loss: 100618.72722271127,  Validation Loss:93249.8642578125 , Absolute MSE:9324986.375


Epoch [4/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.03it/s]


Epoch 4, Training Loss: 94988.40713028169,  Validation Loss:87671.3408203125 , Absolute MSE:8767133.9375


Epoch [5/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.90it/s]


Epoch 5, Training Loss: 89961.58263644367,  Validation Loss:83286.0673828125 , Absolute MSE:8328607.0625


Epoch [6/1000]: 100%|██████████| 71/71 [00:17<00:00,  3.95it/s]


Epoch 6, Training Loss: 85756.82873019367,  Validation Loss:78690.6435546875 , Absolute MSE:7869063.8125


Epoch [7/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.81it/s]


Epoch 7, Training Loss: 80907.42143485915,  Validation Loss:73530.22119140625 , Absolute MSE:7353022.0


Epoch [8/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.10it/s]


Epoch 8, Training Loss: 76255.72001540494,  Validation Loss:69428.7470703125 , Absolute MSE:6942874.5625


Epoch [9/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.84it/s]


Epoch 9, Training Loss: 73356.95004401408,  Validation Loss:67583.72998046875 , Absolute MSE:6758373.0


Epoch [10/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.84it/s]


Epoch 10, Training Loss: 72644.20290492958,  Validation Loss:67145.24853515625 , Absolute MSE:6714524.6875


Epoch [11/1000]: 100%|██████████| 71/71 [00:17<00:00,  3.99it/s]


Epoch 11, Training Loss: 72206.41428257042,  Validation Loss:67061.84033203125 , Absolute MSE:6706183.9375


Epoch [12/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.12it/s]


Epoch 12, Training Loss: 72206.51430457746,  Validation Loss:67075.79541015625 , Absolute MSE:6707579.6875


Epoch [13/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.01it/s]


Epoch 13, Training Loss: 72200.87918133802,  Validation Loss:67042.1845703125 , Absolute MSE:6704218.625


Epoch [14/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.92it/s]


Epoch 14, Training Loss: 72271.32482394367,  Validation Loss:67053.44189453125 , Absolute MSE:6705344.0


Epoch [15/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.03it/s]


Epoch 15, Training Loss: 72147.2745378521,  Validation Loss:67061.30419921875 , Absolute MSE:6706130.125


Epoch [16/1000]: 100%|██████████| 71/71 [00:17<00:00,  3.96it/s]


Epoch 16, Training Loss: 72203.83351672535,  Validation Loss:67111.72802734375 , Absolute MSE:6711173.0


Epoch [17/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.82it/s]


Epoch 17, Training Loss: 72091.45163952465,  Validation Loss:67062.56689453125 , Absolute MSE:6706256.875


Epoch [18/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.86it/s]


Epoch 18, Training Loss: 72071.25990316902,  Validation Loss:67065.90234375 , Absolute MSE:6706590.125


Epoch [19/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.05it/s]


Epoch 19, Training Loss: 72273.93788512323,  Validation Loss:67055.3466796875 , Absolute MSE:6705534.6875


Epoch [20/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.00it/s]


Epoch 20, Training Loss: 72486.12758582746,  Validation Loss:67055.03564453125 , Absolute MSE:6705503.625


Epoch [21/1000]: 100%|██████████| 71/71 [00:17<00:00,  3.98it/s]


Epoch 21, Training Loss: 72307.88083186619,  Validation Loss:67064.70654296875 , Absolute MSE:6706470.625


Epoch [22/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.02it/s]


Epoch 22, Training Loss: 72398.61014524648,  Validation Loss:67064.24462890625 , Absolute MSE:6706424.5


Epoch [23/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.04it/s]


Epoch 23, Training Loss: 72065.56629621479,  Validation Loss:67063.3623046875 , Absolute MSE:6706336.25


Epoch [24/1000]: 100%|██████████| 71/71 [00:16<00:00,  4.18it/s]


Epoch 24, Training Loss: 72180.45092429577,  Validation Loss:67042.35546875 , Absolute MSE:6704235.5


Epoch [25/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.14it/s]


Epoch 25, Training Loss: 72417.39012984154,  Validation Loss:67055.72998046875 , Absolute MSE:6705573.125


Epoch [26/1000]: 100%|██████████| 71/71 [00:17<00:00,  3.95it/s]


Epoch 26, Training Loss: 72365.27695862677,  Validation Loss:67055.68408203125 , Absolute MSE:6705568.375


Epoch [27/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.91it/s]


Epoch 27, Training Loss: 72159.5303147007,  Validation Loss:67055.0595703125 , Absolute MSE:6705506.4375


Epoch [28/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.87it/s]


Epoch 28, Training Loss: 72060.25357614437,  Validation Loss:67056.87744140625 , Absolute MSE:6705687.625


Epoch [29/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.07it/s]


Epoch 29, Training Loss: 72387.36944322183,  Validation Loss:67097.7421875 , Absolute MSE:6709774.0


Epoch [30/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.17it/s]


Epoch 30, Training Loss: 71980.1785871479,  Validation Loss:67063.005859375 , Absolute MSE:6706300.5625


Epoch [31/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.16it/s]


Epoch 31, Training Loss: 72206.13099691902,  Validation Loss:67081.72265625 , Absolute MSE:6708171.875


Epoch [32/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.09it/s]


Epoch 32, Training Loss: 71998.43915052817,  Validation Loss:67043.86279296875 , Absolute MSE:6704386.3125


Epoch [33/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.15it/s]


Epoch 33, Training Loss: 72287.98492517606,  Validation Loss:67038.86572265625 , Absolute MSE:6703886.375


Epoch [34/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.82it/s]


Epoch 34, Training Loss: 72182.83901848592,  Validation Loss:67088.3876953125 , Absolute MSE:6708838.625


Epoch [35/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.90it/s]


Epoch 35, Training Loss: 72107.96033230633,  Validation Loss:67045.21337890625 , Absolute MSE:6704521.375


Epoch [36/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.10it/s]


Epoch 36, Training Loss: 72162.80100132042,  Validation Loss:67059.62109375 , Absolute MSE:6705962.1875


Epoch [37/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.15it/s]


Epoch 37, Training Loss: 72171.25830765846,  Validation Loss:67040.26904296875 , Absolute MSE:6704027.0


Epoch [38/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.02it/s]


Epoch 38, Training Loss: 72289.4793683979,  Validation Loss:67047.8291015625 , Absolute MSE:6704783.0625


Epoch [39/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.86it/s]


Epoch 39, Training Loss: 72322.9776078345,  Validation Loss:67040.46142578125 , Absolute MSE:6704046.0


Epoch [40/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.86it/s]


Epoch 40, Training Loss: 72335.60497359154,  Validation Loss:67064.265625 , Absolute MSE:6706426.5625


Epoch [41/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.78it/s]


Epoch 41, Training Loss: 72110.57878521127,  Validation Loss:67065.1474609375 , Absolute MSE:6706515.0


Epoch [42/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.00it/s]


Epoch 42, Training Loss: 71980.48228433098,  Validation Loss:67041.88671875 , Absolute MSE:6704188.9375


Epoch [43/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.12it/s]


Epoch 43, Training Loss: 72600.58582746479,  Validation Loss:67043.50048828125 , Absolute MSE:6704350.1875


Epoch [44/1000]: 100%|██████████| 71/71 [00:17<00:00,  3.96it/s]


Epoch 44, Training Loss: 72402.88853433098,  Validation Loss:67043.7587890625 , Absolute MSE:6704376.125


Epoch [45/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.78it/s]


Epoch 45, Training Loss: 72301.66186179577,  Validation Loss:67055.6474609375 , Absolute MSE:6705565.0625


Epoch [46/1000]: 100%|██████████| 71/71 [00:17<00:00,  3.98it/s]


Epoch 46, Training Loss: 72091.08285651408,  Validation Loss:67050.4052734375 , Absolute MSE:6705040.25


Epoch [47/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.05it/s]


Epoch 47, Training Loss: 72142.53273547535,  Validation Loss:67072.0849609375 , Absolute MSE:6707208.4375


Epoch [48/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.03it/s]


Epoch 48, Training Loss: 71961.58082086267,  Validation Loss:67042.861328125 , Absolute MSE:6704286.0


Epoch [49/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.92it/s]


Epoch 49, Training Loss: 72346.5887433979,  Validation Loss:67078.8095703125 , Absolute MSE:6707880.8125


Epoch [50/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.81it/s]


Epoch 50, Training Loss: 72025.62246919014,  Validation Loss:67069.51953125 , Absolute MSE:6706952.125


Epoch [51/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.83it/s]


Epoch 51, Training Loss: 72134.56024427817,  Validation Loss:67061.076171875 , Absolute MSE:6706107.5625


Epoch [52/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.94it/s]


Epoch 52, Training Loss: 71993.43370378521,  Validation Loss:67059.5537109375 , Absolute MSE:6705955.5625


Epoch [53/1000]: 100%|██████████| 71/71 [00:16<00:00,  4.19it/s]


Epoch 53, Training Loss: 72239.13847931338,  Validation Loss:67056.2919921875 , Absolute MSE:6705629.1875


Epoch [54/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.17it/s]


Epoch 54, Training Loss: 72613.8817671655,  Validation Loss:67057.716796875 , Absolute MSE:6705771.5625


Epoch [55/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.94it/s]


Epoch 55, Training Loss: 71954.55100132042,  Validation Loss:67054.29833984375 , Absolute MSE:6705430.0


Epoch [56/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.86it/s]


Epoch 56, Training Loss: 72284.08676276408,  Validation Loss:67051.541015625 , Absolute MSE:6705154.1875


Epoch [57/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.87it/s]


Epoch 57, Training Loss: 71892.57994058098,  Validation Loss:67054.08740234375 , Absolute MSE:6705408.5


Epoch [58/1000]: 100%|██████████| 71/71 [00:16<00:00,  4.20it/s]


Epoch 58, Training Loss: 72217.3604753521,  Validation Loss:67051.87841796875 , Absolute MSE:6705187.6875


Epoch [59/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.90it/s]


Epoch 59, Training Loss: 72395.8164612676,  Validation Loss:67050.66845703125 , Absolute MSE:6705066.6875


Epoch [60/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.84it/s]


Epoch 60, Training Loss: 72082.37747579225,  Validation Loss:67051.048828125 , Absolute MSE:6705104.75


Epoch [61/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.00it/s]


Epoch 61, Training Loss: 72299.26501980633,  Validation Loss:67052.52978515625 , Absolute MSE:6705252.625


Epoch [62/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.11it/s]


Epoch 62, Training Loss: 72179.47089568662,  Validation Loss:67051.51611328125 , Absolute MSE:6705151.8125


Epoch [63/1000]: 100%|██████████| 71/71 [00:20<00:00,  3.54it/s]


Epoch 63, Training Loss: 72359.20626100352,  Validation Loss:67053.2431640625 , Absolute MSE:6705323.9375


Epoch [64/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.05it/s]


Epoch 64, Training Loss: 72325.26749559859,  Validation Loss:67047.60205078125 , Absolute MSE:6704760.6875


Epoch [65/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.80it/s]


Epoch 65, Training Loss: 72234.47810299296,  Validation Loss:67048.71484375 , Absolute MSE:6704871.25


Epoch [66/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.92it/s]


Epoch 66, Training Loss: 71999.17352552817,  Validation Loss:67050.52978515625 , Absolute MSE:6705052.9375


Epoch [67/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.08it/s]


Epoch 67, Training Loss: 72051.40069322183,  Validation Loss:67046.3994140625 , Absolute MSE:6704639.9375


Epoch [68/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.07it/s]


Epoch 68, Training Loss: 72421.54049295775,  Validation Loss:67050.35888671875 , Absolute MSE:6705036.0625


Epoch [69/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.12it/s]


Epoch 69, Training Loss: 72175.27299735915,  Validation Loss:67049.89697265625 , Absolute MSE:6704989.75


Epoch [70/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.03it/s]


Epoch 70, Training Loss: 72132.16164172535,  Validation Loss:67052.79296875 , Absolute MSE:6705279.3125


Epoch [71/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.92it/s]


Epoch 71, Training Loss: 72031.14524647887,  Validation Loss:67052.5087890625 , Absolute MSE:6705250.6875


Epoch [72/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.77it/s]


Epoch 72, Training Loss: 72356.21027728873,  Validation Loss:67052.66455078125 , Absolute MSE:6705266.5625


Epoch [73/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.78it/s]


Epoch 73, Training Loss: 72117.6689040493,  Validation Loss:67050.43310546875 , Absolute MSE:6705043.25


Epoch [74/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.77it/s]


Epoch 74, Training Loss: 72288.86020026408,  Validation Loss:67049.87158203125 , Absolute MSE:6704987.25


Epoch [75/1000]: 100%|██████████| 71/71 [00:17<00:00,  3.97it/s]


Epoch 75, Training Loss: 72050.5889084507,  Validation Loss:67048.3720703125 , Absolute MSE:6704837.1875


Epoch [76/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.08it/s]


Epoch 76, Training Loss: 72363.19503741198,  Validation Loss:67048.94873046875 , Absolute MSE:6704895.1875


Epoch [77/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.14it/s]


Epoch 77, Training Loss: 72013.66549295775,  Validation Loss:67049.48974609375 , Absolute MSE:6704948.625


Epoch [78/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.09it/s]


Epoch 78, Training Loss: 72269.12863116198,  Validation Loss:67052.02197265625 , Absolute MSE:6705202.5625


Epoch [79/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.09it/s]


Epoch 79, Training Loss: 72508.2881272007,  Validation Loss:67049.4462890625 , Absolute MSE:6704944.5625


Epoch [80/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.05it/s]


Epoch 80, Training Loss: 72139.74378301056,  Validation Loss:67054.93896484375 , Absolute MSE:6705494.0


Epoch [81/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.10it/s]


Epoch 81, Training Loss: 71973.74818441902,  Validation Loss:67052.9619140625 , Absolute MSE:6705296.1875


Epoch [82/1000]: 100%|██████████| 71/71 [00:16<00:00,  4.22it/s]


Epoch 82, Training Loss: 72435.99152728873,  Validation Loss:67051.49951171875 , Absolute MSE:6705150.125


Epoch [83/1000]: 100%|██████████| 71/71 [00:17<00:00,  4.16it/s]


Epoch 83, Training Loss: 72422.84000880281,  Validation Loss:67051.548828125 , Absolute MSE:6705155.125


Epoch [84/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.78it/s]


Epoch 84, Training Loss: 72106.28515625,  Validation Loss:67052.15234375 , Absolute MSE:6705215.125


Epoch [85/1000]: 100%|██████████| 71/71 [00:18<00:00,  3.85it/s]


Epoch 85, Training Loss: 72466.14106514085,  Validation Loss:67049.5849609375 , Absolute MSE:6704958.375


Epoch [86/1000]: 100%|██████████| 71/71 [00:16<00:00,  4.19it/s]


Epoch 86, Training Loss: 72095.05633802817,  Validation Loss:67050.1083984375 , Absolute MSE:6705010.6875


Epoch [87/1000]: 100%|██████████| 71/71 [00:17<00:00,  3.98it/s]


Epoch 87, Training Loss: 72024.59848151408,  Validation Loss:67051.96044921875 , Absolute MSE:6705196.125


Epoch [88/1000]:  21%|██        | 15/71 [00:04<00:16,  3.45it/s]


KeyboardInterrupt: 

In [47]:
# A = torch.randn(32, 60, 2)  
# B = torch.randn(32, 2)      
# print(B)
# # Expand B to shape (32, 60, 2)
# B_expanded = B.unsqueeze(1).expand(-1, 10, -1)
# B_expanded