In [48]:
import torch
from torch.utils.data import Dataset
import json
import os
import torch.nn as nn
import torch.nn.functional as func
from torch.utils.data import DataLoader
import torch.optim as optim
from os.path import expanduser
import splitfolders
import shutil
import glob
import numpy as np
from sklearn.model_selection import train_test_split

In [49]:
class KpVelDataset(Dataset):
    def __init__(self, json_folder):
        super(KpVelDataset, self).__init__()
        self.data = []
        for json_file in sorted(os.listdir(json_folder)):
            if json_file.endswith('_combined.json'):
                with open(os.path.join(json_folder, json_file), 'r') as file:
                    data = json.load(file)
                    start_kp = data['start_kp']
                    next_kp = data['next_kp']
                    velocity = data['velocity']
                    self.data.append((start_kp, next_kp, velocity))

    def __len__(self):
        return len(self.data)

#     def __getitem__(self, idx):
#         start_kp, next_kp, velocity = self.data[idx]
#         start_kp_flat = torch.tensor([kp for sublist in start_kp for kp in sublist[0]])
#         next_kp_flat = torch.tensor([kp for sublist in next_kp for kp in sublist[0]])
#         velocity = torch.tensor(velocity)
#         return start_kp_flat, next_kp_flat, velocity
  
    def __getitem__(self, idx):
        start_kp, next_kp, velocity = self.data[idx]
        # Extract and flatten the first two elements of each keypoint in start_kp
        start_kp_flat = torch.tensor([kp for sublist in start_kp for kp in sublist[0][:2]], dtype=torch.float)
        # Extract and flatten the first two elements of each keypoint in next_kp
        next_kp_flat = torch.tensor([kp for sublist in next_kp for kp in sublist[0][:2]], dtype=torch.float)
        velocity = torch.tensor(velocity)
        return start_kp_flat, next_kp_flat, velocity
    

In [50]:
def train_test_split(src_dir):
#     dst_dir_img = src_dir + "images"
    dst_dir_anno = src_dir + "annotations"
    
    if os.path.exists(dst_dir_anno):
        print("folders exist")
    else:
        os.mkdir(dst_dir_anno)
        
#     for jpgfile in glob.iglob(os.path.join(src_dir, "*.jpg")):
#         shutil.copy(jpgfile, dst_dir_img)

    for jsonfile in glob.iglob(os.path.join(src_dir, "*_combined.json")):
        shutil.copy(jsonfile, dst_dir_anno)
        
    output = root_dir + "split_folder_reg"
    
    splitfolders.ratio(src_dir, # The location of dataset
                   output=output, # The output location
                   seed=42, # The number of seed
                   ratio=(0.8, 0.1, 0.1), # The ratio of split dataset
                   group_prefix=None, # If your dataset contains more than one file like ".jpg", ".pdf", etc
                   move=False # If you choose to move, turn this into True
                   )
    
#     shutil.rmtree(dst_dir_img)
    shutil.rmtree(dst_dir_anno)
    
    return output  

In [51]:
# class VelRegModel(nn.Module):
#     def __init__(self, input_size):
#         super(VelRegModel, self).__init__()
#         self.fc1 = nn.Linear(input_size * 2, 1024)  # Assuming start_kp and next_kp are concatenated
#         self.fc2 = nn.Linear(1024, 512)
#         self.fc3 = nn.Linear(512,256)
#         self.fc4 = nn.Linear(256,128)
#         self.fc5 = nn.Linear(128,64)
#         self.fc6 = nn.Linear(64,64)
#         self.fc7 = nn.Linear(64,6)  # Output size is 3 for velocity

#     def forward(self, start_kp, velocity):
#         x = torch.cat((start_kp, velocity), dim=1)
#         x = func.relu(self.fc1(x))
#         x = func.relu(self.fc2(x))
#         x = func.relu(self.fc3(x))
#         x = func.relu(self.fc4(x))
#         x = func.relu(self.fc5(x))
#         x = func.relu(self.fc6(x))
#         x = self.fc7(x)
#         return x

In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class KeypointRegressionNet(nn.Module):
    def __init__(self):
        super(KeypointRegressionNet, self).__init__()
        # Define the architecture
        self.fc1 = nn.Linear(15, 128)  # 18 keypoints + 3 velocity values
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 256)
        
        self.fc4 = nn.Linear(128, 12)  # Output size is 18 (6 keypoints * 3 values each)

    def forward(self, start_kp, velocity):
        # Flatten start keypoints and concatenate with velocity
        x = torch.cat((start_kp, velocity), dim=1)
        
        # Forward pass through the network
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)  # No activation function in the last layer
        return x

In [53]:
# Initialize dataset and data loader
# to generalize home directory. User can change their parent path without entering their home directory
num_epochs = 300
batch_size = 128
v = 1
root_dir = '/home/jc-merlab/Pictures/panda_data/panda_sim_vel/vel_reg_sim_test_2/'
print(root_dir)
split_folder_path = train_test_split(root_dir)
dataset = KpVelDataset(root_dir)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize model
model = KeypointRegressionNet()  # Adjust input_size as necessary
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)

# Training loop
for epoch in range(num_epochs):
    for start_kp, next_kp, velocity in data_loader:
        optimizer.zero_grad()
        velocity = velocity.squeeze(1)
        print(start_kp.shape)
        print(velocity.shape)
        output = model(start_kp, velocity)
        loss = criterion(output, next_kp)
        loss.backward()
        optimizer.step()
        print("output", output)
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')
    
# # Save the trained model
model_save_path = f'/home/jc-merlab/Pictures/Data/trained_models/reg_nkp_b{batch_size}_e{num_epochs}_v{v}.pth'
torch.save(model.state_dict(), model_save_path)

# model_save_path = f'/home/jc-merlab/Pictures/Data/trained_models/reg_nkp_b{batch_size}_e{num_epochs}_v{v}.pth'
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'model_structure': KeypointRegressionNet()
# }, model_save_path)
# print(f"Model saved to {model_save_path}")


/home/jc-merlab/Pictures/panda_data/panda_sim_vel/vel_reg_sim_test_2/


Copying files: 3137 files [00:00, 17605.58 files/s]


torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[  0.3755, -11.9200,  -8.2594,  ...,  -3.4368,  -3.8442,  -6.5940],
        [  0.4976, -11.7770,  -8.5405,  ...,  -3.3737,  -3.9315,  -6.9842],
        [ -0.3671, -12.3275,  -7.5591,  ...,  -3.4337,  -3.6178,  -5.8184],
        ...,
        [ -0.7251, -11.2040,  -7.8257,  ...,  -6.0839,  -4.7046,  -4.3595],
        [  0.0161, -12.0585,  -8.0321,  ...,  -3.8734,  -3.8627,  -7.9975],
        [ -0.4673, -11.2488,  -8.1413,  ...,  -3.6001,  -4.5715,  -5.2706]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[42.6132, 62.1723,  7.4779,  ..., 14.5058, 23.6310, 51.8031],
        [43.3016, 63.9957,  7.3664,  ..., 14.9711, 24.2766, 53.3288],
        [38.0389, 53.6992,  6.5125,  ..., 13.6899, 18.1236, 40.7598],
        ...,
        [49.6824, 73.6856,  7.3889,  ..., 16.4269, 28.5839, 60.1466],
        [38.0840, 54.2199,  6.7092,  ..., 13.8676, 18.2255, 41.8668],
        [36.7727, 49.4853,  5.2458

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[242.2252, 348.3666, 248.0428,  ..., 215.0430, 292.4395, 217.6120],
        [242.1182, 351.1022, 251.7318,  ..., 242.7362, 355.0414, 251.2371],
        [290.0117, 397.5147, 276.1425,  ..., 180.0630, 120.4133, 155.4583],
        ...,
        [302.7943, 407.2541, 279.7990,  ..., 187.0379,  99.0204, 157.6131],
        [247.0711, 357.4875, 256.0927,  ..., 246.3974, 355.1330, 254.4703],
        [237.1667, 348.5512, 251.5604,  ..., 230.8494, 363.6476, 239.4995]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[248.5436, 358.5568, 256.9223,  ..., 261.6209, 409.5516, 272.3589],
        [258.4626, 364.6862, 256.5467,  ..., 199.7201, 219.5354, 190.8939],
        [235.2348, 341.5971, 245.0789,  ..., 235.8137, 370.1100, 245.0183],
        ...,
        [282.4911, 392.4300, 274.2609,  ..., 192.8797, 146.1760, 171.7191],
        [247.7747, 350.7218, 247.1131,  ..., 192.5604, 228.6642, 185.3015],
    

output tensor([[249.1085, 349.1666, 247.2515, 273.5283, 238.8030, 224.6090, 247.9228,
         222.2246, 296.0775, 227.4424, 300.5702, 234.4463],
        [264.8458, 376.4427, 267.6223, 294.7419, 195.0018, 264.0314, 199.5981,
         244.7866, 188.3624, 182.3026, 192.2845, 173.8434],
        [251.1154, 358.4316, 252.8081, 276.9332, 299.1253, 222.1866, 313.8466,
         236.9031, 382.9241, 304.9104, 376.9926, 317.0641],
        [267.4573, 377.1215, 268.7354, 295.7281, 329.1194, 231.1313, 342.9965,
         248.2451, 438.8289, 298.6404, 439.5951, 308.5915],
        [262.1222, 371.3096, 260.8711, 287.9832, 176.8117, 284.2438, 174.7639,
         264.3107,  95.6527, 229.9826,  91.3046, 216.0879],
        [271.4735, 377.6129, 272.2938, 300.1266, 322.6065, 230.1946, 338.0417,
         244.6438, 443.9833, 254.4871, 453.3908, 262.7549],
        [253.7027, 369.1898, 259.3585, 280.1800, 229.4936, 215.2393, 247.9808,
         205.6098, 306.1152, 102.4073, 324.8099,  97.1747],
        [254.1222, 3

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[249.2639, 350.8383, 247.6404,  ..., 200.5734, 105.8165, 189.1349],
        [246.7098, 354.4009, 249.3993,  ..., 121.8178, 206.5187, 109.9985],
        [249.7047, 354.0319, 248.8202,  ..., 187.5808, 338.6522, 198.1535],
        ...,
        [258.8914, 368.0422, 259.9338,  ..., 166.5343, 269.5525, 167.6140],
        [261.1645, 371.3591, 264.5874,  ..., 186.4890, 187.3151, 179.5352],
        [249.8822, 355.9760, 252.6067,  ..., 285.4367, 408.4482, 295.9972]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[251.5036, 357.3686, 247.2783,  ..., 294.1547, 318.2641, 308.2112],
        [259.0247, 369.1543, 257.1609,  ..., 131.1836, 320.3207, 133.6708],
        [258.9350, 368.3016, 256.7776,  ..., 336.8931, 350.3697, 350.4725],
        ...,
        [261.9595, 367.4139, 254.3482,  ..., 266.7645,  60.9199, 252.9127],
        [261.4332, 371.9187, 260.7857,  ..., 341.2699, 389.0773, 353.7045],
    

output tensor([[256.6068, 363.2316, 254.2694,  ..., 310.9268, 355.6081, 322.4882],
        [262.4831, 369.0544, 259.8540,  ..., 179.3244, 424.2823, 181.5302],
        [256.0040, 359.1300, 250.5959,  ..., 201.6574,  96.1444, 187.6269],
        ...,
        [257.6031, 364.4344, 255.1505,  ..., 328.8592, 316.3397, 340.7052],
        [253.9845, 360.0057, 252.0991,  ..., 122.2091, 182.1158, 108.8460],
        [258.7167, 366.4442, 256.4585,  ..., 214.3515, 303.3702, 225.6969]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[262.4938, 370.8883, 259.4682,  ..., 332.3422, 307.9131, 343.2874],
        [261.6977, 370.5793, 259.4809,  ..., 338.6389, 310.0034, 349.7570],
        [261.2149, 371.3696, 260.7692,  ..., 337.4957, 350.3370, 348.1866],
        ...,
        [258.0191, 363.9775, 255.4055,  ..., 160.0340, 119.3142, 144.2405],
        [256.5450, 363.8502, 252.2046,  ..., 228.6064, 311.7354, 240.8479],
        [260.4936, 370.8445, 260.4380,  ..., 14

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.0791, 365.3646, 257.4645,  ..., 138.5535, 165.2991, 124.4852],
        [254.1292, 363.0428, 253.5671,  ...,  99.4760, 270.5221,  88.1974],
        [254.2369, 361.7661, 253.8410,  ..., 150.5046, 373.7028, 154.6413],
        ...,
        [251.9014, 361.7364, 256.4373,  ..., 306.7720,  59.3596, 295.8321],
        [256.0569, 361.5723, 250.5466,  ..., 244.1437, 313.8287, 258.5734],
        [255.0271, 363.6827, 254.0978,  ..., 106.1993, 256.0910,  95.2026]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.3063, 362.7182, 253.9695,  ..., 109.7622, 240.1820, 102.1883],
        [250.8235, 360.0796, 253.2804,  ..., 281.7433,  64.3199, 272.6808],
        [254.7789, 364.3188, 258.5166,  ..., 183.8709, 209.8152, 183.6755],
        ...,
        [255.9535, 363.3503, 254.9263,  ..., 303.6591, 368.9651, 320.0607],
        [255.4434, 364.4418, 254.4148,  ..., 182.2429, 349.7440, 196.4695],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.4766, 370.1412, 260.1985,  ..., 335.6784, 342.8237, 342.3598],
        [260.1730, 372.5661, 262.3620,  ..., 330.6819, 428.4984, 343.7186],
        [258.1481, 367.7046, 258.1349,  ..., 133.1514, 221.5441, 121.1842],
        ...,
        [260.7665, 370.9895, 259.3130,  ..., 280.0144, 417.7960, 291.0957],
        [258.3730, 369.2744, 257.2909,  ..., 270.7518,  63.2497, 254.7443],
        [260.5560, 371.0549, 260.0359,  ..., 118.0059, 179.4380,  95.3047]],
       grad_fn=<AddmmBackward0>)
torch.Size([60, 12])
torch.Size([60, 3])
output tensor([[254.7923, 363.6296, 255.7080, 280.4352, 180.2256, 294.9256, 176.9213,
         274.1998, 239.8174, 204.3655, 261.8880, 205.1636],
        [254.1335, 363.5270, 255.9673, 280.5142, 328.2687, 236.6669, 340.7180,
         255.7408, 390.7382, 335.1505, 382.2364, 344.2922],
        [255.5717, 363.7478, 255.3483, 281.0152, 187.4907, 255.2921, 194.1183,
         237.0331, 150.0484, 153.9144, 158

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.0748, 369.0098, 258.7049,  ..., 133.6845, 179.3872, 116.1591],
        [259.0052, 369.8235, 259.0533,  ..., 235.7048, 450.8346, 242.5883],
        [258.7891, 368.7589, 258.2835,  ..., 232.6577, 433.7019, 242.8165],
        ...,
        [260.0306, 368.9183, 259.0282,  ..., 183.9244, 169.8907, 172.3336],
        [258.6262, 368.6588, 258.9576,  ..., 109.4952, 285.1801, 101.9807],
        [260.2863, 369.1838, 259.2577,  ..., 182.2937, 162.0749, 169.4553]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.7766, 362.8271, 255.4199,  ..., 318.7397, 359.2347, 329.2484],
        [253.9430, 362.4658, 254.5333,  ..., 150.9467, 387.7018, 145.2263],
        [255.2337, 363.6225, 255.2802,  ..., 112.4326, 266.6097,  97.0579],
        ...,
        [255.4455, 364.7559, 255.5423,  ..., 262.6375, 425.2909, 274.5038],
        [255.3362, 363.7917, 255.3796,  ..., 112.4887, 263.9888,  96.9187],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.9600, 365.4460, 255.9917,  ..., 150.2156, 163.0087, 137.1682],
        [257.3391, 365.0958, 255.7842,  ..., 173.7440, 101.8540, 156.3984],
        [256.5040, 364.5661, 255.6709,  ..., 332.4654, 298.7607, 339.2791],
        ...,
        [256.1731, 364.8272, 255.9390,  ..., 165.3325, 284.2950, 168.6584],
        [256.4281, 364.5120, 255.3641,  ..., 323.9137, 291.2318, 330.5125],
        [256.4452, 364.1048, 254.7945,  ..., 207.6197,  83.6792, 193.0838]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.1503, 367.7844, 258.8640,  ..., 139.3092, 379.6095, 142.5743],
        [255.8413, 365.4936, 257.3290,  ..., 168.7652, 332.2924, 177.1920],
        [256.8453, 366.9147, 257.6855,  ..., 105.8774, 281.5856,  93.0388],
        ...,
        [256.7793, 367.1684, 258.1206,  ..., 146.2077, 400.8308, 143.9654],
        [257.2680, 368.2779, 258.3543,  ..., 241.4614, 435.1505, 255.5058],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.6383, 367.4261, 257.4505,  ..., 262.2296, 428.1222, 275.6997],
        [257.4769, 366.7216, 257.6370,  ..., 251.8998,  82.4918, 236.4633],
        [257.5454, 365.4539, 256.8830,  ..., 209.8730,  83.4595, 192.5488],
        ...,
        [256.4825, 365.2986, 256.7034,  ..., 198.2555, 421.9082, 203.3310],
        [257.6319, 366.6152, 257.6299,  ..., 143.9749, 181.6976, 129.8094],
        [257.5052, 366.6898, 257.6417,  ..., 249.9838,  83.6004, 234.4733]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.1196, 366.3521, 257.0883,  ..., 231.6091, 433.5036, 241.2742],
        [255.5750, 367.0249, 257.8499,  ..., 312.0321,  60.6879, 299.4840],
        [257.0616, 366.5037, 257.5569,  ..., 147.5087, 378.4395, 149.2398],
        ...,
        [255.5738, 367.2128, 258.5874,  ..., 317.9632,  62.5726, 306.3929],
        [257.8759, 367.4035, 257.7920,  ..., 132.3995, 155.9480, 109.0659],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.7798, 367.8105, 258.6640,  ..., 333.2585, 322.4397, 345.7595],
        [258.7471, 367.8269, 258.3906,  ..., 325.3025, 309.7793, 337.4215],
        [258.0272, 367.7198, 257.9962,  ..., 236.2992, 440.5952, 247.1254],
        ...,
        [259.0271, 368.9105, 259.2979,  ..., 175.7608, 427.5373, 183.7506],
        [259.3325, 369.6820, 258.7655,  ..., 278.2669, 347.2464, 297.7631],
        [259.3844, 369.5410, 259.2317,  ..., 219.4220, 437.6930, 233.6229]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.4146, 363.4434, 254.8247,  ..., 186.1520, 330.6433, 197.1009],
        [256.3434, 365.0671, 255.4141,  ..., 130.4256, 154.4663, 111.2507],
        [255.8377, 364.3008, 255.2356,  ..., 167.5777, 285.7129, 172.3387],
        ...,
        [256.4334, 364.5635, 255.0167,  ..., 180.0184, 209.0523, 175.3187],
        [255.8491, 364.4164, 255.2454,  ..., 113.5438, 251.6769, 105.4334],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.2166, 364.8384, 256.2404,  ..., 202.8007, 384.7673, 214.3556],
        [256.7267, 365.2932, 256.2914,  ..., 316.4549, 299.8331, 324.3442],
        [257.0667, 366.3000, 257.3740,  ..., 152.5508, 407.2841, 150.8774],
        ...,
        [256.4672, 364.8547, 256.4280,  ..., 213.1028, 305.4630, 225.8791],
        [256.4590, 365.0587, 256.6815,  ..., 169.6908, 281.8936, 173.6862],
        [256.8900, 365.4120, 257.4709,  ..., 335.2949, 370.4770, 348.6248]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.2511, 365.6261, 257.5433,  ..., 336.8210, 361.8316, 348.8974],
        [257.4364, 366.4663, 257.4078,  ..., 104.3190, 278.2737,  90.2265],
        [257.0253, 365.9872, 257.0483,  ..., 139.4548, 384.6352, 133.4945],
        ...,
        [257.7207, 366.6738, 257.4490,  ..., 228.8762, 434.8210, 240.5735],
        [258.0740, 365.9663, 257.5179,  ..., 199.7525, 107.0399, 183.8772],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.8678, 361.9783, 254.3732,  ..., 201.5670, 302.4786, 213.1671],
        [255.3512, 361.7755, 254.2611,  ..., 211.1870,  79.3921, 196.9922],
        [255.3827, 363.0307, 255.0207,  ..., 110.7659, 226.3662,  99.1738],
        ...,
        [254.8917, 362.1411, 254.5708,  ..., 163.2917, 281.6051, 167.1861],
        [254.8632, 362.4613, 254.3902,  ..., 271.6312,  58.8381, 259.4636],
        [254.0195, 360.9968, 252.9012,  ..., 296.6585, 436.2584, 311.5945]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[260.6274, 371.6256, 260.0847,  ..., 283.0472, 456.4949, 297.5781],
        [259.8580, 369.9868, 259.7547,  ..., 109.6819, 190.8413,  91.8381],
        [260.3904, 371.3313, 259.9892,  ..., 265.5695, 458.8594, 278.2533],
        ...,
        [258.8278, 368.4424, 258.2993,  ..., 225.7726, 376.3486, 239.5845],
        [258.7076, 369.7082, 259.3766,  ..., 295.9332,  60.5514, 286.0309],
    

