In [1]:
import torch
import json
import torch
import os
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import random
import math
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import clip

In [2]:
dataset_root = '/local/vondrick/ruoshi/objaverse/views_whole_sphere'

# dataset

In [3]:
class objaverse_sfm(Dataset):
    def __init__(self, root_dir, total_view, train=True, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        with open(os.path.join(root_dir, 'valid_paths.json')) as f:
            self.paths = json.load(f)import clip
        random.shuffle(self.paths)
        self.total_view = total_view
        self.train = train
        total_objects = len(self.paths)
        if train:
            self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training|
        else:
            self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
        
    def __len__(self):
        return len(self.img_labels)

    def __len__(self):
        return len(self.paths)
        
    def cartesian_to_spherical(self, xyz):
        ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
        xy = xyz[:,0]**2 + xyz[:,1]**2
        z = np.sqrt(xy + xyz[:,2]**2)
        theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
        #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
        azimuth = np.arctan2(xyz[:,1], xyz[:,0])
        return np.array([theta, azimuth, z])

    def get_T(self, target_RT, cond_RT):
        R, T = target_RT[:3, :3], target_RT[:, -1]
        T_target = -R.T @ T

        R, T = cond_RT[:3, :3], cond_RT[:, -1]
        T_cond = -R.T @ T

        theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
        theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
        
        d_theta = theta_target - theta_cond
        d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
        d_z = z_target - z_cond
        
        d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
#         d_T = torch.tensor([d_theta.item(), d_azimuth.item(), d_z.item()])
        return d_T

    def load_im(self, path):
        '''
        replace background pixel with white in rendering
        '''
        img = plt.imread(path)
        img[img[:, :, -1] == 0.] = [1., 1., 1., 1.]
        img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
        return img

    def __getitem__(self, index):
        data = {}
        index_target, index_cond = random.sample(range(self.total_view), 2) # without replacement
        filename = os.path.join(self.root_dir, self.paths[index])

        # print(self.paths[index])

#         if self.return_paths:
#             data["path"] = str(filename)
            
        target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target)))
        target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
        cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond)))
        cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond))

        data["image_target"] = target_im
        data["image_cond"] = cond_im
        data["T"] = self.get_T(target_RT, cond_RT)
        return data
    
    def process_im(self, im):
        im = im.convert("RGB")
        im = self.transform(im)
        im = torchvision.transforms.functional.resize(im, 224)
        im = im * 2. - 1.
        return im

SyntaxError: invalid syntax (<ipython-input-3-9a652f559f71>, line 6)

In [4]:
dataset = objaverse_sfm(dataset_root, 4, train=True, transform = ToTensor())

In [5]:
dataset[0]['image_target'].shape

torch.Size([3, 224, 224])

In [30]:
train_dataloader = DataLoader(objaverse_sfm(dataset_root, 4, train=True, transform = ToTensor()),\
                              batch_size=16, shuffle=True, num_workers=4)
test_dataloader = DataLoader(objaverse_sfm(dataset_root, 4, train=False, transform = ToTensor()),\
                             batch_size=16, shuffle=False, num_workers=4)

In [31]:
batch = next(iter(train_dataloader))

In [8]:
batch['image_target'].shape

torch.Size([16, 3, 224, 224])

In [76]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)

Using cache found in /home/rliu/.cache/torch/hub/pytorch_vision_v0.10.0


In [77]:
model.layer3 = torch.nn.Sequential()
model.layer4 = torch.nn.Sequential()
model.fc = torch.nn.Sequential()
model.avgpool = torch.nn.Sequential()

In [78]:
feature = model(image_cond.cpu())

In [79]:
feature.shape

torch.Size([64, 100352])

In [48]:
proj_layer(feature.reshape([64, 128, 28, 28])).shape

torch.Size([64, 64, 14, 14])

In [38]:
proj_layer = torch.nn.Conv2d(256, 32, 1)

# model

In [None]:
class sfm(torch.nn.Module):
    def __init__(self, resnet=True):
        super(sfm, self).__init__()
        
        if resnet:
            print('initializing resnet weights')
            self.cnn = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
            self.cnn.conv1 = torch.nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.cnn.layer3 = torch.nn.Sequential()
            self.cnn.layer4 = torch.nn.Sequential()
            self.cnn.avgpool = torch.nn.Sequential()
            self.cnn.fc = torch.nn.Sequential()
            self.proj_layer = torch.nn.Conv2d(128, 16, 1, bias=False)
            self.linear1 = torch.nn.Linear(12544, 64)
            self.activation = torch.nn.ReLU()
            self.linear2 = torch.nn.Linear(64, 4)
    
        else:
            print('initializing vit weights')
            self.cond_net = torchvision.models.vit_b_16(torchvision.models.ViT_B_16_Weights.DEFAULT)
            self.cond_net.heads.head = torch.nn.Sequential()
            self.target_net = torchvision.models.vit_b_16(torchvision.models.ViT_B_16_Weights.DEFAULT)
            self.target_net.heads.head = torch.nn.Sequential()
            self.linear1 = torch.nn.Linear(768 * 2, 768 * 2)
            self.activation = torch.nn.ReLU()
            self.linear2 = torch.nn.Linear(768 * 2, 768 * 2)
            self.activation = torch.nn.ReLU()
            self.linear3 = torch.nn.Linear(768 * 2, 4)

    def forward(self, cond, target):
        B = cond.shape[0]
        x = torch.cat([cond, target], dim=1)
        x = self.cnn(x)
        x = self.proj_layer(x.reshape([B, 128, 28, 28]))
        x = self.linear1(x.reshape(B, -1))
        x = self.activation(x)
        x = self.linear2(x)
        return x

In [94]:
model = sfm(True).to(device)

initializing resnet weights


Using cache found in /home/rliu/.cache/torch/hub/pytorch_vision_v0.10.0


In [95]:
image_target, image_cond, T = batch['image_target'].to(device), batch['image_cond'].to(device), batch['T'].to(device)

In [96]:
image_target.shape, image_cond.shape, T.shape

(torch.Size([64, 3, 224, 224]),
 torch.Size([64, 3, 224, 224]),
 torch.Size([64, 4]))

In [97]:
pred = model(image_cond, image_target)

In [98]:
loss = ((pred - T) ** 2).mean()

In [99]:
pred.shape

torch.Size([64, 4])

In [16]:
# train_dataloader = DataLoader(objaverse_sfm(dataset_root, 4, train=True, transform = ToTensor()),\
#                               batch_size=64, shuffle=True, num_workers=8)
# test_dataloader = DataLoader(objaverse_sfm(dataset_root, 4, train=False, transform = ToTensor()),\
#                              batch_size=64, shuffle=False, num_workers=8)
# azimuth = []
# sin = []
# cos = []
# for i, batch in tqdm(enumerate(train_dataloader, 0), total=50):
#     sin.append(batch['T'][:, 1])
#     cos.append(batch['T'][:, 2])
#     if i == 50:
#         break

In [17]:
# plt.hist(torch.cat(azimuth).numpy(), bins=40)

# training

## unzip files

In [18]:
# from sh import gunzip
# import json
# from tqdm.notebook import tqdm
# import os

# path = '/home/rliu/Desktop/cvfiler04/datasets/objaverse/hf-objaverse-v1/views_whole_sphere'
# with open('/home/rliu/Desktop/cvfiler04/datasets/objaverse/hf-objaverse-v1/views_whole_sphere/valid_paths.json') as f:
#     paths = json.load(f)
# total_view = 4
# gz = []
# for i, id in tqdm(enumerate(paths), total = len(paths)):
#     for filename in os.listdir(os.path.join(path, id)):
        
#         if 'gz' in filename:
# #             print(filename[:-3])
#             gz.append(id)
# #             gunzip(os.path.join(path, id, filename[:-3]))

In [100]:
device = 6

In [104]:
train_set = objaverse_sfm(dataset_root, 4, train=True, transform = ToTensor())
train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=64, drop_last=False)

test_set = objaverse_sfm(dataset_root, 4, train=False, transform = ToTensor())
test_dataloader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=64, drop_last=False)

model = sfm(resnet=True).to(device)

print('\n================== total trainable parameters: %d ==================\n' % sum(p.numel() for p in model.parameters() if p.requires_grad))
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

# model = torch.nn.DataParallel(model, device_ids=[6, 7])
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)

