# Demo

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.chdir('..')
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms

from PIL import Image, ImageOps

from models.unet import build_unet

In [None]:
# Device
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)

# Load weights
model_path = "./logs/hma_unet/hma_unet.pth"
model = build_unet()
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint)
model = model.to(DEVICE)
model.eval()

### Helpers

In [None]:
def read_image(image_path):
    assert type(image_path) == str, f"Should be a path, got: {image_path} which is {type(image_path)}"
    img = Image.open(image_path)
    return img


def predict(input_image):
    
    preprocess = transforms.Compose([transforms.Resize((512, 256)),
                                     transforms.CenterCrop((512, 256)),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5),
                                                          (0.5, 0.5, 0.5))
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
    input_batch = input_batch.to(DEVICE)
    with torch.no_grad():
        out = torch.sigmoid(model(input_batch.float()))
    
    # resize to original image size
    out = torch.nn.functional.interpolate(out, 
                                          size=(input_image.size[1], input_image.size[0]), 
                                          mode='bicubic',
                                          align_corners=True)
    out = out.permute(0, 2, 3, 1).squeeze().detach().cpu().numpy() > 0.5
    out=(out*255).astype(np.uint8)
    out = Image.fromarray(np.uint8(out)).convert('RGB')
    return out
    

### Make a prediction

In [None]:
path = "./datasets/human_artifacts/train/humans/3989.png"
image = read_image(path)

In [None]:
plt.imshow(image)

In [None]:
mask = predict(image)
plt.imshow(mask)

In [None]:
mask.save("notebooks/prediction.png")