### Instructions
1. Clone the repository: https://github.com/pzajec/DRAEM
2. Download the checkpoint and extract it to **DRAEM/checkpoints/** 
3. Input to DRAEM are 256x256 images (colored)

In [None]:
import sys
sys.path.append('../DRAEM/') # Path to DRAEM repository
from test_DRAEM import *

import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import utils

to_image = transforms.ToPILImage()
to_tensor = transforms.ToTensor()

torch.set_grad_enabled(False)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Load DRAEM checkpoint

In [None]:
checkpoint_path = '../DRAEM/checkpoints/'
run_name = 'DRAEM_test_0.0001_700_bs8_shavers_'

model = ReconstructiveSubNetwork(in_channels=3, out_channels=3)
model.load_state_dict(torch.load(os.path.join(checkpoint_path,run_name+".pckl"), map_location=device))
model.to(device)
model.eval()

model_seg = DiscriminativeSubNetwork(in_channels=6, out_channels=2)
model_seg.load_state_dict(torch.load(os.path.join(checkpoint_path, run_name+"_seg.pckl"), map_location=device))
model_seg.to(device)
model_seg.eval();

### DRAEM prediction

In [None]:
def get_draem_prediction(t_image):
    # t_image: torch.Tensor
    #image = to_tensor(image)[None, :].to(device)
    image = t_image
    
    # Get reconstruction
    image_rec = model(image)

    # Get heatmap
    joined_in = torch.cat((image_rec.detach(), image), dim=1)
    out_mask = model_seg(joined_in)
    out_mask_sm = torch.softmax(out_mask, dim=1)

    # Get anomaly score
    out_mask_averaged = torch.nn.functional.avg_pool2d(out_mask_sm[: ,1: ,: ,:], 21, stride=1,
                                                               padding=21 // 2).cpu().detach().numpy()
    image_score = np.max(out_mask_averaged)
    
    # Tensors to images
    o = to_image(image[0].cpu())
    r = to_image(image_rec[0].cpu())
    m = to_image(out_mask_sm[:, 1:, :, :][0].cpu())

    # Gray to heatmap
    ma = np.array(m)
    cmap = plt.cm.jet
    norm = plt.Normalize(vmin=ma.min(), vmax=ma.max())
    heatmap = cmap(norm(ma))

    return o, r, heatmap, image_score

In [None]:
# Load raw images instead of tensors
shavers_raw = utils.Shavers('<path_to_your_dataset>',
        return_tensors=False, dims=(3, 256, 256))

# Use a subset of good images
inds = np.arange(len(shavers_raw))[np.array(shavers_raw.targets) == 1]

# Generate double print image
im = np.array(shavers_raw[inds[0]][0])
im = Image.fromarray(double_print(im)[0])

In [None]:
original, reconstructed, heatmap, anomaly_score = get_draem_prediction(
    to_tensor(im)[None, :].to(device))