In [60]:
import torch
from torch.utils.data import Dataset
import json
import os
import torch.nn as nn
import torch.nn.functional as func
from torch.utils.data import DataLoader
import torch.optim as optim
from os.path import expanduser
import splitfolders
import shutil
import glob
import numpy as np
from sklearn.model_selection import train_test_split

In [61]:
class KpVelDataset(Dataset):
    def __init__(self, json_folder):
        super(KpVelDataset, self).__init__()
        self.data = []
        for json_file in sorted(os.listdir(json_folder)):
            if json_file.endswith('_combined.json'):
                with open(os.path.join(json_folder, json_file), 'r') as file:
                    data = json.load(file)
                    start_kp = data['start_kp']
                    next_kp = data['next_kp']
                    velocity = data['velocity']
                    self.data.append((start_kp, next_kp, velocity))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        start_kp, next_kp, velocity = self.data[idx]
        start_kp_flat = torch.tensor([kp for sublist in start_kp for kp in sublist[0]])
        next_kp_flat = torch.tensor([kp for sublist in next_kp for kp in sublist[0]])
        velocity = torch.tensor(velocity)
        return start_kp_flat, next_kp_flat, velocity
    

In [62]:
def train_test_split(src_dir):
#     dst_dir_img = src_dir + "images"
    dst_dir_anno = src_dir + "annotations"
    
    if os.path.exists(dst_dir_anno):
        print("folders exist")
    else:
        os.mkdir(dst_dir_anno)
        
#     for jpgfile in glob.iglob(os.path.join(src_dir, "*.jpg")):
#         shutil.copy(jpgfile, dst_dir_img)

    for jsonfile in glob.iglob(os.path.join(src_dir, "*_combined.json")):
        shutil.copy(jsonfile, dst_dir_anno)
        
    output = root_dir + "split_folder_reg"
    
    splitfolders.ratio(src_dir, # The location of dataset
                   output=output, # The output location
                   seed=42, # The number of seed
                   ratio=(0.7, 0.15, 0.15), # The ratio of split dataset
                   group_prefix=None, # If your dataset contains more than one file like ".jpg", ".pdf", etc
                   move=False # If you choose to move, turn this into True
                   )
    
#     shutil.rmtree(dst_dir_img)
    shutil.rmtree(dst_dir_anno)
    
    return output  

In [63]:
# class VelRegModel(nn.Module):
#     def __init__(self, input_size):
#         super(VelRegModel, self).__init__()
#         self.fc1 = nn.Linear(input_size * 2, 1024)  # Assuming start_kp and next_kp are concatenated
#         self.fc2 = nn.Linear(1024, 512)
#         self.fc3 = nn.Linear(512,256)
#         self.fc4 = nn.Linear(256,128)
#         self.fc5 = nn.Linear(128,64)
#         self.fc6 = nn.Linear(64,64)
#         self.fc7 = nn.Linear(64,6)  # Output size is 3 for velocity

#     def forward(self, start_kp, velocity):
#         x = torch.cat((start_kp, velocity), dim=1)
#         x = func.relu(self.fc1(x))
#         x = func.relu(self.fc2(x))
#         x = func.relu(self.fc3(x))
#         x = func.relu(self.fc4(x))
#         x = func.relu(self.fc5(x))
#         x = func.relu(self.fc6(x))
#         x = self.fc7(x)
#         return x

In [72]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class KeypointRegressionNet(nn.Module):
    def __init__(self):
        super(KeypointRegressionNet, self).__init__()
        # Define the architecture
        self.fc1 = nn.Linear(21, 128)  # 18 keypoints + 3 velocity values
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 18)  # Output size is 18 (6 keypoints * 3 values each)

    def forward(self, start_kp, velocity):
        # Flatten start keypoints and concatenate with velocity
        x = torch.cat((start_kp, velocity), dim=1)
        
        # Forward pass through the network
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)  # No activation function in the last layer
        return x

In [70]:
# Initialize dataset and data loader
# to generalize home directory. User can change their parent path without entering their home directory
num_epochs = 250
batch_size = 128
v = 3
root_dir = '/home/jc-merlab/Pictures/panda_data/panda_sim_vel/vel_reg_sim_test/'
print(root_dir)
split_folder_path = train_test_split(root_dir)
dataset = KpVelDataset(root_dir)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize model
model = KeypointRegressionNet()  # Adjust input_size as necessary
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)

# Training loop
for epoch in range(num_epochs):
    for start_kp, next_kp, velocity in data_loader:
        optimizer.zero_grad()
        velocity = velocity.squeeze(1)
        print(start_kp.shape)
        print(velocity.shape)
        output = model(start_kp, velocity)
        loss = criterion(output, next_kp)
        loss.backward()
        optimizer.step()
        print("output", output)
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')
    
# Save the trained model
model_save_path = f'/home/jc-merlab/Pictures/Data/trained_models/reg_nkp_b{batch_size}_e{num_epochs}_v{v}.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")


/home/jc-merlab/Pictures/panda_data/panda_sim_vel/vel_reg_sim_test/


Copying files: 3132 files [00:00, 18637.30 files/s]


torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[-17.2105,   8.0819,  -3.7918,  ...,  -0.3458,   9.0517,   9.5128],
        [ -9.7912,   7.3680,  -5.9000,  ...,  -1.2536,   8.7140,   4.1146],
        [-14.5960,   6.1304,  -3.8637,  ...,  -0.3554,   7.0163,   7.2547],
        ...,
        [-13.0491,   5.2479,  -3.7508,  ...,   0.1268,   5.1856,   6.8825],
        [ -9.6725,   7.1833,  -5.8597,  ...,  -0.9611,   8.2416,   3.8241],
        [-14.0880,   5.4400,  -3.6499,  ...,  -0.0720,   6.3218,   7.1921]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 80.7178, 172.8227,  65.7700,  ...,  95.0945,  93.8538,  16.8076],
        [ 99.2576, 216.3378,  84.2275,  ..., 117.8697, 118.1428,  22.9346],
        [ 79.3023, 167.1315,  62.3937,  ...,  86.6561,  89.1970,  13.0872],
        ...,
        [ 72.5383, 155.4098,  59.1620,  ...,  84.2514,  83.3890,  13.5605],
        [ 73.9411, 158.8500,  60.8121,  ...,  87.1790,  85.3799,  14.6534],
    

