In [1]:
import h5py

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# models
discrete_encoder_path = 'discrete_encoder.pth'
discrete_decoder_path = 'discrete_decoder.pth'
mapping_path = 'mapping.pth'
unet_path = 'con_unet_full.pth'
continuous_model_path = 'continuous_model.pth'

hidden_dim_discrete = 128

# data preparation
data_file = '/home/ankbzpx/datasets/ShapeNet/ShapeNetRenderingh5_v1/03001627/sdf_train_core.h5'
sample_size = 2048
batch_size = 32
split_ratio = 0.9
depth_size = 256
num_of_workers = 12
# training
num_epoch = 40

In [4]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import torch.nn as nn
import torchvision.transforms.functional as TF
from torch.autograd import Variable
from PIL import Image
import sys
import time
import matplotlib.pyplot as plt
from torchsummary import summary

In [5]:
# reproducible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
torch.manual_seed(0)
np.random.seed(0)

In [6]:
device

device(type='cuda')

In [7]:
class ChairSDFDataset(Dataset):
     
    def __init__(self, h5_file):
        
        self.file_path = h5_file
        self.dataset = None
        
        with h5py.File(self.file_path, 'r') as file:
            self.dataset_len = len(file)
            self.keys = list(file.keys())
            
        self.to_tensor = transforms.ToTensor()
     
    def __len__(self):
        return self.dataset_len
 
    def __getitem__(self, idx):
        
        #start_time = time.time()
        
        if self.dataset is None:
            self.dataset = h5py.File(self.file_path, 'r')
         
        group = self.dataset[self.keys[idx]]
        
        depth_img = self.to_tensor(Image.fromarray(np.array(group['depth_img'])))
        
        #print("--- depth preprocessing %s seconds ---" % (time.time() - start_time))
        
        sample_pt_np = np.array(group['sample_pt']).reshape(-1, 3)
        sample_sdf_np = np.array(group['sample_sdf']).reshape(-1, 1)
        
        # check size correctness and fix incorrect data
        if sample_pt_np.shape[0] != 2048:
            sample_pt_np = np.pad(sample_pt_np, ((0, 2048 - sample_pt_np.shape[0]), (0, 0)), 'reflect')
        if sample_sdf_np.shape[0] != 2048:
            sample_sdf_np = np.pad(sample_sdf_np, ((0, 2048 - sample_sdf_np.shape[0]), (0, 0)), 'reflect')
            
        
        sample_pt = torch.from_numpy(sample_pt_np).float()
        sample_sdf = torch.from_numpy(sample_sdf_np).float()
        # scale sdf
        sample_sdf = torch.sign(sample_sdf)*torch.pow(torch.abs(sample_sdf), 0.25)
        
        #print("--- subsampling %s seconds ---" % (time.time() - start_time))
        
        target_vox = torch.from_numpy(np.array(group['target_vox'])).float()
        
        sample = { 'depth_img': depth_img,
                   'sample_pt':sample_pt,
                   'sample_sdf':sample_sdf,
                   'target_vox':target_vox,
                  }
        
        return sample

In [8]:
train_sdf_dataset = ChairSDFDataset(data_file)

train_sdf_dataloader = DataLoader(train_sdf_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)

In [9]:
len(train_sdf_dataloader)

2634

In [10]:
start_epoch = 0
latent_dim = 256

from models import Discrete_encoder, Mapping, Discrete_decoder, Conditional_UNET

####################
# Discrete Encoder #
####################

discrete_encoder = Discrete_encoder(256).to(device)
discrete_encoder.load_state_dict(torch.load(discrete_encoder_path))
discrete_encoder.eval()

for child in discrete_encoder.children():
    for param in child.parameters():
        param.requires_grad = False


####################
# Mapping #
####################

mapping = Mapping().to(device)
mapping.load_state_dict(torch.load(mapping_path))
mapping.eval()

for child in mapping.children():
    for param in child.parameters():
        param.requires_grad = False
        
