In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os,sys
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import helper
import simulation_dfm2
import math
from os.path import splitext
from os import listdir
from glob import glob

import torch
torch.cuda.empty_cache()

img_dir_train = 'trian/img/'
obj_dir_train = 'train/gt'
img_dir_val = 'val/img/'
obj_dir_val = 'val/img/'

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models

class SimDataset(Dataset):
    def __init__(self, img_dir, obj_dir, transform=None):
        self.img_dir = img_dir 
        self.obj_dir = obj_dir        
        self.transform = transform
        self.img_ids = [splitext(file)[0] for file in listdir(img_dir)
                    if not file.startswith('.')]
        self.obj_ids = [splitext(file)[0] for file in listdir(obj_dir)
                    if not file.startswith('.')]
    
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_idx = self.img_ids[idx]
        obj_idx = self.obj_ids[idx]
        
        obj_file = glob(self.obj_dir + obj_idx + '.*')
        img_file = glob(self.img_dir + img_idx + '.*')

        obj = Image.open(obj_file[0])
        img = Image.open(img_file[0])

        if self.transform:
            img = self.transform(img)
            obj = self.transform(obj)
        
        return [img, obj]

# use same transform for train/val for this example
trans = transforms.Compose([
    transforms.ToTensor(),
])

train_set = SimDataset(img_dir_train, obj_dir_train, transform = trans)
val_set = SimDataset(img_dir_val, obj_dir_val, transform = trans)

image_datasets = {
    'train': train_set, 'val': val_set
}

batch_size = 15

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}

dataset_sizes = {
    x: len(image_datasets[x]) for x in image_datasets.keys()
}

dataset_sizes

In [None]:
import torchvision.utils

inputs, masks = next(iter(dataloaders['train']))
print(inputs.shape, masks.shape)
for x in [inputs.numpy(), masks.numpy()]:
    print(x.min(), x.max(), x.mean(), x.std())

In [None]:
from torchsummary import summary
import torch
import torch.nn as nn
import pytorch_unet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = pytorch_unet.UNet(1)
model = model.to(device)
summary(model, input_size=(1, 256, 256))

In [None]:
from pytorch_ssim import ssim
from torch import optim
import pytorch_msssim
from torchvision.transforms.functional import to_tensor

metric = 'MSSSIM' # MSSSIM or SSIM

def calc_loss(pred,target,metric):
    loss_ssim = pytorch_msssim.MSSSIM() if metric == 'MSSSIM' else ssim()
    loss_l1 = nn.L1Loss()
    loss = 0.8 * (1-loss_ssim(pred, target)) + 0.2 * loss_l1(pred, target)
    return loss

def train_model(model, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1
    
    loss_list_train = []
    loss_list_val = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metric = 'MSSSIM' # MSSSIM or SSIM
            epoch_samples = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)             

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metric)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                epoch_samples += inputs.size(0)
            epoch_loss = loss.item()
            print(epoch_loss, phase)
            if phase == 'train':
                loss_list_train.append(epoch_loss)
            else:
                loss_list_val.append(epoch_loss)

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, loss_list_train, loss_list_val

In [None]:
import torch
from torch.optim import lr_scheduler
import time
import copy
import pickle

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_class = 1

model = pytorch_unet.UNet(num_class).to(device)

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=20, gamma=0.1)

num_epochs = 100
model, loss_list_train, loss_list_val = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=num_epochs)

## save model
f_save = '.../model.pt'
torch.save(model,f_save)
## save loss
file=open('.../loss_train.txt','w')
for item in loss_list_train:
    file.write(str(item))
    file.write('\n')
file.close()
file=open('.../loss_val.txt','w')
for item in loss_list_val:
    file.write(str(item))
    file.write('\n')
file.close()

In [None]:
# reconstruction
f_model_load = '.../model.pt'# model address
model_trained = torch.load(f_model_load)
model_trained.eval()   # Set model to evaluate mode
img_dir_pred = ''# low resolution image location
obj_dir_pred = '' # ground truth image location, set to img_dir_pred when loss calculation is not needed
test_dataset = SimDataset(img_dir_pred, obj_dir_pred, transform = trans)
test_loader = DataLoader(test_dataset, batch_size=20, shuffle=False, num_workers=0)
        
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.to(device)

pred = model_trained(inputs)
pred = pred.data.cpu().numpy()
print(inputs.shape, labels.shape, pred.shape)

for i in range(20):
    np.savetxt('.../recon/recon'+str(i)+'.txt',pred[i][0])