In [1]:
from __future__ import division
import os, time
import numpy as np
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import sys
import os
from optparse import OptionParser
import numpy as np
from torch import optim
from PIL import Image
from torch.autograd import Function, Variable
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset
import cv2
import glob
import pickle
from tqdm import tqdm
import rawpy
%matplotlib inline

input_dir = './Dataset/Sony/short/'
gt_dir = './Dataset/Sony/long/'

result_dir = './results_SWA/'
model_save_path = './results_SWA/net_weights/'
if os.path.isdir(model_save_path) is not True:
    os.makedirs(model_save_path, exist_ok=True)

In [2]:
if torch.cuda.is_available():
    deviceTag = torch.device('cuda')
else:
    deviceTag = torch.device('cpu')
print(deviceTag)

cuda


In [3]:
from utils.swa import SWA

# LOAD DATASET

In [4]:
# get train IDs
train_fns = glob.glob(gt_dir + '0*.ARW')
train_ids = [int(os.path.basename(train_fn)[0:5]) for train_fn in train_fns]

ps = 512  # patch size for training
save_freq = 500

DEBUG = 0
if DEBUG == 1:
    save_freq = 2
    train_ids = train_ids[0:5]


# UNET MODULES

In [5]:
class conv_lrelu(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(conv_lrelu, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_ch,out_ch,3, padding = 1),nn.LeakyReLU())

    def forward(self, x):
        x = self.conv(x)
        return x
    
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.conv1 = conv_lrelu(in_ch,out_ch)
        self.conv2 = conv_lrelu(out_ch,out_ch)
        self.down =  nn.MaxPool2d((2,2))
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.down(x)
        return x
    

class up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up, self).__init__()       
        self.up =  nn.UpsamplingBilinear2d(scale_factor = 2)
        self.conv1 = conv_lrelu(in_ch,out_ch) 
        self.conv2 = conv_lrelu(out_ch,out_ch) 

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class UNet(nn.Module):
    def __init__(self, in_ch = 4, CH_PER_SCALE = [32,64,128,256,512], out_ch = 12):
        super(UNet, self).__init__()
        self.inc = conv_lrelu(in_ch, CH_PER_SCALE[0])
        self.inc2 = conv_lrelu(CH_PER_SCALE[0], CH_PER_SCALE[0])
        self.down1 = down(CH_PER_SCALE[0], CH_PER_SCALE[1]) 
        self.down2 = down(CH_PER_SCALE[1],CH_PER_SCALE[2])
        self.down3 = down(CH_PER_SCALE[2],CH_PER_SCALE[3])                
        self.down4 = down(CH_PER_SCALE[3],CH_PER_SCALE[4])                
        self.up1 = up(CH_PER_SCALE[4]+CH_PER_SCALE[3],CH_PER_SCALE[3])
        self.up2 = up(CH_PER_SCALE[3]+CH_PER_SCALE[2],CH_PER_SCALE[2])
        self.up3 = up(CH_PER_SCALE[2]+CH_PER_SCALE[1],CH_PER_SCALE[1])
        self.up4 = up(CH_PER_SCALE[1]+CH_PER_SCALE[0],CH_PER_SCALE[0])
        self.outc = nn.Conv2d(CH_PER_SCALE[0], out_ch, 1, padding = 0)

    def forward(self, x):
        x0 = self.inc(x)
        x0 = self.inc2(x0)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x3_up = self.up1(x4,x3)
        x2_up = self.up2(x3_up,x2)
        x1_up = self.up3(x2_up,x1)
        out = self.up4(x1_up,x0)
        out = self.outc(out)
#         out = F.pixel_shuffle(out,2) ## Paper final step rearranges 12 channes to 3 RGB channels
#         out = F.hardtanh(out, min_val=0, max_val=1) #Clamp the top and bottom to 0,1 since pixels can only be in this value
        return out
    
class PRIDNet(nn.Module):
    def __init__(self, in_ch = 4, out_ch = 12):
        super(PRIDNet, self).__init__()
        self.feature_extraction = nn.Sequential(conv_lrelu(in_ch, 32), *[conv_lrelu(32, 32) for i in range(3)])
        self.unet0 = UNet(in_ch = 32, out_ch = 12)
        self.unet1 = UNet(in_ch = 32, out_ch = 12)
        self.unet2 = UNet(in_ch = 32, out_ch = 12)
        self.unet3 = UNet(in_ch = 32, out_ch = 12)
        self.unet4 = UNet(in_ch = 32, out_ch = 12)
        self.avgpool1 = nn.AvgPool2d((2,2))
        self.avgpool2 = nn.AvgPool2d((4,4))
        self.avgpool3 = nn.AvgPool2d((8,8))
        self.avgpool4 = nn.AvgPool2d((16,16))
        self.up4 =  nn.UpsamplingBilinear2d(scale_factor = 16)
        self.up3 =  nn.UpsamplingBilinear2d(scale_factor = 8)
        self.up2 =  nn.UpsamplingBilinear2d(scale_factor = 4)
        self.up1 =  nn.UpsamplingBilinear2d(scale_factor = 2)
        self.out =  nn.Conv2d(32+12*5, out_ch, 1, padding = 0)

    def forward(self, x):
        x_feat = self.feature_extraction(x)
        x0 = self.unet0(x_feat)
        x1 = self.up1(self.unet1(self.avgpool1(x_feat)))
        x2 = self.up2(self.unet2(self.avgpool2(x_feat)))
        x3 = self.up3(self.unet3(self.avgpool3(x_feat)))
        x4 = self.up4(self.unet4(self.avgpool4(x_feat)))
        x_unet_all = torch.cat([x_feat,x0,x1,x2,x3,x4], axis = 1)
        out = self.out(x_unet_all)
        
        out = F.pixel_shuffle(out,2) ## Paper final step rearranges 12 channes to 3 RGB channels
        out = F.hardtanh(out, min_val=0, max_val=1) #Clamp the top and bottom to 0,1 since pixels can only be in this value
        return out
    
    def load_my_state_dict(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            #if isinstance(param, self.Parameter):
            else:
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)

