In [2]:
import sys
sys.path.append("..")
import numpy as np
import torch as th
from torchvision.transforms import v2 as tr

from tqdm.auto import tqdm
from matplotlib import pyplot as plt

In [4]:
from dmfn.models.networks import define_G
from dmfn.data import create_dataset, create_dataloader
from dmfn.data.util import tensor2img

In [5]:
from dataclasses import dataclass

@dataclass
class Args:
    checkpoint: str

    device: int = 0

    seed: int = 4242
    mask_size: int = 32
    use_rnd_mask: bool = False

In [6]:
def load_model(ckpt, *, device: th.device, **opt):
    model = define_G(opt).to(device)
    model.load_state_dict(th.load(ckpt), strict=True)
    model.eval()
    return model

In [7]:

@dataclass
class DataOpts:
    name: str = 'cub'
    batch_size: int = 1
    use_shuffle: bool = True
    n_workers: int = 4
    fineSize: int = 256
    img_shape: tuple = (3, 256, 256)  # [channel, height, width]
    image_list: str = '../../datasets/cub/images_val.txt'
    mask_list: str = None
    mask_type: str = 'regular'
    mask_pos: str = 'center'
    mask_height: int = 160
    mask_width: int = 160
    vertical_margin: int = 0
    horizontal_margin: int = 0
    max_delta_height: int = 0
    max_delta_width: int = 0

In [8]:
def main(args: Args):
    device = th.device("cpu") if args.device < 0 and not th.cuda.is_available() else th.device(f"cuda:{args.device}")
    model = load_model(args.checkpoint, device=device, 
                       network_G=dict( which_model_G='DMFN', in_nc=4, out_nc=3, nf=64, n_res=8),
                       is_train=False
                      )
    data_opt = DataOpts()
    data = create_dataset(vars(data_opt), "val")
    #loader = create_dataloader(data, vars(data_opt))

    rnd = np.random.RandomState(args.seed)
    idxs = rnd.choice(len(data), 16, replace=False)
    for idx in tqdm(idxs):
        batch = data[idx]
        target = batch["target"][None]

        if args.use_rnd_mask:
            mask = th.as_tensor(rnd.randint(0, 2, size=(1, 1, args.mask_size, args.mask_size)).astype(np.float32))
            mask = tr.Resize(target.size()[-2:], interpolation=tr.InterpolationMode.NEAREST)(mask)
        else:
            mask = 1 - batch["mask"][None]

        im1, im2 = target * mask, target * (1-mask)
        X1 = th.cat([im1, 1-mask], dim=1).to(device)
        X2 = th.cat([im2, mask], dim=1).to(device)
        X = th.cat([X1, X2], dim=0)
        out = model(X).detach().float().cpu()

        res = out[0] * (1-mask) + out[1] * mask
        
        arrs = [tensor2img(target)[0], tensor2img(mask)[0], tensor2img(out)[0], tensor2img(out)[1], tensor2img(res)[0]]
        fig, axs = plt.subplots(ncols=len(arrs), figsize=(16,9))

        for ax, arr in zip(axs, arrs):
            ax.axis("off")
            ax.imshow(arr)

        plt.tight_layout()
        plt.show()
        plt.close()
    

In [9]:
main(Args(
    checkpoint="../outputs/cub/checkpoints/latest_G.pth",
    mask_size=16,
    use_rnd_mask=True
))

FileNotFoundError: [Errno 2] No such file or directory: '../outputs/cub/checkpoints/latest_G.pth'