In [1]:
# Modified from ninetf135246's 
# https://github.com/ninetf135246/pytorch-Learning-to-See-in-the-Dark

In [None]:
from __future__ import division
import os, scipy.io
import numpy as np
import logging
import argparse
import sys
from SID_data import SIDData, get_img
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from SID_model import Unet
from datetime import datetime
from SID_GradLoss import GradLoss
from tqdm import tqdm
import pdb
from PIL import Image
import csv
from torchvision import transforms
from torch.nn import init
import cv2
import matplotlib.pyplot as plt

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [3]:
def train(inputID_dir, valID_dir, batch_size, lr, wd, num_epoch, log_interval, save_freq,
         model_save_freq):
    
    with open("./output/train_val_loss.csv", "w") as f:
        writer = csv.writer(f, delimiter=',')
        writer.writerow(['epoch', 'train_loss', 'val_loss'])

    device = torch.device("cuda:0")
    trainset = SIDData(inputID_dir)
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=12)

    # validation
    valset = SIDData(valID_dir)
    val_loader = DataLoader(valset, batch_size=1, shuffle=False)
    
    logging.info("data loading okay")

    # model
    model = Unet().to(device)
    
    # initialization
    params = list(model.parameters())
    for param in params:
        if len(param.shape) >= 2:
            init.xavier_uniform_(param)

    # resume
    starting_epoch = 0

    # loss function
    color_loss = nn.L1Loss()
    gradient_loss = GradLoss(device)

    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # lr scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)

    # training
    #running_loss = 0.0
    
    val_loss_min = float('inf')
    
    # the image for visualization
    input_visual = get_img('./dark_cat.jpg').cuda()
    target_visual = cv2.imread('./cat.jpg')
    
    for epoch in tqdm(range(starting_epoch+1, starting_epoch + num_epoch)):
        model.train()
        scheduler.step()
        for i, databatch in enumerate(train_loader):

            input_img, output_img = databatch
            input_img, output_img = input_img.to(device), output_img.to(device)

            # zero the parameter gradients 
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(input_img)
            train_loss = color_loss(outputs, output_img)
            train_loss.backward()
            optimizer.step()
            
        # validation and visualization
        model.eval()
        with torch.no_grad():
            for i, databatch in enumerate(val_loader):
                input_img, output_img = databatch
                input_img, output_img = input_img.to(device), output_img.to(device)
                outputs = model(input_img)
                val_loss = color_loss(outputs, output_img)

        if epoch % 2 == 0:
            pic_output = (model(input_visual.unsqueeze(dim=0))*255).detach().cpu()
            pic_output = pic_output.squeeze(dim=0).permute((1,2,0)).numpy()
            pic_output = np.concatenate([pic_output, target_visual], axis=1)
            save_path = './output/image_{}.png'.format(epoch)
            color_image = cv2.imwrite(save_path, pic_output)

        # save epoch number, train_loss, val_loss to a csv file
        with open("./output/train_val_loss.csv", "a") as f:
            writer = csv.writer(f, delimiter=',')
            writer.writerow([epoch, train_loss.item(), val_loss.item()])
        
        # save the best model so far
        if val_loss.item() < val_loss_min and abs(train_loss.item() - val_loss.item()) <= 0.05 * train_loss.item():
            torch.save(model.state_dict(), './output/best_model.pth')


In [None]:
inputID_dir = './data/train_ID.txt'
valID_dir = './data/val_ID.txt'
batch_size = 6
lr = 0.0001
wd = 0
num_epoch = 500
log_interval = 5
save_freq = 5
result_dir = './result_grad_loss/'
model_save_freq = 5
checkpoint_dir = './result_grad_loss/'