for epoch in range(50):  # loop over the dataset multiple times

    print('validation')
    model.eval()
    test_loss = torch.zeros(4).to(device)
    with torch.no_grad():
        for i, batch in tqdm(enumerate(test_dataloader, 0), total=len(test_dataloader)):
            # get the inputs; data is a list of [inputs, labels]
            image_target, image_cond, T_gt = \
                batch['image_target'].to(device), batch['image_cond'].to(device), batch['T'].to(device)

            # forward + backward + optimize
            T_pred = model(image_cond, image_target)
            loss = (T_pred - T_gt).abs()
            test_loss += loss.sum(dim=0)
    
    print('validation loss for epoch %d:' % epoch, test_loss.cpu() / len(test_set), test_loss.cpu().sum().item() / len(test_set))
    
    running_loss = torch.zeros(4).to(device)
    
    model.train()
    
    for i, batch in tqdm(enumerate(train_dataloader, 0), total=len(train_dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        image_target, image_cond, T_gt = \
            batch['image_target'].to(device), batch['image_cond'].to(device), batch['T'].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        T_pred = model(image_cond, image_target)
        loss = ((T_pred - T_gt)**2).mean(dim=0)
        l1_loss = ((T_pred - T_gt).abs()).mean(dim=0)
        running_loss += l1_loss.detach()
        loss = loss.sum()
        loss.backward()
        optimizer.step()

        # print statistics
        if i % 200 == 199:    # print every 2000 mini-batches
            print(f'[{i-199}, {i+1:d}] loss: ', running_loss.detach().cpu() / 200, \
                  running_loss.detach().cpu().sum() / 200)
            running_loss = torch.zeros(4).to(device)

            

initializing resnet weights


Using cache found in /home/rliu/.cache/torch/hub/pytorch_vision_v0.10.0




cnn.conv1.weight torch.Size([64, 6, 7, 7])
cnn.bn1.weight torch.Size([64])
cnn.bn1.bias torch.Size([64])
cnn.layer1.0.conv1.weight torch.Size([64, 64, 3, 3])
cnn.layer1.0.bn1.weight torch.Size([64])
cnn.layer1.0.bn1.bias torch.Size([64])
cnn.layer1.0.conv2.weight torch.Size([64, 64, 3, 3])
cnn.layer1.0.bn2.weight torch.Size([64])
cnn.layer1.0.bn2.bias torch.Size([64])
cnn.layer1.1.conv1.weight torch.Size([64, 64, 3, 3])
cnn.layer1.1.bn1.weight torch.Size([64])
cnn.layer1.1.bn1.bias torch.Size([64])
cnn.layer1.1.conv2.weight torch.Size([64, 64, 3, 3])
cnn.layer1.1.bn2.weight torch.Size([64])
cnn.layer1.1.bn2.bias torch.Size([64])
cnn.layer2.0.conv1.weight torch.Size([128, 64, 3, 3])
cnn.layer2.0.bn1.weight torch.Size([128])
cnn.layer2.0.bn1.bias torch.Size([128])
cnn.layer2.0.conv2.weight torch.Size([128, 128, 3, 3])
cnn.layer2.0.bn2.weight torch.Size([128])
cnn.layer2.0.bn2.bias torch.Size([128])
cnn.layer2.0.downsample.0.weight torch.Size([128, 64, 1, 1])
cnn.layer2.0.downsample.1.w

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 0: tensor([0.7636, 0.6330, 0.6397, 0.2344]) 2.270752014297795


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.5953, 0.6398, 0.6284, 0.1781]) tensor(2.0416)
[200, 400] loss:  tensor([0.5148, 0.6355, 0.6172, 0.1466]) tensor(1.9140)
[400, 600] loss:  tensor([0.4828, 0.6350, 0.6103, 0.1360]) tensor(1.8641)
[600, 800] loss:  tensor([0.4695, 0.6313, 0.6047, 0.1338]) tensor(1.8392)
[800, 1000] loss:  tensor([0.4560, 0.6326, 0.5991, 0.1297]) tensor(1.8174)
[1000, 1200] loss:  tensor([0.4477, 0.6293, 0.5951, 0.1281]) tensor(1.8002)
[1200, 1400] loss:  tensor([0.4383, 0.6262, 0.5950, 0.1265]) tensor(1.7860)
[1400, 1600] loss:  tensor([0.4358, 0.6321, 0.5838, 0.1255]) tensor(1.7771)
[1600, 1800] loss:  tensor([0.4374, 0.6255, 0.5828, 0.1229]) tensor(1.7687)
[1800, 2000] loss:  tensor([0.4258, 0.6270, 0.5788, 0.1228]) tensor(1.7544)
[2000, 2200] loss:  tensor([0.4223, 0.6286, 0.5742, 0.1212]) tensor(1.7462)
[2200, 2400] loss:  tensor([0.4162, 0.6238, 0.5681, 0.1213]) tensor(1.7294)
[2400, 2600] loss:  tensor([0.4126, 0.6215, 0.5661, 0.1203]) tensor(1.7204)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 1: tensor([0.4195, 0.6087, 0.5526, 0.2309]) 1.811702918480644


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3850, 0.6010, 0.5437, 0.1126]) tensor(1.6423)
[200, 400] loss:  tensor([0.3866, 0.6002, 0.5440, 0.1131]) tensor(1.6438)
[400, 600] loss:  tensor([0.3775, 0.6030, 0.5388, 0.1130]) tensor(1.6323)
[600, 800] loss:  tensor([0.3840, 0.5988, 0.5410, 0.1127]) tensor(1.6364)
[800, 1000] loss:  tensor([0.3812, 0.6003, 0.5376, 0.1117]) tensor(1.6308)
[1000, 1200] loss:  tensor([0.3830, 0.5940, 0.5420, 0.1122]) tensor(1.6313)
[1200, 1400] loss:  tensor([0.3773, 0.6004, 0.5339, 0.1112]) tensor(1.6227)
[1400, 1600] loss:  tensor([0.3777, 0.5970, 0.5390, 0.1112]) tensor(1.6248)
[1600, 1800] loss:  tensor([0.3753, 0.5954, 0.5376, 0.1107]) tensor(1.6189)
[1800, 2000] loss:  tensor([0.3797, 0.5929, 0.5349, 0.1112]) tensor(1.6187)
[2000, 2200] loss:  tensor([0.3777, 0.5925, 0.5380, 0.1116]) tensor(1.6197)
[2200, 2400] loss:  tensor([0.3718, 0.5960, 0.5295, 0.1103]) tensor(1.6075)
[2400, 2600] loss:  tensor([0.3739, 0.5967, 0.5301, 0.1109]) tensor(1.6116)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 2: tensor([0.4877, 0.6083, 0.5486, 0.5075]) 2.1520391094807065


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3569, 0.5815, 0.5173, 0.1075]) tensor(1.5632)
[200, 400] loss:  tensor([0.3636, 0.5806, 0.5212, 0.1075]) tensor(1.5729)
[400, 600] loss:  tensor([0.3611, 0.5778, 0.5191, 0.1079]) tensor(1.5659)
[600, 800] loss:  tensor([0.3611, 0.5778, 0.5184, 0.1077]) tensor(1.5651)
[800, 1000] loss:  tensor([0.3601, 0.5769, 0.5180, 0.1089]) tensor(1.5639)
[1000, 1200] loss:  tensor([0.3602, 0.5765, 0.5193, 0.1073]) tensor(1.5632)
[1200, 1400] loss:  tensor([0.3605, 0.5797, 0.5212, 0.1077]) tensor(1.5691)
[1400, 1600] loss:  tensor([0.3590, 0.5776, 0.5180, 0.1074]) tensor(1.5619)
[1600, 1800] loss:  tensor([0.3642, 0.5786, 0.5194, 0.1082]) tensor(1.5704)
[1800, 2000] loss:  tensor([0.3600, 0.5774, 0.5192, 0.1075]) tensor(1.5642)
[2000, 2200] loss:  tensor([0.3562, 0.5806, 0.5186, 0.1075]) tensor(1.5630)
[2200, 2400] loss:  tensor([0.3575, 0.5752, 0.5201, 0.1072]) tensor(1.5599)
[2400, 2600] loss:  tensor([0.3575, 0.5794, 0.5162, 0.1076]) tensor(1.5607)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 3: tensor([0.5996, 0.6095, 0.5509, 0.5984]) 2.358358797528815


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3499, 0.5658, 0.5133, 0.1057]) tensor(1.5347)
[200, 400] loss:  tensor([0.3494, 0.5659, 0.5086, 0.1069]) tensor(1.5308)
[400, 600] loss:  tensor([0.3466, 0.5689, 0.5102, 0.1050]) tensor(1.5306)
[600, 800] loss:  tensor([0.3520, 0.5691, 0.5068, 0.1061]) tensor(1.5340)
[800, 1000] loss:  tensor([0.3481, 0.5686, 0.5061, 0.1049]) tensor(1.5277)
[1000, 1200] loss:  tensor([0.3474, 0.5681, 0.5047, 0.1053]) tensor(1.5255)
[1200, 1400] loss:  tensor([0.3495, 0.5670, 0.5100, 0.1060]) tensor(1.5325)
[1400, 1600] loss:  tensor([0.3460, 0.5627, 0.5085, 0.1059]) tensor(1.5231)
[1600, 1800] loss:  tensor([0.3412, 0.5601, 0.5085, 0.1053]) tensor(1.5151)
[1800, 2000] loss:  tensor([0.3463, 0.5639, 0.5062, 0.1058]) tensor(1.5223)
[2000, 2200] loss:  tensor([0.3438, 0.5621, 0.5064, 0.1058]) tensor(1.5181)
[2200, 2400] loss:  tensor([0.3482, 0.5655, 0.5054, 0.1056]) tensor(1.5247)
[2400, 2600] loss:  tensor([0.3448, 0.5640, 0.5032, 0.1051]) tensor(1.5172)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 4: tensor([0.4775, 0.6060, 0.5462, 0.5592]) 2.188898165403408


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3369, 0.5574, 0.4969, 0.1045]) tensor(1.4957)
[200, 400] loss:  tensor([0.3353, 0.5606, 0.4971, 0.1052]) tensor(1.4982)
[400, 600] loss:  tensor([0.3399, 0.5559, 0.4980, 0.1062]) tensor(1.5000)
[600, 800] loss:  tensor([0.3420, 0.5570, 0.5014, 0.1045]) tensor(1.5050)
[800, 1000] loss:  tensor([0.3414, 0.5540, 0.5033, 0.1045]) tensor(1.5032)
[1000, 1200] loss:  tensor([0.3367, 0.5559, 0.4967, 0.1037]) tensor(1.4929)
[1200, 1400] loss:  tensor([0.3420, 0.5531, 0.4979, 0.1056]) tensor(1.4986)
[1400, 1600] loss:  tensor([0.3332, 0.5548, 0.4960, 0.1035]) tensor(1.4875)
[1600, 1800] loss:  tensor([0.3358, 0.5534, 0.4987, 0.1052]) tensor(1.4931)
[1800, 2000] loss:  tensor([0.3408, 0.5569, 0.4987, 0.1041]) tensor(1.5004)
[2000, 2200] loss:  tensor([0.3371, 0.5548, 0.4991, 0.1032]) tensor(1.4943)
[2200, 2400] loss:  tensor([0.3379, 0.5556, 0.4957, 0.1035]) tensor(1.4927)
[2400, 2600] loss:  tensor([0.3398, 0.5518, 0.4917, 0.1038]) tensor(1.4871)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 5: tensor([0.5046, 0.5968, 0.5500, 0.4642]) 2.1156620951829117


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3324, 0.5473, 0.4899, 0.1033]) tensor(1.4730)
[200, 400] loss:  tensor([0.3358, 0.5446, 0.4898, 0.1027]) tensor(1.4729)
[400, 600] loss:  tensor([0.3350, 0.5498, 0.4870, 0.1039]) tensor(1.4757)
[600, 800] loss:  tensor([0.3311, 0.5475, 0.4916, 0.1032]) tensor(1.4733)
[800, 1000] loss:  tensor([0.3306, 0.5513, 0.4911, 0.1028]) tensor(1.4758)
[1000, 1200] loss:  tensor([0.3329, 0.5499, 0.4936, 0.1036]) tensor(1.4799)
[1200, 1400] loss:  tensor([0.3331, 0.5453, 0.4911, 0.1035]) tensor(1.4729)
[1400, 1600] loss:  tensor([0.3317, 0.5494, 0.4925, 0.1036]) tensor(1.4772)
[1600, 1800] loss:  tensor([0.3324, 0.5459, 0.4930, 0.1035]) tensor(1.4748)
[1800, 2000] loss:  tensor([0.3285, 0.5456, 0.4911, 0.1038]) tensor(1.4689)
[2000, 2200] loss:  tensor([0.3273, 0.5474, 0.4876, 0.1033]) tensor(1.4656)
[2200, 2400] loss:  tensor([0.3345, 0.5444, 0.4878, 0.1020]) tensor(1.4686)
[2400, 2600] loss:  tensor([0.3318, 0.5465, 0.4874, 0.1031]) tensor(1.4688)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 6: tensor([0.5113, 0.5948, 0.5392, 0.4530]) 2.0983609605988476


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3254, 0.5412, 0.4848, 0.1023]) tensor(1.4537)
[200, 400] loss:  tensor([0.3267, 0.5376, 0.4862, 0.1020]) tensor(1.4525)
[400, 600] loss:  tensor([0.3289, 0.5411, 0.4783, 0.1024]) tensor(1.4507)
[600, 800] loss:  tensor([0.3272, 0.5439, 0.4840, 0.1025]) tensor(1.4575)
[800, 1000] loss:  tensor([0.3244, 0.5426, 0.4852, 0.1025]) tensor(1.4547)
[1000, 1200] loss:  tensor([0.3277, 0.5404, 0.4875, 0.1018]) tensor(1.4574)
[1200, 1400] loss:  tensor([0.3214, 0.5435, 0.4869, 0.1015]) tensor(1.4533)
[1400, 1600] loss:  tensor([0.3295, 0.5430, 0.4834, 0.1033]) tensor(1.4592)
[1600, 1800] loss:  tensor([0.3226, 0.5374, 0.4848, 0.1025]) tensor(1.4473)
[1800, 2000] loss:  tensor([0.3283, 0.5392, 0.4870, 0.1022]) tensor(1.4566)
[2000, 2200] loss:  tensor([0.3275, 0.5370, 0.4842, 0.1028]) tensor(1.4516)
[2200, 2400] loss:  tensor([0.3283, 0.5404, 0.4866, 0.1034]) tensor(1.4587)
[2400, 2600] loss:  tensor([0.3241, 0.5402, 0.4844, 0.1022]) tensor(1.4509)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 7: tensor([0.5661, 0.5920, 0.5503, 0.5869]) 2.2952725174611626


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3204, 0.5355, 0.4810, 0.1031]) tensor(1.4399)
[200, 400] loss:  tensor([0.3210, 0.5354, 0.4791, 0.1018]) tensor(1.4373)
[400, 600] loss:  tensor([0.3224, 0.5356, 0.4799, 0.1027]) tensor(1.4406)
[600, 800] loss:  tensor([0.3302, 0.5347, 0.4806, 0.1038]) tensor(1.4493)
[800, 1000] loss:  tensor([0.3256, 0.5346, 0.4787, 0.1020]) tensor(1.4409)
[1000, 1200] loss:  tensor([0.3219, 0.5320, 0.4786, 0.1014]) tensor(1.4340)
[1200, 1400] loss:  tensor([0.3198, 0.5320, 0.4767, 0.1020]) tensor(1.4304)
[1400, 1600] loss:  tensor([0.3216, 0.5328, 0.4829, 0.1020]) tensor(1.4393)
[1600, 1800] loss:  tensor([0.3258, 0.5338, 0.4841, 0.1023]) tensor(1.4460)
[1800, 2000] loss:  tensor([0.3183, 0.5345, 0.4771, 0.1012]) tensor(1.4310)
[2000, 2200] loss:  tensor([0.3241, 0.5320, 0.4783, 0.1021]) tensor(1.4364)
[2200, 2400] loss:  tensor([0.3212, 0.5326, 0.4784, 0.1017]) tensor(1.4338)
[2400, 2600] loss:  tensor([0.3242, 0.5301, 0.4805, 0.1021]) tensor(1.4368)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 8: tensor([0.3826, 0.5466, 0.4878, 0.2296]) 1.6465722191336758


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3186, 0.5330, 0.4742, 0.1023]) tensor(1.4280)
[200, 400] loss:  tensor([0.3226, 0.5283, 0.4735, 0.1019]) tensor(1.4263)
[400, 600] loss:  tensor([0.3206, 0.5319, 0.4750, 0.1021]) tensor(1.4296)
[600, 800] loss:  tensor([0.3183, 0.5302, 0.4742, 0.1012]) tensor(1.4239)
[800, 1000] loss:  tensor([0.3184, 0.5269, 0.4712, 0.1011]) tensor(1.4176)
[1000, 1200] loss:  tensor([0.3229, 0.5268, 0.4766, 0.1023]) tensor(1.4286)
[1200, 1400] loss:  tensor([0.3154, 0.5295, 0.4763, 0.1015]) tensor(1.4227)
[1400, 1600] loss:  tensor([0.3184, 0.5321, 0.4707, 0.1012]) tensor(1.4225)
[1600, 1800] loss:  tensor([0.3152, 0.5288, 0.4747, 0.1012]) tensor(1.4199)
[1800, 2000] loss:  tensor([0.3166, 0.5310, 0.4753, 0.1009]) tensor(1.4237)
[2000, 2200] loss:  tensor([0.3167, 0.5284, 0.4782, 0.1019]) tensor(1.4252)
[2200, 2400] loss:  tensor([0.3166, 0.5248, 0.4781, 0.1011]) tensor(1.4207)
[2400, 2600] loss:  tensor([0.3194, 0.5307, 0.4764, 0.1015]) tensor(1.4278)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 9: tensor([0.4333, 0.5832, 0.5394, 0.2515]) 1.807387177790654


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3147, 0.5271, 0.4748, 0.0994]) tensor(1.4160)
[200, 400] loss:  tensor([0.3165, 0.5269, 0.4701, 0.1012]) tensor(1.4148)
[400, 600] loss:  tensor([0.3155, 0.5294, 0.4701, 0.1019]) tensor(1.4167)
[600, 800] loss:  tensor([0.3158, 0.5281, 0.4725, 0.1015]) tensor(1.4179)
[800, 1000] loss:  tensor([0.3142, 0.5283, 0.4723, 0.1010]) tensor(1.4157)
[1000, 1200] loss:  tensor([0.3151, 0.5256, 0.4680, 0.1000]) tensor(1.4088)
[1200, 1400] loss:  tensor([0.3150, 0.5255, 0.4690, 0.1007]) tensor(1.4102)
[1400, 1600] loss:  tensor([0.3153, 0.5259, 0.4763, 0.1018]) tensor(1.4193)
[1600, 1800] loss:  tensor([0.3151, 0.5229, 0.4716, 0.1004]) tensor(1.4101)
[1800, 2000] loss:  tensor([0.3130, 0.5246, 0.4717, 0.1016]) tensor(1.4110)
[2000, 2200] loss:  tensor([0.3179, 0.5245, 0.4685, 0.1016]) tensor(1.4126)
[2200, 2400] loss:  tensor([0.3131, 0.5272, 0.4696, 0.1009]) tensor(1.4109)
[2400, 2600] loss:  tensor([0.3177, 0.5267, 0.4687, 0.1011]) tensor(1.4142)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 10: tensor([0.6304, 0.6046, 0.5287, 0.7560]) 2.5196724368109495


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3163, 0.5250, 0.4689, 0.1009]) tensor(1.4110)
[200, 400] loss:  tensor([0.3135, 0.5227, 0.4698, 0.1002]) tensor(1.4062)
[400, 600] loss:  tensor([0.3121, 0.5198, 0.4655, 0.1008]) tensor(1.3982)
[600, 800] loss:  tensor([0.3139, 0.5216, 0.4640, 0.1015]) tensor(1.4010)
[800, 1000] loss:  tensor([0.3152, 0.5247, 0.4640, 0.0999]) tensor(1.4038)
[1000, 1200] loss:  tensor([0.3087, 0.5230, 0.4703, 0.0994]) tensor(1.4014)
[1200, 1400] loss:  tensor([0.3104, 0.5225, 0.4670, 0.0999]) tensor(1.3999)
[1400, 1600] loss:  tensor([0.3087, 0.5235, 0.4652, 0.1001]) tensor(1.3975)
[1600, 1800] loss:  tensor([0.3134, 0.5231, 0.4708, 0.1000]) tensor(1.4073)
[1800, 2000] loss:  tensor([0.3105, 0.5236, 0.4692, 0.1025]) tensor(1.4057)
[2000, 2200] loss:  tensor([0.3098, 0.5260, 0.4621, 0.1000]) tensor(1.3979)
[2200, 2400] loss:  tensor([0.3121, 0.5186, 0.4718, 0.1004]) tensor(1.4029)
[2400, 2600] loss:  tensor([0.3105, 0.5246, 0.4636, 0.1017]) tensor(1.4004)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 11: tensor([0.4056, 0.5520, 0.4983, 0.3084]) 1.7643272803260461


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3093, 0.5203, 0.4651, 0.1007]) tensor(1.3954)
[200, 400] loss:  tensor([0.3108, 0.5202, 0.4631, 0.1009]) tensor(1.3951)
[400, 600] loss:  tensor([0.3105, 0.5191, 0.4672, 0.1009]) tensor(1.3977)
[600, 800] loss:  tensor([0.3083, 0.5148, 0.4672, 0.1000]) tensor(1.3903)
[800, 1000] loss:  tensor([0.3094, 0.5199, 0.4653, 0.1007]) tensor(1.3953)
[1000, 1200] loss:  tensor([0.3095, 0.5200, 0.4619, 0.1007]) tensor(1.3921)
[1200, 1400] loss:  tensor([0.3106, 0.5162, 0.4635, 0.1004]) tensor(1.3906)
[1400, 1600] loss:  tensor([0.3087, 0.5172, 0.4629, 0.0999]) tensor(1.3887)
[1600, 1800] loss:  tensor([0.3092, 0.5201, 0.4652, 0.0993]) tensor(1.3939)
[1800, 2000] loss:  tensor([0.3096, 0.5200, 0.4621, 0.1017]) tensor(1.3933)
[2000, 2200] loss:  tensor([0.3065, 0.5149, 0.4641, 0.1007]) tensor(1.3862)
[2200, 2400] loss:  tensor([0.3117, 0.5186, 0.4670, 0.1014]) tensor(1.3987)
[2400, 2600] loss:  tensor([0.3091, 0.5165, 0.4660, 0.1004]) tensor(1.3920)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 12: tensor([0.5441, 0.5888, 0.5254, 0.3955]) 2.0536979688674517


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3059, 0.5148, 0.4650, 0.1011]) tensor(1.3869)
[200, 400] loss:  tensor([0.3062, 0.5153, 0.4608, 0.1002]) tensor(1.3825)
[400, 600] loss:  tensor([0.3121, 0.5163, 0.4620, 0.0996]) tensor(1.3900)
[600, 800] loss:  tensor([0.3095, 0.5131, 0.4629, 0.1003]) tensor(1.3858)
[800, 1000] loss:  tensor([0.3064, 0.5182, 0.4614, 0.0997]) tensor(1.3858)
[1000, 1200] loss:  tensor([0.3063, 0.5171, 0.4625, 0.1000]) tensor(1.3860)
[1200, 1400] loss:  tensor([0.3101, 0.5150, 0.4592, 0.1013]) tensor(1.3855)
[1400, 1600] loss:  tensor([0.3073, 0.5172, 0.4599, 0.1015]) tensor(1.3858)
[1600, 1800] loss:  tensor([0.3040, 0.5151, 0.4610, 0.0997]) tensor(1.3798)
[1800, 2000] loss:  tensor([0.3081, 0.5169, 0.4626, 0.0995]) tensor(1.3871)
[2000, 2200] loss:  tensor([0.3071, 0.5114, 0.4589, 0.0997]) tensor(1.3771)
[2200, 2400] loss:  tensor([0.3084, 0.5160, 0.4630, 0.0992]) tensor(1.3866)
[2400, 2600] loss:  tensor([0.3103, 0.5111, 0.4618, 0.0997]) tensor(1.3829)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 13: tensor([0.3557, 0.5352, 0.4810, 0.2737]) 1.6456472864726885


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3036, 0.5110, 0.4573, 0.1002]) tensor(1.3721)
[200, 400] loss:  tensor([0.3041, 0.5189, 0.4585, 0.1003]) tensor(1.3818)
[400, 600] loss:  tensor([0.3076, 0.5114, 0.4591, 0.1000]) tensor(1.3782)
[600, 800] loss:  tensor([0.3044, 0.5126, 0.4587, 0.0991]) tensor(1.3748)
[800, 1000] loss:  tensor([0.3019, 0.5128, 0.4610, 0.0995]) tensor(1.3753)
[1000, 1200] loss:  tensor([0.3028, 0.5112, 0.4597, 0.0998]) tensor(1.3735)
[1200, 1400] loss:  tensor([0.3052, 0.5081, 0.4605, 0.0996]) tensor(1.3734)
[1400, 1600] loss:  tensor([0.3069, 0.5150, 0.4570, 0.1004]) tensor(1.3792)
[1600, 1800] loss:  tensor([0.3033, 0.5125, 0.4589, 0.0996]) tensor(1.3743)
[1800, 2000] loss:  tensor([0.3055, 0.5112, 0.4615, 0.0999]) tensor(1.3780)
[2000, 2200] loss:  tensor([0.3033, 0.5137, 0.4586, 0.1000]) tensor(1.3756)
[2200, 2400] loss:  tensor([0.3073, 0.5147, 0.4608, 0.0991]) tensor(1.3820)
[2400, 2600] loss:  tensor([0.3073, 0.5105, 0.4594, 0.0989]) tensor(1.3762)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 14: tensor([0.3489, 0.5267, 0.4702, 0.1402]) 1.4859478749060386


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3038, 0.5066, 0.4582, 0.1004]) tensor(1.3689)
[200, 400] loss:  tensor([0.3032, 0.5083, 0.4570, 0.0989]) tensor(1.3673)
[400, 600] loss:  tensor([0.3035, 0.5080, 0.4553, 0.0990]) tensor(1.3658)
[600, 800] loss:  tensor([0.3087, 0.5103, 0.4580, 0.1007]) tensor(1.3776)
[800, 1000] loss:  tensor([0.2992, 0.5071, 0.4559, 0.0987]) tensor(1.3609)
[1000, 1200] loss:  tensor([0.2993, 0.5108, 0.4599, 0.0985]) tensor(1.3685)
[1200, 1400] loss:  tensor([0.3069, 0.5059, 0.4603, 0.0996]) tensor(1.3727)
[1400, 1600] loss:  tensor([0.3021, 0.5046, 0.4581, 0.1001]) tensor(1.3648)
[1600, 1800] loss:  tensor([0.3010, 0.5111, 0.4592, 0.0994]) tensor(1.3707)
[1800, 2000] loss:  tensor([0.3050, 0.5052, 0.4588, 0.0998]) tensor(1.3689)
[2000, 2200] loss:  tensor([0.3043, 0.5113, 0.4571, 0.0995]) tensor(1.3722)
[2200, 2400] loss:  tensor([0.3096, 0.5097, 0.4585, 0.0998]) tensor(1.3775)
[2400, 2600] loss:  tensor([0.3047, 0.5053, 0.4583, 0.0994]) tensor(1.3677)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 15: tensor([0.6663, 0.5980, 0.6511, 0.3520]) 2.267470950263092


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.2987, 0.5059, 0.4527, 0.0987]) tensor(1.3559)
[200, 400] loss:  tensor([0.3010, 0.5046, 0.4552, 0.0988]) tensor(1.3596)
[400, 600] loss:  tensor([0.3052, 0.5052, 0.4552, 0.0988]) tensor(1.3644)
[600, 800] loss:  tensor([0.3045, 0.5096, 0.4542, 0.0993]) tensor(1.3676)
[800, 1000] loss:  tensor([0.3031, 0.5050, 0.4544, 0.0988]) tensor(1.3612)
[1000, 1200] loss:  tensor([0.3031, 0.5070, 0.4545, 0.0992]) tensor(1.3638)
[1200, 1400] loss:  tensor([0.2990, 0.5067, 0.4559, 0.0997]) tensor(1.3613)
[1400, 1600] loss:  tensor([0.3044, 0.5077, 0.4547, 0.0985]) tensor(1.3654)
[1600, 1800] loss:  tensor([0.2999, 0.5042, 0.4592, 0.0987]) tensor(1.3620)
[1800, 2000] loss:  tensor([0.3008, 0.5063, 0.4569, 0.1004]) tensor(1.3643)
[2000, 2200] loss:  tensor([0.3015, 0.5025, 0.4554, 0.0991]) tensor(1.3586)
[2200, 2400] loss:  tensor([0.3059, 0.5052, 0.4535, 0.0993]) tensor(1.3639)
[2400, 2600] loss:  tensor([0.3015, 0.5063, 0.4513, 0.0993]) tensor(1.3585)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 16: tensor([0.5558, 0.5902, 0.5463, 0.6332]) 2.325467898474693


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.2997, 0.5059, 0.4558, 0.0991]) tensor(1.3606)
[200, 400] loss:  tensor([0.2986, 0.5061, 0.4510, 0.0993]) tensor(1.3551)
[400, 600] loss:  tensor([0.3007, 0.5017, 0.4551, 0.0988]) tensor(1.3564)
[600, 800] loss:  tensor([0.2997, 0.5044, 0.4533, 0.0993]) tensor(1.3567)
[800, 1000] loss:  tensor([0.3006, 0.5067, 0.4489, 0.0995]) tensor(1.3556)
[1000, 1200] loss:  tensor([0.2984, 0.5008, 0.4511, 0.0985]) tensor(1.3487)
[1200, 1400] loss:  tensor([0.3004, 0.5043, 0.4511, 0.0981]) tensor(1.3539)
[1400, 1600] loss:  tensor([0.3018, 0.5055, 0.4478, 0.1002]) tensor(1.3553)
[1600, 1800] loss:  tensor([0.3003, 0.5032, 0.4513, 0.0988]) tensor(1.3537)
[1800, 2000] loss:  tensor([0.2980, 0.5063, 0.4525, 0.0987]) tensor(1.3555)
[2000, 2200] loss:  tensor([0.3038, 0.5048, 0.4515, 0.0984]) tensor(1.3585)
[2200, 2400] loss:  tensor([0.3011, 0.5064, 0.4538, 0.0994]) tensor(1.3606)
[2400, 2600] loss:  tensor([0.3014, 0.5032, 0.4539, 0.0990]) tensor(1.3575)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 17: tensor([0.6325, 0.6203, 0.5477, 0.8216]) 2.6221610929278376


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.2955, 0.5003, 0.4489, 0.0980]) tensor(1.3428)
[200, 400] loss:  tensor([0.2988, 0.5009, 0.4484, 0.0989]) tensor(1.3470)
[400, 600] loss:  tensor([0.2992, 0.5044, 0.4493, 0.0984]) tensor(1.3512)
[600, 800] loss:  tensor([0.3062, 0.5020, 0.4489, 0.0995]) tensor(1.3566)
[800, 1000] loss:  tensor([0.3002, 0.5013, 0.4499, 0.0987]) tensor(1.3501)
[1000, 1200] loss:  tensor([0.2998, 0.5045, 0.4508, 0.0982]) tensor(1.3532)
[1200, 1400] loss:  tensor([0.2996, 0.5017, 0.4546, 0.0981]) tensor(1.3540)
[1400, 1600] loss:  tensor([0.2960, 0.5037, 0.4482, 0.0982]) tensor(1.3460)
[1600, 1800] loss:  tensor([0.3019, 0.5006, 0.4520, 0.1006]) tensor(1.3551)
[1800, 2000] loss:  tensor([0.3009, 0.5040, 0.4487, 0.0995]) tensor(1.3532)
[2000, 2200] loss:  tensor([0.2989, 0.5008, 0.4505, 0.0989]) tensor(1.3491)
[2200, 2400] loss:  tensor([0.2995, 0.5037, 0.4521, 0.0988]) tensor(1.3540)
[2400, 2600] loss:  tensor([0.3006, 0.5047, 0.4481, 0.0985]) tensor(1.3519)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 18: tensor([0.6363, 0.6190, 0.5943, 0.2533]) 2.102947207388499


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.2957, 0.4979, 0.4478, 0.0988]) tensor(1.3403)
[200, 400] loss:  tensor([0.2964, 0.5006, 0.4531, 0.0981]) tensor(1.3483)
[400, 600] loss:  tensor([0.2946, 0.5006, 0.4473, 0.0988]) tensor(1.3413)
[600, 800] loss:  tensor([0.2930, 0.4974, 0.4469, 0.0983]) tensor(1.3357)
[800, 1000] loss:  tensor([0.2982, 0.5011, 0.4485, 0.0984]) tensor(1.3462)
[1000, 1200] loss:  tensor([0.2999, 0.4991, 0.4523, 0.0994]) tensor(1.3507)
[1200, 1400] loss:  tensor([0.2977, 0.5020, 0.4494, 0.0991]) tensor(1.3482)
[1400, 1600] loss:  tensor([0.2964, 0.5019, 0.4455, 0.0984]) tensor(1.3422)
[1600, 1800] loss:  tensor([0.2954, 0.5007, 0.4497, 0.0980]) tensor(1.3439)
[1800, 2000] loss:  tensor([0.2970, 0.4972, 0.4470, 0.0982]) tensor(1.3394)
[2000, 2200] loss:  tensor([0.3005, 0.5021, 0.4483, 0.0988]) tensor(1.3497)
[2200, 2400] loss:  tensor([0.2966, 0.4996, 0.4489, 0.0978]) tensor(1.3429)
[2400, 2600] loss:  tensor([0.2960, 0.5002, 0.4479, 0.0980]) tensor(1.3421)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 19: tensor([0.3870, 0.5302, 0.4730, 0.3558]) 1.7459207533356302


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.2927, 0.4968, 0.4431, 0.0986]) tensor(1.3313)
[200, 400] loss:  tensor([0.2961, 0.4976, 0.4460, 0.0989]) tensor(1.3386)
[400, 600] loss:  tensor([0.2923, 0.4971, 0.4458, 0.0978]) tensor(1.3330)
[600, 800] loss:  tensor([0.2948, 0.4982, 0.4458, 0.0992]) tensor(1.3380)
[800, 1000] loss:  tensor([0.2978, 0.4959, 0.4475, 0.0986]) tensor(1.3398)
[1000, 1200] loss:  tensor([0.2969, 0.4990, 0.4456, 0.0992]) tensor(1.3408)
[1200, 1400] loss:  tensor([0.2946, 0.4993, 0.4462, 0.0985]) tensor(1.3387)
[1400, 1600] loss:  tensor([0.2951, 0.4965, 0.4491, 0.0977]) tensor(1.3384)
[1600, 1800] loss:  tensor([0.2960, 0.4970, 0.4454, 0.0989]) tensor(1.3373)
[1800, 2000] loss:  tensor([0.2977, 0.4975, 0.4503, 0.0976]) tensor(1.3431)
[2000, 2200] loss:  tensor([0.2969, 0.4994, 0.4470, 0.0990]) tensor(1.3422)
[2200, 2400] loss:  tensor([0.2936, 0.4956, 0.4433, 0.0986]) tensor(1.3312)
[2400, 2600] loss:  tensor([0.2976, 0.5040, 0.4442, 0.0989]) tensor(1.3447)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 20: tensor([0.3814, 0.5458, 0.4992, 0.2088]) 1.6351506465719745


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.2967, 0.4943, 0.4464, 0.0991]) tensor(1.3365)
[200, 400] loss:  tensor([0.2936, 0.4918, 0.4462, 0.0979]) tensor(1.3294)
[400, 600] loss:  tensor([0.2954, 0.4889, 0.4441, 0.0987]) tensor(1.3271)
[600, 800] loss:  tensor([0.2929, 0.4962, 0.4400, 0.0982]) tensor(1.3273)
[800, 1000] loss:  tensor([0.2931, 0.4965, 0.4457, 0.0973]) tensor(1.3325)
[1000, 1200] loss:  tensor([0.2983, 0.4990, 0.4463, 0.0994]) tensor(1.3430)
[1200, 1400] loss:  tensor([0.2940, 0.4979, 0.4454, 0.0992]) tensor(1.3365)
[1400, 1600] loss:  tensor([0.2927, 0.4979, 0.4446, 0.0974]) tensor(1.3325)
[1600, 1800] loss:  tensor([0.2953, 0.4955, 0.4451, 0.0981]) tensor(1.3341)
[1800, 2000] loss:  tensor([0.2967, 0.4977, 0.4437, 0.0965]) tensor(1.3346)
[2000, 2200] loss:  tensor([0.2939, 0.4953, 0.4456, 0.0986]) tensor(1.3333)
[2200, 2400] loss:  tensor([0.2929, 0.4954, 0.4393, 0.0990]) tensor(1.3265)
[2400, 2600] loss:  tensor([0.2930, 0.4937, 0.4459, 0.0983]) tensor(1.3309)
[2600, 2800] loss:  ten

  0%|          | 0/63 [00:00<?, ?it/s]

