In [1]:
from __future__ import division
import os, time
import numpy as np
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import sys
from optparse import OptionParser
import numpy as np
from torch import optim
from PIL import Image
from torch.autograd import Function, Variable
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset
# import cv2
import pickle
from tqdm import tqdm
import rawpy
%matplotlib inline

input_dir = './Dataset/Sony/short/'
gt_dir = './Dataset/Sony/long/'

result_dir = './results_SSIM/'
model_save_path = './results_SSIM/net_weights/'
if os.path.isdir(model_save_path) is not True:
    os.makedirs(model_save_path, exist_ok=True)

In [2]:
if torch.cuda.is_available():
    deviceTag = torch.device('cuda')
else:
    deviceTag = torch.device('cpu')
print(deviceTag)

cuda


In [3]:
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

# LOAD DATASET

In [4]:
# get train IDs
train_fns = glob.glob(gt_dir + '0*.ARW')
train_ids = [int(os.path.basename(train_fn)[0:5]) for train_fn in train_fns]

ps = 512  # patch size for training
save_freq = 500

DEBUG = 0
if DEBUG == 1:
    save_freq = 2
    train_ids = train_ids[0:5]


# UNET MODULES

In [5]:
class conv_lrelu(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(conv_lrelu, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_ch,out_ch,3, padding = 1),nn.LeakyReLU())

    def forward(self, x):
        x = self.conv(x)
        return x
    
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.conv1 = conv_lrelu(in_ch,out_ch)
        self.conv2 = conv_lrelu(out_ch,out_ch)
        self.down =  nn.MaxPool2d((2,2))
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.down(x)
        return x
    

class up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up, self).__init__()       
        self.up =  nn.UpsamplingBilinear2d(scale_factor = 2)
        self.conv1 = conv_lrelu(in_ch,out_ch) 
        self.conv2 = conv_lrelu(out_ch,out_ch) 

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class UNet(nn.Module):
    def __init__(self, in_ch = 4, CH_PER_SCALE = [32,64,128,256,512], out_ch = 12):
        super(UNet, self).__init__()
        self.inc = conv_lrelu(in_ch, CH_PER_SCALE[0])
        self.inc2 = conv_lrelu(CH_PER_SCALE[0], CH_PER_SCALE[0])
        self.down1 = down(CH_PER_SCALE[0], CH_PER_SCALE[1]) 
        self.down2 = down(CH_PER_SCALE[1],CH_PER_SCALE[2])
        self.down3 = down(CH_PER_SCALE[2],CH_PER_SCALE[3])                
        self.down4 = down(CH_PER_SCALE[3],CH_PER_SCALE[4])                
        self.up1 = up(CH_PER_SCALE[4]+CH_PER_SCALE[3],CH_PER_SCALE[3])
        self.up2 = up(CH_PER_SCALE[3]+CH_PER_SCALE[2],CH_PER_SCALE[2])
        self.up3 = up(CH_PER_SCALE[2]+CH_PER_SCALE[1],CH_PER_SCALE[1])
        self.up4 = up(CH_PER_SCALE[1]+CH_PER_SCALE[0],CH_PER_SCALE[0])
        self.outc = nn.Conv2d(CH_PER_SCALE[0], out_ch, 1, padding = 0)

    def forward(self, x):
        x0 = self.inc(x)
        x0 = self.inc2(x0)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x3_up = self.up1(x4,x3)
        x2_up = self.up2(x3_up,x2)
        x1_up = self.up3(x2_up,x1)
        out = self.up4(x1_up,x0)
        out = self.outc(out)
#         out = F.pixel_shuffle(out,2) ## Paper final step rearranges 12 channes to 3 RGB channels
#         out = F.hardtanh(out, min_val=0, max_val=1) #Clamp the top and bottom to 0,1 since pixels can only be in this value
        return out
    
class PRIDNet(nn.Module):
    def __init__(self, in_ch = 4, out_ch = 12):
        super(PRIDNet, self).__init__()
        self.feature_extraction = nn.Sequential(conv_lrelu(in_ch, 32), *[conv_lrelu(32, 32) for i in range(3)])
        self.unet0 = UNet(in_ch = 32, out_ch = 12)
        self.unet1 = UNet(in_ch = 32, out_ch = 12)
        self.unet2 = UNet(in_ch = 32, out_ch = 12)
        self.unet3 = UNet(in_ch = 32, out_ch = 12)
        self.unet4 = UNet(in_ch = 32, out_ch = 12)
        self.avgpool1 = nn.AvgPool2d((2,2))
        self.avgpool2 = nn.AvgPool2d((4,4))
        self.avgpool3 = nn.AvgPool2d((8,8))
        self.avgpool4 = nn.AvgPool2d((16,16))
        self.up4 =  nn.UpsamplingBilinear2d(scale_factor = 16)
        self.up3 =  nn.UpsamplingBilinear2d(scale_factor = 8)
        self.up2 =  nn.UpsamplingBilinear2d(scale_factor = 4)
        self.up1 =  nn.UpsamplingBilinear2d(scale_factor = 2)
        self.out =  nn.Conv2d(32+12*5, out_ch, 1, padding = 0)

    def forward(self, x):
        x_feat = self.feature_extraction(x)
        x0 = self.unet0(x_feat)
        x1 = self.up1(self.unet1(self.avgpool1(x_feat)))
        x2 = self.up2(self.unet2(self.avgpool2(x_feat)))
        x3 = self.up3(self.unet3(self.avgpool3(x_feat)))
        x4 = self.up4(self.unet4(self.avgpool4(x_feat)))
        x_unet_all = torch.cat([x_feat,x0,x1,x2,x3,x4], axis = 1)
        out = self.out(x_unet_all)
        
        out = F.pixel_shuffle(out,2) ## Paper final step rearranges 12 channes to 3 RGB channels
        out = F.hardtanh(out, min_val=0, max_val=1) #Clamp the top and bottom to 0,1 since pixels can only be in this value
        return out
    
    def load_my_state_dict(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            #if isinstance(param, self.Parameter):
            else:
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)