output tensor([[254.4239, 361.6774, 254.0340,  ..., 316.9550, 296.6470, 321.8979],
        [253.9396, 361.4204, 254.1225,  ..., 158.5713, 388.9030, 156.8971],
        [255.6274, 364.1365, 255.8591,  ..., 129.4718, 161.1388, 110.3434],
        ...,
        [254.8374, 362.7931, 255.0522,  ..., 132.3455, 221.4985, 123.7461],
        [255.9520, 363.4746, 255.7964,  ..., 167.3494, 108.0948, 148.4871],
        [254.0999, 361.2763, 254.1466,  ..., 314.6866, 355.7349, 325.4147]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.5257, 368.2036, 259.1942,  ..., 213.3775, 105.1185, 197.2375],
        [259.5061, 368.2982, 259.7863,  ..., 334.3167, 413.8554, 347.2735],
        [259.1801, 368.7441, 258.9572,  ..., 250.5329, 376.0029, 265.3359],
        ...,
        [258.5719, 367.2690, 258.3839,  ..., 241.1523, 319.3150, 255.1825],
        [258.8404, 367.7602, 258.9364,  ..., 204.7598, 307.9644, 215.6915],
        [257.3318, 367.0929, 258.1016,  ..., 29

output tensor([[254.8366, 362.0232, 254.0938,  ..., 194.3166, 334.0092, 203.1536],
        [255.3942, 362.9052, 254.4975,  ..., 193.7029, 248.5466, 191.8302],
        [255.4119, 363.2346, 254.8862,  ..., 119.7258, 281.7826, 113.4098],
        ...,
        [254.8857, 362.3242, 254.2234,  ..., 122.3282, 337.0623, 116.7341],
        [255.7796, 364.1094, 255.1310,  ..., 128.8694, 164.6783, 108.0089],
        [255.3859, 363.4631, 254.5600,  ..., 233.9084, 450.4908, 237.7377]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.7575, 369.3863, 258.6089,  ..., 252.2100, 361.0062, 265.4130],
        [259.2276, 368.9902, 258.3273,  ..., 123.7592, 197.7174, 107.4727],
        [259.6853, 369.4268, 258.9244,  ..., 209.7219, 424.1386, 217.2253],
        ...,
        [259.9378, 369.5214, 258.1939,  ..., 289.2120, 315.8208, 300.8274],
        [259.4369, 369.5347, 258.9314,  ..., 112.1598, 187.9242,  93.0751],
        [258.9240, 368.3835, 258.1921,  ..., 13

output tensor([[253.7049, 361.9940, 253.5945,  ..., 150.5959, 385.9585, 146.1108],
        [254.0791, 362.2520, 254.2710,  ..., 151.9177, 317.8248, 155.2503],
        [255.2225, 363.3239, 254.8905,  ..., 166.2888, 112.9877, 146.8133],
        ...,
        [253.5980, 361.6365, 253.6754,  ..., 157.1539, 367.2470, 160.4843],
        [253.3631, 361.4779, 253.2556,  ..., 209.7817, 417.7320, 216.9182],
        [253.5937, 361.3380, 253.7167,  ..., 334.0062, 379.7169, 345.9906]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[260.2152, 371.5834, 261.0259,  ..., 191.8724, 426.8713, 200.5718],
        [260.3331, 371.0619, 260.7015,  ..., 182.7581, 218.0875, 179.7061],
        [260.9771, 372.7293, 261.6698,  ..., 231.1277, 343.6003, 248.8038],
        ...,
        [260.8098, 372.6114, 261.1794,  ..., 303.2551, 439.3820, 319.5067],
        [261.2220, 372.6129, 262.1038,  ..., 333.0642, 373.2875, 347.1033],
        [260.2368, 371.7197, 261.0389,  ..., 19

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.3112, 365.6286, 256.7334,  ..., 184.0184, 176.5731, 173.4145],
        [256.5062, 365.7457, 257.0121,  ..., 279.8177, 399.5845, 293.4849],
        [256.7957, 365.7703, 256.7768,  ..., 113.5441, 241.9801,  98.4665],
        ...,
        [256.8849, 366.4595, 257.4114,  ..., 108.7927, 199.5215,  89.6123],
        [256.3003, 365.4761, 257.1784,  ..., 193.1194, 418.9627, 194.5506],
        [257.9347, 367.1704, 258.7151,  ..., 335.1461, 413.3044, 351.5907]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.7526, 365.6538, 256.9121,  ..., 139.7494, 234.7639, 132.6054],
        [256.8450, 366.3232, 257.3998,  ..., 249.1881, 430.0165, 262.6016],
        [256.7617, 365.7368, 257.3257,  ..., 184.7636, 344.6316, 194.6543],
        ...,
        [256.8180, 365.9676, 257.5426,  ..., 156.7682, 386.5870, 161.3893],
        [257.1540, 366.1789, 257.8294,  ..., 175.2155, 423.7408, 177.4369],
    

torch.Size([60, 12])
torch.Size([60, 3])
output tensor([[257.1559, 364.0748, 254.8410, 281.6721, 326.7968, 250.8607, 334.7205,
         268.6246, 420.1865, 328.1317, 422.0814, 349.0610],
        [255.8177, 362.7713, 253.6582, 280.3159, 254.7410, 201.2428, 275.7005,
         200.7717, 360.5796, 241.1558, 365.7247, 259.1465],
        [255.6223, 362.9503, 253.8841, 280.3714, 253.3915, 201.4182, 274.0657,
         200.4500, 256.3258, 100.6322, 265.9734,  88.5044],
        [255.5282, 362.2697, 253.5674, 280.0076, 193.7361, 227.8904, 207.6608,
         211.3607, 298.8243, 228.1464, 310.6688, 246.1724],
        [255.4014, 362.6048, 253.8388, 280.1131, 176.2663, 295.9871, 172.0587,
         274.4293,  97.3386, 213.3357, 101.1004, 199.7793],
        [255.4405, 362.6627, 253.9239, 279.9386, 180.2034, 253.2551, 187.4821,
         232.9596, 128.3249, 164.0063, 132.6517, 151.1245],
        [255.8184, 363.7396, 254.5630, 280.7896, 180.6209, 253.1343, 188.2788,
         232.7803, 214.7336, 136.7453, 

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.5322, 361.6833, 254.5945,  ..., 261.3120,  65.9079, 248.0507],
        [254.0000, 361.4714, 254.6456,  ..., 130.5448, 191.7150, 119.2985],
        [253.8409, 361.2056, 254.5030,  ..., 139.9197, 380.2195, 135.7802],
        ...,
        [253.3170, 360.8271, 253.8623,  ..., 229.5854, 444.4087, 237.3597],
        [254.7124, 361.8517, 255.3068,  ..., 185.4097,  89.6964, 168.5012],
        [254.3594, 361.4174, 254.6492,  ..., 184.6422, 208.2460, 181.7208]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[262.0368, 371.7165, 261.1082,  ..., 228.3562, 439.3646, 239.6507],
        [261.5466, 371.0267, 260.6482,  ..., 242.1679, 331.3412, 257.6660],
        [261.2170, 370.7153, 260.4041,  ..., 143.0153, 392.8087, 138.5672],
        ...,
        [261.7714, 371.4426, 260.7212,  ..., 251.4254, 375.7530, 266.5231],
        [261.0939, 370.4268, 260.3383,  ..., 204.3718, 310.9123, 214.9878],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.2449, 367.3119, 258.3323,  ..., 195.7267, 337.7178, 206.4926],
        [258.1035, 367.6197, 258.2528,  ..., 203.4169,  82.5430, 184.4630],
        [257.5854, 367.6374, 258.4546,  ..., 110.9104, 333.3206, 103.4314],
        ...,
        [258.4298, 368.1224, 258.7065,  ..., 206.5110,  99.1132, 186.6965],
        [257.1741, 367.0147, 258.3249,  ..., 154.3285, 320.6104, 158.3261],
        [257.2895, 367.4154, 258.2496,  ..., 195.5118, 302.1898, 203.1198]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.2458, 365.5793, 256.9276,  ..., 248.4394, 380.7943, 261.0776],
        [256.2604, 365.4671, 257.1027,  ..., 228.4820, 341.8392, 240.5543],
        [256.8201, 364.8121, 256.3008,  ..., 205.2510,  97.7728, 186.3737],
        ...,
        [256.6010, 365.6168, 256.6077,  ..., 263.3364,  70.6314, 245.1273],
        [256.4042, 365.5336, 257.3133,  ..., 225.6773, 438.5597, 233.4510],
    

output tensor([[258.9611, 367.5891, 256.7902,  ..., 239.1492,  90.5842, 224.0597],
        [259.2460, 367.3097, 257.0703,  ..., 165.0136, 130.5938, 147.3302],
        [258.7506, 367.6291, 257.0298,  ..., 281.4526, 408.7515, 292.7278],
        ...,
        [258.6273, 366.9454, 256.8791,  ..., 155.3850, 256.9294, 152.1406],
        [258.8043, 366.8181, 256.3173,  ..., 113.0631, 241.0999,  96.8875],
        [259.0371, 367.9042, 257.3365,  ..., 104.6018, 217.7451,  85.5970]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.1515, 365.5692, 255.6737,  ..., 182.4125, 229.4648, 178.7342],
        [256.8922, 365.2679, 255.7529,  ..., 153.0340, 373.9250, 155.6303],
        [256.9446, 365.6432, 255.8763,  ..., 163.0331, 282.3994, 164.7621],
        ...,
        [257.9867, 366.5036, 256.6776,  ..., 162.6760, 116.9251, 144.1559],
        [256.4012, 365.2454, 255.2828,  ..., 280.2603, 403.8907, 293.5996],
        [256.9208, 365.1796, 255.9885,  ..., 33

output tensor([[256.9052, 365.7686, 257.5122,  ..., 312.3168, 353.2546, 321.8093],
        [256.6697, 365.5456, 257.9643,  ..., 163.9994, 266.8573, 163.9454],
        [256.6040, 365.5995, 257.5561,  ..., 309.0860, 429.1358, 320.2289],
        ...,
        [256.8266, 366.3385, 258.1529,  ..., 131.5874, 168.4699, 111.4333],
        [256.8631, 365.8055, 258.0301,  ..., 119.9563, 279.6255, 114.1813],
        [256.7915, 365.4617, 257.7789,  ..., 123.9087, 336.0011, 119.4157]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.6263, 364.0852, 256.2141,  ..., 214.0103, 302.0435, 226.0580],
        [257.4105, 365.6548, 257.0978,  ..., 150.7237, 385.6889, 155.3304],
        [257.0207, 364.6527, 256.4848,  ..., 156.3826, 148.1675, 141.3101],
        ...,
        [257.2540, 364.6559, 256.7014,  ..., 188.1617,  92.0579, 169.3575],
        [256.5728, 364.4012, 256.1992,  ..., 237.9906, 441.1163, 245.5441],
        [257.2973, 365.7638, 256.7957,  ..., 26

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.6931, 362.9961, 254.9069,  ..., 160.9741, 284.6526, 162.9511],
        [256.0335, 362.9987, 255.2277,  ..., 183.4134, 163.2332, 172.7822],
        [254.7059, 362.1859, 255.0731,  ..., 333.3458, 328.0706, 340.6248],
        ...,
        [254.3416, 362.1536, 253.5866,  ..., 286.5128, 317.4387, 297.0869],
        [254.0142, 361.6107, 254.1167,  ..., 215.3622, 300.5879, 226.4283],
        [254.9550, 362.6385, 254.3603,  ...,  99.5404, 276.9962,  85.4554]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.1067, 368.1480, 258.3583,  ..., 167.4082, 132.4784, 151.5595],
        [259.1770, 369.0314, 258.9037,  ..., 157.5372, 385.8747, 154.4136],
        [259.0915, 369.2556, 259.0204,  ..., 240.6376, 434.6153, 245.5743],
        ...,
        [259.6431, 369.2369, 259.7035,  ..., 337.3987, 380.3187, 347.2547],
        [259.6072, 369.8416, 259.4495,  ..., 147.9980, 386.6154, 142.5216],
    

output tensor([[259.7570, 369.3667, 258.1085,  ..., 230.6731, 453.0956, 232.4231],
        [257.4229, 366.4812, 256.9823,  ..., 194.5840, 250.7446, 193.6795],
        [260.9421, 372.1106, 260.9096,  ..., 334.1104, 414.4148, 345.6404],
        ...,
        [259.5830, 370.3271, 260.8181,  ..., 337.0608, 326.2950, 342.5950],
        [259.4997, 368.7066, 257.7361,  ..., 238.4514, 450.5294, 241.9592],
        [257.2795, 365.3153, 257.1077,  ..., 182.0114, 142.5962, 164.8607]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.3302, 361.7297, 253.4317,  ..., 164.3361, 284.8547, 166.9238],
        [255.5948, 365.1038, 256.8123,  ..., 328.5413, 405.0462, 341.4599],
        [251.1860, 358.8785, 250.5247,  ..., 101.9594, 275.7962,  90.4815],
        ...,
        [255.6748, 361.8287, 252.0043,  ..., 266.9516, 413.4332, 278.7041],
        [252.2916, 360.5740, 252.8538,  ..., 130.0670, 219.1361, 120.5422],
        [253.6109, 360.4142, 252.0312,  ..., 23

output tensor([[257.9756, 366.4898, 255.9893,  ..., 247.8040, 456.1503, 256.5081],
        [257.0526, 364.6910, 255.1069,  ..., 225.1329, 436.7914, 232.6807],
        [258.0612, 365.0537, 254.7060,  ..., 308.0062, 361.2642, 321.8170],
        ...,
        [257.6876, 366.6035, 256.2734,  ..., 134.9709, 326.7275, 137.3967],
        [258.2752, 365.7086, 255.9994,  ..., 175.8914,  99.7428, 159.3507],
        [258.2906, 366.1127, 255.5791,  ..., 312.9375, 365.3336, 328.5091]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.3615, 366.1904, 255.9419,  ..., 192.6846, 305.8543, 203.0256],
        [257.4810, 366.3280, 256.5971,  ..., 167.6523, 109.0857, 152.2364],
        [256.6078, 366.5624, 256.2303,  ..., 232.8553, 450.9827, 240.4557],
        ...,
        [256.7798, 367.5513, 257.4300,  ..., 336.3242, 343.7602, 349.2941],
        [256.7486, 366.2078, 256.0070,  ..., 147.9709, 392.9851, 147.1684],
        [256.7709, 367.4979, 257.1131,  ..., 33

output tensor([[254.9999, 364.1153, 255.8913,  ..., 167.3400, 269.8596, 168.2130],
        [255.0793, 363.9907, 255.7425,  ..., 128.3296, 343.3622, 124.1716],
        [255.5402, 364.8252, 255.8947,  ..., 103.7281, 271.0695,  85.8487],
        ...,
        [254.8493, 363.9046, 255.9406,  ..., 217.1374, 420.2178, 222.6871],
        [255.4957, 365.1703, 256.1693,  ..., 253.4767, 429.5428, 262.5166],
        [256.1387, 364.5935, 255.9870,  ..., 180.2245, 212.0300, 173.1849]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.6425, 367.1494, 257.4933,  ..., 130.3895, 350.4745, 127.3594],
        [258.8562, 367.3535, 257.8101,  ..., 214.7722, 380.5820, 224.2580],
        [259.6031, 368.0035, 257.7369,  ..., 328.7795, 325.9373, 337.0494],
        ...,
        [259.2039, 368.1944, 258.6715,  ..., 134.8586, 173.6736, 117.6020],
        [259.4275, 367.4836, 257.9881,  ..., 180.6794, 148.7215, 164.9195],
        [258.9320, 367.3815, 257.4135,  ..., 20

output tensor([[257.1918, 367.0302, 257.8899,  ..., 253.8207, 375.4134, 268.5595],
        [257.2440, 366.9941, 258.3905,  ..., 229.9262, 435.2888, 239.6958],
        [257.0498, 365.8410, 257.4521,  ..., 329.2430, 294.5257, 337.0593],
        ...,
        [257.0022, 366.7032, 257.9843,  ..., 337.1740, 413.9929, 352.8982],
        [256.4111, 365.7036, 257.5578,  ..., 179.2416, 409.8455, 180.6174],
        [257.2338, 366.9716, 258.3810,  ..., 230.3852, 435.2684, 240.2986]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.0910, 366.1406, 257.7003,  ..., 226.1445, 341.3416, 241.5037],
        [256.2419, 366.5495, 258.3564,  ..., 308.9140,  62.4665, 306.1647],
        [256.5809, 365.6876, 256.8225,  ..., 285.7390,  73.4894, 279.7413],
        ...,
        [256.6285, 365.0056, 257.4419,  ..., 232.9484, 315.2002, 249.1425],
        [257.5393, 365.5023, 257.3741,  ..., 184.3028, 160.1122, 176.0667],
        [256.4398, 364.7200, 256.7688,  ..., 30

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.0978, 362.3806, 253.5594,  ..., 252.5166, 356.3597, 266.6005],
        [255.1452, 361.8271, 253.1386,  ..., 120.4462, 333.0268, 118.9616],
        [255.1721, 362.1469, 254.0430,  ..., 145.0161, 176.2162, 135.5756],
        ...,
        [255.7182, 363.3352, 254.5521,  ..., 148.6664, 239.0490, 147.7659],
        [254.4256, 361.6839, 253.8829,  ..., 131.0832, 153.1322, 112.7980],
        [254.6868, 361.5372, 252.5851,  ..., 300.4542, 427.4091, 313.8733]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.2346, 367.9760, 256.7972,  ..., 261.6670, 457.2866, 271.2465],
        [259.5271, 368.4037, 257.3780,  ..., 109.7389, 286.7121, 106.3505],
        [258.1158, 366.8358, 257.8994,  ..., 311.7119,  63.1432, 307.3589],
        ...,
        [260.0966, 367.9505, 257.2961,  ..., 240.1439, 436.7364, 250.7377],
        [259.7726, 367.2339, 257.1505,  ..., 231.1517, 434.7926, 241.9624],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.0858, 365.9610, 257.3480,  ..., 162.8732, 123.7132, 142.9628],
        [257.6336, 365.6719, 256.6020,  ..., 190.4964, 247.3964, 188.6345],
        [257.6849, 366.2880, 257.0150,  ..., 186.0646, 348.7840, 194.9522],
        ...,
        [257.8497, 366.7266, 257.0494,  ..., 175.4124, 395.7708, 181.7101],
        [257.7030, 366.3678, 257.0284,  ..., 184.2048, 348.8963, 192.4714],
        [258.0072, 365.4412, 256.6158,  ..., 185.7894, 214.7209, 181.1919]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.6553, 365.9271, 256.3668,  ..., 113.9309, 285.5490, 106.7678],
        [257.1661, 365.0560, 256.3047,  ..., 167.5417, 272.8209, 169.2212],
        [257.2054, 365.2470, 256.3377,  ..., 169.9689, 281.3812, 172.1291],
        ...,
        [257.6289, 365.6901, 256.2385,  ..., 114.0858, 259.0572, 105.2436],
        [257.6824, 365.6147, 256.8718,  ..., 218.6888,  89.7977, 198.1442],
    

output tensor([[257.3504, 365.3595, 256.3981,  ..., 184.4676, 209.8459, 181.2732],
        [256.6060, 365.6642, 257.1459,  ..., 210.4440,  92.1749, 190.9946],
        [256.9595, 365.5734, 256.2939,  ..., 289.0909, 312.5450, 302.4634],
        ...,
        [256.5264, 365.1028, 256.4920,  ..., 193.2788, 336.2611, 205.8201],
        [256.2596, 364.6818, 256.2449,  ..., 186.2968, 413.5130, 191.1239],
        [257.5760, 365.8424, 256.6759,  ..., 179.6205, 214.2621, 174.5943]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.4753, 365.9466, 256.5844,  ..., 329.9264, 300.8241, 337.7102],
        [257.6700, 366.1909, 256.9857,  ..., 144.6625, 173.9145, 135.6785],
        [257.0211, 365.7422, 256.5355,  ..., 143.3620, 237.7646, 143.5293],
        ...,
        [257.3857, 366.0067, 256.5696,  ..., 141.5102, 381.2310, 149.0858],
        [257.1351, 365.5864, 256.5022,  ..., 248.5378, 361.4850, 267.1231],
        [256.8379, 366.1698, 256.6872,  ..., 28

output tensor([[261.3672, 372.5948, 261.6655,  ..., 182.6276, 401.2433, 191.2433],
        [260.4420, 371.2681, 261.2033,  ..., 205.4553, 313.1832, 215.1914],
        [260.1801, 370.0771, 259.8810,  ..., 111.2904, 288.8945, 103.1771],
        ...,
        [260.3728, 370.8636, 261.2997,  ..., 216.9195, 307.7039, 230.0566],
        [261.6130, 371.7850, 261.4552,  ..., 330.8719, 300.5618, 338.2495],
        [260.5336, 371.5150, 261.3232,  ..., 138.3868, 332.2318, 140.2661]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.5956, 359.6631, 253.1756,  ..., 333.1200, 341.1625, 342.5756],
        [253.1831, 359.0306, 252.1527,  ..., 237.2801, 361.9952, 250.6912],
        [252.7198, 358.8761, 251.9506,  ..., 207.1648,  78.4658, 190.2793],
        ...,
        [253.2112, 359.8922, 251.9319,  ..., 259.2473,  69.2075, 245.1165],
        [253.7268, 359.1567, 252.2820,  ..., 182.4053, 207.8439, 179.9109],
        [253.3421, 359.9854, 252.6781,  ..., 13

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[252.5383, 358.5872, 251.5985,  ..., 203.3941,  76.6306, 189.9352],
        [252.8367, 359.5380, 251.9639,  ..., 184.9488, 343.3149, 195.5950],
        [253.4299, 359.6866, 252.9594,  ..., 162.2426, 103.9814, 144.8529],
        ...,
        [253.9610, 359.7767, 252.4741,  ..., 177.5793, 154.9938, 166.9184],
        [254.1229, 360.3142, 253.2115,  ..., 174.5548, 141.5184, 159.8085],
        [252.8985, 359.4073, 252.4682,  ..., 304.2648,  58.1689, 302.2963]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[263.4321, 372.9602, 260.9380,  ..., 316.4341, 358.8177, 330.7020],
        [263.9849, 374.1721, 261.9225,  ..., 226.7500, 347.3802, 241.6378],
        [262.9680, 373.0011, 261.5023,  ..., 140.2836, 236.2641, 133.0465],
        ...,
        [262.3757, 372.5435, 260.2247,  ..., 308.3132,  72.3761, 303.4375],
        [263.0838, 372.5346, 261.4085,  ..., 224.8285, 438.1120, 234.3767],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.0242, 369.5836, 260.4616,  ..., 153.4305, 258.1411, 153.4605],
        [258.6866, 369.9259, 260.9197,  ..., 109.6097, 202.8053,  89.4550],
        [257.9564, 368.4550, 259.3981,  ..., 292.4239, 451.7503, 303.7999],
        ...,
        [259.4397, 369.7757, 259.8411,  ..., 112.3851, 249.3719,  99.4237],
        [257.8083, 369.0645, 259.8204,  ..., 309.8062,  65.9989, 303.1900],
        [258.7883, 370.6580, 260.1543,  ..., 252.5126, 460.4678, 256.3760]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.2862, 366.3951, 257.4765,  ..., 140.3237, 378.8882, 135.0869],
        [258.8119, 366.9404, 257.1206,  ..., 111.0588, 249.6787, 101.5349],
        [256.5337, 363.9260, 255.2240,  ..., 311.9016, 346.4605, 324.8347],
        ...,
        [258.0503, 366.8422, 258.0517,  ..., 216.1971,  84.2136, 199.0311],
        [259.5689, 367.1660, 257.6628,  ..., 183.3169, 207.0866, 181.4590],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.1394, 361.7076, 253.9136,  ..., 308.6904, 353.2048, 323.2721],
        [254.1946, 363.4218, 256.2933,  ..., 136.2531, 161.6778, 122.6561],
        [253.3344, 362.1201, 254.2676,  ..., 143.9052, 372.9358, 148.1283],
        ...,
        [253.5324, 362.2207, 254.4820,  ..., 306.3098, 299.0166, 317.4850],
        [253.5797, 362.9726, 254.9904,  ..., 184.8954, 292.8824, 192.1429],
        [253.2382, 362.2067, 253.7895,  ..., 244.5008, 425.7629, 257.9832]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[260.3011, 370.8226, 259.6284,  ..., 290.0556, 338.9966, 305.5895],
        [258.6776, 368.6325, 259.6191,  ..., 159.9111, 331.9297, 169.1763],
        [259.1122, 369.0749, 259.7405,  ..., 239.1921, 322.3934, 255.2891],
        ...,
        [259.1182, 369.3273, 259.5962,  ..., 190.4101, 347.8394, 202.2467],
        [259.1657, 369.4274, 259.5379,  ..., 145.4794, 394.7566, 143.9150],
    

output tensor([[265.6385, 376.4217, 265.4479,  ..., 206.6796, 431.0206, 218.5106],
        [260.4075, 369.8623, 260.3352,  ..., 272.0038,  64.4588, 262.9663],
        [263.1982, 373.6174, 263.0535,  ..., 161.4432, 271.6288, 167.6602],
        ...,
        [265.7315, 377.1286, 264.0204,  ..., 254.8647, 387.3109, 271.2767],
        [262.7673, 373.6429, 262.5889,  ..., 107.1879, 227.8861,  92.5160],
        [261.7763, 371.3351, 261.9716,  ..., 191.8154, 127.1853, 176.2861]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.0036, 359.6662, 251.4736,  ..., 169.1604, 386.5793, 173.8377],
        [252.6472, 358.4679, 251.3523,  ..., 141.7847, 359.9224, 147.7009],
        [252.6091, 358.1662, 250.9292,  ..., 315.9177, 339.7317, 328.8820],
        ...,
        [251.4482, 357.3153, 250.0638,  ..., 319.6935, 284.9309, 322.4120],
        [252.8985, 358.4405, 250.6877,  ..., 327.0008, 403.9719, 342.8717],
        [250.5994, 356.6824, 249.3146,  ..., 23

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.1861, 368.1743, 257.2197,  ..., 116.0425, 242.9124, 102.1294],
        [259.1056, 368.8445, 258.5860,  ..., 136.0819, 203.3556, 118.6281],
        [258.6507, 367.7616, 258.4205,  ..., 340.2325, 348.3152, 349.2054],
        ...,
        [258.6415, 367.8096, 258.2119,  ..., 337.8569, 384.5052, 349.6224],
        [258.5994, 368.4753, 258.0469,  ..., 169.5217, 285.5366, 166.4704],
        [258.5268, 367.9042, 257.9814,  ..., 338.3709, 409.6912, 351.0908]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.6812, 362.6308, 254.3603,  ..., 318.4756, 421.4400, 336.6884],
        [255.7154, 363.9774, 255.1545,  ...,  97.9846, 283.6129,  83.5788],
        [256.2244, 364.4968, 254.6614,  ..., 119.0414, 202.6921, 106.3351],
        ...,
        [257.1474, 364.6881, 255.6185,  ..., 177.2134, 155.8900, 164.4369],
        [255.3371, 363.3189, 255.2287,  ..., 333.0127, 345.3201, 345.0516],
    

output tensor([[252.3381, 360.4269, 253.2760,  ..., 166.8114, 279.5201, 167.5394],
        [252.8926, 360.2459, 252.3708,  ..., 100.9553, 266.3212,  81.7453],
        [252.1081, 360.0697, 253.3970,  ..., 215.2533, 293.2967, 225.2635],
        ...,
        [252.3956, 360.5004, 253.3980,  ..., 202.9676, 257.6056, 203.6640],
        [253.3425, 362.1851, 254.8405,  ..., 117.1802, 292.4556, 105.3469],
        [252.8312, 359.5056, 252.4582,  ..., 184.4388, 211.0369, 182.2731]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[260.7559, 370.5388, 259.7360,  ..., 116.4563, 335.5833, 118.9145],
        [261.2409, 372.4373, 261.1198,  ..., 237.4186, 334.9923, 254.3537],
        [261.8308, 372.3315, 261.2237,  ..., 160.5393, 135.8338, 148.3311],
        ...,
        [261.3173, 372.8434, 261.0418,  ..., 223.7471, 341.0508, 237.5113],
        [261.4917, 371.8776, 260.1208,  ..., 100.5742, 277.2255,  88.3250],
        [260.9818, 371.4645, 260.5403,  ..., 14

output tensor([[256.5083, 365.1731, 256.3611,  ..., 126.0820, 159.1304, 111.1964],
        [256.2227, 365.0967, 256.2458,  ..., 110.0376, 192.1284,  94.8654],
        [257.4891, 365.7895, 254.7976,  ..., 249.5445, 374.5016, 267.5325],
        ...,
        [256.9783, 365.6519, 255.9392,  ..., 106.7764, 278.8282, 100.5680],
        [256.9398, 364.5529, 255.1040,  ..., 276.3717, 400.1924, 296.1436],
        [257.2569, 365.4659, 255.6435,  ..., 103.8361, 323.2660,  96.5486]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.3449, 367.5362, 256.9269,  ..., 291.5885, 324.7897, 308.5259],
        [258.1409, 366.9414, 257.1681,  ..., 338.1768, 373.0326, 350.5934],
        [257.6976, 366.9341, 256.8166,  ..., 112.1614, 263.7086, 110.2808],
        ...,
        [258.3080, 367.5813, 256.7518,  ..., 104.8190, 280.5534,  94.2006],
        [257.5601, 366.8231, 257.6497,  ..., 204.7105, 420.4874, 217.6716],
        [258.0001, 367.8585, 256.9266,  ..., 20

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.4339, 365.2708, 255.7630,  ..., 337.4539, 412.4884, 352.2788],
        [257.1371, 365.1534, 256.7023,  ..., 221.1264,  86.4975, 200.8289],
        [257.4363, 366.0867, 256.1439,  ..., 298.2440,  73.7758, 286.5496],
        ...,
        [257.9518, 364.8857, 256.4393,  ..., 188.3081, 416.2044, 197.3350],
        [258.1350, 365.1819, 255.6193,  ..., 282.7764, 404.1486, 297.9554],
        [257.3253, 365.1644, 255.0809,  ..., 266.2518,  76.5636, 251.3497]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.7740, 366.2330, 256.7356,  ..., 309.8556,  63.1633, 305.5959],
        [257.4785, 365.7680, 255.1307,  ..., 218.6461, 106.1376, 201.7921],
        [257.2819, 366.2382, 255.8660,  ..., 157.9399, 279.9263, 153.9243],
        ...,
        [258.1244, 366.0110, 254.8532,  ..., 276.0560, 417.0081, 291.4802],
        [256.9612, 365.1311, 256.2937,  ..., 163.5530, 109.4450, 145.2386],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[247.3689, 351.7173, 246.4963,  ..., 173.8825, 340.7001, 189.6864],
        [247.9632, 353.0779, 247.6448,  ..., 197.0208, 300.1826, 210.4746],
        [247.7018, 352.8115, 247.1147,  ..., 173.1863, 295.4503, 181.1095],
        ...,
        [246.3720, 348.9977, 245.7658,  ..., 199.4133, 371.3625, 220.3252],
        [245.0674, 348.9686, 244.9983,  ..., 310.1193, 288.0017, 320.4670],
        [245.1901, 348.5127, 244.2204,  ..., 292.8506, 349.3058, 311.1771]],
       grad_fn=<AddmmBackward0>)
torch.Size([60, 12])
torch.Size([60, 3])
output tensor([[262.4791, 370.9308, 263.2822, 288.2845, 306.1934, 222.3419, 320.7396,
         234.0673, 408.5138, 202.8689, 420.2837, 216.4865],
        [262.5125, 372.8856, 263.4235, 288.1146, 211.6656, 219.7922, 229.3084,
         206.8352, 326.1972, 201.1841, 336.5992, 218.7324],
        [264.0356, 374.7922, 264.1316, 289.3225, 336.7475, 260.6067, 343.6810,
         280.0542, 426.4346, 332.9126, 422

output tensor([[258.2268, 366.7993, 256.8100,  ..., 144.6278, 380.3924, 142.2388],
        [258.0965, 366.8556, 256.9025,  ..., 187.9969, 348.0413, 194.7209],
        [258.4203, 365.8065, 256.5484,  ..., 180.0736, 206.5385, 179.9287],
        ...,
        [257.4717, 365.4905, 256.4089,  ..., 327.5711, 320.1178, 336.3225],
        [256.3848, 365.5493, 256.2312,  ..., 303.1936,  72.3505, 294.0773],
        [256.8373, 365.0378, 256.9622,  ..., 222.4882, 305.2437, 237.8048]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.0652, 368.3420, 258.6804,  ..., 138.8148, 328.3544, 132.6528],
        [257.7403, 367.4354, 258.3740,  ..., 110.7074, 329.4207,  95.3801],
        [257.2815, 367.0490, 257.8341,  ..., 165.3547, 285.6041, 165.2513],
        ...,
        [257.4568, 366.9951, 258.1509,  ..., 104.1648, 278.7916,  84.8146],
        [257.9781, 367.9596, 258.9583,  ..., 147.4507, 391.7589, 134.8972],
        [258.2032, 366.9655, 257.7529,  ..., 18

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.6155, 369.5675, 259.7345,  ..., 102.9524, 278.6479,  87.0367],
        [260.6299, 370.2182, 260.0973,  ..., 339.1561, 379.2927, 352.0307],
        [260.0365, 370.3610, 259.3329,  ..., 242.4721, 457.0479, 245.3266],
        ...,
        [260.3801, 369.7131, 259.5143,  ..., 145.3343, 174.5678, 139.6184],
        [260.3896, 370.2805, 259.5980,  ..., 311.3216, 361.4769, 323.9274],
        [260.0118, 369.4911, 260.5005,  ..., 177.6781, 333.5464, 194.8130]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.6870, 362.0086, 254.4648,  ..., 305.5266,  63.2369, 298.7470],
        [254.6211, 361.9079, 253.5714,  ..., 179.3799, 410.8806, 187.1060],
        [254.6142, 361.9474, 253.3468,  ..., 138.5863, 325.4848, 138.5804],
        ...,
        [253.6456, 361.5907, 254.0618,  ..., 116.2385, 172.8015,  95.9386],
        [254.0760, 361.6819, 254.5737,  ..., 135.2948, 159.3019, 116.2867],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.2038, 360.0838, 251.7202,  ..., 158.7406, 131.9024, 147.0781],
        [251.6716, 358.5963, 251.5346,  ..., 311.0627, 355.1571, 326.2901],
        [252.3359, 360.1277, 252.3911,  ..., 109.5648, 275.9875, 101.5131],
        ...,
        [252.9151, 360.0034, 251.8785,  ..., 160.8522, 122.1095, 145.3423],
        [252.4683, 360.1263, 252.4632,  ..., 140.1303, 320.5354, 143.2578],
        [251.2913, 358.2900, 251.3926,  ..., 310.8024, 350.3732, 323.8752]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[262.0206, 372.9088, 261.3155,  ..., 110.7706, 288.0236, 102.4385],
        [261.3782, 372.4139, 261.6598,  ..., 109.7887, 201.0005,  90.8733],
        [261.6490, 372.6412, 261.0856,  ..., 162.8349, 288.6217, 164.0372],
        ...,
        [262.6957, 373.3755, 261.2288,  ..., 214.5831, 109.6462, 197.3542],
        [262.6231, 373.6255, 261.8447,  ..., 144.0663, 385.8970, 142.4641],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.0367, 363.6201, 254.4518,  ..., 151.4420, 383.2073, 144.2003],
        [254.9429, 363.6917, 255.0677,  ..., 103.3165, 269.0067,  80.9685],
        [255.0921, 363.5793, 253.8671,  ..., 194.9707, 394.8248, 197.6017],
        ...,
        [255.2867, 363.9094, 256.1422,  ..., 184.3676,  90.0435, 166.6372],
        [255.6315, 363.9405, 254.4288,  ..., 104.7948, 279.4066,  88.3290],
        [255.3195, 364.6455, 255.5889,  ..., 148.8864, 391.8114, 132.7740]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.6428, 367.9790, 257.9460,  ..., 114.1214, 339.7011, 104.8311],
        [258.4251, 368.4068, 259.3579,  ..., 137.6919, 173.9489, 116.4973],
        [259.3393, 369.5262, 258.5203,  ..., 288.2884, 459.7073, 290.9920],
        ...,
        [258.3048, 368.2305, 258.5790,  ..., 106.5943, 307.9914,  86.5456],
        [258.4485, 368.6693, 258.7649,  ..., 308.9607, 308.2435, 312.4939],
    

output tensor([[255.8529, 364.7324, 254.6675,  ..., 189.5482, 393.8670, 196.1087],
        [254.6362, 363.3651, 255.8240,  ..., 209.1753,  89.7104, 192.5008],
        [256.2917, 364.9260, 255.4792,  ..., 171.7418, 392.3242, 175.7225],
        ...,
        [255.0205, 363.4040, 255.0974,  ..., 123.1714, 159.6396, 107.7425],
        [254.5511, 362.3423, 252.6317,  ..., 244.3309,  80.8300, 231.3112],
        [254.5903, 363.0742, 255.8424,  ..., 218.8487, 417.7220, 240.0528]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.0625, 368.1747, 258.2397,  ..., 330.6211, 295.8601, 329.9342],
        [260.2537, 369.6764, 258.6776,  ..., 154.7451, 153.9998, 144.0544],
        [260.4117, 370.0026, 259.0857,  ..., 146.4674, 175.3123, 139.1221],
        ...,
        [259.7241, 370.0772, 259.6952,  ..., 316.4396, 365.9882, 331.3933],
        [260.7840, 371.8144, 259.8932,  ..., 280.7945, 423.0993, 294.3882],
        [259.8572, 369.8338, 258.0788,  ..., 13

output tensor([[255.1748, 363.2616, 255.6352,  ..., 162.5179, 114.4744, 142.1313],
        [254.4065, 362.5463, 254.2986,  ..., 299.8176, 307.6512, 305.4492],
        [255.4587, 363.6453, 254.6832,  ..., 259.2994,  78.4972, 241.7612],
        ...,
        [255.0861, 363.1989, 254.7691,  ..., 110.8873, 331.9892,  95.2307],
        [254.7783, 362.7583, 255.2721,  ..., 161.3221, 270.5493, 171.3990],
        [254.7394, 363.6489, 255.3209,  ..., 123.3039, 303.2509, 106.4361]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.7350, 369.1933, 259.2899,  ..., 344.2975, 350.3757, 349.2204],
        [259.4855, 368.7633, 259.7559,  ..., 204.5808, 382.3511, 222.9557],
        [260.0158, 369.8356, 258.5870,  ..., 231.1599, 442.2216, 237.7763],
        ...,
        [259.1450, 368.6371, 259.2155,  ..., 168.3947, 276.2690, 179.3377],
        [259.5978, 368.4696, 258.1076,  ..., 148.1824, 395.0973, 148.4380],
        [258.8453, 368.3175, 259.4534,  ..., 23

output tensor([[255.4082, 362.7469, 255.0782,  ..., 335.0743, 344.2396, 345.3160],
        [256.0968, 363.6741, 255.0239,  ..., 249.3638, 370.7377, 267.8045],
        [256.1419, 363.3949, 254.8712,  ..., 113.3832, 336.8929, 113.9116],
        ...,
        [256.2754, 364.6470, 256.0067,  ..., 137.9860, 235.1866, 135.9324],
        [256.1494, 364.2281, 256.1876,  ..., 143.2983, 243.6069, 151.5714],
        [255.6325, 363.8909, 255.8259,  ..., 104.8725, 308.8892,  86.0188]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.6654, 368.1966, 258.0279,  ..., 181.8913, 215.0446, 186.8602],
        [258.9967, 367.6085, 257.9657,  ..., 120.1436, 342.6979, 127.4211],
        [259.3896, 368.5309, 259.5407,  ..., 183.1442, 420.6254, 196.5486],
        ...,
        [258.9042, 368.0314, 258.1266,  ..., 332.5205, 387.6756, 346.6496],
        [259.1310, 368.7041, 258.2615,  ..., 154.2203, 386.5280, 147.1994],
        [259.3864, 368.7744, 259.3034,  ..., 17

output tensor([[253.5258, 361.4513, 254.4778,  ..., 264.8519,  57.1617, 256.4119],
        [254.5508, 361.5573, 253.9496,  ..., 176.4797, 178.1372, 175.4452],
        [254.5006, 362.5315, 254.6501,  ..., 320.5322, 414.8471, 338.0816],
        ...,
        [255.2385, 362.5993, 253.1835,  ..., 206.9348,  97.7948, 193.7546],
        [255.3980, 363.5212, 255.2296,  ..., 332.5317, 299.0435, 332.3874],
        [255.2819, 364.0353, 254.5299,  ..., 209.1073, 344.3012, 212.5603]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[260.4662, 370.6068, 259.6500,  ..., 134.0308, 178.8414, 120.9402],
        [260.6657, 371.4311, 261.6512,  ..., 336.8316, 384.5936, 349.0078],
        [259.5664, 371.1413, 261.9306,  ..., 125.2613, 303.9918, 110.4080],
        ...,
        [260.3149, 370.9865, 262.4858,  ..., 151.0392, 323.5270, 167.5255],
        [259.8035, 370.8582, 260.7442,  ..., 200.2572, 306.5300, 199.7999],
        [260.2978, 370.2611, 259.4511,  ..., 23

output tensor([[255.7206, 364.6211, 256.7323,  ..., 195.4002, 110.6794, 172.2804],
        [256.4120, 364.4041, 253.8402,  ..., 245.6656,  89.7014, 225.5498],
        [256.3613, 364.6613, 255.6759,  ..., 181.6511, 423.0599, 185.8086],
        ...,
        [256.0970, 364.3838, 255.9029,  ..., 173.0942, 407.1480, 185.5591],
        [255.6260, 364.0996, 255.9550,  ..., 165.7336, 106.4123, 146.9842],
        [255.4389, 363.8352, 254.3632,  ..., 335.6163, 299.6697, 332.1962]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.2460, 367.5900, 258.3819,  ..., 289.8108, 442.7573, 310.7649],
        [258.9249, 367.8560, 256.4793,  ..., 117.1604, 211.8267, 109.7831],
        [257.1698, 366.4607, 257.9708,  ..., 210.1678,  78.4635, 201.3928],
        ...,
        [258.1548, 367.3610, 259.0573,  ..., 184.7128, 333.4899, 205.7261],
        [259.6286, 368.8842, 257.4661,  ..., 257.6125, 437.0617, 269.1895],
        [257.6198, 367.6811, 258.1422,  ..., 11

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[251.2241, 357.4084, 252.1840,  ..., 135.4119, 360.5161, 149.2765],
        [251.1721, 358.0541, 252.0072,  ..., 324.5262, 293.8271, 328.6225],
        [251.5021, 358.5612, 253.3784,  ..., 188.0547, 327.6898, 206.1079],
        ...,
        [250.9258, 357.7341, 251.8189,  ..., 271.8591, 397.6072, 290.4694],
        [250.1284, 357.3627, 252.1654,  ..., 231.2439, 433.0887, 246.2542],
        [251.9218, 359.3590, 252.9411,  ..., 167.8070, 405.4365, 180.2380]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[261.5505, 372.8552, 262.8197,  ..., 106.9102, 267.3183,  85.5589],
        [262.7601, 373.2852, 261.8476,  ..., 208.1956,  94.8357, 191.5356],
        [262.8660, 373.1838, 261.0538,  ..., 255.2228,  91.3405, 236.4080],
        ...,
        [264.0563, 375.5285, 263.6172,  ..., 293.6566, 458.7243, 304.8552],
        [261.7106, 372.6323, 262.6736,  ..., 146.8329, 388.0693, 132.2545],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[251.1857, 358.1028, 250.9795,  ..., 295.0875, 370.3584, 309.9453],
        [251.2464, 358.3252, 252.3678,  ..., 228.4137, 417.7588, 248.5913],
        [253.6667, 360.7368, 252.2901,  ..., 104.9294, 240.9247, 107.0430],
        ...,
        [254.3245, 361.5904, 251.7312,  ..., 122.0026, 193.3839, 112.8722],
        [251.1405, 358.0113, 250.9953,  ..., 298.2606, 365.1567, 312.8768],
        [253.3523, 361.1314, 252.7899,  ..., 105.7411, 221.0841,  90.5062]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[261.2483, 372.1185, 261.9703,  ..., 338.3366, 355.8598, 348.9926],
        [260.4421, 371.6288, 260.9825,  ..., 310.5952, 367.8626, 320.3850],
        [262.3256, 373.3732, 261.7716,  ..., 337.7192, 419.8352, 353.9513],
        ...,
        [259.8842, 370.5457, 260.0240,  ..., 204.1552, 265.1997, 207.0467],
        [260.1230, 370.8889, 261.5835,  ..., 243.3771, 319.8904, 262.1838],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.9731, 363.7226, 255.4377,  ..., 309.6115, 363.7932, 321.6905],
        [254.8190, 363.5927, 256.8275,  ..., 217.9257, 380.5027, 239.7786],
        [256.2450, 365.1499, 257.5107,  ..., 229.7218, 439.1394, 242.8717],
        ...,
        [257.5028, 367.0306, 256.8090,  ..., 304.4336,  81.5019, 292.3521],
        [256.6260, 366.3447, 258.2621,  ..., 134.8931, 187.2404, 112.9471],
        [257.9354, 369.2089, 259.4431,  ..., 155.3211, 400.8546, 135.9194]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.7288, 364.5013, 255.7471,  ..., 114.3770, 330.2242, 103.6170],
        [256.2836, 364.2570, 255.9238,  ..., 315.6451, 346.8801, 325.7472],
        [256.6929, 364.8356, 255.7876,  ..., 201.2523, 384.4662, 210.1044],
        ...,
        [257.7534, 365.2372, 255.0971,  ..., 151.2444, 155.3335, 143.2423],
        [255.9462, 364.7126, 256.8957,  ..., 171.4932,  95.8099, 154.9969],
    

output tensor([[251.9933, 359.6722, 252.7402,  ..., 105.4519, 206.6191,  80.6455],
        [249.8081, 355.7869, 249.2730,  ..., 246.0798, 357.1984, 264.5462],
        [250.1072, 355.5439, 249.4358,  ..., 227.5277, 424.3071, 238.7299],
        ...,
        [251.4832, 358.0796, 249.2541,  ..., 208.7074, 342.9315, 208.0417],
        [252.0147, 359.2386, 250.7119,  ..., 302.3172,  71.7901, 284.6167],
        [250.7511, 356.0913, 249.2771,  ..., 226.8222, 430.7481, 231.3656]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[263.9479, 375.4739, 264.8225,  ..., 149.4646, 329.7570, 167.3857],
        [263.5320, 374.7367, 263.1301,  ..., 141.5927, 187.0849, 126.0204],
        [264.2544, 375.4703, 262.1033,  ..., 130.9041, 195.0496, 121.2944],
        ...,
        [262.8082, 374.8145, 264.6705,  ..., 219.2344, 304.6063, 237.0077],
        [262.4588, 373.9914, 263.0576,  ..., 329.3431, 314.6690, 338.5711],
        [261.9259, 373.6402, 262.4180,  ..., 33

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.8010, 367.5215, 258.4468,  ..., 113.4091, 342.8170, 112.3792],
        [259.2601, 368.3744, 257.9801,  ..., 262.9291,  84.5172, 244.8394],
        [258.9047, 368.1400, 258.9779,  ..., 333.7118, 394.3957, 352.8345],
        ...,
        [257.7939, 367.1765, 259.7142,  ..., 217.2310, 375.2859, 240.2615],
        [258.5961, 367.7783, 259.4708,  ..., 141.8570, 323.6674, 159.9605],
        [259.1463, 368.5360, 258.8337,  ..., 289.6243, 328.3909, 310.9106]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.5267, 364.7289, 254.6844,  ..., 129.8294, 177.0089, 116.0520],
        [256.1221, 364.1389, 255.4938,  ..., 138.2910, 177.1781, 123.2495],
        [256.3937, 364.0092, 255.5532,  ..., 178.9097, 216.3208, 174.0448],
        ...,
        [255.3350, 363.3381, 255.7206,  ..., 197.5980, 381.9540, 216.6879],
        [255.2222, 363.4731, 254.8290,  ..., 195.6739, 302.3036, 190.5664],
    

output tensor([[253.6750, 363.1942, 253.2623,  ..., 276.9564, 314.0148, 302.8263],
        [253.3294, 365.1730, 255.2824,  ..., 118.7490, 154.1597, 112.7663],
        [253.9005, 367.2047, 257.0028,  ..., 100.5477, 267.9542,  91.8888],
        ...,
        [256.0331, 366.7202, 255.7840,  ..., 228.6134, 331.9129, 254.3207],
        [253.7332, 366.9143, 256.2254,  ..., 323.5127, 346.4645, 344.5659],
        [253.1827, 365.6290, 255.7749,  ...,  92.3760, 265.3914,  85.3268]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.0689, 368.5700, 257.0359,  ..., 210.7287, 304.9567, 211.4839],
        [257.4992, 366.3015, 254.9685,  ..., 233.0414,  99.3168, 217.3651],
        [260.1466, 368.7475, 256.4115,  ..., 268.2563, 424.0594, 283.2551],
        ...,
        [259.6230, 367.6362, 256.0648,  ..., 242.9032, 336.5029, 258.7179],
        [256.7772, 365.6843, 255.0536,  ..., 317.6032, 300.4757, 323.8796],
        [255.4943, 363.4416, 253.9978,  ..., 23

output tensor([[255.9895, 364.3115, 254.7060,  ..., 289.9771, 452.3931, 296.8703],
        [255.4324, 364.3068, 255.8072,  ..., 305.4317,  73.0822, 293.0271],
        [254.9948, 363.3198, 254.7688,  ..., 305.7631, 360.6459, 315.8102],
        ...,
        [255.5420, 364.9228, 256.7036,  ..., 308.5474,  60.6161, 305.4491],
        [255.4537, 363.5940, 255.3606,  ..., 336.1135, 372.8331, 345.1223],
        [257.2341, 365.8569, 255.2776,  ..., 263.3018, 372.0956, 274.6171]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.0763, 367.9908, 259.4732,  ..., 239.6746, 366.3412, 257.8758],
        [258.1540, 367.6102, 258.2036,  ..., 131.2715, 166.0070, 109.8429],
        [257.7747, 367.6691, 259.1873,  ..., 223.9779, 373.4447, 242.5578],
        ...,
        [258.8640, 368.1330, 258.5139,  ..., 163.9705, 121.9560, 143.8192],
        [257.8669, 367.6064, 258.7882,  ..., 157.7455, 267.2549, 168.7282],
        [257.6520, 367.0108, 258.4639,  ..., 10

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.3526, 368.0873, 258.7987,  ..., 309.6436, 308.1550, 314.1909],
        [259.8113, 369.0829, 258.8240,  ..., 147.5865, 165.6310, 141.8215],
        [258.4071, 368.3616, 258.5927,  ..., 175.0744, 295.9469, 168.6753],
        ...,
        [259.2715, 368.6081, 258.3925,  ..., 104.1416, 289.0950,  99.1052],
        [259.0194, 368.5856, 258.3857,  ..., 252.4194,  89.2173, 235.1361],
        [260.3617, 370.2928, 258.6885,  ..., 277.5680, 429.2069, 292.1766]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.3918, 362.8638, 255.1577,  ..., 335.6158, 297.1793, 327.6627],
        [256.0071, 364.1441, 255.3797,  ..., 144.9342, 167.6047, 136.4650],
        [254.8776, 363.1954, 254.2594,  ..., 265.3538, 457.3194, 264.3268],
        ...,
        [255.6327, 364.3172, 255.3091,  ..., 250.8863, 371.1569, 266.3866],
        [255.3342, 363.8234, 255.2998,  ..., 108.2189, 264.7904, 111.2643],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.3876, 364.5436, 256.2325,  ..., 139.2853, 232.3969, 126.3608],
        [253.5109, 361.9511, 254.4707,  ..., 288.0455, 307.3196, 296.3012],
        [254.8398, 363.1503, 255.0437,  ..., 216.5257,  99.6707, 193.3743],
        ...,
        [255.6526, 364.0402, 256.2094,  ..., 149.2984, 374.9962, 154.5923],
        [255.3392, 364.0753, 255.9262,  ..., 202.8408, 259.0317, 203.6878],
        [254.8768, 363.2121, 254.9794,  ..., 263.2090,  69.7471, 241.6863]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.7523, 368.9456, 259.1722,  ..., 179.1436, 218.1310, 173.2157],
        [260.6732, 370.4594, 260.2385,  ..., 197.8787, 392.9356, 214.7048],
        [259.4340, 368.8818, 259.5825,  ..., 328.0980, 301.9811, 337.5412],
        ...,
        [260.3292, 369.7659, 259.9721,  ..., 149.3161, 394.4258, 163.3656],
        [261.3241, 371.1198, 259.4596,  ..., 277.9604, 428.0571, 291.8740],
    

output tensor([[254.5220, 361.9720, 254.7670,  ..., 323.9339, 298.9888, 335.7376],
        [256.1631, 363.8925, 255.6050,  ..., 180.0510, 157.0892, 167.2904],
        [254.0834, 361.2298, 253.5120,  ..., 309.1745, 354.2890, 320.3622],
        ...,
        [254.7590, 362.5306, 255.2912,  ..., 201.6148, 375.0701, 226.4029],
        [254.7405, 362.0788, 255.4976,  ..., 186.8236, 415.8564, 209.1211],
        [255.7444, 363.0220, 255.0888,  ..., 119.4145, 348.6907, 136.3080]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.9057, 368.7780, 259.2491,  ..., 142.8014, 324.7415, 125.7119],
        [259.5981, 368.9446, 258.2020,  ..., 293.1282, 459.3554, 298.3944],
        [259.3566, 368.6743, 257.8909,  ..., 267.0780,  84.6922, 247.7513],
        ...,
        [258.5097, 368.1393, 258.1415,  ..., 249.8228, 459.4292, 242.2056],
        [259.1110, 368.9452, 259.4235,  ..., 312.9786,  61.7539, 313.3306],
        [259.5514, 369.0146, 258.7666,  ..., 20

torch.Size([60, 12])
torch.Size([60, 3])
output tensor([[257.6092, 366.5565, 257.2097, 283.3602, 257.6546, 203.9598, 278.4097,
         203.5325, 376.8643, 183.8546, 397.4843, 179.8593],
        [258.3199, 367.4484, 257.0975, 283.8332, 235.0272, 206.9205, 255.0964,
         200.6827, 215.9931, 109.3934, 236.4606, 106.3637],
        [257.3596, 366.0160, 256.2789, 282.7635, 179.3732, 298.0477, 175.4092,
         277.0858,  92.7706, 224.9271, 107.5041, 208.6115],
        [257.4855, 366.0145, 256.9636, 283.1448, 178.9732, 297.7686, 175.2173,
         277.1262, 189.1373, 178.6592, 211.4296, 185.3101],
        [257.3167, 366.3483, 257.5763, 283.5478, 257.3815, 203.3994, 277.7734,
         203.4637, 286.5732, 104.8626, 290.0860,  84.3025],
        [257.3700, 366.5868, 257.2812, 283.2490, 305.2392, 219.0849, 321.8129,
         231.2957, 388.7681, 159.1029, 404.3787, 144.7524],
        [257.0237, 366.1112, 257.7017, 283.3636, 183.7677, 254.4674, 191.2334,
         235.4234, 166.6400, 139.1131, 

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.0507, 361.5818, 254.4292,  ..., 133.8268, 215.1902, 114.6192],
        [253.4278, 361.3412, 254.2138,  ..., 179.3342, 142.5559, 160.4097],
        [255.1280, 362.3526, 255.6724,  ..., 307.4734,  59.4900, 307.8285],
        ...,
        [254.4662, 362.9989, 254.5337,  ..., 249.0909, 374.4020, 266.6267],
        [253.3339, 361.4666, 254.8822,  ..., 215.0813, 297.2057, 237.4427],
        [254.1343, 362.3096, 254.7308,  ..., 188.4341, 342.6784, 201.1056]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[261.4383, 372.0020, 260.6532,  ..., 212.1230, 352.5867, 213.1913],
        [262.8393, 373.5472, 261.6323,  ..., 288.6611, 353.7496, 300.6698],
        [260.6077, 371.1050, 261.1223,  ..., 153.0119, 261.0857, 163.8291],
        ...,
        [263.4175, 374.6003, 263.2997,  ..., 342.1165, 399.9055, 354.9112],
        [264.0410, 375.2213, 263.6145,  ..., 341.4573, 413.6570, 355.5262],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.1020, 364.2172, 255.5858,  ..., 178.4497, 189.1727, 178.0155],
        [255.6726, 364.6355, 255.7319,  ..., 127.7215, 152.6310, 110.5543],
        [256.8981, 365.1537, 255.2739,  ..., 155.0111, 152.9333, 145.7503],
        ...,
        [255.8764, 364.2847, 255.1389,  ..., 215.4737, 112.8545, 198.1418],
        [256.3568, 364.8002, 255.4984,  ..., 163.8300, 119.4020, 144.5794],
        [255.1997, 364.3753, 255.9783,  ..., 223.8201, 433.1987, 241.0945]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.5345, 368.1496, 257.8032,  ..., 250.0465,  85.5893, 231.7384],
        [257.4699, 367.5352, 258.9535,  ..., 220.9499, 415.1017, 242.1993],
        [257.8740, 367.5510, 258.2458,  ..., 202.1976, 261.4747, 204.9364],
        ...,
        [258.7568, 368.3771, 258.6659,  ..., 264.2213,  68.7706, 245.3237],
        [258.5071, 368.2071, 257.6917,  ..., 277.0454, 420.7948, 293.6886],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.1918, 363.8854, 255.9963,  ..., 213.2284, 288.9664, 232.7549],
        [256.0943, 364.2748, 255.0061,  ..., 330.4361, 415.1689, 348.6321],
        [255.9212, 364.5663, 254.9353,  ..., 277.3748, 355.3747, 291.8357],
        ...,
        [255.4386, 363.6765, 255.6603,  ..., 197.0281, 380.9237, 217.6033],
        [255.7317, 364.3922, 255.8800,  ..., 181.7481, 160.3409, 174.5707],
        [255.5707, 363.7795, 256.2858,  ..., 333.7258, 307.8535, 345.4581]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.8183, 365.8425, 256.8536,  ..., 181.2973, 396.2548, 175.6827],
        [257.0092, 366.4396, 257.3529,  ..., 107.7013, 269.9896, 116.4307],
        [256.0132, 364.3531, 256.2416,  ..., 146.2147, 390.0913, 158.6423],
        ...,
        [257.7765, 367.3979, 257.4470,  ..., 112.6647, 222.5008, 107.0120],
        [256.5839, 366.1766, 257.4966,  ..., 123.9676, 156.1449, 106.6799],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.8199, 368.3319, 258.6945,  ..., 292.0562, 336.2742, 311.8725],
        [257.6044, 367.3849, 258.1937,  ..., 122.2255, 294.1628, 102.0508],
        [257.9174, 367.0749, 257.7176,  ..., 141.2395, 180.0142, 124.5919],
        ...,
        [257.1174, 366.3968, 257.2835,  ..., 197.4538,  81.6082, 187.8568],
        [257.2432, 366.8604, 257.9200,  ..., 331.6176, 294.3746, 324.8103],
        [257.3488, 366.5761, 256.9965,  ..., 218.7930, 115.3599, 202.9300]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.7103, 364.2220, 255.2399,  ..., 205.0686,  78.9355, 192.0493],
        [256.6200, 365.6202, 255.7761,  ..., 238.5984, 336.1823, 252.7081],
        [256.7445, 365.4261, 256.8894,  ..., 187.5746, 328.8623, 208.7859],
        ...,
        [256.8347, 365.3737, 256.8326,  ..., 143.7680, 368.2596, 162.9134],
        [257.1840, 365.8414, 256.5763,  ..., 100.7979, 281.2376, 102.9048],
    

output tensor([[255.5242, 364.2427, 255.3899,  ..., 154.7027, 396.1759, 136.3910],
        [255.5133, 363.7878, 255.6743,  ..., 196.7158, 379.5838, 215.4962],
        [254.7339, 362.8109, 254.9471,  ..., 333.6192, 344.9083, 343.9739],
        ...,
        [254.7047, 363.1602, 254.9577,  ..., 322.2706, 291.3721, 317.5748],
        [256.3154, 364.7089, 255.3931,  ..., 153.2079, 149.8986, 142.9269],
        [255.4616, 364.0438, 255.5251,  ..., 200.4285, 258.8345, 202.0114]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.6834, 367.9885, 258.4951,  ..., 162.9333, 406.6813, 177.7681],
        [258.5547, 368.0019, 257.9020,  ..., 243.0557, 339.0602, 257.5245],
        [259.3088, 369.1233, 258.4306,  ..., 251.0118,  91.1857, 232.2372],
        ...,
        [259.1933, 368.8908, 259.6289,  ..., 184.7865, 134.2476, 164.6387],
        [259.4594, 369.2707, 258.8988,  ..., 285.9020,  85.0166, 267.2056],
        [258.4769, 367.7653, 258.7107,  ..., 23

output tensor([[255.4035, 364.1200, 256.1226,  ..., 105.8181, 235.5753, 109.9532],
        [254.9146, 363.0364, 255.7959,  ..., 166.4544, 388.5323, 161.4523],
        [256.0953, 364.1512, 255.6732,  ..., 328.1746, 420.2509, 349.1961],
        ...,
        [254.4314, 363.0470, 255.8897,  ..., 209.9650,  85.2966, 196.3128],
        [255.0930, 363.6437, 255.9900,  ..., 110.8084, 283.1976, 122.6926],
        [254.8741, 362.7870, 255.7401,  ..., 147.3602, 378.1851, 152.2139]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[260.1681, 370.0824, 259.8172,  ..., 212.6104, 349.3023, 217.7025],
        [259.4029, 369.6831, 260.0530,  ..., 233.1879, 443.3669, 243.6980],
        [257.8612, 366.6387, 258.2093,  ..., 327.0893, 291.4198, 333.6830],
        ...,
        [258.5438, 367.4203, 258.1692,  ..., 309.6687, 361.2628, 333.1562],
        [259.0196, 368.3706, 259.6211,  ..., 221.6449, 302.2254, 244.1246],
        [257.9173, 366.6784, 257.5306,  ..., 29

output tensor([[258.3965, 367.6116, 258.0625,  ..., 307.7800,  77.8454, 288.0409],
        [258.0710, 367.5056, 257.9958,  ..., 190.8275,  82.5818, 177.6809],
        [259.1666, 368.8304, 259.0688,  ..., 180.1522, 423.3480, 167.5363],
        ...,
        [258.8943, 369.0186, 259.1602,  ..., 337.2805, 299.3236, 328.5246],
        [258.5424, 368.4240, 258.7640,  ..., 112.3217, 280.3756,  98.5903],
        [258.4891, 368.4563, 258.6379,  ..., 139.8497, 232.8599, 124.3390]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.6219, 362.3298, 254.4887,  ..., 338.9724, 358.9942, 347.2170],
        [254.8911, 362.9148, 254.8321,  ..., 141.8663, 318.2237, 162.4917],
        [254.7177, 362.5578, 254.6965,  ..., 106.8912, 314.8358,  89.2844],
        ...,
        [254.2479, 362.5094, 254.8348,  ..., 222.8923, 418.9749, 246.0049],
        [254.4371, 362.3889, 254.4066,  ..., 293.6317, 431.2194, 316.7007],
        [254.9954, 363.2660, 254.9071,  ..., 20

output tensor([[256.6692, 365.4614, 256.8433,  ..., 335.5306, 351.6966, 344.8811],
        [257.2348, 366.3973, 257.0974,  ..., 335.4910, 400.7520, 351.9500],
        [257.0679, 366.3169, 257.4564,  ..., 230.4386, 441.4194, 236.6921],
        ...,
        [256.5497, 366.3172, 257.1347,  ..., 235.4877, 445.4031, 243.9116],
        [256.8941, 365.7953, 257.3704,  ..., 326.7157, 322.2963, 339.0952],
        [256.2610, 365.1169, 256.6622,  ..., 143.0033, 319.5976, 158.8378]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.6750, 365.8825, 256.0511,  ..., 129.8177, 181.3068, 118.7372],
        [255.8782, 365.0069, 256.5091,  ..., 111.4672, 187.3253,  91.0381],
        [256.8978, 365.5915, 256.4139,  ..., 330.9306, 382.7047, 348.5786],
        ...,
        [256.4347, 365.1750, 256.5889,  ..., 103.5247, 267.6095,  82.8832],
        [257.5596, 366.6169, 257.4430,  ..., 338.2343, 356.1534, 346.6830],
        [256.6263, 365.3962, 256.5705,  ..., 23

output tensor([[254.6293, 363.0826, 254.1025,  ..., 238.7114, 334.5135, 248.8146],
        [253.6062, 361.1863, 253.5504,  ..., 307.1369, 350.8777, 321.5808],
        [255.3913, 363.7874, 255.3009,  ..., 187.3693, 331.4049, 206.0197],
        ...,
        [255.3823, 364.8655, 256.4363,  ..., 110.3707, 180.9719,  88.9783],
        [256.1989, 365.2309, 256.4709,  ..., 269.4474,  50.4879, 265.0608],
        [256.4601, 365.6183, 256.2182,  ..., 163.5850, 107.7021, 143.7003]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.3931, 365.1332, 256.8741,  ..., 326.1156, 305.1847, 322.6819],
        [258.7747, 369.0854, 258.8553,  ..., 124.9157, 206.8447, 115.6766],
        [258.7962, 370.1483, 259.6476,  ..., 207.8900, 432.8409, 223.3503],
        ...,
        [257.6052, 366.8636, 258.1478,  ..., 186.9661, 334.9510, 206.4583],
        [259.1654, 369.2047, 258.7500,  ..., 208.8103, 107.6224, 189.3297],
        [257.7346, 367.1187, 258.3499,  ..., 11

output tensor([[250.6940, 357.6238, 250.9306,  ..., 109.6158, 266.4381,  91.1645],
        [250.9888, 357.8788, 250.9040,  ..., 137.8759, 229.4485, 131.0061],
        [249.7291, 355.7364, 249.7218,  ..., 285.0872, 304.9497, 304.3841],
        ...,
        [250.0141, 356.0179, 249.9675,  ..., 285.6697, 311.7557, 307.1552],
        [250.1649, 356.5342, 249.9834,  ..., 140.8941, 317.5724, 135.2494],
        [252.0470, 359.0960, 251.2968,  ..., 229.1366,  89.1215, 211.3300]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[264.0271, 375.2945, 264.3681,  ..., 115.1946, 343.5768, 125.5385],
        [264.6782, 376.6551, 265.3166,  ..., 338.9576, 361.3745, 350.4524],
        [264.6578, 376.7166, 265.6937,  ..., 312.8794,  68.1889, 316.4474],
        ...,
        [263.4940, 375.4760, 264.4762,  ..., 114.8909, 194.9612,  91.2944],
        [263.7155, 374.9969, 264.1874,  ..., 114.9060, 336.2899,  94.1825],
        [264.2326, 376.7377, 265.4678,  ..., 22

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.4065, 366.6269, 257.3885,  ..., 209.0180,  81.9437, 196.9142],
        [257.4459, 366.5428, 257.6491,  ..., 182.1572, 149.6636, 165.7736],
        [257.0864, 365.6851, 256.7250,  ..., 138.8992, 172.9782, 122.3482],
        ...,
        [256.9268, 366.0408, 256.8848,  ..., 139.3044, 239.8420, 139.2932],
        [257.3142, 366.5207, 257.7405,  ..., 228.0079, 442.5947, 232.6683],
        [256.6082, 365.2352, 256.5821,  ..., 103.2943, 282.1071,  93.0637]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.7555, 365.8196, 256.3758,  ..., 288.3560, 460.7228, 292.3297],
        [255.6035, 364.4506, 255.6335,  ..., 169.8051, 284.6804, 157.1022],
        [256.6967, 366.0631, 256.0530,  ..., 269.7081, 371.5730, 282.5540],
        ...,
        [256.2162, 365.1182, 256.2673,  ..., 110.4901, 279.0483, 119.5166],
        [256.0564, 364.9651, 256.2351,  ..., 215.2930, 293.0238, 233.0324],
    

output tensor([[258.3279, 367.8929, 257.9900, 284.2832, 258.7298, 201.1766, 280.3341,
         201.1938, 361.7134, 257.0458, 382.4993, 267.8982],
        [257.6082, 366.7838, 257.9213, 283.7160, 302.7213, 215.2156, 320.9239,
         227.0628, 316.3128, 326.7919, 299.1340, 334.5902],
        [257.7228, 366.4788, 257.4899, 283.7843, 258.4942, 201.4229, 279.1630,
         201.4276, 319.8922, 113.5349, 338.3863, 127.1483],
        [257.4288, 366.3122, 257.4737, 283.3506, 179.4038, 295.1816, 176.2754,
         274.4528, 176.4418, 174.5920, 198.4193, 179.3822],
        [258.0790, 367.1575, 258.0365, 284.1032, 212.1110, 215.0756, 230.1735,
         202.9666, 323.1732, 176.5673, 328.0432, 197.4875],
        [258.1461, 367.4028, 257.9175, 284.2267, 251.4939, 200.8698, 272.7528,
         198.8754, 251.6097, 102.2878, 273.0855, 103.8188],
        [258.4005, 367.3256, 258.0213, 284.1942, 257.8524, 201.1773, 278.9566,
         200.7187, 366.8264, 160.2957, 389.1577, 150.0303],
        [258.3901, 3

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.8311, 363.1932, 254.3768,  ..., 215.2334,  99.0050, 198.2767],
        [253.2903, 361.0155, 252.9788,  ..., 239.0220, 332.3825, 258.5494],
        [253.3763, 360.7632, 253.5940,  ..., 298.5309, 305.0402, 306.1768],
        ...,
        [253.8835, 361.5220, 253.7529,  ..., 187.9445, 328.3894, 212.0024],
        [253.3685, 361.5255, 253.5731,  ..., 103.7939, 214.6213,  85.8554],
        [253.6864, 361.3136, 253.6766,  ..., 210.8077, 295.8908, 229.6028]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.8407, 369.9275, 260.2174,  ..., 200.3080, 259.2984, 200.0591],
        [259.6756, 369.5354, 260.0663,  ..., 332.1957, 296.9105, 332.9372],
        [259.9339, 369.9579, 260.3588,  ..., 292.7177, 322.6790, 308.8412],
        ...,
        [260.0919, 370.2795, 260.2156,  ..., 340.3284, 378.0470, 347.3428],
        [259.6179, 370.1263, 260.4233,  ..., 232.1703, 427.1729, 249.6187],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.3775, 366.6989, 257.6589,  ..., 140.2693, 228.9874, 120.7431],
        [257.2380, 365.8313, 257.3536,  ..., 113.7277, 338.3386, 118.9078],
        [257.3134, 366.7274, 257.7442,  ..., 165.9949, 280.1156, 155.4820],
        ...,
        [257.6948, 366.6629, 257.6960,  ..., 191.5834, 341.5742, 204.5052],
        [257.9318, 367.0463, 258.1128,  ..., 104.1105, 275.4090, 103.9188],
        [257.4577, 366.5181, 257.5994,  ..., 336.1162, 300.7931, 330.6002]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.0832, 366.2443, 257.5803,  ..., 210.2354,  85.0928, 195.9406],
        [257.4337, 366.6399, 257.0169,  ..., 276.1143, 359.9010, 294.3805],
        [257.3782, 366.6691, 257.5892,  ..., 106.6404, 268.8391, 116.2358],
        ...,
        [257.0536, 365.9099, 257.1588,  ..., 222.7594, 301.7897, 244.6264],
        [257.1995, 366.5095, 257.5288,  ..., 111.1319, 271.4702,  91.3564],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.4106, 365.9466, 257.1009,  ..., 176.8363, 211.5247, 173.5583],
        [253.8112, 361.3036, 253.4669,  ..., 309.9064, 346.5966, 319.5500],
        [256.8017, 365.5693, 256.6141,  ..., 210.4352, 287.4801, 232.6168],
        ...,
        [256.7551, 366.0492, 256.9393,  ..., 161.9188, 273.5843, 150.0457],
        [256.8040, 365.9655, 256.8698,  ..., 136.8717, 226.0084, 121.8642],
        [255.7586, 363.7530, 255.3619,  ..., 143.5884, 377.3487, 145.2993]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.4155, 367.6299, 258.3575,  ..., 136.3043, 190.1552, 112.8703],
        [258.7418, 367.7397, 258.2146,  ..., 290.2510, 337.5319, 310.7120],
        [259.1251, 369.1501, 258.9576,  ..., 238.1341, 456.2343, 230.6202],
        ...,
        [258.5150, 367.2634, 258.3638,  ..., 146.6284, 373.3324, 165.1368],
        [259.1937, 369.3002, 259.0168,  ..., 235.7376, 455.2069, 228.6034],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.7633, 364.4490, 255.6415,  ..., 187.5615, 297.7537, 182.3243],
        [257.2859, 366.2018, 256.3296,  ..., 325.2451, 423.7914, 348.6147],
        [256.6230, 365.3145, 256.3044,  ..., 186.5685, 344.7451, 197.3459],
        ...,
        [256.6406, 365.5499, 256.1179,  ..., 192.7901, 396.9706, 197.8607],
        [256.4512, 365.2189, 256.5107,  ..., 100.6330, 269.4904,  83.1045],
        [256.2319, 364.8229, 256.3520,  ..., 337.3384, 339.6953, 343.8466]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.1385, 366.9113, 257.5434,  ..., 282.2004, 395.8320, 293.7764],
        [257.5435, 367.0138, 257.6646,  ..., 310.8913, 428.0464, 325.0769],
        [258.0382, 367.4323, 258.5484,  ..., 347.6211, 342.8214, 345.6114],
        ...,
        [257.2581, 366.7811, 257.5645,  ..., 228.3743, 369.7042, 243.6465],
        [256.6332, 365.9857, 257.3919,  ..., 239.2887, 437.6880, 252.9521],
    

output tensor([[260.0581, 370.3870, 260.1404,  ..., 138.2584, 220.0142, 117.6483],
        [260.5800, 370.9832, 260.6320,  ..., 141.9423, 331.0679, 136.0928],
        [260.3021, 370.6792, 260.6135,  ..., 187.7160, 124.3012, 167.6513],
        ...,
        [260.6049, 371.0775, 260.9082,  ..., 241.4622, 333.7846, 262.8110],
        [261.5240, 372.3788, 261.2680,  ..., 333.3743, 418.3860, 355.1760],
        [260.2913, 370.7797, 260.4064,  ..., 175.3559, 292.6167, 166.5644]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[252.7029, 359.9464, 252.5486,  ..., 332.8751, 373.6147, 339.4590],
        [253.2833, 360.9987, 253.0269,  ..., 240.4803, 336.5292, 250.9410],
        [252.7206, 360.5596, 252.8335,  ..., 234.4072, 443.2837, 239.0188],
        ...,
        [253.9232, 361.2919, 253.1862,  ..., 148.1762, 159.8363, 140.0113],
        [253.5696, 361.2036, 253.0332,  ..., 160.6356, 140.2543, 147.4756],
        [253.6579, 361.2828, 253.5122,  ..., 19

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.8737, 366.9703, 257.4970,  ..., 136.9425, 233.7318, 121.8908],
        [257.7890, 366.6648, 257.7447,  ..., 175.0067, 197.5108, 179.3947],
        [258.6804, 368.2960, 258.3766,  ..., 241.2162, 459.5918, 236.5230],
        ...,
        [258.4633, 368.0589, 257.9609,  ..., 122.0165, 200.4466, 115.4485],
        [258.2657, 367.5325, 258.0776,  ..., 250.3414, 462.9939, 247.7677],
        [257.7994, 366.5311, 257.4933,  ..., 213.3051, 297.0387, 233.8212]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.1140, 364.0863, 255.4506,  ..., 248.2808, 379.6090, 261.3602],
        [254.3064, 362.5395, 254.7598,  ..., 211.6657, 304.8610, 215.0726],
        [255.2831, 363.7036, 255.4945,  ..., 143.9077, 379.1326, 124.5861],
        ...,
        [254.8515, 363.3354, 255.3213,  ..., 307.9714, 354.4037, 315.5670],
        [255.9329, 364.4521, 256.0123,  ..., 246.3526,  85.2290, 227.4897],
    

output tensor([[256.6074, 365.3447, 256.5294,  ..., 240.8385, 318.6235, 260.6727],
        [257.2746, 365.3449, 257.1745,  ..., 114.6471, 338.8132, 110.1365],
        [257.6358, 366.1316, 257.0934,  ..., 143.8179, 326.0709, 152.0556],
        ...,
        [258.4645, 367.6788, 258.0468,  ..., 181.5039, 426.5170, 172.6857],
        [257.0156, 365.7209, 257.5052,  ..., 300.7367, 309.7812, 305.1184],
        [258.0610, 366.4311, 257.6545,  ..., 142.3784, 364.8664, 157.8852]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.1198, 365.9607, 257.3016,  ..., 291.3672, 325.0421, 307.5588],
        [256.7117, 365.0457, 256.8806,  ..., 108.0125, 264.2378,  81.8038],
        [256.9058, 365.4612, 257.0695,  ..., 149.5617, 391.0047, 138.0605],
        ...,
        [256.2171, 365.0850, 256.2560,  ..., 183.8751, 165.0676, 176.9113],
        [257.2947, 366.4681, 257.0370,  ..., 319.4213, 357.1961, 331.4708],
        [256.5307, 365.4495, 256.1180,  ..., 30

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.5826, 364.9253, 256.1886,  ..., 142.3411, 317.5958, 161.3855],
        [256.8030, 365.3271, 256.1789,  ..., 333.2726, 405.8193, 354.9231],
        [255.7572, 364.4188, 255.6935,  ..., 308.2045,  65.1463, 300.4591],
        ...,
        [255.7480, 364.3205, 255.5358,  ..., 220.1554, 370.1667, 240.7113],
        [256.0502, 364.2739, 255.8698,  ..., 333.0138, 347.5518, 347.6546],
        [256.2292, 365.1600, 256.2361,  ..., 138.4537, 235.5394, 129.5059]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.3600, 367.8233, 258.3186,  ..., 159.7092, 266.0960, 171.9763],
        [258.4206, 367.7175, 258.0918,  ..., 198.4212, 398.7426, 196.2262],
        [258.4584, 368.1632, 258.7716,  ..., 233.6222, 433.9029, 244.3359],
        ...,
        [257.6599, 366.6146, 257.7537,  ..., 147.1956, 384.3752, 128.1779],
        [258.1464, 367.2103, 257.8815,  ..., 256.1689,  87.3186, 234.8836],
    

output tensor([[251.9174, 358.4849, 251.6256,  ..., 144.6873, 368.5414, 160.3213],
        [251.7827, 359.0805, 251.3882,  ..., 214.2090,  96.6476, 194.4839],
        [251.8133, 359.3718, 251.4987,  ..., 264.8353, 370.9564, 278.8524],
        ...,
        [251.1974, 358.5769, 251.3595,  ..., 248.0380, 370.3624, 267.5444],
        [250.9007, 357.7763, 250.6869,  ..., 276.4125,  49.8239, 273.3557],
        [251.7467, 359.1270, 251.4029,  ..., 219.2967, 347.0623, 221.2657]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[261.0493, 371.5698, 261.0526,  ..., 136.2123, 192.7060, 114.7519],
        [262.1030, 372.9551, 262.1644,  ..., 134.8086, 318.9298, 120.6583],
        [261.8913, 372.6511, 261.8792,  ..., 110.7827, 239.2353, 107.5004],
        ...,
        [263.0207, 373.5945, 262.8128,  ..., 140.4205, 367.8775, 159.9355],
        [263.5237, 374.6824, 263.8188,  ..., 345.8719, 343.7756, 349.7129],
        [263.5240, 375.4724, 264.0841,  ..., 23

output tensor([[258.2175, 367.2030, 257.9600,  ..., 344.0119, 346.2578, 345.2395],
        [257.9158, 366.7739, 257.7097,  ..., 336.5490, 300.3923, 331.6254],
        [257.4061, 366.3533, 257.3487,  ..., 178.1700, 203.5919, 182.2714],
        ...,
        [257.5139, 366.5013, 257.6532,  ..., 294.4639, 307.9846, 301.1747],
        [257.5073, 366.6398, 257.2531,  ..., 111.6718, 227.8492, 105.6101],
        [257.1494, 366.1709, 256.8711,  ..., 174.2294, 285.8684, 161.1818]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.2762, 366.3908, 257.1168,  ..., 112.7097, 221.5596, 107.6112],
        [257.1006, 366.0189, 257.2198,  ..., 213.6280,  82.4302, 199.7440],
        [256.9280, 365.3289, 256.8063,  ..., 113.2526, 338.7341, 119.6777],
        ...,
        [257.8751, 366.9008, 257.5351,  ..., 155.3882, 384.2910, 144.6408],
        [257.8839, 366.7612, 257.7375,  ..., 342.9978, 348.2855, 345.0408],
        [257.9478, 366.8270, 257.6776,  ..., 34

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.3539, 367.5713, 258.2566,  ..., 191.0047, 343.5992, 204.6365],
        [258.1227, 367.4508, 258.5611,  ..., 204.2276, 417.6682, 226.2193],
        [257.7298, 366.6453, 257.5560,  ..., 266.4378,  54.1132, 261.8215],
        ...,
        [258.4447, 367.6325, 258.2715,  ..., 147.4025, 381.4950, 146.3165],
        [258.1604, 367.9467, 258.4548,  ..., 132.7210, 313.6570, 116.7790],
        [257.7147, 366.6207, 257.6923,  ..., 179.2498, 214.5731, 183.9299]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.8584, 364.6010, 256.1298,  ..., 297.5733,  54.0232, 299.7109],
        [256.2432, 365.0882, 256.2565,  ..., 136.5753, 200.6364, 113.1694],
        [256.1001, 364.4832, 256.2542,  ..., 113.0211, 330.2574,  95.6984],
        ...,
        [256.2983, 364.5936, 256.3239,  ..., 339.5677, 343.2135, 345.5534],
        [256.4605, 364.6476, 256.4937,  ..., 328.7279, 296.4442, 333.8659],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.9930, 367.7464, 258.0962,  ..., 193.3853, 354.5153, 190.2360],
        [258.0681, 367.7245, 258.5377,  ..., 211.9621, 311.9109, 217.8659],
        [258.0229, 367.6808, 258.5092,  ..., 180.6612, 424.0119, 193.2173],
        ...,
        [256.5154, 365.8484, 256.8783,  ..., 312.8574, 354.1133, 318.9247],
        [257.9663, 367.5835, 258.4526,  ..., 211.2246, 313.4054, 214.7655],
        [257.6786, 367.2163, 258.1532,  ..., 235.2189, 315.4044, 255.4016]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.1784, 364.1492, 255.8538,  ..., 181.9767, 220.1137, 175.0306],
        [256.9465, 365.9683, 256.9954,  ..., 280.2482,  53.2158, 275.0340],
        [255.9026, 363.8012, 255.8218,  ..., 336.8423, 296.9795, 342.2082],
        ...,
        [256.0721, 364.5643, 256.2664,  ..., 205.9333, 260.5370, 203.1356],
        [255.3886, 363.1778, 255.2643,  ..., 337.8586, 318.7789, 345.7719],
    

output tensor([[253.8431, 361.6573, 254.0917,  ..., 113.4017, 336.5063,  98.6053],
        [254.6052, 363.5398, 254.5108,  ..., 275.6623, 428.8289, 284.4505],
        [253.7587, 361.9812, 254.0078,  ..., 313.9898, 354.5218, 318.4802],
        ...,
        [254.5167, 363.1125, 254.7604,  ..., 182.7830, 422.2980, 191.8918],
        [255.3993, 364.2450, 255.6827,  ..., 186.5022, 123.7199, 164.4298],
        [254.8905, 363.4186, 254.8070,  ..., 290.0740, 338.3185, 306.6487]],
       grad_fn=<AddmmBackward0>)
torch.Size([60, 12])
torch.Size([60, 3])
output tensor([[259.0783, 369.1670, 258.8251, 285.3248, 302.2837, 217.7442, 320.3538,
         229.3293, 405.2393, 277.8886, 420.4431, 294.8460],
        [259.4518, 369.2372, 259.2839, 285.6198, 179.3568, 299.8005, 175.5071,
         278.5616,  79.1647, 248.2542,  90.8517, 231.0184],
        [259.7265, 369.9475, 259.8698, 286.1978, 178.4492, 298.9526, 175.2103,
         277.5899, 124.3915, 188.3496, 114.9591, 170.9287],
        [258.7998, 368.34

output tensor([[252.3111, 359.0753, 251.7399,  ..., 259.8458,  79.2492, 238.6306],
        [251.4562, 358.8963, 251.2037,  ..., 168.4391,  95.6042, 152.3535],
        [252.0556, 358.6936, 251.7402,  ..., 211.2082,  90.2020, 190.1893],
        ...,
        [252.0465, 358.8683, 251.4799,  ..., 204.4158,  81.1724, 186.5362],
        [250.6745, 358.1444, 251.0025,  ..., 115.7782, 167.5992,  94.2171],
        [251.6755, 358.5094, 251.6068,  ..., 302.7981, 302.4438, 305.0320]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[260.1181, 370.6208, 260.8082,  ..., 179.3801, 196.8062, 183.1210],
        [260.2668, 370.7146, 260.7398,  ..., 182.3844, 217.4962, 189.3546],
        [260.3288, 370.5377, 260.7300,  ..., 316.1258, 307.4244, 318.9583],
        ...,
        [260.4944, 370.6809, 260.4825,  ..., 266.3484,  86.5387, 247.1081],
        [260.3152, 370.1348, 260.3764,  ..., 150.1042, 382.5829, 161.2137],
        [260.3336, 370.5398, 260.6038,  ..., 18

output tensor([[256.4391, 365.2875, 256.6691,  ..., 218.1075, 111.4775, 203.7354],
        [256.5693, 365.1152, 256.4768,  ..., 217.8072, 298.4630, 239.2191],
        [256.6890, 365.4464, 256.6796,  ..., 332.8607, 331.0612, 345.7467],
        ...,
        [256.1673, 365.1225, 256.0943,  ..., 190.8256,  80.8653, 182.1511],
        [256.3318, 365.0435, 256.2463,  ..., 176.6265, 291.2515, 165.5990],
        [256.9668, 365.9530, 256.7895,  ..., 154.4362, 151.7299, 145.7897]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.2821, 367.3605, 258.1172,  ..., 215.9151,  94.5261, 195.0787],
        [258.5338, 367.7856, 258.2497,  ..., 185.9164, 399.2446, 178.8911],
        [257.9248, 367.2523, 257.8168,  ..., 140.5376, 179.6689, 123.1124],
        ...,
        [258.6896, 368.5438, 258.3152,  ..., 335.7493, 400.1798, 353.0339],
        [258.0374, 367.3993, 257.9800,  ..., 143.5781, 328.7790, 132.5076],
        [257.4798, 367.1837, 257.3784,  ..., 13

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.6690, 361.2207, 253.6766,  ..., 321.5204, 325.2641, 338.5302],
        [254.0648, 361.7183, 253.5842,  ..., 187.6680, 346.8432, 182.7523],
        [253.5030, 360.8466, 253.5085,  ..., 301.4165, 304.9619, 306.0516],
        ...,
        [253.9658, 361.2128, 253.4803,  ..., 146.4103, 376.9837, 153.4908],
        [254.9008, 363.0446, 254.7509,  ..., 140.6206, 178.2000, 138.0368],
        [254.1270, 361.7816, 253.8583,  ..., 211.8438, 294.0876, 228.2342]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[260.2751, 370.4841, 259.7874,  ..., 334.5892, 407.3964, 350.8550],
        [260.0041, 369.7788, 259.8535,  ..., 230.9360, 308.5860, 252.7107],
        [260.6373, 371.0535, 260.6774,  ..., 107.9579, 234.7819, 106.7455],
        ...,
        [260.9009, 371.4120, 260.7836,  ..., 112.6892, 227.5299, 108.8308],
        [259.9341, 370.1339, 259.9064,  ..., 313.1949, 352.9323, 321.1968],
    

output tensor([[259.2650, 369.1787, 259.2657,  ..., 315.7159, 351.2934, 327.9315],
        [258.2354, 368.4707, 258.2509,  ..., 175.1476, 408.5114, 195.8833],
        [256.6858, 366.3257, 257.1751,  ..., 107.8214, 198.4933,  89.4517],
        ...,
        [258.2617, 367.6722, 258.0552,  ..., 149.0340, 161.8063, 145.6241],
        [256.7117, 365.7498, 256.5335,  ..., 103.6742, 269.9942,  84.6348],
        [256.7833, 365.7802, 256.7448,  ..., 104.5214, 273.2850,  90.5325]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.9356, 369.8383, 259.3653,  ..., 177.5381, 419.4595, 189.4357],
        [256.9554, 366.3703, 257.3382,  ..., 104.1461, 219.8081,  79.9468],
        [259.2111, 368.9781, 259.0954,  ..., 332.1483, 391.4828, 342.0685],
        ...,
        [259.7163, 369.6166, 260.2661,  ..., 239.7717, 323.7777, 257.6333],
        [260.1003, 370.5502, 260.3304,  ..., 230.4734, 354.9486, 234.3461],
        [257.7120, 367.3598, 258.0204,  ..., 18

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.1258, 364.7170, 255.7897,  ..., 342.9878, 349.2583, 345.1714],
        [256.1856, 364.9996, 256.8051,  ..., 113.0943, 332.6854,  98.1657],
        [256.4814, 365.1406, 256.0976,  ..., 118.9639, 343.5049, 129.5225],
        ...,
        [256.5793, 365.9087, 256.3492,  ..., 140.7879, 226.7447, 117.7903],
        [256.9120, 366.1175, 257.1047,  ..., 208.4973, 101.5010, 186.0454],
        [256.5501, 365.8546, 256.7110,  ..., 202.2887,  80.1947, 191.4306]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.6427, 365.2295, 256.3354,  ..., 111.9025, 338.3182, 124.7267],
        [256.7574, 365.9759, 256.6410,  ..., 115.0098, 287.8409,  96.2627],
        [256.9582, 365.4044, 256.6743,  ..., 333.1107, 308.3849, 343.4647],
        ...,
        [256.1953, 364.9269, 256.3751,  ..., 203.7460,  96.9507, 185.3617],
        [255.9914, 364.4816, 256.0004,  ..., 194.1614, 255.3268, 196.5985],
    

output tensor([[255.9914, 364.5224, 255.8944,  ..., 192.1359, 347.4983, 199.5757],
        [255.5853, 364.1970, 255.6759,  ..., 213.6825, 309.2064, 217.3114],
        [256.1006, 364.6053, 255.9934,  ..., 143.4283, 321.4460, 157.7056],
        ...,
        [255.8546, 364.4867, 256.0857,  ..., 230.9933, 310.2427, 249.5390],
        [255.9746, 364.4644, 255.9769,  ..., 342.1734, 354.8044, 344.3333],
        [256.5037, 365.3376, 256.2683,  ..., 172.5936, 420.5244, 157.2336]],
       grad_fn=<AddmmBackward0>)
torch.Size([60, 12])
torch.Size([60, 3])
output tensor([[258.5090, 367.6036, 258.3082, 284.2046, 257.0885, 203.6035, 277.6943,
         203.0996, 357.6233, 147.3530, 379.3155, 148.4352],
        [258.3357, 367.4080, 258.1854, 284.3110, 178.9844, 297.2825, 175.3578,
         276.4117, 116.9082, 194.3253, 104.6724, 175.8336],
        [257.7126, 366.7481, 257.7943, 283.6777, 256.5552, 203.3480, 277.2299,
         202.7787, 265.6124, 104.2124, 280.5789,  91.6122],
        [257.7780, 366.79

output tensor([[254.7487, 363.1357, 254.6830,  ..., 197.8541,  73.9213, 193.1298],
        [254.6018, 362.2694, 254.5094,  ..., 213.8233, 413.3119, 238.2250],
        [254.3084, 362.3350, 254.1388,  ..., 303.7794, 420.4681, 325.0331],
        ...,
        [255.0223, 362.9514, 254.7402,  ..., 102.9702, 264.9731,  82.0233],
        [254.9984, 363.2868, 254.7383,  ..., 329.0849, 417.7530, 352.4029],
        [255.4303, 364.0006, 255.6941,  ..., 120.9361, 293.8964, 104.1992]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.2289, 367.4945, 258.3282,  ..., 220.2156, 113.7397, 203.4163],
        [260.0583, 369.7439, 259.7787,  ..., 337.4467, 354.0010, 348.0887],
        [258.7680, 367.8471, 258.4801,  ..., 214.0212,  84.8277, 198.5435],
        ...,
        [259.0397, 368.2971, 259.4820,  ..., 112.3138, 322.2229,  89.3700],
        [259.1619, 368.7888, 259.1115,  ..., 110.5705, 227.7941,  92.1294],
        [258.6470, 367.9236, 258.3537,  ..., 21

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.7336, 369.9976, 259.5921,  ..., 332.5995, 420.4139, 353.9489],
        [258.9557, 368.7571, 259.0419,  ..., 204.1086, 307.8188, 201.0480],
        [259.4284, 369.1828, 259.3274,  ..., 169.8325, 417.8817, 157.9729],
        ...,
        [259.7720, 369.6220, 259.5876,  ..., 337.3405, 301.2009, 333.5296],
        [259.1335, 369.9661, 260.1560,  ..., 311.1766,  60.7357, 307.9198],
        [259.6141, 369.3900, 259.5681,  ..., 148.3454, 380.4276, 153.0834]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.3287, 362.3140, 254.3755,  ..., 254.3799, 459.3761, 251.3798],
        [255.4497, 363.8922, 255.3618,  ..., 332.2512, 327.5474, 343.4700],
        [254.9760, 362.9799, 254.6096,  ..., 153.9124, 324.7658, 171.9179],
        ...,
        [254.6293, 362.2799, 254.3744,  ..., 179.7730, 421.9321, 190.3336],
        [254.6272, 363.7897, 255.4060,  ..., 309.8047,  60.5333, 309.8628],
    

output tensor([[252.8891, 360.8654, 252.7973,  ..., 127.7783, 172.8051, 115.8475],
        [253.7279, 361.6118, 253.6999,  ..., 307.6717,  55.0968, 311.7201],
        [252.8583, 360.2519, 252.9677,  ..., 227.4723, 412.5479, 248.0602],
        ...,
        [252.9885, 360.1605, 252.6408,  ..., 204.5143,  85.4529, 185.6520],
        [253.6088, 361.1920, 253.4713,  ..., 109.8642, 276.7235, 119.6717],
        [252.7163, 360.2591, 252.6130,  ..., 231.3290, 354.4248, 250.5453]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[260.6939, 371.7699, 261.3194,  ..., 305.2287,  62.6658, 305.2383],
        [263.2462, 375.4206, 263.7684,  ..., 240.4226, 462.9019, 235.8045],
        [262.7413, 374.1273, 262.9514,  ..., 185.8587, 437.9519, 176.1906],
        ...,
        [262.0608, 373.0490, 262.3753,  ..., 218.2387, 303.9500, 236.1713],
        [262.2446, 373.8848, 262.9780,  ..., 240.1469, 453.1378, 250.7779],
        [261.4973, 372.7759, 262.2636,  ..., 24

output tensor([[255.3947, 364.4378, 255.7022,  ..., 184.9406,  85.4729, 171.5866],
        [256.2151, 365.6939, 256.4814,  ..., 168.2751, 414.6115, 157.1754],
        [255.4471, 364.6138, 255.7585,  ..., 211.1709, 350.0832, 213.2708],
        ...,
        [256.0321, 364.9441, 256.6542,  ..., 232.8692, 437.0427, 244.8562],
        [256.4302, 365.0010, 256.5979,  ..., 107.8026, 229.8638,  98.0850],
        [256.0953, 365.3758, 256.6368,  ..., 249.1680, 440.4436, 255.1927]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.3478, 365.3371, 256.7388,  ..., 309.9638, 351.7368, 319.6130],
        [255.6334, 363.8302, 255.4497,  ..., 273.7213,  51.0129, 273.1720],
        [257.6056, 366.4558, 257.3338,  ..., 135.1435, 231.4984, 124.0722],
        ...,
        [256.2543, 365.1042, 256.5910,  ..., 330.3905, 382.6744, 345.3454],
        [256.9293, 365.6280, 257.0787,  ..., 338.5873, 354.2791, 343.8289],
        [256.8151, 366.1196, 257.1372,  ..., 27

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.2283, 366.0720, 257.0290,  ..., 203.9657,  92.8636, 187.0798],
        [256.9442, 365.9648, 256.4945,  ..., 204.0570, 349.9350, 206.5555],
        [256.9125, 365.5647, 256.7852,  ..., 218.4132, 302.6742, 240.7672],
        ...,
        [256.9706, 366.3040, 256.8788,  ..., 231.4046, 346.0829, 238.4790],
        [256.7146, 365.3723, 256.7492,  ..., 226.0286, 309.0543, 247.7322],
        [257.1081, 365.8412, 257.1799,  ..., 143.3606, 390.4030, 134.7910]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.6262, 366.8651, 257.6050,  ..., 110.9677, 239.7198, 104.8728],
        [257.4706, 366.7411, 257.4806,  ..., 109.6621, 229.8444,  95.9535],
        [257.4075, 366.6795, 257.5128,  ..., 205.2798,  75.7971, 196.3103],
        ...,
        [257.4806, 366.7361, 257.5082,  ..., 109.4310, 234.0546, 105.2577],
        [257.6399, 366.5898, 257.7006,  ..., 217.2259, 109.8018, 196.7433],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.6539, 362.4521, 254.4835,  ..., 148.3204, 379.5941, 142.2649],
        [255.6113, 363.6776, 255.3754,  ..., 179.4699, 213.8983, 176.2420],
        [254.1048, 362.4684, 254.7428,  ..., 250.5313, 355.2422, 269.9185],
        ...,
        [255.3032, 364.1782, 255.1301,  ..., 267.0277, 429.1824, 279.3506],
        [255.1766, 363.1709, 254.9011,  ..., 158.9918, 324.2562, 180.6014],
        [255.4801, 363.9829, 255.5597,  ..., 338.3919, 372.0104, 346.3828]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.4001, 369.9050, 259.3112,  ..., 274.3510, 363.9952, 289.0641],
        [258.9002, 368.3349, 259.5956,  ..., 230.7732, 429.6778, 246.4141],
        [257.9540, 367.0855, 258.2515,  ..., 145.1152, 391.8675, 140.4655],
        ...,
        [258.2042, 367.8389, 258.2153,  ..., 106.0122, 222.2390,  86.5995],
        [259.0176, 369.1874, 259.0221,  ..., 208.7776, 349.0478, 209.6582],
    

output tensor([[263.2952, 374.4591, 263.6883,  ..., 212.9450, 422.4074, 233.2628],
        [262.1568, 373.2278, 262.2768,  ..., 110.0920, 245.2472, 106.2709],
        [264.0439, 375.5528, 264.4709,  ..., 292.9766, 464.7359, 297.1871],
        ...,
        [261.6667, 372.6378, 262.2057,  ..., 179.2033, 204.9115, 183.2659],
        [260.9546, 371.4693, 261.2527,  ..., 208.0291, 100.4576, 187.9197],
        [263.2141, 374.2675, 263.3397,  ..., 194.1741, 345.8894, 207.3793]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[251.9243, 358.6213, 251.6625,  ..., 177.5568, 416.4183, 188.2984],
        [250.0320, 356.6440, 250.3780,  ..., 212.3976, 108.6914, 198.3376],
        [250.5673, 357.7919, 250.5624,  ..., 159.6992, 106.2848, 141.9781],
        ...,
        [251.8356, 359.1675, 251.8850,  ..., 198.5229, 255.8888, 200.5310],
        [250.5401, 357.6896, 250.2091,  ..., 125.2506, 149.6276, 107.8674],
        [252.1535, 360.1593, 251.9208,  ..., 18

output tensor([[256.6110, 365.2725, 256.5147,  ..., 291.8915,  74.1544, 265.9746],
        [256.1478, 365.7221, 256.1465,  ..., 259.3162, 378.6515, 265.4053],
        [256.3818, 364.8542, 256.6211,  ..., 330.4287, 296.7935, 334.0132],
        ...,
        [256.2359, 365.0673, 256.3467,  ..., 198.7518, 390.3866, 207.2945],
        [256.9532, 366.3720, 257.0359,  ..., 291.1537, 333.3196, 303.9238],
        [255.3977, 364.2146, 255.9427,  ..., 335.5995, 374.9215, 338.5113]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.4923, 365.6882, 256.9195,  ..., 297.5052, 433.4160, 321.0825],
        [257.3267, 367.1805, 257.5155,  ..., 146.6961, 394.0581, 161.9690],
        [257.7422, 366.5547, 258.0948,  ..., 335.8165, 305.6740, 345.1834],
        ...,
        [257.6916, 367.1154, 258.0895,  ..., 278.0533, 464.6223, 285.4920],
        [257.2031, 366.3152, 257.3434,  ..., 106.9729, 227.7298,  94.4738],
        [257.8951, 367.4442, 257.9378,  ..., 32

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.5809, 363.5245, 255.1765,  ..., 213.9186,  81.6762, 199.2673],
        [255.3137, 363.4010, 254.9604,  ..., 339.4950, 349.0045, 342.8599],
        [255.1633, 363.3430, 254.7837,  ..., 183.9417, 297.1427, 176.1513],
        ...,
        [255.9301, 364.5096, 255.6431,  ..., 145.9520, 257.2776, 159.2596],
        [255.1898, 363.3557, 255.1784,  ..., 226.0119, 308.1145, 246.3480],
        [254.4633, 362.3975, 254.0713,  ..., 323.5633, 413.6283, 340.4494]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.1782, 366.9234, 257.7350,  ..., 279.9427, 395.2497, 295.1109],
        [257.9018, 367.3548, 257.9435,  ..., 234.8460, 435.9444, 252.1854],
        [257.8786, 367.3069, 258.0901,  ..., 208.2306,  74.7504, 200.7576],
        ...,
        [258.3156, 367.4724, 258.0258,  ..., 150.5122, 320.8744, 167.9816],
        [258.6154, 367.9265, 258.5955,  ..., 272.3279, 462.1042, 271.2096],
    

output tensor([[255.5778, 363.9007, 254.9789,  ..., 190.7887, 398.1793, 186.9815],
        [255.9285, 364.5077, 256.0056,  ..., 287.2538, 323.8339, 305.8686],
        [254.4545, 362.5498, 254.5727,  ..., 287.4760, 430.4738, 305.7429],
        ...,
        [255.5777, 363.8983, 255.4339,  ..., 110.0241, 279.4159,  93.7612],
        [254.7914, 363.2982, 254.6713,  ..., 227.6759, 433.1568, 244.0770],
        [254.4518, 362.5411, 254.5634,  ..., 287.5917, 431.4152, 305.7121]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.1039, 369.3399, 259.3604,  ..., 294.5483,  56.0598, 296.9284],
        [258.4020, 368.0843, 258.3241,  ..., 108.0718, 214.3552,  87.1151],
        [257.8071, 367.2274, 257.9822,  ..., 292.2069, 455.6403, 304.6466],
        ...,
        [258.4040, 367.9144, 258.4247,  ..., 140.1778, 181.5787, 128.5938],
        [258.0791, 367.5960, 258.7457,  ..., 291.2480, 315.7288, 308.3371],
        [257.9456, 366.9824, 257.9471,  ..., 11

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.8981, 361.1461, 253.6596,  ..., 192.8612, 413.3678, 214.8932],
        [253.5829, 361.3815, 253.0831,  ..., 196.2662, 371.4530, 219.9948],
        [254.0368, 361.9194, 253.5448,  ..., 169.9897, 416.7676, 161.6164],
        ...,
        [254.1570, 361.6706, 253.9483,  ..., 268.0038,  79.1274, 251.1269],
        [253.5702, 361.7733, 253.8584,  ..., 177.2185, 211.7789, 185.3488],
        [254.5437, 362.8870, 254.5151,  ..., 104.8057, 221.2891,  90.0958]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[260.0629, 370.2429, 260.3448,  ..., 114.2922, 340.4463, 124.2668],
        [260.4658, 370.7390, 259.9580,  ..., 334.3818, 424.9907, 351.3101],
        [260.2219, 370.4804, 260.3277,  ..., 140.8336, 177.7646, 120.3186],
        ...,
        [259.7439, 370.0859, 260.3111,  ..., 109.4334, 309.7955,  86.2503],
        [260.0022, 370.5847, 259.8811,  ..., 131.7882, 164.7538, 109.8400],
    

output tensor([[259.9539, 369.7644, 260.1299,  ..., 240.0972, 314.4593, 257.4697],
        [258.8166, 368.9668, 258.9164,  ..., 169.0707, 100.9945, 152.8491],
        [260.5290, 370.3882, 260.3571,  ..., 331.9902, 298.4104, 335.9330],
        ...,
        [260.6375, 370.6363, 260.5764,  ..., 344.1888, 342.9901, 346.3588],
        [259.9011, 369.6195, 260.0553,  ..., 258.1525, 460.9694, 254.3671],
        [260.8714, 370.9517, 260.7665,  ..., 345.3267, 341.3616, 345.4592]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.4075, 362.1576, 254.2880,  ..., 288.7456, 433.7871, 309.7314],
        [254.7914, 362.5388, 254.3614,  ..., 125.2568, 352.0157, 142.1279],
        [254.9942, 363.5419, 255.6866,  ..., 242.7804, 438.1740, 246.8319],
        ...,
        [254.8484, 362.8334, 254.4020,  ..., 165.2785, 276.3571, 151.7396],
        [254.7392, 362.8058, 254.6727,  ..., 236.9919, 358.2044, 255.7106],
        [255.4191, 363.1987, 254.8407,  ..., 17

output tensor([[257.5166, 366.3351, 257.5293,  ..., 145.6011, 169.2115, 138.0603],
        [259.2556, 370.2039, 258.8763,  ..., 178.9136, 421.2467, 192.7161],
        [258.6503, 367.8051, 258.5013,  ..., 285.1898, 463.8837, 295.4250],
        ...,
        [257.7887, 367.2044, 258.2151,  ..., 118.6186, 298.1844,  97.4453],
        [258.8941, 369.5533, 258.5426,  ..., 163.5830, 411.2452, 175.9626],
        [257.7973, 367.0545, 257.9476,  ..., 135.6456, 241.8277, 141.0029]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.3071, 365.1661, 256.1248,  ..., 109.7698, 203.6632,  86.5462],
        [256.8027, 365.2563, 255.6786,  ..., 267.5191,  49.0317, 264.8571],
        [256.4980, 365.2644, 255.7219,  ..., 332.5470, 415.0475, 354.1928],
        ...,
        [256.2913, 365.2576, 255.7654,  ..., 333.6118, 409.0391, 355.4281],
        [255.7397, 363.9047, 255.3600,  ..., 276.9239, 453.7182, 279.3893],
        [256.1446, 364.4823, 255.9167,  ..., 25

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.8045, 367.8591, 258.6127,  ..., 181.5487, 161.8234, 176.4692],
        [257.2448, 365.2228, 256.0615,  ..., 263.7412,  58.9207, 258.8845],
        [257.6775, 367.9463, 258.6491,  ..., 178.3714, 179.6575, 179.1049],
        ...,
        [257.7932, 367.3296, 257.7494,  ..., 329.4226, 421.6354, 353.2736],
        [257.6689, 367.3924, 257.9604,  ..., 142.6259, 249.9211, 155.6249],
        [257.3847, 365.8893, 257.3711,  ..., 212.0565, 103.9594, 197.4903]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.7846, 365.4401, 256.3152,  ..., 231.4240, 363.5029, 249.6054],
        [256.1688, 365.0179, 256.2588,  ..., 313.1460, 351.4152, 319.3502],
        [256.5925, 365.1312, 256.5537,  ..., 343.4413, 353.3237, 344.1769],
        ...,
        [256.5941, 364.8809, 256.1849,  ..., 202.9341,  96.9529, 182.9437],
        [256.2209, 364.5587, 256.0656,  ..., 330.3513, 291.9108, 327.6973],
    

output tensor([[248.0786, 352.6725, 246.9709,  ..., 151.3955, 378.3404, 158.5346],
        [250.3787, 354.1695, 247.7110,  ..., 275.3895,  70.7997, 259.3665],
        [248.7997, 353.8651, 247.9281,  ..., 190.7496, 229.2850, 188.1626],
        ...,
        [250.0996, 354.8123, 248.5883,  ..., 201.4273, 123.8721, 176.0533],
        [250.1073, 354.3244, 248.2760,  ..., 212.0373, 107.3367, 188.0693],
        [248.6761, 352.8215, 247.5729,  ..., 116.4846, 269.5492, 114.1652]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[261.1069, 372.3965, 257.1525,  ..., 177.0248, 159.9041, 180.1487],
        [261.3216, 371.9122, 256.5879,  ..., 206.8746, 301.8021, 218.9451],
        [265.2415, 376.3806, 258.5277,  ..., 325.9515, 412.3725, 350.5399],
        ...,
        [260.0781, 370.3349, 255.3934,  ..., 132.7288, 236.1991, 140.8723],
        [261.3923, 372.1875, 256.1272,  ..., 163.6315, 413.5567, 159.6878],
        [261.3570, 371.6849, 255.6298,  ..., 26

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[261.1314, 370.5820, 260.4559,  ..., 144.6148, 332.8876, 145.5770],
        [259.6899, 370.4303, 260.5570,  ..., 151.9699, 166.3934, 142.1175],
        [261.2594, 371.6624, 261.1976,  ..., 185.5471, 135.2717, 163.6314],
        ...,
        [260.5220, 370.8320, 260.5408,  ..., 140.0851, 242.1920, 134.9718],
        [260.3137, 370.3795, 260.5765,  ..., 302.8888, 312.9885, 306.3908],
        [261.5888, 372.5952, 262.5747,  ..., 266.2038, 438.3451, 275.9596]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.3583, 360.6404, 253.0394,  ..., 203.4002, 257.6661, 203.0807],
        [252.8896, 359.6715, 252.8511,  ..., 279.0417, 395.5754, 296.2224],
        [254.1193, 361.1538, 252.7977,  ..., 240.9565, 334.6316, 248.4389],
        ...,
        [253.6971, 360.8110, 253.3409,  ..., 317.4357, 346.3785, 333.5081],
        [253.8287, 360.5247, 253.1498,  ..., 210.2510, 301.9066, 205.4433],
    

output tensor([[256.6439, 367.1974, 257.8130,  ..., 310.7049,  61.3534, 302.6210],
        [257.8035, 368.3954, 258.2188,  ..., 313.0842,  59.9160, 313.2999],
        [256.6700, 365.0081, 256.3702,  ..., 106.4784, 222.1933,  88.7605],
        ...,
        [256.2848, 365.8298, 257.6691,  ..., 229.7032, 435.9985, 239.0640],
        [255.6119, 364.3787, 256.4746,  ..., 201.0946, 417.6513, 218.9525],
        [257.1728, 367.3004, 257.6664,  ..., 306.1046,  56.5704, 305.7131]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.3216, 364.8857, 256.3616,  ..., 104.1280, 286.4905, 101.2900],
        [257.5031, 366.6586, 257.3465,  ..., 335.8701, 303.1558, 334.1326],
        [257.5954, 366.8045, 257.7834,  ..., 342.6963, 342.4509, 348.5931],
        ...,
        [257.2746, 366.4165, 257.3188,  ..., 112.2339, 220.9274, 110.4757],
        [256.0204, 365.6530, 256.6208,  ..., 144.6213, 388.5227, 130.0434],
        [256.7490, 366.5265, 257.7217,  ..., 18

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.1028, 364.6145, 255.9846,  ..., 333.7890, 340.2455, 346.5041],
        [255.6357, 364.3461, 255.9805,  ..., 181.3640, 423.3578, 188.8855],
        [255.7542, 364.9659, 256.2283,  ..., 163.3460, 135.1867, 147.9868],
        ...,
        [255.6915, 364.1820, 255.6918,  ..., 249.6253, 377.9933, 266.3687],
        [256.6400, 364.7955, 255.5757,  ..., 264.7963,  54.8233, 261.1613],
        [256.2679, 365.0994, 256.2883,  ..., 111.3167, 223.7164, 106.3281]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.0201, 368.1283, 258.6657,  ..., 125.9762, 303.8711, 106.5191],
        [258.7567, 367.6128, 258.3331,  ..., 261.5360,  85.6315, 241.4496],
        [258.8458, 368.4405, 258.9324,  ..., 324.3390, 332.2660, 339.0988],
        ...,
        [259.2641, 369.1801, 259.2167,  ..., 335.0928, 398.6579, 353.0846],
        [258.3822, 367.1157, 258.1924,  ..., 296.3312,  80.0459, 276.7136],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.7972, 365.7772, 256.2751,  ..., 238.6275, 455.1660, 232.6935],
        [256.5612, 365.7411, 256.9078,  ..., 201.8134, 102.4290, 183.2137],
        [256.5198, 365.8016, 256.7109,  ..., 138.8730, 245.2312, 148.7437],
        ...,
        [256.2961, 365.3892, 256.3223,  ..., 152.0002, 267.3976, 166.7874],
        [255.9356, 364.9165, 256.2587,  ..., 177.2795, 219.6986, 184.4509],
        [255.4491, 363.4464, 255.1955,  ..., 233.9699, 438.3206, 249.7601]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.3738, 369.4525, 259.5834,  ..., 127.7228, 153.9142, 109.4566],
        [259.6353, 370.0839, 259.9881,  ..., 151.3863, 377.1793, 140.1227],
        [257.8210, 367.4148, 258.2208,  ..., 328.6167, 299.1933, 336.1348],
        ...,
        [260.3435, 370.2983, 260.4826,  ..., 108.1461, 264.9195, 105.6782],
        [259.0534, 368.7053, 259.6432,  ..., 294.0647, 444.3636, 306.5814],
    

output tensor([[254.6639, 363.0162, 255.1662,  ..., 307.7627,  79.1597, 286.7378],
        [255.2961, 363.2914, 254.9509,  ..., 211.7291, 347.6507, 211.4190],
        [254.9315, 362.8494, 255.0023,  ..., 302.5923, 305.2560, 304.8345],
        ...,
        [255.7875, 363.3040, 255.4061,  ..., 284.3056, 458.8752, 286.0613],
        [255.2583, 364.1643, 255.4753,  ..., 180.8313, 425.8214, 175.7197],
        [254.9891, 363.3702, 255.4035,  ..., 310.9931, 353.7108, 321.1726]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.0506, 368.9635, 259.6184,  ..., 213.1726, 310.8981, 218.4229],
        [258.9526, 368.9189, 259.4573,  ..., 279.1084, 420.6795, 299.1891],
        [259.2861, 368.9567, 259.4341,  ..., 142.7086, 322.4897, 160.1661],
        ...,
        [259.2725, 369.3021, 259.4757,  ..., 278.0483, 427.4976, 294.6492],
        [259.2591, 368.5271, 259.0705,  ..., 168.3569, 282.8282, 155.4174],
        [259.5822, 368.9149, 259.6262,  ..., 26

output tensor([[257.1900, 366.2501, 257.3750,  ..., 290.1382, 327.5217, 308.8946],
        [257.0756, 366.4229, 257.5079,  ..., 311.3200,  68.8246, 297.8364],
        [257.1416, 366.1113, 257.0846,  ..., 325.6204, 415.3788, 340.2056],
        ...,
        [257.3893, 365.8344, 257.2850,  ..., 291.6462, 443.5652, 309.1646],
        [256.6430, 365.9973, 257.3819,  ..., 241.7899, 319.2498, 260.4140],
        [257.2666, 365.9467, 257.3541,  ..., 290.8035, 432.0573, 309.2923]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.8681, 365.5638, 256.5191,  ..., 108.8813, 207.0420,  85.1143],
        [256.9614, 365.7817, 256.6670,  ..., 336.2054, 298.6956, 335.6303],
        [256.8697, 365.5430, 256.4261,  ..., 336.2004, 298.5723, 331.9008],
        ...,
        [257.0569, 366.0598, 256.5929,  ..., 196.8427, 397.0625, 196.0422],
        [256.8960, 365.2580, 256.8127,  ..., 105.7343, 283.7197,  99.6771],
        [256.6157, 365.2005, 256.1739,  ..., 32

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.1320, 369.1812, 259.3537,  ..., 204.3395, 266.1608, 206.8520],
        [258.3703, 367.5884, 258.6219,  ..., 106.4856, 293.8138,  83.1306],
        [258.6516, 368.2577, 258.5243,  ..., 337.0974, 303.9916, 333.4396],
        ...,
        [259.2157, 368.9064, 259.3983,  ..., 181.8319, 217.2974, 178.8522],
        [258.4430, 368.7108, 259.2006,  ..., 218.6213, 299.8381, 238.1806],
        [258.1876, 367.8926, 258.6130,  ..., 335.1317, 381.9558, 345.6912]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.2886, 363.6116, 255.5034,  ..., 311.1787, 357.0718, 325.4313],
        [255.9829, 364.7164, 256.1945,  ..., 201.5887, 260.4965, 202.6623],
        [255.2028, 364.0278, 255.3989,  ..., 197.1377, 398.9907, 195.2205],
        ...,
        [255.8565, 364.1131, 255.8864,  ..., 163.1927, 280.0760, 159.1253],
        [255.2428, 363.4010, 255.5795,  ..., 287.5697, 308.2498, 296.4303],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.9827, 366.5610, 257.3040,  ..., 309.4312,  59.4531, 311.7715],
        [257.1746, 366.7262, 257.4830,  ..., 213.8954, 294.9284, 232.4429],
        [256.9603, 365.7606, 257.0518,  ..., 194.3651, 417.9355, 213.8588],
        ...,
        [257.3000, 366.4146, 256.8675,  ..., 328.8687, 426.7500, 347.4329],
        [256.8883, 365.6621, 256.3356,  ..., 327.7883, 294.9059, 331.4544],
        [257.1576, 366.9615, 258.1690,  ..., 237.9214, 442.6714, 238.7539]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.9092, 365.0175, 256.7214,  ..., 308.9622,  74.0672, 292.1613],
        [257.2567, 366.3648, 257.1157,  ..., 334.2015, 420.7129, 353.8253],
        [256.9626, 365.2801, 256.2860,  ..., 264.9215,  60.6200, 252.8251],
        ...,
        [256.6506, 365.2173, 256.8395,  ..., 107.0739, 292.1072,  83.0428],
        [256.9101, 365.4420, 256.6872,  ..., 146.7973, 390.9949, 149.8205],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.8477, 368.0013, 258.3155,  ..., 129.5364, 190.0060, 121.6654],
        [258.7570, 368.5469, 258.8542,  ..., 321.4792, 343.2832, 340.6029],
        [258.3434, 368.5315, 258.6678,  ..., 182.2938, 432.4871, 181.9565],
        ...,
        [258.8823, 368.3979, 258.8271,  ..., 256.4979, 464.5757, 254.6242],
        [258.8389, 368.1049, 258.8124,  ..., 269.1652, 466.6668, 270.5173],
        [258.3672, 367.9339, 258.6215,  ..., 209.4728, 312.4443, 211.1434]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[254.7016, 363.4418, 255.1662,  ..., 227.6061, 301.4760, 245.6226],
        [254.7748, 363.0857, 254.3974,  ..., 327.2155, 420.8376, 343.8900],
        [255.6993, 364.1584, 255.2301,  ..., 107.4474, 213.5265,  83.8976],
        ...,
        [255.7005, 364.4622, 255.6987,  ..., 335.4472, 297.8654, 339.5461],
        [255.3073, 364.3582, 255.0397,  ..., 129.0920, 158.8747, 106.1586],
    

output tensor([[254.2475, 362.4811, 254.5249,  ..., 105.9848, 304.9820,  85.3491],
        [255.4374, 363.3936, 255.0418,  ..., 141.7359, 324.2687, 139.4979],
        [254.7792, 362.7401, 255.0085,  ..., 228.6694, 415.1071, 250.1240],
        ...,
        [255.0446, 363.3951, 255.0974,  ..., 181.5671, 157.9758, 172.5376],
        [254.7441, 362.9877, 254.6764,  ..., 184.4393,  82.0297, 174.6454],
        [254.4051, 363.6532, 255.1550,  ..., 161.0686, 405.5425, 147.5901]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.1056, 369.1298, 259.2469,  ..., 202.7853, 264.5129, 204.8716],
        [259.1137, 368.5924, 259.2671,  ..., 341.6466, 344.8138, 347.2612],
        [258.3746, 368.1782, 258.9403,  ..., 225.4683, 307.4236, 246.4698],
        ...,
        [259.5734, 369.1380, 260.0323,  ..., 230.8474, 431.1384, 248.7275],
        [259.2825, 368.9152, 259.2031,  ..., 109.7568, 241.6863, 105.8946],
        [259.5217, 370.0311, 260.5129,  ..., 25

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.3487, 368.0525, 258.7555,  ..., 290.5294, 309.4055, 305.0757],
        [257.7865, 367.5445, 258.2435,  ..., 277.6775, 400.4610, 301.1531],
        [258.7856, 368.7054, 258.9816,  ..., 291.0192, 324.5838, 312.5122],
        ...,
        [257.4271, 367.3757, 257.6982,  ..., 172.3573,  93.9793, 162.0245],
        [258.0847, 367.8361, 257.5284,  ..., 195.5829, 397.2004, 196.8449],
        [258.6938, 368.7933, 258.5852,  ..., 240.9271, 452.3887, 239.1166]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.3960, 363.9319, 255.4684,  ..., 202.3046, 261.8785, 202.3593],
        [255.5644, 364.0352, 255.6802,  ..., 138.4993, 184.6045, 128.2425],
        [255.4259, 363.7647, 255.4504,  ..., 177.7074, 211.9153, 181.6844],
        ...,
        [255.3456, 363.5691, 255.5529,  ..., 109.8866, 330.8097,  93.1711],
        [255.7521, 364.7303, 256.0031,  ..., 126.8735, 187.6934, 115.8170],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.2255, 363.3443, 255.0719,  ..., 263.2774,  82.4118, 241.5634],
        [255.4966, 364.8748, 256.0269,  ..., 111.2367, 190.5171,  88.5946],
        [256.1732, 365.2419, 256.5747,  ..., 299.3992, 378.6446, 308.9393],
        ...,
        [256.6536, 365.6385, 256.9141,  ..., 328.4763, 308.1892, 336.5752],
        [256.9818, 366.6701, 256.6656,  ..., 250.7376, 390.0615, 261.0833],
        [256.9627, 366.4832, 257.1062,  ..., 155.2338, 387.4634, 143.3194]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.5205, 367.1994, 257.6900,  ..., 164.4729, 275.5118, 179.8689],
        [257.6482, 367.4254, 258.0968,  ..., 213.3151, 294.2517, 229.4098],
        [258.3925, 367.9368, 258.3206,  ..., 141.4471, 322.2718, 126.5120],
        ...,
        [258.9995, 369.3473, 259.9451,  ..., 334.8596, 334.3858, 346.7832],
        [256.8231, 366.5212, 256.9283,  ..., 137.8090, 159.9648, 116.7612],
    

output tensor([[256.1604, 364.9650, 256.5275,  ..., 241.3833, 312.3102, 259.2894],
        [256.5952, 366.0470, 256.7501,  ..., 310.9257,  63.8171, 304.0003],
        [256.9549, 366.1872, 256.8322,  ..., 166.6395, 277.2182, 181.2454],
        ...,
        [256.7338, 366.1289, 256.8118,  ..., 311.2433,  62.9085, 306.6748],
        [255.3044, 364.1084, 255.5980,  ..., 283.4550, 387.3460, 299.6137],
        [255.8571, 364.1539, 256.1287,  ..., 214.4010, 418.4031, 235.0396]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.7744, 366.3122, 257.0713,  ..., 109.6416, 190.6181,  88.7691],
        [256.7654, 365.5149, 256.5165,  ..., 330.2137, 416.7837, 351.1630],
        [257.8012, 366.3862, 256.9204,  ..., 263.5685,  60.8009, 251.2178],
        ...,
        [256.9454, 365.8504, 257.0718,  ..., 333.2856, 338.7767, 346.0605],
        [257.5955, 366.6412, 257.5867,  ..., 181.4373, 158.8407, 172.1493],
        [257.1774, 366.1818, 257.0887,  ..., 10

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.1360, 365.6317, 256.6212,  ..., 231.6065, 337.3359, 238.6565],
        [257.0865, 365.9488, 256.9639,  ..., 156.6771, 268.2198, 144.5241],
        [257.0337, 366.0655, 257.0814,  ..., 158.0394, 273.6595, 156.0659],
        ...,
        [256.6118, 365.0604, 256.4151,  ..., 170.3441, 321.3578, 188.8315],
        [256.4284, 365.8417, 256.3261,  ..., 133.5412, 157.0879, 111.8789],
        [256.3518, 364.7062, 255.8755,  ..., 111.8815, 336.0407, 125.8597]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[253.2534, 360.0029, 252.8783,  ..., 312.1855, 427.0092, 328.7870],
        [254.7730, 363.3534, 255.3531,  ..., 218.0616, 107.1379, 199.0712],
        [251.8648, 359.8753, 251.7846,  ..., 150.8517, 391.3824, 160.3248],
        ...,
        [253.6295, 361.7092, 253.9823,  ..., 143.1098, 174.8376, 122.3292],
        [252.9174, 360.7493, 253.3155,  ..., 119.4793, 217.5323, 111.7165],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.9044, 366.9742, 257.4128,  ..., 326.7534, 426.7650, 345.7299],
        [257.3088, 366.3413, 257.6312,  ..., 333.2850, 301.5085, 339.6501],
        [257.1583, 366.3593, 256.8163,  ..., 135.7777, 173.3100, 112.9268],
        ...,
        [257.4633, 367.6266, 257.4174,  ..., 309.0928,  60.8391, 313.3719],
        [256.2235, 365.5656, 257.0686,  ..., 275.2776, 418.1900, 294.6452],
        [256.8943, 365.8987, 257.5256,  ..., 226.3437, 426.7997, 245.6343]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.1437, 366.0056, 257.1267,  ..., 166.4532, 284.9041, 166.9783],
        [258.2440, 367.5521, 258.3918,  ..., 187.2027, 124.5322, 169.2959],
        [258.5110, 367.9262, 258.5051,  ..., 264.0541,  87.9045, 245.4342],
        ...,
        [256.7028, 365.7286, 257.0224,  ..., 108.2857, 306.3328,  86.4232],
        [257.1649, 366.2252, 257.3147,  ..., 182.8762, 423.2693, 195.2964],
    

output tensor([[256.5658, 365.0689, 256.3509,  ..., 260.3477,  86.3350, 240.2390],
        [256.2374, 365.3163, 256.2738,  ..., 303.3353,  58.5855, 308.9174],
        [255.8182, 364.5081, 255.8452,  ..., 334.7890, 349.3818, 341.9271],
        ...,
        [256.1271, 364.9835, 256.3974,  ..., 104.4950, 237.3385, 101.9360],
        [256.0424, 364.2510, 256.1560,  ..., 217.5570, 420.9101, 236.8038],
        [255.7217, 364.9663, 255.3491,  ..., 247.1117, 388.8461, 258.3318]],
       grad_fn=<AddmmBackward0>)
torch.Size([60, 12])
torch.Size([60, 3])
output tensor([[258.6409, 368.1431, 258.7242, 284.3956, 321.6655, 235.4341, 333.5971,
         251.2164, 371.4175, 346.0173, 351.5601, 347.9738],
        [258.0959, 366.7887, 257.4610, 283.7414, 211.7467, 217.1380, 229.0614,
         204.9125, 327.1252, 200.7336, 347.1797, 199.3145],
        [258.4732, 367.2079, 257.9830, 284.3743, 179.5629, 298.3015, 176.0745,
         277.4743,  97.0591, 219.5523,  87.5698, 199.7587],
        [258.2589, 367.28

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.1870, 365.1332, 256.1479,  ..., 220.2246, 367.0863, 241.7401],
        [257.0221, 365.6819, 256.8143,  ..., 182.6845, 130.4901, 164.2370],
        [256.4167, 365.0839, 256.1388,  ..., 198.1138, 375.1300, 221.8746],
        ...,
        [257.6058, 366.7888, 257.3387,  ..., 334.3064, 408.8252, 353.3755],
        [256.6006, 366.0598, 256.8463,  ..., 278.4911, 421.7421, 295.9656],
        [256.4167, 364.2844, 256.2102,  ..., 105.3802, 283.5997, 103.6160]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.9630, 365.3170, 256.3818,  ..., 192.5345, 352.1422, 189.2904],
        [257.0126, 365.5869, 256.9744,  ..., 287.4696, 445.0829, 306.7257],
        [256.6208, 365.3160, 256.7019,  ..., 105.4025, 233.8825,  94.5909],
        ...,
        [256.7138, 365.5656, 256.7112,  ..., 136.8457, 234.8136, 122.1174],
        [256.7883, 365.4411, 256.4559,  ..., 136.1941, 213.2608, 110.1376],
    

output tensor([[255.8933, 365.3362, 256.1675,  ..., 197.5254, 394.3009, 197.9727],
        [252.5260, 359.9106, 252.5137,  ..., 328.8073, 286.7531, 331.3197],
        [255.2199, 363.5777, 255.1578,  ..., 214.4137, 112.0637, 199.8628],
        ...,
        [253.8191, 361.2866, 253.8322,  ..., 288.5536, 313.4552, 306.5701],
        [256.5562, 364.7244, 255.9873,  ..., 189.9311, 343.6218, 196.1929],
        [253.1806, 361.4681, 253.9238,  ..., 337.6359, 337.7633, 343.9690]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[259.1476, 370.1023, 259.6445,  ..., 251.4249, 359.4659, 268.8082],
        [259.8835, 370.8152, 260.3336,  ..., 180.7285, 218.0904, 184.7917],
        [260.4030, 371.2899, 260.5719,  ..., 111.4056, 218.4065,  85.8331],
        ...,
        [259.8292, 371.2341, 261.5315,  ..., 250.2468, 439.8544, 255.2730],
        [259.2755, 369.2550, 258.9406,  ..., 333.2692, 422.7104, 350.7714],
        [261.2469, 371.8559, 260.8236,  ..., 13

output tensor([[255.7547, 364.3575, 255.7346,  ..., 338.0553, 343.0373, 344.9004],
        [256.0251, 364.5546, 255.7935,  ..., 331.1649, 395.0373, 351.3733],
        [255.8220, 364.0913, 255.6387,  ..., 191.7877, 301.4048, 185.2225],
        ...,
        [255.5766, 364.1858, 255.3106,  ..., 229.1321, 364.0110, 248.9375],
        [255.5784, 364.5399, 255.6252,  ..., 142.9263, 383.5182, 125.0541],
        [256.2547, 364.9445, 256.1321,  ..., 106.6155, 214.1872,  85.1029]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.0935, 368.3481, 258.5174,  ..., 177.2322, 425.1716, 164.4196],
        [258.4310, 368.6686, 258.7491,  ..., 239.2606, 452.4440, 240.4737],
        [259.1494, 368.5424, 258.9709,  ..., 185.1558, 139.3225, 163.5775],
        ...,
        [258.5984, 367.7228, 259.0497,  ..., 233.4187, 429.8453, 250.2857],
        [258.2236, 367.2503, 257.9882,  ..., 148.2441, 369.2537, 166.5296],
        [258.1364, 368.0121, 259.3773,  ..., 24

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.9808, 365.2734, 256.6927,  ..., 103.7940, 279.7759, 106.2523],
        [256.6754, 365.6039, 256.8612,  ..., 145.9900, 319.5485, 165.0479],
        [256.8607, 366.0739, 257.0958,  ..., 276.2036, 425.3760, 290.7286],
        ...,
        [256.5553, 365.2426, 256.4332,  ..., 105.1126, 299.7747,  84.7732],
        [257.1092, 365.7546, 256.5570,  ..., 264.3530,  63.8310, 249.7786],
        [256.5369, 366.1230, 256.3241,  ..., 251.5419, 384.4283, 265.7027]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.3538, 366.7538, 257.3526,  ..., 198.8217, 395.2179, 208.6849],
        [257.2047, 367.0590, 257.6825,  ..., 175.2427, 422.8628, 161.9407],
        [256.8253, 365.8307, 256.7570,  ..., 186.0071,  84.5715, 173.3274],
        ...,
        [256.7006, 365.6569, 256.8544,  ..., 177.1922, 193.7646, 178.0167],
        [257.4587, 366.9527, 257.5980,  ..., 334.5258, 407.0527, 346.0301],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[251.7698, 358.2038, 251.2827,  ..., 330.6699, 295.1880, 327.7491],
        [253.3795, 360.7985, 253.6613,  ..., 214.3491, 308.5008, 214.6403],
        [252.8511, 362.0455, 253.1449,  ..., 197.3540, 399.1424, 195.2828],
        ...,
        [253.5804, 362.3203, 253.6235,  ..., 181.1369, 424.9914, 169.7435],
        [252.3591, 359.0874, 251.9133,  ..., 108.1144, 266.9766,  84.0210],
        [253.7104, 360.8979, 253.5973,  ..., 214.4067, 104.9515, 194.6191]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[261.2609, 373.0666, 260.9720,  ..., 256.2231, 384.2411, 269.9681],
        [260.1012, 368.9294, 259.2791,  ..., 329.7062, 295.3952, 326.6356],
        [261.3976, 371.0049, 260.6072,  ..., 265.4026,  84.8393, 243.1221],
        ...,
        [261.3050, 372.4856, 261.3987,  ..., 197.3185, 401.2083, 199.2707],
        [260.9062, 370.6350, 260.2452,  ..., 183.5141, 156.2256, 165.5253],
    

torch.Size([60, 12])
torch.Size([60, 3])
output tensor([[256.0916, 364.5632, 255.8359, 282.0316, 211.6497, 216.3538, 228.5454,
         204.6266, 323.7710, 222.5423, 344.4930, 229.4850],
        [256.3734, 364.6024, 256.0819, 282.3361, 178.1233, 296.8621, 174.6292,
         276.6363, 132.5910, 185.2805, 122.0944, 168.1518],
        [255.3564, 363.7696, 255.3199, 281.6695, 300.8408, 216.5350, 317.6035,
         228.4770, 392.6259, 291.2717, 375.7991, 308.4297],
        [255.7087, 364.2820, 255.8127, 281.8667, 210.3062, 216.2543, 227.6373,
         204.5457, 322.7918, 179.4647, 324.4616, 202.9132],
        [255.6504, 364.6540, 255.5424, 281.9675, 256.7563, 202.4030, 277.9467,
         202.5637, 373.1628, 225.2648, 362.8244, 248.7444],
        [255.3708, 363.8545, 255.1094, 281.2648, 303.4198, 217.2251, 319.9056,
         230.0066, 369.0015, 144.4927, 388.2387, 142.7962],
        [255.9190, 364.3436, 255.8561, 281.8946, 181.8343, 253.8652, 189.6359,
         234.0432, 266.7392, 171.9191, 

output tensor([[255.1765, 364.1634, 255.2116,  ..., 137.6281, 163.3825, 117.0515],
        [255.1420, 363.4548, 255.1090,  ..., 138.9348, 178.2496, 125.9303],
        [254.5140, 362.9586, 254.7997,  ..., 174.2103, 419.1259, 164.8527],
        ...,
        [254.9715, 363.1465, 254.9894,  ..., 180.9587, 213.5876, 184.5309],
        [254.9781, 362.8281, 254.7829,  ..., 236.9201,  98.6299, 218.7562],
        [255.0222, 362.9836, 254.8734,  ..., 183.6526, 159.5669, 175.1247]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.6248, 368.7605, 258.9634,  ..., 335.7182, 379.9727, 341.5885],
        [259.3724, 369.1557, 259.4046,  ..., 108.3249, 313.7285,  85.2016],
        [259.4056, 369.4991, 259.3527,  ..., 343.3279, 343.0970, 343.9646],
        ...,
        [259.3954, 369.5455, 259.3750,  ..., 241.5116, 362.0388, 255.7658],
        [259.5463, 369.1171, 258.9383,  ..., 212.9515, 351.8309, 213.7229],
        [259.8283, 370.5540, 260.4752,  ..., 16

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.9682, 365.5258, 256.7127,  ..., 101.0415, 265.9543,  82.2821],
        [256.8266, 365.3531, 255.8952,  ..., 173.8617,  95.6523, 163.5460],
        [256.6550, 365.2211, 255.6479,  ..., 168.8463, 100.5195, 158.1530],
        ...,
        [260.1734, 368.7356, 258.8073,  ..., 334.2238, 348.1926, 345.8392],
        [259.2059, 368.8581, 257.5561,  ..., 306.0509,  60.8924, 304.4665],
        [259.0034, 366.3531, 257.4024,  ..., 316.9969, 420.2073, 332.3334]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.8931, 369.4814, 260.1871,  ..., 197.1958, 254.4297, 198.5634],
        [260.1198, 369.1312, 259.2950,  ..., 142.9234, 387.2861, 145.9168],
        [259.2481, 368.6812, 258.1701,  ..., 240.9242, 336.4847, 253.9991],
        ...,
        [259.5041, 369.0446, 259.1401,  ..., 217.7762,  95.3004, 200.7590],
        [257.9594, 367.7373, 257.3846,  ..., 137.9292, 226.2313, 122.0283],
    

output tensor([[258.3518, 367.6347, 258.0508,  ..., 332.1024, 411.1247, 354.2463],
        [257.5594, 366.6852, 257.4454,  ..., 179.0323, 166.8808, 178.7729],
        [258.1674, 366.7142, 257.9131,  ..., 145.7371, 368.9975, 164.4429],
        ...,
        [257.2905, 367.0334, 258.2106,  ..., 111.9796, 274.9616,  94.5496],
        [258.1609, 367.2696, 256.9226,  ..., 187.8536, 346.8312, 188.7540],
        [257.8295, 368.2189, 257.3512,  ..., 136.8577, 160.6951, 116.1577]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.9777, 365.5172, 256.8448,  ..., 331.1268, 392.5140, 350.2299],
        [256.6545, 365.1334, 257.0193,  ..., 111.2947, 285.0742, 121.6030],
        [256.5738, 365.2647, 256.2097,  ..., 214.9688,  93.1308, 196.0000],
        ...,
        [255.9890, 364.3674, 255.8108,  ..., 318.8530, 300.1263, 319.3158],
        [255.9186, 364.7579, 256.4559,  ..., 105.0518, 267.1174,  82.3170],
        [257.1627, 365.8427, 256.3677,  ..., 32

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.3420, 365.5780, 257.1170,  ..., 141.5345, 244.7262, 151.6320],
        [256.0002, 363.9229, 255.7781,  ..., 294.6442, 376.4771, 307.0148],
        [256.1177, 364.6999, 256.3980,  ..., 319.8971, 345.5263, 336.6554],
        ...,
        [256.7228, 365.7918, 256.8188,  ..., 168.3223, 405.2825, 182.5271],
        [256.6381, 365.3238, 256.7346,  ..., 120.1301, 290.7092,  98.0927],
        [256.9355, 365.5461, 256.5152,  ..., 333.6260, 411.0200, 353.8501]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.3954, 366.1397, 257.1776,  ..., 336.7292, 299.7202, 330.9520],
        [257.7822, 366.0999, 257.9191,  ..., 105.2430, 282.7420,  92.6959],
        [257.8137, 367.4629, 258.5991,  ..., 156.3490, 269.2588, 172.2421],
        ...,
        [257.8672, 367.6792, 257.7679,  ..., 235.8290, 451.5447, 235.2067],
        [257.6622, 366.8697, 258.0751,  ..., 179.3237, 215.6357, 186.5216],
    

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.8119, 364.2419, 255.9098,  ..., 143.7979, 384.4912, 125.0412],
        [255.9397, 364.9357, 256.6592,  ..., 213.1294, 428.4059, 231.4685],
        [256.0803, 364.3470, 255.9910,  ..., 104.0045, 264.6788,  80.5632],
        ...,
        [256.4189, 365.1692, 256.5464,  ..., 323.0105, 327.3426, 338.4390],
        [256.5029, 364.8010, 256.0435,  ..., 146.8271, 375.1027, 157.9791],
        [256.5297, 365.6191, 256.6786,  ..., 171.5896, 408.1071, 187.6733]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[258.9066, 368.3842, 258.4956,  ..., 329.8519, 427.1912, 350.8133],
        [258.1690, 367.7374, 258.2443,  ..., 235.8726, 452.1535, 239.3587],
        [258.0878, 367.1357, 257.7112,  ..., 209.7028, 309.1376, 206.4612],
        ...,
        [258.3580, 368.1227, 258.8428,  ..., 342.8625, 340.2900, 345.0698],
        [258.0956, 367.9144, 258.8030,  ..., 336.8713, 347.9884, 346.1805],
    

output tensor([[257.1949, 365.6807, 257.4982,  ..., 206.2482,  86.9468, 187.7280],
        [256.6454, 365.0913, 257.0353,  ..., 181.7224, 416.6373, 200.0144],
        [256.0451, 364.4461, 256.1670,  ..., 326.5174, 293.8621, 322.0739],
        ...,
        [256.5793, 365.5510, 257.2071,  ..., 213.5981, 296.8257, 231.5851],
        [257.3053, 365.8384, 256.9938,  ..., 181.3986, 163.0598, 171.5372],
        [257.1886, 366.8211, 256.9065,  ..., 137.2996, 159.1682, 113.9501]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.2124, 367.4346, 257.8839,  ..., 311.1020,  72.0667, 292.9178],
        [257.2123, 366.3669, 257.0590,  ..., 173.1165,  94.9831, 158.6541],
        [257.1636, 365.8974, 256.9563,  ..., 184.5646, 291.1185, 172.5064],
        ...,
        [258.1835, 368.2159, 258.6578,  ..., 336.5560, 395.6992, 354.8838],
        [257.2246, 366.2645, 257.1145,  ..., 116.5951, 216.9066, 109.7782],
        [257.1850, 366.8367, 257.8077,  ..., 34

output tensor([[256.8196, 365.5752, 256.3492,  ..., 176.1677,  88.9554, 163.9058],
        [256.2383, 363.9185, 256.1290,  ..., 215.1699, 416.3716, 236.7419],
        [256.0110, 364.8074, 256.0941,  ..., 151.8438, 393.4504, 134.6812],
        ...,
        [256.6011, 365.7554, 256.9240,  ..., 140.2888, 235.7602, 134.0592],
        [256.1492, 364.9221, 256.3653,  ..., 340.3822, 365.4355, 347.3851],
        [257.0094, 365.8280, 256.8521,  ..., 181.2588, 211.7457, 186.9512]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.3764, 367.4025, 257.8604,  ..., 308.4465,  63.3333, 300.2110],
        [257.9866, 366.7109, 257.9098,  ..., 104.6073, 276.6239,  87.5312],
        [257.1839, 367.0897, 256.9826,  ..., 195.7949, 399.8232, 197.2220],
        ...,
        [258.0394, 367.5811, 258.3299,  ..., 262.8385,  83.6476, 243.0373],
        [257.5757, 367.2740, 258.2718,  ..., 334.3748, 350.9517, 347.5271],
        [256.8364, 365.5745, 257.3029,  ..., 23

output tensor([[257.8856, 366.6030, 257.2303,  ..., 339.5443, 304.8707, 344.6467],
        [257.6124, 365.6554, 256.6619,  ..., 335.2393, 408.6879, 349.2357],
        [257.4346, 366.0189, 257.2564,  ..., 132.8884, 360.2340, 151.1408],
        ...,
        [256.9033, 366.3987, 257.7342,  ..., 108.8995, 232.4446,  98.8419],
        [257.4462, 365.4893, 256.5719,  ..., 334.9750, 411.7087, 351.1805],
        [256.6308, 365.4109, 257.0968,  ..., 307.5082, 364.0593, 315.8873]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[255.2652, 363.2710, 255.1297,  ..., 265.2894, 460.6373, 267.1636],
        [255.8138, 364.5020, 256.2151,  ..., 143.0138, 171.2749, 139.2089],
        [255.6890, 363.4560, 255.6875,  ..., 232.6517, 435.8209, 253.8308],
        ...,
        [256.4524, 365.4277, 257.0796,  ..., 330.4916, 393.9176, 353.2209],
        [255.6404, 363.8630, 255.5271,  ..., 140.1365, 328.1481, 136.6427],
        [255.9064, 363.7829, 255.1391,  ..., 32

output tensor([[256.8121, 365.7995, 256.4280,  ..., 151.4458, 392.3003, 163.4984],
        [257.2385, 366.1140, 257.2209,  ..., 166.2326, 138.7037, 149.9868],
        [257.3745, 366.3745, 257.2159,  ..., 110.4818, 238.1296, 105.4772],
        ...,
        [257.3005, 365.6969, 256.9513,  ..., 106.4952, 267.4440, 104.5512],
        [257.2280, 366.5179, 257.3310,  ..., 335.7138, 398.4576, 353.7637],
        [256.8236, 365.0647, 256.9737,  ..., 202.6078, 418.2978, 222.0068]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.6671, 366.0732, 257.0899,  ..., 137.6815, 238.2749, 139.2724],
        [257.5996, 365.9596, 257.0114,  ..., 288.2642, 316.7506, 303.9218],
        [257.3120, 366.6055, 256.8802,  ..., 152.2392, 394.9846, 166.2884],
        ...,
        [256.8379, 365.8557, 256.8689,  ..., 277.0187, 412.3674, 295.6818],
        [257.1606, 365.7147, 256.9345,  ..., 279.1821, 389.0033, 295.0756],
        [256.8311, 366.2573, 256.9854,  ..., 27

output tensor([[256.8857, 366.6058, 257.5665,  ..., 265.9156,  74.7909, 244.2948],
        [256.4882, 365.3182, 256.2921,  ..., 247.1254, 455.3861, 241.1904],
        [257.8866, 367.7670, 257.7866,  ..., 181.1011, 425.6824, 171.8695],
        ...,
        [258.1635, 366.8153, 257.0627,  ..., 146.7915, 172.9047, 139.6491],
        [256.2002, 364.4872, 256.2942,  ..., 308.8968, 359.5685, 317.9550],
        [256.5833, 365.1089, 256.8682,  ..., 313.7400, 360.1096, 328.4959]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.7840, 365.6386, 256.4178,  ..., 166.3054, 280.5188, 182.2681],
        [255.6017, 363.8810, 255.8969,  ..., 231.0648, 435.4175, 244.4821],
        [254.3958, 362.0671, 254.3263,  ..., 266.8609, 459.2018, 270.6960],
        ...,
        [256.9507, 365.7031, 256.5663,  ..., 166.1870, 285.0707, 171.5941],
        [256.0798, 364.6615, 256.8269,  ..., 114.4398, 275.3093,  93.5757],
        [256.9333, 365.4641, 256.4695,  ..., 17

output tensor([[257.4187, 366.4783, 257.3994,  ..., 334.3203, 352.2175, 348.6633],
        [256.8022, 365.9601, 257.0663,  ..., 341.5791, 342.4717, 346.5880],
        [257.0355, 365.8314, 256.9522,  ..., 191.4652, 342.6178, 205.6948],
        ...,
        [257.5753, 366.7289, 257.2390,  ..., 332.6862, 418.2373, 353.2808],
        [257.1115, 365.3996, 257.1691,  ..., 209.7233,  78.6340, 196.8342],
        [257.1928, 365.2842, 257.0962,  ..., 207.6553,  96.3216, 187.5170]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.4695, 365.5914, 256.1885,  ..., 237.9911, 342.5035, 246.2613],
        [256.2408, 365.6545, 257.1700,  ..., 173.6055, 327.3590, 194.2382],
        [256.5929, 366.3141, 257.1516,  ..., 160.3748, 271.5279, 177.2580],
        ...,
        [256.9269, 366.7060, 257.0038,  ..., 327.8416, 426.0735, 349.6355],
        [257.1547, 366.6341, 257.0962,  ..., 137.9164, 169.6005, 117.4710],
        [256.5833, 366.0258, 256.9379,  ..., 33

torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[257.0991, 366.2686, 257.2187,  ..., 259.1420,  85.7044, 239.0456],
        [257.1878, 365.7696, 257.0373,  ..., 174.0055, 285.2707, 158.3846],
        [256.2415, 366.2301, 256.6752,  ..., 118.6716, 169.5182,  97.0496],
        ...,
        [257.5593, 367.3506, 257.4276,  ..., 166.7807, 404.2133, 179.6207],
        [257.4592, 366.1657, 256.8505,  ..., 190.4445, 348.4646, 192.6750],
        [257.3581, 366.0845, 257.6296,  ..., 230.5232, 428.6717, 245.8172]],
       grad_fn=<AddmmBackward0>)
torch.Size([128, 12])
torch.Size([128, 3])
output tensor([[256.9741, 365.5286, 256.7393,  ..., 142.5862, 325.5501, 128.0860],
        [256.6971, 365.6763, 257.2019,  ..., 176.4835, 329.4633, 197.0022],
        [257.1752, 365.8211, 256.3951,  ..., 202.4978, 350.8584, 203.5629],
        ...,
        [257.5577, 365.7144, 257.0054,  ..., 298.9969, 308.2711, 304.9214],
        [256.9872, 365.5489, 256.7859,  ..., 139.7727, 323.0129, 126.4521],
    

In [76]:
def test_model(model_path, test_data_dir):
    # Load the test dataset
    test_dataset = KpVelDataset(test_data_dir)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Initialize the model and load the saved state
    model = KeypointRegressionNet()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # Criterion for evaluation
    criterion = nn.MSELoss()
    total_loss = 0

    # No gradient needed for evaluation
    with torch.no_grad():
        for start_kp, next_kp, velocity in test_loader:
            output = model(start_kp, velocity)
            for i in range(start_kp.size(0)):
                individual_start_kp = start_kp[i]
                individual_next_kp = next_kp[i]
                individual_velocity = velocity[i]
                predicted_next_kp = output[i]

                print("Start KP:", individual_start_kp)
                print("Next KP:", individual_next_kp)
                print("Actual Velocity:", individual_velocity)
                print("Predicted Velocity:", predicted_next_kp)
                print("-----------------------------------------")
            loss = criterion(output, next_kp)
            total_loss += loss.item()
            
    # Calculate the average loss
    avg_loss = total_loss / len(test_loader)
    print(f'Average Test Loss: {avg_loss}')

# Usage
model_path = '/home/jc-merlab/Pictures/Data/trained_models/reg_nkp_b128_e250_v0.pth'  # Update with your model path
test_data_dir = '/home/jc-merlab/Pictures/panda_data/panda_sim_vel/vel_reg_sim_test/split_folder_reg/test/annotations/'  # Update with your test data path
test_model(model_path, test_data_dir)

Start KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 179.8257,
        298.1946,   1.0000, 175.8687, 277.7923,   1.0000, 240.1268, 202.5661,
          1.0000, 262.0790, 205.0822,   1.0000])
Next KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 179.8260,
        298.1959,   1.0000, 175.8685, 277.7937,   1.0000, 240.1280, 202.5687,
          1.0000, 262.0802, 205.0853,   1.0000])
Actual Velocity: tensor([0.0000, 0.0000, 0.3000])
Predicted Velocity: tensor([257.7654, 366.7654,   0.9178, 257.9145, 283.0347,   1.1408, 179.7099,
        298.4261,   1.0625, 175.4932, 277.9028,   0.8915, 240.3632, 202.5495,
          1.0117, 261.5081, 204.5572,   1.1235])
-----------------------------------------
Start KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 179.8213,
        298.1725,   1.0000, 175.8698, 277.7691,   1.0000, 240.1633, 202.5730,
          1.0000, 262.1147, 205.0952,   1.0000])
Next KP: tensor([257.9522, 366.9199,   

Next KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 303.7598,
        217.8956,   1.0000, 320.7623, 229.8525,   1.0000, 419.9565, 231.2956,
          1.0000, 442.5465, 231.8152,   1.0000])
Actual Velocity: tensor([0.0000, 0.0893, 0.0000])
Predicted Velocity: tensor([258.3449, 367.1544,   0.7209, 258.5476, 283.7431,   1.3077, 303.9985,
        219.0668,   1.1029, 320.3009, 230.5019,   0.9057, 420.4909, 232.1425,
          1.0414, 442.1325, 232.9416,   1.1690])
-----------------------------------------
Start KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 303.7598,
        217.8956,   1.0000, 320.7623, 229.8525,   1.0000, 419.9565, 231.2956,
          1.0000, 442.5465, 231.8152,   1.0000])
Next KP: tensor([257.9522, 366.9199,   1.0000, 257.9597, 283.0131,   1.0000, 303.7659,
        217.8999,   1.0000, 320.7673, 229.8584,   1.0000, 419.9698, 230.4328,
          1.0000, 442.5634, 230.7548,   1.0000])
Actual Velocity: tensor([0.0000, 0.0893,