validation loss for epoch 21: tensor([0.3641, 0.5249, 0.4662, 0.2204]) 1.5755831480205462


  0%|          | 0/6174 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.2886, 0.4948, 0.4413, 0.0980]) tensor(1.3228)
[200, 400] loss:  tensor([0.2925, 0.4953, 0.4435, 0.0968]) tensor(1.3281)
[400, 600] loss:  tensor([0.2959, 0.4914, 0.4455, 0.0983]) tensor(1.3312)
[600, 800] loss:  tensor([0.2928, 0.4937, 0.4446, 0.0983]) tensor(1.3294)
[800, 1000] loss:  tensor([0.2948, 0.4964, 0.4408, 0.0980]) tensor(1.3300)
[1000, 1200] loss:  tensor([0.2940, 0.4935, 0.4458, 0.0994]) tensor(1.3327)
[1200, 1400] loss:  tensor([0.2929, 0.4928, 0.4426, 0.0979]) tensor(1.3262)
[1400, 1600] loss:  tensor([0.2930, 0.4920, 0.4422, 0.0986]) tensor(1.3259)
[1600, 1800] loss:  tensor([0.2950, 0.4922, 0.4442, 0.0979]) tensor(1.3293)


KeyboardInterrupt: 

In [22]:
train_set = objaverse_sfm(dataset_root, 4, train=True, transform = ToTensor())
train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=32, drop_last=False)