# Helper Functions for packing raw and saving images

In [6]:
def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out

In [7]:
# Raw data takes long time to load. Keep them in memory after loaded.
gt_images = [None] * 6000
input_images = {}
input_images['300'] = [None] * len(train_ids)
input_images['250'] = [None] * len(train_ids)
input_images['100'] = [None] * len(train_ids)

g_loss = np.zeros((5000, 1))

allfolders = glob.glob(result_dir + '*0')

for folder in allfolders:
    lastepoch = np.maximum(epochs, int(folder[-4:]))

# Training 

In [8]:
def process_img(input_raw_img, model, ratio):
    ## Process image(s) using the given model
    # input_raw_img: numpy array, dimension: (Batch,Height,Width,Channel)
    # ratio: numpy array, dimension: (Batch,)
    model.eval();
    model.to(deviceTag)
    ratio = ratio.reshape(ratio.shape[0],1,1,1)
    input_raw_img = np.transpose(input_raw_img, [0,3,1,2]).astype('float32')*ratio
    input_tensor = torch.from_numpy(input_raw_img.copy()).float().to(deviceTag)
    with torch.no_grad():
        output_tensor = model(input_tensor)
    output_img = output_tensor.cpu().numpy()
    output_img = np.transpose(output_img, [0,2,3,1])
    
    return output_img
    
def validate(model, input_list, gt_list, block_size = None, batch_size = 8):
    assert len(input_list) == len(gt_list)
    
    model.eval();
    PSNR_list = []
    
    for i in range(len(input_list)//batch_size):
        if i%10 == 0:
            print(i)
        input_raw_img_batch = []
        gt_img_batch = []
        ratio_batch = []
        for b in range(batch_size):
            if i*batch_size+b < len(input_list):
                in_path = input_list[i*batch_size+b]
                gt_path = gt_list[i*batch_size+b]
            else:
                break
            in_fn = os.path.basename(in_path)
            gt_fn = os.path.basename(gt_path)
            in_exposure = float(in_fn[9:-5])
            gt_exposure = float(gt_fn[9:-5])
            ratio = min(gt_exposure / in_exposure, 300)
        
            raw = rawpy.imread(in_path)
            input_raw_img = pack_raw(raw)
            
            gt_raw = rawpy.imread(gt_path)
            gt_img = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt_img = np.float32(gt_img / 65535.0)
            
            if block_size is not None:
                i_cut, j_cut = np.random.randint(0,input_raw_img.shape[0]-block_size), np.random.randint(0,input_raw_img.shape[1]-block_size)
                gt_img = gt_img[i_cut*2:i_cut*2+block_size*2, j_cut*2:j_cut*2+block_size*2, :]
                input_raw_img = input_raw_img[i_cut:i_cut+block_size, j_cut:j_cut+block_size, :]
            
            ratio_batch.append(ratio)
            input_raw_img_batch.append(input_raw_img)
            gt_img_batch.append(gt_img)
        
        input_raw_img_batch = np.array(input_raw_img_batch)
        ratio_batch = np.array(ratio_batch)
        gt_img_batch = np.array(gt_img_batch)
        
        output_img_batch = process_img(input_raw_img_batch, model, ratio_batch)
        plt.figure()
        plt.imshow(gt_img_batch[0,:,:,:])
        plt.title("Ground Truth")
        plt.figure()
        plt.imshow(output_img_batch[0,:,:,:])
        plt.title("Predicted patch")
        MSE = np.mean((output_img_batch.reshape(output_img_batch.shape[0],-1) - gt_img_batch.reshape(gt_img_batch.shape[0],-1))**2, axis = 1)
        PSNR_batch = 10*np.log10(1/MSE)
        PSNR_list.append(list(PSNR_batch))
    
    Val_PSNR = np.mean(PSNR_list)
        
    return Val_PSNR

In [9]:
model = PRIDNet()
model = model.cuda()
model = model.train()

In [None]:
learning_rate = 1e-4 
batch_num = 4
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.1) #Step Scheduler. 
criterion = SSIM(data_range=1, size_average=True, channel=3)
Start_epoch = 0
epochs = 4000
TrainingLossData = np.zeros(epochs)

