In [1]:
import h5py

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# models
MODEL_PATH = '/home/ankbzpx/Documents/'

discrete_encoder_path = MODEL_PATH + 'discrete continuous shape/discrete_encoder.pth'
discrete_decoder_path = MODEL_PATH + 'discrete continuous shape/discrete_decoder.pth'
unet_path = MODEL_PATH + 'discrete continuous shape/con_unet_full.pth'
continuous_model_path = MODEL_PATH + 'discrete continuous shape/continuous_model.pth'

hidden_dim_discrete = 128

# data preparation
data_file = '/home/ankbzpx/datasets/ShapeNet/ShapeNetRenderingh5_v1/03001627/sdf_train_core.h5'
sample_size = 2048
batch_size = 12
split_ratio = 0.9
depth_size = 256
num_of_workers = 12
# training
num_epoch = 40

In [4]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import torch.nn as nn
import torchvision.transforms.functional as TF
from torch.autograd import Variable
from PIL import Image
import sys
import time
import matplotlib.pyplot as plt
from torchsummary import summary

In [5]:
# reproducible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
torch.manual_seed(0)
np.random.seed(0)

In [6]:
device

device(type='cuda')

In [7]:
class ChairSDFDataset(Dataset):
     
    def __init__(self, h5_file):
        
        self.file_path = h5_file
        self.dataset = None
        
        with h5py.File(self.file_path, 'r') as file:
            self.dataset_len = len(file)
            self.keys = list(file.keys())
            
        self.to_tensor = transforms.ToTensor()
     
    def __len__(self):
        return self.dataset_len
 
    def __getitem__(self, idx):
        
        #start_time = time.time()
        
        if self.dataset is None:
            self.dataset = h5py.File(self.file_path, 'r')
         
        group = self.dataset[self.keys[idx]]
        
        depth_img = self.to_tensor(Image.fromarray(np.array(group['depth_img'])))
        
        #print("--- depth preprocessing %s seconds ---" % (time.time() - start_time))
        
        sample_pt_np = np.array(group['sample_pt']).reshape(-1, 3)
        sample_sdf_np = np.array(group['sample_sdf']).reshape(-1, 1)
        
        # check size correctness and fix incorrect data
        if sample_pt_np.shape[0] != 2048:
            sample_pt_np = np.pad(sample_pt_np, ((0, 2048 - sample_pt_np.shape[0]), (0, 0)), 'reflect')
        if sample_sdf_np.shape[0] != 2048:
            sample_sdf_np = np.pad(sample_sdf_np, ((0, 2048 - sample_sdf_np.shape[0]), (0, 0)), 'reflect')
            
        
        sample_pt = torch.from_numpy(sample_pt_np).float()
        sample_sdf = torch.from_numpy(sample_sdf_np).float()
        # scale sdf
        sample_sdf = torch.sign(sample_sdf)*torch.pow(torch.abs(sample_sdf), 0.25)
        
        #print("--- subsampling %s seconds ---" % (time.time() - start_time))
        
        target_vox = torch.from_numpy(np.array(group['target_vox'])).float()
        
        sample = { 'depth_img': depth_img,
                   'sample_pt':sample_pt,
                   'sample_sdf':sample_sdf,
                   'target_vox':target_vox,
                  }
        
        return sample

In [8]:
train_sdf_dataset = ChairSDFDataset(data_file)

train_sdf_dataloader = DataLoader(train_sdf_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)

In [9]:
start_epoch = 0
latent_dim = 256

from models import Discrete_encoder, Discrete_decoder, Conditional_UNET, Continuous

####################
# Discrete Encoder #
####################

discrete_encoder = Discrete_encoder(256).to(device)
discrete_encoder.load_state_dict(torch.load(discrete_encoder_path))
discrete_encoder.eval()

for child in discrete_encoder.children():
    for param in child.parameters():
        param.requires_grad = False

        
# ####################
# # Discrete Decoder #
# ####################
        
discrete_decoder = Discrete_decoder(256).to(device)
discrete_decoder.load_state_dict(torch.load(discrete_decoder_path))
discrete_decoder.eval()

for child in discrete_decoder.children():
    for param in child.parameters():
        param.requires_grad = False
        
########
# UNET #
########

# pre-trained model is loaded within the model
unet = Conditional_UNET(unet_path).to(device)
unet.eval()

for child in unet.children():
    for param in child.parameters():
        param.requires_grad = False

In [10]:
class BottleNect1D(nn.Module):
    def __init__(self, input_dim, expand = 5):
        super(BottleNect1D, self).__init__()
        
        self.block = nn.Sequential(
            nn.Linear(input_dim, expand*input_dim),
            nn.BatchNorm1d(expand*input_dim),
            nn.ReLU(),
            nn.Linear(expand*input_dim, input_dim),
            nn.BatchNorm1d(input_dim),
        )
    
    def forward(self, x):
        return x + self.block(x)

class Continuous(nn.Module):
    def __init__(self, pt_dim = 3, con_dim = 32, latent_dim = 256):
        super(Continuous, self).__init__()
        
        self.de_pt =  nn.Sequential(
            nn.Linear(pt_dim + con_dim, 64),
            nn.BatchNorm1d(64),
            BottleNect1D(64),
            nn.Linear(64, latent_dim),
            nn.BatchNorm1d(latent_dim),
            BottleNect1D(latent_dim),
        )
        
        self.de_1 = nn.Sequential(
            nn.Linear(2*latent_dim, 2*latent_dim),
            nn.BatchNorm1d(2*latent_dim),
            BottleNect1D(2*latent_dim),
            nn.Linear(2*latent_dim, 2*latent_dim),
            nn.BatchNorm1d(2*latent_dim),
            BottleNect1D(2*latent_dim),
        )
        
        self.de_2 = nn.Sequential(
            nn.Linear(4*latent_dim, 2*latent_dim),
            nn.BatchNorm1d(2*latent_dim),
            BottleNect1D(2*latent_dim),
            nn.Linear(2*latent_dim, 2*latent_dim),
            nn.BatchNorm1d(2*latent_dim),
            BottleNect1D(2*latent_dim),
            nn.Linear(2*latent_dim, 1),
            nn.Tanh(),
        )
        
    def forward(self, pt, con, z):
        
        cat = torch.cat((self.de_pt(torch.cat((pt, con), 1)), z), 1)
        out = self.de_1(cat)
        out = self.de_2(torch.cat((out, cat), 1))
        
        return out