# Helper Functions for packing raw and saving images

In [6]:
def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out

In [7]:
# Raw data takes long time to load. Keep them in memory after loaded.
gt_images = [None] * 6000
input_images = {}
input_images['300'] = [None] * len(train_ids)
input_images['250'] = [None] * len(train_ids)
input_images['100'] = [None] * len(train_ids)

g_loss = np.zeros((5000, 1))

allfolders = glob.glob(result_dir + '*0')

for folder in allfolders:
    lastepoch = np.maximum(epochs, int(folder[-4:]))

# Training 

In [8]:
model = PRIDNet()
model = model.cuda()
model = model.train()

In [9]:
def process_img(input_raw_img, model, ratio):
    ## Process image(s) using the given model
    # input_raw_img: numpy array, dimension: (Batch,Height,Width,Channel)
    # ratio: numpy array, dimension: (Batch,)
    model.eval();
    model.to(deviceTag)
    ratio = ratio.reshape(ratio.shape[0],1,1,1)
    input_raw_img = np.transpose(input_raw_img, [0,3,1,2]).astype('float32')*ratio
    input_tensor = torch.from_numpy(input_raw_img.copy()).float().to(deviceTag)
    with torch.no_grad():
        output_tensor = model(input_tensor)
    output_img = output_tensor.cpu().numpy()
    output_img = np.transpose(output_img, [0,2,3,1])
    
    return output_img
    
def validate(model, input_list, gt_list, block_size = None, batch_size = 8):
    assert len(input_list) == len(gt_list)
    
    model.eval();
    PSNR_list = []
    
    for i in range(len(input_list)//batch_size):
        if i%10 == 0:
            print(i)
        input_raw_img_batch = []
        gt_img_batch = []
        ratio_batch = []
        for b in range(batch_size):
            if i*batch_size+b < len(input_list):
                in_path = input_list[i*batch_size+b]
                gt_path = gt_list[i*batch_size+b]
            else:
                break
            in_fn = os.path.basename(in_path)
            gt_fn = os.path.basename(gt_path)
            in_exposure = float(in_fn[9:-5])
            gt_exposure = float(gt_fn[9:-5])
            ratio = min(gt_exposure / in_exposure, 300)
        
            raw = rawpy.imread(in_path)
            input_raw_img = pack_raw(raw)
            
            gt_raw = rawpy.imread(gt_path)
            gt_img = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt_img = np.float32(gt_img / 65535.0)
            
            if block_size is not None:
                i_cut, j_cut = np.random.randint(0,input_raw_img.shape[0]-block_size), np.random.randint(0,input_raw_img.shape[1]-block_size)
                gt_img = gt_img[i_cut*2:i_cut*2+block_size*2, j_cut*2:j_cut*2+block_size*2, :]
                input_raw_img = input_raw_img[i_cut:i_cut+block_size, j_cut:j_cut+block_size, :]
            
            ratio_batch.append(ratio)
            input_raw_img_batch.append(input_raw_img)
            gt_img_batch.append(gt_img)
        
        input_raw_img_batch = np.array(input_raw_img_batch)
        ratio_batch = np.array(ratio_batch)
        gt_img_batch = np.array(gt_img_batch)
        
        output_img_batch = process_img(input_raw_img_batch, model, ratio_batch)
        plt.figure()
        plt.imshow(gt_img_batch[0,:,:,:])
        plt.title("Ground Truth")
        plt.figure()
        plt.imshow(output_img_batch[0,:,:,:])
        plt.title("Predicted patch")
        MSE = np.mean((output_img_batch.reshape(output_img_batch.shape[0],-1) - gt_img_batch.reshape(gt_img_batch.shape[0],-1))**2, axis = 1)
        PSNR_batch = 10*np.log10(1/MSE)
        PSNR_list.append(list(PSNR_batch))
    
    Val_PSNR = np.mean(PSNR_list)
        
    return Val_PSNR

In [None]:
learning_rate = 1e-4
batch_num = 4;
base_opt = optim.SGD(model.parameters(), lr=learning_rate)
optimizer = SWA(base_opt, swa_start=1000, swa_freq=10, swa_lr=0.05)
scheduler = optim.lr_scheduler.StepLR(base_opt, step_size=2000, gamma=0.1) #Step Scheduler. 
criterion = nn.MSELoss()

Start_epoch = 0
epochs = 4000
TrainingLossData = np.zeros(epochs)

for epoch in range(Start_epoch, Start_epoch+epochs):
    print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
    epoch_loss = 0 ## Set Epoch Loss
    count = 0;
    batches_processed = 0
    ##This version has a batch size of 1. In the future conside increasing batchsize
    for ind in np.random.permutation(len(train_ids)):    
        # get the path from image id
        train_id = train_ids[ind]
        in_files = glob.glob(input_dir + '%05d_00*.ARW' % train_id)
        in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
        in_fn = os.path.basename(in_path)

        gt_files = glob.glob(gt_dir + '%05d_00*.ARW' % train_id)
        gt_path = gt_files[0]
        gt_fn = os.path.basename(gt_path)
        in_exposure = float(in_fn[9:-5])
        gt_exposure = float(gt_fn[9:-5])
        ratio = min(gt_exposure / in_exposure, 300)

        st = time.time()

        if input_images[str(ratio)[0:3]][ind] is None:
            raw = rawpy.imread(in_path)
            input_images[str(ratio)[0:3]][ind] = np.expand_dims(pack_raw(raw), axis=0) * ratio

            gt_raw = rawpy.imread(gt_path)
            im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt_images[ind] = np.expand_dims(np.float32(im / 65535.0), axis=0)

        # crop
        H = input_images[str(ratio)[0:3]][ind].shape[1]
        W = input_images[str(ratio)[0:3]][ind].shape[2]

        xx = np.random.randint(0, W - ps)
        yy = np.random.randint(0, H - ps)
        input_patch = input_images[str(ratio)[0:3]][ind][:, yy:yy + ps, xx:xx + ps, :]
        gt_patch = gt_images[ind][:, yy * 2:yy * 2 + ps * 2, xx * 2:xx * 2 + ps * 2, :]

        if np.random.randint(2, size=1)[0] == 1:  # random flip
            input_patch = np.flip(input_patch, axis=1)
            gt_patch = np.flip(gt_patch, axis=1)
        if np.random.randint(2, size=1)[0] == 1:
            input_patch = np.flip(input_patch, axis=2)
            gt_patch = np.flip(gt_patch, axis=2)
        if np.random.randint(2, size=1)[0] == 1:  # random transpose
            input_patch = np.transpose(input_patch, (0, 2, 1, 3))
            gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))
        #(1, 512, 512, 4)
        #(1, 1024, 1024, 3)
        input_patch = np.transpose(input_patch, (0,3,1,2))
        input_patch = torch.from_numpy(input_patch.copy()).cuda()
        gt_patch = np.transpose(gt_patch, (0,3,1,2))
        gt_patch = torch.from_numpy(gt_patch.copy()).cuda()
        ##Batch concatenation
        if count%(batch_num)==0:
            input_patch_all = input_patch
            gt_patch_all = gt_patch
        else:
            input_patch_all = torch.cat([input_patch_all, input_patch], dim=0)
            gt_patch_all = torch.cat([gt_patch_all, gt_patch], dim=0)
        ##Every N batches we ship it back 
        if count%(batch_num)==batch_num-1:
            #print(input_patch_all.shape)
            img_pred = model.forward(input_patch_all)
            loss = criterion(img_pred, gt_patch_all)
            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batches_processed += 1
            ##print(epoch_loss/count);
        count = count +1
    scheduler.step()
    print('Epoch finished ! Loss: {}'.format(epoch_loss / batches_processed))
    if epoch == 0:
        trainF= open(result_dir+"TrainingEpoch.txt","w+")
        trainF.write('Epoch,Train_loss\n')
        trainF.write('{},{}\n'.format(epoch, epoch_loss / batches_processed))
        trainF.close()
    else:
        trainF= open(result_dir+"TrainingEpoch.txt","a")
        trainF.write('{},{}\n'.format(epoch, epoch_loss / batches_processed))
        trainF.close()
    TrainingLossData[epoch] = epoch_loss / batches_processed ## Save for plotting
    ################################################ [TODO] ###################################################
    # Perform validation with eval_net() on the validation data
    # Save the model after every 10 epochs. This save our Memory on HPC.
    ##Save Top results after 95%
    if epoch > epochs*0.95:
        torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))
    if epoch%99 == 0:
        torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))