for epoch in range(Start_epoch, Start_epoch+epochs):
    print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
    epoch_loss = 0 ## Set Epoch Loss
    count = 0;
    batches_processed = 0
    ##This version has a batch size of 1. In the future conside increasing batchsize
    for ind in np.random.permutation(len(train_ids)):    
        # get the path from image id
        train_id = train_ids[ind]
        in_files = glob.glob(input_dir + '%05d_00*.ARW' % train_id)
        in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
        in_fn = os.path.basename(in_path)

        gt_files = glob.glob(gt_dir + '%05d_00*.ARW' % train_id)
        gt_path = gt_files[0]
        gt_fn = os.path.basename(gt_path)
        in_exposure = float(in_fn[9:-5])
        gt_exposure = float(gt_fn[9:-5])
        ratio = min(gt_exposure / in_exposure, 300)

        st = time.time()

        if input_images[str(ratio)[0:3]][ind] is None:
            raw = rawpy.imread(in_path)
            input_images[str(ratio)[0:3]][ind] = np.expand_dims(pack_raw(raw), axis=0) * ratio

            gt_raw = rawpy.imread(gt_path)
            im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt_images[ind] = np.expand_dims(np.float32(im / 65535.0), axis=0)

        # crop
        H = input_images[str(ratio)[0:3]][ind].shape[1]
        W = input_images[str(ratio)[0:3]][ind].shape[2]

        xx = np.random.randint(0, W - ps)
        yy = np.random.randint(0, H - ps)
        input_patch = input_images[str(ratio)[0:3]][ind][:, yy:yy + ps, xx:xx + ps, :]
        gt_patch = gt_images[ind][:, yy * 2:yy * 2 + ps * 2, xx * 2:xx * 2 + ps * 2, :]

        if np.random.randint(2, size=1)[0] == 1:  # random flip
            input_patch = np.flip(input_patch, axis=1)
            gt_patch = np.flip(gt_patch, axis=1)
        if np.random.randint(2, size=1)[0] == 1:
            input_patch = np.flip(input_patch, axis=2)
            gt_patch = np.flip(gt_patch, axis=2)
        if np.random.randint(2, size=1)[0] == 1:  # random transpose
            input_patch = np.transpose(input_patch, (0, 2, 1, 3))
            gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))
        #(1, 512, 512, 4)
        #(1, 1024, 1024, 3)
        input_patch = np.transpose(input_patch, (0,3,1,2))
        input_patch = torch.from_numpy(input_patch.copy()).cuda()
        gt_patch = np.transpose(gt_patch, (0,3,1,2))
        gt_patch = torch.from_numpy(gt_patch.copy()).cuda()
        ##Batch concatenation
        if count%(batch_num)==0:
            input_patch_all = input_patch
            gt_patch_all = gt_patch
        else:
            input_patch_all = torch.cat([input_patch_all, input_patch], dim=0)
            gt_patch_all = torch.cat([gt_patch_all, gt_patch], dim=0)
        ##Every N batches we ship it back 
        if count%(batch_num)==batch_num-1:
            #print(input_patch_all.shape)
            img_pred = model.forward(input_patch_all)
            loss = 1 - criterion(img_pred, gt_patch_all)
            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batches_processed += 1
        count = count +1
    scheduler.step()
    print('Epoch finished ! Loss: {}'.format(epoch_loss / batches_processed))
    if epoch == 0:
        trainF= open(result_dir+"TrainingEpoch.txt","w+")
        trainF.write('Epoch,Train_loss\n')
        trainF.write('{},{}\n'.format(epoch, epoch_loss / batches_processed))
        trainF.close()
    else:
        trainF= open(result_dir+"TrainingEpoch.txt","a")
        trainF.write('{},{}\n'.format(epoch, epoch_loss / batches_processed))
        trainF.close()
    TrainingLossData[epoch] = epoch_loss / batches_processed ## Save for plotting
    ################################################ [TODO] ###################################################
    # Perform validation with eval_net() on the validation data
    # Save the model after every 10 epochs. This save our Memory on HPC.
    ##Save Top results after 95%
    if epoch > epochs*0.95:
        torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))
    if epoch%200 == 0:
        torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))


Starting epoch 1/4000.


  in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
  in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
  in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]


Epoch finished ! Loss: 0.7763022989034652
Starting epoch 2/4000.
Epoch finished ! Loss: 0.6982577428221702
Starting epoch 3/4000.
Epoch finished ! Loss: 0.6785474762320518
Starting epoch 4/4000.
Epoch finished ! Loss: 0.5824663564562798
Starting epoch 5/4000.
Epoch finished ! Loss: 0.5399887874722481
Starting epoch 6/4000.
Epoch finished ! Loss: 0.5064837828278541
Starting epoch 7/4000.
Epoch finished ! Loss: 0.49664628356695173
Starting epoch 8/4000.
Epoch finished ! Loss: 0.4947374314069748
Starting epoch 9/4000.
Epoch finished ! Loss: 0.4861320167779922
Starting epoch 10/4000.
Epoch finished ! Loss: 0.479812479019165
Starting epoch 11/4000.
Epoch finished ! Loss: 0.46261010468006136
Starting epoch 12/4000.
Epoch finished ! Loss: 0.3993742436170578
Starting epoch 13/4000.
Epoch finished ! Loss: 0.3443560630083084
Starting epoch 14/4000.
Epoch finished ! Loss: 0.33322286456823347
Starting epoch 15/4000.
Epoch finished ! Loss: 0.3272545367479324
Starting epoch 16/4000.
Epoch finished !

Epoch finished ! Loss: 0.2541937053203583
Starting epoch 126/4000.
Epoch finished ! Loss: 0.26231185495853426
Starting epoch 127/4000.
Epoch finished ! Loss: 0.2538023337721825
Starting epoch 128/4000.
Epoch finished ! Loss: 0.23885460644960405
Starting epoch 129/4000.
Epoch finished ! Loss: 0.24495213776826857
Starting epoch 130/4000.
Epoch finished ! Loss: 0.2516783058643341
Starting epoch 131/4000.
Epoch finished ! Loss: 0.2500741809606552
Starting epoch 132/4000.
Epoch finished ! Loss: 0.2493017092347145
Starting epoch 133/4000.
Epoch finished ! Loss: 0.2501416191458702
Starting epoch 134/4000.
Epoch finished ! Loss: 0.2551036819815636
Starting epoch 135/4000.
Epoch finished ! Loss: 0.25447954684495927
Starting epoch 136/4000.
Epoch finished ! Loss: 0.25209175050258636
Starting epoch 137/4000.
Epoch finished ! Loss: 0.25378765612840654
Starting epoch 138/4000.
Epoch finished ! Loss: 0.25054737031459806
Starting epoch 139/4000.
Epoch finished ! Loss: 0.25461516678333285
Starting epo