test_set = objaverse_sfm(dataset_root, 4, train=False, transform = ToTensor())
test_dataloader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=16, drop_last=False)

model = sfm(resnet=True).to(device)
model.cnn.layer1.requires_grad_(False)
model.cnn.layer2.requires_grad_(False)
model.cnn.layer3.requires_grad_(False)

print('\n================== total trainable parameters: %d ==================\n' % sum(p.numel() for p in model.parameters() if p.requires_grad))
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

model = torch.nn.DataParallel(model, device_ids=[6, 7])
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.3)

for epoch in range(50):  # loop over the dataset multiple times

    print('validation')
    model.eval()
    test_loss = torch.zeros(4).to(device)
    with torch.no_grad():
        for i, batch in tqdm(enumerate(test_dataloader, 0), total=len(test_dataloader)):
            # get the inputs; data is a list of [inputs, labels]
            image_target, image_cond, T_gt = \
                batch['image_target'].to(device), batch['image_cond'].to(device), batch['T'].to(device)

            # forward + backward + optimize
            T_pred = model(image_cond, image_target)
            loss = (T_pred - T_gt).abs()
            test_loss += loss.sum(dim=0)
    
    print('validation loss for epoch %d:' % epoch, test_loss.cpu() / len(test_set), test_loss.cpu().sum().item() / len(test_set))
    
    running_loss = torch.zeros(4).to(device)
    
    model.train()
    
    for i, batch in tqdm(enumerate(train_dataloader, 0), total=len(train_dataloader)):
        # get the inputs; data is a list of [inputs, labels]
        image_target, image_cond, T_gt = \
            batch['image_target'].to(device), batch['image_cond'].to(device), batch['T'].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        T_pred = model(image_cond, image_target)
        loss = ((T_pred - T_gt)**2).mean(dim=0)
        l1_loss = ((T_pred - T_gt).abs()).mean(dim=0)
        running_loss += l1_loss.detach()
        loss = loss.sum()
        loss.backward()
        optimizer.step()

        # print statistics
        if i % 200 == 199:    # print every 2000 mini-batches
            print(f'[{i-199}, {i+1:d}] loss: ', running_loss.detach().cpu() / 200, \
                  running_loss.detach().cpu().sum() / 200)
            running_loss = torch.zeros(4).to(device)

            

