In [1]:
import os
import torch
import numpy as np
from scipy.io import loadmat
import PIL
from PIL import Image

import data_loader
import models
from utils.metrics import EncodeColor

import pydensecrf.densecrf as dcrf

colors = loadmat('data_loader/mit_scene_color150.mat')

def test(config):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  # Setup dataloader
  mit_data = data_loader.get_loader("mit_sceneparsing_benchmark")
  mit_data = mit_data("testing")

  # Load model checkpoint
  model = models.get_model("fcn32s", mit_data.n_classes).to(device)
  print(f"Loading model checkpoint: {config['ckpoint']}")
  ckpoint = torch.load(config["ckpoint"])
  model.load_state_dict(ckpoint["model_state"])
  
  model.eval()
  model.to(device)
  ctr = 0
  for img, img_path in mit_data.get_images():
    img = img.to(device)
    # Make this into a batch of size 1. Model neeed NCHW
    img = img.unsqueeze(0) # [1, C=3, H=512, W=512]
    out = model(img) # [1, C=151, H=512, W=512]
    
    # Run the model softmax output through a densecrf
    # AH TODO
    
    # TMP dump tensors
    tnsr_ckpoint = {"img" : img, "out": out}
    torch.save(tnsr_ckpoint, f"results/tnsr_{ctr}.pkl")
    ctr += 1
    
    _, pred = out.max(1) # [1, H=512, W=512]
    pred_np = pred.squeeze().cpu().numpy() # [H=512, W=512]
    pred_col = EncodeColor(pred_np, colors['colors'])
    
    orig_file_name = os.path.splitext(os.path.basename(img_path))[0]
    results_file = "results/" + orig_file_name + "_res.png"
    Image.fromarray(pred_col).save(results_file)

    
config = {"ckpoint": "checkpoints/fcn32s_10.pkl"}
test(config)


Found 18 testing images
Loading model checkpoint: checkpoints/fcn32s_10.pkl