Epoch finished ! Loss: 0.24518512338399887
Starting epoch 248/4000.
Epoch finished ! Loss: 0.23551846742630006
Starting epoch 249/4000.
Epoch finished ! Loss: 0.239611978828907
Starting epoch 250/4000.
Epoch finished ! Loss: 0.2382198691368103
Starting epoch 251/4000.
Epoch finished ! Loss: 0.2328857496380806
Starting epoch 252/4000.
Epoch finished ! Loss: 0.22934969067573546
Starting epoch 253/4000.
Epoch finished ! Loss: 0.24930980652570725
Starting epoch 254/4000.
Epoch finished ! Loss: 0.24536050707101822
Starting epoch 255/4000.
Epoch finished ! Loss: 0.23497429639101028
Starting epoch 256/4000.
Epoch finished ! Loss: 0.23290574550628662
Starting epoch 257/4000.
Epoch finished ! Loss: 0.23766859620809555
Starting epoch 258/4000.
Epoch finished ! Loss: 0.23498857021331787
Starting epoch 259/4000.
Epoch finished ! Loss: 0.23813380002975465
Starting epoch 260/4000.
Epoch finished ! Loss: 0.23252004534006118
Starting epoch 261/4000.
Epoch finished ! Loss: 0.2375602975487709
Starting e

Epoch finished ! Loss: 0.2244450718164444
Starting epoch 370/4000.
Epoch finished ! Loss: 0.22109983414411544
Starting epoch 371/4000.
Epoch finished ! Loss: 0.2246589720249176
Starting epoch 372/4000.
Epoch finished ! Loss: 0.21546713560819625
Starting epoch 373/4000.
Epoch finished ! Loss: 0.2275662437081337
Starting epoch 374/4000.
Epoch finished ! Loss: 0.21905127316713333
Starting epoch 375/4000.
Epoch finished ! Loss: 0.21993454396724701
Starting epoch 376/4000.
Epoch finished ! Loss: 0.2234477773308754
Starting epoch 377/4000.
Epoch finished ! Loss: 0.22175341546535493
Starting epoch 378/4000.
Epoch finished ! Loss: 0.21450394839048387
Starting epoch 379/4000.
Epoch finished ! Loss: 0.2183404967188835
Starting epoch 380/4000.
Epoch finished ! Loss: 0.22129897624254227
Starting epoch 381/4000.
Epoch finished ! Loss: 0.22187036871910096
Starting epoch 382/4000.
Epoch finished ! Loss: 0.22565321624279022
Starting epoch 383/4000.
Epoch finished ! Loss: 0.21684794574975969
Starting e

Epoch finished ! Loss: 0.2159949868917465
Starting epoch 492/4000.
Epoch finished ! Loss: 0.211438949406147
Starting epoch 493/4000.
Epoch finished ! Loss: 0.2111328899860382
Starting epoch 494/4000.
Epoch finished ! Loss: 0.22059471309185028
Starting epoch 495/4000.
Epoch finished ! Loss: 0.21712317764759065
Starting epoch 496/4000.
Epoch finished ! Loss: 0.2092873275279999
Starting epoch 497/4000.
Epoch finished ! Loss: 0.2097756028175354
Starting epoch 498/4000.
Epoch finished ! Loss: 0.21733547002077103
Starting epoch 499/4000.
Epoch finished ! Loss: 0.21664490252733232
Starting epoch 500/4000.
Epoch finished ! Loss: 0.21755722761154175
Starting epoch 501/4000.
Epoch finished ! Loss: 0.21979887038469315
Starting epoch 502/4000.
Epoch finished ! Loss: 0.22313294559717178
Starting epoch 503/4000.
Epoch finished ! Loss: 0.21003451198339462
Starting epoch 504/4000.
Epoch finished ! Loss: 0.2019500583410263
Starting epoch 505/4000.
Epoch finished ! Loss: 0.2140850692987442
Starting epoc

Epoch finished ! Loss: 0.2138327181339264
Starting epoch 613/4000.
Epoch finished ! Loss: 0.21398090571165085
Starting epoch 614/4000.
Epoch finished ! Loss: 0.20770560801029206
Starting epoch 615/4000.
Epoch finished ! Loss: 0.21497985273599624
Starting epoch 616/4000.
Epoch finished ! Loss: 0.21260295957326888
Starting epoch 617/4000.
Epoch finished ! Loss: 0.2136277213692665
Starting epoch 618/4000.
Epoch finished ! Loss: 0.2111247092485428
Starting epoch 619/4000.
Epoch finished ! Loss: 0.2101668521761894
Starting epoch 620/4000.
Epoch finished ! Loss: 0.20814866721630096
Starting epoch 621/4000.
Epoch finished ! Loss: 0.2097553864121437
Starting epoch 622/4000.
Epoch finished ! Loss: 0.19972966909408568
Starting epoch 623/4000.
Epoch finished ! Loss: 0.21846118122339248
Starting epoch 624/4000.
Epoch finished ! Loss: 0.21636753380298615
Starting epoch 625/4000.
Epoch finished ! Loss: 0.21251117140054704
Starting epoch 626/4000.
Epoch finished ! Loss: 0.21290062069892884
Starting e