initializing resnet weights


cnn.conv1.weight torch.Size([64, 6, 7, 7])
cnn.bn1.weight torch.Size([64])
cnn.bn1.bias torch.Size([64])
cnn.layer4.0.conv1.weight torch.Size([512, 256, 3, 3])
cnn.layer4.0.bn1.weight torch.Size([512])
cnn.layer4.0.bn1.bias torch.Size([512])
cnn.layer4.0.conv2.weight torch.Size([512, 512, 3, 3])
cnn.layer4.0.bn2.weight torch.Size([512])
cnn.layer4.0.bn2.bias torch.Size([512])
cnn.layer4.0.downsample.0.weight torch.Size([512, 256, 1, 1])
cnn.layer4.0.downsample.1.weight torch.Size([512])
cnn.layer4.0.downsample.1.bias torch.Size([512])
cnn.layer4.1.conv1.weight torch.Size([512, 512, 3, 3])
cnn.layer4.1.bn1.weight torch.Size([512])
cnn.layer4.1.bn1.bias torch.Size([512])
cnn.layer4.1.conv2.weight torch.Size([512, 512, 3, 3])
cnn.layer4.1.bn2.weight torch.Size([512])
cnn.layer4.1.bn2.bias torch.Size([512])
linear1.weight torch.Size([512, 512])
linear1.bias torch.Size([512])
linear2.weight torch.Size([4, 512])
linear2.bias torch.Size([4])
validation


  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 0: tensor([0.7673, 0.6362, 0.6355, 0.2494]) 2.288342618547983


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.6100, 0.6377, 0.6343, 0.1830]) tensor(2.0650)
[200, 400] loss:  tensor([0.5451, 0.6400, 0.6271, 0.1669]) tensor(1.9791)
[400, 600] loss:  tensor([0.5230, 0.6376, 0.6221, 0.1579]) tensor(1.9406)
[600, 800] loss:  tensor([0.5085, 0.6405, 0.6128, 0.1560]) tensor(1.9178)
[800, 1000] loss:  tensor([0.4991, 0.6318, 0.6167, 0.1528]) tensor(1.9004)
[1000, 1200] loss:  tensor([0.4901, 0.6396, 0.6111, 0.1493]) tensor(1.8900)
[1200, 1400] loss:  tensor([0.4919, 0.6355, 0.6108, 0.1493]) tensor(1.8875)
[1400, 1600] loss:  tensor([0.4757, 0.6398, 0.6017, 0.1456]) tensor(1.8629)
[1600, 1800] loss:  tensor([0.4697, 0.6348, 0.6084, 0.1428]) tensor(1.8557)
[1800, 2000] loss:  tensor([0.4670, 0.6384, 0.6060, 0.1430]) tensor(1.8545)
[2000, 2200] loss:  tensor([0.4684, 0.6418, 0.5965, 0.1406]) tensor(1.8472)
[2200, 2400] loss:  tensor([0.4615, 0.6384, 0.6001, 0.1400]) tensor(1.8401)
[2400, 2600] loss:  tensor([0.4582, 0.6339, 0.5986, 0.1406]) tensor(1.8313)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 1: tensor([0.4159, 0.6201, 0.5756, 0.1390]) 1.750625308310887


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.4070, 0.6233, 0.5595, 0.1251]) tensor(1.7149)
[200, 400] loss:  tensor([0.4047, 0.6182, 0.5646, 0.1238]) tensor(1.7113)
[400, 600] loss:  tensor([0.4099, 0.6173, 0.5659, 0.1230]) tensor(1.7162)
[600, 800] loss:  tensor([0.4038, 0.6190, 0.5599, 0.1262]) tensor(1.7089)
[800, 1000] loss:  tensor([0.4029, 0.6196, 0.5558, 0.1244]) tensor(1.7027)
[1000, 1200] loss:  tensor([0.3992, 0.6150, 0.5660, 0.1243]) tensor(1.7046)
[1200, 1400] loss:  tensor([0.4004, 0.6131, 0.5665, 0.1238]) tensor(1.7039)
[1400, 1600] loss:  tensor([0.4006, 0.6150, 0.5596, 0.1245]) tensor(1.6998)
[1600, 1800] loss:  tensor([0.4000, 0.6148, 0.5629, 0.1234]) tensor(1.7011)
[1800, 2000] loss:  tensor([0.4035, 0.6136, 0.5624, 0.1241]) tensor(1.7036)
[2000, 2200] loss:  tensor([0.4081, 0.6118, 0.5648, 0.1229]) tensor(1.7076)
[2200, 2400] loss:  tensor([0.4048, 0.6183, 0.5536, 0.1231]) tensor(1.6999)
[2400, 2600] loss:  tensor([0.4013, 0.6170, 0.5534, 0.1225]) tensor(1.6942)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 2: tensor([0.5251, 0.6194, 0.5622, 0.2801]) 1.9868236246398145


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3833, 0.6070, 0.5440, 0.1202]) tensor(1.6545)
[200, 400] loss:  tensor([0.3810, 0.6020, 0.5487, 0.1188]) tensor(1.6505)
[400, 600] loss:  tensor([0.3823, 0.6070, 0.5445, 0.1201]) tensor(1.6539)
[600, 800] loss:  tensor([0.3839, 0.6063, 0.5442, 0.1208]) tensor(1.6551)
[800, 1000] loss:  tensor([0.3831, 0.6061, 0.5433, 0.1198]) tensor(1.6522)
[1000, 1200] loss:  tensor([0.3898, 0.5992, 0.5402, 0.1214]) tensor(1.6506)
[1200, 1400] loss:  tensor([0.3827, 0.6049, 0.5435, 0.1199]) tensor(1.6511)
[1400, 1600] loss:  tensor([0.3829, 0.6077, 0.5393, 0.1194]) tensor(1.6493)
[1600, 1800] loss:  tensor([0.3911, 0.6005, 0.5452, 0.1200]) tensor(1.6568)
[1800, 2000] loss:  tensor([0.3818, 0.6014, 0.5469, 0.1205]) tensor(1.6507)
[2000, 2200] loss:  tensor([0.3815, 0.6060, 0.5394, 0.1203]) tensor(1.6473)
[2200, 2400] loss:  tensor([0.3854, 0.6012, 0.5423, 0.1193]) tensor(1.6482)
[2400, 2600] loss:  tensor([0.3844, 0.5997, 0.5503, 0.1196]) tensor(1.6540)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 3: tensor([0.3763, 0.5941, 0.5367, 0.1226]) 1.6298139319327862


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3709, 0.5922, 0.5312, 0.1170]) tensor(1.6113)
[200, 400] loss:  tensor([0.3727, 0.5873, 0.5388, 0.1176]) tensor(1.6164)
[400, 600] loss:  tensor([0.3732, 0.5942, 0.5296, 0.1159]) tensor(1.6129)
[600, 800] loss:  tensor([0.3709, 0.5895, 0.5368, 0.1166]) tensor(1.6138)
[800, 1000] loss:  tensor([0.3626, 0.5907, 0.5330, 0.1169]) tensor(1.6033)
[1000, 1200] loss:  tensor([0.3734, 0.5914, 0.5398, 0.1164]) tensor(1.6210)
[1200, 1400] loss:  tensor([0.3726, 0.5928, 0.5326, 0.1145]) tensor(1.6125)
[1400, 1600] loss:  tensor([0.3751, 0.5916, 0.5360, 0.1177]) tensor(1.6205)
[1600, 1800] loss:  tensor([0.3726, 0.5944, 0.5351, 0.1157]) tensor(1.6178)
[1800, 2000] loss:  tensor([0.3749, 0.5964, 0.5312, 0.1173]) tensor(1.6197)
[2000, 2200] loss:  tensor([0.3679, 0.5941, 0.5310, 0.1157]) tensor(1.6086)
[2200, 2400] loss:  tensor([0.3648, 0.5951, 0.5317, 0.1180]) tensor(1.6096)
[2400, 2600] loss:  tensor([0.3719, 0.5925, 0.5315, 0.1177]) tensor(1.6136)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 4: tensor([0.3520, 0.5799, 0.5292, 0.1344]) 1.5955352668504135


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3684, 0.5801, 0.5256, 0.1154]) tensor(1.5895)
[200, 400] loss:  tensor([0.3584, 0.5842, 0.5269, 0.1146]) tensor(1.5841)
[400, 600] loss:  tensor([0.3586, 0.5864, 0.5207, 0.1136]) tensor(1.5793)
[600, 800] loss:  tensor([0.3573, 0.5871, 0.5270, 0.1139]) tensor(1.5854)
[800, 1000] loss:  tensor([0.3630, 0.5841, 0.5296, 0.1159]) tensor(1.5926)
[1000, 1200] loss:  tensor([0.3643, 0.5840, 0.5277, 0.1154]) tensor(1.5914)
[1200, 1400] loss:  tensor([0.3616, 0.5818, 0.5255, 0.1148]) tensor(1.5837)
[1400, 1600] loss:  tensor([0.3672, 0.5824, 0.5260, 0.1151]) tensor(1.5907)
[1600, 1800] loss:  tensor([0.3601, 0.5794, 0.5280, 0.1146]) tensor(1.5820)
[1800, 2000] loss:  tensor([0.3587, 0.5805, 0.5284, 0.1160]) tensor(1.5837)
[2000, 2200] loss:  tensor([0.3570, 0.5807, 0.5233, 0.1140]) tensor(1.5750)
[2200, 2400] loss:  tensor([0.3598, 0.5853, 0.5284, 0.1152]) tensor(1.5887)
[2400, 2600] loss:  tensor([0.3583, 0.5845, 0.5284, 0.1143]) tensor(1.5855)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 5: tensor([0.3571, 0.5809, 0.5256, 0.1335]) 1.5971961082357178


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3526, 0.5757, 0.5245, 0.1146]) tensor(1.5675)
[200, 400] loss:  tensor([0.3523, 0.5730, 0.5214, 0.1155]) tensor(1.5623)
[400, 600] loss:  tensor([0.3659, 0.5796, 0.5228, 0.1160]) tensor(1.5843)
[600, 800] loss:  tensor([0.3518, 0.5753, 0.5141, 0.1137]) tensor(1.5549)
[800, 1000] loss:  tensor([0.3541, 0.5725, 0.5221, 0.1140]) tensor(1.5627)
[1000, 1200] loss:  tensor([0.3574, 0.5775, 0.5198, 0.1140]) tensor(1.5688)
[1200, 1400] loss:  tensor([0.3537, 0.5765, 0.5228, 0.1158]) tensor(1.5688)
[1400, 1600] loss:  tensor([0.3513, 0.5749, 0.5143, 0.1134]) tensor(1.5540)
[1600, 1800] loss:  tensor([0.3579, 0.5778, 0.5166, 0.1134]) tensor(1.5657)
[1800, 2000] loss:  tensor([0.3559, 0.5720, 0.5114, 0.1140]) tensor(1.5533)
[2000, 2200] loss:  tensor([0.3513, 0.5711, 0.5233, 0.1151]) tensor(1.5609)
[2200, 2400] loss:  tensor([0.3546, 0.5721, 0.5246, 0.1150]) tensor(1.5663)
[2400, 2600] loss:  tensor([0.3545, 0.5744, 0.5218, 0.1136]) tensor(1.5643)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 6: tensor([0.3576, 0.5653, 0.5176, 0.1228]) 1.5633326840860686


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3403, 0.5655, 0.5157, 0.1135]) tensor(1.5351)
[200, 400] loss:  tensor([0.3442, 0.5688, 0.5159, 0.1128]) tensor(1.5417)
[400, 600] loss:  tensor([0.3511, 0.5642, 0.5198, 0.1151]) tensor(1.5502)
[600, 800] loss:  tensor([0.3522, 0.5676, 0.5153, 0.1136]) tensor(1.5487)
[800, 1000] loss:  tensor([0.3472, 0.5748, 0.5148, 0.1127]) tensor(1.5496)
[1000, 1200] loss:  tensor([0.3510, 0.5677, 0.5162, 0.1132]) tensor(1.5481)
[1200, 1400] loss:  tensor([0.3514, 0.5740, 0.5179, 0.1140]) tensor(1.5573)
[1400, 1600] loss:  tensor([0.3488, 0.5628, 0.5164, 0.1119]) tensor(1.5400)
[1600, 1800] loss:  tensor([0.3455, 0.5654, 0.5144, 0.1150]) tensor(1.5403)
[1800, 2000] loss:  tensor([0.3476, 0.5665, 0.5149, 0.1128]) tensor(1.5419)
[2000, 2200] loss:  tensor([0.3526, 0.5626, 0.5201, 0.1140]) tensor(1.5493)
[2200, 2400] loss:  tensor([0.3532, 0.5715, 0.5128, 0.1141]) tensor(1.5516)
[2400, 2600] loss:  tensor([0.3515, 0.5660, 0.5178, 0.1138]) tensor(1.5490)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 7: tensor([0.3512, 0.5703, 0.5146, 0.1193]) 1.5554264673640692


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3472, 0.5666, 0.5144, 0.1138]) tensor(1.5420)
[200, 400] loss:  tensor([0.3461, 0.5617, 0.5148, 0.1117]) tensor(1.5344)
[400, 600] loss:  tensor([0.3472, 0.5563, 0.5137, 0.1141]) tensor(1.5313)
[600, 800] loss:  tensor([0.3494, 0.5608, 0.5160, 0.1134]) tensor(1.5396)
[800, 1000] loss:  tensor([0.3506, 0.5606, 0.5060, 0.1136]) tensor(1.5308)
[1000, 1200] loss:  tensor([0.3459, 0.5620, 0.5103, 0.1144]) tensor(1.5327)
[1200, 1400] loss:  tensor([0.3412, 0.5632, 0.5104, 0.1113]) tensor(1.5261)
[1400, 1600] loss:  tensor([0.3409, 0.5632, 0.5073, 0.1155]) tensor(1.5269)
[1600, 1800] loss:  tensor([0.3454, 0.5696, 0.5112, 0.1137]) tensor(1.5399)
[1800, 2000] loss:  tensor([0.3457, 0.5638, 0.5064, 0.1127]) tensor(1.5287)
[2000, 2200] loss:  tensor([0.3391, 0.5578, 0.5110, 0.1150]) tensor(1.5229)
[2200, 2400] loss:  tensor([0.3490, 0.5591, 0.5082, 0.1116]) tensor(1.5279)
[2400, 2600] loss:  tensor([0.3453, 0.5621, 0.5095, 0.1146]) tensor(1.5315)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 8: tensor([0.3781, 0.5734, 0.5256, 0.1365]) 1.6135856458281133


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3442, 0.5624, 0.5131, 0.1133]) tensor(1.5330)
[200, 400] loss:  tensor([0.3486, 0.5572, 0.5069, 0.1137]) tensor(1.5265)
[400, 600] loss:  tensor([0.3429, 0.5552, 0.5117, 0.1146]) tensor(1.5245)
[600, 800] loss:  tensor([0.3420, 0.5562, 0.5100, 0.1131]) tensor(1.5213)
[800, 1000] loss:  tensor([0.3472, 0.5619, 0.5125, 0.1125]) tensor(1.5341)
[1000, 1200] loss:  tensor([0.3434, 0.5563, 0.5069, 0.1120]) tensor(1.5186)
[1200, 1400] loss:  tensor([0.3479, 0.5618, 0.5042, 0.1122]) tensor(1.5262)
[1400, 1600] loss:  tensor([0.3445, 0.5554, 0.5058, 0.1129]) tensor(1.5186)
[1600, 1800] loss:  tensor([0.3366, 0.5541, 0.5121, 0.1116]) tensor(1.5144)
[1800, 2000] loss:  tensor([0.3370, 0.5537, 0.5024, 0.1113]) tensor(1.5044)
[2000, 2200] loss:  tensor([0.3400, 0.5557, 0.5050, 0.1115]) tensor(1.5122)
[2200, 2400] loss:  tensor([0.3431, 0.5553, 0.5058, 0.1128]) tensor(1.5170)
[2400, 2600] loss:  tensor([0.3402, 0.5592, 0.5132, 0.1126]) tensor(1.5252)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 9: tensor([0.3406, 0.5570, 0.5057, 0.1132]) 1.5164625779096719


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3373, 0.5529, 0.5091, 0.1134]) tensor(1.5127)
[200, 400] loss:  tensor([0.3427, 0.5564, 0.5048, 0.1126]) tensor(1.5166)
[400, 600] loss:  tensor([0.3413, 0.5555, 0.5014, 0.1129]) tensor(1.5111)
[600, 800] loss:  tensor([0.3355, 0.5551, 0.5050, 0.1119]) tensor(1.5075)
[800, 1000] loss:  tensor([0.3376, 0.5585, 0.5009, 0.1123]) tensor(1.5092)
[1000, 1200] loss:  tensor([0.3355, 0.5566, 0.5075, 0.1118]) tensor(1.5114)
[1200, 1400] loss:  tensor([0.3433, 0.5553, 0.5108, 0.1135]) tensor(1.5230)
[1400, 1600] loss:  tensor([0.3436, 0.5589, 0.5107, 0.1124]) tensor(1.5255)
[1600, 1800] loss:  tensor([0.3405, 0.5548, 0.5068, 0.1123]) tensor(1.5144)
[1800, 2000] loss:  tensor([0.3429, 0.5548, 0.5086, 0.1129]) tensor(1.5191)
[2000, 2200] loss:  tensor([0.3471, 0.5538, 0.5061, 0.1127]) tensor(1.5197)
[2200, 2400] loss:  tensor([0.3348, 0.5500, 0.5080, 0.1104]) tensor(1.5033)
[2400, 2600] loss:  tensor([0.3474, 0.5537, 0.4978, 0.1125]) tensor(1.5114)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 10: tensor([0.3420, 0.5483, 0.5015, 0.1125]) 1.5042420886447632


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3349, 0.5512, 0.5068, 0.1116]) tensor(1.5044)
[200, 400] loss:  tensor([0.3409, 0.5545, 0.5052, 0.1137]) tensor(1.5143)
[400, 600] loss:  tensor([0.3396, 0.5493, 0.5035, 0.1100]) tensor(1.5024)
[600, 800] loss:  tensor([0.3408, 0.5524, 0.5068, 0.1121]) tensor(1.5122)
[800, 1000] loss:  tensor([0.3352, 0.5489, 0.5001, 0.1119]) tensor(1.4960)
[1000, 1200] loss:  tensor([0.3386, 0.5520, 0.5041, 0.1127]) tensor(1.5074)
[1200, 1400] loss:  tensor([0.3403, 0.5540, 0.5062, 0.1132]) tensor(1.5136)
[1400, 1600] loss:  tensor([0.3451, 0.5556, 0.5070, 0.1126]) tensor(1.5203)
[1600, 1800] loss:  tensor([0.3408, 0.5475, 0.5059, 0.1121]) tensor(1.5062)
[1800, 2000] loss:  tensor([0.3455, 0.5515, 0.5057, 0.1109]) tensor(1.5136)
[2000, 2200] loss:  tensor([0.3416, 0.5502, 0.5059, 0.1118]) tensor(1.5095)
[2200, 2400] loss:  tensor([0.3401, 0.5524, 0.5047, 0.1128]) tensor(1.5100)
[2400, 2600] loss:  tensor([0.3427, 0.5527, 0.5084, 0.1133]) tensor(1.5172)
[2600, 2800] loss:  ten

  0%|          | 0/125 [00:00<?, ?it/s]

