In [None]:
import os
import numpy as np
import datetime
from tqdm import tqdm
from skimage import io

import torch
from torch.utils.data import DataLoader

from csnet.dataset import CS_Dataset
from csnet.utils import show_image_grid

In [None]:
input_dir = 'data/test/img'
ref_dir = 'data/test/gt'
output_dir = 'data/test/predicted'

batch_size = 2
model_path = 'model/2022-06-16-16-14-14/2022-06-16-16-14-14_8.pkl'

In [None]:
ds = CS_Dataset(input_dir, ref_dir)
dl = DataLoader(ds, batch_size=batch_size, num_workers=batch_size, shuffle=False)
net = torch.load(model_path).cuda()

In [None]:
imgs = []
lbls = []
predictions = []

with torch.no_grad():
    net.eval()
    for idx, batch in enumerate(dl):
        image = batch[0].cuda()
        label = batch[1].cuda()
        pred = net(image)
        
        imgs.append(image)
        lbls.append(label)
        predictions.append(torch.argmax(pred, dim=1))
imgs = torch.concat(imgs)
lbls = torch.concat(lbls)
predictions = torch.concat(predictions)

In [None]:
os.makedirs(output_dir, exist_ok=True)
for pred, fn in zip(tqdm(predictions), ds.image_fns):
    io.imsave(fn.replace(input_dir, output_dir), (pred.cpu().numpy()*255).astype(np.uint8))

In [None]:
show_image_grid([imgs, lbls, predictions.unsqueeze(1)])