In [11]:
continuous = Continuous().to(device)

In [12]:
l1loss = nn.L1Loss()

In [13]:
lr = 3e-4

model = continuous

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [14]:
def save_checkpoint(state):
    torch.save(state, 'continuous_model_checkpoint.pth.tar')

In [14]:
checkpoint = torch.load('continuous_model_checkpoint.pth.tar')
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['continuous_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

In [15]:
# advanced indexing 2x2x2 context from voxel
def getContext(sample_pt_query, vox):
    
    # sample_pt bxmx3
    # vox bxcxdimxdimxdim
    
    channel_size = vox.shape[1]
    batch_size, sample_size, _ = sample_pt_query.shape
    meshgrid_base = torch.Tensor(np.meshgrid(np.arange(0, batch_size), np.arange(0, channel_size), np.arange(0, 2), np.arange(0, 2), np.arange(0, 2))).int()
    context = torch.empty((batch_size, sample_size, channel_size, 2, 2, 2))
    
    for j in range(context.shape[1]):
        context[:, j, :, :, :, :] = vox[
                    meshgrid_base[0].long(),
                    meshgrid_base[1].long(),
                    (meshgrid_base[2] + sample_pt_query[:, j, 0].reshape(1, -1, 1, 1, 1)).long(), 
                    (meshgrid_base[3] + sample_pt_query[:, j, 1].reshape(1, -1, 1, 1, 1)).long(), 
                    (meshgrid_base[4] + sample_pt_query[:, j, 2].reshape(1, -1, 1, 1, 1)).long()
                ].transpose(0, 1)
    
    # b x c x m x 2 x 2 x 2
    return context.transpose(1, 2)

In [16]:
def trilinearInterpolation(context, dx, dy, dz):
    
    v0 = context[:, :, :, 0, 0, 0]*(1-dx)*(1-dy)*(1-dz)
    v1 = context[:, :, :, 1, 0, 0]*dx*(1-dy)*(1-dz)
    v2 = context[:, :, :, 0, 1, 0]*(1-dx)*dy*(1-dz)
    v3 = context[:, :, :, 1, 1, 0]*dx*dy*(1-dz)
    v4 = context[:, :, :, 0, 0, 1]*(1-dx)*(1-dy)*dz
    v5 = context[:, :, :, 1, 0, 1]*dx*(1-dy)*dz
    v6 = context[:, :, :, 0, 1, 1]*(1-dx)*dy*dz
    v7 = context[:, :, :, 1, 1, 1]*dx*dy*dz
    
    # b x c x m 1
    return v0 + v1 + v2 + v3 + v4 + v5 + v6 + v7

In [17]:
#########################################
#test sub module

import time

vox_size = 32
latent_dim = 256
con_dim = 32

batch_len = len(train_sdf_dataloader)

print("Starting Training Loop...")

start_time = time.time()

for epoch in range(start_epoch, num_epoch):
    
    loss_list = []
    loss_batch = []
    
    print("Epoch: ", epoch)
    
    count = 0
    
    for i, data in enumerate(train_sdf_dataloader):
        
        ####################
        # Data preparation #
        ####################
        
        # b x 1 x 256 x 256
        depth_img = data['depth_img'].to(device)
        # b x 128 x 1 x 1
        z = discrete_encoder(depth_img)
        # b x n x 3
        # DO NOT scale by np.sqrt(3)
        sample_pt = data['sample_pt']
        # b x n x 1
        sample_sdf = data['sample_sdf']
        
        # b x 16 x 64 x 64 x 64
        target_vox = data['target_vox'].to(device)
        vox_feature = unet(target_vox, z)
        #vox_feature = unet(torch.sigmoid(discrete_decoder(mapping(z))))
        
        ####################
        # indexing context #
        ####################
        
        # stay with cpu for v-ram efficiency
        sample_pt_normalized = sample_pt + torch.tensor([0.5, 0.5, 0.5])
        # (0, vox_size-1)
        sample_pt_scale = torch.clamp(sample_pt_normalized* (vox_size-1), 0, (vox_size-1)-1e-5)
        # (0, vox_size-2)
        sample_pt_query = torch.clamp((sample_pt_scale).int(), 0, (vox_size-2))
        sample_pt_distance = sample_pt_scale - sample_pt_query
        
        context = getContext(sample_pt_query, vox_feature)
        
        dx = sample_pt_distance[:, :, 0].unsqueeze(1)
        dy = sample_pt_distance[:, :, 1].unsqueeze(1)
        dz = sample_pt_distance[:, :, 2].unsqueeze(1)
        # local feature
        con = trilinearInterpolation(context, dx, dy, dz)
        
        ################################
        # Reshape input & forward pass #
        ################################
        
        sample_pt = sample_pt.transpose(-1, -2).to(device)
        con = con.to(device)
        z = z.squeeze(-1).squeeze(-1).repeat(1, 1, sample_size)
        sample_sdf = sample_sdf.transpose(-1, -2).to(device)
        
        
        sample_pt = sample_pt.transpose(-1, -2).reshape(-1, 3)
        con = con.transpose(-1, -2).reshape(-1, con_dim)
        z = z.transpose(-1, -2).reshape(-1, latent_dim)
        sample_sdf = sample_sdf.transpose(-1, -2).reshape(-1, 1)
        
        
        optimizer.zero_grad()
        
        pred_sdf = model(sample_pt, con, z)
        
        loss_l1 = l1loss(pred_sdf, sample_sdf)
        
        loss = loss_l1
        
        loss.backward()
        
        loss_list.append(loss_l1.item())
        loss_batch.append(loss_l1.item())
        
        optimizer.step()
        
        #scheduler.step()
        
        if count != 0 and count % 10 == 0:
            loss_batch_avg = np.average(loss_batch)
            
            print("Batch: ", count, ", l1 Loss: ", loss_batch_avg, ", Time: %s s" % (time.time() - start_time))
            
            if count % 500 == 0:
                torch.save(model.state_dict(), continuous_model_path)
                
            loss_batch.clear()
            
        count += 1
        
    print("Epoch: ", epoch, ', l1 loss: ', np.average(loss_list))
    
    loss_list.clear()
    
    torch.save(model.state_dict(), continuous_model_path)
    
    save_checkpoint({
        'epoch': epoch + 1,
        'continuous_state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    })
    
print("Training finished")

Starting Training Loop...
Epoch:  0
Batch:  10 , l1 Loss:  0.6222708333622325 , Time: 14.060269832611084 s
Batch:  20 , l1 Loss:  0.33097506910562513 , Time: 26.45430064201355 s
Batch:  30 , l1 Loss:  0.28993759155273435 , Time: 38.93435883522034 s
Batch:  40 , l1 Loss:  0.23616625517606735 , Time: 51.451082706451416 s
Batch:  50 , l1 Loss:  0.22275505512952803 , Time: 64.02029180526733 s
Batch:  60 , l1 Loss:  0.20075249671936035 , Time: 76.67286324501038 s
Batch:  70 , l1 Loss:  0.18069818019866943 , Time: 89.28528094291687 s
Batch:  80 , l1 Loss:  0.18161795437335967 , Time: 101.93931722640991 s
Batch:  90 , l1 Loss:  0.17924284785985947 , Time: 114.61278223991394 s
Batch:  100 , l1 Loss:  0.14637547731399536 , Time: 127.29287838935852 s
Batch:  110 , l1 Loss:  0.1511179819703102 , Time: 140.0021674633026 s
Batch:  120 , l1 Loss:  0.1369858279824257 , Time: 152.76185870170593 s
Batch:  130 , l1 Loss:  0.14808819741010665 , Time: 165.49697756767273 s
Batch:  140 , l1 Loss:  0.1552505

Batch:  1140 , l1 Loss:  0.11564047709107399 , Time: 1467.037314414978 s
Batch:  1150 , l1 Loss:  0.11183306947350502 , Time: 1480.110852241516 s
Batch:  1160 , l1 Loss:  0.10603501051664352 , Time: 1493.1291942596436 s
Batch:  1170 , l1 Loss:  0.10788829177618027 , Time: 1506.0277047157288 s
Batch:  1180 , l1 Loss:  0.10977548584342003 , Time: 1518.9243104457855 s
Batch:  1190 , l1 Loss:  0.11410747244954109 , Time: 1531.8259069919586 s
Batch:  1200 , l1 Loss:  0.11226021870970726 , Time: 1544.7649977207184 s
Batch:  1210 , l1 Loss:  0.11147228777408599 , Time: 1557.6638133525848 s
Batch:  1220 , l1 Loss:  0.11225007697939873 , Time: 1570.6371417045593 s
Batch:  1230 , l1 Loss:  0.11264226734638214 , Time: 1583.596331357956 s
Batch:  1240 , l1 Loss:  0.11193570867180824 , Time: 1596.6141147613525 s
Batch:  1250 , l1 Loss:  0.11199744269251824 , Time: 1609.555477142334 s
Batch:  1260 , l1 Loss:  0.11478796526789665 , Time: 1622.5303087234497 s
Batch:  1270 , l1 Loss:  0.112783951312303

Batch:  2260 , l1 Loss:  0.10759318098425866 , Time: 2929.718319416046 s
Batch:  2270 , l1 Loss:  0.10892292857170105 , Time: 2942.836816072464 s
Batch:  2280 , l1 Loss:  0.10942751243710518 , Time: 2955.9783675670624 s
Batch:  2290 , l1 Loss:  0.10583948418498039 , Time: 2969.0882625579834 s
Batch:  2300 , l1 Loss:  0.10686551779508591 , Time: 2982.184993505478 s
Batch:  2310 , l1 Loss:  0.11189115643501282 , Time: 2995.226074695587 s
Batch:  2320 , l1 Loss:  0.10833612233400344 , Time: 3008.2703187465668 s
Batch:  2330 , l1 Loss:  0.11462832391262054 , Time: 3021.271742105484 s
Batch:  2340 , l1 Loss:  0.11157186627388001 , Time: 3034.3496351242065 s
Batch:  2350 , l1 Loss:  0.11146168783307076 , Time: 3047.395822286606 s
Batch:  2360 , l1 Loss:  0.10777470543980598 , Time: 3060.4134299755096 s
Batch:  2370 , l1 Loss:  0.10738113522529602 , Time: 3073.4846918582916 s
Batch:  2380 , l1 Loss:  0.10889825224876404 , Time: 3086.5090782642365 s
Batch:  2390 , l1 Loss:  0.10996963456273079

Batch:  3380 , l1 Loss:  0.10498561933636666 , Time: 4393.810075998306 s
Batch:  3390 , l1 Loss:  0.10814249664545059 , Time: 4406.872135877609 s
Batch:  3400 , l1 Loss:  0.10975121781229973 , Time: 4419.924778461456 s
Batch:  3410 , l1 Loss:  0.10419900193810464 , Time: 4433.002830982208 s
Batch:  3420 , l1 Loss:  0.10834634006023407 , Time: 4446.078515052795 s
Batch:  3430 , l1 Loss:  0.10692025274038315 , Time: 4459.1771504879 s
Batch:  3440 , l1 Loss:  0.10774853900074959 , Time: 4472.474814891815 s
Batch:  3450 , l1 Loss:  0.10331388339400291 , Time: 4485.553847551346 s
Batch:  3460 , l1 Loss:  0.10835448205471039 , Time: 4498.632539510727 s
Batch:  3470 , l1 Loss:  0.10801061168313027 , Time: 4511.690908670425 s
Batch:  3480 , l1 Loss:  0.10764374509453774 , Time: 4524.768872976303 s
Batch:  3490 , l1 Loss:  0.10752873122692108 , Time: 4537.801419496536 s
Batch:  3500 , l1 Loss:  0.10892132669687271 , Time: 4550.838408708572 s
Batch:  3510 , l1 Loss:  0.10953421369194985 , Time: 

Batch:  4510 , l1 Loss:  0.1072274774312973 , Time: 5876.233028411865 s
Batch:  4520 , l1 Loss:  0.10630022436380386 , Time: 5889.348572254181 s
Batch:  4530 , l1 Loss:  0.10596808195114135 , Time: 5902.486583948135 s
Batch:  4540 , l1 Loss:  0.10552069991827011 , Time: 5915.667801856995 s
Batch:  4550 , l1 Loss:  0.10414698347449303 , Time: 5928.802874326706 s
Batch:  4560 , l1 Loss:  0.10569369941949844 , Time: 5941.8566081523895 s
Batch:  4570 , l1 Loss:  0.1065135046839714 , Time: 5954.951425075531 s
Batch:  4580 , l1 Loss:  0.1053469367325306 , Time: 5968.05441403389 s
Batch:  4590 , l1 Loss:  0.10605354905128479 , Time: 5981.195420026779 s
Batch:  4600 , l1 Loss:  0.10657467693090439 , Time: 5994.337844848633 s
Batch:  4610 , l1 Loss:  0.1058957777917385 , Time: 6007.475194692612 s
Batch:  4620 , l1 Loss:  0.10367530211806297 , Time: 6020.61604642868 s
Batch:  4630 , l1 Loss:  0.10403775870800018 , Time: 6033.755775690079 s
Batch:  4640 , l1 Loss:  0.1033033013343811 , Time: 6046

Batch:  5640 , l1 Loss:  0.1024406410753727 , Time: 7361.138870954514 s
Batch:  5650 , l1 Loss:  0.10955684334039688 , Time: 7374.250108003616 s
Batch:  5660 , l1 Loss:  0.10661673173308372 , Time: 7387.355711698532 s
Batch:  5670 , l1 Loss:  0.10515652671456337 , Time: 7400.548562526703 s
Batch:  5680 , l1 Loss:  0.10380772948265075 , Time: 7413.685817480087 s
Batch:  5690 , l1 Loss:  0.10943234711885452 , Time: 7426.845082759857 s
Batch:  5700 , l1 Loss:  0.10539023503661156 , Time: 7440.001220703125 s
Batch:  5710 , l1 Loss:  0.10632451847195626 , Time: 7453.135695934296 s
Batch:  5720 , l1 Loss:  0.10382528826594353 , Time: 7466.312672138214 s
Batch:  5730 , l1 Loss:  0.10651838630437852 , Time: 7479.527826309204 s
Batch:  5740 , l1 Loss:  0.1057260848581791 , Time: 7492.665759801865 s
Batch:  5750 , l1 Loss:  0.10255171209573746 , Time: 7505.795810222626 s
Batch:  5760 , l1 Loss:  0.10603666007518768 , Time: 7518.938846349716 s
Batch:  5770 , l1 Loss:  0.10523307621479035 , Time: 

Batch:  6770 , l1 Loss:  0.10537129342556 , Time: 8844.178131341934 s
Batch:  6780 , l1 Loss:  0.10393856465816498 , Time: 8857.294424533844 s
Batch:  6790 , l1 Loss:  0.10471213459968567 , Time: 8870.41519832611 s
Batch:  6800 , l1 Loss:  0.1028789222240448 , Time: 8883.536967515945 s
Batch:  6810 , l1 Loss:  0.10454447194933891 , Time: 8896.652315378189 s
Batch:  6820 , l1 Loss:  0.1057123564183712 , Time: 8909.815384626389 s
Batch:  6830 , l1 Loss:  0.10196040347218513 , Time: 8922.93256521225 s
Batch:  6840 , l1 Loss:  0.10286107212305069 , Time: 8936.031563282013 s
Batch:  6850 , l1 Loss:  0.10393947660923004 , Time: 8949.195801496506 s
Batch:  6860 , l1 Loss:  0.10615792125463486 , Time: 8962.330397844315 s
Batch:  6870 , l1 Loss:  0.10373349264264106 , Time: 8975.467922449112 s
Batch:  6880 , l1 Loss:  0.10336619317531585 , Time: 8988.605613708496 s
Batch:  6890 , l1 Loss:  0.10348541662096977 , Time: 9001.743015527725 s
Batch:  6900 , l1 Loss:  0.10432492569088936 , Time: 9014.

Batch:  890 , l1 Loss:  0.10234321877360344 , Time: 10337.017402887344 s
Batch:  900 , l1 Loss:  0.10431804656982421 , Time: 10350.133598566055 s
Batch:  910 , l1 Loss:  0.10298867970705032 , Time: 10363.130765676498 s
Batch:  920 , l1 Loss:  0.1082956239581108 , Time: 10376.161674261093 s
Batch:  930 , l1 Loss:  0.10528303012251854 , Time: 10389.121956586838 s
Batch:  940 , l1 Loss:  0.10320260748267174 , Time: 10402.074187994003 s
Batch:  950 , l1 Loss:  0.10965722724795342 , Time: 10415.043492555618 s
Batch:  960 , l1 Loss:  0.10305386409163475 , Time: 10428.019608020782 s
Batch:  970 , l1 Loss:  0.10790353566408158 , Time: 10440.956770420074 s
Batch:  980 , l1 Loss:  0.10456297993659973 , Time: 10453.730530023575 s
Batch:  990 , l1 Loss:  0.10593241825699806 , Time: 10466.544493198395 s
Batch:  1000 , l1 Loss:  0.1028931774199009 , Time: 10479.418792009354 s
Batch:  1010 , l1 Loss:  0.10792988017201424 , Time: 10492.934875011444 s
Batch:  1020 , l1 Loss:  0.10244600474834442 , Time

Batch:  2010 , l1 Loss:  0.10252778232097626 , Time: 11903.433548927307 s
Batch:  2020 , l1 Loss:  0.10250156596302987 , Time: 11916.52677154541 s
Batch:  2030 , l1 Loss:  0.10259456112980843 , Time: 11929.56952548027 s
Batch:  2040 , l1 Loss:  0.10104589909315109 , Time: 11942.52896118164 s
Batch:  2050 , l1 Loss:  0.10344970375299453 , Time: 11955.461243629456 s
Batch:  2060 , l1 Loss:  0.10299953371286392 , Time: 11968.43947172165 s
Batch:  2070 , l1 Loss:  0.1029331423342228 , Time: 11981.437920331955 s
Batch:  2080 , l1 Loss:  0.10368879288434982 , Time: 11994.474351167679 s
Batch:  2090 , l1 Loss:  0.1049007959663868 , Time: 12007.453614711761 s
Batch:  2100 , l1 Loss:  0.10444119200110435 , Time: 12020.453511476517 s
Batch:  2110 , l1 Loss:  0.10319137275218963 , Time: 12033.49183177948 s
Batch:  2120 , l1 Loss:  0.10318803191184997 , Time: 12046.469720602036 s
Batch:  2130 , l1 Loss:  0.10210988074541091 , Time: 12059.472611427307 s
Batch:  2140 , l1 Loss:  0.10176787748932839 

Batch:  3130 , l1 Loss:  0.10373829379677772 , Time: 13359.46435880661 s
Batch:  3140 , l1 Loss:  0.10226365625858307 , Time: 13372.386568069458 s
Batch:  3150 , l1 Loss:  0.1089572250843048 , Time: 13385.198198795319 s
Batch:  3160 , l1 Loss:  0.10833229944109916 , Time: 13398.080803632736 s
Batch:  3170 , l1 Loss:  0.10600491687655449 , Time: 13410.959999322891 s
Batch:  3180 , l1 Loss:  0.10674877092242241 , Time: 13423.895334005356 s
Batch:  3190 , l1 Loss:  0.10544400066137313 , Time: 13436.812149763107 s
Batch:  3200 , l1 Loss:  0.10369816496968269 , Time: 13449.772837400436 s
Batch:  3210 , l1 Loss:  0.104886594414711 , Time: 13462.751316785812 s
Batch:  3220 , l1 Loss:  0.10470624640583992 , Time: 13475.68816947937 s
Batch:  3230 , l1 Loss:  0.10679475665092468 , Time: 13488.624327421188 s
Batch:  3240 , l1 Loss:  0.10548666268587112 , Time: 13501.578483819962 s
Batch:  3250 , l1 Loss:  0.10634243190288543 , Time: 13514.533401727676 s
Batch:  3260 , l1 Loss:  0.1028290562331676

Batch:  4250 , l1 Loss:  0.10302729681134223 , Time: 14815.190380573273 s
Batch:  4260 , l1 Loss:  0.102212905138731 , Time: 14828.306272983551 s
Batch:  4270 , l1 Loss:  0.10838131532073021 , Time: 14841.377132892609 s
Batch:  4280 , l1 Loss:  0.1194544978439808 , Time: 14854.430921316147 s
Batch:  4290 , l1 Loss:  0.12070790752768516 , Time: 14867.488626718521 s
Batch:  4300 , l1 Loss:  0.11568413451313972 , Time: 14880.602949857712 s
Batch:  4310 , l1 Loss:  0.11367223411798477 , Time: 14893.640291929245 s
Batch:  4320 , l1 Loss:  0.1099728912115097 , Time: 14906.67838048935 s
Batch:  4330 , l1 Loss:  0.10706565305590629 , Time: 14919.807616710663 s
Batch:  4340 , l1 Loss:  0.10613504648208619 , Time: 14932.871587753296 s
Batch:  4350 , l1 Loss:  0.10445750206708908 , Time: 14945.923359394073 s
Batch:  4360 , l1 Loss:  0.10483096763491631 , Time: 14958.98577618599 s
Batch:  4370 , l1 Loss:  0.1043745070695877 , Time: 14972.025510549545 s
Batch:  4380 , l1 Loss:  0.10422563180327415 

Batch:  5370 , l1 Loss:  0.10279466584324837 , Time: 16275.797885894775 s
Batch:  5380 , l1 Loss:  0.10115106105804443 , Time: 16288.773777723312 s
Batch:  5390 , l1 Loss:  0.10264011919498443 , Time: 16301.814002513885 s
Batch:  5400 , l1 Loss:  0.10465280264616013 , Time: 16314.81283211708 s
Batch:  5410 , l1 Loss:  0.10226431488990784 , Time: 16327.78382396698 s
Batch:  5420 , l1 Loss:  0.10344723090529442 , Time: 16340.76855301857 s
Batch:  5430 , l1 Loss:  0.10581233501434326 , Time: 16353.808062076569 s
Batch:  5440 , l1 Loss:  0.10328185707330703 , Time: 16366.805819749832 s
Batch:  5450 , l1 Loss:  0.10282253473997116 , Time: 16379.803084611893 s
Batch:  5460 , l1 Loss:  0.10351699590682983 , Time: 16392.819905757904 s
Batch:  5470 , l1 Loss:  0.10315023586153985 , Time: 16405.82869553566 s
Batch:  5480 , l1 Loss:  0.10163533315062523 , Time: 16418.81312060356 s
Batch:  5490 , l1 Loss:  0.10243872702121734 , Time: 16431.70971274376 s
Batch:  5500 , l1 Loss:  0.10138046443462372

Batch:  6490 , l1 Loss:  0.1020821489393711 , Time: 17730.48830962181 s
Batch:  6500 , l1 Loss:  0.10347632691264153 , Time: 17743.56121492386 s
Batch:  6510 , l1 Loss:  0.1028152734041214 , Time: 17756.65947151184 s
Batch:  6520 , l1 Loss:  0.10131194591522216 , Time: 17769.717993974686 s
Batch:  6530 , l1 Loss:  0.09908441007137299 , Time: 17782.735707998276 s
Batch:  6540 , l1 Loss:  0.10160050168633461 , Time: 17795.811142206192 s
Batch:  6550 , l1 Loss:  0.1035899505019188 , Time: 17808.843181610107 s
Batch:  6560 , l1 Loss:  0.10203099548816681 , Time: 17821.855226516724 s
Batch:  6570 , l1 Loss:  0.10337842926383019 , Time: 17834.89602279663 s
Batch:  6580 , l1 Loss:  0.10317335650324821 , Time: 17847.93301510811 s
Batch:  6590 , l1 Loss:  0.10337360203266144 , Time: 17860.96911597252 s
Batch:  6600 , l1 Loss:  0.10119169279932975 , Time: 17873.965989112854 s
Batch:  6610 , l1 Loss:  0.1014327235519886 , Time: 17886.915684223175 s
Batch:  6620 , l1 Loss:  0.10296652168035507 , T

Batch:  590 , l1 Loss:  0.10133177787065506 , Time: 19185.093391656876 s
Batch:  600 , l1 Loss:  0.10514897033572197 , Time: 19198.035388231277 s
Batch:  610 , l1 Loss:  0.10144117027521134 , Time: 19210.936295986176 s
Batch:  620 , l1 Loss:  0.10183952227234841 , Time: 19223.84939956665 s
Batch:  630 , l1 Loss:  0.10350258946418762 , Time: 19236.800542593002 s
Batch:  640 , l1 Loss:  0.10191370695829391 , Time: 19249.681005954742 s
Batch:  650 , l1 Loss:  0.09916811808943748 , Time: 19262.599314689636 s
Batch:  660 , l1 Loss:  0.10215995535254478 , Time: 19275.52956223488 s
Batch:  670 , l1 Loss:  0.10458474680781364 , Time: 19288.463523626328 s
Batch:  680 , l1 Loss:  0.10220379382371902 , Time: 19301.384677648544 s
Batch:  690 , l1 Loss:  0.10037873163819314 , Time: 19314.277300834656 s
Batch:  700 , l1 Loss:  0.10164713189005851 , Time: 19327.194789409637 s
Batch:  710 , l1 Loss:  0.10123834758996964 , Time: 19340.112315654755 s
Batch:  720 , l1 Loss:  0.10122834667563438 , Time: 1

Batch:  1720 , l1 Loss:  0.10303688123822212 , Time: 20648.18138217926 s
Batch:  1730 , l1 Loss:  0.10309863537549972 , Time: 20661.147671461105 s
Batch:  1740 , l1 Loss:  0.10065256282687188 , Time: 20674.15418434143 s
Batch:  1750 , l1 Loss:  0.099588543176651 , Time: 20687.133922100067 s
Batch:  1760 , l1 Loss:  0.10301246419548989 , Time: 20700.129646539688 s
Batch:  1770 , l1 Loss:  0.1032145656645298 , Time: 20713.065435647964 s
Batch:  1780 , l1 Loss:  0.10258913636207581 , Time: 20726.057560920715 s
Batch:  1790 , l1 Loss:  0.1006972149014473 , Time: 20739.12262248993 s
Batch:  1800 , l1 Loss:  0.10199219584465027 , Time: 20752.204557657242 s
Batch:  1810 , l1 Loss:  0.1047582633793354 , Time: 20765.257970571518 s
Batch:  1820 , l1 Loss:  0.1026942141354084 , Time: 20778.155484199524 s
Batch:  1830 , l1 Loss:  0.10379894003272057 , Time: 20791.05418777466 s
Batch:  1840 , l1 Loss:  0.10468000024557114 , Time: 20803.987063884735 s
Batch:  1850 , l1 Loss:  0.10677537396550178 , T

Batch:  2840 , l1 Loss:  0.10545584261417389 , Time: 22098.131308555603 s
Batch:  2850 , l1 Loss:  0.10223342850804329 , Time: 22111.13097167015 s
Batch:  2860 , l1 Loss:  0.1011072002351284 , Time: 22124.089428901672 s
Batch:  2870 , l1 Loss:  0.10200414434075356 , Time: 22137.065145254135 s
Batch:  2880 , l1 Loss:  0.10208194404840469 , Time: 22150.10292863846 s
Batch:  2890 , l1 Loss:  0.09981720298528671 , Time: 22163.083998918533 s
Batch:  2900 , l1 Loss:  0.10235111489892006 , Time: 22176.12290453911 s
Batch:  2910 , l1 Loss:  0.10320847034454346 , Time: 22189.09217453003 s
Batch:  2920 , l1 Loss:  0.10144681558012962 , Time: 22202.132350206375 s
Batch:  2930 , l1 Loss:  0.10025516971945762 , Time: 22215.25816822052 s
Batch:  2940 , l1 Loss:  0.10062126740813256 , Time: 22228.369537591934 s
Batch:  2950 , l1 Loss:  0.10009031444787979 , Time: 22241.500529527664 s
Batch:  2960 , l1 Loss:  0.10247887894511223 , Time: 22254.580592870712 s
Batch:  2970 , l1 Loss:  0.10254285708069802

Batch:  3960 , l1 Loss:  0.10209522917866706 , Time: 23549.411649227142 s
Batch:  3970 , l1 Loss:  0.10054313465952873 , Time: 23562.39328813553 s
Batch:  3980 , l1 Loss:  0.10439046993851661 , Time: 23575.371765375137 s
Batch:  3990 , l1 Loss:  0.10179613381624222 , Time: 23588.290605545044 s
Batch:  4000 , l1 Loss:  0.1031246930360794 , Time: 23601.270037651062 s
Batch:  4010 , l1 Loss:  0.10229721441864967 , Time: 23614.263493537903 s
Batch:  4020 , l1 Loss:  0.10407048910856247 , Time: 23627.238141059875 s
Batch:  4030 , l1 Loss:  0.10249204710125923 , Time: 23640.23616361618 s
Batch:  4040 , l1 Loss:  0.10046897307038308 , Time: 23653.174507379532 s
Batch:  4050 , l1 Loss:  0.10072513520717621 , Time: 23666.16606760025 s
Batch:  4060 , l1 Loss:  0.10137928426265716 , Time: 23679.18185400963 s
Batch:  4070 , l1 Loss:  0.1010911799967289 , Time: 23692.161667346954 s
Batch:  4080 , l1 Loss:  0.10025181695818901 , Time: 23705.10351228714 s
Batch:  4090 , l1 Loss:  0.10125950574874878 

Batch:  5080 , l1 Loss:  0.10449554696679116 , Time: 25000.799622297287 s
Batch:  5090 , l1 Loss:  0.09967636093497276 , Time: 25013.737236738205 s
Batch:  5100 , l1 Loss:  0.10055437311530113 , Time: 25026.647030115128 s
Batch:  5110 , l1 Loss:  0.10231124013662338 , Time: 25039.58411049843 s
Batch:  5120 , l1 Loss:  0.10345015525817872 , Time: 25052.517202854156 s
Batch:  5130 , l1 Loss:  0.10485049858689308 , Time: 25065.459993839264 s
Batch:  5140 , l1 Loss:  0.10052623003721237 , Time: 25078.393759965897 s
Batch:  5150 , l1 Loss:  0.1032386414706707 , Time: 25091.28619670868 s
Batch:  5160 , l1 Loss:  0.10135585591197013 , Time: 25104.159628391266 s
Batch:  5170 , l1 Loss:  0.09907504692673683 , Time: 25117.117537498474 s
Batch:  5180 , l1 Loss:  0.10064190849661828 , Time: 25130.015751838684 s
Batch:  5190 , l1 Loss:  0.10101099535822869 , Time: 25143.009566307068 s
Batch:  5200 , l1 Loss:  0.09978479295969009 , Time: 25155.942548513412 s
Batch:  5210 , l1 Loss:  0.10016264691948

Batch:  6200 , l1 Loss:  0.09987829625606537 , Time: 26452.75675010681 s
Batch:  6210 , l1 Loss:  0.10325600430369378 , Time: 26465.694767475128 s
Batch:  6220 , l1 Loss:  0.10297117680311203 , Time: 26478.614973306656 s
Batch:  6230 , l1 Loss:  0.10075679868459701 , Time: 26491.5708861351 s
Batch:  6240 , l1 Loss:  0.10201162546873092 , Time: 26504.524520397186 s
Batch:  6250 , l1 Loss:  0.10098938941955567 , Time: 26517.500088214874 s
Batch:  6260 , l1 Loss:  0.10180521085858345 , Time: 26530.458481788635 s
Batch:  6270 , l1 Loss:  0.10145903006196022 , Time: 26543.40000796318 s
Batch:  6280 , l1 Loss:  0.10224303454160691 , Time: 26556.39503955841 s
Batch:  6290 , l1 Loss:  0.10111218914389611 , Time: 26569.371505975723 s
Batch:  6300 , l1 Loss:  0.10072181671857834 , Time: 26582.331192731857 s
Batch:  6310 , l1 Loss:  0.10042344480752945 , Time: 26595.337315797806 s
Batch:  6320 , l1 Loss:  0.10003795176744461 , Time: 26608.302629470825 s
Batch:  6330 , l1 Loss:  0.1013329610228538

Batch:  300 , l1 Loss:  0.10054391101002694 , Time: 27910.838450193405 s
Batch:  310 , l1 Loss:  0.10132896304130554 , Time: 27923.731503248215 s
Batch:  320 , l1 Loss:  0.10147906988859176 , Time: 27936.541546583176 s
Batch:  330 , l1 Loss:  0.10131789445877075 , Time: 27949.4203042984 s
Batch:  340 , l1 Loss:  0.10259396061301232 , Time: 27962.303371429443 s
Batch:  350 , l1 Loss:  0.10191823691129684 , Time: 27975.219435214996 s
Batch:  360 , l1 Loss:  0.10160309374332428 , Time: 27988.13053393364 s
Batch:  370 , l1 Loss:  0.10033189356327057 , Time: 28001.26632642746 s
Batch:  380 , l1 Loss:  0.10131421759724617 , Time: 28014.31858921051 s
Batch:  390 , l1 Loss:  0.10200956240296363 , Time: 28027.377527713776 s
Batch:  400 , l1 Loss:  0.10046070590615272 , Time: 28040.376056432724 s
Batch:  410 , l1 Loss:  0.09945999234914779 , Time: 28053.40935897827 s
Batch:  420 , l1 Loss:  0.10006348118185997 , Time: 28066.439769744873 s
Batch:  430 , l1 Loss:  0.09940270632505417 , Time: 28079

Batch:  1430 , l1 Loss:  0.10132381692528725 , Time: 29375.817677259445 s
Batch:  1440 , l1 Loss:  0.10107669606804848 , Time: 29388.731818199158 s
Batch:  1450 , l1 Loss:  0.10133690908551216 , Time: 29401.654061317444 s
Batch:  1460 , l1 Loss:  0.10746443271636963 , Time: 29414.608151435852 s
Batch:  1470 , l1 Loss:  0.10680588260293007 , Time: 29427.523127555847 s
Batch:  1480 , l1 Loss:  0.10688736736774444 , Time: 29440.464397192 s
Batch:  1490 , l1 Loss:  0.10672667324542999 , Time: 29453.374522686005 s
Batch:  1500 , l1 Loss:  0.10146920457482338 , Time: 29466.29358267784 s
Batch:  1510 , l1 Loss:  0.10123461261391639 , Time: 29479.267792224884 s
Batch:  1520 , l1 Loss:  0.1023484081029892 , Time: 29492.18756699562 s
Batch:  1530 , l1 Loss:  0.09910318329930305 , Time: 29505.128677129745 s
Batch:  1540 , l1 Loss:  0.10035663694143296 , Time: 29518.065029144287 s
Batch:  1550 , l1 Loss:  0.10117477551102638 , Time: 29531.000784873962 s
Batch:  1560 , l1 Loss:  0.10092361494898797

Batch:  2550 , l1 Loss:  0.10131178423762321 , Time: 30826.39546251297 s
Batch:  2560 , l1 Loss:  0.10187182649970054 , Time: 30839.334251880646 s
Batch:  2570 , l1 Loss:  0.10205570757389068 , Time: 30852.265276432037 s
Batch:  2580 , l1 Loss:  0.09901227653026581 , Time: 30865.18859219551 s
Batch:  2590 , l1 Loss:  0.10131625384092331 , Time: 30878.106032848358 s
Batch:  2600 , l1 Loss:  0.10074106007814407 , Time: 30891.063526153564 s
Batch:  2610 , l1 Loss:  0.10112023651599884 , Time: 30904.0222697258 s
Batch:  2620 , l1 Loss:  0.09982427433133126 , Time: 30916.942880392075 s
Batch:  2630 , l1 Loss:  0.0986778162419796 , Time: 30929.774524450302 s
Batch:  2640 , l1 Loss:  0.10179335176944733 , Time: 30942.702518701553 s
Batch:  2650 , l1 Loss:  0.10148065909743309 , Time: 30955.597086429596 s
Batch:  2660 , l1 Loss:  0.10148991867899895 , Time: 30968.49508213997 s
Batch:  2670 , l1 Loss:  0.09933649301528931 , Time: 30981.435520887375 s
Batch:  2680 , l1 Loss:  0.10082468539476394

Batch:  3670 , l1 Loss:  0.09983877763152123 , Time: 32278.2036986351 s
Batch:  3680 , l1 Loss:  0.10030801892280579 , Time: 32291.236379623413 s
Batch:  3690 , l1 Loss:  0.10166999697685242 , Time: 32304.212918758392 s
Batch:  3700 , l1 Loss:  0.10088608115911483 , Time: 32317.250661849976 s
Batch:  3710 , l1 Loss:  0.10052293315529823 , Time: 32330.25079727173 s
Batch:  3720 , l1 Loss:  0.10248799994587898 , Time: 32343.15325665474 s
Batch:  3730 , l1 Loss:  0.09996562823653221 , Time: 32355.93921971321 s
Batch:  3740 , l1 Loss:  0.09995659813284874 , Time: 32368.75483727455 s
Batch:  3750 , l1 Loss:  0.09980555400252342 , Time: 32381.69346833229 s
Batch:  3760 , l1 Loss:  0.10138349756598472 , Time: 32394.670224428177 s
Batch:  3770 , l1 Loss:  0.10158933848142623 , Time: 32407.586256742477 s
Batch:  3780 , l1 Loss:  0.10122073739767075 , Time: 32420.503121614456 s
Batch:  3790 , l1 Loss:  0.10033605322241783 , Time: 32433.43492078781 s
Batch:  3800 , l1 Loss:  0.10133186429738998 ,

Batch:  4790 , l1 Loss:  0.09880012348294258 , Time: 33732.035818099976 s
Batch:  4800 , l1 Loss:  0.1016931876540184 , Time: 33745.055603027344 s
Batch:  4810 , l1 Loss:  0.10036587417125702 , Time: 33758.09735226631 s
Batch:  4820 , l1 Loss:  0.09778079539537429 , Time: 33771.05070948601 s
Batch:  4830 , l1 Loss:  0.10132867097854614 , Time: 33784.05187869072 s
Batch:  4840 , l1 Loss:  0.10134657323360444 , Time: 33797.10373330116 s
Batch:  4850 , l1 Loss:  0.09909924194216728 , Time: 33810.1434738636 s
Batch:  4860 , l1 Loss:  0.10084285438060761 , Time: 33823.17971134186 s
Batch:  4870 , l1 Loss:  0.10101465433835984 , Time: 33836.235557317734 s
Batch:  4880 , l1 Loss:  0.10023643150925636 , Time: 33849.25014424324 s
Batch:  4890 , l1 Loss:  0.1002196952700615 , Time: 33862.20684862137 s
Batch:  4900 , l1 Loss:  0.10008579641580581 , Time: 33875.242653131485 s
Batch:  4910 , l1 Loss:  0.10067836493253708 , Time: 33888.221528053284 s
Batch:  4920 , l1 Loss:  0.09923197478055953 , Ti

Batch:  5920 , l1 Loss:  0.10152223780751228 , Time: 35199.4715590477 s
Batch:  5930 , l1 Loss:  0.1007527120411396 , Time: 35212.367154598236 s
Batch:  5940 , l1 Loss:  0.10187013521790504 , Time: 35225.28562760353 s
Batch:  5950 , l1 Loss:  0.09913881719112397 , Time: 35238.324226379395 s
Batch:  5960 , l1 Loss:  0.09940312504768371 , Time: 35251.38167977333 s
Batch:  5970 , l1 Loss:  0.09797674715518952 , Time: 35264.38931107521 s
Batch:  5980 , l1 Loss:  0.10130699798464775 , Time: 35277.43057227135 s
Batch:  5990 , l1 Loss:  0.09882613718509674 , Time: 35290.45270204544 s
Batch:  6000 , l1 Loss:  0.10026866495609284 , Time: 35303.491309165955 s
Batch:  6010 , l1 Loss:  0.10060828775167466 , Time: 35316.60376572609 s
Batch:  6020 , l1 Loss:  0.09808481708168984 , Time: 35329.639719963074 s
Batch:  6030 , l1 Loss:  0.10112260282039642 , Time: 35342.67514538765 s
Batch:  6040 , l1 Loss:  0.09871230721473694 , Time: 35355.67322182655 s
Batch:  6050 , l1 Loss:  0.09989573955535888 , Ti

KeyboardInterrupt: 