# ####################
# # Discrete Decoder #
# ####################
        
discrete_decoder = Discrete_decoder(256).to(device)
discrete_decoder.load_state_dict(torch.load(discrete_decoder_path))
discrete_decoder.eval()

for child in discrete_decoder.children():
    for param in child.parameters():
        param.requires_grad = False
        
########
# UNET #
########

# pre-trained model is loaded within the model
unet = Conditional_UNET(unet_path).to(device)
unet.eval()

for child in unet.children():
    for param in child.parameters():
        param.requires_grad = False

In [11]:
class BottleNect1D(nn.Module):
    def __init__(self, input_dim, expand = 5):
        super(BottleNect1D, self).__init__()
        
        self.block = nn.Sequential(
            nn.Linear(input_dim, expand*input_dim),
            nn.BatchNorm1d(expand*input_dim),
            nn.ReLU(),
            nn.Linear(expand*input_dim, input_dim),
            nn.BatchNorm1d(input_dim),
        )
    
    def forward(self, x):
        return x + self.block(x)

class Continuous(nn.Module):
    def __init__(self, pt_dim = 3, con_dim = 32, latent_dim = 256):
        super(Continuous, self).__init__()
        
        self.de_pt =  nn.Sequential(
            nn.Linear(pt_dim + con_dim + latent_dim, latent_dim),
            nn.BatchNorm1d(latent_dim),
            BottleNect1D(latent_dim),
        )
        
        self.de_1 = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.BatchNorm1d(latent_dim),
            BottleNect1D(latent_dim),
        )
        
        self.de_2 = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.BatchNorm1d(latent_dim),
            BottleNect1D(latent_dim),
            nn.Linear(latent_dim, 1),
            nn.Tanh(),
        )
        
    def forward(self, pt, con, z):
        
        fea = self.de_pt(torch.cat((torch.cat((pt, con), 1), z), 1))
        out = self.de_1(fea) + fea
        out = self.de_2(out)
        
        return out

In [12]:
continuous = Continuous().to(device)
continuous.load_state_dict(torch.load(continuous_model_path))

