In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
import ipdb

from DehazeNet import DehazeNet
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
from torchnet import meter
import config
import torchvision.utils
import torch
from config import Config
from DehazingSet import DehazingSet

In [3]:
def train(opt):
    #step1: model
    model = DehazeNet(opt.kernel_size, opt.rate_num, opt.conv, opt.ranking).cuda()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    
    #step2: dataset
    train_set = DehazingSet(opt.train_data_root)
    val_set = DehazingSet(opt.val_data_root)
    train_dataloader = DataLoader(train_set, opt.batch_size, shuffle = True, num_workers = opt.num_workers)
    val_dataloader = DataLoader(val_set, opt.val_batch_size, shuffle = True, num_workers = opt.num_workers)
    
    #step3: Loss function and Optimizer
    criterion = nn.MSELoss().cuda()
    lr = opt.lr #current learning rate
    optimizer = optim.Adam(model.parameters(), lr = lr, weight_decay = opt.weight_decay)
    
    # metrics
    total_loss = 0
    previous_loss = 1e100
    
    model.train()  #train mode
    
    
    #step5: train
    for epoch in range(opt.max_epoch):
        total_loss = 0
        
        for iteration, (hazy_img, gt_img) in enumerate(train_dataloader):

            input_data = hazy_img.cuda()
            target_data = gt_img.cuda()
            
            print("iteration {} (before back): {}".format(iteration, torch.cuda.memory_allocated()/10e6))
            output_result = model(input_data)
            loss = criterion(output_result, target_data)
            
            optimizer.zero_grad()
            loss.backward()
            print("iteration {} (after back): {}".format(iteration, torch.cuda.memory_allocated()/10e6))
            ipdb.set_trace()
            optimizer.step()
            
            total_loss += loss.detach()
            
            if (iteration + 1) % opt.display_iter == 0:
                print("Loss at iteration {}: {}".format(iteration, loss))
            if (iteration + 1) == len(train_dataloader):
                torchvision.utils.save_image(torch.cat((input_data.data, target_data.data, output_result.data), dim = 0))
                
        print("Training Set Loss at Epoch {}: {}".format(epoch, total_loss))
        model.save(strftime('%m%d_%H:%M:%S') + '_Epoch:' + epoch)
        
        
        val_loss = val(model, val_dataloader)
        print("Val Set Loss at Epoch {}: {}".format(epoch, val_loss))
        
        #if loss does not decrease, decrease learning rate
        if loss_meter.value()[0] > previous_loss:
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
                
       
                  

In [4]:
def val(model, dataloader):
    model.eval() #evaluation mode
    
    loss_meter = meter.AverageValueMeter()
    for iteration, (hazy_img, gt_img) in enumerate(dataloader):
        input_data = hazy_img.cuda()
        target_data = gt_img.cuda()
        
        output_result = model(input_data)
        
        #TODO: SSIM and PSNR test
        loss = nn.MSELoss()(input_data, target_data)
        loss_meter.add(loss.data[0])
    
    model.train() #back to train mode
    
    return loss_meter.value()[0]

In [5]:
if __name__ == '__main__':
    config = Config()
    train(config)

iteration 0 (before back): 0.5720576
iteration 0 (after back): 0.9298432
> [0;32m<ipython-input-3-1dec2a178921>[0m(41)[0;36mtrain[0;34m()[0m
[0;32m     40 [0;31m            [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m---> 41 [0;31m            [0moptimizer[0m[0;34m.[0m[0mstep[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     42 [0;31m[0;34m[0m[0m
[0m
ipdb> c
iteration 1 (before back): 1.645312
iteration 1 (after back): 1.8550784
> [0;32m<ipython-input-3-1dec2a178921>[0m(40)[0;36mtrain[0;34m()[0m
[0;32m     39 [0;31m            [0mprint[0m[0;34m([0m[0;34m"iteration {} (after back): {}"[0m[0;34m.[0m[0mformat[0m[0;34m([0m[0miteration[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0mcuda[0m[0;34m.[0m[0mmemory_allocated[0m[0;34m([0m[0;34m)[0m[0;34m/[0m[0;36m10e6[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m---> 40 [0;31m            [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m(

RuntimeError: CUDA out of memory. Tried to allocate 98.88 MiB (GPU 0; 5.93 GiB total capacity; 4.61 GiB already allocated; 27.25 MiB free; 519.16 MiB cached)

In [None]:
%debug

> [0;32m/home/ws/anaconda3/envs/pytorch/lib/python3.6/site-packages/torchvision/transforms/functional.py[0m(208)[0;36mnormalize[0;34m()[0m
[0;32m    206 [0;31m    [0mmean[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0mmean[0m[0;34m,[0m [0mdtype[0m[0;34m=[0m[0mtorch[0m[0;34m.[0m[0mfloat32[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m    207 [0;31m    [0mstd[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0mstd[0m[0;34m,[0m [0mdtype[0m[0;34m=[0m[0mtorch[0m[0;34m.[0m[0mfloat32[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m--> 208 [0;31m    [0mtensor[0m[0;34m.[0m[0msub_[0m[0;34m([0m[0mmean[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;32mNone[0m[0;34m,[0m [0;32mNone[0m[0;34m][0m[0;34m)[0m[0;34m.[0m[0mdiv_[0m[0;34m([0m[0mstd[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;32mNone[0m[0;34m,[0m [0;32mNone[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m    209 [0;31m    [0;32mreturn[0m 

ipdb> u
> [0;32m/home/ws/anaconda3/envs/pytorch/lib/python3.6/site-packages/torchvision/transforms/transforms.py[0m(163)[0;36m__call__[0;34m()[0m
[0;32m    161 [0;31m            [0mTensor[0m[0;34m:[0m [0mNormalized[0m [0mTensor[0m [0mimage[0m[0;34m.[0m[0;34m[0m[0m
[0m[0;32m    162 [0;31m        """
[0m[0;32m--> 163 [0;31m        [0;32mreturn[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mtensor[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mmean[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mstd[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0minplace[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m    164 [0;31m[0;34m[0m[0m
[0m[0;32m    165 [0;31m    [0;32mdef[0m [0m__repr__[0m[0;34m([0m[0mself[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0m
[0m
ipdb> l 155,163
[1;32m    155 [0m    [0;32mdef[0m [0m__call__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mtensor[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0m
[1;32m    156 [0m        """
[1;32m    157 [0

ipdb> u
> [0;32m/home/ws/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py[0m(615)[0;36m<listcomp>[0;34m()[0m
[0;32m    613 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mnum_workers[0m [0;34m==[0m [0;36m0[0m[0;34m:[0m  [0;31m# same-process loading[0m[0;34m[0m[0m
[0m[0;32m    614 [0;31m            [0mindices[0m [0;34m=[0m [0mnext[0m[0;34m([0m[0mself[0m[0;34m.[0m[0msample_iter[0m[0;34m)[0m  [0;31m# may raise StopIteration[0m[0;34m[0m[0m
[0m[0;32m--> 615 [0;31m            [0mbatch[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcollate_fn[0m[0;34m([0m[0;34m[[0m[0mself[0m[0;34m.[0m[0mdataset[0m[0;34m[[0m[0mi[0m[0;34m][0m [0;32mfor[0m [0mi[0m [0;32min[0m [0mindices[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m    616 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mpin_memory[0m[0;34m:[0m[0;34m[0m[0m
[0m[0;32m    617 [0;31m                [0mbatch[0m 