## Historical Forest Mapping by DeepLabV3

Do inference and display the semantic segmentation results

__Step 1.__ Import necessary packages

In [None]:
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
import os
import glob

# local import
import custom_model
from iou import iou
from accuracy import count_for_user_accuracy, count_for_producer_accuracy, count_for_overall_accuracy

__Step 2.__ Find the hardware

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

__Step 3.__ Define file paths needed in this notebook

In [None]:
# define the path to the weights file
weights_dir = 'weights/weights.pt'
# path to the test images
img_path = "./image.tif"
# path to the test labels
label_path = "./target.tif"

__Step 4.__ Initialize the model with trained weights, and set the model in evaluation mode

In [None]:
# import our trained model
model = custom_model.initialize_model(3, keep_feature_extract=True)
state_dict = torch.load(weights_dir, map_location=device)
model = model.to(device)
model.load_state_dict(state_dict)

# set the model in evaluation mode
model.eval()

__Step 5.__ Input one image in the test set, do the transform required by the model

In [None]:
# load the image
image = Image.open(img_path)

# define the transforms
image_transforms = transforms.Compose([
    transforms.Resize(size=(512, 512), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# apply transforms
image = image_transforms(image)

__Step 6.__ Do the inference, generate the prediction

In [None]:
# do the inference
outputs = model(image)["out"]

# select the prediction only in the first two classes
outputs = outputs[:, :2, :, :]
# get the prediction
_, preds = torch.max(outputs, 1)

__Step 7.__ Resize the prediction back to the original size, display the image, prediction and ground truth

In [None]:
# load the label
label = Image.open(label_path)

# resize the prediction to the label size
preds = transforms.Resize(size=(label.size[1], label.size[0]), interpolation=transforms.InterpolationMode.NEAREST)(preds)

# change color of the prediction and label
color_pair_dict = {0: (255, 255, 255), 1: (0, 255, 0), 255: (0, 0, 0)}

# convert the prediction and label to numpy array
preds = preds.cpu().numpy()
preds = np.squeeze(preds)
label = np.array(label)
label = np.squeeze(label)

# convert the prediction and label to RGB image
preds = np.vectorize(color_pair_dict.get)(preds).astype(np.uint8)
label = np.vectorize(color_pair_dict.get)(label).astype(np.uint8)

# display the original image, prediction and label