In [13]:
summary(continuous, [(3, ), (32, ), (256, )])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 256]          74,752
       BatchNorm1d-2                  [-1, 256]             512
            Linear-3                 [-1, 1280]         328,960
       BatchNorm1d-4                 [-1, 1280]           2,560
              ReLU-5                 [-1, 1280]               0
            Linear-6                  [-1, 256]         327,936
       BatchNorm1d-7                  [-1, 256]             512
      BottleNect1D-8                  [-1, 256]               0
            Linear-9                  [-1, 256]          65,792
      BatchNorm1d-10                  [-1, 256]             512
           Linear-11                 [-1, 1280]         328,960
      BatchNorm1d-12                 [-1, 1280]           2,560
             ReLU-13                 [-1, 1280]               0
           Linear-14                  [

In [14]:
l1loss = nn.L1Loss()

In [15]:
lr = 1e-3

model = continuous

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [16]:
def save_checkpoint(state):
    torch.save(state, 'continuous_model_checkpoint.pth.tar')

In [14]:
checkpoint = torch.load('continuous_model_checkpoint.pth.tar')
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['continuous_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

In [17]:
# advanced indexing 2x2x2 context from voxel
def getContext(sample_pt_query, vox):
    
    # sample_pt bxmx3
    # vox bxcxdimxdimxdim
    
    channel_size = vox.shape[1]
    batch_size, sample_size, _ = sample_pt_query.shape
    meshgrid_base = torch.Tensor(np.meshgrid(np.arange(0, batch_size), np.arange(0, channel_size), np.arange(0, 2), np.arange(0, 2), np.arange(0, 2))).int()
    context = torch.empty((batch_size, sample_size, channel_size, 2, 2, 2))
    
    for j in range(context.shape[1]):
        context[:, j, :, :, :, :] = vox[
                    meshgrid_base[0].long(),
                    meshgrid_base[1].long(),
                    (meshgrid_base[2] + sample_pt_query[:, j, 0].reshape(1, -1, 1, 1, 1)).long(), 
                    (meshgrid_base[3] + sample_pt_query[:, j, 1].reshape(1, -1, 1, 1, 1)).long(), 
                    (meshgrid_base[4] + sample_pt_query[:, j, 2].reshape(1, -1, 1, 1, 1)).long()
                ].transpose(0, 1)
    
    # b x c x m x 2 x 2 x 2
    return context.transpose(1, 2)

In [18]:
def trilinearInterpolation(context, dx, dy, dz):
    
    v0 = context[:, :, :, 0, 0, 0]*(1-dx)*(1-dy)*(1-dz)
    v1 = context[:, :, :, 1, 0, 0]*dx*(1-dy)*(1-dz)
    v2 = context[:, :, :, 0, 1, 0]*(1-dx)*dy*(1-dz)
    v3 = context[:, :, :, 1, 1, 0]*dx*dy*(1-dz)
    v4 = context[:, :, :, 0, 0, 1]*(1-dx)*(1-dy)*dz
    v5 = context[:, :, :, 1, 0, 1]*dx*(1-dy)*dz
    v6 = context[:, :, :, 0, 1, 1]*(1-dx)*dy*dz
    v7 = context[:, :, :, 1, 1, 1]*dx*dy*dz
    
    # b x c x m 1
    return v0 + v1 + v2 + v3 + v4 + v5 + v6 + v7

In [19]:
#########################################
#test sub module

import time

vox_size = 32
latent_dim = 256
con_dim = 32

batch_len = len(train_sdf_dataloader)

print("Starting Training Loop...")

start_time = time.time()

for epoch in range(start_epoch, num_epoch):
    
    loss_list = []
    loss_batch = []
    
    print("Epoch: ", epoch)
    
    count = 0
    
    for i, data in enumerate(train_sdf_dataloader):
        
        ####################
        # Data preparation #
        ####################
        
        # b x 1 x 256 x 256
        depth_img = data['depth_img'].to(device)
        # b x 128 x 1 x 1
        z = discrete_encoder(depth_img)
        # b x n x 3
        # DO NOT scale by np.sqrt(3)
        sample_pt = data['sample_pt']
        # b x n x 1
        sample_sdf = data['sample_sdf']
        
        # b x 16 x 64 x 64 x 64
        target_vox = data['target_vox'].to(device)
        vox_feature = unet(target_vox, z)
        #vox_feature = unet(torch.sigmoid(discrete_decoder(mapping(z))))
        
        ####################
        # indexing context #
        ####################
        
        # stay with cpu for v-ram efficiency
        sample_pt_normalized = sample_pt + torch.tensor([0.5, 0.5, 0.5])
        # (0, vox_size-1)
        sample_pt_scale = torch.clamp(sample_pt_normalized* (vox_size-1), 0, (vox_size-1)-1e-5)
        # (0, vox_size-2)
        sample_pt_query = torch.clamp((sample_pt_scale).int(), 0, (vox_size-2))
        sample_pt_distance = sample_pt_scale - sample_pt_query
        
        context = getContext(sample_pt_query, vox_feature)
        
        dx = sample_pt_distance[:, :, 0].unsqueeze(1)
        dy = sample_pt_distance[:, :, 1].unsqueeze(1)
        dz = sample_pt_distance[:, :, 2].unsqueeze(1)
        # local feature
        con = trilinearInterpolation(context, dx, dy, dz)
        
        ################################
        # Reshape input & forward pass #
        ################################
        
        sample_pt = sample_pt.transpose(-1, -2).to(device)
        con = con.to(device)
        z = z.squeeze(-1).squeeze(-1).repeat(1, 1, sample_size)
        sample_sdf = sample_sdf.transpose(-1, -2).to(device)
        
        
        sample_pt = sample_pt.transpose(-1, -2).reshape(-1, 3)
        con = con.transpose(-1, -2).reshape(-1, con_dim)
        z = z.transpose(-1, -2).reshape(-1, latent_dim)
        sample_sdf = sample_sdf.transpose(-1, -2).reshape(-1, 1)
        
        
        optimizer.zero_grad()
        
        pred_sdf = model(sample_pt, con, z)
        
        loss_l1 = l1loss(pred_sdf, sample_sdf)
        
        loss = loss_l1
        
        loss.backward()
        
        loss_list.append(loss_l1.item())
        loss_batch.append(loss_l1.item())
        
        optimizer.step()
        
        #scheduler.step()
        
        if count != 0 and count % 10 == 0:
            loss_batch_avg = np.average(loss_batch)
            
            print("Batch: ", count, ", l1 Loss: ", loss_batch_avg, ", Time: %s s" % (time.time() - start_time))
            
            if count % 500 == 0:
                torch.save(model.state_dict(), continuous_model_path)
                
            loss_batch.clear()
            
        count += 1
        
    print("Epoch: ", epoch, ', l1 loss: ', np.average(loss_list))
    
    loss_list.clear()
    
    torch.save(model.state_dict(), continuous_model_path)
    
    save_checkpoint({
        'epoch': epoch + 1,
        'continuous_state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    })
    
print("Training finished")

Starting Training Loop...
Epoch:  0
Batch:  10 , l1 Loss:  0.6356682235544379 , Time: 12.870596170425415 s
Batch:  20 , l1 Loss:  0.26166956722736356 , Time: 24.250309228897095 s
Batch:  30 , l1 Loss:  0.14672429114580154 , Time: 35.689682960510254 s
Batch:  40 , l1 Loss:  0.12958474606275558 , Time: 47.236165285110474 s
Batch:  50 , l1 Loss:  0.12115571945905686 , Time: 58.75198841094971 s
Batch:  60 , l1 Loss:  0.1182081013917923 , Time: 70.26559329032898 s
Batch:  70 , l1 Loss:  0.11695915758609772 , Time: 81.7999517917633 s
Batch:  80 , l1 Loss:  0.11439906805753708 , Time: 93.38381791114807 s
Batch:  90 , l1 Loss:  0.11247901022434234 , Time: 104.96345615386963 s
Batch:  100 , l1 Loss:  0.11350428611040116 , Time: 116.56965374946594 s
Batch:  110 , l1 Loss:  0.11099378541111946 , Time: 128.1645200252533 s
Batch:  120 , l1 Loss:  0.11063363701105118 , Time: 139.77182364463806 s
Batch:  130 , l1 Loss:  0.11276050209999085 , Time: 151.28039693832397 s
Batch:  140 , l1 Loss:  0.111937

Batch:  1150 , l1 Loss:  0.10257095545530319 , Time: 1347.9632287025452 s
Batch:  1160 , l1 Loss:  0.10255348831415176 , Time: 1359.7036740779877 s
Batch:  1170 , l1 Loss:  0.10365713611245156 , Time: 1371.412005186081 s
Batch:  1180 , l1 Loss:  0.10360780507326126 , Time: 1383.0767753124237 s
Batch:  1190 , l1 Loss:  0.10288463160395622 , Time: 1394.755003452301 s
Batch:  1200 , l1 Loss:  0.1023711919784546 , Time: 1406.4381167888641 s
Batch:  1210 , l1 Loss:  0.10126769915223122 , Time: 1418.1188027858734 s
Batch:  1220 , l1 Loss:  0.10211240127682686 , Time: 1429.8157942295074 s
Batch:  1230 , l1 Loss:  0.1044051043689251 , Time: 1441.5042808055878 s
Batch:  1240 , l1 Loss:  0.10284226909279823 , Time: 1453.1773653030396 s
Batch:  1250 , l1 Loss:  0.1030073568224907 , Time: 1464.8562614917755 s
Batch:  1260 , l1 Loss:  0.10184578374028205 , Time: 1476.5433089733124 s
Batch:  1270 , l1 Loss:  0.10219207108020782 , Time: 1488.2690043449402 s
Batch:  1280 , l1 Loss:  0.1012992545962333

Batch:  2270 , l1 Loss:  0.10122964456677437 , Time: 2657.5434470176697 s
Batch:  2280 , l1 Loss:  0.101145289093256 , Time: 2669.196286916733 s
Batch:  2290 , l1 Loss:  0.10093993917107583 , Time: 2680.892108440399 s
Batch:  2300 , l1 Loss:  0.10091946497559548 , Time: 2692.5340909957886 s
Batch:  2310 , l1 Loss:  0.10063840821385384 , Time: 2704.1932871341705 s
Batch:  2320 , l1 Loss:  0.10019042864441871 , Time: 2715.864919900894 s
Batch:  2330 , l1 Loss:  0.10168906971812249 , Time: 2727.51092672348 s
Batch:  2340 , l1 Loss:  0.10176944360136986 , Time: 2739.1514236927032 s
Batch:  2350 , l1 Loss:  0.10123281255364418 , Time: 2750.8056411743164 s
Batch:  2360 , l1 Loss:  0.10145051926374435 , Time: 2762.4651079177856 s
Batch:  2370 , l1 Loss:  0.10287504866719246 , Time: 2774.173999071121 s
Batch:  2380 , l1 Loss:  0.10318695977330208 , Time: 2785.852584838867 s
Batch:  2390 , l1 Loss:  0.1008954480290413 , Time: 2797.538878917694 s
Batch:  2400 , l1 Loss:  0.10261690616607666 , Ti

Batch:  770 , l1 Loss:  0.09956611394882202 , Time: 4010.910325527191 s
Batch:  780 , l1 Loss:  0.09986158534884453 , Time: 4023.1513047218323 s
Batch:  790 , l1 Loss:  0.10026874542236328 , Time: 4035.336446046829 s
Batch:  800 , l1 Loss:  0.10093663781881332 , Time: 4047.5101838111877 s
Batch:  810 , l1 Loss:  0.10096770823001862 , Time: 4059.648444414139 s
Batch:  820 , l1 Loss:  0.09967626556754113 , Time: 4071.841958761215 s
Batch:  830 , l1 Loss:  0.09932728260755538 , Time: 4083.9532120227814 s
Batch:  840 , l1 Loss:  0.09944009184837341 , Time: 4095.9055304527283 s
Batch:  850 , l1 Loss:  0.10089885145425796 , Time: 4107.824432134628 s
Batch:  860 , l1 Loss:  0.10061604753136635 , Time: 4119.724838256836 s
Batch:  870 , l1 Loss:  0.10002379789948464 , Time: 4131.540773391724 s
Batch:  880 , l1 Loss:  0.10108595341444016 , Time: 4143.474110841751 s
Batch:  890 , l1 Loss:  0.09997587725520134 , Time: 4155.40248966217 s
Batch:  900 , l1 Loss:  0.10156461521983147 , Time: 4167.3304

Batch:  1900 , l1 Loss:  0.09957264587283135 , Time: 5366.2438180446625 s
Batch:  1910 , l1 Loss:  0.10050475522875786 , Time: 5378.329997301102 s
Batch:  1920 , l1 Loss:  0.10036210119724273 , Time: 5390.439720630646 s
Batch:  1930 , l1 Loss:  0.09999601617455482 , Time: 5402.623251438141 s
Batch:  1940 , l1 Loss:  0.09928654134273529 , Time: 5414.893968820572 s
Batch:  1950 , l1 Loss:  0.10039583966135979 , Time: 5427.18314409256 s
Batch:  1960 , l1 Loss:  0.10121290311217308 , Time: 5439.585829496384 s
Batch:  1970 , l1 Loss:  0.09936719760298729 , Time: 5451.811643362045 s
Batch:  1980 , l1 Loss:  0.10084240660071372 , Time: 5464.061973810196 s
Batch:  1990 , l1 Loss:  0.10088977739214897 , Time: 5476.353795051575 s
Batch:  2000 , l1 Loss:  0.09937947914004326 , Time: 5488.524446964264 s
Batch:  2010 , l1 Loss:  0.09952156394720077 , Time: 5500.785048246384 s
Batch:  2020 , l1 Loss:  0.1008484996855259 , Time: 5513.130594730377 s
Batch:  2030 , l1 Loss:  0.10073452964425086 , Time:

Batch:  400 , l1 Loss:  0.09857395738363266 , Time: 6731.731207609177 s
Batch:  410 , l1 Loss:  0.09985497742891311 , Time: 6743.604747772217 s
Batch:  420 , l1 Loss:  0.09956611916422844 , Time: 6755.483067989349 s
Batch:  430 , l1 Loss:  0.1002014510333538 , Time: 6767.431567430496 s
Batch:  440 , l1 Loss:  0.0992838516831398 , Time: 6779.333445072174 s
Batch:  450 , l1 Loss:  0.09856216087937356 , Time: 6791.199228048325 s
Batch:  460 , l1 Loss:  0.09912454336881638 , Time: 6803.089158296585 s
Batch:  470 , l1 Loss:  0.10022744908928871 , Time: 6814.9557819366455 s
Batch:  480 , l1 Loss:  0.09863841459155083 , Time: 6826.815460443497 s
Batch:  490 , l1 Loss:  0.09987311959266662 , Time: 6838.567478179932 s
Batch:  500 , l1 Loss:  0.09993821531534194 , Time: 6850.337361812592 s
Batch:  510 , l1 Loss:  0.09772551953792571 , Time: 6862.1086275577545 s
Batch:  520 , l1 Loss:  0.09980422481894494 , Time: 6873.940417051315 s
Batch:  530 , l1 Loss:  0.09820955023169517 , Time: 6885.7929909

Batch:  1540 , l1 Loss:  0.0997310921549797 , Time: 8090.241664648056 s
Batch:  1550 , l1 Loss:  0.09986224174499511 , Time: 8102.133373498917 s
Batch:  1560 , l1 Loss:  0.09982054308056831 , Time: 8114.026735067368 s
Batch:  1570 , l1 Loss:  0.10010933354496956 , Time: 8125.903761386871 s
Batch:  1580 , l1 Loss:  0.0992144450545311 , Time: 8137.773485898972 s
Batch:  1590 , l1 Loss:  0.09896508604288101 , Time: 8149.643887281418 s
Batch:  1600 , l1 Loss:  0.09966051205992699 , Time: 8161.520415067673 s
Batch:  1610 , l1 Loss:  0.10041618272662163 , Time: 8173.407382726669 s
Batch:  1620 , l1 Loss:  0.09931881353259087 , Time: 8185.280114650726 s
Batch:  1630 , l1 Loss:  0.1002419576048851 , Time: 8197.135855674744 s
Batch:  1640 , l1 Loss:  0.09978322982788086 , Time: 8209.048544883728 s
Batch:  1650 , l1 Loss:  0.09841089025139808 , Time: 8220.910330295563 s
Batch:  1660 , l1 Loss:  0.09873040989041329 , Time: 8232.747718572617 s
Batch:  1670 , l1 Loss:  0.09847219735383987 , Time: 8

Batch:  40 , l1 Loss:  0.09871150851249695 , Time: 9441.85818362236 s
Batch:  50 , l1 Loss:  0.09887349084019662 , Time: 9453.691535711288 s
Batch:  60 , l1 Loss:  0.09985054060816764 , Time: 9465.516493797302 s
Batch:  70 , l1 Loss:  0.09965771287679673 , Time: 9477.724455595016 s
Batch:  80 , l1 Loss:  0.09971566200256347 , Time: 9489.955788612366 s
Batch:  90 , l1 Loss:  0.09877900034189224 , Time: 9502.134913921356 s
Batch:  100 , l1 Loss:  0.09867423549294471 , Time: 9514.338024377823 s
Batch:  110 , l1 Loss:  0.09978807494044303 , Time: 9526.24976015091 s
Batch:  120 , l1 Loss:  0.09892596676945686 , Time: 9537.988258123398 s
Batch:  130 , l1 Loss:  0.0989707313477993 , Time: 9549.740361690521 s
Batch:  140 , l1 Loss:  0.09947738125920295 , Time: 9561.82054233551 s
Batch:  150 , l1 Loss:  0.09883525520563126 , Time: 9574.02545928955 s
Batch:  160 , l1 Loss:  0.09808789044618607 , Time: 9586.202939987183 s
Batch:  170 , l1 Loss:  0.09864782691001892 , Time: 9598.352454662323 s
Bat

Batch:  1180 , l1 Loss:  0.09789760708808899 , Time: 10812.03312087059 s
Batch:  1190 , l1 Loss:  0.09823333099484444 , Time: 10824.139057159424 s
Batch:  1200 , l1 Loss:  0.09769949689507484 , Time: 10836.229548215866 s
Batch:  1210 , l1 Loss:  0.09922126531600953 , Time: 10848.379180908203 s
Batch:  1220 , l1 Loss:  0.09796006456017495 , Time: 10860.350600719452 s
Batch:  1230 , l1 Loss:  0.09897981211543083 , Time: 10872.093367099762 s
Batch:  1240 , l1 Loss:  0.09829192087054253 , Time: 10883.837620735168 s
Batch:  1250 , l1 Loss:  0.09819738939404488 , Time: 10895.587580442429 s
Batch:  1260 , l1 Loss:  0.09759861454367638 , Time: 10907.338278770447 s
Batch:  1270 , l1 Loss:  0.097577565908432 , Time: 10919.168296813965 s
Batch:  1280 , l1 Loss:  0.09819739982485771 , Time: 10930.950378894806 s
Batch:  1290 , l1 Loss:  0.09732119217514992 , Time: 10942.74713563919 s
Batch:  1300 , l1 Loss:  0.09879650101065636 , Time: 10954.563338756561 s
Batch:  1310 , l1 Loss:  0.097094298899173

Batch:  2300 , l1 Loss:  0.09856067821383477 , Time: 12148.551580190659 s
Batch:  2310 , l1 Loss:  0.09846526831388473 , Time: 12160.525727510452 s
Batch:  2320 , l1 Loss:  0.09758670702576637 , Time: 12172.472474575043 s
Batch:  2330 , l1 Loss:  0.09854372516274452 , Time: 12184.374491930008 s
Batch:  2340 , l1 Loss:  0.09913661554455758 , Time: 12196.30502486229 s
Batch:  2350 , l1 Loss:  0.09844968467950821 , Time: 12208.21138381958 s
Batch:  2360 , l1 Loss:  0.09831609502434731 , Time: 12220.16304898262 s
Batch:  2370 , l1 Loss:  0.09917841255664825 , Time: 12232.08529829979 s
Batch:  2380 , l1 Loss:  0.09924081116914749 , Time: 12244.026213169098 s
Batch:  2390 , l1 Loss:  0.0991397425532341 , Time: 12255.96628832817 s
Batch:  2400 , l1 Loss:  0.09796633720397949 , Time: 12267.872453212738 s
Batch:  2410 , l1 Loss:  0.09789662137627601 , Time: 12279.825081825256 s
Batch:  2420 , l1 Loss:  0.09792011082172394 , Time: 12291.795195817947 s
Batch:  2430 , l1 Loss:  0.09775489196181297

Batch:  790 , l1 Loss:  0.09883271604776382 , Time: 13489.263092041016 s
Batch:  800 , l1 Loss:  0.0975725919008255 , Time: 13501.032339572906 s
Batch:  810 , l1 Loss:  0.09833245128393173 , Time: 13512.805270671844 s
Batch:  820 , l1 Loss:  0.09795412868261337 , Time: 13524.571369409561 s
Batch:  830 , l1 Loss:  0.09727412834763527 , Time: 13536.39691567421 s
Batch:  840 , l1 Loss:  0.09787286669015885 , Time: 13548.234284162521 s
Batch:  850 , l1 Loss:  0.097813730686903 , Time: 13560.12653017044 s
Batch:  860 , l1 Loss:  0.09710708633065224 , Time: 13572.053849220276 s
Batch:  870 , l1 Loss:  0.09699229747056962 , Time: 13583.990653276443 s
Batch:  880 , l1 Loss:  0.09822850301861763 , Time: 13595.8990213871 s
Batch:  890 , l1 Loss:  0.09822874665260314 , Time: 13607.831624507904 s
Batch:  900 , l1 Loss:  0.09749066010117531 , Time: 13619.743229866028 s
Batch:  910 , l1 Loss:  0.09811577200889587 , Time: 13631.670124053955 s
Batch:  920 , l1 Loss:  0.09783555418252946 , Time: 13643.

Batch:  1910 , l1 Loss:  0.09669868648052216 , Time: 14821.291342020035 s
Batch:  1920 , l1 Loss:  0.09828607514500617 , Time: 14833.203172206879 s
Batch:  1930 , l1 Loss:  0.0979889489710331 , Time: 14845.055113077164 s
Batch:  1940 , l1 Loss:  0.09672467336058617 , Time: 14856.957318782806 s
Batch:  1950 , l1 Loss:  0.0968216374516487 , Time: 14868.831486701965 s
Batch:  1960 , l1 Loss:  0.09732453525066376 , Time: 14880.747684240341 s
Batch:  1970 , l1 Loss:  0.09680458754301072 , Time: 14892.625725507736 s
Batch:  1980 , l1 Loss:  0.09801888391375542 , Time: 14904.532778978348 s
Batch:  1990 , l1 Loss:  0.09778067097067833 , Time: 14916.413763284683 s
Batch:  2000 , l1 Loss:  0.09692433550953865 , Time: 14928.34209227562 s
Batch:  2010 , l1 Loss:  0.09695871323347091 , Time: 14940.585692882538 s
Batch:  2020 , l1 Loss:  0.09815188571810722 , Time: 14952.805101394653 s
Batch:  2030 , l1 Loss:  0.09714054465293884 , Time: 14964.988294839859 s
Batch:  2040 , l1 Loss:  0.09653128683567

Batch:  400 , l1 Loss:  0.09726045280694962 , Time: 16168.936507701874 s
Batch:  410 , l1 Loss:  0.09758597984910011 , Time: 16180.813372135162 s
Batch:  420 , l1 Loss:  0.0976669579744339 , Time: 16192.72688984871 s
Batch:  430 , l1 Loss:  0.09675105884671212 , Time: 16204.674311637878 s
Batch:  440 , l1 Loss:  0.09660834968090057 , Time: 16216.549914360046 s
Batch:  450 , l1 Loss:  0.09608095288276672 , Time: 16228.439908027649 s
Batch:  460 , l1 Loss:  0.09758634939789772 , Time: 16240.46865606308 s
Batch:  470 , l1 Loss:  0.09664669707417488 , Time: 16252.347853422165 s
Batch:  480 , l1 Loss:  0.09641780480742454 , Time: 16264.212922096252 s
Batch:  490 , l1 Loss:  0.09723957255482674 , Time: 16276.081236362457 s
Batch:  500 , l1 Loss:  0.09670391827821731 , Time: 16288.035796165466 s
Batch:  510 , l1 Loss:  0.09788425788283348 , Time: 16300.1062874794 s
Batch:  520 , l1 Loss:  0.0968261905014515 , Time: 16312.050725460052 s
Batch:  530 , l1 Loss:  0.09667059630155564 , Time: 16324

KeyboardInterrupt: 