output tensor([[241.8555, 333.4318,   5.4125,  ..., 266.0714, 171.3859,   1.8070],
        [289.1658, 395.1394,   8.6299,  ..., 445.3755, 287.6058,  -4.2604],
        [245.2279, 336.1860,   5.9110,  ..., 305.9469, 175.1573,  -1.2418],
        ...,
        [268.8633, 386.7881,  -7.3931,  ...,  80.6208, 211.1143,   4.0283],
        [250.3639, 340.0231,   6.5317,  ..., 381.9652, 210.6390,  -4.6008],
        [260.9583, 369.6223,   0.7599,  ..., 161.8573, 188.0839,   7.9722]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[252.1105, 353.1423,   5.9937,  ..., 263.0735, 212.4309,   1.8475],
        [242.1610, 337.2092,   4.5602,  ..., 306.2841, 178.6939,  -2.8553],
        [266.9996, 371.8618,   4.8455,  ..., 433.0881, 274.7920,  -7.3946],
        ...,
        [265.1457, 370.6563,   6.5815,  ..., 384.9132, 273.6104,  -3.8414],
        [251.5038, 351.7184,   5.6113,  ..., 241.4981, 189.1081,   2.4739],
        [247.9777, 347.2054,   5.7056,  ..., 31

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6863e+02,  3.8331e+02, -1.9355e+00,  ...,  4.4001e+02,
          2.5193e+02, -1.9937e+00],
        [ 2.4995e+02,  3.5879e+02, -3.1307e+00,  ...,  3.9387e+02,
          2.9976e+02, -5.7590e-01],
        [ 2.5606e+02,  3.6600e+02, -2.9653e+00,  ...,  3.5946e+02,
          3.4669e+02,  1.3017e+00],
        ...,
        [ 2.5374e+02,  3.6175e+02, -6.3663e-01,  ...,  3.1540e+02,
          3.2973e+02,  1.5884e+00],
        [ 2.5520e+02,  3.6304e+02, -2.9487e-01,  ...,  3.8127e+02,
          1.4249e+02, -6.8753e-01],
        [ 2.5059e+02,  3.5887e+02, -2.7168e+00,  ...,  3.6462e+02,
          3.1934e+02,  4.7737e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5649e+02,  3.6904e+02,  1.0801e+00,  ...,  2.9155e+02,
          1.0218e+02, -2.2506e+00],
        [ 2.6290e+02,  3.7462e+02, -2.2180e+00,  ...,  1.0554e+02,
          2.0012e+02,  1.7199e+00],
        [ 2.6685e+02,  3.8382e+02, -

output tensor([[ 2.5571e+02,  3.6084e+02,  1.3360e+00,  ...,  7.8079e+01,
          1.8867e+02,  2.1922e+00],
        [ 2.5623e+02,  3.6119e+02,  1.1200e+00,  ...,  8.1132e+01,
          1.8052e+02,  1.9106e+00],
        [ 2.5901e+02,  3.6780e+02,  2.3813e-01,  ...,  4.2060e+02,
          1.9171e+02,  2.2393e+00],
        ...,
        [ 2.5813e+02,  3.7091e+02,  8.8624e-01,  ...,  2.2403e+02,
          9.3906e+01,  5.5937e-01],
        [ 2.5563e+02,  3.6526e+02,  1.2772e+00,  ...,  3.5267e+02,
          1.3370e+02,  3.6402e+00],
        [ 2.5494e+02,  3.6511e+02, -8.3479e-01,  ...,  2.2197e+02,
          1.2691e+02,  2.8720e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5947e+02, 3.7201e+02, 5.9857e-01,  ..., 3.5584e+02, 3.4654e+02,
         1.9252e+00],
        [2.5617e+02, 3.6899e+02, 2.1680e+00,  ..., 3.0618e+02, 3.3551e+02,
         2.8407e+00],
        [2.6544e+02, 3.7340e+02, 3.2950e-01,  ..., 2.1570e+02, 1.7723e+02,
         1.8181e

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5127e+02,  3.5814e+02,  1.0870e+00,  ...,  2.9181e+02,
          1.1049e+02,  7.6704e-01],
        [ 2.5106e+02,  3.5701e+02, -2.6834e-02,  ...,  2.3675e+02,
          1.3483e+02,  2.2241e+00],
        [ 2.5184e+02,  3.5713e+02, -2.3426e-01,  ...,  4.3228e+02,
          2.5501e+02,  4.1557e-01],
        ...,
        [ 2.5596e+02,  3.5924e+02, -1.8996e-01,  ...,  9.3197e+01,
          2.1997e+02,  1.8650e+00],
        [ 2.5489e+02,  3.5875e+02,  8.3644e-01,  ...,  2.1340e+02,
          1.8164e+02,  2.4008e+00],
        [ 2.5054e+02,  3.5432e+02,  7.5037e-01,  ...,  3.0393e+02,
          2.3987e+02,  2.0620e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[266.6917, 374.2115,   2.8624,  ..., 300.1430, 336.2789,   4.3936],
        [266.1143, 373.6827,   2.4657,  ..., 303.6582, 331.4314,   4.3102],
        [271.0101, 375.6594,   1.0636,  ..., 231.5595, 189.0483,   4.3194],
        ...,
 

output tensor([[ 2.5874e+02,  3.6796e+02,  3.8851e-01,  ...,  3.7534e+02,
          1.5253e+02,  1.0697e+00],
        [ 2.5784e+02,  3.6845e+02,  1.8789e+00,  ...,  2.9692e+02,
          3.2747e+02,  8.1518e-01],
        [ 2.5875e+02,  3.6608e+02,  1.1426e+00,  ...,  3.4246e+02,
          2.3878e+02, -8.8819e-02],
        ...,
        [ 2.5696e+02,  3.6721e+02,  1.0540e+00,  ...,  3.5801e+02,
          2.6424e+02,  2.9942e-01],
        [ 2.5950e+02,  3.6763e+02,  1.8935e+00,  ...,  4.1842e+02,
          1.6417e+02,  1.4460e+00],
        [ 2.5879e+02,  3.6588e+02,  2.2418e+00,  ...,  4.5595e+02,
          2.9622e+02, -5.8571e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[256.9162, 364.4837,   1.4159,  ...,  75.9882, 197.3427,   1.6974],
        [259.3482, 367.6555,   2.1356,  ..., 249.0658, 194.2431,   1.3722],
        [257.2754, 364.6531,   1.4358,  ..., 112.4382, 148.9080,   1.4053],
        ...,
        [257.9616, 368.0953,   1.2636,  ...,

output tensor([[259.3025, 366.5246,   0.7763,  ..., 420.0915, 210.8154,   1.2752],
        [258.4796, 367.9073,   1.1357,  ..., 255.1131, 199.9107,   1.0246],
        [256.8072, 365.1580,   1.7597,  ..., 166.9987, 122.4467,   2.0963],
        ...,
        [258.8783, 367.1309,   1.2890,  ..., 158.6913, 176.2845,   0.6975],
        [257.9580, 367.5547,   0.5381,  ..., 400.0133, 351.9629,   1.4247],
        [257.8705, 366.5475,   1.4303,  ..., 192.8588, 176.2970,   1.3009]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5860e+02, 3.6605e+02, 4.7862e-01,  ..., 4.1972e+02, 2.3815e+02,
         1.1365e+00],
        [2.5751e+02, 3.6609e+02, 1.0141e+00,  ..., 1.6814e+02, 1.7510e+02,
         1.2782e+00],
        [2.5779e+02, 3.6749e+02, 1.0564e+00,  ..., 3.0024e+02, 3.4555e+02,
         1.7176e+00],
        ...,
        [2.5697e+02, 3.6602e+02, 6.1835e-01,  ..., 4.1160e+02, 3.4980e+02,
         1.4956e+00],
        [2.5829e+02, 3.6610e+02, 3.1492

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[256.3282, 364.1674,   2.0459,  ..., 145.2267, 138.4464,   1.6387],
        [256.0536, 365.5984,   1.3784,  ..., 276.1525, 101.1484,   0.8676],
        [257.3296, 365.5280,   0.9461,  ..., 362.1081, 267.9171,   0.8325],
        ...,
        [256.3874, 365.3700,   0.9055,  ..., 307.2589, 214.4751,   1.3134],
        [256.7987, 366.1505,   1.8248,  ..., 146.2207, 165.8270,   0.9100],
        [257.1839, 366.1047,   1.6544,  ..., 126.7574, 171.8082,   0.6702]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5951e+02,  3.6999e+02,  8.0288e-01,  ...,  2.4254e+02,
          9.7060e+01, -4.8981e-01],
        [ 2.5837e+02,  3.6776e+02,  6.8198e-01,  ...,  3.1726e+02,
          3.0653e+02,  5.1950e-01],
        [ 2.5762e+02,  3.6856e+02,  1.4639e+00,  ...,  2.9171e+02,
          1.0791e+02,  9.3187e-01],
        ...,
        [ 2.5823e+02,  3.6671e+02,  8.3818e-01,  ...,  2.0566e+02,
        

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5830e+02, 3.6783e+02, 9.7890e-01,  ..., 3.6280e+02, 1.4588e+02,
         3.8648e-01],
        [2.5844e+02, 3.6785e+02, 1.0997e+00,  ..., 7.8568e+01, 2.4324e+02,
         8.6933e-01],
        [2.5856e+02, 3.6790e+02, 8.9252e-01,  ..., 3.7679e+02, 1.5785e+02,
         2.3648e-01],
        ...,
        [2.5925e+02, 3.6958e+02, 1.4580e+00,  ..., 2.8538e+02, 1.1327e+02,
         8.8665e-01],
        [2.5870e+02, 3.6878e+02, 1.2618e+00,  ..., 3.9354e+02, 1.3818e+02,
         8.3155e-01],
        [2.5887e+02, 3.6862e+02, 1.6015e+00,  ..., 1.6881e+02, 1.1182e+02,
         8.1083e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5721e+02, 3.6530e+02, 7.7575e-01,  ..., 3.5934e+02, 3.1834e+02,
         1.4083e+00],
        [2.5661e+02, 3.6495e+02, 8.1997e-01,  ..., 4.1244e+02, 3.4630e+02,
         1.3819e+00],
        [2.5552e+02, 3.6355e+02, 1.8407e-01,  ..., 7.3902e+01, 2.5088e+02,
         

output tensor([[257.4262, 366.8640,   1.2498,  ..., 110.3271, 198.2941,   0.7244],
        [257.0942, 366.1320,   1.2841,  ..., 274.1754,  85.8241,   0.7406],
        [256.4929, 364.1742,   1.3499,  ..., 425.2547, 244.1153,   1.6086],
        ...,
        [257.1940, 365.4816,   1.3745,  ..., 302.1988, 224.2452,   1.2643],
        [256.3176, 364.5029,   1.6234,  ..., 457.6174, 292.4843,   0.8680],
        [257.3883, 366.1951,   0.8632,  ..., 316.1973, 335.5130,   0.7350]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.3478, 368.8911,   1.2941,  ..., 301.4836, 323.7529,   0.8079],
        [259.0495, 367.9417,   1.6767,  ..., 309.0816, 245.4622,   1.2994],
        [258.4261, 368.0709,   1.8748,  ..., 326.2169, 135.6388,   1.3368],
        ...,
        [258.7669, 368.1168,   1.0270,  ...,  89.0964, 227.1856,   0.8918],
        [258.9086, 368.1252,   1.2120,  ..., 390.5671, 147.7724,   0.8821],
        [259.1444, 367.7762,   1.6034,  ..., 41

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.8420, 367.0706,   1.1050,  ..., 394.1466, 199.6666,   0.7332],
        [258.1332, 367.6654,   1.2409,  ..., 378.0142, 262.6999,   0.8082],
        [257.9737, 367.4402,   1.5799,  ...,  58.0558, 288.6902,   0.5533],
        ...,
        [257.3323, 366.4922,   1.3313,  ...,  81.1695, 225.5526,   1.1485],
        [258.1490, 366.8982,   1.4077,  ..., 236.9584, 132.8208,   1.5832],
        [257.8889, 366.5050,   1.1071,  ..., 427.2288, 214.6979,   1.0920]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5800e+02, 3.6757e+02, 9.2703e-01,  ..., 2.1727e+02, 1.7347e+02,
         4.1260e-01],
        [2.5723e+02, 3.6695e+02, 8.3051e-01,  ..., 3.3408e+02, 1.1643e+02,
         8.1128e-01],
        [2.5843e+02, 3.6787e+02, 1.2725e+00,  ..., 9.1687e+01, 1.9805e+02,
         6.6587e-01],
        ...,
        [2.5812e+02, 3.6707e+02, 1.1420e+00,  ..., 2.9993e+02, 1.8712e+02,
         1.0973e+0

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[261.0862, 371.3140,   1.2344,  ..., 299.1771, 335.4663,   0.8799],
        [261.5456, 371.9842,   1.4267,  ..., 364.7499, 269.5180,   0.8355],
        [261.5224, 371.9408,   1.3949,  ..., 345.2416, 241.8883,   0.8088],
        ...,
        [261.4499, 371.2220,   1.1631,  ..., 252.0872, 146.7852,   1.5113],
        [260.6780, 371.6749,   1.0200,  ..., 141.2271, 168.1699,   0.7532],
        [261.2519, 371.2176,   1.3912,  ..., 263.0809, 106.9246,   1.0402]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[254.8915, 362.5087,   1.0207,  ..., 377.8952, 215.7033,   0.9888],
        [255.9364, 364.3471,   1.3192,  ..., 171.4739, 114.5762,   0.7860],
        [255.9710, 364.2101,   1.6868,  ..., 111.1566, 146.7214,   1.7967],
        ...,
        [255.1939, 363.2451,   1.0618,  ..., 344.1856, 218.0946,   0.7944],
        [254.5429, 361.8326,   0.9690,  ..., 420.6039, 276.5018,   1.5332],
    

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.4005, 366.8682,   0.9898,  ..., 420.0239, 212.1777,   1.2525],
        [257.8837, 367.2172,   1.2592,  ..., 382.4983, 343.6003,   1.2035],
        [258.4963, 367.8636,   1.0530,  ..., 344.2221, 345.4112,   0.9743],
        ...,
        [258.5799, 367.2492,   1.0707,  ...,  88.2445, 189.5333,   1.4604],
        [258.1815, 367.4009,   0.9471,  ..., 406.1094, 350.7847,   1.2281],
        [257.6741, 366.6779,   0.8393,  ..., 351.1253, 131.2600,   0.6772]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.5856, 367.1196,   0.9856,  ..., 178.6612, 100.8925,   0.6364],
        [257.9474, 367.1971,   1.1810,  ..., 317.7581, 259.3060,   0.9053],
        [257.7975, 366.4336,   1.1047,  ..., 409.7416, 154.7227,   0.9495],
        ...,
        [257.9799, 367.3925,   1.1196,  ..., 148.6050, 166.9655,   0.7383],
        [257.3958, 366.6135,   0.7914,  ...,  74.7917, 266.2619,   1.0987],
    

output tensor([[ 2.6138e+02,  3.7242e+02, -6.0152e-01,  ...,  2.9247e+02,
          1.1141e+02,  1.9842e+00],
        [ 2.6000e+02,  3.7025e+02, -1.4529e+00,  ...,  3.7227e+02,
          2.7484e+02,  1.4888e+00],
        [ 2.6137e+02,  3.7199e+02, -1.8339e+00,  ...,  4.2400e+02,
          2.4516e+02,  1.9069e+00],
        ...,
        [ 2.5797e+02,  3.6659e+02, -4.1706e-01,  ...,  8.4447e+01,
          1.9153e+02,  1.9011e+00],
        [ 2.5809e+02,  3.6663e+02, -7.1506e-01,  ...,  8.0902e+01,
          2.3676e+02,  5.8088e-01],
        [ 2.5998e+02,  3.7023e+02, -1.4499e+00,  ...,  3.7172e+02,
          2.7492e+02,  1.4923e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5485e+02,  3.6310e+02, -3.7767e-01,  ...,  3.0828e+02,
          3.3763e+02, -4.4880e-01],
        [ 2.5488e+02,  3.6307e+02, -3.6174e-01,  ...,  3.0777e+02,
          3.3367e+02, -5.4126e-01],
        [ 2.5479e+02,  3.6298e+02, -5.8654e-01,  ...,  3.6387e+02,
          3.

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.6323, 368.6615,   1.1059,  ..., 349.6877, 220.7123,   0.5998],
        [257.8846, 367.2959,   1.2489,  ..., 234.8780, 100.6621,   0.9601],
        [258.3000, 367.4481,   1.1875,  ..., 444.8029, 247.0572,   0.7375],
        ...,
        [258.0285, 367.3871,   1.2234,  ..., 309.1972, 208.9349,   1.0697],
        [258.2463, 366.8857,   0.8177,  ..., 174.4086, 134.0387,   1.2139],
        [258.6017, 367.6827,   0.8320,  ..., 132.4295, 151.2502,   1.2606]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5586e+02, 3.6401e+02, 2.9790e-01,  ..., 7.1887e+01, 2.9229e+02,
         2.6259e-01],
        [2.5681e+02, 3.6507e+02, 5.5625e-01,  ..., 3.6842e+02, 3.4807e+02,
         1.0493e+00],
        [2.5738e+02, 3.6560e+02, 7.6881e-01,  ..., 1.1459e+02, 1.4749e+02,
         1.5073e+00],
        ...,
        [2.5705e+02, 3.6525e+02, 5.5425e-01,  ..., 8.3449e+01, 1.9224e+02,
         1.4189e+0

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5829e+02, 3.6779e+02, 1.1531e+00,  ..., 3.4096e+02, 2.5071e+02,
         6.2455e-01],
        [2.5868e+02, 3.6780e+02, 1.2059e+00,  ..., 3.1756e+02, 3.4585e+02,
         1.1394e+00],
        [2.5832e+02, 3.6764e+02, 1.5602e+00,  ..., 4.5882e+02, 2.6288e+02,
         8.6145e-01],
        ...,
        [2.5883e+02, 3.6849e+02, 1.0266e+00,  ..., 3.8135e+02, 2.6525e+02,
         7.1791e-01],
        [2.5788e+02, 3.6712e+02, 9.2175e-01,  ..., 3.8124e+02, 1.4838e+02,
         4.0298e-01],
        [2.5817e+02, 3.6682e+02, 1.1257e+00,  ..., 4.2544e+02, 2.2737e+02,
         1.2658e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[256.5284, 365.0699,   1.3006,  ..., 400.6638, 293.6173,   1.5770],
        [257.3252, 365.3889,   1.3737,  ..., 448.1155, 232.3089,   1.2741],
        [256.4820, 365.2922,   1.1009,  ..., 362.8048, 260.2606,   1.0257],
        ...,
        [256.8735, 364.9688,   1.0594

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5287e+02, 3.5935e+02, 8.5393e-01,  ..., 6.0166e+01, 3.0754e+02,
         7.2849e-01],
        [2.5311e+02, 3.5986e+02, 3.8382e-01,  ..., 3.7782e+02, 2.1508e+02,
         9.5846e-01],
        [2.5261e+02, 3.5939e+02, 7.2872e-01,  ..., 3.8084e+02, 3.4319e+02,
         1.4966e+00],
        ...,
        [2.5301e+02, 3.5996e+02, 8.8511e-02,  ..., 1.7479e+02, 1.7303e+02,
         9.9452e-01],
        [2.5282e+02, 3.5976e+02, 8.1413e-01,  ..., 4.4153e+02, 3.0414e+02,
         1.3250e+00],
        [2.5320e+02, 3.5965e+02, 5.2789e-01,  ..., 4.1883e+02, 2.2243e+02,
         1.5441e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.6305e+02, 3.7491e+02, 1.7737e+00,  ..., 3.4082e+02, 1.0935e+02,
         7.3179e-01],
        [2.6345e+02, 3.7420e+02, 1.7648e+00,  ..., 1.3317e+02, 1.5287e+02,
         1.0478e+00],
        [2.6380e+02, 3.7484e+02, 1.5269e+00,  ..., 4.0117e+02, 1.8319e+02,
         

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5718e+02, 3.6694e+02, 1.5555e+00,  ..., 3.0194e+02, 3.3689e+02,
         8.7394e-01],
        [2.5844e+02, 3.6800e+02, 1.2211e+00,  ..., 3.8761e+02, 1.5897e+02,
         2.6856e-01],
        [2.5892e+02, 3.6927e+02, 1.3605e+00,  ..., 3.3757e+02, 1.1260e+02,
         7.1559e-01],
        ...,
        [2.5831e+02, 3.6781e+02, 1.4032e+00,  ..., 3.7076e+02, 2.7145e+02,
         7.7185e-01],
        [2.5835e+02, 3.6804e+02, 1.6428e+00,  ..., 3.2948e+02, 2.6106e+02,
         7.0513e-01],
        [2.5856e+02, 3.6790e+02, 1.3231e+00,  ..., 3.8179e+02, 2.2607e+02,
         6.6425e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5655e+02, 3.6600e+02, 7.8847e-01,  ..., 4.1384e+02, 3.4899e+02,
         1.4936e+00],
        [2.5600e+02, 3.6510e+02, 4.5683e-01,  ..., 7.3301e+01, 2.7175e+02,
         1.3007e+00],
        [2.5689e+02, 3.6661e+02, 1.8282e-01,  ..., 1.9254e+02, 1.7565e+02,
         

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.6398e+02, 3.7681e+02, 2.2155e+00,  ..., 2.9192e+02, 1.0990e+02,
         9.6024e-01],
        [2.6368e+02, 3.7470e+02, 1.6156e+00,  ..., 3.8337e+02, 3.1621e+02,
         1.2443e+00],
        [2.6730e+02, 3.7952e+02, 1.3739e+00,  ..., 4.2481e+02, 3.5308e+02,
         7.1542e-01],
        ...,
        [2.6678e+02, 3.7885e+02, 7.9833e-01,  ..., 2.0822e+02, 1.8483e+02,
         1.2282e-01],
        [2.6609e+02, 3.7808e+02, 1.7583e+00,  ..., 2.1936e+02, 1.2830e+02,
         1.0880e+00],
        [2.6637e+02, 3.7770e+02, 9.2750e-01,  ..., 4.0693e+02, 3.5656e+02,
         6.5621e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5546e+02, 3.6380e+02, 1.5971e+00,  ..., 2.5784e+02, 2.0342e+02,
         3.8989e-01],
        [2.5166e+02, 3.5833e+02, 9.5747e-01,  ..., 3.1264e+02, 3.0013e+02,
         1.1587e+00],
        [2.5596e+02, 3.6401e+02, 1.6184e+00,  ..., 2.4114e+02, 1.4431e+02,
         

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.6024e+02, 3.7034e+02, 5.4038e-01,  ..., 4.1544e+02, 2.9712e+02,
         7.1792e-01],
        [2.5962e+02, 3.6974e+02, 2.1950e-01,  ..., 3.8190e+02, 1.5368e+02,
         2.4978e-01],
        [2.5978e+02, 3.6962e+02, 8.1416e-01,  ..., 1.7072e+02, 1.2225e+02,
         1.1257e+00],
        ...,
        [2.5992e+02, 3.7009e+02, 7.7440e-01,  ..., 3.5191e+02, 2.2185e+02,
         4.2146e-01],
        [2.5949e+02, 3.6906e+02, 4.6955e-01,  ..., 1.6394e+02, 1.3795e+02,
         9.3874e-01],
        [2.5935e+02, 3.6988e+02, 4.8071e-01,  ..., 2.9838e+02, 3.3281e+02,
         6.0749e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5522e+02, 3.6421e+02, 5.1871e-01,  ..., 3.0443e+02, 3.0945e+02,
         8.9017e-01],
        [2.5472e+02, 3.6382e+02, 4.3329e-01,  ..., 2.9184e+02, 3.2353e+02,
         8.6099e-01],
        [2.5476e+02, 3.6280e+02, 4.5399e-01,  ..., 4.1992e+02, 2.3148e+02,
         

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[254.2728, 362.2202,   1.5007,  ..., 274.3150, 158.0609,   1.5739],
        [254.4860, 362.2787,   1.1497,  ..., 415.1624, 188.6371,   1.3229],
        [253.9469, 362.3694,   1.2468,  ..., 338.0195, 233.0414,   0.9057],
        ...,
        [254.1718, 362.5578,   1.2183,  ..., 341.7372, 201.3492,   0.8928],
        [253.4204, 362.2695,   1.5171,  ..., 286.1153,  85.8789,   1.2692],
        [254.3076, 362.9398,   1.6280,  ..., 446.7957, 293.8568,   1.1470]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[261.1972, 373.1641,   1.7703,  ..., 321.5427, 123.6559,   1.2168],
        [261.8117, 372.4528,   0.8835,  ..., 392.4744, 215.2090,   0.5961],
        [261.7873, 371.7158,   1.4275,  ..., 104.0163, 160.8308,   1.4543],
        ...,
        [263.5529, 374.7027,   1.3866,  ..., 460.5457, 246.7828,   0.7156],
        [262.5329, 373.3765,   1.3786,  ..., 391.4316, 133.5540,   0.9807],
    

output tensor([[ 2.6319e+02,  3.7364e+02,  1.4560e+00,  ...,  3.1434e+02,
          2.1749e+02,  4.2964e-01],
        [ 2.6446e+02,  3.7572e+02,  8.0948e-01,  ...,  8.3135e+01,
          2.5092e+02,  9.0237e-02],
        [ 2.6439e+02,  3.7545e+02,  1.4608e+00,  ...,  1.9343e+02,
          1.2405e+02,  5.9135e-01],
        ...,
        [ 2.6115e+02,  3.6996e+02,  1.3480e+00,  ...,  3.7422e+02,
          2.7379e+02,  2.9737e-01],
        [ 2.6518e+02,  3.7698e+02,  1.2969e+00,  ...,  2.8247e+02,
          9.3913e+01, -1.1309e-01],
        [ 2.6599e+02,  3.7579e+02,  1.6504e+00,  ...,  6.9197e+01,
          3.0741e+02,  5.5387e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5829e+02,  3.6741e+02,  1.5919e+00,  ...,  1.7900e+02,
          1.2155e+02,  9.4848e-01],
        [ 2.5973e+02,  3.6858e+02,  5.5142e-01,  ...,  1.7456e+02,
          1.0404e+02, -6.8079e-02],
        [ 2.5321e+02,  3.5974e+02,  1.2345e+00,  ...,  3.7107e+02,
          2.

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5745e+02, 3.6641e+02, 9.0639e-01,  ..., 3.0729e+02, 3.0757e+02,
         1.1680e+00],
        [2.5739e+02, 3.6584e+02, 3.9790e-01,  ..., 2.8143e+02, 9.7544e+01,
         8.5301e-01],
        [2.5772e+02, 3.6664e+02, 8.4067e-01,  ..., 3.1950e+02, 3.0435e+02,
         1.0304e+00],
        ...,
        [2.5775e+02, 3.6677e+02, 7.9436e-01,  ..., 3.4431e+02, 2.0039e+02,
         9.6648e-01],
        [2.5724e+02, 3.6610e+02, 8.3909e-01,  ..., 4.0644e+02, 2.9387e+02,
         1.2977e+00],
        [2.5733e+02, 3.6582e+02, 9.3004e-01,  ..., 4.4935e+02, 2.3677e+02,
         9.3549e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5846e+02, 3.6951e+02, 7.5229e-01,  ..., 1.5298e+02, 1.7032e+02,
         5.5722e-01],
        [2.5749e+02, 3.6658e+02, 5.2373e-01,  ..., 2.9806e+02, 3.3717e+02,
         8.9604e-01],
        [2.5780e+02, 3.6637e+02, 8.9416e-01,  ..., 4.3559e+02, 3.1102e+02,
         

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[252.5221, 360.3967,   1.5179,  ..., 372.8143, 309.3342,   1.6093],
        [255.6198, 362.9707,   1.0284,  ..., 394.5927, 142.9047,   1.6229],
        [254.9846, 360.7414,   0.8825,  ..., 433.3736, 246.6939,   1.4974],
        ...,
        [254.1005, 361.7626,   0.9832,  ..., 271.4795,  84.1616,   1.1094],
        [254.1388, 361.1095,   1.1144,  ...,  66.2344, 294.0549,   1.0048],
        [255.1971, 361.4680,   0.5133,  ..., 157.2056, 110.9754,   0.6810]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.9278, 367.8914,   0.6543,  ..., 211.7872, 175.8756,   0.6142],
        [259.0363, 368.0614,   0.8777,  ..., 235.9297, 184.5284,   0.5989],
        [259.4496, 369.2207,   1.2616,  ..., 348.2516, 221.4593,   0.5186],
        ...,
        [259.5354, 369.3347,   1.2906,  ..., 274.9838,  87.9501,   0.7128],
        [260.0267, 369.1902,   0.6442,  ..., 359.5538, 143.7105,   0.7042],
    

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[260.9356, 371.4520,   1.3584,  ..., 398.3302, 180.2703,   0.6766],
        [260.1195, 370.4664,   1.7889,  ..., 218.1165, 120.9453,   1.6714],
        [261.0570, 370.5640,   1.4687,  ..., 337.2247, 209.4621,   1.2102],
        ...,
        [259.9947, 369.0035,   0.9076,  ..., 234.4027, 102.2129,   0.7467],
        [260.3332, 370.4982,   1.6880,  ..., 204.6122, 118.8299,   1.5843],
        [261.7513, 370.6765,   1.1074,  ..., 391.7061, 154.0158,   1.1626]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[253.9862, 361.5656,   0.9803,  ..., 375.5750, 217.8992,   1.0589],
        [252.6792, 360.9889,   0.8157,  ...,  72.8896, 254.6311,   0.8615],
        [253.2516, 362.0464,   1.2015,  ..., 100.2396, 194.9252,   0.9723],
        ...,
        [254.5236, 362.9176,   1.3483,  ..., 417.3569, 170.8626,   1.0328],
        [253.3754, 361.7482,   0.7573,  ..., 233.4622, 183.4643,   0.8751],
    

output tensor([[2.5131e+02, 3.5693e+02, 2.2687e-01,  ..., 5.3428e+01, 2.6487e+02,
         1.5660e+00],
        [2.5269e+02, 3.6144e+02, 6.0676e-01,  ..., 3.4638e+02, 3.2867e+02,
         1.3599e+00],
        [2.5197e+02, 3.5805e+02, 1.0675e-01,  ..., 5.7350e+01, 2.4760e+02,
         1.5167e+00],
        ...,
        [2.6014e+02, 3.6944e+02, 1.6722e+00,  ..., 4.0214e+02, 1.4379e+02,
         1.4898e+00],
        [2.5784e+02, 3.6775e+02, 7.3182e-01,  ..., 4.3789e+02, 2.2936e+02,
         6.5152e-01],
        [2.5561e+02, 3.6452e+02, 8.8906e-01,  ..., 1.8581e+02, 1.1661e+02,
         1.5917e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6614e+02,  3.7746e+02,  1.2525e+00,  ...,  3.2922e+02,
          1.5460e+02,  7.9344e-01],
        [ 2.6415e+02,  3.7454e+02,  1.2848e+00,  ...,  3.4210e+02,
          2.5786e+02,  7.1507e-01],
        [ 2.6525e+02,  3.7634e+02, -1.3306e-01,  ...,  2.2066e+02,
          1.8736e+02,  5.9519e-01],
        ...,

output tensor([[259.7191, 369.9308,   0.9970,  ..., 106.5498, 202.7359,   1.1118],
        [258.8364, 369.2770,   1.3226,  ..., 144.5106, 170.6797,   1.3294],
        [260.4663, 370.3339,   1.3261,  ..., 349.2419, 209.6982,   1.1913],
        ...,
        [259.2455, 368.0858,   0.9063,  ..., 141.6729, 149.8898,   1.3397],
        [260.7814, 370.4457,   1.3642,  ..., 299.5516, 338.2704,   1.7947],
        [261.4111, 371.3916,   1.1430,  ..., 341.4269, 340.2173,   1.6880]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[255.7848, 364.9229,   1.4604,  ..., 372.3740, 347.6946,   1.9989],
        [255.9268, 364.2489,   0.4668,  ..., 231.0078,  98.4108,   0.5146],
        [255.7099, 365.8561,   1.2465,  ..., 378.2245, 141.1809,   1.3917],
        ...,
        [255.2281, 363.9915,   1.4264,  ..., 281.9042, 166.0042,   1.3088],
        [255.8042, 365.1028,   1.3767,  ..., 347.7156, 191.2236,   0.8853],
        [256.4587, 364.7106,   1.2789,  ...,  8

output tensor([[2.6045e+02, 3.7076e+02, 1.6787e+00,  ..., 2.6830e+02, 1.1026e+02,
         4.7386e-01],
        [2.6017e+02, 3.7110e+02, 1.3702e+00,  ..., 3.1938e+02, 3.3787e+02,
         3.1161e-01],
        [2.6064e+02, 3.7134e+02, 1.5816e+00,  ..., 3.3806e+02, 1.1137e+02,
         2.3895e-01],
        ...,
        [2.6116e+02, 3.7031e+02, 1.1996e+00,  ..., 3.9159e+02, 1.5230e+02,
         2.1047e-01],
        [2.6171e+02, 3.7078e+02, 1.4100e+00,  ..., 6.1488e+01, 2.5892e+02,
         8.6257e-01],
        [2.6018e+02, 3.7149e+02, 2.1640e+00,  ..., 2.8550e+02, 1.0761e+02,
         4.8773e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[255.3054, 362.8468,   1.2361,  ..., 344.8252, 195.9144,   1.5164],
        [253.5852, 361.9066,   1.3276,  ..., 340.0128, 345.7412,   2.0940],
        [255.1797, 363.4201,   1.2658,  ..., 431.5935, 254.0073,   1.8033],
        ...,
        [255.4071, 362.2050,   1.1397,  ..., 418.6897, 230.2401,   2.1235],
    

output tensor([[2.5900e+02, 3.6841e+02, 8.8053e-01,  ..., 1.6793e+02, 1.2072e+02,
         1.3674e+00],
        [2.5912e+02, 3.6803e+02, 7.8245e-01,  ..., 9.4257e+01, 1.6784e+02,
         1.4354e+00],
        [2.5866e+02, 3.6806e+02, 2.4734e-01,  ..., 3.9561e+02, 2.1017e+02,
         7.0264e-01],
        ...,
        [2.5916e+02, 3.6819e+02, 1.5328e-01,  ..., 3.2842e+02, 3.0755e+02,
         6.6770e-01],
        [2.5812e+02, 3.6846e+02, 3.6536e-01,  ..., 1.5550e+02, 1.6875e+02,
         7.2833e-01],
        [2.5713e+02, 3.6695e+02, 2.0633e-01,  ..., 8.2011e+01, 2.3431e+02,
         5.1342e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[256.9021, 366.3788,   0.8806,  ..., 298.3259, 332.6727,   0.7450],
        [257.4751, 366.5541,   0.6908,  ...,  81.2246, 243.5454,   0.4457],
        [258.0706, 365.9382,   0.5566,  ..., 334.9171, 207.5376,   0.9202],
        ...,
        [258.4934, 366.5462,   0.9461,  ...,  84.3840, 184.2403,   1.3509],
    

output tensor([[ 2.6080e+02,  3.7054e+02,  8.4831e-01,  ...,  3.6164e+02,
          3.2797e+02,  2.1528e+00],
        [ 2.6508e+02,  3.7696e+02,  4.5071e-01,  ...,  3.3826e+02,
          1.1328e+02,  1.5719e+00],
        [ 2.6131e+02,  3.6993e+02,  7.0486e-01,  ...,  3.8024e+02,
          2.7718e+02,  1.8497e+00],
        ...,
        [ 2.6327e+02,  3.7401e+02, -1.9034e-01,  ...,  4.0266e+02,
          1.9697e+02,  1.1256e+00],
        [ 2.6082e+02,  3.7112e+02,  5.4670e-01,  ...,  3.0525e+02,
          3.4011e+02,  2.2714e+00],
        [ 2.6061e+02,  3.7070e+02,  4.8670e-01,  ...,  3.4547e+02,
          3.5333e+02,  2.3554e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5025e+02,  3.5579e+02,  1.0289e-01,  ...,  2.4994e+02,
          2.0168e+02,  4.5947e-01],
        [ 2.4836e+02,  3.5252e+02,  2.1636e-02,  ...,  3.3261e+02,
          2.9968e+02,  5.3165e-01],
        [ 2.5159e+02,  3.5874e+02,  8.6880e-02,  ...,  2.7491e+02,
          9.

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5701e+02, 3.6575e+02, 7.6565e-01,  ..., 1.0010e+02, 1.8702e+02,
         5.9409e-01],
        [2.5729e+02, 3.6544e+02, 1.1343e+00,  ..., 3.0658e+02, 2.4248e+02,
         9.4377e-01],
        [2.5723e+02, 3.6754e+02, 8.7437e-01,  ..., 3.9651e+02, 1.9911e+02,
         2.0647e-01],
        ...,
        [2.5679e+02, 3.6522e+02, 1.1922e+00,  ..., 3.0228e+02, 2.2445e+02,
         7.6179e-01],
        [2.5729e+02, 3.6543e+02, 1.1483e+00,  ..., 3.0563e+02, 2.4142e+02,
         9.5291e-01],
        [2.5823e+02, 3.6654e+02, 8.2701e-01,  ..., 3.1291e+02, 2.9938e+02,
         7.2554e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.5153, 367.5145,   1.0130,  ..., 358.2283, 324.2172,   1.9770],
        [258.3591, 367.6314,   1.3066,  ..., 165.9542, 119.8137,   2.2058],
        [258.2411, 367.1157,   0.9593,  ...,  87.1907, 189.8127,   1.8204],
        ...,
        [257.6190, 368.5873,   1.1554

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5919e+02, 3.6807e+02, 1.1668e+00,  ..., 2.1289e+02, 8.9484e+01,
         1.5222e+00],
        [2.5838e+02, 3.6771e+02, 1.4472e+00,  ..., 3.1793e+02, 9.3760e+01,
         1.4610e+00],
        [2.5799e+02, 3.6622e+02, 4.2097e-01,  ..., 1.3275e+02, 1.4992e+02,
         1.6907e+00],
        ...,
        [2.5764e+02, 3.6649e+02, 4.1496e-01,  ..., 9.4989e+01, 2.1806e+02,
         1.5534e+00],
        [2.5774e+02, 3.6650e+02, 6.6724e-01,  ..., 1.0036e+02, 1.8909e+02,
         1.4119e+00],
        [2.5836e+02, 3.6605e+02, 3.2582e-01,  ..., 3.5214e+02, 1.3794e+02,
         1.7041e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5831e+02,  3.6783e+02,  9.2326e-01,  ...,  3.0856e+02,
          3.4328e+02,  1.1236e-01],
        [ 2.5824e+02,  3.6781e+02,  9.0020e-01,  ...,  3.0589e+02,
          3.4278e+02,  9.3758e-02],
        [ 2.5835e+02,  3.6727e+02,  1.1754e+00,  ...,  4.0980e+02,
     

output tensor([[2.6011e+02, 3.6964e+02, 1.2344e+00,  ..., 2.0000e+02, 9.3434e+01,
         9.8296e-01],
        [2.5918e+02, 3.6781e+02, 2.9409e-02,  ..., 1.9170e+02, 1.7970e+02,
         9.2748e-01],
        [2.5902e+02, 3.6793e+02, 5.0385e-01,  ..., 1.4197e+02, 1.4953e+02,
         9.1798e-01],
        ...,
        [2.5934e+02, 3.6976e+02, 1.0488e+00,  ..., 3.9470e+02, 1.6966e+02,
         6.2382e-01],
        [2.5989e+02, 3.6861e+02, 9.5445e-01,  ..., 3.1673e+02, 2.6209e+02,
         1.0723e+00],
        [2.5925e+02, 3.6963e+02, 7.9621e-01,  ..., 1.0589e+02, 1.9874e+02,
         7.9410e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5540e+02,  3.6275e+02, -8.0156e-02,  ...,  2.1353e+02,
          1.8293e+02,  8.0353e-01],
        [ 2.5546e+02,  3.6369e+02,  4.6716e-01,  ...,  2.5970e+02,
          2.0428e+02,  4.6750e-01],
        [ 2.5527e+02,  3.6495e+02,  9.6089e-01,  ...,  1.3075e+02,
          1.6703e+02,  1.0392e+00],
        ...,

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6314e+02,  3.7288e+02,  1.0043e+00,  ...,  3.8151e+02,
          2.4006e+02, -5.5249e-01],
        [ 2.6371e+02,  3.7279e+02,  1.5333e+00,  ...,  4.4721e+02,
          3.1283e+02, -6.6792e-01],
        [ 2.6201e+02,  3.7202e+02,  1.0759e+00,  ...,  2.7915e+02,
          1.1643e+02,  6.0782e-02],
        ...,
        [ 2.6320e+02,  3.7368e+02,  1.7891e+00,  ...,  2.7523e+02,
          8.8831e+01, -8.0429e-01],
        [ 2.6361e+02,  3.7249e+02,  1.0524e+00,  ...,  4.2260e+02,
          2.1311e+02, -3.3711e-01],
        [ 2.6290e+02,  3.7347e+02,  1.4167e+00,  ...,  3.3054e+02,
          1.5062e+02, -3.5832e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5371e+02, 3.6077e+02, 2.3118e-01,  ..., 3.3439e+02, 1.2251e+02,
         3.3877e+00],
        [2.5464e+02, 3.6220e+02, 9.7671e-01,  ..., 2.3285e+02, 1.3523e+02,
         3.3309e+00],
        [2.5434e+02, 3.6344e+02, 1.9043e+00,  ..

output tensor([[264.4440, 375.5832,   1.0658,  ..., 259.5295, 197.3744,   0.6290],
        [265.6643, 377.8951,   1.0405,  ..., 354.7211, 352.8176,   1.0147],
        [265.3726, 377.6541,   1.0961,  ..., 314.3407, 339.6897,   0.9631],
        ...,
        [265.7044, 378.1147,   0.8762,  ..., 361.3341, 351.5065,   0.9807],
        [263.7750, 375.0727,   1.3569,  ..., 317.8260, 222.8941,   0.8698],
        [266.1623, 378.5289,   1.3210,  ..., 463.8471, 306.8444,   0.7872]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5176e+02, 3.5941e+02, 7.9375e-01,  ..., 7.9308e+01, 1.9694e+02,
         2.1718e+00],
        [2.4926e+02, 3.5630e+02, 6.3194e-01,  ..., 3.7399e+02, 3.3705e+02,
         2.7234e+00],
        [2.4991e+02, 3.5631e+02, 5.7401e-01,  ..., 2.5508e+02, 1.0611e+02,
         2.2911e+00],
        ...,
        [2.4972e+02, 3.5745e+02, 3.4014e-01,  ..., 3.0717e+02, 3.2939e+02,
         2.4514e+00],
        [2.5039e+02, 3.5617e+02, 4.5183

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.6417, 369.0347,   0.6032,  ..., 367.1876, 272.1749,   1.9555],
        [258.9012, 368.6663,   0.6763,  ..., 376.6133, 151.8519,   2.0479],
        [258.9140, 367.7751,   1.0187,  ...,  65.5499, 303.9559,   1.1003],
        ...,
        [258.9373, 368.3886,   0.6520,  ..., 297.5270, 333.5760,   1.9177],
        [259.5164, 369.0693,   0.4220,  ..., 337.2015, 346.5762,   1.9442],
        [259.1372, 368.4484,   0.7173,  ..., 329.0701, 261.2086,   1.8458]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5720e+02, 3.6468e+02, 1.1263e+00,  ..., 2.3901e+02, 1.3926e+02,
         1.0564e+00],
        [2.5787e+02, 3.6693e+02, 8.4201e-01,  ..., 3.5522e+02, 3.4675e+02,
         2.1348e-01],
        [2.5652e+02, 3.6443e+02, 1.3228e+00,  ..., 8.1051e+01, 2.9140e+02,
         1.7390e-01],
        ...,
        [2.5738e+02, 3.6647e+02, 1.6650e+00,  ..., 2.3098e+02, 1.2364e+02,
         1.0562e+0

output tensor([[ 2.6369e+02,  3.7329e+02,  7.5587e-01,  ...,  3.2322e+02,
          3.1120e+02, -5.3046e-03],
        [ 2.6201e+02,  3.7183e+02,  1.2473e+00,  ...,  1.7121e+02,
          1.1486e+02,  2.8959e-01],
        [ 2.6274e+02,  3.7299e+02,  1.2175e+00,  ...,  3.4040e+02,
          2.5997e+02,  2.5578e-01],
        ...,
        [ 2.6390e+02,  3.7308e+02,  1.8735e+00,  ...,  4.3420e+02,
          3.1787e+02,  2.5317e-01],
        [ 2.6419e+02,  3.7399e+02,  1.3666e+00,  ...,  4.2604e+02,
          3.5388e+02,  2.8098e-01],
        [ 2.6272e+02,  3.7139e+02,  1.5761e+00,  ...,  5.8940e+01,
          2.5967e+02,  3.3145e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5289e+02, 3.6002e+02, 4.9770e-01,  ..., 2.5293e+02, 1.0300e+02,
         1.7610e+00],
        [2.5290e+02, 3.6226e+02, 1.7817e+00,  ..., 3.1109e+02, 1.1045e+02,
         1.7556e+00],
        [2.5289e+02, 3.5981e+02, 3.7179e-01,  ..., 3.8892e+02, 1.4484e+02,
         1.8691e

output tensor([[2.5507e+02, 3.6416e+02, 8.7295e-01,  ..., 3.5771e+02, 3.2871e+02,
         1.4014e+00],
        [2.5396e+02, 3.6174e+02, 7.7625e-01,  ..., 3.8719e+02, 2.1643e+02,
         2.1768e+00],
        [2.5698e+02, 3.6497e+02, 1.2198e-02,  ..., 2.6187e+02, 1.0144e+02,
         9.7981e-01],
        ...,
        [2.5599e+02, 3.6486e+02, 1.1303e+00,  ..., 3.2248e+02, 1.4768e+02,
         2.3073e+00],
        [2.5650e+02, 3.6692e+02, 1.8440e+00,  ..., 9.7030e+01, 1.9636e+02,
         1.6182e+00],
        [2.5415e+02, 3.6326e+02, 3.7456e-01,  ..., 4.1584e+02, 2.9415e+02,
         1.3293e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6034e+02,  3.6892e+02,  1.6769e+00,  ...,  3.0633e+02,
          2.4432e+02,  1.1440e+00],
        [ 2.5893e+02,  3.6810e+02,  1.5081e+00,  ...,  4.6222e+02,
          2.4794e+02, -1.0411e+00],
        [ 2.6072e+02,  3.7069e+02,  1.2462e+00,  ...,  3.5804e+02,
          3.2331e+02, -2.9776e-01],
        ...,

output tensor([[ 2.5836e+02,  3.6756e+02,  6.1621e-01,  ...,  4.1283e+02,
          3.5027e+02,  7.9756e-01],
        [ 2.5739e+02,  3.6679e+02,  1.6441e+00,  ...,  3.6102e+02,
          3.1878e+02,  1.0625e+00],
        [ 2.5788e+02,  3.6807e+02,  9.1538e-01,  ...,  4.3876e+02,
          2.3260e+02,  1.0938e+00],
        ...,
        [ 2.5707e+02,  3.6697e+02,  1.3124e+00,  ...,  3.8070e+02,
          1.4450e+02,  7.5570e-01],
        [ 2.5728e+02,  3.6783e+02,  1.6104e+00,  ...,  3.8799e+02,
          1.4980e+02,  7.2241e-01],
        [ 2.5773e+02,  3.6662e+02,  8.9255e-01,  ...,  6.3339e+01,
          3.0420e+02, -5.6601e-03]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.2404, 367.4932,   0.9709,  ..., 428.2183, 179.8379,   1.3351],
        [258.8299, 368.0626,   1.5467,  ..., 171.0307, 101.6645,   0.7814],
        [258.1122, 367.5715,   1.2428,  ..., 420.2704, 160.9274,   1.1648],
        ...,
        [258.1673, 366.9965,   0.9974,  ...,

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[250.5038, 358.2253,   1.0216,  ..., 259.6628, 196.3902,   0.7985],
        [251.7586, 360.9823,   1.5035,  ..., 307.7194, 335.8492,   1.0364],
        [250.1741, 358.8011,   1.3719,  ..., 101.8521, 203.9671,   1.6076],
        ...,
        [253.4140, 361.7416,   0.9391,  ..., 417.8397, 233.8061,   1.3096],
        [251.1688, 359.0036,   1.3861,  ..., 333.1159, 249.5876,   1.3235],
        [251.4067, 359.2139,   1.5790,  ..., 310.6446, 249.6550,   1.1069]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6527e+02,  3.7585e+02,  1.0734e+00,  ...,  4.5795e+02,
          2.4329e+02,  1.3225e+00],
        [ 2.5960e+02,  3.6831e+02,  3.6548e-01,  ...,  1.6737e+02,
          1.4321e+02,  1.8516e+00],
        [ 2.6292e+02,  3.7419e+02,  1.3105e+00,  ...,  3.8939e+02,
          1.3874e+02,  1.2965e+00],
        ...,
        [ 2.6077e+02,  3.7007e+02, -1.3507e-02,  ...,  2.0575e+02,
        

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5626e+02, 3.6510e+02, 8.1084e-01,  ..., 1.7092e+02, 1.1981e+02,
         5.3910e-01],
        [2.5488e+02, 3.6368e+02, 9.5243e-01,  ..., 3.0453e+02, 2.1042e+02,
         1.3346e-01],
        [2.5573e+02, 3.6498e+02, 9.8565e-01,  ..., 4.0604e+02, 1.4637e+02,
         5.6210e-02],
        ...,
        [2.5536e+02, 3.6346e+02, 4.5542e-01,  ..., 4.1481e+02, 2.2685e+02,
         2.4785e-01],
        [2.5654e+02, 3.6485e+02, 1.9651e-01,  ..., 1.0017e+02, 1.8588e+02,
         4.1383e-01],
        [2.5533e+02, 3.6395e+02, 7.5516e-01,  ..., 3.3278e+02, 2.5440e+02,
         3.2853e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.6079e+02, 3.6941e+02, 5.4144e-01,  ..., 2.8540e+02, 9.3894e+01,
         1.0009e+00],
        [2.6119e+02, 3.7049e+02, 5.1719e-01,  ..., 1.0733e+02, 1.9132e+02,
         1.6858e+00],
        [2.5957e+02, 3.6915e+02, 5.9476e-01,  ..., 3.9670e+02, 1.4218e+02,
         

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.4918, 367.4652,   0.4829,  ...,  82.7572, 280.1933,   0.8448],
        [259.3112, 369.0373,   1.5410,  ..., 206.2844,  88.0525,   0.6934],
        [259.4618, 368.1055,   0.6668,  ..., 368.0457, 348.5639,   0.9364],
        ...,
        [258.8396, 368.3222,   0.6219,  ..., 441.2593, 239.9169,   1.0964],
        [259.3488, 367.8824,   0.5995,  ..., 281.0946,  89.2415,   0.5671],
        [258.7651, 368.2617,   1.2281,  ..., 282.8534,  98.7910,   0.9470]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.8608, 366.2887,   0.9559,  ..., 390.9506, 350.3554,   2.2529],
        [257.4409, 366.0782,   1.0064,  ..., 385.4512, 218.7521,   1.9817],
        [257.5451, 365.5237,   0.9211,  ..., 323.9360, 178.4800,   1.8696],
        ...,
        [257.1740, 365.2233,   0.8665,  ..., 371.0417, 163.1283,   1.8325],
        [257.2993, 365.7586,   0.8009,  ..., 227.5500, 179.7277,   1.6115],
    

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.7009e+02,  3.8213e+02,  1.1961e+00,  ...,  4.3223e+02,
          2.3774e+02, -5.6974e-02],
        [ 2.6802e+02,  3.7933e+02,  8.5672e-01,  ...,  1.7992e+02,
          1.8846e+02,  4.0027e-01],
        [ 2.7068e+02,  3.8185e+02,  1.1556e+00,  ...,  4.2949e+02,
          2.1299e+02, -2.4671e-01],
        ...,
        [ 2.6848e+02,  3.8098e+02,  1.9542e+00,  ...,  1.8804e+02,
          1.0571e+02,  2.4857e-03],
        [ 2.6894e+02,  3.8140e+02,  1.4250e+00,  ...,  3.9399e+02,
          1.5587e+02, -6.8619e-02],
        [ 2.6818e+02,  3.7940e+02,  1.6480e+00,  ...,  3.1347e+02,
          2.3939e+02, -1.1476e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.4414e+02,  3.4910e+02,  6.7659e-01,  ...,  4.1839e+01,
          2.6691e+02,  2.0110e+00],
        [ 2.4417e+02,  3.4902e+02, -3.1251e-01,  ...,  1.7913e+02,
          1.6709e+02,  2.2759e+00],
        [ 2.4356e+02,  3.4855e+02,  

output tensor([[ 2.5784e+02,  3.6707e+02,  1.3395e+00,  ...,  3.7635e+02,
          2.7852e+02, -2.3905e-01],
        [ 2.5828e+02,  3.6640e+02,  5.7961e-01,  ...,  6.1952e+01,
          3.1587e+02, -6.9071e-01],
        [ 2.5789e+02,  3.6803e+02,  1.1982e+00,  ...,  3.9595e+02,
          2.1064e+02, -3.2056e-01],
        ...,
        [ 2.5670e+02,  3.6634e+02,  1.8772e+00,  ...,  2.9671e+02,
          3.2374e+02, -9.4986e-01],
        [ 2.5812e+02,  3.6612e+02,  2.8512e-01,  ...,  2.4028e+02,
          1.0332e+02, -1.8495e-01],
        [ 2.5771e+02,  3.6580e+02,  7.5875e-01,  ...,  3.4273e+02,
          3.0444e+02, -5.6720e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.7405, 368.4300,   1.7117,  ..., 365.1206, 320.5260,   2.6004],
        [258.9806, 368.0675,   0.8210,  ..., 283.9706, 122.2098,   2.3825],
        [259.2175, 368.6990,   1.0155,  ..., 329.2688, 194.0164,   2.0289],
        ...,
        [260.3276, 369.8585,   1.6135,  ...,

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5633e+02, 3.6495e+02, 2.8785e-01,  ..., 1.6325e+02, 1.7823e+02,
         1.4556e+00],
        [2.5587e+02, 3.6427e+02, 4.3448e-01,  ..., 3.6937e+02, 1.6387e+02,
         8.8678e-01],
        [2.5598e+02, 3.6486e+02, 9.0829e-01,  ..., 3.3516e+02, 1.0204e+02,
         7.9444e-01],
        ...,
        [2.5584e+02, 3.6527e+02, 9.7604e-01,  ..., 3.8047e+02, 1.3981e+02,
         6.7632e-01],
        [2.5638e+02, 3.6525e+02, 7.1143e-01,  ..., 3.6910e+02, 2.7149e+02,
         1.1696e+00],
        [2.5710e+02, 3.6638e+02, 9.5931e-01,  ..., 3.5015e+02, 2.0897e+02,
         1.1553e+00]], grad_fn=<AddmmBackward0>)
torch.Size([60, 18])
torch.Size([60, 3])
output tensor([[2.6020e+02, 3.6828e+02, 7.8329e-01,  ..., 2.5786e+02, 1.6021e+02,
         1.0354e+00],
        [2.6023e+02, 3.6932e+02, 1.0350e+00,  ..., 4.4935e+02, 3.0673e+02,
         1.1298e+00],
        [2.6019e+02, 3.6987e+02, 1.1258e+00,  ..., 3.5196e+02, 2.2061e+02,
         1.

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5578e+02, 3.6491e+02, 7.3213e-01,  ..., 4.3285e+02, 2.6741e+02,
         8.4677e-01],
        [2.5549e+02, 3.6322e+02, 1.1915e+00,  ..., 2.9138e+02, 2.3027e+02,
         5.1232e-01],
        [2.5568e+02, 3.6465e+02, 8.7759e-01,  ..., 3.9278e+02, 2.0978e+02,
         8.2514e-01],
        ...,
        [2.5588e+02, 3.6457e+02, 1.3205e+00,  ..., 4.5706e+02, 2.4970e+02,
         5.2965e-01],
        [2.5594e+02, 3.6417e+02, 1.3166e+00,  ..., 4.4962e+02, 2.3230e+02,
         4.4818e-01],
        [2.5616e+02, 3.6426e+02, 5.6226e-01,  ..., 4.0224e+02, 3.5124e+02,
         7.6770e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[260.7831, 371.0134,   1.0482,  ..., 393.7233, 220.1323,   1.0111],
        [260.3522, 369.9953,   1.1740,  ...,  70.5262, 299.5533,   0.9780],
        [260.2923, 370.5544,   1.6315,  ..., 312.1283, 115.4825,   1.0248],
        ...,
        [261.0834, 371.4204,   1.3588

output tensor([[2.5146e+02, 3.5884e+02, 3.0961e-01,  ..., 1.6177e+02, 1.6991e+02,
         2.4665e+00],
        [2.5314e+02, 3.6092e+02, 6.7178e-01,  ..., 3.8527e+02, 2.0896e+02,
         2.0364e+00],
        [2.5273e+02, 3.6134e+02, 5.1280e-01,  ..., 3.5268e+02, 3.2180e+02,
         2.4091e+00],
        ...,
        [2.5177e+02, 3.6097e+02, 1.4668e+00,  ..., 2.8683e+02, 3.2133e+02,
         1.9713e+00],
        [2.5188e+02, 3.5939e+02, 2.6134e-01,  ..., 1.9172e+02, 1.7278e+02,
         2.2494e+00],
        [2.5113e+02, 3.5862e+02, 6.3246e-01,  ..., 5.3058e+01, 3.0277e+02,
         2.5536e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6312e+02,  3.7302e+02,  1.0847e+00,  ...,  1.6102e+02,
          1.1357e+02,  2.3662e-01],
        [ 2.6396e+02,  3.7514e+02,  1.6425e+00,  ...,  3.2808e+02,
          1.2612e+02,  1.9038e-01],
        [ 2.6488e+02,  3.7465e+02,  1.3080e+00,  ...,  4.6678e+02,
          2.4480e+02, -5.3919e-01],
        ...,

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5579e+02, 3.6457e+02, 8.3177e-01,  ..., 4.1657e+02, 2.2862e+02,
         5.6331e-01],
        [2.5576e+02, 3.6409e+02, 1.0781e+00,  ..., 9.7775e+01, 1.8444e+02,
         1.4335e-01],
        [2.5580e+02, 3.6410e+02, 8.4617e-01,  ..., 3.5603e+02, 3.4479e+02,
         5.1371e-01],
        ...,
        [2.5564e+02, 3.6328e+02, 7.3446e-01,  ..., 3.4167e+02, 3.0129e+02,
         2.5397e-01],
        [2.5544e+02, 3.6420e+02, 1.2306e+00,  ..., 3.5965e+02, 3.1506e+02,
         7.1974e-01],
        [2.5575e+02, 3.6411e+02, 1.2067e+00,  ..., 7.5172e+01, 1.8814e+02,
         3.5528e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.6002e+02, 3.6888e+02, 3.4870e-01,  ..., 3.9270e+02, 1.6512e+02,
         1.1148e+00],
        [2.5975e+02, 3.6877e+02, 6.1768e-01,  ..., 3.2299e+02, 1.6759e+02,
         1.1457e+00],
        [2.6051e+02, 3.7089e+02, 1.5194e+00,  ..., 1.9988e+02, 8.9870e+01,
         

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6962e+02,  3.8156e+02,  1.9003e+00,  ...,  9.2355e+01,
          2.0070e+02, -5.8064e-01],
        [ 2.7098e+02,  3.8254e+02,  2.0593e+00,  ...,  3.4954e+02,
          2.6984e+02, -3.3170e-01],
        [ 2.7038e+02,  3.8253e+02,  2.3157e+00,  ...,  3.5675e+02,
          2.5553e+02, -1.5837e-01],
        ...,
        [ 2.7101e+02,  3.8308e+02,  1.9740e+00,  ...,  3.8992e+02,
          2.8435e+02, -5.3349e-02],
        [ 2.6835e+02,  3.7959e+02,  1.1963e+00,  ...,  2.9342e+02,
          1.3083e+02, -6.0047e-01],
        [ 2.7239e+02,  3.8687e+02,  2.8281e+00,  ...,  3.1808e+02,
          3.5177e+02, -1.5684e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.4901e+02, 3.5532e+02, 9.1399e-02,  ..., 3.3146e+02, 2.9368e+02,
         3.8745e+00],
        [2.4752e+02, 3.5469e+02, 4.6368e-01,  ..., 7.0528e+01, 2.8391e+02,
         3.7536e+00],
        [2.4666e+02, 3.5136e+02, 2.0845e-01,  ..

output tensor([[2.5751e+02, 3.6735e+02, 1.1078e+00,  ..., 2.9903e+02, 3.3402e+02,
         4.4663e-01],
        [2.5836e+02, 3.7040e+02, 1.7058e+00,  ..., 1.3699e+02, 1.6791e+02,
         8.8414e-01],
        [2.5941e+02, 3.6664e+02, 1.2877e+00,  ..., 4.4970e+02, 2.4603e+02,
         8.1455e-01],
        ...,
        [2.6043e+02, 3.6938e+02, 1.2588e+00,  ..., 4.1741e+02, 3.4390e+02,
         3.8125e-01],
        [2.5949e+02, 3.6784e+02, 1.1242e+00,  ..., 3.6082e+02, 3.5038e+02,
         5.8896e-01],
        [2.5870e+02, 3.6824e+02, 5.5906e-01,  ..., 4.3946e+02, 2.6226e+02,
         4.3160e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[255.8267, 364.1306,   0.4487,  ..., 214.4088, 182.4446,   0.9627],
        [255.5995, 364.7019,   1.2916,  ..., 324.3058, 126.6512,   1.3095],
        [255.8379, 364.1781,   1.2076,  ..., 298.4016, 236.7685,   0.9264],
        ...,
        [256.2663, 364.5682,   0.9669,  ..., 352.3213, 317.2612,   1.3256],
    

output tensor([[257.9849, 366.5487,   0.4874,  ..., 224.5609, 110.3638,   1.5121],
        [257.2446, 366.8941,   1.0983,  ..., 300.3002, 335.5161,   1.8861],
        [258.4567, 367.2888,   1.0337,  ..., 170.9289, 112.8908,   1.5202],
        ...,
        [258.5250, 367.7538,   0.7954,  ..., 421.8062, 251.1656,   2.1086],
        [258.0817, 366.7953,   0.8789,  ..., 418.1601, 199.9320,   1.9399],
        [258.2895, 366.5034,   0.8480,  ..., 271.3300, 175.7197,   1.7731]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5780e+02,  3.6706e+02,  1.1847e+00,  ...,  3.0629e+02,
          8.4832e+01,  4.1226e-01],
        [ 2.5792e+02,  3.6613e+02,  6.5214e-01,  ...,  2.8473e+02,
          9.7001e+01, -1.1604e-01],
        [ 2.5706e+02,  3.6680e+02,  1.0976e+00,  ...,  4.0042e+02,
          1.9646e+02,  3.0280e-01],
        ...,
        [ 2.5730e+02,  3.6657e+02,  1.0714e+00,  ...,  3.6925e+02,
          3.1392e+02,  2.9976e-01],
        [ 2.575

output tensor([[259.2309, 368.2878,   0.9761,  ..., 399.3734, 354.7873,   0.6709],
        [258.5334, 368.1258,   1.1271,  ..., 407.8424, 298.6566,   0.8064],
        [258.7842, 368.3034,   1.4425,  ..., 293.8742, 102.6565,   0.6051],
        ...,
        [258.8994, 368.3337,   1.3163,  ..., 462.7411, 274.6668,   0.5981],
        [259.0235, 368.4331,   0.9744,  ...,  65.5027, 303.5226,   0.7526],
        [258.3855, 368.1062,   1.1712,  ..., 389.0070, 301.8225,   0.7679]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.1006, 366.1844,   1.1005,  ...,  60.9628, 248.3461,   1.5298],
        [257.4630, 366.7068,   0.8896,  ..., 204.4460, 113.8794,   1.4211],
        [257.1073, 365.7905,   1.2032,  ..., 303.1646,  83.2213,   1.4604],
        ...,
        [256.6600, 365.0424,   1.0706,  ..., 414.3673, 204.1249,   1.3316],
        [256.8549, 365.7827,   1.1501,  ..., 458.2034, 293.5781,   1.5242],
        [256.3000, 365.6039,   1.2427,  ..., 39

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[263.2338, 372.9232,   1.2492,  ..., 152.0181, 156.2861,   0.8642],
        [263.0072, 373.5594,   1.5215,  ..., 378.9540, 248.9181,   0.6454],
        [265.4942, 375.6119,   1.3414,  ..., 357.6752, 355.7743,   0.7259],
        ...,
        [265.9271, 377.2433,   1.2273,  ...,  90.6467, 280.9512,   1.0727],
        [264.4180, 375.5085,   1.1250,  ...,  99.2482, 231.2758,   0.7052],
        [266.3194, 377.5834,   1.5695,  ...,  62.0108, 301.2584,   0.9745]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[254.9768, 362.6511,   0.7499,  ..., 348.4691, 324.3631,   2.1742],
        [254.3625, 362.3164,   0.9697,  ..., 314.4702, 252.3671,   1.7766],
        [254.2371, 362.0778,   0.6916,  ..., 103.1091, 182.3956,   1.5660],
        ...,
        [253.0454, 360.2542,   0.3766,  ..., 197.2264, 173.6209,   1.7792],
        [254.8001, 362.1296,   0.8363,  ..., 347.8188, 337.6322,   2.2364],
    

output tensor([[2.5706e+02, 3.6517e+02, 5.2743e-01,  ..., 2.6923e+02, 1.1504e+02,
         6.0747e-01],
        [2.5683e+02, 3.6604e+02, 9.3404e-01,  ..., 1.1353e+02, 1.9808e+02,
         8.6185e-01],
        [2.5669e+02, 3.6465e+02, 4.5478e-01,  ..., 3.1752e+02, 1.5732e+02,
         5.7362e-01],
        ...,
        [2.5571e+02, 3.6440e+02, 9.0844e-01,  ..., 2.9078e+02, 3.2598e+02,
         2.8932e-01],
        [2.5579e+02, 3.6479e+02, 5.9727e-01,  ..., 4.0620e+02, 2.9762e+02,
         5.9063e-01],
        [2.5612e+02, 3.6411e+02, 4.9830e-01,  ..., 3.4484e+02, 3.0271e+02,
         5.8241e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.6003e+02, 3.6924e+02, 1.3166e+00,  ..., 4.2837e+02, 1.9115e+02,
         3.0655e-01],
        [2.6013e+02, 3.6868e+02, 1.3168e+00,  ..., 4.1253e+02, 1.8943e+02,
         1.5208e-01],
        [2.5953e+02, 3.6850e+02, 1.1813e+00,  ..., 3.5270e+02, 3.0240e+02,
         7.5142e-01],
        ...,
        [2.5992e+

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.4088, 365.7250,   1.7875,  ..., 292.7023, 232.5752,  -1.6505],
        [257.2388, 364.9094,   1.2036,  ..., 160.1333, 143.5544,  -0.5078],
        [257.4089, 365.6132,   1.2021,  ..., 320.8117, 160.9380,  -1.2555],
        ...,
        [258.0573, 367.0422,   1.7172,  ..., 263.0429, 205.4907,  -1.4586],
        [257.9641, 367.1906,   2.2441,  ..., 391.1610, 132.4774,  -0.6641],
        [257.3384, 366.2075,   1.1686,  ..., 423.5530, 348.1905,  -2.0784]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.9948, 367.8203,   1.6337,  ..., 314.8590,  89.6130,   3.8667],
        [259.6917, 367.8764,   1.0202,  ..., 351.8884, 346.7923,   3.4907],
        [258.9860, 368.1353,   1.4537,  ...,  84.8905, 242.1081,   3.7156],
        ...,
        [259.0419, 368.1563,   1.0304,  ..., 368.3945, 314.4332,   3.4909],
        [259.2562, 368.2430,   1.0172,  ..., 382.5222, 141.8696,   3.9567],
    

output tensor([[ 2.6628e+02,  3.7783e+02,  1.6799e+00,  ...,  4.3484e+02,
          3.1152e+02, -6.3186e-01],
        [ 2.6648e+02,  3.7916e+02,  2.2358e+00,  ...,  2.7795e+02,
          1.0597e+02, -1.5421e+00],
        [ 2.6809e+02,  3.8136e+02,  1.6045e+00,  ...,  2.5813e+02,
          2.0592e+02, -1.5997e-01],
        ...,
        [ 2.6359e+02,  3.7535e+02,  1.3293e+00,  ...,  8.1763e+01,
          2.4013e+02,  1.0487e+00],
        [ 2.6439e+02,  3.7572e+02,  1.1515e+00,  ...,  1.9177e+02,
          1.2052e+02,  1.5406e-01],
        [ 2.6253e+02,  3.7431e+02,  5.4459e-01,  ...,  7.4767e+01,
          2.8080e+02,  1.0699e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.4168e+02,  3.4586e+02, -1.0141e+00,  ...,  4.3221e+02,
          2.2701e+02,  2.0008e+00],
        [ 2.3864e+02,  3.4056e+02, -8.3058e-01,  ...,  3.4852e+02,
          1.3065e+02,  2.3317e+00],
        [ 2.4198e+02,  3.4591e+02, -1.8617e-01,  ...,  2.1244e+02,
          8.

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5628e+02,  3.6479e+02,  3.2718e-01,  ...,  3.3476e+02,
          2.0980e+02,  8.1541e-01],
        [ 2.5555e+02,  3.6450e+02,  3.6570e-01,  ...,  2.4494e+02,
          1.5052e+02,  8.1921e-01],
        [ 2.5629e+02,  3.6478e+02,  4.2881e-01,  ...,  3.4363e+02,
          2.0372e+02,  8.1387e-01],
        ...,
        [ 2.5625e+02,  3.6482e+02,  6.3703e-01,  ...,  3.5112e+02,
          2.1754e+02,  8.5171e-01],
        [ 2.5699e+02,  3.6493e+02,  6.8804e-02,  ...,  2.1547e+02,
          1.7986e+02,  1.1680e-03],
        [ 2.5633e+02,  3.6403e+02, -2.1347e-01,  ...,  4.4825e+02,
          2.2807e+02,  1.2070e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5896e+02, 3.6873e+02, 1.1358e+00,  ..., 3.9668e+02, 1.3426e+02,
         1.4613e+00],
        [2.5930e+02, 3.7022e+02, 9.9568e-01,  ..., 4.4118e+02, 2.3939e+02,
         1.0401e+00],
        [2.5897e+02, 3.6839e+02, 7.3669e-01,  ..

output tensor([[257.2043, 366.2957,   1.3831,  ...,  96.4043, 160.2623,   2.5130],
        [257.0982, 366.1767,   1.1341,  ..., 276.8855, 116.8332,   2.1881],
        [257.4869, 366.7816,   0.8929,  ..., 420.5211, 242.7232,   1.9063],
        ...,
        [256.7821, 366.5465,   1.7330,  ..., 107.9401, 176.5773,   2.0512],
        [257.5210, 366.7428,   1.1191,  ..., 364.5976, 155.3289,   2.3103],
        [256.9577, 366.2898,   0.7648,  ..., 384.2330, 343.9777,   2.6006]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5876e+02, 3.6917e+02, 2.4231e-01,  ..., 4.4051e+02, 2.3956e+02,
         7.4109e-01],
        [2.5851e+02, 3.6735e+02, 1.3899e+00,  ..., 1.1406e+02, 2.0345e+02,
         1.1159e+00],
        [2.5911e+02, 3.6840e+02, 7.6728e-01,  ..., 2.2501e+02, 9.1440e+01,
         4.8659e-01],
        ...,
        [2.5890e+02, 3.6793e+02, 9.3809e-01,  ..., 3.1456e+02, 3.0513e+02,
         1.2367e+00],
        [2.5849e+02, 3.6755e+02, 5.9529

output tensor([[257.3872, 366.0470,   0.7022,  ..., 342.0505, 113.0415,   1.0388],
        [258.3655, 368.6227,   1.1813,  ..., 462.4867, 284.2472,   1.0019],
        [258.4183, 367.4434,   1.3933,  ..., 339.5845, 255.5930,   0.7852],
        ...,
        [257.6502, 366.5770,   1.0368,  ..., 161.0761, 170.3075,   0.7110],
        [257.9695, 367.3370,   0.6832,  ..., 378.9736, 308.8602,   0.7224],
        [258.7916, 367.4227,   0.9331,  ..., 319.7108, 346.4739,   0.7220]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.2040, 366.6531,   1.1340,  ..., 350.8973, 320.1112,   1.0875],
        [257.5995, 367.0138,   0.6382,  ...,  59.9916, 307.7785,   0.8457],
        [258.1844, 366.5372,   1.0073,  ..., 417.3033, 219.2480,   0.9392],
        ...,
        [258.6592, 367.3235,   1.1246,  ..., 356.3911, 325.1501,   1.0300],
        [257.6791, 367.0189,   1.4878,  ..., 400.1054, 137.9445,   0.8667],
        [257.7144, 367.1612,   0.9327,  ..., 39

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.1002, 367.7035,   1.3245,  ..., 400.4601, 184.0939,   0.5539],
        [258.3243, 367.3353,   1.2044,  ..., 311.8206, 299.0063,   0.6461],
        [258.3047, 368.0116,   1.3112,  ..., 323.2852, 260.3079,   0.5389],
        ...,
        [257.8342, 367.1345,   1.1218,  ..., 168.7582, 116.8026,   1.0168],
        [258.5709, 367.9573,   1.2347,  ..., 351.2609, 218.9402,   0.8414],
        [258.2122, 367.1476,   1.0360,  ..., 313.0584, 343.9312,   0.5688]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.8632, 366.9838,   1.0269,  ..., 375.9188, 224.0126,   1.5916],
        [257.8865, 367.0305,   0.9927,  ..., 371.7722, 239.7020,   1.6208],
        [257.7127, 367.4651,   0.5177,  ..., 440.5021, 255.8311,   1.2673],
        ...,
        [257.4053, 366.5720,   0.6131,  ...,  56.9958, 315.5923,   1.2459],
        [257.8831, 367.1805,   0.9790,  ..., 399.3334, 186.9036,   1.3132],
    

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.5606, 366.3014,   1.6330,  ..., 443.5165, 247.6138,   1.0765],
        [258.4727, 367.2010,   0.7813,  ..., 283.8405,  97.6296,   0.8870],
        [258.0188, 367.7128,   1.1589,  ..., 425.8980, 347.5012,   0.9610],
        ...,
        [259.0702, 367.3150,   1.3691,  ..., 417.9296, 206.4491,   1.0368],
        [258.2394, 367.6081,   1.1924,  ..., 327.8257, 211.2924,   0.9262],
        [258.1098, 367.3275,   1.2363,  ..., 299.7605, 333.4037,   0.6839]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.5008, 366.3874,   0.9560,  ..., 267.6413, 114.5174,   1.3011],
        [257.7045, 366.4863,   0.7806,  ...,  89.8618, 231.0809,   1.4138],
        [257.4675, 366.6589,   1.0839,  ..., 232.3689,  98.5655,   1.0045],
        ...,
        [257.6708, 366.2625,   1.0371,  ..., 192.4895, 179.3084,   1.2926],
        [256.8040, 367.3768,   1.3275,  ..., 118.9654, 169.0565,   1.0858],
    

output tensor([[256.3882, 365.8505,   0.4776,  ..., 380.9245, 301.6612,   1.4259],
        [257.1442, 366.1919,   0.9055,  ..., 321.8187, 258.9012,   1.3542],
        [256.8385, 366.3177,   0.6832,  ..., 300.1395, 331.0783,   1.0673],
        ...,
        [256.5349, 366.1679,   0.9152,  ..., 101.2598, 193.8650,   1.2775],
        [256.8752, 365.1272,   0.5707,  ..., 350.8329, 297.3086,   1.6237],
        [256.5684, 366.4819,   0.7733,  ..., 291.2895, 325.6576,   1.1022]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.2740, 367.9196,   1.2118,  ..., 386.4053, 146.2926,   0.8547],
        [258.2094, 367.5553,   0.9566,  ..., 286.3637, 119.1322,   1.2121],
        [258.1158, 367.9452,   1.1383,  ..., 393.8583, 161.5697,   0.7571],
        ...,
        [257.8957, 367.7050,   1.3047,  ..., 401.0517, 201.3866,   0.9434],
        [258.3608, 367.0521,   1.0235,  ..., 432.1595, 317.6924,   0.5516],
        [258.5612, 368.6409,   1.0968,  ..., 30

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[262.7898, 371.9777,   2.3676,  ..., 150.5366, 152.8683,   1.6162],
        [264.3901, 374.9333,   2.1191,  ..., 284.6893, 103.7800,   1.7232],
        [265.7247, 375.2402,   2.8436,  ..., 437.3047, 240.6456,   1.6619],
        ...,
        [264.3388, 375.1110,   2.4590,  ..., 397.5802, 162.4446,   1.5637],
        [264.7977, 375.1743,   2.0867,  ..., 308.5506,  90.0398,   1.7785],
        [264.8508, 375.4521,   2.4748,  ..., 377.5037, 321.2444,   1.4983]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5088e+02,  3.5834e+02, -2.2017e-01,  ...,  4.2497e+02,
          2.6690e+02,  4.1200e-02],
        [ 2.5157e+02,  3.5880e+02, -3.3572e-01,  ...,  3.5713e+02,
          1.4810e+02,  4.9724e-01],
        [ 2.5126e+02,  3.5809e+02,  1.5320e-01,  ...,  3.6380e+02,
          2.6153e+02,  4.6433e-01],
        ...,
        [ 2.5187e+02,  3.5910e+02,  1.7160e-01,  ...,  3.0251e+02,
        

output tensor([[255.0855, 362.7671,   1.7463,  ..., 379.1183, 273.9465,   1.5068],
        [255.9807, 364.1219,   1.4070,  ..., 217.7330, 170.9864,   1.0376],
        [256.2601, 363.0531,   1.6603,  ..., 427.0575, 219.5818,   1.0719],
        ...,
        [254.9069, 362.9229,   1.6659,  ..., 382.0737, 344.2979,   1.7165],
        [255.6830, 364.9334,   2.1742,  ..., 105.5617, 176.1446,   1.1773],
        [255.0257, 363.6767,   1.3994,  ..., 391.6676, 155.6733,   0.8311]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5940e+02, 3.6950e+02, 8.7187e-01,  ..., 3.0282e+02, 1.1027e+02,
         6.2823e-01],
        [2.5942e+02, 3.6911e+02, 3.9387e-01,  ..., 2.9984e+02, 1.8370e+02,
         5.4553e-01],
        [2.5916e+02, 3.6902e+02, 7.7628e-01,  ..., 3.9326e+02, 2.1111e+02,
         6.7218e-01],
        ...,
        [2.5916e+02, 3.6871e+02, 1.2351e+00,  ..., 4.5791e+02, 2.7053e+02,
         1.1100e+00],
        [2.5947e+02, 3.6904e+02, 6.1057

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.0722, 367.3991,   1.1073,  ..., 438.2960, 312.1508,   1.2876],
        [259.2313, 368.3480,   0.7029,  ..., 240.6894, 106.3346,   1.1248],
        [259.4585, 368.4725,   1.3256,  ..., 462.8758, 264.1906,   1.6085],
        ...,
        [259.6394, 368.9274,   0.9694,  ...,  97.5086, 162.4907,   1.4921],
        [259.8506, 368.9192,   0.7878,  ...,  54.7672, 291.5888,   1.2342],
        [259.0182, 368.4930,   0.9227,  ..., 398.3093, 294.9582,   1.3521]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[256.3399, 364.5146,   1.2009,  ..., 316.3120, 158.0534,   1.0177],
        [255.8658, 364.1067,   1.2278,  ..., 413.6120, 197.4149,   0.5058],
        [256.8667, 365.4180,   0.8282,  ...,  76.4915, 289.6189,   0.8921],
        ...,
        [256.9356, 365.6343,   1.2468,  ..., 275.5442,  95.9557,   0.8184],
        [256.2763, 364.3853,   1.4055,  ..., 297.6312, 227.7503,   0.7296],
    

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.7164, 366.5531,   0.9139,  ..., 366.6034, 164.4856,   1.0033],
        [257.4915, 365.9158,   0.8606,  ..., 253.4236, 112.7119,   0.9345],
        [257.7200, 366.3783,   0.5253,  ..., 379.6767, 345.3676,   1.1127],
        ...,
        [257.3684, 366.2297,   1.1120,  ..., 392.0201, 215.9378,   1.1005],
        [257.5925, 366.6409,   0.7700,  ..., 420.4306, 354.1647,   1.1421],
        [257.3875, 366.2558,   1.1142,  ..., 391.2709, 216.5258,   1.1053]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5867e+02,  3.6764e+02,  1.1085e+00,  ...,  3.3015e+02,
          2.0332e+02,  1.7102e-01],
        [ 2.5844e+02,  3.6797e+02,  1.1186e+00,  ...,  4.2826e+02,
          2.8512e+02,  1.4816e-01],
        [ 2.5870e+02,  3.6771e+02,  1.3564e+00,  ...,  3.7958e+02,
          1.3217e+02,  2.4889e-01],
        ...,
        [ 2.5859e+02,  3.6763e+02,  1.2596e+00,  ...,  4.0773e+02,
        

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6059e+02,  3.7091e+02,  1.1121e+00,  ...,  3.5035e+02,
          2.0177e+02, -4.7296e-01],
        [ 2.5984e+02,  3.7049e+02,  1.2207e+00,  ...,  3.0743e+02,
          3.3857e+02, -5.6577e-01],
        [ 2.6108e+02,  3.7139e+02,  8.9473e-01,  ...,  4.3997e+02,
          2.4507e+02, -7.5490e-01],
        ...,
        [ 2.6040e+02,  3.7045e+02,  1.0120e+00,  ...,  3.2919e+02,
          1.7966e+02, -4.7123e-01],
        [ 2.6082e+02,  3.7118e+02,  8.2165e-01,  ...,  4.0946e+02,
          3.5589e+02, -1.0361e+00],
        [ 2.6032e+02,  3.7035e+02,  8.8659e-01,  ...,  3.4332e+02,
          1.1407e+02, -2.4876e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[254.9204, 363.0222,   1.0132,  ..., 322.1339, 135.3540,   2.7309],
        [254.7911, 362.4993,   0.9204,  ...,  61.9099, 243.6567,   3.0897],
        [254.1069, 362.0116,   0.9083,  ...,  83.2437, 166.4988,   2.5967],
        ...,
 

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6719e+02,  3.7572e+02,  5.2524e-01,  ...,  4.4143e+02,
          2.4797e+02, -1.4965e+00],
        [ 2.6556e+02,  3.7720e+02,  8.9612e-01,  ...,  3.1030e+02,
          3.4439e+02, -1.6406e+00],
        [ 2.6514e+02,  3.7702e+02,  9.0187e-01,  ...,  3.0919e+02,
          3.3981e+02, -1.7624e+00],
        ...,
        [ 2.6529e+02,  3.7639e+02,  2.9685e-01,  ...,  3.7701e+02,
          1.6966e+02, -1.1690e+00],
        [ 2.6438e+02,  3.7614e+02,  4.6321e-01,  ...,  1.9297e+02,
          1.1663e+02, -4.4840e-01],
        [ 2.6732e+02,  3.7797e+02, -3.6426e-02,  ...,  2.7286e+02,
          8.8381e+01, -1.8816e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[249.8589, 355.2390,   1.3061,  ..., 156.5346, 139.7454,   3.2197],
        [250.4732, 357.7019,   0.9647,  ...,  79.0971, 234.2632,   3.9563],
        [251.2467, 356.3650,   1.9110,  ..., 316.1310, 258.6186,   2.2605],
        ...,
 

output tensor([[2.5605e+02, 3.6430e+02, 1.0491e+00,  ..., 2.5486e+02, 1.1099e+02,
         1.0985e+00],
        [2.5546e+02, 3.6445e+02, 1.6293e+00,  ..., 2.9850e+02, 3.1164e+02,
         5.9929e-01],
        [2.5534e+02, 3.6446e+02, 1.6933e+00,  ..., 2.9683e+02, 3.1362e+02,
         5.1054e-01],
        ...,
        [2.5560e+02, 3.6495e+02, 1.7561e+00,  ..., 2.9639e+02, 3.3157e+02,
         9.5862e-02],
        [2.5686e+02, 3.6439e+02, 1.3974e+00,  ..., 2.8424e+02, 8.1021e+01,
         6.5948e-01],
        [2.5562e+02, 3.6284e+02, 1.4595e+00,  ..., 1.7362e+02, 1.7801e+02,
         4.4379e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6101e+02,  3.7083e+02,  1.6137e+00,  ...,  2.5738e+02,
          2.0089e+02, -3.8245e-01],
        [ 2.6173e+02,  3.7119e+02,  1.2668e+00,  ...,  3.5598e+02,
          3.5046e+02, -3.3677e-01],
        [ 2.5948e+02,  3.6859e+02,  8.4149e-01,  ...,  1.0480e+02,
          2.0139e+02, -2.9092e-01],
        ...,

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.3942, 365.4337,   1.0416,  ..., 292.9735,  82.4933,   1.3552],
        [257.2357, 365.5780,   1.2324,  ..., 279.6178,  89.0911,   1.0364],
        [258.0171, 366.4283,   1.1103,  ...,  54.5032, 263.6535,   1.7869],
        ...,
        [256.9245, 365.4224,   1.9593,  ..., 341.2178, 253.6994,   1.4881],
        [257.4437, 365.7577,   1.1278,  ...,  74.2908, 200.7576,   2.0687],
        [257.1811, 365.9168,   1.2737,  ..., 369.3644, 313.0400,   1.8138]],
       grad_fn=<AddmmBackward0>)
torch.Size([60, 18])
torch.Size([60, 3])
output tensor([[2.5865e+02, 3.6814e+02, 6.2549e-01,  ..., 7.3390e+01, 2.9390e+02,
         9.8095e-01],
        [2.5795e+02, 3.6846e+02, 1.1110e+00,  ..., 1.0830e+02, 1.7591e+02,
         1.1759e+00],
        [2.5843e+02, 3.6724e+02, 8.7233e-01,  ..., 4.0997e+02, 1.8550e+02,
         6.2789e-01],
        ...,
        [2.5905e+02, 3.6901e+02, 1.0080e+00,  ..., 3.0834e+02, 3.4321e+02,
         2.8996e-01]

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[260.1771, 369.9785,   1.2624,  ..., 336.6647, 216.0841,   1.2561],
        [260.4054, 369.0250,   1.1002,  ..., 345.5403, 313.2347,   1.2151],
        [260.0963, 369.8420,   1.5610,  ..., 327.1247, 345.3504,   1.2685],
        ...,
        [260.0999, 369.7384,   1.3468,  ..., 356.5773, 213.4660,   1.2690],
        [259.8055, 369.1477,   1.2424,  ..., 331.4514, 183.4540,   1.3088],
        [259.8063, 369.3291,   1.5765,  ..., 288.8743, 128.0033,   1.4129]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5528e+02, 3.6377e+02, 9.4226e-03,  ..., 5.4622e+01, 3.1140e+02,
         9.5484e-01],
        [2.5498e+02, 3.6311e+02, 3.9722e-01,  ..., 3.1640e+02, 1.5119e+02,
         1.1661e+00],
        [2.5515e+02, 3.6330e+02, 3.0287e-01,  ..., 3.1266e+02, 1.5719e+02,
         1.1418e+00],
        ...,
        [2.5451e+02, 3.6353e+02, 5.4945e-01,  ..., 2.9200e+02, 3.2936e+02,
         9.4738e-0

output tensor([[257.6114, 365.9975,   1.0972,  ...,  82.9714, 242.8070,   1.7812],
        [258.2793, 366.8579,   0.9573,  ..., 356.0738, 333.8790,   2.1195],
        [257.3740, 366.7782,   1.2572,  ..., 459.1541, 297.4028,   2.4298],
        ...,
        [257.7865, 366.6722,   1.1309,  ..., 366.6691, 248.9137,   2.1961],
        [257.8575, 366.5043,   1.1330,  ..., 381.5114, 137.7781,   2.0876],
        [257.9218, 366.8081,   0.9982,  ..., 317.7740, 345.6753,   1.6823]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5796e+02,  3.6654e+02,  5.1113e-01,  ...,  7.9510e+01,
          2.7529e+02, -1.1046e-01],
        [ 2.5810e+02,  3.6661e+02,  1.0021e+00,  ...,  3.0192e+02,
          2.2558e+02,  3.1838e-02],
        [ 2.5859e+02,  3.6747e+02,  7.9539e-01,  ...,  4.1677e+02,
          2.4851e+02,  8.9787e-02],
        ...,
        [ 2.5808e+02,  3.6670e+02,  1.1529e+00,  ...,  3.2282e+02,
          1.5480e+02,  4.6283e-01],
        [ 2.583

output tensor([[267.0466, 378.0115,   2.3559,  ..., 150.7448, 162.1502,  -1.6708],
        [267.5132, 378.6089,   2.6192,  ..., 321.8953, 325.6655,  -2.3546],
        [270.8543, 383.9775,   2.8625,  ..., 458.9668, 263.8499,  -2.6701],
        ...,
        [269.6041, 382.0969,   2.5166,  ..., 390.7803, 176.4908,  -2.1648],
        [269.3994, 381.3904,   3.1030,  ..., 378.2128, 282.5258,  -2.2043],
        [270.7754, 383.8655,   3.5438,  ..., 478.5450, 287.3486,  -2.5047]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.4663e+02,  3.5210e+02, -1.2373e+00,  ...,  2.8943e+02,
          1.0274e+02,  2.0979e+00],
        [ 2.4584e+02,  3.5126e+02, -1.1900e+00,  ...,  2.2445e+02,
          1.2797e+02,  2.9175e+00],
        [ 2.4538e+02,  3.4908e+02, -5.9927e-01,  ...,  3.1317e+02,
          2.5650e+02,  2.0814e+00],
        ...,
        [ 2.4305e+02,  3.4823e+02, -1.0383e+00,  ...,  4.3949e+02,
          2.5662e+02,  3.9629e+00],
        [ 2.464

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5797e+02,  3.6736e+02,  3.4708e-01,  ...,  3.9330e+02,
          1.5683e+02,  4.1930e-01],
        [ 2.5913e+02,  3.6800e+02,  4.5473e-01,  ...,  3.0656e+02,
          2.2722e+02,  2.0485e-01],
        [ 2.5863e+02,  3.6792e+02,  6.1721e-01,  ...,  4.6365e+02,
          2.9462e+02,  1.4195e-01],
        ...,
        [ 2.5859e+02,  3.6814e+02,  4.2918e-01,  ...,  4.3489e+02,
          2.8121e+02, -1.9231e-02],
        [ 2.5861e+02,  3.6843e+02,  3.8742e-01,  ...,  1.4636e+02,
          1.6456e+02,  1.9318e-01],
        [ 2.5920e+02,  3.6795e+02,  5.3142e-01,  ...,  1.7935e+02,
          1.3984e+02,  5.5608e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.1692, 365.3922,   1.3874,  ..., 309.0884, 210.3090,   1.8729],
        [256.6891, 366.2591,   1.1351,  ..., 280.5784, 153.8159,   2.1117],
        [257.4379, 366.8190,   1.8874,  ..., 314.0946, 342.7203,   1.9054],
        ...,
 

output tensor([[258.7445, 367.1822,   1.2538,  ...,  59.5090, 256.2363,   1.6459],
        [258.5479, 368.0857,   1.0424,  ..., 443.4868, 235.2357,   1.7024],
        [257.7475, 366.5718,   1.0596,  ..., 353.1959, 194.7896,   1.5003],
        ...,
        [258.5276, 367.8632,   1.2652,  ..., 459.6112, 244.6875,   1.9760],
        [258.6136, 366.7881,   0.9576,  ..., 280.3484,  88.0870,   1.3281],
        [258.7117, 366.4825,   0.8142,  ..., 273.1251,  83.6691,   1.3033]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.8172, 366.6618,   0.5390,  ..., 285.1947, 106.9854,   0.8277],
        [257.8087, 366.9229,   0.5748,  ..., 346.5210, 188.3170,   0.8012],
        [257.9514, 367.3631,   1.2802,  ..., 386.6871, 128.4275,   1.1555],
        ...,
        [258.4992, 367.2480,   0.4694,  ..., 271.0985,  93.5630,   0.5067],
        [257.8912, 366.5368,   0.9974,  ..., 326.1733, 260.8882,   0.4981],
        [257.5787, 367.3525,   0.9099,  ...,  9

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5676e+02,  3.6566e+02,  1.1721e+00,  ...,  4.3395e+02,
          2.4014e+02, -1.8325e-02],
        [ 2.5604e+02,  3.6420e+02,  1.3173e+00,  ...,  4.2508e+02,
          3.2290e+02, -1.0914e-01],
        [ 2.5635e+02,  3.6485e+02,  1.2655e+00,  ...,  3.3932e+02,
          1.1162e+02,  4.9741e-01],
        ...,
        [ 2.5697e+02,  3.6577e+02,  1.1856e+00,  ...,  1.6718e+02,
          1.0808e+02,  5.5423e-01],
        [ 2.5603e+02,  3.6444e+02,  1.6406e+00,  ...,  3.4320e+02,
          2.4796e+02, -3.9196e-02],
        [ 2.5624e+02,  3.6528e+02,  1.4687e+00,  ...,  4.1017e+02,
          1.4665e+02, -8.7459e-02]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.5727, 367.5695,   1.0702,  ..., 349.1385, 349.6996,   2.0653],
        [258.9174, 368.0711,   1.1068,  ..., 415.7677, 202.7244,   1.5262],
        [259.8528, 369.5957,   1.0836,  ..., 215.1855, 190.1133,   1.7991],
        ...,
 

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.6424e+02, 3.7540e+02, 1.4691e+00,  ..., 1.5723e+02, 1.4456e+02,
         7.0519e-01],
        [2.6388e+02, 3.7537e+02, 1.7729e+00,  ..., 1.2731e+02, 1.4765e+02,
         7.2920e-01],
        [2.6513e+02, 3.7709e+02, 1.1812e+00,  ..., 2.5422e+02, 1.6026e+02,
         7.2294e-01],
        ...,
        [2.6607e+02, 3.7815e+02, 7.5405e-01,  ..., 2.9682e+02, 1.7696e+02,
         2.2379e-01],
        [2.6597e+02, 3.7873e+02, 1.7933e+00,  ..., 1.1091e+02, 2.0058e+02,
         6.0723e-01],
        [2.6648e+02, 3.7768e+02, 1.6325e+00,  ..., 3.0004e+02, 2.4092e+02,
         1.4541e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.4968e+02, 3.5414e+02, 3.4806e-01,  ..., 3.2109e+02, 2.9693e+02,
         4.6314e-01],
        [2.5167e+02, 3.5946e+02, 1.4400e-01,  ..., 8.4376e+01, 2.3651e+02,
         1.2092e+00],
        [2.5014e+02, 3.5604e+02, 4.8374e-01,  ..., 3.2176e+02, 1.2340e+02,
         

output tensor([[ 2.5725e+02,  3.6635e+02,  3.5548e-01,  ...,  3.5095e+02,
          2.2527e+02,  1.8791e+00],
        [ 2.5745e+02,  3.6709e+02,  1.7832e-01,  ...,  4.0237e+02,
          2.9638e+02,  1.5789e+00],
        [ 2.5697e+02,  3.6633e+02, -6.9178e-02,  ...,  3.8148e+02,
          1.5497e+02,  1.9166e+00],
        ...,
        [ 2.5749e+02,  3.6785e+02,  1.1156e+00,  ...,  4.0102e+02,
          1.3937e+02,  1.6075e+00],
        [ 2.5635e+02,  3.6549e+02,  5.4517e-02,  ...,  3.6036e+02,
          3.4802e+02,  2.0091e+00],
        [ 2.5788e+02,  3.6628e+02,  3.6594e-01,  ...,  4.0926e+02,
          1.8492e+02,  1.6044e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5902e+02,  3.6803e+02,  1.3099e+00,  ...,  2.8062e+02,
          9.9026e+01, -9.1353e-01],
        [ 2.5746e+02,  3.6605e+02,  1.5388e+00,  ...,  1.3389e+02,
          1.5130e+02,  1.9457e-01],
        [ 2.5818e+02,  3.6782e+02,  2.1328e+00,  ...,  2.5565e+02,
          2.

output tensor([[258.6802, 367.8970,   0.8486,  ...,  86.2357, 173.8291,   1.9203],
        [259.1859, 366.8514,   0.6246,  ..., 265.8445,  81.2316,   1.3480],
        [258.4240, 367.1957,   1.3474,  ..., 381.7051, 265.2338,   1.2512],
        ...,
        [258.2261, 367.2499,   0.9418,  ..., 216.3012, 180.0339,   1.2892],
        [258.6440, 367.2912,   1.3023,  ..., 407.3324, 184.1933,   1.3405],
        [257.9256, 367.0435,   1.2354,  ..., 347.7736, 344.5707,   1.2095]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.3472, 367.5023,   1.2595,  ..., 325.0989, 346.5612,   1.1471],
        [258.6371, 367.6551,   0.9359,  ..., 102.7876, 194.6257,   1.1515],
        [258.0397, 367.5224,   0.9692,  ..., 436.5872, 266.9435,   0.8819],
        ...,
        [257.8549, 367.0324,   1.2962,  ..., 272.3284, 118.7650,   1.3217],
        [257.5336, 367.1574,   1.5122,  ..., 297.0529, 333.6737,   1.2560],
        [258.5651, 367.0080,   1.0051,  ..., 29

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.8724, 368.1578,   1.7802,  ..., 334.1915, 309.5881,   1.5344],
        [259.5372, 368.0016,   1.6806,  ..., 316.5884, 303.9893,   1.8247],
        [258.8477, 367.9460,   1.6013,  ..., 160.3223, 116.7474,   1.4871],
        ...,
        [259.3245, 368.2204,   1.6859,  ..., 380.2544, 343.6670,   1.4619],
        [258.8506, 368.6496,   1.6918,  ..., 155.4573, 111.7955,   1.5227],
        [259.4038, 368.5574,   1.4865,  ..., 281.5500,  98.4045,   1.3580]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5705e+02,  3.6536e+02,  4.7617e-01,  ...,  3.8317e+02,
          2.6402e+02,  2.2453e-01],
        [ 2.5792e+02,  3.6626e+02,  6.8435e-01,  ...,  5.9875e+01,
          2.5304e+02,  3.8293e-01],
        [ 2.5725e+02,  3.6609e+02,  3.9763e-01,  ...,  3.5417e+02,
          3.2521e+02, -3.4623e-03],
        ...,
        [ 2.5697e+02,  3.6596e+02,  8.1430e-01,  ...,  9.3525e+01,
        

output tensor([[255.4053, 363.9078,   1.3173,  ...,  64.4016, 295.2695,   2.6697],
        [255.6540, 364.1007,   1.1451,  ...,  62.0703, 298.8644,   2.6610],
        [256.0492, 364.4178,   1.8912,  ..., 376.9321, 341.5566,   3.2078],
        ...,
        [256.5656, 365.0376,   2.1632,  ..., 253.4719, 162.0021,   3.1247],
        [255.9482, 363.8894,   1.4117,  ...,  72.8541, 193.0000,   3.1084],
        [257.2520, 363.7418,   2.2472,  ..., 421.9148, 223.3899,   3.6154]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5883e+02,  3.6781e+02,  3.2215e-01,  ...,  3.5034e+02,
          1.9763e+02, -1.9977e-01],
        [ 2.5912e+02,  3.6858e+02, -2.5014e-01,  ...,  4.0288e+02,
          3.5380e+02, -7.5018e-01],
        [ 2.5877e+02,  3.6787e+02,  2.7335e-01,  ...,  3.8322e+02,
          1.4519e+02, -1.0354e-01],
        ...,
        [ 2.5902e+02,  3.6771e+02,  1.9666e-01,  ...,  2.7550e+02,
          9.4353e+01, -6.4612e-01],
        [ 2.578

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[261.6469, 374.7230,   5.3319,  ..., 318.2705, 308.4033,  -4.4134],
        [262.1641, 377.3499,   6.5321,  ..., 309.9690, 352.8523,  -5.5514],
        [261.7925, 375.6518,   5.0949,  ..., 436.6146, 254.1945,  -4.6515],
        ...,
        [261.5545, 374.4913,   5.5155,  ..., 347.3623, 256.1786,  -3.7821],
        [260.2842, 373.0715,   4.1609,  ..., 184.6819, 123.7033,  -3.3178],
        [259.9300, 372.8300,   4.0835,  ..., 215.8733, 119.1389,  -3.4321]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[245.4928, 347.0280, -10.4146,  ..., 311.5476, 235.8467,  12.2805],
        [243.6010, 346.0844, -11.6516,  ..., 417.5638, 258.5839,  13.5900],
        [245.8599, 349.5491,  -8.9532,  ...,  90.0468, 168.1249,  11.2355],
        ...,
        [242.5216, 343.9505, -12.7170,  ..., 368.9489, 317.5745,  14.0110],
        [243.0983, 345.0386, -12.6188,  ..., 339.7552, 319.8621,  13.4135],
    

output tensor([[ 2.5610e+02,  3.6247e+02, -9.4058e-01,  ...,  5.6463e+01,
          3.1009e+02,  1.3734e-01],
        [ 2.5765e+02,  3.6031e+02, -2.0087e+00,  ...,  4.4913e+02,
          2.9184e+02,  3.7509e+00],
        [ 2.5449e+02,  3.6171e+02, -5.6679e-01,  ...,  8.6995e+01,
          2.1321e+02,  2.4185e+00],
        ...,
        [ 2.5745e+02,  3.6360e+02, -8.6496e-01,  ...,  3.9344e+02,
          3.5122e+02,  1.3050e+00],
        [ 2.5373e+02,  3.6110e+02,  5.3825e-01,  ...,  6.8292e+01,
          2.5543e+02,  1.0182e+00],
        [ 2.5684e+02,  3.6440e+02, -6.4160e-01,  ...,  3.1668e+02,
          3.0421e+02,  1.9264e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[260.8743, 371.1993,   2.1282,  ..., 430.9402, 244.2429,   1.3820],
        [258.4108, 368.5902,   3.6682,  ..., 341.3862, 208.6738,   1.6777],
        [259.5588, 368.5986,   3.1887,  ..., 122.4164, 171.9921,   1.0219],
        ...,
        [259.5948, 369.8608,   2.1353,  ...,

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5722e+02, 3.6653e+02, 1.1546e+00,  ..., 4.0024e+02, 3.5207e+02,
         1.1770e-01],
        [2.5623e+02, 3.6467e+02, 1.5992e+00,  ..., 1.4211e+02, 1.5045e+02,
         1.1045e+00],
        [2.5775e+02, 3.6611e+02, 1.8529e+00,  ..., 3.6725e+02, 2.8390e+02,
         1.8614e-01],
        ...,
        [2.5706e+02, 3.6721e+02, 1.5512e+00,  ..., 3.5125e+02, 1.3308e+02,
         1.9712e-01],
        [2.5600e+02, 3.6609e+02, 2.1919e+00,  ..., 4.0838e+02, 1.5050e+02,
         1.5265e+00],
        [2.5770e+02, 3.6693e+02, 1.4923e+00,  ..., 4.2128e+02, 3.4984e+02,
         2.3657e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5725e+02, 3.6732e+02, 2.2115e+00,  ..., 1.5229e+02, 1.6719e+02,
         8.7324e-01],
        [2.5760e+02, 3.6630e+02, 1.6815e+00,  ..., 8.0243e+01, 2.5895e+02,
         5.8289e-01],
        [2.5697e+02, 3.6692e+02, 3.1734e-01,  ..., 3.7154e+02, 2.4422e+02,
         

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5848e+02, 3.6686e+02, 8.3837e-01,  ..., 3.0322e+02, 3.1944e+02,
         1.2933e+00],
        [2.5905e+02, 3.6863e+02, 1.0378e+00,  ..., 9.8668e+01, 2.1739e+02,
         7.3927e-02],
        [2.5969e+02, 3.6840e+02, 1.4500e+00,  ..., 8.5307e+01, 2.4129e+02,
         4.7550e-01],
        ...,
        [2.5913e+02, 3.6826e+02, 8.7213e-01,  ..., 4.0937e+02, 3.4900e+02,
         1.1174e+00],
        [2.6050e+02, 3.6796e+02, 8.1155e-01,  ..., 4.3332e+02, 2.2981e+02,
         1.3613e+00],
        [2.5857e+02, 3.6816e+02, 7.3624e-01,  ..., 1.7375e+02, 1.1746e+02,
         1.2786e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5651e+02, 3.6565e+02, 1.3215e+00,  ..., 4.3443e+02, 2.6332e+02,
         1.0910e+00],
        [2.5796e+02, 3.6622e+02, 1.5941e+00,  ..., 8.0241e+01, 1.9518e+02,
         1.9857e+00],
        [2.5670e+02, 3.6542e+02, 2.8072e-01,  ..., 2.1257e+02, 1.1652e+02,
         

output tensor([[2.5699e+02, 3.6551e+02, 1.6387e+00,  ..., 3.4853e+02, 1.8968e+02,
         1.1844e+00],
        [2.5620e+02, 3.6587e+02, 1.0273e-01,  ..., 1.8271e+02, 1.1219e+02,
         1.1574e+00],
        [2.5706e+02, 3.6726e+02, 3.3037e-01,  ..., 8.7545e+01, 1.9446e+02,
         8.9093e-01],
        ...,
        [2.5658e+02, 3.6607e+02, 1.1277e+00,  ..., 3.6189e+02, 3.2919e+02,
         7.3367e-01],
        [2.5813e+02, 3.6642e+02, 6.6438e-01,  ..., 4.1531e+02, 1.8907e+02,
         1.0529e+00],
        [2.5716e+02, 3.6582e+02, 1.6594e+00,  ..., 3.9403e+02, 1.6915e+02,
         1.0912e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5905e+02, 3.6823e+02, 2.6122e-01,  ..., 2.9414e+02, 3.3041e+02,
         1.1776e+00],
        [2.5801e+02, 3.6708e+02, 9.3233e-01,  ..., 1.2406e+02, 1.4569e+02,
         1.3213e+00],
        [2.5888e+02, 3.6833e+02, 3.0057e-01,  ..., 2.8580e+02, 1.0941e+02,
         8.9396e-01],
        ...,
        [2.5875e+

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5800e+02, 3.6675e+02, 6.9969e-01,  ..., 1.1278e+02, 2.0033e+02,
         2.6370e-01],
        [2.5848e+02, 3.6698e+02, 1.1158e+00,  ..., 2.8250e+02, 1.6217e+02,
         1.5473e+00],
        [2.5831e+02, 3.6732e+02, 1.4270e+00,  ..., 5.8000e+01, 2.5564e+02,
         1.1310e+00],
        ...,
        [2.5834e+02, 3.6691e+02, 8.8904e-01,  ..., 4.2286e+02, 1.9331e+02,
         7.8508e-01],
        [2.5927e+02, 3.6644e+02, 1.5969e+00,  ..., 4.5249e+02, 2.3381e+02,
         1.6534e+00],
        [2.5795e+02, 3.6750e+02, 1.0507e+00,  ..., 2.9842e+02, 2.2883e+02,
         9.6516e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.0194, 367.0938,   0.8852,  ..., 274.2370,  93.7205,   0.6175],
        [258.8489, 367.4423,   0.7914,  ..., 216.4516, 187.8200,   0.8211],
        [257.8724, 366.3669,   1.3412,  ..., 438.9881, 258.5392,   1.0618],
        ...,
        [258.2965, 367.5334,   1.0325

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.1987, 366.5557,   1.2531,  ..., 177.8636, 114.2261,   1.0070],
        [256.8210, 366.5732,   0.7205,  ..., 163.2953, 116.2643,   0.9330],
        [257.3843, 366.8753,   1.1262,  ..., 348.7511, 346.6933,   0.8365],
        ...,
        [256.8055, 366.5907,   0.7100,  ..., 162.6700, 116.2650,   0.9248],
        [258.1815, 366.3264,   1.4438,  ..., 321.9146, 305.4026,   1.0763],
        [257.4618, 366.8153,   0.7348,  ...,  62.2499, 308.7618,   1.2039]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.4326, 367.4275,   0.5479,  ..., 161.5476, 116.4715,   1.0309],
        [258.3077, 367.6190,   1.0885,  ..., 306.2910, 220.8935,   0.9452],
        [258.2555, 367.4847,   1.5532,  ..., 380.9588, 143.3350,   1.0348],
        ...,
        [258.6561, 367.0229,   1.0160,  ..., 136.7888, 153.7823,   1.0555],
        [258.0653, 367.5127,   0.5795,  ..., 372.7562, 162.1283,   0.9169],
    

output tensor([[260.5094, 370.3104,   1.5559,  ..., 281.8833, 155.9861,   1.0872],
        [262.7291, 371.5729,   1.2820,  ..., 343.6593, 312.6114,   0.5905],
        [261.0826, 371.0812,   1.5530,  ..., 312.6867, 224.2897,   0.6466],
        ...,
        [261.6033, 371.8964,   1.7402,  ..., 241.2710, 189.8026,   0.7506],
        [259.6670, 370.3357,   1.3421,  ..., 326.9715, 172.7711,   0.9359],
        [263.1112, 374.2344,   1.4468,  ..., 300.5756, 339.4756,   0.8322]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5429e+02,  3.6098e+02, -5.4765e-02,  ...,  3.3433e+02,
          1.1580e+02,  1.2642e+00],
        [ 2.5391e+02,  3.6175e+02, -7.4188e-03,  ...,  3.7054e+02,
          1.5478e+02,  1.4020e+00],
        [ 2.5647e+02,  3.6393e+02,  6.9547e-01,  ...,  3.2415e+02,
          3.0193e+02,  1.1275e+00],
        ...,
        [ 2.5463e+02,  3.6331e+02,  5.4308e-01,  ...,  2.0228e+02,
          1.0958e+02,  1.3222e+00],
        [ 2.552

output tensor([[257.3194, 366.9419,   2.1725,  ..., 409.5244, 144.3507,   1.2653],
        [257.9641, 366.8954,   1.7684,  ..., 449.4510, 308.4444,   1.1643],
        [257.3066, 365.7753,   1.9608,  ..., 428.0555, 293.3424,   0.9564],
        ...,
        [257.8307, 366.8173,   1.8127,  ..., 351.4792, 223.2397,   1.2912],
        [258.8788, 367.4660,   1.6880,  ...,  66.5880, 248.2479,   1.1050],
        [257.9487, 367.4322,   1.4092,  ..., 171.4281, 110.4003,   1.2205]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5804e+02, 3.6762e+02, 5.0557e-01,  ..., 3.0709e+02, 2.5246e+02,
         1.0382e+00],
        [2.5885e+02, 3.6698e+02, 6.7411e-01,  ..., 2.6745e+02, 8.1135e+01,
         8.9470e-01],
        [2.5865e+02, 3.6748e+02, 1.9169e-01,  ..., 2.1305e+02, 1.8673e+02,
         9.1127e-01],
        ...,
        [2.5845e+02, 3.6631e+02, 3.2116e-01,  ..., 2.7048e+02, 1.0411e+02,
         1.1209e+00],
        [2.5867e+02, 3.6675e+02, 3.0896

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[267.7184, 380.0410,   1.7840,  ..., 375.6429, 357.5363,   0.9978],
        [265.1411, 375.6705,   1.7054,  ..., 199.8757, 186.5127,   0.5171],
        [267.8340, 380.3541,   1.6937,  ..., 313.2306, 352.1116,   0.7126],
        ...,
        [265.7357, 375.9769,   1.4188,  ...,  84.9167, 287.1013,   1.0246],
        [267.9059, 379.8907,   1.2899,  ..., 440.0230, 331.0642,   0.6241],
        [264.0128, 375.6930,   1.7677,  ..., 189.6427,  97.8151,   0.8493]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[249.3866, 356.5018,   0.6483,  ..., 417.4839, 340.9989,   1.2987],
        [250.0231, 355.8532,   1.0019,  ..., 272.7404,  85.3731,   1.0803],
        [249.5210, 356.5403,   0.7810,  ..., 367.4507, 215.2885,   1.3053],
        ...,
        [249.9726, 356.2557,   0.9488,  ..., 288.7296, 315.2184,   1.1674],
        [249.8237, 357.0665,   0.7235,  ..., 430.7530, 225.9340,   1.1029],
    

output tensor([[ 2.5653e+02,  3.6481e+02,  5.8261e-01,  ...,  4.4348e+02,
          3.0104e+02,  7.0198e-01],
        [ 2.5670e+02,  3.6462e+02,  1.7768e-01,  ...,  2.8019e+02,
          1.1998e+02,  9.4088e-01],
        [ 2.5553e+02,  3.6399e+02,  1.2902e-01,  ...,  2.9185e+02,
          2.3302e+02,  5.5994e-01],
        ...,
        [ 2.5557e+02,  3.6359e+02,  1.9916e-02,  ...,  2.6694e+02,
          1.7199e+02,  9.3321e-01],
        [ 2.5498e+02,  3.6369e+02, -5.6102e-02,  ...,  2.9155e+02,
          2.2724e+02,  6.3172e-01],
        [ 2.5573e+02,  3.6460e+02, -2.3395e-01,  ...,  3.6376e+02,
          1.6145e+02,  7.3623e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[260.4135, 370.1725,   1.7371,  ..., 353.5465, 195.5484,   1.1718],
        [260.8278, 370.3627,   1.5944,  ..., 386.9015, 268.6537,   0.7807],
        [260.8081, 370.8094,   1.3599,  ..., 369.2567, 351.7795,   0.9333],
        ...,
        [259.1609, 368.0856,   0.9438,  ...,

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[260.0881, 370.5583,   0.8930,  ..., 432.5950, 247.7948,   0.7178],
        [259.7389, 369.3556,   1.2631,  ..., 397.7481, 174.2217,   1.1481],
        [258.8819, 368.6802,   1.7650,  ..., 177.7730,  96.7113,   0.9568],
        ...,
        [260.0460, 369.1887,   1.1576,  ..., 305.8692, 317.0009,   1.2828],
        [260.4018, 369.8271,   1.5240,  ..., 382.1870, 268.4287,   0.9334],
        [259.0246, 369.2965,   0.8945,  ...,  89.6382, 198.7469,   1.1294]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5657e+02, 3.6619e+02, 8.3334e-01,  ..., 1.6732e+02, 1.0947e+02,
         1.1710e+00],
        [2.5695e+02, 3.6511e+02, 6.3816e-01,  ..., 3.6293e+02, 3.1713e+02,
         1.2430e+00],
        [2.5660e+02, 3.6688e+02, 8.8219e-02,  ..., 4.1738e+02, 2.4207e+02,
         7.9512e-01],
        ...,
        [2.5660e+02, 3.6516e+02, 3.7484e-01,  ..., 2.6047e+02, 1.0252e+02,
         1.0159e+0

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.1411, 369.2442,   1.3822,  ..., 230.2513,  93.7753,   0.8372],
        [258.7602, 368.1072,   1.5221,  ..., 246.7646, 105.1782,   0.9002],
        [259.1763, 368.5977,   1.3854,  ..., 311.5146, 206.0827,   0.8447],
        ...,
        [259.5148, 368.2656,   1.8733,  ..., 382.7968, 344.8257,   1.3102],
        [259.4892, 368.8183,   1.8557,  ..., 352.8101, 216.2265,   1.1652],
        [259.2478, 368.2187,   1.5472,  ..., 270.6459, 105.1001,   0.9046]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[255.9791, 364.5197,   0.6557,  ..., 397.9320, 185.6234,   1.1608],
        [255.8706, 365.2179,   0.5237,  ..., 205.5362, 110.3955,   1.0380],
        [256.4070, 365.1872,   1.0279,  ..., 457.9102, 291.5053,   1.1521],
        ...,
        [256.9690, 364.7521,   0.7294,  ..., 426.3871, 230.9571,   1.4013],
        [256.4799, 365.1783,   0.8505,  ..., 450.6453, 302.3602,   1.0738],
    

output tensor([[2.5790e+02, 3.6756e+02, 1.3698e+00,  ..., 4.3185e+02, 3.4583e+02,
         1.3338e+00],
        [2.5839e+02, 3.6788e+02, 6.6981e-01,  ..., 2.3170e+02, 8.9941e+01,
         9.2664e-01],
        [2.5873e+02, 3.6707e+02, 1.5299e+00,  ..., 3.3704e+02, 2.5530e+02,
         1.2882e+00],
        ...,
        [2.5714e+02, 3.6676e+02, 1.1409e+00,  ..., 1.1623e+02, 1.3998e+02,
         1.0770e+00],
        [2.5853e+02, 3.6690e+02, 3.7352e-01,  ..., 3.4311e+02, 1.1549e+02,
         1.5840e+00],
        [2.5861e+02, 3.6836e+02, 9.5827e-01,  ..., 2.8700e+02, 8.9933e+01,
         7.8977e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5858e+02,  3.6841e+02,  1.9180e-01,  ...,  4.1759e+02,
          2.1202e+02,  1.9215e+00],
        [ 2.5814e+02,  3.6727e+02,  7.9882e-01,  ...,  3.1986e+02,
          9.3400e+01,  1.7534e+00],
        [ 2.5957e+02,  3.6885e+02, -3.3638e-01,  ...,  7.7357e+01,
          2.7483e+02,  2.1466e+00],
        ...,

output tensor([[2.5571e+02, 3.6485e+02, 2.3964e-01,  ..., 1.1233e+02, 2.0343e+02,
         5.9374e-01],
        [2.5725e+02, 3.6535e+02, 4.1687e-01,  ..., 2.9824e+02, 3.1843e+02,
         1.0442e+00],
        [2.5676e+02, 3.6626e+02, 4.5273e-01,  ..., 4.1572e+02, 2.4563e+02,
         6.1927e-01],
        ...,
        [2.5631e+02, 3.6514e+02, 7.9949e-01,  ..., 3.2052e+02, 1.2625e+02,
         9.8151e-01],
        [2.5700e+02, 3.6603e+02, 5.7748e-01,  ..., 3.4225e+02, 3.4515e+02,
         8.9899e-01],
        [2.5752e+02, 3.6495e+02, 1.2506e+00,  ..., 4.2785e+02, 2.3096e+02,
         1.1456e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.2366, 368.6836,   0.9740,  ..., 302.7107, 334.1014,   0.7588],
        [258.9475, 367.5856,   1.4372,  ...,  86.6724, 247.0891,   0.9782],
        [259.5185, 369.2605,   0.7875,  ..., 431.4047, 178.5995,   1.0247],
        ...,
        [259.0392, 368.8781,   1.1477,  ..., 355.8860, 347.1628,   0.8226],
    

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[261.5135, 373.0264,   1.3297,  ..., 412.5665, 355.4240,   1.0333],
        [260.2437, 370.8724,   2.1938,  ..., 176.6747,  98.2958,   0.6924],
        [263.3354, 372.0874,   1.2065,  ..., 453.9329, 234.2384,   1.3581],
        ...,
        [260.7066, 371.4918,   1.3553,  ..., 282.3579, 155.3413,   0.7516],
        [261.2342, 372.2612,   1.6198,  ..., 427.9666, 352.3291,   1.2349],
        [259.9516, 369.9378,   0.9983,  ..., 357.4819, 141.3838,   1.0783]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5420e+02,  3.6254e+02,  7.1911e-01,  ...,  1.6346e+02,
          1.0193e+02,  7.1019e-01],
        [ 2.5485e+02,  3.6276e+02,  5.4328e-02,  ...,  2.8663e+02,
          9.6928e+01,  5.9308e-01],
        [ 2.5432e+02,  3.6263e+02, -2.3536e-01,  ...,  2.7752e+02,
          1.5725e+02,  6.4136e-01],
        ...,
        [ 2.5509e+02,  3.6314e+02,  3.3015e-01,  ...,  3.8386e+02,
        

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[256.8599, 365.9384,   1.1675,  ..., 321.9969, 155.9340,   1.0196],
        [257.4944, 365.7191,   1.2499,  ..., 190.7691, 179.1183,   0.8243],
        [256.9290, 365.3216,   0.6734,  ..., 300.5757, 180.7268,   0.9733],
        ...,
        [257.0616, 366.0720,   1.0457,  ..., 396.1825, 207.0725,   0.8577],
        [257.2973, 366.6567,   0.9626,  ..., 351.7426, 344.8471,   0.8739],
        [257.4925, 366.3827,   1.2491,  ..., 460.7831, 281.9432,   0.8500]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5865e+02, 3.6818e+02, 9.4003e-01,  ..., 2.0919e+02, 8.6346e+01,
         7.0034e-01],
        [2.5861e+02, 3.6802e+02, 1.1784e+00,  ..., 1.7881e+02, 9.5694e+01,
         9.2119e-01],
        [2.5855e+02, 3.6763e+02, 8.4341e-01,  ..., 3.0964e+02, 1.1300e+02,
         9.4128e-01],
        ...,
        [2.5858e+02, 3.6784e+02, 4.3325e-01,  ..., 3.5520e+02, 1.4265e+02,
         1.0417e+0

output tensor([[256.6216, 365.7236,   0.5228,  ..., 305.7770, 336.8644,   0.6073],
        [256.9490, 365.7243,   0.7520,  ..., 246.6565, 110.4391,   1.0913],
        [257.0888, 365.7408,   0.5135,  ..., 409.4209, 349.5546,   0.7139],
        ...,
        [257.0043, 365.3853,   0.8041,  ..., 283.2661, 106.0323,   1.0078],
        [256.8339, 366.3712,   0.7481,  ..., 322.8047, 345.2385,   0.5906],
        [256.7912, 366.1938,   0.6885,  ..., 343.0696, 346.3545,   0.5801]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.3238, 368.4927,   1.2432,  ..., 409.7501, 297.3369,   1.0761],
        [259.2025, 368.0519,   1.3651,  ..., 381.1993, 343.9960,   1.5790],
        [258.5197, 368.2343,   0.7600,  ..., 372.4439, 165.5131,   1.2964],
        ...,
        [259.1614, 367.6281,   1.1401,  ..., 454.1622, 235.6828,   1.1535],
        [259.5057, 368.4254,   1.1574,  ..., 287.3459,  82.1455,   0.9489],
        [258.8345, 369.0010,   1.1660,  ..., 42

output tensor([[2.6194e+02, 3.7160e+02, 2.3391e+00,  ..., 3.9266e+02, 2.9471e+02,
         4.3665e-01],
        [2.6093e+02, 3.7049e+02, 2.7946e+00,  ..., 2.3157e+02, 1.7612e+02,
         2.7901e-01],
        [2.6122e+02, 3.7036e+02, 2.2356e+00,  ..., 4.5872e+02, 2.3596e+02,
         2.1069e-01],
        ...,
        [2.6058e+02, 3.7092e+02, 1.9318e+00,  ..., 2.3176e+02, 9.0847e+01,
         1.2546e-01],
        [2.6059e+02, 3.7027e+02, 2.0299e+00,  ..., 2.6048e+02, 1.0905e+02,
         4.0399e-01],
        [2.6118e+02, 3.7095e+02, 2.3619e+00,  ..., 3.2071e+02, 2.5937e+02,
         2.1652e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5502e+02, 3.6301e+02, 3.2205e-01,  ..., 1.5162e+02, 1.4740e+02,
         1.7176e+00],
        [2.5482e+02, 3.6179e+02, 9.5016e-01,  ..., 4.2556e+02, 2.4820e+02,
         1.5318e+00],
        [2.5504e+02, 3.6331e+02, 5.6758e-01,  ..., 1.6035e+02, 1.1258e+02,
         1.4653e+00],
        ...,
        [2.5500e+

output tensor([[257.4494, 367.2706,   4.5173,  ..., 345.5457, 345.4402,   1.0876],
        [257.9224, 368.0279,   4.6360,  ..., 359.3201, 321.7023,   1.2824],
        [258.6117, 367.9701,   4.4189,  ..., 395.3233, 162.6186,   1.5169],
        ...,
        [255.7248, 366.5039,   4.4749,  ..., 107.6749, 194.3451,   2.0530],
        [256.2752, 365.5078,   6.0588,  ..., 284.5870, 170.9695,   1.1185],
        [254.4916, 362.5691,   4.1864,  ...,  72.2720, 294.3275,   2.4715]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6095e+02,  3.6970e+02,  1.1945e+00,  ...,  3.1414e+02,
          2.5175e+02, -2.1532e-01],
        [ 2.5964e+02,  3.6633e+02, -2.3599e-01,  ...,  4.3139e+02,
          2.1311e+02, -5.4530e-01],
        [ 2.5845e+02,  3.6757e+02,  3.6622e-01,  ...,  4.2040e+02,
          1.8917e+02, -7.3860e-01],
        ...,
        [ 2.5791e+02,  3.6836e+02,  1.9045e-01,  ...,  4.2871e+02,
          2.3895e+02, -9.6612e-01],
        [ 2.575

output tensor([[ 2.5747e+02,  3.6386e+02, -9.7283e-01,  ...,  4.4188e+02,
          2.3576e+02,  3.7174e+00],
        [ 2.5706e+02,  3.6727e+02, -1.8858e-01,  ...,  7.5076e+01,
          2.3791e+02,  2.4065e+00],
        [ 2.5558e+02,  3.6353e+02,  2.2745e-01,  ...,  2.8995e+02,
          3.2209e+02,  3.3741e+00],
        ...,
        [ 2.5691e+02,  3.6627e+02,  9.4922e-01,  ...,  7.8267e+01,
          1.8926e+02,  2.4878e+00],
        [ 2.5593e+02,  3.6448e+02,  2.6076e-01,  ...,  3.5891e+02,
          3.1884e+02,  3.5923e+00],
        [ 2.5579e+02,  3.6348e+02,  3.7016e-01,  ...,  3.4075e+02,
          3.3610e+02,  3.3870e+00]], grad_fn=<AddmmBackward0>)
torch.Size([60, 18])
torch.Size([60, 3])
output tensor([[ 2.5842e+02,  3.6860e+02,  1.1400e+00,  ...,  1.0955e+02,
          2.1713e+02, -2.4590e+00],
        [ 2.6014e+02,  3.6898e+02,  1.8493e+00,  ...,  4.6581e+02,
          2.6740e+02, -1.7684e+00],
        [ 2.5786e+02,  3.6588e+02,  1.8663e+00,  ...,  3.9090e+02,
          1.57

output tensor([[257.4978, 367.0372,   1.7271,  ...,  90.5740, 188.6798,   1.1015],
        [257.1921, 366.4329,   2.3472,  ..., 377.3516, 344.4250,   1.5066],
        [256.8861, 366.2116,   2.4427,  ..., 351.0269, 322.4624,   1.3950],
        ...,
        [257.9494, 367.3616,   1.4926,  ..., 365.5345, 286.1786,   1.2423],
        [257.2523, 366.8983,   1.5781,  ..., 214.7478, 187.9821,   1.4267],
        [256.8382, 365.3628,   1.6061,  ..., 416.4986, 339.2883,   1.8511]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5792e+02,  3.6850e+02, -3.9875e-02,  ...,  4.0204e+02,
          3.5452e+02, -4.4038e-01],
        [ 2.5733e+02,  3.6724e+02,  7.0175e-01,  ...,  4.1915e+02,
          2.4698e+02, -6.5189e-01],
        [ 2.5748e+02,  3.6696e+02,  5.4290e-01,  ...,  4.2056e+02,
          2.3014e+02, -8.5118e-01],
        ...,
        [ 2.5828e+02,  3.6703e+02, -1.3888e-01,  ...,  4.5386e+02,
          2.3427e+02,  1.0180e-01],
        [ 2.565

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.4875, 366.4797,   0.9824,  ..., 233.0215, 125.6847,   1.0527],
        [257.8218, 367.7471,   1.3763,  ..., 439.7601, 250.1936,   1.0304],
        [258.0668, 366.8783,   1.6117,  ..., 153.2881, 166.0206,   0.9835],
        ...,
        [258.4211, 366.6453,   1.1884,  ..., 246.0270, 152.1230,   1.0742],
        [258.4221, 366.7882,   0.8891,  ..., 323.9767, 306.7056,   1.4487],
        [257.7001, 367.1854,   1.2225,  ..., 349.8426, 344.7007,   1.3162]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.2221, 367.7740,   0.5898,  ..., 307.7478, 337.0424,   1.0597],
        [258.4926, 367.9137,   1.2412,  ..., 319.5406, 161.6080,   0.8487],
        [257.7606, 367.2551,   0.6626,  ..., 414.9314, 298.0608,   1.0669],
        ...,
        [258.0625, 367.6149,   1.1286,  ..., 297.9998, 238.0319,   0.8713],
        [258.7588, 367.8415,   0.9413,  ..., 219.3349, 175.7341,   1.0137],
    

output tensor([[261.2691, 371.2491,   1.0701,  ..., 459.8080, 303.7514,   0.9474],
        [260.3068, 370.2447,   0.6938,  ..., 172.4827, 118.9866,   0.7958],
        [260.9123, 371.0547,   0.8628,  ..., 314.2719, 255.9296,   0.9441],
        ...,
        [260.6770, 370.8312,   0.8571,  ..., 309.5928, 249.5612,   0.8758],
        [260.8046, 370.6312,   1.1301,  ..., 108.7065, 178.7428,   0.8787],
        [260.6734, 369.7768,   0.9596,  ..., 362.6174, 322.9805,   1.0257]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[254.1219, 362.5344,   0.5514,  ...,  99.7454, 212.0874,   0.7006],
        [254.3551, 362.3064,   1.0994,  ..., 361.5142, 159.7318,   0.8069],
        [254.6447, 362.4910,   1.1150,  ..., 267.9998, 175.1548,   0.9349],
        ...,
        [254.2792, 362.9744,   1.1585,  ..., 257.5845, 202.8853,   0.9192],
        [254.5841, 362.9482,   1.6852,  ...,  74.1212, 188.2257,   1.2146],
        [254.6507, 362.4427,   1.1163,  ..., 26

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.9666, 369.0634,   0.8192,  ..., 368.4924, 286.0922,   0.4413],
        [259.8141, 369.8102,   1.0735,  ..., 216.1196, 181.6310,   0.9791],
        [260.2198, 369.3844,   1.1311,  ...,  86.9186, 237.2049,   0.9276],
        ...,
        [259.2787, 368.8766,   1.2642,  ..., 345.7240, 347.4333,   0.6731],
        [260.1331, 369.2009,   1.0819,  ...,  81.8877, 263.8383,   0.6372],
        [259.6931, 369.6794,   0.6620,  ..., 236.8654, 107.4920,   0.9177]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[255.6257, 363.6510,   1.6263,  ...,  79.0831, 241.6889,   1.0831],
        [255.8563, 363.7893,   1.4627,  ..., 349.2921, 209.5323,   1.0429],
        [254.9191, 363.8257,   1.5044,  ..., 325.8008,  92.4109,   1.1076],
        ...,
        [255.7094, 364.8470,   1.2814,  ..., 276.3819,  86.0955,   0.8495],
        [255.5508, 363.5040,   0.9116,  ..., 133.8934, 146.5356,   1.0673],
    

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.7092, 368.8403,   1.0399,  ..., 365.2128, 257.6638,   0.9912],
        [259.3804, 369.2228,   0.7182,  ..., 418.7454, 250.2321,   0.7054],
        [259.3147, 369.0731,   0.5363,  ..., 376.3763, 162.4512,   0.9387],
        ...,
        [259.1773, 369.5127,   1.2494,  ..., 320.0893, 165.0175,   1.0538],
        [259.9097, 368.9336,   0.9842,  ..., 354.6868, 324.2587,   1.0700],
        [260.2535, 369.5679,   1.2537,  ..., 349.2077, 187.3442,   1.0624]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[256.2285, 365.1147,   1.0700,  ..., 301.5109, 335.3665,   0.7761],
        [256.2397, 365.1001,   1.3761,  ..., 421.8722, 290.0710,   0.8056],
        [256.6812, 365.8093,   1.2777,  ..., 437.2833, 254.1746,   0.6955],
        ...,
        [257.0215, 365.9908,   0.7319,  ..., 391.3288, 161.5128,   1.0549],
        [257.4901, 365.1610,   1.0607,  ..., 236.9761, 144.1204,   0.9823],
    

output tensor([[258.3305, 366.9675,   1.0457,  ..., 233.2510, 106.5881,   1.1250],
        [258.4691, 366.6151,   0.8098,  ..., 182.0518, 129.0036,   1.2658],
        [258.3455, 366.6986,   0.7029,  ..., 176.3380, 121.5626,   1.1239],
        ...,
        [258.6838, 367.9684,   0.5459,  ...,  86.9125, 196.4103,   1.2655],
        [258.3600, 366.8542,   0.8056,  ..., 340.2607, 308.8687,   1.0296],
        [258.0785, 366.4446,   1.2119,  ..., 400.7330, 195.0245,   1.1459]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.4233, 367.4919,   1.5925,  ..., 340.0311, 108.6756,   1.3470],
        [258.5950, 367.4112,   0.8909,  ..., 167.1427, 116.5152,   0.8598],
        [258.5195, 367.8832,   1.1786,  ..., 303.7855, 227.5794,   1.0424],
        ...,
        [259.1831, 368.6982,   1.4067,  ...,  75.3271, 198.7890,   1.2850],
        [258.8801, 367.9562,   1.1050,  ..., 393.2222, 160.3036,   1.3632],
        [258.9319, 367.3513,   1.1236,  ..., 28

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[260.0922, 368.6524,   1.2337,  ..., 341.3490, 250.1598,   1.4359],
        [259.3123, 367.7995,   1.4129,  ..., 420.8629, 354.1883,   1.2133],
        [261.4084, 371.6559,   1.1261,  ..., 282.7081, 101.6886,   1.3164],
        ...,
        [261.1037, 371.0897,   0.8117,  ..., 303.1803, 192.7076,   1.3074],
        [260.6132, 369.3093,   0.5959,  ..., 373.1313, 164.5154,   1.5777],
        [261.6353, 372.0504,   0.7875,  ..., 277.1357, 151.4754,   1.0587]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5673e+02, 3.6476e+02, 7.0949e-01,  ..., 1.8368e+02, 1.3432e+02,
         1.5998e+00],
        [2.5583e+02, 3.6377e+02, 5.9038e-01,  ..., 4.5274e+02, 2.5504e+02,
         1.5062e+00],
        [2.5526e+02, 3.6337e+02, 5.8554e-01,  ..., 3.9447e+02, 2.9964e+02,
         1.8024e+00],
        ...,
        [2.5665e+02, 3.6424e+02, 1.9040e-01,  ..., 3.5732e+02, 1.6029e+02,
         1.4180e+0

output tensor([[256.9328, 364.7511,   1.3320,  ..., 154.4428, 165.3000,   0.4631],
        [256.5098, 364.7771,   0.5702,  ..., 110.3227, 195.2378,   0.4403],
        [257.0524, 364.9426,   1.2899,  ..., 271.4389,  89.8261,   0.4273],
        ...,
        [256.0722, 364.7714,   1.2460,  ..., 111.3121, 142.8486,   0.4358],
        [257.0955, 365.4275,   0.9734,  ..., 337.4292, 249.1599,   0.6471],
        [257.5313, 365.8720,   1.2049,  ...,  56.3449, 256.1955,   0.4667]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.9425, 368.0781,   1.2873,  ..., 280.3412,  88.9241,   1.1338],
        [259.7555, 367.8856,   0.9813,  ..., 310.7371, 310.7704,   1.6367],
        [260.2499, 368.9386,   1.1251,  ..., 323.9084, 307.3817,   1.6527],
        ...,
        [259.2654, 368.4736,   0.9451,  ..., 396.2854, 212.8046,   1.3837],
        [259.6831, 368.9855,   1.1792,  ..., 154.9432, 114.7764,   1.5388],
        [260.2785, 368.9981,   1.0938,  ..., 33

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5922e+02, 3.6828e+02, 6.1427e-01,  ..., 3.6193e+02, 2.6101e+02,
         9.5146e-01],
        [2.5947e+02, 3.6891e+02, 4.7632e-01,  ..., 1.8918e+02, 1.1377e+02,
         7.7494e-01],
        [2.5975e+02, 3.6883e+02, 1.0779e+00,  ..., 3.5034e+02, 1.9086e+02,
         9.2647e-01],
        ...,
        [2.5949e+02, 3.6922e+02, 1.0981e+00,  ..., 3.4427e+02, 2.0461e+02,
         9.6804e-01],
        [2.5908e+02, 3.6929e+02, 3.2865e-01,  ..., 4.3004e+02, 1.7910e+02,
         9.6913e-01],
        [2.5920e+02, 3.6904e+02, 8.2348e-01,  ..., 1.6198e+02, 1.1006e+02,
         7.4781e-01]], grad_fn=<AddmmBackward0>)
torch.Size([60, 18])
torch.Size([60, 3])
output tensor([[257.3629, 366.1953,   0.5284,  ..., 330.1182, 260.8456,   1.1546],
        [257.0122, 366.4235,   0.9520,  ..., 101.2874, 152.0540,   1.0311],
        [257.6396, 366.4090,   0.5811,  ..., 419.4430, 352.4884,   1.1367],
        ...,
        [257.1243, 365.9149,   1.0424, 

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[255.7502, 364.4900,   1.2108,  ..., 308.0724, 218.6159,   0.7222],
        [256.3532, 364.3342,   1.6077,  ...,  79.7305, 257.8662,   0.6984],
        [256.4426, 364.9953,   0.8425,  ..., 167.1245, 114.6896,   0.6563],
        ...,
        [256.8822, 365.1547,   0.8287,  ..., 324.6740, 126.1046,   0.8165],
        [257.7528, 366.2492,   1.2578,  ..., 151.8803, 114.1674,   1.0547],
        [255.5036, 364.5318,   1.5611,  ..., 456.6854, 297.0300,   0.7299]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.0679, 368.3557,   1.1169,  ..., 237.2937, 184.4690,   1.1275],
        [258.7699, 368.2445,   1.2576,  ..., 302.4039, 333.7985,   0.7867],
        [259.2794, 369.0379,   1.4429,  ..., 327.5565, 180.4537,   1.0077],
        ...,
        [259.5046, 369.4806,   1.3852,  ..., 331.5952, 213.0647,   0.9890],
        [259.2151, 367.9816,   1.0442,  ..., 261.8057, 164.7544,   0.8887],
    

output tensor([[257.9383, 365.8342,   0.8874,  ..., 377.4128, 307.1436,   1.2293],
        [258.3308, 367.9968,   1.5617,  ..., 327.1324, 183.5544,   1.1663],
        [257.1677, 366.3398,   1.4479,  ..., 443.1990, 307.4916,   1.3123],
        ...,
        [257.2056, 365.9113,   0.9622,  ..., 401.5362, 351.3051,   1.0835],
        [257.7637, 365.4974,   0.9908,  ..., 366.6628, 314.6924,   1.2365],
        [257.4958, 366.2511,   0.9567,  ..., 417.3203, 350.7786,   1.2067]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.6551, 366.8326,   0.4413,  ...,  80.0086, 289.2366,   1.4575],
        [259.1705, 367.6340,   0.8504,  ..., 356.7655, 262.3838,   1.3473],
        [259.1312, 368.2571,   0.9919,  ..., 152.3588, 167.5782,   1.2928],
        ...,
        [258.1950, 367.4102,   0.6648,  ..., 321.5275, 338.3911,   1.3598],
        [259.6351, 368.9098,   0.5940,  ..., 308.6378, 256.0657,   1.2169],
        [258.4840, 367.3005,   0.7365,  ..., 43

output tensor([[256.9317, 365.4956,   0.8531,  ..., 328.7799, 142.5444,   0.8784],
        [256.8799, 364.8151,   1.3840,  ..., 293.7668,  79.5010,   0.5782],
        [257.9028, 367.3518,   1.5080,  ..., 427.5093, 289.9008,   0.8852],
        ...,
        [259.3390, 369.1371,   1.2711,  ..., 431.8033, 348.7564,   0.9412],
        [257.9001, 366.6267,   1.2072,  ..., 171.0650,  97.4371,   0.9348],
        [256.7823, 365.1727,   0.7767,  ..., 373.0257, 160.8029,   0.9279]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[256.8647, 364.6418,   1.3759,  ...,  78.4947, 279.6882,   0.9021],
        [255.0677, 362.9928,   0.5720,  ..., 325.3894, 134.2269,   0.8956],
        [256.6476, 365.4408,   1.2245,  ..., 228.1906, 178.9730,   1.1076],
        ...,
        [255.7480, 364.3961,   0.9834,  ..., 280.3192, 174.6297,   0.9153],
        [255.9509, 365.0231,   1.0306,  ..., 294.5723, 235.1841,   0.9611],
        [257.7184, 366.8890,   0.7241,  ...,  8

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5519e+02, 3.6310e+02, 1.2465e+00,  ..., 3.0896e+02, 8.5974e+01,
         6.4634e-01],
        [2.5545e+02, 3.6408e+02, 3.8781e-01,  ..., 4.0365e+02, 2.9458e+02,
         8.2861e-01],
        [2.5536e+02, 3.6402e+02, 9.8319e-01,  ..., 8.0130e+01, 1.8051e+02,
         1.0006e+00],
        ...,
        [2.5509e+02, 3.6331e+02, 3.3149e-01,  ..., 2.9851e+02, 1.8164e+02,
         9.6632e-01],
        [2.5479e+02, 3.6301e+02, 7.1660e-01,  ..., 1.6058e+02, 1.1537e+02,
         7.6788e-01],
        [2.5612e+02, 3.6457e+02, 6.6898e-01,  ..., 4.0295e+02, 3.5105e+02,
         7.5717e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.9470, 369.3691,   1.3356,  ...,  94.7326, 229.0427,   1.2906],
        [259.8457, 369.1440,   0.8843,  ..., 329.6880, 139.8121,   1.2203],
        [260.1956, 370.1951,   0.9391,  ..., 304.5759, 338.7460,   1.2510],
        ...,
        [260.3516, 370.3959,   0.9562

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5977e+02, 3.6933e+02, 1.4664e+00,  ..., 3.4457e+02, 3.4776e+02,
         8.7660e-01],
        [2.5957e+02, 3.6961e+02, 9.3966e-01,  ..., 3.8359e+02, 1.4360e+02,
         9.9063e-01],
        [2.5973e+02, 3.6945e+02, 1.1425e+00,  ..., 3.7358e+02, 2.7098e+02,
         6.7905e-01],
        ...,
        [2.5880e+02, 3.6836e+02, 6.9549e-01,  ..., 2.4026e+02, 1.0596e+02,
         8.7267e-01],
        [2.5996e+02, 3.6902e+02, 1.2370e+00,  ..., 2.8937e+02, 8.2758e+01,
         3.8506e-01],
        [2.6003e+02, 3.6948e+02, 9.0774e-01,  ..., 4.0965e+02, 3.5484e+02,
         6.7063e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[255.9748, 363.7379,   0.7708,  ..., 160.1572, 141.8055,   1.5524],
        [256.0090, 363.9164,   1.4864,  ..., 379.0667, 342.4416,   1.5758],
        [255.9388, 364.2793,   1.2374,  ...,  53.6807, 259.6559,   1.1696],
        ...,
        [256.0343, 364.0196,   1.4565

output tensor([[2.6238e+02, 3.7294e+02, 1.4147e+00,  ..., 3.2030e+02, 2.5790e+02,
         3.9152e-01],
        [2.6463e+02, 3.7499e+02, 1.9678e+00,  ..., 1.8448e+02, 1.3988e+02,
         5.9569e-01],
        [2.6416e+02, 3.7611e+02, 1.4094e+00,  ..., 4.1436e+02, 1.4542e+02,
         7.6093e-01],
        ...,
        [2.6472e+02, 3.7615e+02, 1.9059e+00,  ..., 2.8479e+02, 8.7729e+01,
         6.3387e-02],
        [2.6182e+02, 3.7218e+02, 1.5941e+00,  ..., 3.1118e+02, 3.3518e+02,
         4.7384e-01],
        [2.6423e+02, 3.7574e+02, 1.7268e+00,  ..., 2.4322e+02, 1.0756e+02,
         4.9780e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5183e+02,  3.5921e+02,  4.8594e-01,  ...,  3.2544e+02,
          9.9658e+01,  1.7180e+00],
        [ 2.5027e+02,  3.5771e+02,  1.5486e-01,  ...,  4.4832e+02,
          2.6823e+02,  1.8457e+00],
        [ 2.5158e+02,  3.5824e+02, -2.4940e-02,  ...,  2.9342e+02,
          3.1150e+02,  1.8720e+00],
        ...,

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[255.4810, 364.3251,   1.7401,  ..., 428.0280, 277.8970,   1.5675],
        [255.7163, 363.8767,   0.8983,  ..., 339.6138, 129.5362,   1.7985],
        [256.5650, 365.1822,   0.9712,  ..., 290.0274, 100.1228,   1.5771],
        ...,
        [256.1926, 364.4440,   1.4508,  ..., 377.4648, 340.0512,   1.7278],
        [257.1070, 363.6331,   1.1656,  ..., 424.1844, 225.5514,   1.6484],
        [256.1377, 364.5081,   1.4664,  ..., 342.5925, 237.2802,   1.6000]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5903e+02, 3.6846e+02, 1.2840e+00,  ..., 1.7833e+02, 1.7923e+02,
         2.6079e-01],
        [2.5963e+02, 3.6904e+02, 1.3644e+00,  ..., 9.7235e+01, 2.2634e+02,
         6.3131e-01],
        [2.5908e+02, 3.6847e+02, 8.2766e-01,  ..., 2.8550e+02, 1.6677e+02,
         3.3636e-01],
        ...,
        [2.5917e+02, 3.6753e+02, 1.4368e+00,  ..., 4.4604e+02, 2.4858e+02,
         5.3517e-0

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5763e+02, 3.6638e+02, 7.5557e-01,  ..., 2.7670e+02, 1.2107e+02,
         1.1843e+00],
        [2.5713e+02, 3.6620e+02, 9.5204e-01,  ..., 3.2361e+02, 1.5665e+02,
         1.1321e+00],
        [2.5756e+02, 3.6644e+02, 2.7130e-01,  ..., 1.6544e+02, 1.1784e+02,
         1.0128e+00],
        ...,
        [2.5757e+02, 3.6677e+02, 4.7122e-01,  ..., 2.1286e+02, 1.8833e+02,
         1.2900e+00],
        [2.5734e+02, 3.6643e+02, 1.9332e-01,  ..., 1.5820e+02, 1.1877e+02,
         9.6358e-01],
        [2.5763e+02, 3.6666e+02, 1.3142e-01,  ..., 3.9965e+02, 2.9597e+02,
         1.1328e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.1794, 367.6305,   0.8581,  ...,  52.6035, 293.6317,   0.5835],
        [259.1037, 367.3731,   1.1089,  ..., 220.2959, 118.5286,   0.7460],
        [258.3664, 367.6754,   1.0648,  ..., 410.8346, 345.3768,   0.8517],
        ...,
        [258.4651, 367.6391,   1.1605

output tensor([[2.6499e+02, 3.7666e+02, 1.6737e+00,  ..., 9.2980e+01, 1.9893e+02,
         5.8860e-01],
        [2.6747e+02, 3.7894e+02, 1.6489e+00,  ..., 3.6296e+02, 2.1145e+02,
         5.6299e-01],
        [2.6973e+02, 3.8198e+02, 1.1119e+00,  ..., 4.3253e+02, 3.5755e+02,
         5.2246e-01],
        ...,
        [2.6508e+02, 3.7700e+02, 1.6771e+00,  ..., 8.4616e+01, 2.0411e+02,
         7.4799e-01],
        [2.6972e+02, 3.8196e+02, 1.0975e+00,  ..., 4.3021e+02, 3.5934e+02,
         5.0127e-01],
        [2.6720e+02, 3.7843e+02, 1.7420e+00,  ..., 3.4086e+02, 9.9611e+01,
         3.6275e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.4640e+02,  3.5227e+02, -7.5994e-01,  ...,  2.4222e+02,
          1.9324e+02,  2.0632e+00],
        [ 2.4922e+02,  3.5537e+02,  3.1236e-01,  ...,  3.9426e+02,
          2.0066e+02,  2.1763e+00],
        [ 2.4732e+02,  3.5198e+02,  2.2918e-01,  ...,  2.6661e+02,
          1.0558e+02,  2.3171e+00],
        ...,

output tensor([[257.8702, 365.5985,   1.7666,  ..., 101.9857, 186.2620,   1.0926],
        [257.5141, 366.4574,   1.7526,  ..., 165.5630, 179.1528,   0.7888],
        [258.7871, 368.0963,   1.0791,  ..., 307.8509, 205.6627,   1.0508],
        ...,
        [257.8848, 365.5840,   1.7506,  ..., 102.8092, 186.4258,   1.1010],
        [257.9936, 365.8286,   0.7762,  ..., 428.9747, 317.2537,   1.0951],
        [258.0351, 366.4761,   1.3215,  ..., 353.5674, 322.8188,   1.1588]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5661e+02,  3.6486e+02, -4.1611e-01,  ...,  5.6129e+01,
          2.9193e+02,  1.1270e+00],
        [ 2.5776e+02,  3.6717e+02,  7.2499e-02,  ...,  3.9041e+02,
          3.5138e+02,  1.3938e+00],
        [ 2.5828e+02,  3.6685e+02,  3.8746e-01,  ...,  4.5739e+02,
          2.9976e+02,  1.2659e+00],
        ...,
        [ 2.5807e+02,  3.6630e+02,  4.2742e-01,  ...,  3.4887e+02,
          1.9165e+02,  1.1774e+00],
        [ 2.579

output tensor([[257.7479, 366.5206,   1.0587,  ..., 180.2229, 114.3327,   0.9362],
        [258.1588, 367.4950,   1.4484,  ..., 327.2643, 337.8247,   1.0500],
        [257.8853, 367.0008,   0.8522,  ..., 301.5083, 182.3788,   1.0037],
        ...,
        [258.5660, 367.0197,   1.1030,  ..., 369.9177, 163.9100,   0.9166],
        [258.1464, 367.3088,   1.4632,  ..., 301.0747, 238.8832,   1.0396],
        [256.9271, 366.0065,   1.0721,  ..., 212.1968, 110.6575,   0.9648]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.3088, 366.8303,   0.8809,  ..., 185.8683, 121.6177,   0.9867],
        [257.8490, 366.5591,   0.9284,  ..., 360.3116, 320.4009,   1.1653],
        [257.8806, 367.0989,   0.7910,  ..., 392.5701, 353.2997,   0.9123],
        ...,
        [258.3302, 368.1329,   1.1259,  ..., 307.7219, 345.9752,   1.1199],
        [258.4718, 367.2198,   1.2295,  ...,  91.4673, 223.7817,   1.2922],
        [258.1214, 367.3834,   1.1826,  ..., 34

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.7250, 368.9232,   1.2274,  ...,  87.8185, 242.8819,   0.9472],
        [260.0490, 369.4653,   0.9146,  ..., 335.3211, 257.9612,   1.0130],
        [260.3760, 369.7719,   1.3360,  ..., 346.0431, 196.1227,   1.1699],
        ...,
        [259.4958, 368.4753,   0.7582,  ..., 369.3686, 348.6945,   0.9980],
        [258.6177, 368.8669,   0.7047,  ..., 422.1695, 162.2787,   1.1953],
        [259.0338, 368.8065,   0.7940,  ..., 182.4075, 120.6399,   0.8425]],
       grad_fn=<AddmmBackward0>)
torch.Size([60, 18])
torch.Size([60, 3])
output tensor([[256.5502, 364.0314,   1.0406,  ..., 437.0876, 253.5079,   1.0600],
        [257.2148, 365.3352,   1.1507,  ..., 276.0134,  81.3017,   0.6356],
        [256.9611, 365.7147,   1.3645,  ..., 274.1499,  93.0801,   0.8423],
        ...,
        [256.7266, 365.4350,   1.3413,  ..., 332.4629,  99.0430,   1.1193],
        [257.4608, 365.2937,   1.0467,  ..., 180.4838, 137.5700,   1.1018],
      

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[256.5881, 364.9673,   0.9906,  ..., 421.2268, 348.8078,   1.7591],
        [256.1464, 365.0768,   1.0882,  ..., 458.2016, 258.4072,   1.5919],
        [256.2358, 364.4791,   1.2681,  ..., 184.9775, 178.7201,   1.7009],
        ...,
        [256.8257, 365.5098,   0.5441,  ...,  89.7438, 195.1210,   1.6901],
        [256.8439, 364.5688,   1.0039,  ..., 183.8014, 129.2355,   1.9689],
        [255.7920, 364.5771,   0.9529,  ..., 104.6282, 209.2699,   1.3756]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5976e+02,  3.6905e+02,  7.4973e-01,  ...,  1.1487e+02,
          1.7485e+02, -1.5469e-03],
        [ 2.5901e+02,  3.6773e+02,  1.2044e+00,  ...,  2.6094e+02,
          1.6187e+02,  3.6767e-01],
        [ 2.5952e+02,  3.6772e+02,  7.5453e-01,  ...,  2.2933e+02,
          1.2077e+02,  3.6297e-01],
        ...,
        [ 2.5938e+02,  3.6952e+02,  1.0816e+00,  ...,  4.2453e+02,
        

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.6766e+02, 3.8034e+02, 7.4495e-01,  ..., 3.0287e+02, 1.7301e+02,
         4.7882e-01],
        [2.6722e+02, 3.7901e+02, 1.2350e+00,  ..., 2.9110e+02, 8.4600e+01,
         2.6598e-01],
        [2.6712e+02, 3.7977e+02, 1.2840e+00,  ..., 3.3807e+02, 1.5137e+02,
         5.6698e-01],
        ...,
        [2.6693e+02, 3.7854e+02, 1.0371e+00,  ..., 3.0391e+02, 3.3696e+02,
         4.7672e-01],
        [2.6696e+02, 3.7826e+02, 1.0014e+00,  ..., 3.0615e+02, 3.2979e+02,
         4.5324e-01],
        [2.6684e+02, 3.7941e+02, 8.9744e-01,  ..., 2.9546e+02, 9.8068e+01,
         4.3039e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.4572e+02,  3.5006e+02, -3.0267e-01,  ...,  3.8289e+02,
          1.7265e+02,  2.0917e+00],
        [ 2.4727e+02,  3.5225e+02, -3.3550e-01,  ...,  1.4108e+02,
          1.4423e+02,  1.8571e+00],
        [ 2.4827e+02,  3.5367e+02, -3.6017e-01,  ...,  6.3983e+01,
     

output tensor([[257.1334, 366.4484,   1.3571,  ..., 286.6853, 112.1539,   1.8891],
        [257.1055, 365.5251,   1.2605,  ..., 382.0499, 267.6961,   1.5704],
        [257.6784, 366.0561,   1.0926,  ..., 180.4195, 125.0454,   1.8656],
        ...,
        [257.3882, 365.9734,   0.6621,  ..., 391.5873, 297.5962,   1.8262],
        [257.1065, 366.2656,   1.3149,  ..., 221.4343, 178.1107,   2.0503],
        [257.4787, 366.2051,   1.0200,  ..., 418.1681, 214.6808,   1.3384]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5920e+02,  3.6866e+02,  1.3548e+00,  ...,  3.9098e+02,
          2.1211e+02, -6.2421e-02],
        [ 2.5863e+02,  3.6835e+02,  1.7096e+00,  ...,  3.7388e+02,
          2.3013e+02, -2.0486e-01],
        [ 2.5840e+02,  3.6784e+02,  1.3449e+00,  ...,  1.4938e+02,
          1.6075e+02,  2.5575e-01],
        ...,
        [ 2.5831e+02,  3.6790e+02,  1.7313e+00,  ...,  3.9203e+02,
          1.3539e+02, -3.2360e-02],
        [ 2.587

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.4676, 366.1945,   1.6710,  ..., 433.4727, 315.1733,   2.1227],
        [257.4373, 366.4534,   1.5949,  ..., 373.3021, 224.0163,   1.7717],
        [257.2581, 365.9518,   1.6135,  ..., 344.7371, 236.3678,   1.8020],
        ...,
        [257.4670, 366.7528,   1.7397,  ..., 325.8960, 212.2968,   1.9453],
        [257.9603, 366.2636,   1.3921,  ..., 369.5094, 314.6100,   2.0158],
        [258.0496, 366.9205,   1.6004,  ..., 298.3979, 336.3125,   1.8210]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5799e+02, 3.6679e+02, 7.3855e-01,  ..., 3.7938e+02, 3.4131e+02,
         5.6826e-01],
        [2.5859e+02, 3.6787e+02, 5.9179e-01,  ..., 3.9660e+02, 1.6631e+02,
         1.0086e+00],
        [2.5799e+02, 3.6673e+02, 1.2869e+00,  ..., 8.9115e+01, 2.3844e+02,
         2.3359e-01],
        ...,
        [2.5764e+02, 3.6708e+02, 9.2874e-01,  ..., 1.1905e+02, 1.4185e+02,
         2.1969e-0

output tensor([[258.9529, 368.8129,   0.6922,  ..., 112.2650, 171.8154,   1.3200],
        [258.7474, 368.0277,   1.0651,  ...,  80.3738, 243.0608,   1.3554],
        [258.4468, 367.6740,   0.5292,  ..., 400.5425, 354.1937,   0.7535],
        ...,
        [258.3718, 367.7086,   0.8452,  ..., 273.8788,  91.5212,   1.2072],
        [258.9882, 367.8193,   0.8861,  ..., 349.1515, 186.6794,   1.1333],
        [258.4699, 367.8652,   0.6999,  ..., 339.9968, 252.9214,   1.0607]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5710e+02, 3.6656e+02, 1.2595e+00,  ..., 1.8526e+02, 1.2123e+02,
         4.5504e-01],
        [2.5752e+02, 3.6693e+02, 1.7402e+00,  ..., 3.1295e+02, 3.4372e+02,
         4.5155e-01],
        [2.5743e+02, 3.6591e+02, 1.5579e+00,  ..., 3.6775e+02, 2.7160e+02,
         4.0608e-01],
        ...,
        [2.5726e+02, 3.6579e+02, 1.7213e+00,  ..., 3.7546e+02, 3.4741e+02,
         2.4448e-01],
        [2.5716e+02, 3.6573e+02, 1.5874

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[254.5655, 363.3756,   1.8106,  ..., 316.4888, 337.3476,   2.4097],
        [255.1258, 362.6367,   1.3774,  ...,  61.1476, 243.1348,   1.8519],
        [253.6060, 359.8380,   1.7435,  ..., 223.1691, 114.5957,   1.7760],
        ...,
        [254.1535, 361.4098,   1.9917,  ...,  84.7553, 229.2836,   2.1051],
        [254.7287, 363.1461,   1.4254,  ..., 408.1350, 293.3948,   2.9217],
        [252.7451, 360.7625,   1.3239,  ..., 231.1144, 100.3327,   1.9489]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.6045e+02,  3.7050e+02,  5.0179e-01,  ...,  2.6256e+02,
          2.0464e+02,  2.5266e-01],
        [ 2.6149e+02,  3.7128e+02,  3.3957e-01,  ...,  6.3755e+01,
          3.0906e+02,  5.0123e-02],
        [ 2.5983e+02,  3.6897e+02,  9.2471e-01,  ...,  2.3707e+02,
          1.4861e+02,  5.5841e-01],
        ...,
        [ 2.6020e+02,  3.7111e+02,  3.4254e-01,  ...,  3.1937e+02,
        

output tensor([[2.5621e+02, 3.6550e+02, 3.6141e-01,  ..., 3.0924e+02, 3.4312e+02,
         1.0963e+00],
        [2.5735e+02, 3.6656e+02, 2.5169e-01,  ..., 9.6131e+01, 1.8259e+02,
         7.4096e-01],
        [2.5616e+02, 3.6492e+02, 3.4439e-01,  ..., 3.0558e+02, 2.1376e+02,
         1.0393e+00],
        ...,
        [2.5581e+02, 3.6528e+02, 3.4098e-01,  ..., 4.5719e+02, 2.5733e+02,
         6.5454e-01],
        [2.5587e+02, 3.6374e+02, 3.4170e-01,  ..., 3.3648e+02, 3.0670e+02,
         9.0918e-01],
        [2.5600e+02, 3.6470e+02, 6.5341e-01,  ..., 3.6567e+02, 2.4449e+02,
         7.4868e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[259.4822, 369.2493,   1.3356,  ..., 202.4409, 113.4825,   0.8391],
        [259.7220, 369.5581,   1.7650,  ..., 321.4243, 161.1208,   1.2808],
        [259.7695, 369.4389,   1.3503,  ..., 306.0109, 192.4300,   1.4212],
        ...,
        [260.3251, 370.3181,   1.7101,  ..., 101.3538, 153.4090,   0.9010],
    

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5739e+02, 3.6563e+02, 1.4344e+00,  ..., 3.4321e+02, 3.0729e+02,
         4.1351e-01],
        [2.5751e+02, 3.6607e+02, 1.2392e+00,  ..., 3.6337e+02, 2.9133e+02,
         3.9930e-01],
        [2.5707e+02, 3.6589e+02, 1.4831e+00,  ..., 2.7987e+02, 1.7763e+02,
         5.6064e-01],
        ...,
        [2.5756e+02, 3.6649e+02, 1.4387e+00,  ..., 3.9423e+02, 3.5142e+02,
         4.7896e-01],
        [2.5770e+02, 3.6672e+02, 1.6906e+00,  ..., 4.3662e+02, 3.1047e+02,
         4.8902e-01],
        [2.5758e+02, 3.6659e+02, 1.4921e+00,  ..., 3.6151e+02, 3.2872e+02,
         6.1669e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5783e+02, 3.6723e+02, 2.0455e-01,  ..., 4.2599e+02, 1.8524e+02,
         1.6965e+00],
        [2.5846e+02, 3.6649e+02, 2.3270e-01,  ..., 4.1743e+02, 2.0452e+02,
         1.5634e+00],
        [2.5815e+02, 3.6672e+02, 5.4591e-01,  ..., 2.8549e+02, 8.1565e+01,
         

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[2.5143e+02, 3.5851e+02, 5.1861e-01,  ..., 9.5526e+01, 2.1469e+02,
         1.9388e+00],
        [2.5231e+02, 3.5921e+02, 5.1772e-01,  ..., 5.8773e+01, 3.0585e+02,
         1.6575e+00],
        [2.5094e+02, 3.5754e+02, 4.3315e-01,  ..., 3.9232e+02, 1.6427e+02,
         1.8034e+00],
        ...,
        [2.5255e+02, 3.5907e+02, 6.7827e-01,  ..., 3.4474e+02, 1.8915e+02,
         1.5435e+00],
        [2.5120e+02, 3.5653e+02, 4.2284e-02,  ..., 4.2968e+02, 2.3655e+02,
         1.8921e+00],
        [2.5137e+02, 3.5799e+02, 3.9757e-01,  ..., 3.9458e+02, 1.7459e+02,
         1.7405e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[264.7753, 376.5082,   0.5392,  ..., 281.8123, 155.6444,   0.4486],
        [265.4069, 376.7247,   1.3865,  ..., 365.0348, 354.0662,   0.7087],
        [264.3864, 374.9493,   1.0525,  ..., 309.0555, 321.8526,   0.9048],
        ...,
        [264.8566, 376.5643,   1.2963

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5654e+02,  3.6468e+02,  3.1864e-01,  ...,  3.4951e+02,
          3.5013e+02,  1.6149e+00],
        [ 2.5640e+02,  3.6415e+02, -6.5350e-02,  ...,  3.0688e+02,
          2.0930e+02,  1.6595e+00],
        [ 2.5594e+02,  3.6413e+02, -1.2376e-01,  ...,  1.0120e+02,
          1.9702e+02,  1.7493e+00],
        ...,
        [ 2.5558e+02,  3.6382e+02,  5.6322e-01,  ...,  7.9361e+01,
          1.9010e+02,  1.6475e+00],
        [ 2.5633e+02,  3.6501e+02,  2.9718e-01,  ...,  3.3223e+02,
          1.0070e+02,  1.6575e+00],
        [ 2.5602e+02,  3.6451e+02, -2.6216e-01,  ...,  2.8015e+02,
          1.5541e+02,  1.2799e+00]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.9789, 368.2505,   1.7553,  ...,  58.8439, 300.1513,   0.7894],
        [260.2629, 369.4702,   1.7376,  ..., 185.7834, 135.0440,   0.7654],
        [259.7127, 369.4449,   1.5726,  ..., 408.6388, 298.5427,   0.6714],
        ...,
 

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[260.1415, 369.8408,   2.2685,  ..., 392.9762, 180.0739,   2.2023],
        [259.3195, 369.6888,   1.9833,  ..., 421.9723, 169.6595,   2.0301],
        [260.5998, 370.2542,   2.5422,  ..., 368.3015, 268.5684,   2.4948],
        ...,
        [260.0083, 370.1651,   2.0813,  ..., 394.3234, 134.7930,   2.0201],
        [259.5400, 369.4180,   1.6010,  ..., 256.7395, 159.5098,   2.2997],
        [261.2890, 371.1322,   1.5329,  ...,  56.3326, 257.6355,   2.5313]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5570e+02,  3.6406e+02,  1.5819e+00,  ...,  4.5136e+02,
          3.0467e+02, -1.0002e+00],
        [ 2.5580e+02,  3.6384e+02,  6.0830e-01,  ...,  3.5539e+02,
          1.4979e+02, -9.4716e-01],
        [ 2.5592e+02,  3.6392e+02,  6.4548e-01,  ...,  3.5975e+02,
          1.5672e+02, -1.0194e+00],
        ...,
        [ 2.5515e+02,  3.6280e+02,  1.0234e+00,  ...,  3.4355e+02,
        

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5847e+02,  3.6706e+02,  1.0106e-01,  ...,  1.9590e+02,
          1.1334e+02,  6.1823e-01],
        [ 2.5857e+02,  3.6787e+02,  6.6667e-02,  ...,  3.8114e+02,
          1.4445e+02,  1.4074e+00],
        [ 2.5828e+02,  3.6643e+02,  6.0127e-01,  ...,  2.3791e+02,
          1.3121e+02,  6.7607e-01],
        ...,
        [ 2.5879e+02,  3.6768e+02,  3.7975e-01,  ...,  1.6007e+02,
          1.0991e+02,  1.0044e+00],
        [ 2.5819e+02,  3.6762e+02,  6.0219e-02,  ...,  3.9185e+02,
          1.3030e+02,  1.1856e+00],
        [ 2.5885e+02,  3.6794e+02, -7.6429e-02,  ...,  4.1429e+02,
          2.0362e+02,  6.2536e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.3818, 366.5862,   2.2332,  ..., 317.4594, 303.5574,   1.4788],
        [257.5349, 366.4627,   2.3172,  ..., 339.6259, 337.5202,   1.5292],
        [257.6625, 366.3988,   2.2799,  ..., 385.0356, 344.9081,   1.5591],
        ...,
 

output tensor([[ 2.4963e+02,  3.5661e+02, -9.9041e-01,  ...,  2.7098e+02,
          1.0876e+02,  1.6926e-01],
        [ 2.5214e+02,  3.5918e+02,  5.9971e-03,  ...,  3.4461e+02,
          2.1888e+02,  8.0693e-01],
        [ 2.5165e+02,  3.5931e+02, -3.6377e-01,  ...,  1.0646e+02,
          2.0635e+02,  5.2569e-01],
        ...,
        [ 2.5031e+02,  3.5722e+02, -7.1262e-01,  ...,  2.3313e+02,
          1.1179e+02,  5.2678e-01],
        [ 2.5002e+02,  3.5694e+02, -3.9348e-01,  ...,  2.7299e+02,
          1.0037e+02,  3.8051e-01],
        [ 2.5343e+02,  3.6131e+02, -2.6925e-01,  ...,  5.3602e+01,
          3.0091e+02,  4.0941e-01]], grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[266.4313, 378.5429,   2.3285,  ...,  79.5474, 201.4805,   1.1092],
        [265.4235, 377.4096,   1.5236,  ..., 286.3219, 165.8961,   0.9546],
        [264.4542, 376.0184,   2.5211,  ..., 466.4637, 266.4492,   0.7082],
        ...,
        [265.4211, 376.9419,   1.4581,  ...,

output tensor([[257.5973, 367.0148,   0.6739,  ..., 292.5225, 234.8377,   0.8915],
        [257.0556, 366.7437,   0.8122,  ..., 230.0999, 104.5219,   0.7428],
        [258.5092, 368.1083,   1.0662,  ..., 308.6748, 252.6377,   0.5520],
        ...,
        [257.7826, 367.1085,   0.9164,  ...,  93.8220, 157.7689,   0.5662],
        [256.9153, 366.7531,   0.7873,  ..., 244.4329, 103.0209,   0.7349],
        [257.4084, 366.6411,   0.7999,  ..., 167.1524, 115.2486,   0.8158]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.7241, 368.9793,   0.9992,  ..., 291.6711,  99.2633,   0.8553],
        [259.9311, 368.8837,   0.9639,  ...,  78.4158, 288.3798,   0.7242],
        [259.6653, 369.3020,   1.5117,  ..., 382.1628, 344.9129,   0.9130],
        ...,
        [258.6785, 368.7648,   1.4252,  ..., 284.0938,  95.8711,   0.8458],
        [259.3817, 369.4249,   1.2117,  ..., 435.4612, 247.0018,   0.9558],
        [259.8397, 369.4427,   1.2951,  ..., 38

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[257.1927, 365.4221,   0.6176,  ..., 379.5749, 345.0918,   0.9764],
        [257.3721, 365.6395,   1.0667,  ..., 400.2399, 189.5462,   1.2783],
        [257.0952, 365.6794,   0.9112,  ..., 282.5886, 172.3105,   1.0382],
        ...,
        [257.5035, 365.5558,   1.2646,  ..., 349.0352, 189.0096,   1.1143],
        [257.0039, 365.7343,   1.2593,  ..., 322.0249, 157.5295,   1.1571],
        [256.8873, 365.8492,   0.6649,  ..., 418.4243, 238.5150,   0.7659]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[258.5067, 367.2408,   0.8417,  ..., 411.2480, 350.7558,   1.0670],
        [258.7168, 368.1088,   1.0543,  ..., 300.5524, 241.0503,   0.9644],
        [258.7688, 368.0066,   0.9923,  ..., 388.6164, 128.6055,   1.1641],
        ...,
        [259.2720, 367.9824,   0.8615,  ..., 197.4364, 114.5279,   0.9317],
        [258.5329, 368.0403,   1.4321,  ..., 319.5089, 164.7003,   1.0837],
    

torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[255.8227, 363.5770,   0.8866,  ..., 179.1214, 122.6226,   0.8574],
        [256.2375, 365.1920,   0.8870,  ..., 429.3319, 245.1533,   0.5410],
        [256.2411, 364.1041,   1.1965,  ..., 288.0987,  82.5220,   0.5527],
        ...,
        [255.6701, 364.2141,   0.6960,  ..., 278.0697, 105.0202,   0.6959],
        [255.6635, 363.7432,   0.9005,  ..., 182.5868, 113.0598,   0.7685],
        [255.3340, 363.2931,   1.2051,  ..., 360.2234, 258.9237,   0.5381]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[260.2127, 369.6142,   0.9690,  ..., 353.6959, 322.6679,   1.3338],
        [261.2409, 371.0004,   1.3186,  ..., 453.7888, 241.6299,   1.3326],
        [259.7126, 369.6455,   0.8346,  ..., 223.7558, 108.3927,   1.2184],
        ...,
        [259.8606, 369.7755,   1.0421,  ..., 164.1664, 117.5684,   1.2965],
        [261.0792, 371.1751,   0.8754,  ..., 442.0934, 239.8562,   1.2118],
    

output tensor([[258.8342, 368.0084,   1.2386,  ..., 265.5524, 205.2001,   0.7043],
        [258.4431, 367.8278,   0.9590,  ..., 293.3424, 167.0519,   0.9623],
        [259.4543, 368.9725,   1.5803,  ..., 411.5758, 352.5380,   1.1906],
        ...,
        [258.8702, 368.1381,   0.9444,  ..., 334.2533, 261.1833,   1.0491],
        [258.6695, 367.8842,   1.2478,  ..., 353.6036, 224.3502,   1.4366],
        [258.1398, 367.6743,   1.5604,  ..., 346.0152, 351.2582,   1.0898]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 18])
torch.Size([128, 3])
output tensor([[ 2.5730e+02,  3.6643e+02, -1.6627e-01,  ...,  3.1948e+02,
          2.5787e+02,  6.4230e-01],
        [ 2.5722e+02,  3.6601e+02, -2.7620e-01,  ...,  3.0373e+02,
          2.0448e+02,  8.9320e-01],
        [ 2.5690e+02,  3.6544e+02, -2.4043e-01,  ...,  3.3145e+02,
          3.0509e+02,  9.0115e-01],
        ...,
        [ 2.5711e+02,  3.6636e+02, -1.4310e-01,  ...,  2.8987e+02,
          2.3118e+02,  9.1475e-01],
        [ 2.570

In [76]:
def test_model(model_path, test_data_dir):
    # Load the test dataset
    test_dataset = KpVelDataset(test_data_dir)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Initialize the model and load the saved state
    model = KeypointRegressionNet()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # Criterion for evaluation
    criterion = nn.MSELoss()
    total_loss = 0

    # No gradient needed for evaluation
    with torch.no_grad():
        for start_kp, next_kp, velocity in test_loader:
            output = model(start_kp, velocity)
            for i in range(start_kp.size(0)):
                individual_start_kp = start_kp[i]
                individual_next_kp = next_kp[i]
                individual_velocity = velocity[i]
                predicted_next_kp = output[i]

                print("Start KP:", individual_start_kp)
                print("Next KP:", individual_next_kp)
                print("Actual Velocity:", individual_velocity)
                print("Predicted Velocity:", predicted_next_kp)
                print("-----------------------------------------")
            loss = criterion(output, next_kp)
            total_loss += loss.item()
            
    # Calculate the average loss
    avg_loss = total_loss / len(test_loader)
    print(f'Average Test Loss: {avg_loss}')

# Usage
model_path = '/home/jc-merlab/Pictures/Data/trained_models/reg_nkp_b128_e250_v0.pth'  # Update with your model path
test_data_dir = '/home/jc-merlab/Pictures/panda_data/panda_sim_vel/vel_reg_sim_test/split_folder_reg/test/annotations/'  # Update with your test data path
test_model(model_path, test_data_dir)

Start KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 179.8257,
        298.1946,   1.0000, 175.8687, 277.7923,   1.0000, 240.1268, 202.5661,
          1.0000, 262.0790, 205.0822,   1.0000])
Next KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 179.8260,
        298.1959,   1.0000, 175.8685, 277.7937,   1.0000, 240.1280, 202.5687,
          1.0000, 262.0802, 205.0853,   1.0000])
Actual Velocity: tensor([0.0000, 0.0000, 0.3000])
Predicted Velocity: tensor([257.7654, 366.7654,   0.9178, 257.9145, 283.0347,   1.1408, 179.7099,
        298.4261,   1.0625, 175.4932, 277.9028,   0.8915, 240.3632, 202.5495,
          1.0117, 261.5081, 204.5572,   1.1235])
-----------------------------------------
Start KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 179.8213,
        298.1725,   1.0000, 175.8698, 277.7691,   1.0000, 240.1633, 202.5730,
          1.0000, 262.1147, 205.0952,   1.0000])
Next KP: tensor([257.9522, 366.9199,   

Next KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 303.7598,
        217.8956,   1.0000, 320.7623, 229.8525,   1.0000, 419.9565, 231.2956,
          1.0000, 442.5465, 231.8152,   1.0000])
Actual Velocity: tensor([0.0000, 0.0893, 0.0000])
Predicted Velocity: tensor([258.3449, 367.1544,   0.7209, 258.5476, 283.7431,   1.3077, 303.9985,
        219.0668,   1.1029, 320.3009, 230.5019,   0.9057, 420.4909, 232.1425,
          1.0414, 442.1325, 232.9416,   1.1690])
-----------------------------------------
Start KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 303.7598,
        217.8956,   1.0000, 320.7623, 229.8525,   1.0000, 419.9565, 231.2956,
          1.0000, 442.5465, 231.8152,   1.0000])
Next KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 303.7659,
        217.8999,   1.0000, 320.7673, 229.8584,   1.0000, 419.9698, 230.4328,
          1.0000, 442.5634, 230.7548,   1.0000])
Actual Velocity: tensor([0.0000, 0.0893,