Epoch finished ! Loss: 0.2118067115545273
Starting epoch 734/4000.
Epoch finished ! Loss: 0.1982174411416054
Starting epoch 735/4000.
Epoch finished ! Loss: 0.2074539080262184
Starting epoch 736/4000.
Epoch finished ! Loss: 0.21204494386911393
Starting epoch 737/4000.
Epoch finished ! Loss: 0.21084729582071304
Starting epoch 738/4000.
Epoch finished ! Loss: 0.20128847807645797
Starting epoch 739/4000.
Epoch finished ! Loss: 0.20639189332723618
Starting epoch 740/4000.
Epoch finished ! Loss: 0.20459033995866777
Starting epoch 741/4000.
Epoch finished ! Loss: 0.2031285509467125
Starting epoch 742/4000.
Epoch finished ! Loss: 0.2088189274072647
Starting epoch 743/4000.
Epoch finished ! Loss: 0.2109663501381874
Starting epoch 744/4000.
Epoch finished ! Loss: 0.20545107126235962
Starting epoch 745/4000.
Epoch finished ! Loss: 0.22175138741731643
Starting epoch 746/4000.
Epoch finished ! Loss: 0.2032170608639717
Starting epoch 747/4000.
Epoch finished ! Loss: 0.2072750598192215
Starting epoc

Epoch finished ! Loss: 0.20614922940731048
Starting epoch 856/4000.
Epoch finished ! Loss: 0.20587860494852067
Starting epoch 857/4000.
Epoch finished ! Loss: 0.1954764410853386
Starting epoch 858/4000.
Epoch finished ! Loss: 0.2047959804534912
Starting epoch 859/4000.
Epoch finished ! Loss: 0.20054930299520493
Starting epoch 860/4000.
Epoch finished ! Loss: 0.20691322386264802
Starting epoch 861/4000.
Epoch finished ! Loss: 0.19505432844161988
Starting epoch 862/4000.
Epoch finished ! Loss: 0.2018860474228859
Starting epoch 863/4000.
Epoch finished ! Loss: 0.20128166824579238
Starting epoch 864/4000.
Epoch finished ! Loss: 0.20170571953058242
Starting epoch 865/4000.
Epoch finished ! Loss: 0.20480316877365112
Starting epoch 866/4000.
Epoch finished ! Loss: 0.20054162293672562
Starting epoch 867/4000.
Epoch finished ! Loss: 0.20798581838607788
Starting epoch 868/4000.
Epoch finished ! Loss: 0.1956069827079773
Starting epoch 869/4000.
Epoch finished ! Loss: 0.2185867965221405
Starting e

Epoch finished ! Loss: 0.20193127691745758
Starting epoch 977/4000.
Epoch finished ! Loss: 0.20556644946336747
Starting epoch 978/4000.
Epoch finished ! Loss: 0.20072634369134904
Starting epoch 979/4000.
Epoch finished ! Loss: 0.20778914391994477
Starting epoch 980/4000.
Epoch finished ! Loss: 0.19581746459007263
Starting epoch 981/4000.
Epoch finished ! Loss: 0.1970319926738739
Starting epoch 982/4000.
Epoch finished ! Loss: 0.20257450193166732
Starting epoch 983/4000.
Epoch finished ! Loss: 0.2042734831571579
Starting epoch 984/4000.
Epoch finished ! Loss: 0.189560566842556
Starting epoch 985/4000.
Epoch finished ! Loss: 0.2093104526400566
Starting epoch 986/4000.
Epoch finished ! Loss: 0.2067848563194275
Starting epoch 987/4000.
Epoch finished ! Loss: 0.194984370470047
Starting epoch 988/4000.
Epoch finished ! Loss: 0.20312001854181289
Starting epoch 989/4000.
Epoch finished ! Loss: 0.20490215271711348
Starting epoch 990/4000.
Epoch finished ! Loss: 0.20007357597351075
Starting epoc

Epoch finished ! Loss: 0.194363072514534
Starting epoch 1097/4000.
Epoch finished ! Loss: 0.19848954528570176
Starting epoch 1098/4000.
Epoch finished ! Loss: 0.20656623542308808
Starting epoch 1099/4000.
Epoch finished ! Loss: 0.2077740877866745
Starting epoch 1100/4000.
Epoch finished ! Loss: 0.19665133506059645
Starting epoch 1101/4000.
Epoch finished ! Loss: 0.19765679985284806
Starting epoch 1102/4000.
Epoch finished ! Loss: 0.19701558947563172
Starting epoch 1103/4000.
Epoch finished ! Loss: 0.19902787804603578
Starting epoch 1104/4000.
Epoch finished ! Loss: 0.1919052854180336
Starting epoch 1105/4000.
Epoch finished ! Loss: 0.19847490638494492
Starting epoch 1106/4000.
Epoch finished ! Loss: 0.19976738691329957
Starting epoch 1107/4000.
Epoch finished ! Loss: 0.19103931486606598
Starting epoch 1108/4000.
Epoch finished ! Loss: 0.19879295974969863
Starting epoch 1109/4000.
Epoch finished ! Loss: 0.19345795214176179
Starting epoch 1110/4000.
Epoch finished ! Loss: 0.1996288612484