validation loss for epoch 11: tensor([0.3486, 0.5551, 0.5076, 0.1156]) 1.5269081996993235


  0%|          | 0/12347 [00:00<?, ?it/s]

[0, 200] loss:  tensor([0.3336, 0.5526, 0.5030, 0.1112]) tensor(1.5004)
[200, 400] loss:  tensor([0.3440, 0.5522, 0.4977, 0.1113]) tensor(1.5052)
[400, 600] loss:  tensor([0.3454, 0.5556, 0.4965, 0.1113]) tensor(1.5088)
[600, 800] loss:  tensor([0.3363, 0.5506, 0.5017, 0.1121]) tensor(1.5006)
[800, 1000] loss:  tensor([0.3387, 0.5501, 0.4984, 0.1122]) tensor(1.4993)
[1000, 1200] loss:  tensor([0.3416, 0.5506, 0.5075, 0.1118]) tensor(1.5115)
[1200, 1400] loss:  tensor([0.3360, 0.5488, 0.5056, 0.1109]) tensor(1.5014)
[1400, 1600] loss:  tensor([0.3367, 0.5548, 0.4989, 0.1109]) tensor(1.5013)
[1600, 1800] loss:  tensor([0.3416, 0.5500, 0.5018, 0.1124]) tensor(1.5058)
[1800, 2000] loss:  tensor([0.3416, 0.5486, 0.5017, 0.1120]) tensor(1.5040)
[2000, 2200] loss:  tensor([0.3388, 0.5486, 0.5058, 0.1128]) tensor(1.5061)
[2200, 2400] loss:  tensor([0.3376, 0.5537, 0.5034, 0.1109]) tensor(1.5057)
[2400, 2600] loss:  tensor([0.3377, 0.5512, 0.5017, 0.1116]) tensor(1.5022)
[2600, 2800] loss:  ten

KeyboardInterrupt: 

Process Process-573:
Process Process-559:
Process Process-555:
Process Process-576:
Process Process-575:
Process Process-564:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/rliu/ruoshi/anaconda3/lib/python3.8/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/rliu/ruoshi/anaconda3/lib/python3.8/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/rliu/ruoshi/anaconda3/lib/python3.8/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/rliu/ruoshi/anaconda3/lib/python3.8/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/rliu/ruoshi/anaconda3/lib/python3.8/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/rliu/