In [1]:
import os
import sys

import torch
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from yaml import load, dump
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper

In [2]:
from Model.FastSCNN import *
from Dataset.dataset import *

In [3]:
with open("params.yml") as file:
    params = load(file, Loader=Loader)
    
dataset_path = params["dataset_path"]   # Path at which the dataset is located
crop_height  = params["crop_height"]    # Height of cropped/resized input image
crop_width   = params["crop_width"]     # Width of cropped/resized input image
num_classes  = params["num_classes"]    # Number of classes

In [4]:
test_dataset = Dataset(dataset_path, crop_height, crop_width, mode="test")
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size  = 1,
    num_workers = 4,
    shuffle     = True
)
dataloader = iter(test_dataloader)

In [5]:
image, label = next(dataloader)

In [6]:
model = FastSCNN(image_height   = crop_height,
                 image_width    = crop_width,
                 image_channels = 3,
                 num_classes    = num_classes)
model.load_state_dict(torch.load("./checkpoints/best_model.pt", map_location=torch.device('cpu')))
model.eval()

FastSCNN(
  (learning_to_downsample): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=48, bias=False)
      (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(48, 64, kernel_size=(1, 1), 

In [7]:
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, image)

  return F.interpolate(x, list(map(int, size)), mode='bilinear', align_corners=True)
  x = F.interpolate(x, list(map(int, size)), mode='bilinear', align_corners=True)


In [9]:
traced_script_module.save("./checkpoints/traced_fastscnn_model.pt")