optimizer.swap_swa_sgd()
torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))

Starting epoch 1/4000.


  in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
  in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
  in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]


Epoch finished ! Loss: 0.08446125553455204
Starting epoch 2/4000.
Epoch finished ! Loss: 0.08514510178938509
Starting epoch 3/4000.
Epoch finished ! Loss: 0.08160317400470377
Starting epoch 4/4000.
Epoch finished ! Loss: 0.08639021832495927
Starting epoch 5/4000.
Epoch finished ! Loss: 0.08197938054800033
Starting epoch 6/4000.
Epoch finished ! Loss: 0.08757441475754604
Starting epoch 7/4000.
Epoch finished ! Loss: 0.08554771190974861
Starting epoch 8/4000.
Epoch finished ! Loss: 0.08134236671030522
Starting epoch 9/4000.
Epoch finished ! Loss: 0.07823340981267393
Starting epoch 10/4000.
Epoch finished ! Loss: 0.08552727324422449
Starting epoch 11/4000.
Epoch finished ! Loss: 0.08355004640761762
Starting epoch 12/4000.
Epoch finished ! Loss: 0.07384061184711754
Starting epoch 13/4000.
Epoch finished ! Loss: 0.07670602169819177
Starting epoch 14/4000.
Epoch finished ! Loss: 0.0822586442809552
Starting epoch 15/4000.
Epoch finished ! Loss: 0.07924061589874327
Starting epoch 16/4000.
Epoc

Epoch finished ! Loss: 0.018764441751409323
Starting epoch 124/4000.
Epoch finished ! Loss: 0.019239833927713335
Starting epoch 125/4000.
Epoch finished ! Loss: 0.0191692448919639
Starting epoch 126/4000.
Epoch finished ! Loss: 0.019322043959982695
Starting epoch 127/4000.
Epoch finished ! Loss: 0.020204559434205294
Starting epoch 128/4000.
Epoch finished ! Loss: 0.01832672767341137
Starting epoch 129/4000.
Epoch finished ! Loss: 0.019791716989129782
Starting epoch 130/4000.
Epoch finished ! Loss: 0.01727230045944452
Starting epoch 131/4000.
Epoch finished ! Loss: 0.019860619597602637
Starting epoch 132/4000.
Epoch finished ! Loss: 0.019500655552837996
Starting epoch 133/4000.
Epoch finished ! Loss: 0.019693822809495033
Starting epoch 134/4000.
Epoch finished ! Loss: 0.01902453196235001
Starting epoch 135/4000.
Epoch finished ! Loss: 0.018019898142665625
Starting epoch 136/4000.
Epoch finished ! Loss: 0.01983831328107044
Starting epoch 137/4000.
Epoch finished ! Loss: 0.020125849952455

