### Loading models trained on shapenet data and saving predicted images

In [None]:
import os
import glob
import json
import torch
import random
import numpy as np
from PIL import Image
from model import models
from model import metrics
from data import camera_utils
from data import data_utils
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

In [2]:
# global params
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# volume params
center = torch.Tensor([0.0, 0.0, 0.0])
radius = 1.0
vol_params = ("sphere", center, radius)

In [3]:
# load saved model
learning_rate = 1e-2
input_ch = 6
points_type = 'cartesian'
viewdir_type = 'cartesian'
out_dir = '/home/goel/Thesis/Code/dvr/outputs/shapenet/chair/nomap/'
data_dir = '/home/goel/Thesis/Data/shapenet/chair/val'
writer_path = os.path.join(out_dir,'logs')
model_path = os.path.join(out_dir,'models/2000.pth')

# create the model
model = models.ConvNet(input_ch=input_ch, output_ch=3).to(device)

checkpoint = torch.load(model_path, map_location=device)
input_map = checkpoint['input_map']
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [4]:
pred_data = []

for cam_file in sorted(glob.glob(data_dir + "/*.json")):
    points, viewdirs, valid_mask = data_utils.get_input_data(cam_file, vol_params, points_type, viewdir_type)

    input = torch.cat((points, viewdirs),dim=0).to(device)
    input = data_utils.input_mapping(input, input_map, True, True, "cartesian", "ConvNet")

    input = input.unsqueeze_(0).to(device)
    rgb = model(input)
    clamped_rgb = torch.clamp(rgb.detach(), min=0.0, max=1.0)
    pred_data.append(clamped_rgb[0])

In [5]:
pred = torch.stack(pred_data,dim=0)
print(pred.shape)
writer = SummaryWriter(writer_path)
writer.add_images('val_pred', pred, 2000)

torch.Size([25, 3, 512, 512])


In [6]:
# gt_data = []
# for img_file in sorted(glob.glob(data_dir + "/*.jpg")):
#     gt = Image.open(img_file)
#     gt = np.array(gt).astype(np.float32) / 255. # shape: (H, W, C)
#     gt = torch.Tensor(gt).permute(2,0,1) # shape: (C, H, W)
#     gt_data.append(gt)
# gt = torch.stack(gt_data,dim=0)
# print(gt.shape)
# writer.add_images('train_gt', gt, 1)