Epoch finished ! Loss: 0.18993605077266693
Starting epoch 1217/4000.
Epoch finished ! Loss: 0.1918968752026558
Starting epoch 1218/4000.
Epoch finished ! Loss: 0.20312466323375702
Starting epoch 1219/4000.
Epoch finished ! Loss: 0.19084702283143998
Starting epoch 1220/4000.
Epoch finished ! Loss: 0.20582564622163774
Starting epoch 1221/4000.
Epoch finished ! Loss: 0.2017810747027397
Starting epoch 1222/4000.
Epoch finished ! Loss: 0.2057544246315956
Starting epoch 1223/4000.
Epoch finished ! Loss: 0.19568010568618774
Starting epoch 1224/4000.
Epoch finished ! Loss: 0.18849794268608094
Starting epoch 1225/4000.
Epoch finished ! Loss: 0.19469232410192489
Starting epoch 1226/4000.
Epoch finished ! Loss: 0.1989203453063965
Starting epoch 1227/4000.
Epoch finished ! Loss: 0.19579901844263076
Starting epoch 1228/4000.
Epoch finished ! Loss: 0.19529208838939666
Starting epoch 1229/4000.
Epoch finished ! Loss: 0.1941780313849449
Starting epoch 1230/4000.
Epoch finished ! Loss: 0.19074708968400

Epoch finished ! Loss: 0.19362212717533112
Starting epoch 1337/4000.
Epoch finished ! Loss: 0.1886645793914795
Starting epoch 1338/4000.
Epoch finished ! Loss: 0.19246279299259186
Starting epoch 1339/4000.
Epoch finished ! Loss: 0.18908171504735946
Starting epoch 1340/4000.
Epoch finished ! Loss: 0.19775157272815705
Starting epoch 1341/4000.
Epoch finished ! Loss: 0.19370991587638856
Starting epoch 1342/4000.
Epoch finished ! Loss: 0.1949683368206024
Starting epoch 1343/4000.
Epoch finished ! Loss: 0.1864302545785904
Starting epoch 1344/4000.
Epoch finished ! Loss: 0.18898012191057206
Starting epoch 1345/4000.
Epoch finished ! Loss: 0.19622931480407715
Starting epoch 1346/4000.
Epoch finished ! Loss: 0.19847262948751448
Starting epoch 1347/4000.
Epoch finished ! Loss: 0.1991350084543228
Starting epoch 1348/4000.
Epoch finished ! Loss: 0.19432485550642015
Starting epoch 1349/4000.
Epoch finished ! Loss: 0.20154054164886476
Starting epoch 1350/4000.
Epoch finished ! Loss: 0.1883544266223

Epoch finished ! Loss: 0.1853237956762314
Starting epoch 1457/4000.
Epoch finished ! Loss: 0.18685756921768187
Starting epoch 1458/4000.
Epoch finished ! Loss: 0.18688192516565322
Starting epoch 1459/4000.
Epoch finished ! Loss: 0.19798744916915895
Starting epoch 1460/4000.
Epoch finished ! Loss: 0.19877853691577912
Starting epoch 1461/4000.
Epoch finished ! Loss: 0.19320734441280366
Starting epoch 1462/4000.
Epoch finished ! Loss: 0.18452731966972352
Starting epoch 1463/4000.
Epoch finished ! Loss: 0.19282863587141036
Starting epoch 1464/4000.
Epoch finished ! Loss: 0.1985216572880745
Starting epoch 1465/4000.
Epoch finished ! Loss: 0.19530950486660004
Starting epoch 1466/4000.
Epoch finished ! Loss: 0.1965022787451744
Starting epoch 1467/4000.
Epoch finished ! Loss: 0.1848030284047127
Starting epoch 1468/4000.
Epoch finished ! Loss: 0.1856882467865944
Starting epoch 1469/4000.
Epoch finished ! Loss: 0.18795979470014573
Starting epoch 1470/4000.
Epoch finished ! Loss: 0.19600500166416

Epoch finished ! Loss: 0.184169040620327
Starting epoch 1577/4000.
Epoch finished ! Loss: 0.185898058116436
Starting epoch 1578/4000.
Epoch finished ! Loss: 0.18185312151908875
Starting epoch 1579/4000.
Epoch finished ! Loss: 0.192366886138916
Starting epoch 1580/4000.
Epoch finished ! Loss: 0.1842030853033066
Starting epoch 1581/4000.
Epoch finished ! Loss: 0.19150419235229493
Starting epoch 1582/4000.
Epoch finished ! Loss: 0.18459993004798889
Starting epoch 1583/4000.
Epoch finished ! Loss: 0.18465090841054915
Starting epoch 1584/4000.
Epoch finished ! Loss: 0.1876888543367386
Starting epoch 1585/4000.
Epoch finished ! Loss: 0.1891316756606102
Starting epoch 1586/4000.
Epoch finished ! Loss: 0.19322687685489653
Starting epoch 1587/4000.
Epoch finished ! Loss: 0.18564959317445756
Starting epoch 1588/4000.
Epoch finished ! Loss: 0.1891741394996643
Starting epoch 1589/4000.
Epoch finished ! Loss: 0.1966513454914093
Starting epoch 1590/4000.
Epoch finished ! Loss: 0.18847012668848037
St

Epoch finished ! Loss: 0.18801751881837844
Starting epoch 1697/4000.
Epoch finished ! Loss: 0.18880511671304703
Starting epoch 1698/4000.
Epoch finished ! Loss: 0.18682534396648406
Starting epoch 1699/4000.
Epoch finished ! Loss: 0.18632492572069168
Starting epoch 1700/4000.
Epoch finished ! Loss: 0.18629974722862244
Starting epoch 1701/4000.
Epoch finished ! Loss: 0.18878764510154725
Starting epoch 1702/4000.
Epoch finished ! Loss: 0.18498342633247375
Starting epoch 1703/4000.
Epoch finished ! Loss: 0.18773774951696395
Starting epoch 1704/4000.
Epoch finished ! Loss: 0.1865402042865753
Starting epoch 1705/4000.
Epoch finished ! Loss: 0.19065366834402084
Starting epoch 1706/4000.
Epoch finished ! Loss: 0.18500795513391494
Starting epoch 1707/4000.
Epoch finished ! Loss: 0.18073839098215103
Starting epoch 1708/4000.
Epoch finished ! Loss: 0.1884375259280205
Starting epoch 1709/4000.
Epoch finished ! Loss: 0.1864461988210678
Starting epoch 1710/4000.
Epoch finished ! Loss: 0.193445059657