Epoch finished ! Loss: 0.018444285925943403
Starting epoch 244/4000.
Epoch finished ! Loss: 0.01871605966007337
Starting epoch 245/4000.
Epoch finished ! Loss: 0.019269630732014775
Starting epoch 246/4000.
Epoch finished ! Loss: 0.017892095958814025
Starting epoch 247/4000.
Epoch finished ! Loss: 0.01817089932737872
Starting epoch 248/4000.
Epoch finished ! Loss: 0.020552623365074397
Starting epoch 249/4000.
Epoch finished ! Loss: 0.018849559500813483
Starting epoch 250/4000.
Epoch finished ! Loss: 0.018419055338017642
Starting epoch 251/4000.
Epoch finished ! Loss: 0.018314186274074017
Starting epoch 252/4000.
Epoch finished ! Loss: 0.019524218165315688
Starting epoch 253/4000.
Epoch finished ! Loss: 0.01863446618663147
Starting epoch 254/4000.
Epoch finished ! Loss: 0.01833842492196709
Starting epoch 255/4000.
Epoch finished ! Loss: 0.019492346327751875
Starting epoch 256/4000.
Epoch finished ! Loss: 0.018941402644850315
Starting epoch 257/4000.
Epoch finished ! Loss: 0.0189248142298

Epoch finished ! Loss: 0.018064879043959082
Starting epoch 364/4000.
Epoch finished ! Loss: 0.018866487534251065
Starting epoch 365/4000.
Epoch finished ! Loss: 0.02001419060397893
Starting epoch 366/4000.
Epoch finished ! Loss: 0.019355735171120613
Starting epoch 367/4000.
Epoch finished ! Loss: 0.017772503849118947
Starting epoch 368/4000.
Epoch finished ! Loss: 0.018239059089682996
Starting epoch 369/4000.
Epoch finished ! Loss: 0.019407858909107744
Starting epoch 370/4000.
Epoch finished ! Loss: 0.018441657139919698
Starting epoch 371/4000.
Epoch finished ! Loss: 0.019100418756715955
Starting epoch 372/4000.
Epoch finished ! Loss: 0.018810968624893575
Starting epoch 373/4000.
Epoch finished ! Loss: 0.018468586762901397
Starting epoch 374/4000.
Epoch finished ! Loss: 0.017336116661317645
Starting epoch 375/4000.
Epoch finished ! Loss: 0.01958663170225918
Starting epoch 376/4000.
Epoch finished ! Loss: 0.01871903776191175
Starting epoch 377/4000.
Epoch finished ! Loss: 0.019585428468

Epoch finished ! Loss: 0.017744511947967112
Starting epoch 484/4000.
Epoch finished ! Loss: 0.016406583099160342
Starting epoch 485/4000.
Epoch finished ! Loss: 0.01682877573184669
Starting epoch 486/4000.
Epoch finished ! Loss: 0.017167076072655617
Starting epoch 487/4000.
Epoch finished ! Loss: 0.016879612777847795
Starting epoch 488/4000.
Epoch finished ! Loss: 0.01689248954644427
Starting epoch 489/4000.
Epoch finished ! Loss: 0.015540538588538766
Starting epoch 490/4000.
Epoch finished ! Loss: 0.017565648618619888
Starting epoch 491/4000.
Epoch finished ! Loss: 0.015679398784413934
Starting epoch 492/4000.
Epoch finished ! Loss: 0.017079961486160755
Starting epoch 493/4000.
Epoch finished ! Loss: 0.016912903543561696
Starting epoch 494/4000.
Epoch finished ! Loss: 0.017161218856927007
Starting epoch 495/4000.
Epoch finished ! Loss: 0.01564009819412604
Starting epoch 496/4000.
Epoch finished ! Loss: 0.016809373162686825
Starting epoch 497/4000.
Epoch finished ! Loss: 0.016878581838

Epoch finished ! Loss: 0.01626053781947121
Starting epoch 604/4000.
Epoch finished ! Loss: 0.016029326338320972
Starting epoch 605/4000.
Epoch finished ! Loss: 0.01606747875921428
Starting epoch 606/4000.
Epoch finished ! Loss: 0.0164550949819386
Starting epoch 607/4000.
Epoch finished ! Loss: 0.0162869913270697
Starting epoch 608/4000.
Epoch finished ! Loss: 0.016924833552911876
Starting epoch 609/4000.
Epoch finished ! Loss: 0.016617010894697158
Starting epoch 610/4000.
Epoch finished ! Loss: 0.01682340889237821
Starting epoch 611/4000.
Epoch finished ! Loss: 0.015584672219119966
Starting epoch 612/4000.
Epoch finished ! Loss: 0.016432605613954366
Starting epoch 613/4000.
Epoch finished ! Loss: 0.016280151158571243
Starting epoch 614/4000.
Epoch finished ! Loss: 0.014936839661095292
Starting epoch 615/4000.
Epoch finished ! Loss: 0.014912414585705847
Starting epoch 616/4000.
Epoch finished ! Loss: 0.016319797351025046
Starting epoch 617/4000.
Epoch finished ! Loss: 0.0154359977226704

