In [1]:
%matplotlib inline
import torch
import torch.nn as nn
# import pandas as pd
import numpy as np
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch import autograd
from torch.autograd import Variable
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import copy

In [2]:
import onnx
import onnxruntime as ort

In [8]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device

device(type='cuda')

In [9]:
class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.model = nn.Sequential(
            # Pixel Coord + Light Coord + Average Pixel Color
            nn.Linear(2 + 2 + 3, 16),
            nn.LeakyReLU(),
            nn.Linear(16, 16),
            nn.LeakyReLU(),
            nn.Linear(16, 3),
            nn.Tanh()
        )
    def forward(self, pixel_coord, light_coord, average_rgb):
        d_in = torch.cat((pixel_coord, light_coord, average_rgb), -1)
        rgb = self.model(d_in)
        return rgb

In [63]:
def load_model(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False

#     model.eval().to(device)
    model.eval()
    return model

In [64]:
model = load_model('./check_pt/trained_basketball_128_200.pth')

In [65]:
loaded_avg_img = np.load('./avg/avg_basketball_128.npy')

In [21]:
im_s = 128
img_shape = (3, im_s, im_s)
cuda = True if torch.cuda.is_available() else False
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
CharTensor = torch.cuda.CharTensor if cuda else torch.CharTensor

In [22]:
def sample_image(n_row, model, loaded_avg_img, lr=0, lc=0):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    width = img_shape[1]
    height = img_shape[2]
    avg_img = torch.from_numpy(loaded_avg_img).to(device)
    avg_img = Variable(avg_img.type(FloatTensor))
    img_tensor = torch.FloatTensor().to(device)
    for k in range(height):
        for j in range(width):
            # Pixel Coord + Light Coord + Average Pixel Color
            avg_value = Variable(avg_img[j + k * width].repeat(1, n_row).view(n_row, 3).type(FloatTensor))
            pixel_coord = Variable(torch.FloatTensor([j/width, k/height])).repeat(1, n_row).view(n_row, 2).to(device)
            # Default sample location is the light in the middle
            light_coord = Variable(torch.FloatTensor([lr, lc])).repeat(1, n_row).view(n_row, 2).to(device)
            generate_pixel = model(pixel_coord, light_coord, avg_value)
            img_tensor = torch.cat((img_tensor, generate_pixel), 0)
    plt.imshow(img_tensor.cpu().data.detach().numpy().reshape(img_shape[1],img_shape[2],img_shape[0],order='F') * .5 + .5, cmap='gray')

In [23]:
from ipywidgets import interact

def test(lr, lc):
    sample_image(1, model, loaded_avg_img, lr, lc)
    
interact(test, lr=(-1, 1, .1), lc=(-1, 1, .1))

interactive(children=(FloatSlider(value=0.0, description='lr', max=1.0, min=-1.0), FloatSlider(value=0.0, desc…

<function __main__.test(lr, lc)>

# Inference with ONNX

In [48]:
pix = torch.FloatTensor(1, 2).to(device)
light = torch.FloatTensor(1, 2).to(device)
avg = torch.FloatTensor(1, 3).to(device)
torch.onnx.export(model, (pix, light, avg), 'example.onnx', input_names=['pix', 'light', 'avg'])

In [49]:
model = onnx.load('example.onnx')
onnx.checker.check_model(model)
onnx.helper.printable_graph(model.graph)

'graph torch-jit-export (\n  %pix[FLOAT, 1x2]\n  %light[FLOAT, 1x2]\n  %avg[FLOAT, 1x3]\n) initializers (\n  %model.0.bias[FLOAT, 16]\n  %model.0.weight[FLOAT, 16x7]\n  %model.2.bias[FLOAT, 16]\n  %model.2.weight[FLOAT, 16x16]\n  %model.4.bias[FLOAT, 3]\n  %model.4.weight[FLOAT, 3x16]\n) {\n  %9 = Concat[axis = -1](%pix, %light, %avg)\n  %10 = Gemm[alpha = 1, beta = 1, transB = 1](%9, %model.0.weight, %model.0.bias)\n  %11 = LeakyRelu[alpha = 0.00999999977648258](%10)\n  %12 = Gemm[alpha = 1, beta = 1, transB = 1](%11, %model.2.weight, %model.2.bias)\n  %13 = LeakyRelu[alpha = 0.00999999977648258](%12)\n  %14 = Gemm[alpha = 1, beta = 1, transB = 1](%13, %model.4.weight, %model.4.bias)\n  %15 = Tanh(%14)\n  return %15\n}'

In [83]:
ort_sess = ort.InferenceSession('example.onnx')
pix = torch.FloatTensor([[1, 1]]).to(device)
light = torch.FloatTensor([[1, 1]]).to(device)
avg = torch.FloatTensor([[1, 1, 1]]).to(device)
ort_sess.run(None, {'pix': pix.cpu().numpy(), 'light': light.cpu().numpy(), 'avg': avg.cpu().numpy()})[0]

array([[ 0.36055586, -0.8710557 ,  0.6011299 ]], dtype=float32)

In [74]:
def sample_onnx_image(n_row, model, loaded_avg_img, lr=0, lc=0):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    width = img_shape[1]
    height = img_shape[2]
    avg_img = torch.from_numpy(loaded_avg_img).to(device)
    avg_img = Variable(avg_img.type(FloatTensor))
    img_tensor = torch.FloatTensor().to(device)
    for k in range(height):
        for j in range(width):
            # Pixel Coord + Light Coord + Average Pixel Color
            avg_value = Variable(avg_img[j + k * width].repeat(1, n_row).view(n_row, 3).type(FloatTensor))
            pixel_coord = Variable(torch.FloatTensor([j/width, k/height])).repeat(1, n_row).view(n_row, 2).to(device)
            # Default sample location is the light in the middle
            light_coord = Variable(torch.FloatTensor([lr, lc])).repeat(1, n_row).view(n_row, 2).to(device)
            # This is inefficient because of the to_cpu then to_gpu
            generate_pixel = torch.FloatTensor(ort_sess.run(None, {'pix': pixel_coord.cpu().numpy(), 'light': light_coord.cpu().numpy(), 'avg': avg_value.cpu().numpy()})[0]).to(device)
            img_tensor = torch.cat((img_tensor, generate_pixel), 0)
    plt.imshow(img_tensor.cpu().data.detach().numpy().reshape(img_shape[1],img_shape[2],img_shape[0],order='F') * .5 + .5, cmap='gray')

In [75]:
from ipywidgets import interact

def test(lr, lc):
    sample_onnx_image(1, model, loaded_avg_img, lr, lc)
    
interact(test, lr=(-1, 1, .1), lc=(-1, 1, .1))

interactive(children=(FloatSlider(value=0.0, description='lr', max=1.0, min=-1.0), FloatSlider(value=0.0, desc…

<function __main__.test(lr, lc)>

In [116]:
def sample_onnx_image_cpu(n_row, model, loaded_avg_img, lr=0, lc=0):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    width = img_shape[1]
    height = img_shape[2]
    avg_img = loaded_avg_img
    img_tensor = np.array([]).reshape(-1,3)
    for k in range(height):
        for j in range(width):
            # Pixel Coord + Light Coord + Average Pixel Color
            avg_value = np.array(avg_img[j + k * width]).reshape(1, 3)
            pixel_coord = np.array([[j/width, k/height]]).reshape(1, 2)
            # Default sample location is the light in the middle
            light_coord = np.array([[lr, lc]]).reshape(1, 2)
            # This is inefficient because of the to_cpu then to_gpu
            generate_pixel = ort_sess.run(None, {'pix': pixel_coord.astype(np.float32), 'light': light_coord.astype(np.float32), 'avg': avg_value.astype(np.float32)})[0]
            img_tensor = np.append(img_tensor, generate_pixel, axis=0)
    plt.imshow(img_tensor.reshape(img_shape[1],img_shape[2],img_shape[0],order='F') * .5 + .5, cmap='gray')

In [117]:
from ipywidgets import interact

def test(lr, lc):
    sample_onnx_image_cpu(1, model, loaded_avg_img, lr, lc)
    
interact(test, lr=(-1, 1, .1), lc=(-1, 1, .1))

interactive(children=(FloatSlider(value=0.0, description='lr', max=1.0, min=-1.0), FloatSlider(value=0.0, desc…

<function __main__.test(lr, lc)>