Epoch finished ! Loss: 0.18845455199480057
Starting epoch 1817/4000.
Epoch finished ! Loss: 0.18204805105924607
Starting epoch 1818/4000.
Epoch finished ! Loss: 0.17998969554901123
Starting epoch 1819/4000.
Epoch finished ! Loss: 0.1891898110508919
Starting epoch 1820/4000.
Epoch finished ! Loss: 0.18098429292440416
Starting epoch 1821/4000.
Epoch finished ! Loss: 0.18600234985351563
Starting epoch 1822/4000.
Epoch finished ! Loss: 0.18401769548654556
Starting epoch 1823/4000.
Epoch finished ! Loss: 0.18751427084207534
Starting epoch 1824/4000.
Epoch finished ! Loss: 0.1855562373995781
Starting epoch 1825/4000.
Epoch finished ! Loss: 0.18475028723478318
Starting epoch 1826/4000.
Epoch finished ! Loss: 0.18494517207145691
Starting epoch 1827/4000.
Epoch finished ! Loss: 0.18299180418252944
Starting epoch 1828/4000.
Epoch finished ! Loss: 0.18569518625736237
Starting epoch 1829/4000.
Epoch finished ! Loss: 0.1899841845035553
Starting epoch 1830/4000.
Epoch finished ! Loss: 0.185684433579

Epoch finished ! Loss: 0.18376102447509765
Starting epoch 1937/4000.
Epoch finished ! Loss: 0.18648373931646348
Starting epoch 1938/4000.
Epoch finished ! Loss: 0.18357032537460327
Starting epoch 1939/4000.
Epoch finished ! Loss: 0.1799742177128792
Starting epoch 1940/4000.
Epoch finished ! Loss: 0.18630622029304506
Starting epoch 1941/4000.
Epoch finished ! Loss: 0.18495163321495056
Starting epoch 1942/4000.
Epoch finished ! Loss: 0.1787436157464981
Starting epoch 1943/4000.
Epoch finished ! Loss: 0.19691064357757568
Starting epoch 1944/4000.
Epoch finished ! Loss: 0.19018793404102324
Starting epoch 1945/4000.
Epoch finished ! Loss: 0.20616901963949202
Starting epoch 1946/4000.
Epoch finished ! Loss: 0.18721343874931334
Starting epoch 1947/4000.
Epoch finished ! Loss: 0.18497568666934966
Starting epoch 1948/4000.
Epoch finished ! Loss: 0.18431828320026397
Starting epoch 1949/4000.
Epoch finished ! Loss: 0.18863056600093842
Starting epoch 1950/4000.
Epoch finished ! Loss: 0.18377104997

Epoch finished ! Loss: 0.18229379802942275
Starting epoch 2057/4000.
Epoch finished ! Loss: 0.18194645792245864
Starting epoch 2058/4000.
Epoch finished ! Loss: 0.16832907497882843
Starting epoch 2059/4000.
Epoch finished ! Loss: 0.1786426916718483
Starting epoch 2060/4000.
Epoch finished ! Loss: 0.17563137412071228
Starting epoch 2061/4000.
Epoch finished ! Loss: 0.17299506962299346
Starting epoch 2062/4000.
Epoch finished ! Loss: 0.17554180324077606
Starting epoch 2063/4000.
Epoch finished ! Loss: 0.17187435775995255
Starting epoch 2064/4000.
Epoch finished ! Loss: 0.17417063117027282
Starting epoch 2065/4000.
Epoch finished ! Loss: 0.1712791383266449
Starting epoch 2066/4000.
Epoch finished ! Loss: 0.17465739250183104
Starting epoch 2067/4000.
Epoch finished ! Loss: 0.18057911843061447
Starting epoch 2068/4000.
Epoch finished ! Loss: 0.1815079391002655
Starting epoch 2069/4000.
Epoch finished ! Loss: 0.17491934597492217
Starting epoch 2070/4000.
Epoch finished ! Loss: 0.175411455333

Epoch finished ! Loss: 0.18044889718294144
Starting epoch 2177/4000.
Epoch finished ! Loss: 0.17680803686380386
Starting epoch 2178/4000.
Epoch finished ! Loss: 0.17623939067125322
Starting epoch 2179/4000.
Epoch finished ! Loss: 0.17188296616077423
Starting epoch 2180/4000.
Epoch finished ! Loss: 0.17252619862556456
Starting epoch 2181/4000.
Epoch finished ! Loss: 0.17065289318561555
Starting epoch 2182/4000.
Epoch finished ! Loss: 0.17741223126649858
Starting epoch 2183/4000.
Epoch finished ! Loss: 0.16825465410947799
Starting epoch 2184/4000.
Epoch finished ! Loss: 0.17453915029764175
Starting epoch 2185/4000.
Epoch finished ! Loss: 0.1723368749022484
Starting epoch 2186/4000.
Epoch finished ! Loss: 0.17580861747264862
Starting epoch 2187/4000.
Epoch finished ! Loss: 0.17473426908254625
Starting epoch 2188/4000.
Epoch finished ! Loss: 0.17704480439424514
Starting epoch 2189/4000.
Epoch finished ! Loss: 0.1719101533293724
Starting epoch 2190/4000.
Epoch finished ! Loss: 0.16840359419