Epoch finished ! Loss: 0.01659808137919754
Starting epoch 724/4000.
Epoch finished ! Loss: 0.0160536116338335
Starting epoch 725/4000.
Epoch finished ! Loss: 0.01522807136643678
Starting epoch 726/4000.
Epoch finished ! Loss: 0.015498733113054186
Starting epoch 727/4000.
Epoch finished ! Loss: 0.015508044883608818
Starting epoch 728/4000.
Epoch finished ! Loss: 0.01622874728636816
Starting epoch 729/4000.
Epoch finished ! Loss: 0.0166300430893898
Starting epoch 730/4000.
Epoch finished ! Loss: 0.017214885645080356
Starting epoch 731/4000.
Epoch finished ! Loss: 0.01582873354200274
Starting epoch 732/4000.
Epoch finished ! Loss: 0.01626449531177059
Starting epoch 733/4000.
Epoch finished ! Loss: 0.0155161133792717
Starting epoch 734/4000.
Epoch finished ! Loss: 0.015952844824641942
Starting epoch 735/4000.
Epoch finished ! Loss: 0.016194985061883927
Starting epoch 736/4000.
Epoch finished ! Loss: 0.01657738097710535
Starting epoch 737/4000.
Epoch finished ! Loss: 0.01653813720913604
Sta

Epoch finished ! Loss: 0.0164642671123147
Starting epoch 844/4000.
Epoch finished ! Loss: 0.015429770923219621
Starting epoch 845/4000.
Epoch finished ! Loss: 0.014995967678260058
Starting epoch 846/4000.
Epoch finished ! Loss: 0.015260514104738832
Starting epoch 847/4000.
Epoch finished ! Loss: 0.01587144872173667
Starting epoch 848/4000.
Epoch finished ! Loss: 0.01622283619362861
Starting epoch 849/4000.
Epoch finished ! Loss: 0.015951850940473376
Starting epoch 850/4000.
Epoch finished ! Loss: 0.015588018338894472
Starting epoch 851/4000.
Epoch finished ! Loss: 0.016099754453171043
Starting epoch 852/4000.
Epoch finished ! Loss: 0.01511522161308676
Starting epoch 853/4000.
Epoch finished ! Loss: 0.01606958293123171
Starting epoch 854/4000.
Epoch finished ! Loss: 0.015522980049718171
Starting epoch 855/4000.
Epoch finished ! Loss: 0.015557311871089042
Starting epoch 856/4000.
Epoch finished ! Loss: 0.015157794789411127
Starting epoch 857/4000.
Epoch finished ! Loss: 0.015235130733344

Epoch finished ! Loss: 0.015844947099685668
Starting epoch 964/4000.
Epoch finished ! Loss: 0.015018332877662032
Starting epoch 965/4000.
Epoch finished ! Loss: 0.015745537576731295
Starting epoch 966/4000.
Epoch finished ! Loss: 0.016076971986331047
Starting epoch 967/4000.
Epoch finished ! Loss: 0.016703550424426793
Starting epoch 968/4000.
Epoch finished ! Loss: 0.016613322030752897
Starting epoch 969/4000.
Epoch finished ! Loss: 0.01597054082667455
Starting epoch 970/4000.
Epoch finished ! Loss: 0.015579335298389197
Starting epoch 971/4000.
Epoch finished ! Loss: 0.01567715898854658
Starting epoch 972/4000.
Epoch finished ! Loss: 0.016103364911396058
Starting epoch 973/4000.
Epoch finished ! Loss: 0.01529169644927606
Starting epoch 974/4000.
Epoch finished ! Loss: 0.015578094718512148
Starting epoch 975/4000.
Epoch finished ! Loss: 0.015290629223454744
Starting epoch 976/4000.
Epoch finished ! Loss: 0.015741940814768894
Starting epoch 977/4000.
Epoch finished ! Loss: 0.015738532150

Epoch finished ! Loss: 0.015522755589336157
Starting epoch 1083/4000.
Epoch finished ! Loss: 0.01561445239931345
Starting epoch 1084/4000.
Epoch finished ! Loss: 0.01484025081153959
Starting epoch 1085/4000.
Epoch finished ! Loss: 0.016548784717451782
Starting epoch 1086/4000.
Epoch finished ! Loss: 0.016210314526688308
Starting epoch 1087/4000.
Epoch finished ! Loss: 0.015740317315794526
Starting epoch 1088/4000.
Epoch finished ! Loss: 0.015135767904575914
Starting epoch 1089/4000.
Epoch finished ! Loss: 0.01617949898354709
Starting epoch 1090/4000.
Epoch finished ! Loss: 0.015144645934924483
Starting epoch 1091/4000.
Epoch finished ! Loss: 0.015076903009321541
Starting epoch 1092/4000.
Epoch finished ! Loss: 0.015666183142457157
Starting epoch 1093/4000.
Epoch finished ! Loss: 0.015601215523201973
Starting epoch 1094/4000.
Epoch finished ! Loss: 0.015581838879734277
Starting epoch 1095/4000.
Epoch finished ! Loss: 0.014699932117946447
Starting epoch 1096/4000.
Epoch finished ! Loss: 

