In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp

from tqdm import tqdm
from dataset import *

**Load Model**

In [2]:
# For GPU Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model Definition
model = smp.Unet(
    encoder_name='resnet34',
    encoder_weights=None,
    in_channels=5,
    classes=3,
    activation='softmax'
    )#.to(device)

model.load_state_dict(torch.load('./87DICE.pth', map_location='cpu'))

<All keys matched successfully>

**Load Data**

In [None]:
# Takes maximum class probability and assigns class to output
def max_classvals(out):
    no_classes, height, width = out.shape
    class_img = np.zeros((height, width), dtype=np.uint8)

    for i in range(height):
      for j in range(width):
        class_img[i, j] = np.argmax(out[:, i, j]) #take max value

    return class_img

In [7]:
# Loading test data then cropping to be divisible by 32
test_data = np.load('./data/raw/south_data.npy')
test_retile = test_data[::, 0:((test_data.shape[1]//32)*32), 0:((test_data.shape[2]//32)*32)]

**Run Inference**

In [None]:
# Retile test data
test_tiles = divide_image(test_retile, (64), (64))
preds = []
bar = tqdm(range(len(test_tiles)), position=0)

# Evaluate model
model.eval()
for iter in bar:
    pred = model.forward(torch.Tensor(np.expand_dims(test_tiles[iter].astype(int), axis=0)).type(torch.cuda.FloatTensor))
    pred = max_classvals(pred.detach().cpu().numpy().reshape(3, 128, 128))
    preds.append(pred)

In [3]:
preds = np.load('./data/preds.npy')

**Generate Results**

In [None]:
# Reconstruct image from patches
height = test_retile.shape[1]
width = test_retile.shape[2]
stride = 64
out = np.zeros((1, test_retile.shape[1], test_retile.shape[2]))
count = 0

# Iterate through predictions and assign to row/col in output array
for y in range(0, height-stride+1, stride):
  for x in range(0, width-stride+1, stride):
    out[::, y:y+stride, x:x+stride] = preds[count].reshape(1, stride, stride)
    count += 1
    
np.save('./out128.npy', out)