Epoch finished ! Loss: 0.17376600950956345
Starting epoch 2297/4000.
Epoch finished ! Loss: 0.18170257806777954
Starting epoch 2298/4000.
Epoch finished ! Loss: 0.17130912989377975
Starting epoch 2299/4000.
Epoch finished ! Loss: 0.17691943645477295
Starting epoch 2300/4000.
Epoch finished ! Loss: 0.18163011521100997
Starting epoch 2301/4000.
Epoch finished ! Loss: 0.16756696701049806
Starting epoch 2302/4000.
Epoch finished ! Loss: 0.1699868842959404
Starting epoch 2303/4000.
Epoch finished ! Loss: 0.17135332971811296
Starting epoch 2304/4000.
Epoch finished ! Loss: 0.1753877267241478
Starting epoch 2305/4000.
Epoch finished ! Loss: 0.1744704455137253
Starting epoch 2306/4000.
Epoch finished ! Loss: 0.17292397022247313
Starting epoch 2307/4000.
Epoch finished ! Loss: 0.17914474606513978
Starting epoch 2308/4000.
Epoch finished ! Loss: 0.17214266657829286
Starting epoch 2309/4000.
Epoch finished ! Loss: 0.17167214602231978
Starting epoch 2310/4000.
Epoch finished ! Loss: 0.173949742317

Epoch finished ! Loss: 0.18126543313264848
Starting epoch 2417/4000.
Epoch finished ! Loss: 0.1720343768596649
Starting epoch 2418/4000.
Epoch finished ! Loss: 0.17673512995243074
Starting epoch 2419/4000.
Epoch finished ! Loss: 0.17503172308206558
Starting epoch 2420/4000.
Epoch finished ! Loss: 0.17033858895301818
Starting epoch 2421/4000.
Epoch finished ! Loss: 0.1783323183655739
Starting epoch 2422/4000.
Epoch finished ! Loss: 0.1802534967660904
Starting epoch 2423/4000.
Epoch finished ! Loss: 0.17536055892705918
Starting epoch 2424/4000.
Epoch finished ! Loss: 0.17183175534009934
Starting epoch 2425/4000.
Epoch finished ! Loss: 0.17634406685829163
Starting epoch 2426/4000.
Epoch finished ! Loss: 0.17700812369585037
Starting epoch 2427/4000.
Epoch finished ! Loss: 0.17470148354768752
Starting epoch 2428/4000.
Epoch finished ! Loss: 0.17701360136270522
Starting epoch 2429/4000.
Epoch finished ! Loss: 0.177689266204834
Starting epoch 2430/4000.
Epoch finished ! Loss: 0.17939429730176

Epoch finished ! Loss: 0.1753830075263977
Starting epoch 2537/4000.
Epoch finished ! Loss: 0.1742592990398407
Starting epoch 2538/4000.
Epoch finished ! Loss: 0.17336484789848328
Starting epoch 2539/4000.
Epoch finished ! Loss: 0.1739479809999466
Starting epoch 2540/4000.
Epoch finished ! Loss: 0.17647797763347625
Starting epoch 2541/4000.
Epoch finished ! Loss: 0.17306195944547653
Starting epoch 2542/4000.
Epoch finished ! Loss: 0.18148289620876312
Starting epoch 2543/4000.
Epoch finished ! Loss: 0.17102064192295074
Starting epoch 2544/4000.
Epoch finished ! Loss: 0.17571442276239396
Starting epoch 2545/4000.
Epoch finished ! Loss: 0.17708104997873306
Starting epoch 2546/4000.
Epoch finished ! Loss: 0.17783386558294295
Starting epoch 2547/4000.
Epoch finished ! Loss: 0.1692910373210907
Starting epoch 2548/4000.
Epoch finished ! Loss: 0.17160663902759551
Starting epoch 2549/4000.
Epoch finished ! Loss: 0.1710606798529625
Starting epoch 2550/4000.
Epoch finished ! Loss: 0.17437984645366

Epoch finished ! Loss: 0.1753266215324402
Starting epoch 2657/4000.
Epoch finished ! Loss: 0.1704387202858925
Starting epoch 2658/4000.
Epoch finished ! Loss: 0.1729739561676979
Starting epoch 2659/4000.
Epoch finished ! Loss: 0.1681669145822525
Starting epoch 2660/4000.
Epoch finished ! Loss: 0.16987788528203965
Starting epoch 2661/4000.
Epoch finished ! Loss: 0.16921505481004714
Starting epoch 2662/4000.
Epoch finished ! Loss: 0.1737332373857498
Starting epoch 2663/4000.
Epoch finished ! Loss: 0.179964779317379
Starting epoch 2664/4000.
Epoch finished ! Loss: 0.17468404471874238
Starting epoch 2665/4000.
Epoch finished ! Loss: 0.17208441495895385
Starting epoch 2666/4000.
Epoch finished ! Loss: 0.17543412894010543
Starting epoch 2667/4000.
Epoch finished ! Loss: 0.17496222704648973
Starting epoch 2668/4000.
Epoch finished ! Loss: 0.17637979388237
Starting epoch 2669/4000.
Epoch finished ! Loss: 0.17460937052965164
Starting epoch 2670/4000.
Epoch finished ! Loss: 0.1740424230694771
St

In [None]:
if os.path.isdir(model_save_path):
            torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))
else:
    os.makedirs(model_save_path, exist_ok=True)
    torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))