Epoch finished ! Loss: 0.014788850722834469
Starting epoch 1201/4000.
Epoch finished ! Loss: 0.016737536597065627
Starting epoch 1202/4000.
Epoch finished ! Loss: 0.015227133443113416
Starting epoch 1203/4000.
Epoch finished ! Loss: 0.015863930457271636
Starting epoch 1204/4000.
Epoch finished ! Loss: 0.01489402501611039
Starting epoch 1205/4000.
Epoch finished ! Loss: 0.015848221897613257
Starting epoch 1206/4000.
Epoch finished ! Loss: 0.016336016939021647
Starting epoch 1207/4000.
Epoch finished ! Loss: 0.016417104785796256
Starting epoch 1208/4000.
Epoch finished ! Loss: 0.01526115764863789
Starting epoch 1209/4000.
Epoch finished ! Loss: 0.015544041199609638
Starting epoch 1210/4000.
Epoch finished ! Loss: 0.015567249373998494
Starting epoch 1211/4000.
Epoch finished ! Loss: 0.016002179810311646
Starting epoch 1212/4000.
Epoch finished ! Loss: 0.015219422400696203
Starting epoch 1213/4000.
Epoch finished ! Loss: 0.015235605440102518
Starting epoch 1214/4000.
Epoch finished ! Loss:

Epoch finished ! Loss: 0.015315281914081425
Starting epoch 1319/4000.
Epoch finished ! Loss: 0.01621072959387675
Starting epoch 1320/4000.
Epoch finished ! Loss: 0.016176813910715283
Starting epoch 1321/4000.
Epoch finished ! Loss: 0.016036186530254782
Starting epoch 1322/4000.
Epoch finished ! Loss: 0.015472138696350158
Starting epoch 1323/4000.
Epoch finished ! Loss: 0.0153673529275693
Starting epoch 1324/4000.
Epoch finished ! Loss: 0.015895160310901702
Starting epoch 1325/4000.
Epoch finished ! Loss: 0.016550163005013017
Starting epoch 1326/4000.
Epoch finished ! Loss: 0.016961423039902
Starting epoch 1327/4000.
Epoch finished ! Loss: 0.01551501639187336
Starting epoch 1328/4000.
Epoch finished ! Loss: 0.015715998050291093
Starting epoch 1329/4000.
Epoch finished ! Loss: 0.015816631203051656
Starting epoch 1330/4000.
Epoch finished ! Loss: 0.014343694678973406
Starting epoch 1331/4000.
Epoch finished ! Loss: 0.01624572093132883
Starting epoch 1332/4000.
Epoch finished ! Loss: 0.015

Epoch finished ! Loss: 0.016673149669077247
Starting epoch 1437/4000.
Epoch finished ! Loss: 0.015777636389248072
Starting epoch 1438/4000.
Epoch finished ! Loss: 0.01503670954843983
Starting epoch 1439/4000.
Epoch finished ! Loss: 0.015817135921679437
Starting epoch 1440/4000.
Epoch finished ! Loss: 0.016063130018301307
Starting epoch 1441/4000.
Epoch finished ! Loss: 0.015118308633100241
Starting epoch 1442/4000.
Epoch finished ! Loss: 0.016081383521668612
Starting epoch 1443/4000.
Epoch finished ! Loss: 0.01545725913019851
Starting epoch 1444/4000.
Epoch finished ! Loss: 0.0154108407208696
Starting epoch 1445/4000.
Epoch finished ! Loss: 0.015566930919885636
Starting epoch 1446/4000.
Epoch finished ! Loss: 0.015310396440327168
Starting epoch 1447/4000.
Epoch finished ! Loss: 0.014941015292424708
Starting epoch 1448/4000.
Epoch finished ! Loss: 0.015899760543834417
Starting epoch 1449/4000.
Epoch finished ! Loss: 0.015251056931447238
Starting epoch 1450/4000.
Epoch finished ! Loss: 0

Epoch finished ! Loss: 0.015705291414633393
Starting epoch 1555/4000.
Epoch finished ! Loss: 0.016676161461509763
Starting epoch 1556/4000.
Epoch finished ! Loss: 0.01565493574598804
Starting epoch 1557/4000.
Epoch finished ! Loss: 0.015722527634352446
Starting epoch 1558/4000.
Epoch finished ! Loss: 0.016419575095642357
Starting epoch 1559/4000.
Epoch finished ! Loss: 0.01618338655680418
Starting epoch 1560/4000.
Epoch finished ! Loss: 0.01596946008503437
Starting epoch 1561/4000.
Epoch finished ! Loss: 0.014972300291992724
Starting epoch 1562/4000.
Epoch finished ! Loss: 0.015800324070733042
Starting epoch 1563/4000.
Epoch finished ! Loss: 0.01586168319918215
Starting epoch 1564/4000.
Epoch finished ! Loss: 0.016379034239798786
Starting epoch 1565/4000.
Epoch finished ! Loss: 0.015497188759036362
Starting epoch 1566/4000.
Epoch finished ! Loss: 0.015177406836301088
Starting epoch 1567/4000.
Epoch finished ! Loss: 0.016117013723123817
Starting epoch 1568/4000.
Epoch finished ! Loss: 0

Epoch finished ! Loss: 0.01511257210513577
Starting epoch 1673/4000.
Epoch finished ! Loss: 0.015125785709824414
Starting epoch 1674/4000.
Epoch finished ! Loss: 0.015680723823606968
Starting epoch 1675/4000.
Epoch finished ! Loss: 0.01555812667356804
Starting epoch 1676/4000.
Epoch finished ! Loss: 0.0161661930847913
Starting epoch 1677/4000.
Epoch finished ! Loss: 0.015702198399230836
Starting epoch 1678/4000.
Epoch finished ! Loss: 0.01581466117640957
Starting epoch 1679/4000.
Epoch finished ! Loss: 0.01586998593993485
Starting epoch 1680/4000.
Epoch finished ! Loss: 0.0156401643413119
Starting epoch 1681/4000.
Epoch finished ! Loss: 0.01624721838161349
Starting epoch 1682/4000.
Epoch finished ! Loss: 0.014577953785192221
Starting epoch 1683/4000.
Epoch finished ! Loss: 0.015826111868955196
Starting epoch 1684/4000.
Epoch finished ! Loss: 0.016135046258568764
Starting epoch 1685/4000.
Epoch finished ! Loss: 0.016227769700344653
Starting epoch 1686/4000.
Epoch finished ! Loss: 0.0157

