### Loading the models trained on Ejecta DVR images and saving the predicted images

In [None]:
import os
import glob
import json
import torch
import random
import numpy as np
from PIL import Image
from PIL import ImageDraw
from model import models
from model import metrics
from data import camera_utils
from data import data_utils
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

In [2]:
# global params

width = 512
height = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# volume params
center = torch.Tensor([0.0, 0.0, 0.0])
radius = 1.0
vol_params = ("sphere", center, radius)

np.random.seed(42)
torch.manual_seed(42)
random.seed(42)

In [3]:
# load saved model
model_name = 'NeRF'
input_ch = 3
points_type = 'cartesian'
viewdir_type = 'cartesian'
out_dir = '/home/goel/Thesis/Code/dvr/outputs/nerf_approach/first_run/lr1e-2/points_128/'
writer_path = os.path.join(out_dir,'logs')
model_path = os.path.join(out_dir,'models/200.pth')

# create the model
model = models.OldNeRF(input_ch=input_ch, output_ch = 4, hidden_size=128).to(device)

checkpoint = torch.load(model_path, map_location=device)
input_map = checkpoint['input_map']
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [4]:
def check_model_size(model):
    num_params = 0
    traininable_param = 0
    for param in model.parameters():
        num_params += param.numel()
        if param.requires_grad:
            traininable_param += param.numel()
    print("[Network  Total number of parameters : %.3f M" % (num_params / 1e6))
    print(
        "[Network  Total number of trainable parameters : %.3f M"
        % (traininable_param / 1e6))

In [5]:
# train_data_dir = '/home/goel/Thesis/Data/ejecta_27/train/'
# train_data_lst = []

# for cam_file in glob.glob(train_data_dir+"*.json"):
#     fid =cam_file[cam_file.rfind('_')+1:cam_file.rfind('.')]
#     train_data_lst.append(fid)

# n_data = len(train_data_lst)
# train_selected_data = random.sample(range(0,n_data),4)

# gt_train_data = []
# pred_train_data = []

# for data_idx in train_selected_data:
#     id = train_data_lst[data_idx]
#     cam_file = train_data_dir + "camera_" + id + ".json"
#     img_file = train_data_dir + "color_" + id + ".png"

#     points, viewdirs, valid_mask = data_utils.get_input_data(cam_file, vol_params, points_type, viewdir_type)

#     input = torch.cat((points, viewdirs),dim=0).to(device)
#     input = data_utils.input_mapping(input, input_map, True, True, "cartesian")

#     input = input.unsqueeze_(0).to(device)
#     rgb = model(input)
#     clamped_rgb = torch.clamp(rgb.detach(), min=0.0, max=1.0)
#     pred_train_data.append(clamped_rgb[0])

#     img = torch.Tensor(np.array(Image.open(img_file)) / 255.).permute(2,0,1)[:3]
#     gt_train_data.append(img)

In [7]:
# gt_train = torch.stack(gt_train_data,dim=0)
# pred_train = torch.stack(pred_train_data,dim=0)
# gt_val = torch.stack(gt_val_data,dim=0)
# pred_val = torch.stack(pred_val_data,dim=0)

# writer = SummaryWriter(writer_path)
# # writer.add_images('vis_gt_train', gt_train, 100)
# # writer.add_images('vis_pred_train', pred_train, 100)
# writer.add_images('vis_gt_val', gt_val, 101)
# writer.add_images('vis_pred_val', pred_val, 101)

### Creating zoomed in views of ground truth and network predictions

In [14]:
img_file = '/home/parika/Downloads/ConvNet/gt.png'
created_img = '/home/parika/Downloads/ConvNet/gt_with_zoom.png'
gt_img_orig = Image.open(img_file)
gt_img_orig.load()

gt_img_white_bkg = Image.new("RGB", gt_img_orig.size, (0, 0, 0))
gt_img_white_bkg.paste(gt_img_orig, mask=gt_img_orig.split()[3])

borders = (80,80,480,480)
gt_img = gt_img_white_bkg.crop(borders)
print("Image Size: ", gt_img.size)

# borders_crop = (210,210,280,280)
# borders_crop = (220,70,290,140)
# borders_crop = (170,70,240,140)
borders_crop = (180, 90, 250, 160)
gt_crop = gt_img.crop(borders_crop)
print("Size of the crop taken: ", gt_crop.size)

height, width = gt_crop.size
rescaled_size=(height*3, width*3)
gt_crop_scaled =  gt_crop.resize(rescaled_size)
print("Size of the scaled crop: ",gt_crop_scaled.size)

# Draw a box to specify where the crop is taken
width_of_rect = 1.0
draw = ImageDraw.Draw(gt_img)
draw.rectangle(borders_crop, outline='black', width=3)
print("Position where the crop is taken: ",borders_crop)

# Draw a box where the crop will be pasted
# pos_paste_crop = (0,0,rescaled_size[0],rescaled_size[1])
# draw.rectangle(pos_paste_crop, outline='black', width=3)
# print("Position where scaled crop is pasted",pos_paste_crop)

# pos_paste_crop = (0,0,rescaled_size[0],rescaled_size[1])
pos_paste_crop = (0,gt_img.size[1]-rescaled_size[1],rescaled_size[0],gt_img.size[1])
gt_img.paste(gt_crop_scaled,pos_paste_crop)

draw.rectangle(pos_paste_crop, outline='black', width=3)
print("Position where scaled crop is pasted",pos_paste_crop)

# Draw the lines to connect both the boxes
draw.line((pos_paste_crop[0],pos_paste_crop[1])+(borders_crop[0],borders_crop[1]),fill='black',width=3)
draw.line((pos_paste_crop[2],pos_paste_crop[3])+(borders_crop[2],borders_crop[3]),fill='black',width=3)

# gt_img.show()
gt_img.save(created_img)

Image Size:  (400, 400)
Size of the crop taken:  (70, 70)
Size of the scaled crop:  (210, 210)
Position where the crop is taken:  (180, 90, 250, 160)
Position where scaled crop is pasted (0, 190, 210, 400)
