In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data.sampler import SubsetRandomSampler
import argparse
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 1, stride=1, padding=0)
        for j in range(3):
            for i in range(6):
                exec("self.Resconv1"+str(j)+str(i)+"="+"nn.Conv2d(32, 32, 3, stride=1, padding=1)")
                exec("self.Batch_norm1"+str(j)+str(i)+"="+"nn.BatchNorm2d(32, track_running_stats=False)")
                exec("self.Resconv2"+str(j)+str(i)+"="+"nn.Conv2d(32, 32, 3, stride=1, padding=1)")
                exec("self.Batch_norm2"+str(j)+str(i)+"="+"nn.BatchNorm2d(32, track_running_stats=False)")
        self.Transconv1 = nn.ConvTranspose2d(32, 32, 3, stride=2, padding=1, output_padding = 1)
        self.Transconv2 = nn.ConvTranspose2d(32, 32, 3, stride=2, padding=1, output_padding = 1)
        self.conv2 = nn.Conv2d(32, 3*256, 1, stride=1, padding=0)

    def conditioning_network(self, lr_images):
        res_num = 6
        inputs = lr_images
        inputs = self.conv1(inputs)
        for i in range(2):
            for j in range(res_num):
                inputs = self.resnet_block(inputs, i, j)
            inputs = eval("self.Transconv"+str(i+1))(inputs)
            inputs = F.relu(inputs)
        for i in range(res_num):
            inputs = self.resnet_block(inputs, 2, i)
        conditioning_logits = self.conv2(inputs)
        return conditioning_logits

    def resnet_block(self, inputs, i, j):
        conv1 = eval("self.Resconv1"+str(i)+str(j))(inputs)
        bn1 = eval("self.Batch_norm1"+str(i)+str(j))(conv1)
        relu1 = F.relu(bn1)
        conv2 = eval("self.Resconv2"+str(i)+str(j))(relu1)
        bn2 = eval("self.Batch_norm2"+str(i)+str(j))(conv2)
        output = inputs + bn2
        return output

    def forward(self, lr_images):
        lr_images = lr_images - 0.5
        conditioning_logits = self.conditioning_network(lr_images)
        return conditioning_logits

In [3]:
def softmax_loss(logits, labels):
    logits = logits.permute(0, 2, 3, 1)
    logits = torch.reshape(logits, [-1, 256])
    labels = labels.to(torch.int64)
    labels = labels.permute(0, 2, 3, 1)
    labels = torch.reshape(labels, [-1])
    return F.cross_entropy(logits, labels)

def test(args, model, device, test_loader, epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader,1):
            data, target = data.to(device), target.to(device)
            conditioning_logits = model(lr_images = data)
            l2 = softmax_loss(conditioning_logits, torch.floor(target*255))
            test_loss += l2 # sum up batch loss
    test_loss /= len(test_loader)*len(data)
    print("test_loss : ", test_loss.item())
    sample(model, data, target, len(data), mu=1.1, step=epoch)

def logits_2_pixel_value(logits, mu=1.1):
    rebalance_logits = logits * mu
    probs = softmax(rebalance_logits)
    pixel_dict = torch.arange(0, 256, dtype=torch.float32)
    pixels = torch.sum(probs*pixel_dict, dim=1)
    return (pixels/255)

def softmax(x):
    a, b = torch.max(x, -1, keepdim=True, out=None)
    e_x = torch.exp(x - a)
    return e_x / e_x.sum(dim=-1, keepdim =True) # only difference

def sample(model, data, target, batch_size, mu=1.1, step=None):
    with torch.no_grad():
        np_lr_imgs = data
        np_hr_imgs = target
        c_logits = model.conditioning_network
        #p_logits = model.prior_network
        gen_hr_imgs = torch.zeros((batch_size, 3, 32, 32), dtype=torch.float32)
        np_c_logits = c_logits(np_lr_imgs)
        for i in range(32):
            for j in range(32):
                for c in range(3):
                    new_pixel = logits_2_pixel_value(np_c_logits[:, c*256:(c+1)*256, i, j], mu=mu)
                    gen_hr_imgs[:, c, i, j] = new_pixel
        samples_dir =  "/home/eee/ug/15084005/DIH/samples_ip/"
        print("sample")
        save_samples(np_lr_imgs, samples_dir + '/lr_' + str(mu*10) + '_' + str(step))
        save_samples(np_hr_imgs, samples_dir + '/hr_' + str(mu*10) + '_' + str(step))
        save_samples(gen_hr_imgs, samples_dir + '/generate_' + str(mu*10) + '_' + str(step))

def save_samples(np_imgs, img_path):
    print("save")
    torchvision.utils.save_image(np_imgs[0, :, :, :], img_path+".jpg")

def load_image( infilename ) :
    img = Image.open( infilename )
    img.load()
    data = np.asarray( img, dtype="float32" )
    return data

In [5]:
model = Net()
model = model.to("cpu")
model.load_state_dict(torch.load("/home/eee/ug/15084005/DIH/models/30.pt"))

In [6]:
target  = torchvision.transforms.ToTensor()
target = target(load_image("d3_32*32.png"))
target = target.unsqueeze_(0)/255

In [7]:
data  = torchvision.transforms.ToTensor()
data = data(load_image("d3_8*8.png"))
data = data.unsqueeze_(0)/255

In [15]:
with torch.no_grad():
    sample(model, data, target, batch_size=1, mu=1.1, step=None)

sample
save
save
save