Epoch finished ! Loss: 0.015440999332349747
Starting epoch 1791/4000.
Epoch finished ! Loss: 0.015822210686746985
Starting epoch 1792/4000.
Epoch finished ! Loss: 0.014901314256712794
Starting epoch 1793/4000.
Epoch finished ! Loss: 0.015650631592143326
Starting epoch 1794/4000.
Epoch finished ! Loss: 0.015607299329712988
Starting epoch 1795/4000.
Epoch finished ! Loss: 0.016369795135688037
Starting epoch 1796/4000.
Epoch finished ! Loss: 0.01570704081095755
Starting epoch 1797/4000.
Epoch finished ! Loss: 0.01554636366199702
Starting epoch 1798/4000.
Epoch finished ! Loss: 0.015849350357893855
Starting epoch 1799/4000.
Epoch finished ! Loss: 0.015011297794990242
Starting epoch 1800/4000.
Epoch finished ! Loss: 0.014581873407587409
Starting epoch 1801/4000.
Epoch finished ! Loss: 0.016153048793785273
Starting epoch 1802/4000.
Epoch finished ! Loss: 0.015181505621876567
Starting epoch 1803/4000.
Epoch finished ! Loss: 0.015438207029365002
Starting epoch 1804/4000.
Epoch finished ! Loss:

Epoch finished ! Loss: 0.015709741541650148
Starting epoch 1909/4000.
Epoch finished ! Loss: 0.015228539239615202
Starting epoch 1910/4000.
Epoch finished ! Loss: 0.014942109817638993
Starting epoch 1911/4000.
Epoch finished ! Loss: 0.015684384386986494
Starting epoch 1912/4000.
Epoch finished ! Loss: 0.015284234064165503
Starting epoch 1913/4000.
Epoch finished ! Loss: 0.01622217979747802
Starting epoch 1914/4000.
Epoch finished ! Loss: 0.015309034462552518
Starting epoch 1915/4000.
Epoch finished ! Loss: 0.014480213657952845
Starting epoch 1916/4000.
Epoch finished ! Loss: 0.014152321708388626
Starting epoch 1917/4000.
Epoch finished ! Loss: 0.01571195877622813
Starting epoch 1918/4000.
Epoch finished ! Loss: 0.014564020663965494
Starting epoch 1919/4000.
Epoch finished ! Loss: 0.015107941813766956
Starting epoch 1920/4000.
Epoch finished ! Loss: 0.015338492242153733
Starting epoch 1921/4000.
Epoch finished ! Loss: 0.016134822205640376
Starting epoch 1922/4000.
Epoch finished ! Loss:

Epoch finished ! Loss: 0.016178026096895337
Starting epoch 2027/4000.
Epoch finished ! Loss: 0.015458283689804375
Starting epoch 2028/4000.
Epoch finished ! Loss: 0.015480208734516054
Starting epoch 2029/4000.
Epoch finished ! Loss: 0.01731166880344972
Starting epoch 2030/4000.
Epoch finished ! Loss: 0.015593842160888017
Starting epoch 2031/4000.
Epoch finished ! Loss: 0.015230136213358492
Starting epoch 2032/4000.
Epoch finished ! Loss: 0.01454523524735123
Starting epoch 2033/4000.
Epoch finished ! Loss: 0.015630635689012705
Starting epoch 2034/4000.
Epoch finished ! Loss: 0.016454987728502603
Starting epoch 2035/4000.
Epoch finished ! Loss: 0.016132912843022495
Starting epoch 2036/4000.
Epoch finished ! Loss: 0.014414082549046725
Starting epoch 2037/4000.
Epoch finished ! Loss: 0.015135559183545411
Starting epoch 2038/4000.
Epoch finished ! Loss: 0.015353159967344255
Starting epoch 2039/4000.
Epoch finished ! Loss: 0.015964351571165027
Starting epoch 2040/4000.
Epoch finished ! Loss:

Epoch finished ! Loss: 0.014540417597163468
Starting epoch 2145/4000.
Epoch finished ! Loss: 0.0155421162256971
Starting epoch 2146/4000.
Epoch finished ! Loss: 0.016457944782450794
Starting epoch 2147/4000.
Epoch finished ! Loss: 0.014937617257237435
Starting epoch 2148/4000.
Epoch finished ! Loss: 0.014803740952629596
Starting epoch 2149/4000.
Epoch finished ! Loss: 0.015054112358484417
Starting epoch 2150/4000.
Epoch finished ! Loss: 0.015117837185971438
Starting epoch 2151/4000.
Epoch finished ! Loss: 0.014726059522945433
Starting epoch 2152/4000.
Epoch finished ! Loss: 0.015281615406274795
Starting epoch 2153/4000.
Epoch finished ! Loss: 0.015535867249127477
Starting epoch 2154/4000.
Epoch finished ! Loss: 0.015620606066659094
Starting epoch 2155/4000.
Epoch finished ! Loss: 0.016266977495979516
Starting epoch 2156/4000.
Epoch finished ! Loss: 0.015467704110778868
Starting epoch 2157/4000.
Epoch finished ! Loss: 0.015967977629043163
Starting epoch 2158/4000.
Epoch finished ! Loss:

