In [1]:
from smolvla_in_isaac.dataset.loading import load_and_split_dataset, create_dataloaders
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
device

device(type='cuda')

In [2]:
repo_id = "eternalmay33/pick_place_test"

In [3]:
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
    if delta_indices is None:
        return [0]
    return [i / fps for i in delta_indices]

delta_timestamps = {
    "action" : make_delta_timestamps(
        list(range(20)),
        30
    )
}

In [4]:
full_dataset, train_dataset, val_dataset = load_and_split_dataset(
    repo_id=repo_id, delta_timestamps=delta_timestamps
)

LeRobotDatasetMetadata({
    Repository ID: 'eternalmay33/pick_place_test',
    Total episodes: '39',
    Total frames: '8159',
    Features: '['action', 'observation.state', 'observation.images.front', 'observation.images.third_person', 'observation.images.gripper', 'timestamp', 'frame_index', 'episode_index', 'index', 'task_index']',
})',



In [5]:
full_dataset.features

{'action': {'dtype': 'float32',
  'shape': (6,),
  'names': ['shoulder_pan.pos',
   'shoulder_lift.pos',
   'elbow_flex.pos',
   'wrist_flex.pos',
   'wrist_roll.pos',
   'gripper.pos']},
 'observation.state': {'dtype': 'float32',
  'shape': (6,),
  'names': ['shoulder_pan.pos',
   'shoulder_lift.pos',
   'elbow_flex.pos',
   'wrist_flex.pos',
   'wrist_roll.pos',
   'gripper.pos']},
 'observation.images.front': {'dtype': 'video',
  'shape': (480, 640, 3),
  'names': ['height', 'width', 'channels'],
  'video_info': {'video.height': 480,
   'video.width': 640,
   'video.codec': 'av1',
   'video.pix_fmt': 'yuv420p',
   'video.is_depth_map': False,
   'video.fps': 30.0,
   'video.channels': 3,
   'has_audio': False},
  'info': {'video.height': 480,
   'video.width': 640,
   'video.codec': 'av1',
   'video.pix_fmt': 'yuv420p',
   'video.is_depth_map': False,
   'video.fps': 30,
   'video.channels': 3,
   'has_audio': False}},
 'observation.images.third_person': {'dtype': 'video',
  'shape'

In [6]:
train_loader, val_loader = create_dataloaders(
    train_dataset, val_dataset
)

In [7]:
class MLP(nn.Module):
    def __init__(self, inputdim, size, actiondim):
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features=inputdim, out_features=size),
            nn.ReLU(),
            nn.Linear(size, size),
            nn.ReLU()
        )

        self.decoder = nn.Linear(size, actiondim)
    
    def forward(self, batch):
        batch = self.net(batch)
        batch = self.decoder(batch)

        return batch

In [8]:
model = MLP(6, 32, 6).to(device)

model

