In [None]:
%reload_ext autoreload
%autoreload 2
import re
import os, glob, datetime as dt
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.init as init
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from helpers import *
import random
cuda = torch.cuda.is_available()

In [None]:
n_epoch = 50
lr = 0.005
batch_size = 32
n_train_data = 1000 * batch_size # use all available training data
data_dir = 'data/train'

In [None]:
class DnCNN(nn.Module):
    def __init__(self, n_layers = 17, n_channels=64, image_channels = 1, use_bnorm = True, kernel_size = 3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []

        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, 
                                kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(n_layers - 2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, 
                                    kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, 
                                kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        return self.dncnn(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
        print('weights initialized')


In [None]:
def findLastCheckpoint(save_dir):
    file_list = glob.glob("./saved_models/model_*.pth")
    if file_list:
        epochs_exist = []
        for file_ in file_list:
            result = re.findall(".*model_(.*).pth.*", file_)
            epochs_exist.append(int(result[0]))
        initial_epoch = max(epochs_exist)
    else:
        initial_epoch = 0
    return initial_epoch

def log(e, l, t, r = 1):
    out = "epoch = {:2d}, loss = {:8.2f}, time = {:4d} seconds, rate = {:1.1f}".\
    format(e, l, round(t.total_seconds()), r )
    print(dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S   "), out)


In [None]:
model = DnCNN()
    
initial_epoch = findLastCheckpoint(save_dir=save_dir)
if initial_epoch > 1:
    print('resuming from epoch %03d\n' % (initial_epoch-1))
  
    if initial_epoch >= n_epoch:
        print("done")
    else:
        model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % (initial_epoch-1)))
model.train()
criterion = nn.MSELoss()
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr = lr)
scheduler = MultiStepLR(optimizer, milestones=[20, 30, 40], gamma=0.2) 


In [None]:
data_dir = "./data/landmass1"
patch_size = 99
from helpers import read_matlab
xs = read_matlab(data_dir, patch_size, n_train_data)
print(xs.shape)
xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))
print(xs.shape)

In [None]:
for epoch in range(initial_epoch, n_epoch):
    r = random.sample([0.6, 0.7, 0.8, 2,3,4,5],1)[0]
    DDataset = DownsampleDataset(xs, r) 
    DLoader = DataLoader(dataset=DDataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True)
    epoch_loss = 0
    start_time = dt.datetime.now()

    for n_count, batch_yx in enumerate(DLoader):
            optimizer.zero_grad()
            batch_x = batch_yx[1].cuda()
            batch_y = batch_yx[0].cuda()
            loss = criterion(model(batch_y), batch_x)
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
            
    scheduler.step(epoch)
    elapsed_time = dt.datetime.now() - start_time
    log(epoch+1, epoch_loss, elapsed_time, r) 
    torch.save(model, "saved_models/model_{:03d}.pth".format(epoch+1))