Epoch finished ! Loss: 0.014819207927212119
Starting epoch 2263/4000.
Epoch finished ! Loss: 0.015091919701080769
Starting epoch 2264/4000.
Epoch finished ! Loss: 0.015475128556136041
Starting epoch 2265/4000.
Epoch finished ! Loss: 0.015068517660256475
Starting epoch 2266/4000.
Epoch finished ! Loss: 0.016506251017563044
Starting epoch 2267/4000.
Epoch finished ! Loss: 0.01477559922495857
Starting epoch 2268/4000.
Epoch finished ! Loss: 0.016343126515857877
Starting epoch 2269/4000.
Epoch finished ! Loss: 0.014539964927826077
Starting epoch 2270/4000.
Epoch finished ! Loss: 0.015763612231239676
Starting epoch 2271/4000.
Epoch finished ! Loss: 0.01606588300783187
Starting epoch 2272/4000.
Epoch finished ! Loss: 0.015278850973118097
Starting epoch 2273/4000.
Epoch finished ! Loss: 0.01486476812278852
Starting epoch 2274/4000.
Epoch finished ! Loss: 0.014939465723000468
Starting epoch 2275/4000.
Epoch finished ! Loss: 0.01519077819539234
Starting epoch 2276/4000.
Epoch finished ! Loss: 0

Epoch finished ! Loss: 0.015377101802732796
Starting epoch 2381/4000.
Epoch finished ! Loss: 0.014791631209664046
Starting epoch 2382/4000.
Epoch finished ! Loss: 0.014535409724339842
Starting epoch 2383/4000.
Epoch finished ! Loss: 0.015299486136063934
Starting epoch 2384/4000.
Epoch finished ! Loss: 0.014688832405954599
Starting epoch 2385/4000.
Epoch finished ! Loss: 0.014702156075509265
Starting epoch 2386/4000.
Epoch finished ! Loss: 0.0162797168828547
Starting epoch 2387/4000.
Epoch finished ! Loss: 0.015716116491239517
Starting epoch 2388/4000.
Epoch finished ! Loss: 0.01613048785366118
Starting epoch 2389/4000.
Epoch finished ! Loss: 0.016099852533079682
Starting epoch 2390/4000.
Epoch finished ! Loss: 0.014884499076288193
Starting epoch 2391/4000.
Epoch finished ! Loss: 0.014874507876811549
Starting epoch 2392/4000.
Epoch finished ! Loss: 0.015633704839274287
Starting epoch 2393/4000.
Epoch finished ! Loss: 0.015216003963723778
Starting epoch 2394/4000.
Epoch finished ! Loss: 

Epoch finished ! Loss: 0.015791677858214826
Starting epoch 2499/4000.
Epoch finished ! Loss: 0.015183188673108815
Starting epoch 2500/4000.
Epoch finished ! Loss: 0.015257185627706348
Starting epoch 2501/4000.
Epoch finished ! Loss: 0.015079198242165148
Starting epoch 2502/4000.
Epoch finished ! Loss: 0.015249409072566777
Starting epoch 2503/4000.
Epoch finished ! Loss: 0.014904992748051881
Starting epoch 2504/4000.
Epoch finished ! Loss: 0.01683612884953618
Starting epoch 2505/4000.
Epoch finished ! Loss: 0.014513105060905218
Starting epoch 2506/4000.
Epoch finished ! Loss: 0.01543431639438495
Starting epoch 2507/4000.
Epoch finished ! Loss: 0.01505452535348013
Starting epoch 2508/4000.
Epoch finished ! Loss: 0.014781522809062152
Starting epoch 2509/4000.
Epoch finished ! Loss: 0.0157718924456276
Starting epoch 2510/4000.
Epoch finished ! Loss: 0.015243373322300613
Starting epoch 2511/4000.
Epoch finished ! Loss: 0.015379466174636036
Starting epoch 2512/4000.
Epoch finished ! Loss: 0.

Epoch finished ! Loss: 0.01404972147429362
Starting epoch 2617/4000.
Epoch finished ! Loss: 0.015908778330776842
Starting epoch 2618/4000.
Epoch finished ! Loss: 0.014283477352000773
Starting epoch 2619/4000.
Epoch finished ! Loss: 0.01527911180164665
Starting epoch 2620/4000.
Epoch finished ! Loss: 0.01585986256832257
Starting epoch 2621/4000.
Epoch finished ! Loss: 0.014371758315246553
Starting epoch 2622/4000.
Epoch finished ! Loss: 0.01591038638725877
Starting epoch 2623/4000.
Epoch finished ! Loss: 0.014686023711692541
Starting epoch 2624/4000.
Epoch finished ! Loss: 0.015388512494973838
Starting epoch 2625/4000.
Epoch finished ! Loss: 0.014690470322966576
Starting epoch 2626/4000.
Epoch finished ! Loss: 0.014529076823964714
Starting epoch 2627/4000.
Epoch finished ! Loss: 0.014758236811030657
Starting epoch 2628/4000.
Epoch finished ! Loss: 0.014519248122815043
Starting epoch 2629/4000.
Epoch finished ! Loss: 0.015944201475940646
Starting epoch 2630/4000.
Epoch finished ! Loss: 0

In [None]:
if os.path.isdir(model_save_path):
            torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))
else:
    os.makedirs(model_save_path, exist_ok=True)
    torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))