MLP(
  (net): Sequential(
    (0): Linear(in_features=6, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ReLU()
  )
  (decoder): Linear(in_features=32, out_features=6, bias=True)
)

In [9]:
lr = 1e-4
criterion = nn.MSELoss()
optimizer = Adam(params=model.parameters(), lr=lr)

In [14]:
loss_list = [] 
n_epoch = 3_000

model.train()
 
for itr in range(0, n_epoch+1):
    total_loss = 0
    b=0
    for batch in tqdm(train_loader):

        states = batch["observation.state"].to(device)

        actions = batch["action"].to(device)[:, 0, :]

        print("="*50)
        print(states[0])
        print(actions[0])

        # print(actions)

        # print(states)
        # print(actions)

        y_pred = model(states)
        loss   = criterion(y_pred, actions) 
        total_loss += loss.item() 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        b += 1 
        
        loss_list.append(loss.item())
        
    print(f'Epoch {itr} Loss: {total_loss/b:.4f}')

  1%|          | 7/796 [00:00<00:49, 15.78it/s]

tensor([ -1.5393,  21.5249, -97.1096,  68.2947,  -1.6350,  20.5777],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([  1.2953,  27.3674, -40.5030,  90.0032, -11.8048,  21.3686],
       device='cuda:0')
tensor([ 4.8853, -7.6596, 17.8119,  8.5512, -0.2352,  6.9785], device='cuda:0')
tensor([ -1.8946,  16.8034, -98.2341,  76.7019,   8.2281,  21.4283],
       device='cuda:0')
tensor([ 5.1814, -0.2553, 15.9132,  7.3992,  2.0643,  1.3804], device='cuda:0')
tensor([  9.5804, -25.1486, -38.8444,  45.5175,  18.7182,  20.8577],
       device='cuda:0')
tensor([ 7.0318, -1.3617, 21.1573,  0.9304,  1.0191,  1.8405], device='cuda:0')
tensor([ -7.8358,   9.4741, -61.8914,  95.9055,  -0.8280,  20.9605],
       device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([13.0030, 10.1426, -4.7327, 67.6872, -2.6275, 32.5942], device='cuda:0')
tensor([ 1.7024, -6.1277, 23.3273,  2.8799

  2%|▏         | 15/796 [00:00<00:32, 24.33it/s]

tensor([ 11.9735,   2.2390, -17.8473,  81.9289,   0.4088,  20.8893],
       device='cuda:0')
tensor([ 11.8431,  -4.5957, -20.5244,  82.6318,   0.3919,   4.9847],
       device='cuda:0')
tensor([16.1963, 24.0647,  2.3340, 73.6565, 16.0294, 31.0620], device='cuda:0')
tensor([ -0.5922, -10.2979,  16.4557,  22.6407,   3.5276,  24.2331],
       device='cuda:0')
tensor([ -6.5913, -15.6493,  41.6720,  70.9859,  11.0637,  42.7840],
       device='cuda:0')
tensor([ 7.8460, -6.6383, 20.0723, 11.8299,  3.0572,  2.7607], device='cuda:0')
tensor([  2.0266,  37.2864, -56.4701,  95.5877,  -2.8785,  34.1089],
       device='cuda:0')
tensor([-0.3701, -8.4255, 21.1573, -4.6522,  0.3397,  1.7638], device='cuda:0')
tensor([ -0.2934,  53.3564, -54.2116,  97.8045, -17.2422,  38.1563],
       device='cuda:0')
tensor([ 5.6255, -7.7447, 16.6365, 10.8551,  1.2281,  4.6779], device='cuda:0')
tensor([ 1.4094, -1.4124,  8.1403, 55.2152,  0.2298, 12.4883], device='cuda:0')
tensor([ 1.7024, -6.1277, 23.3273,  2.8799

  3%|▎         | 23/796 [00:01<00:25, 30.54it/s]

tensor([  7.0216,  10.8401, -17.3887,  78.1308,   6.0113,  20.5492],
       device='cuda:0')
tensor([ 3.9230, -9.1064, 27.8481,  5.5383, -0.2874,  1.6871], device='cuda:0')
tensor([13.8072, -9.7436, 37.6008, 56.4555, 21.3590, 32.5086], device='cuda:0')
tensor([  2.0725, -18.4681,  30.4702,  18.0328,  -1.0713,   2.3773],
       device='cuda:0')
tensor([  6.9805,  58.0085, -59.9655,  84.0440,  -3.4881,  27.0396],
       device='cuda:0')
tensor([ 4.3671, -2.6383, 20.9765, -2.0824, -0.1307,  1.7638], device='cuda:0')
tensor([  7.3984,  52.6959, -59.7990,  86.0291,  -3.5271,  26.9876],
       device='cuda:0')
tensor([ 4.3671, -2.6383, 20.9765, -2.0824, -0.1307,  1.7638], device='cuda:0')
tensor([16.1665, 23.5864,  1.2165, 72.8838, 15.3010, 20.5352], device='cuda:0')
tensor([ -0.5922, -10.2979,  16.4557,  22.6407,   3.5276,  24.2331],
       device='cuda:0')
tensor([ -7.2728,  41.8506, -40.4707,  79.6455, -16.1551,  39.9460],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472

  4%|▍         | 31/796 [00:01<00:24, 31.18it/s]

tensor([ 1.6306, 20.3272,  4.0032, 57.8287, 15.1874, 23.5649], device='cuda:0')
tensor([ 6.4397, -2.9787, 23.9602,  3.6775,  1.3849,  1.3804], device='cuda:0')
tensor([ 7.3320,  9.0329,  9.6523, 81.3400, -5.4855, 34.7340], device='cuda:0')
tensor([ 3.9230, -9.1064, 27.8481,  5.5383, -0.2874,  1.6871], device='cuda:0')
tensor([ -6.1462,  36.1457, -65.3460,  88.0683,  -8.3081,  20.5725],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([-0.0574,  2.1161, -9.3589, 51.1345, -5.1716, 16.8891], device='cuda:0')
tensor([-0.3701, -8.4255, 21.1573, -4.6522,  0.3397,  1.7638], device='cuda:0')
tensor([-13.8153,  41.8645, -36.5532,  97.7745, -21.7538,  38.1822],
       device='cuda:0')
tensor([ 5.6255, -7.7447, 16.6365, 10.8551,  1.2281,  4.6779], device='cuda:0')
tensor([-12.6774,  43.6851, -41.6761,  99.2756,  -4.2817,  44.6559],
       device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       devic

  5%|▍         | 39/796 [00:01<00:25, 29.92it/s]

tensor([ -1.8454,  42.2184, -40.2730,  89.0312, -15.8617,  39.7939],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([ 5.4146, 23.9415, -1.4656, 66.2577, 13.8705, 25.5512], device='cuda:0')
tensor([ 5.9956, -2.8936, 21.1573,  1.4621,  2.5346,  1.6871], device='cuda:0')
tensor([  5.0688, -32.1987, -11.6134,  59.6870,  -1.7162,  21.4277],
       device='cuda:0')
tensor([ 1.3324, 10.6383,  6.4195, 11.0323,  0.9668, 17.3313], device='cuda:0')
tensor([ 0.9113, 19.1573, -7.2696, 98.7322, -5.2136, 38.3156], device='cuda:0')
tensor([ 4.8853, -7.6596, 17.8119,  8.5512, -0.2352,  6.9785], device='cuda:0')
tensor([ 4.1426, 14.8820,  2.5536, 86.1611,  1.9898, 30.8243], device='cuda:0')
tensor([ 1.6284, 14.8936,  1.4467, 84.2269,  2.6914, 30.8282], device='cuda:0')
tensor([  7.0262,  18.6112, -78.2949,  81.5614,   2.1879,  22.1950],
       device='cuda:0')
tensor([-0.3701, -8.4255, 21.1573, -4.6522,  0.3397,  1.7638], device='cuda:0')
t

  6%|▌         | 47/796 [00:01<00:24, 30.39it/s]

tensor([ -6.6587, -16.1206,  41.7474,  71.3492,  10.3590,  42.7836],
       device='cuda:0')
tensor([ 7.8460, -6.6383, 20.0723, 11.8299,  3.0572,  2.7607], device='cuda:0')
tensor([ 11.8422,  42.0378, -50.5693,  99.1718,  21.9188,  32.1281],
       device='cuda:0')
tensor([ 6.2916,  4.2553,  0.2712, 29.0208,  0.0784, 19.4018], device='cuda:0')
tensor([  0.1177,   0.2253, -61.1300,  69.1239,  26.1748,  21.6575],
       device='cuda:0')
tensor([ 4.6632,  1.7021, 17.9928, -1.7280, -0.7055,  7.2853], device='cuda:0')
tensor([ 8.0611, -9.6339, 23.5264, 66.7072,  8.5885, 33.5757], device='cuda:0')
tensor([  1.4064, -11.2340,  35.2622,  -3.3230,   4.3115,   1.9172],
       device='cuda:0')
tensor([ 0.9258, 19.1032, -7.7329, 98.7226, -5.6176, 38.3010], device='cuda:0')
tensor([ 4.8853, -7.6596, 17.8119,  8.5512, -0.2352,  6.9785], device='cuda:0')
tensor([-0.8284,  8.7510, 13.5139, 59.1321, 15.9787, 48.0308], device='cuda:0')
tensor([ 6.4397, -2.9787, 23.9602,  3.6775,  1.3849,  1.3804], devic

  6%|▋         | 51/796 [00:01<00:23, 31.38it/s]

tensor([  0.2597,  60.2037, -61.1059,  81.1108,  38.1076,  49.6100],
       device='cuda:0')
tensor([ 4.6632,  1.7021, 17.9928, -1.7280, -0.7055,  7.2853], device='cuda:0')
tensor([  6.9574,  27.4578, -78.3074,  81.4006,   2.1582,  22.1816],
       device='cuda:0')
tensor([-0.3701, -8.4255, 21.1573, -4.6522,  0.3397,  1.7638], device='cuda:0')
tensor([-16.8719, -21.5577, -65.3323,  91.8916,   6.0675,  21.4705],
       device='cuda:0')
tensor([ 1.9985, -8.1702, 10.5787, 24.1471, -1.5939, 20.3221], device='cuda:0')
tensor([-8.8111, 22.9988, -5.8211, 75.8831,  3.4368, 44.0528], device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([-16.9076, -14.2316, -63.8570,  95.0323,   6.0666,  21.4762],
       device='cuda:0')
tensor([ 1.9985, -8.1702, 10.5787, 24.1471, -1.5939, 20.3221], device='cuda:0')
tensor([ -7.5496,  47.6598, -38.6649,  98.4446,   5.7150,   4.9109],
       device='cuda:0')
tensor([ 1.9985, -8.1702, 10.5787, 24.1471

  7%|▋         | 59/796 [00:02<00:25, 28.93it/s]

tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([ 7.8684,  6.5122,  5.9365, 65.1497, 14.6244, 35.1908], device='cuda:0')
tensor([ 4.3671, -1.1064, 21.1573, -1.0191,  3.5798,  1.5337], device='cuda:0')
tensor([-10.8199,  25.5509, -19.7325,  79.6435, -10.5291,  40.0226],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([ -6.7033,  42.1168, -40.8026,  79.6666, -16.1622,  39.8762],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([-12.6623,  52.0612, -58.3957,  98.5485,  -4.4138,  44.6252],
       device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([ 2.9557, 17.6327,  4.4731, 85.8195,  8.2972, 32.2814], device='cuda:0')
tensor([ 3.9970, -3.3191, 19.7107,  6.4245,  1.4372,  1.9172], device='cuda:0')


  8%|▊         | 63/796 [00:02<00:24, 29.60it/s]

tensor([ 2.1358, -3.2772, 12.3097, 26.8029, -1.6024,  5.0014], device='cuda:0')
tensor([ 5.6255, -7.7447, 16.6365, 10.8551,  1.2281,  4.6779], device='cuda:0')
tensor([  9.5062, -24.5624, -21.2859,  65.6193,  18.6523,  21.0385],
       device='cuda:0')
tensor([ 7.0318, -1.3617, 21.1573,  0.9304,  1.0191,  1.8405], device='cuda:0')
tensor([ -2.5932,  -6.2077, -14.4841,  90.9987,   2.6775,  21.2249],
       device='cuda:0')
tensor([ -2.2946, -18.0426,  26.2206,  13.4249,   3.1095,   1.6104],
       device='cuda:0')
tensor([ 0.8234, 19.5603, -6.3658, 98.6174, -4.6969, 38.3357], device='cuda:0')
tensor([ 4.8853, -7.6596, 17.8119,  8.5512, -0.2352,  6.9785], device='cuda:0')
tensor([ -3.0705, -13.6354,   4.5142,  57.3477,  12.8822,  22.0348],
       device='cuda:0')
tensor([ 6.4397, -2.9787, 23.9602,  3.6775,  1.3849,  1.3804], device='cuda:0')
tensor([ 9.5895, -5.9224, -3.3067, 99.7789, 18.3751,  8.3804], device='cuda:0')
tensor([ 9.4004, -7.0638,  0.3617, 99.6455, 19.1011, 10.2761], devic

  9%|▉         | 71/796 [00:02<00:24, 29.41it/s]

tensor([-2.7003, -6.0849, 11.4026, 90.6102,  1.4710, 36.6540], device='cuda:0')
tensor([ 1.4804, -6.2128, 20.8861, 15.2858,  1.2804, 19.0184], device='cuda:0')
tensor([21.5544, 19.8139,  9.2666, 80.3303, 22.0982, 40.1058], device='cuda:0')
tensor([ 5.7735,  1.7872,  6.7812,  1.4621, -2.5869,  1.3037], device='cuda:0')
tensor([ 9.3336, 19.4329,  8.4783, 79.0938, 28.6279, 35.2689], device='cuda:0')
tensor([ 4.3671, -1.1064, 21.1573, -1.0191,  3.5798,  1.5337], device='cuda:0')
tensor([ -1.5307,  11.8458, -13.5558,  78.4638,  31.3228,  49.5484],
       device='cuda:0')
tensor([ 4.6632,  1.7021, 17.9928, -1.7280, -0.7055,  7.2853], device='cuda:0')
tensor([ -0.1504,  37.6516, -26.0414,  85.1074,  -1.6453,  37.6873],
       device='cuda:0')
tensor([ 2.8127, -2.5532, 14.0145, 16.9694,  2.6914, 25.0000], device='cuda:0')
tensor([ 13.7224, -36.3703,   0.8324,  67.0731,  13.0344,  21.0082],
       device='cuda:0')
tensor([ -2.0725, -10.2979,  16.0036,  40.3633,   6.0361,  24.3098],
       devic

  9%|▉         | 75/796 [00:02<00:24, 29.05it/s]

tensor([ -7.9186,  55.0801, -61.8251,  98.6993,  -4.3395,  21.5512],
       device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([ -6.9078,  43.3314, -43.8004,  88.2931,  -8.3999,  20.7817],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([  1.2589,  17.0777, -10.8260,  85.7165, -18.0431,   9.1433],
       device='cuda:0')
tensor([ 4.8853, -7.6596, 17.8119,  8.5512, -0.2352,  6.9785], device='cuda:0')
tensor([13.5198, -8.2897, 36.9488, 53.8881, 21.2517, 32.5098], device='cuda:0')
tensor([  2.0725, -18.4681,  30.4702,  18.0328,  -1.0713,   2.3773],
       device='cuda:0')
tensor([ 7.8640,  6.3593,  5.7560, 65.0307, 14.5114, 35.1897], device='cuda:0')
tensor([ 4.3671, -1.1064, 21.1573, -1.0191,  3.5798,  1.5337], device='cuda:0')


 10%|█         | 82/796 [00:03<00:27, 26.17it/s]

tensor([ 4.3707,  4.4193, 18.5161,  4.4643,  0.3853,  7.8787], device='cuda:0')
tensor([ 4.6632,  1.7021, 17.9928, -1.7280, -0.7055,  7.2853], device='cuda:0')
tensor([ -1.3536, -42.5878,  34.5400,  84.9747,   4.1120,  34.8888],
       device='cuda:0')
tensor([ 3.9230, -9.1064, 27.8481,  5.5383, -0.2874,  1.6871], device='cuda:0')
tensor([-10.7968,  28.7308, -24.3850,  79.6457, -10.5293,  40.0233],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([ -6.4732,  36.6375, -60.8954,  88.0689,  -8.3561,  20.5660],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([  3.2623, -11.6753,  -6.2563,  59.9154,  -2.3056,  12.5978],
       device='cuda:0')
tensor([  3.9230, -13.3617,  16.3653,   7.1334,  -2.1688,  12.4233],
       device='cuda:0')
tensor([17.9287, -9.8915, 35.5751, 56.2466, 21.3875, 32.5080], device='cuda:0')
tensor([  2.0725, -18.4681,  30.4702,  18.0328,  -1.071

 11%|█         | 89/796 [00:03<00:26, 27.10it/s]

tensor([-22.4813,  22.5706,  -0.6012,  99.0981,   1.8654,  55.2885],
       device='cuda:0')
tensor([ 1.9985, -8.1702, 10.5787, 24.1471, -1.5939, 20.3221], device='cuda:0')
tensor([10.6288, 14.4303,  4.5233, 67.9579, 11.7312, 40.2658], device='cuda:0')
tensor([ -0.5922, -10.2979,  16.4557,  22.6407,   3.5276,  24.2331],
       device='cuda:0')
tensor([ 8.4374, -2.5585, 21.2769, 78.9534, -0.6836, 24.3055], device='cuda:0')
tensor([ 7.4760,  1.1064, 13.0199,  9.1715,  0.1307,  1.6104], device='cuda:0')
tensor([  3.8621,  16.1504,  -5.6439,  79.5843, -10.6552,  39.8699],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([ 2.5712, -7.9156, 20.8587, 79.2777, -1.0705, 24.3081], device='cuda:0')
tensor([ 7.4760,  1.1064, 13.0199,  9.1715,  0.1307,  1.6104], device='cuda:0')
tensor([  5.1835, -33.8084, -11.0161,  59.6135,  -1.6851,  21.3894],
       device='cuda:0')
tensor([ 1.3324, 10.6383,  6.4195, 11.0323,  0.9668, 17.3313], devic

 12%|█▏        | 97/796 [00:03<00:27, 25.75it/s]

tensor([ -2.5257, -18.4996,  13.0394,  99.3435,   3.6249,  24.2560],
       device='cuda:0')
tensor([ 7.4760,  1.1064, 13.0199,  9.1715,  0.1307,  1.6104], device='cuda:0')
tensor([  5.1158,  -7.1381, -16.1375,  60.1625,  -1.7532,  21.6396],
       device='cuda:0')
tensor([ 1.3324, 10.6383,  6.4195, 11.0323,  0.9668, 17.3313], device='cuda:0')
tensor([  0.2334,  33.7491, -25.8864,  84.5097,  -1.6364,  21.0139],
       device='cuda:0')
tensor([ 2.8127, -2.5532, 14.0145, 16.9694,  2.6914, 25.0000], device='cuda:0')
tensor([ 6.4391, 23.5193,  2.4111, 64.3152, -1.2360, 41.4813], device='cuda:0')
tensor([ 1.3324, 10.6383,  6.4195, 11.0323,  0.9668, 17.3313], device='cuda:0')
tensor([ -2.0563,  27.7726, -33.7184,  91.9861,   7.1783,  38.1810],
       device='cuda:0')
tensor([ 5.1814, -0.2553, 15.9132,  7.3992,  2.0643,  1.3804], device='cuda:0')
tensor([ 15.9105, -13.2923,  32.0358,  80.1841,  26.5674,  43.3989],
       device='cuda:0')
tensor([ 7.0318, -1.3617, 21.1573,  0.9304,  1.0191,  1

 13%|█▎        | 105/796 [00:03<00:25, 27.30it/s]

tensor([21.5462, 18.1221,  9.2084, 79.1582, 19.9867, 21.0620], device='cuda:0')
tensor([ 5.7735,  1.7872,  6.7812,  1.4621, -2.5869,  1.3037], device='cuda:0')
tensor([  5.1926, -16.8559, -16.4392,  59.8759,  -1.6986,  21.7999],
       device='cuda:0')
tensor([ 1.3324, 10.6383,  6.4195, 11.0323,  0.9668, 17.3313], device='cuda:0')
tensor([ 7.1701e-02,  7.8077e+00, -5.9655e+00,  7.3499e+01,  1.8544e+01,
         2.3374e+01], device='cuda:0')
tensor([ 4.6632,  1.7021, 17.9928, -1.7280, -0.7055,  7.2853], device='cuda:0')
tensor([14.7234, -9.4353, 23.6585, 66.6164, -2.7881, 33.5646], device='cuda:0')
tensor([  1.4064, -11.2340,  35.2622,  -3.3230,   4.3115,   1.9172],
       device='cuda:0')
tensor([ 2.5399,  3.7181, 17.2927, 57.0415, 16.7020, 28.5475], device='cuda:0')
tensor([ 6.4397, -2.9787, 23.9602,  3.6775,  1.3849,  1.3804], device='cuda:0')
tensor([ 2.1205,  4.3703,  9.7535, 20.7792,  1.3762, 12.6056], device='cuda:0')
tensor([ 4.6632,  1.7021, 17.9928, -1.7280, -0.7055,  7.2853],

 14%|█▍        | 113/796 [00:04<00:24, 28.14it/s]

tensor([ -5.7599, -14.1862,   9.7304,  81.4429,   3.5929,  42.7721],
       device='cuda:0')
tensor([ 7.8460, -6.6383, 20.0723, 11.8299,  3.0572,  2.7607], device='cuda:0')
tensor([ 7.3150,  8.1918,  8.8714, 64.6935, 16.3524, 35.8057], device='cuda:0')
tensor([ 5.9956, -2.8936, 21.1573,  1.4621,  2.5346,  1.6871], device='cuda:0')
tensor([  0.1372,  -3.5598, -61.2159,  67.6761,  26.1323,  21.6926],
       device='cuda:0')
tensor([ 4.6632,  1.7021, 17.9928, -1.7280, -0.7055,  7.2853], device='cuda:0')
tensor([15.1110, 17.5136,  2.0965, 84.7065,  0.1766, 50.9144], device='cuda:0')
tensor([ 1.7024, -6.1277, 23.3273,  2.8799,  2.1688,  6.5951], device='cuda:0')
tensor([ -9.6814,  54.6410, -61.7796,  98.5785,  -4.1148,  44.6249],
       device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([ 6.2101,  7.5434,  6.2887, 65.4183, 13.8182, 43.8559], device='cuda:0')
tensor([ 3.5529, -1.0213, 18.3544,  7.6650, -2.4301,  7.7454], devic

 15%|█▌        | 121/796 [00:04<00:24, 27.61it/s]

tensor([ 5.9482, 18.9618, 10.3053, 79.4471, 26.9696, 36.7271], device='cuda:0')
tensor([ 4.3671, -1.1064, 21.1573, -1.0191,  3.5798,  1.5337], device='cuda:0')
tensor([-6.5137, -0.2549, 32.7830, 70.8823, 11.5947, 42.7373], device='cuda:0')
tensor([ 7.8460, -6.6383, 20.0723, 11.8299,  3.0572,  2.7607], device='cuda:0')
tensor([ -0.6460, -27.5708,  34.2884,  80.6644,  -9.6121,  34.7331],
       device='cuda:0')
tensor([ 3.9230, -9.1064, 27.8481,  5.5383, -0.2874,  1.6871], device='cuda:0')
tensor([ 2.1602,  4.1818, -5.5129, 72.8481, -7.2484, 28.0528], device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([  5.1734, -33.8140, -11.1616,  59.6263,  -1.6882,  21.3779],
       device='cuda:0')
tensor([ 1.3324, 10.6383,  6.4195, 11.0323,  0.9668, 17.3313], device='cuda:0')
tensor([ -1.8919,  17.3041, -98.1371,  76.8779,   8.2283,  21.3161],
       device='cuda:0')
tensor([ 5.1814, -0.2553, 15.9132,  7.3992,  2.0643,  1.3804], devic

 16%|█▌        | 129/796 [00:04<00:23, 28.39it/s]

tensor([21.5475,  8.7699, 10.3962, 89.2560, 21.5271, 40.7898], device='cuda:0')
tensor([ 5.7735,  1.7872,  6.7812,  1.4621, -2.5869,  1.3037], device='cuda:0')
tensor([ -6.3660,  58.1725, -60.9962,  96.7787,  -4.3837,  38.0287],
       device='cuda:0')
tensor([ 5.6255, -7.7447, 16.6365, 10.8551,  1.2281,  4.6779], device='cuda:0')
tensor([ 2.4757, -1.3021, 21.8848, 32.0938, 10.5273, 23.8120], device='cuda:0')
tensor([ 6.4397, -2.9787, 23.9602,  3.6775,  1.3849,  1.3804], device='cuda:0')
tensor([ 0.5318, 18.7877,  4.6869, 57.7528, 15.2288, 22.0555], device='cuda:0')
tensor([ 6.4397, -2.9787, 23.9602,  3.6775,  1.3849,  1.3804], device='cuda:0')
tensor([13.6138, -9.7919, 37.5628, 56.4593, 21.3556, 32.5091], device='cuda:0')
tensor([  2.0725, -18.4681,  30.4702,  18.0328,  -1.0713,   2.3773],
       device='cuda:0')
tensor([ 1.1888, 22.2057,  3.3781, 87.1002,  8.0411, 31.6718], device='cuda:0')
tensor([ 3.9970, -3.3191, 19.7107,  6.4245,  1.4372,  1.9172], device='cuda:0')
tensor([21.001

 17%|█▋        | 137/796 [00:04<00:22, 29.88it/s]

tensor([  2.0392, -10.3904, -11.1790,  71.5210,  13.8183,  21.6405],
       device='cuda:0')
tensor([ 3.5529, -1.0213, 18.3544,  7.6650, -2.4301,  7.7454], device='cuda:0')
tensor([-10.4434,  37.4300, -36.0478,  79.6612, -11.3144,  39.9937],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([  3.8418, -13.7945,   7.6080,  97.3523,   1.3675,  19.2079],
       device='cuda:0')
tensor([ 1.4804, -6.2128, 20.8861, 15.2858,  1.2804, 19.0184], device='cuda:0')
tensor([  4.1033,  10.6432, -12.5508,  98.9074,   6.2257,  39.7155],
       device='cuda:0')
tensor([ -2.2946, -18.0426,  26.2206,  13.4249,   3.1095,   1.6104],
       device='cuda:0')
tensor([ 1.1573, 22.3991,  3.3885, 86.3009,  7.2598, 22.0183], device='cuda:0')
tensor([ 3.9970, -3.3191, 19.7107,  6.4245,  1.4372,  1.9172], device='cuda:0')
tensor([ -1.8893,  17.2483, -98.1363,  76.8660,   8.2284,  21.3149],
       device='cuda:0')
tensor([ 5.1814, -0.2553, 15.9132,  7.3992

 18%|█▊        | 145/796 [00:05<00:22, 28.46it/s]

tensor([-6.1838, 22.4067, -5.5693, 74.1606, -1.8820, 28.1902], device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([  2.7587,   8.1593, -14.5555,  54.8059, -15.6734,  17.5570],
       device='cuda:0')
tensor([ 5.6255, -7.7447, 16.6365, 10.8551,  1.2281,  4.6779], device='cuda:0')
tensor([-0.7812,  5.1306, 18.9856, 59.8732, 21.1192, 29.2093], device='cuda:0')
tensor([ 7.0318, -1.3617, 21.1573,  0.9304,  1.0191,  1.8405], device='cuda:0')
tensor([ 1.6246,  3.7598, 13.0155, 62.2156, 11.6721, 35.8000], device='cuda:0')
tensor([ 5.9956, -2.8936, 21.1573,  1.4621,  2.5346,  1.6871], device='cuda:0')
tensor([ -7.0796,  46.6800, -40.3638,  88.8872,  -8.5076,  39.7938],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([14.3632, 18.3386, -8.1139, 81.7817, -1.0043, 20.6755], device='cuda:0')
tensor([ 1.7024, -6.1277, 23.3273,  2.8799,  2.1688,  6.5951], device='cuda:0')
t

 19%|█▉        | 153/796 [00:05<00:21, 29.97it/s]

tensor([ -5.2470, -22.6168, -62.6860,  84.7303,  -0.4270,  21.0451],
       device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([ 3.5919e-02,  2.9854e+01, -6.0627e+01,  7.4869e+01,  2.9669e+01,
         2.1702e+01], device='cuda:0')
tensor([ 4.6632,  1.7021, 17.9928, -1.7280, -0.7055,  7.2853], device='cuda:0')
tensor([  1.8939,  42.1675, -50.7182,  80.9438,  -3.7868,  26.9874],
       device='cuda:0')
tensor([ 4.3671, -2.6383, 20.9765, -2.0824, -0.1307,  1.7638], device='cuda:0')
tensor([  6.1673,  -4.8704, -55.4834,  72.8581,  14.2966,  20.8302],
       device='cuda:0')
tensor([ 4.3671, -1.1064, 21.1573, -1.0191,  3.5798,  1.5337], device='cuda:0')
tensor([21.2195,  6.4844, 10.1650, 89.2181, 21.4408, 40.7875], device='cuda:0')
tensor([ 5.7735,  1.7872,  6.7812,  1.4621, -2.5869,  1.3037], device='cuda:0')
tensor([  0.8517, -27.5686,  34.3097,  80.8870,  -9.5966,  34.7334],
       device='cuda:0')
tensor([ 3.9230, -9.106

 20%|██        | 161/796 [00:05<00:21, 29.30it/s]

tensor([ 15.9114, -11.8964,  31.9150,  80.0869,  26.6155,  43.3985],
       device='cuda:0')
tensor([ 7.0318, -1.3617, 21.1573,  0.9304,  1.0191,  1.8405], device='cuda:0')
tensor([ 7.1008, 12.0362,  9.3954, 81.2401, -5.4625, 34.7359], device='cuda:0')
tensor([ 3.9230, -9.1064, 27.8481,  5.5383, -0.2874,  1.6871], device='cuda:0')
tensor([-12.4454,  32.3832, -30.5575,  97.8539, -25.0913,  38.0952],
       device='cuda:0')
tensor([ 5.6255, -7.7447, 16.6365, 10.8551,  1.2281,  4.6779], device='cuda:0')
tensor([15.4829, -9.0569, 10.3083, 99.1486, 20.9319, 29.1744], device='cuda:0')
tensor([ 5.7735,  1.7872,  6.7812,  1.4621, -2.5869,  1.3037], device='cuda:0')
tensor([-2.6778, 17.8234,  3.8385, 90.8160,  1.1143, 21.2556], device='cuda:0')
tensor([ -2.2946, -18.0426,  26.2206,  13.4249,   3.1095,   1.6104],
       device='cuda:0')
tensor([  1.6765,  40.9731, -48.7828,  80.9007,  -3.7865,  26.9873],
       device='cuda:0')
tensor([ 4.3671, -2.6383, 20.9765, -2.0824, -0.1307,  1.7638], devic

 21%|██        | 169/796 [00:06<00:20, 30.13it/s]

tensor([-3.5843e-02,  1.0657e+01, -2.8488e+01,  9.8716e+01,  4.8324e+00,
         3.9343e+01], device='cuda:0')
tensor([ -2.2946, -18.0426,  26.2206,  13.4249,   3.1095,   1.6104],
       device='cuda:0')
tensor([  5.7257,  57.8247, -59.5908,  83.7421,  -1.6605,  24.0743],
       device='cuda:0')
tensor([ 4.3671, -2.6383, 20.9765, -2.0824, -0.1307,  1.7638], device='cuda:0')
tensor([ -7.8420,   1.8588, -62.2526,  94.2216,  -0.4850,  20.9540],
       device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([ 6.1070, 27.1122, -2.8942, 66.5184, -0.6962, 41.4810], device='cuda:0')
tensor([ 1.3324, 10.6383,  6.4195, 11.0323,  0.9668, 17.3313], device='cuda:0')
tensor([14.9656, 11.4859, 12.1454, 66.5811, -2.9852, 33.5347], device='cuda:0')
tensor([  1.4064, -11.2340,  35.2622,  -3.3230,   4.3115,   1.9172],
       device='cuda:0')
tensor([-1.0867,  6.7814, -8.3969, 78.0996,  5.7008, 31.5032], device='cuda:0')
tensor([ 7.8460, -6.638

 22%|██▏       | 177/796 [00:06<00:20, 30.84it/s]

tensor([19.7913, 15.6531, -5.4578, 74.4344, -1.7999, 28.1671], device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([  6.0072,  -6.3834, -64.0167,  72.8420,  14.2948,  20.9610],
       device='cuda:0')
tensor([ 4.3671, -1.1064, 21.1573, -1.0191,  3.5798,  1.5337], device='cuda:0')
tensor([ 1.4225, 12.2775,  7.4304,  9.5831,  0.8536, 16.5287], device='cuda:0')
tensor([ 1.3324, 10.6383,  6.4195, 11.0323,  0.9668, 17.3313], device='cuda:0')
tensor([13.4764,  4.4790, -9.2796, 88.8880, 12.2580, 28.5160], device='cuda:0')
tensor([  3.9230, -13.3617,  16.3653,   7.1334,  -2.1688,  12.4233],
       device='cuda:0')
tensor([-3.5271,  6.1665, -9.8211, 83.7295,  5.7256, 37.0016], device='cuda:0')
tensor([ 7.8460, -6.6383, 20.0723, 11.8299,  3.0572,  2.7607], device='cuda:0')
tensor([14.0265, 14.1103, -5.9985, 74.2340, -6.0952, 28.2156], device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       devic

 23%|██▎       | 185/796 [00:06<00:20, 29.83it/s]

tensor([  0.1585,  37.2273, -26.1043,  84.5829,  -1.6853,  21.0998],
       device='cuda:0')
tensor([ 2.8127, -2.5532, 14.0145, 16.9694,  2.6914, 25.0000], device='cuda:0')
tensor([-0.8842,  3.9905, 17.1386, 60.8792, 17.7025, 46.7602], device='cuda:0')
tensor([ 6.4397, -2.9787, 23.9602,  3.6775,  1.3849,  1.3804], device='cuda:0')
tensor([ 4.9624,  8.4019,  7.0322, 64.1171, -2.0388, 41.4343], device='cuda:0')
tensor([ 1.3324, 10.6383,  6.4195, 11.0323,  0.9668, 17.3313], device='cuda:0')
tensor([ 5.2028, -4.9355, 19.0591,  8.3721, -0.2175,  7.1489], device='cuda:0')
tensor([ 4.8853, -7.6596, 17.8119,  8.5512, -0.2352,  6.9785], device='cuda:0')
tensor([ 5.1510,  0.1343, -0.1747, 55.9759,  0.8410, 23.5770], device='cuda:0')
tensor([ 7.4760,  1.1064, 13.0199,  9.1715,  0.1307,  1.6104], device='cuda:0')
tensor([ -6.8685,  46.6632, -40.4901,  88.6407,  -8.4602,  23.4892],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([13.174

 24%|██▍       | 193/796 [00:06<00:20, 29.23it/s]

tensor([ 6.3004e-02,  1.2966e-01, -1.2240e+01,  7.7055e+01,  3.2700e+00,
         2.4280e+01], device='cuda:0')
tensor([ 7.4760,  1.1064, 13.0199,  9.1715,  0.1307,  1.6104], device='cuda:0')
tensor([-2.5979e+00, -4.4832e+00,  1.1694e+01,  8.8825e+01,  4.1237e-02,
         3.6652e+01], device='cuda:0')
tensor([ 1.4804, -6.2128, 20.8861, 15.2858,  1.2804, 19.0184], device='cuda:0')
tensor([  7.7034, -24.4573, -19.2153,  68.4640,   6.8934,  20.9618],
       device='cuda:0')
tensor([ 7.4760,  1.1064, 13.0199,  9.1715,  0.1307,  1.6104], device='cuda:0')
tensor([  7.0689,  64.7530, -73.1776,  81.7090,   2.7951,  35.9874],
       device='cuda:0')
tensor([-0.3701, -8.4255, 21.1573, -4.6522,  0.3397,  1.7638], device='cuda:0')
tensor([-10.0539,  16.1244,  -7.2583,  79.5745, -10.5282,  39.9718],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([  1.4169,   9.6965, -21.2975,  65.1204,  -5.3341,  16.9564],
       device='cuda:0')
tens

 25%|██▌       | 201/796 [00:07<00:21, 27.34it/s]

tensor([-0.4315, -9.6893, 34.2598, 14.5525,  4.3105, 12.5403], device='cuda:0')
tensor([  1.4064, -11.2340,  35.2622,  -3.3230,   4.3115,   1.9172],
       device='cuda:0')
tensor([ -4.2825, -16.4948,  29.8565,  75.7990,  -9.7803,  34.7335],
       device='cuda:0')
tensor([ 3.9230, -9.1064, 27.8481,  5.5383, -0.2874,  1.6871], device='cuda:0')
tensor([  4.1151,  21.4608, -97.4190,  70.0406, -11.7525,  21.2866],
       device='cuda:0')
tensor([ 4.8853, -7.6596, 17.8119,  8.5512, -0.2352,  6.9785], device='cuda:0')
tensor([ 11.8364,  -8.3409, -22.0976,  82.1048,   0.3622,  20.5925],
       device='cuda:0')
tensor([ 11.8431, -14.0425, -26.5823,  82.5432,   0.3397,   4.9847],
       device='cuda:0')
tensor([-30.3421,  24.6781,  -4.4784,  98.5236,   5.7648,  21.7637],
       device='cuda:0')
tensor([ 1.9985, -8.1702, 10.5787, 24.1471, -1.5939, 20.3221], device='cuda:0')
tensor([-2.9764, 27.4591, -9.4806, 60.5014, -7.5281, 48.0980], device='cuda:0')
tensor([ 4.2931, -1.9574, 19.3490,  4.0319

 26%|██▋       | 209/796 [00:07<00:19, 30.11it/s]

tensor([18.2723, 17.7334,  5.8526, 76.2547, 28.8384, 35.4059], device='cuda:0')
tensor([ 4.3671, -1.1064, 21.1573, -1.0191,  3.5798,  1.5337], device='cuda:0')
tensor([13.8377, 14.0858,  4.8394, 68.9339, 23.1493, 35.1931], device='cuda:0')
tensor([ 4.3671, -1.1064, 21.1573, -1.0191,  3.5798,  1.5337], device='cuda:0')
tensor([ -2.3549, -42.5859,  34.3616,  79.3928,   6.7270,  34.9443],
       device='cuda:0')
tensor([ 3.9230, -9.1064, 27.8481,  5.5383, -0.2874,  1.6871], device='cuda:0')
tensor([  0.0988, -27.5700,  34.2921,  80.7467,  -9.5939,  34.7334],
       device='cuda:0')
tensor([ 3.9230, -9.1064, 27.8481,  5.5383, -0.2874,  1.6871], device='cuda:0')
tensor([  1.2206,  28.2408, -28.8071,  80.4082,  -0.1858,  26.9745],
       device='cuda:0')
tensor([ 4.3671, -2.6383, 20.9765, -2.0824, -0.1307,  1.7638], device='cuda:0')
tensor([ 2.9005, 22.8547,  0.6875, 65.8829, 16.8503, 35.6528], device='cuda:0')
tensor([ 5.9956, -2.8936, 21.1573,  1.4621,  2.5346,  1.6871], device='cuda:0')
t

 27%|██▋       | 217/796 [00:07<00:18, 31.43it/s]

tensor([-1.1155, -1.4618, 11.3439, 87.3437,  0.6764, 36.6521], device='cuda:0')
tensor([ 1.4804, -6.2128, 20.8861, 15.2858,  1.2804, 19.0184], device='cuda:0')
tensor([ 7.3352, 28.0609, -9.0720, 72.1564, 14.7335, 43.7811], device='cuda:0')
tensor([ 3.5529, -1.0213, 18.3544,  7.6650, -2.4301,  7.7454], device='cuda:0')
tensor([ 15.9732, -15.1770,  22.6696,  80.5443,  22.5891,  43.3938],
       device='cuda:0')
tensor([ 7.0318, -1.3617, 21.1573,  0.9304,  1.0191,  1.8405], device='cuda:0')
tensor([ -0.7603,  -0.4464, -15.8213,  80.5526,   3.3195,  24.3008],
       device='cuda:0')
tensor([ 7.4760,  1.1064, 13.0199,  9.1715,  0.1307,  1.6104], device='cuda:0')
tensor([ 9.5323,  1.0387, 31.5627, 68.7049, 22.5441, 43.3978], device='cuda:0')
tensor([ 7.0318, -1.3617, 21.1573,  0.9304,  1.0191,  1.8405], device='cuda:0')
tensor([ 1.9266,  1.2886,  4.2886, 72.0012, -6.8725, 32.6423], device='cuda:0')
tensor([ 1.7024, -6.1277, 23.3273,  2.8799,  2.1688,  6.5951], device='cuda:0')
tensor([ -3.40

 28%|██▊       | 225/796 [00:07<00:18, 30.82it/s]

tensor([ 2.6832, -4.1264, 14.5555, 22.8764, -0.3949,  4.9732], device='cuda:0')
tensor([ 5.6255, -7.7447, 16.6365, 10.8551,  1.2281,  4.6779], device='cuda:0')
tensor([ 4.2716,  6.8908, -5.9347, 74.2876, -7.6695, 28.1896], device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([ -2.0630,  50.9954, -47.1789,  90.3056,   7.8814,  37.3602],
       device='cuda:0')
tensor([ 5.1814, -0.2553, 15.9132,  7.3992,  2.0643,  1.3804], device='cuda:0')
tensor([14.4190, 21.4166,  1.5465, 81.2857, -1.0357, 21.0651], device='cuda:0')
tensor([ 1.7024, -6.1277, 23.3273,  2.8799,  2.1688,  6.5951], device='cuda:0')
tensor([  4.2510,  21.1992, -97.3050,  65.2462, -11.7557,  21.2225],
       device='cuda:0')
tensor([ 4.8853, -7.6596, 17.8119,  8.5512, -0.2352,  6.9785], device='cuda:0')
tensor([ 7.2268, 10.5421,  3.7031, 65.4503, 13.8410, 43.8571], device='cuda:0')
tensor([ 3.5529, -1.0213, 18.3544,  7.6650, -2.4301,  7.7454], device='cuda:0')
t

 29%|██▉       | 229/796 [00:08<00:20, 27.44it/s]

tensor([ 3.9230, -9.1064, 27.8481,  5.5383, -0.2874,  1.6871], device='cuda:0')
tensor([-19.1848,  58.6098, -61.4896,  95.2041,  -2.2675,  37.9562],
       device='cuda:0')
tensor([ 5.6255, -7.7447, 16.6365, 10.8551,  1.2281,  4.6779], device='cuda:0')
tensor([ -1.4015,  24.4290, -97.0991,  72.7024,  -1.6374,  20.5651],
       device='cuda:0')
tensor([ 3.3309,  1.6170, 23.8698, -6.2472,  2.5869,  1.8405], device='cuda:0')
tensor([  8.2070,  58.1620, -73.1972,  92.6104,   2.5984,  37.1098],
       device='cuda:0')
tensor([-0.3701, -8.4255, 21.1573, -4.6522,  0.3397,  1.7638], device='cuda:0')
tensor([ 5.9143,  9.3271, -4.2240, 79.9265,  4.0820, 20.6911], device='cuda:0')
tensor([ 1.4804, -6.2128, 20.8861, 15.2858,  1.2804, 19.0184], device='cuda:0')
tensor([ 0.8326, 19.6716,  2.4669, 82.4226,  2.7512, 30.8254], device='cuda:0')
tensor([ 1.0363, 18.8936,  1.4467, 82.2773,  2.7959, 30.8282], device='cuda:0')


 30%|██▉       | 237/796 [00:08<00:21, 26.01it/s]

tensor([ -9.8423,  51.2355, -59.8136,  92.3429,  -3.6063,  21.4180],
       device='cuda:0')
tensor([ 5.6255, -7.7447, 16.6365, 10.8551,  1.2281,  4.6779], device='cuda:0')
tensor([  3.5704,  23.4541, -96.2826,  74.7211, -11.7504,  21.3213],
       device='cuda:0')
tensor([ 4.8853, -7.6596, 17.8119,  8.5512, -0.2352,  6.9785], device='cuda:0')
tensor([  4.0417,   7.8022,   6.8246,  50.8069, -16.4775,  27.6665],
       device='cuda:0')
tensor([ 4.2931, -1.9574, 19.3490,  4.0319,  1.6985,  1.7638], device='cuda:0')
tensor([-18.3592,   3.4708, -28.0376,  98.2355,   5.8773,  21.5782],
       device='cuda:0')
tensor([ 1.9985, -8.1702, 10.5787, 24.1471, -1.5939, 20.3221], device='cuda:0')
tensor([ 2.1060,  0.9291, -5.9062, 68.4168, -4.8007, 27.5425], device='cuda:0')
tensor([  4.8112, -11.8298,  18.6257,  25.4763,   0.9146,   1.9939],
       device='cuda:0')
tensor([ -7.6668,  47.8941, -39.1072,  98.4034,  -0.6235,   3.8666],
       device='cuda:0')
tensor([ 1.9985, -8.1702, 10.5787, 24.1471

 31%|███       | 244/796 [00:08<00:19, 28.21it/s]


tensor([-5.0931, -4.2139, -2.2265, 85.2428,  4.5602, 42.3491], device='cuda:0')
tensor([ 7.8460, -6.6383, 20.0723, 11.8299,  3.0572,  2.7607], device='cuda:0')
tensor([-1.7701, -8.6147, 17.6911, 34.6983,  5.5164, 24.2159], device='cuda:0')
tensor([ -0.5922, -10.2979,  16.4557,  22.6407,   3.5276,  24.2331],
       device='cuda:0')
tensor([  9.3496, -23.8487,  -2.6055,  65.6143,  18.6775,  21.0752],
       device='cuda:0')
tensor([ 7.0318, -1.3617, 21.1573,  0.9304,  1.0191,  1.8405], device='cuda:0')
tensor([ -6.6634, -17.4313,  41.6454,  72.0716,  10.3380,  42.7846],
       device='cuda:0')
tensor([ 7.8460, -6.6383, 20.0723, 11.8299,  3.0572,  2.7607], device='cuda:0')


KeyboardInterrupt: 

Exception in thread Thread-18 (_pin_memory_loop):
Traceback (most recent call last):
  File "/home/may33/miniconda3/envs/isaac/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/home/may33/miniconda3/envs/isaac/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/home/may33/miniconda3/envs/isaac/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py", line 61, in _pin_memory_loop
    do_one_step()
  File "/home/may33/miniconda3/envs/isaac/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py", line 37, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/may33/miniconda3/envs/isaac/lib/python3.11/multiprocessing/queues.py", line 122, in get


    return _ForkingPickler.loads(res)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/may33/miniconda3/envs/isaac/lib/python3.11/site-packages/torch/multiprocessing/reductions.py", line 541, in rebuild_storage_fd
    fd = df.detach()
         ^^^^^^^^^^^
  File "/home/may33/miniconda3/envs/isaac/lib/python3.11/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/may33/miniconda3/envs/isaac/lib/python3.11/multiprocessing/resource_sharer.py", line 86, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/may33/miniconda3/envs/isaac/lib/python3.11/multiprocessing/connection.py", line 519, in Client
    c = SocketClient(address)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/may33/miniconda3/envs/isaac/lib/python3.11/multiprocessing/connection.py", line

In [None]:

        states = batch["observation.state"].to(device)

        actions = batch["action"].to(device)

        # print(states)
        # print(actions)

        y_pred = model(states)
        loss   = criterion(y_pred, actions) 
        total_loss += loss.item() 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        b += 1 
        
        loss_list.append(loss.item())
        
    print(f'Epoch {itr} Loss: {total_loss/b:.4f}')

In [None]:
from pathlib import Path

modelpath = "/home/may33/projects/robotics/smolvla_in_isaac/training/checkpoints/mlpmodel.pth"

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'model_config': {
          'inputdim': 6,
          'size': 32,
          'actiondim': 6
      },
      'loss_list': loss_list,
      'n_epoch': n_epoch,
}, modelpath)

print(f"Model saved to {modelpath}")

Model saved to /home/may33/projects/robotics/smolvla_in_isaac/training/checkpoints/mlpmodel.pth
