In [1]:
import os
import argparse
from utils import util
from models.cut_seg import CUT_SEG_model
import torch
from utils.create_dataset import EczemaDataset
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle

  from .autonotebook import tqdm as notebook_tqdm


<h1> pick 100 images from source and target folder for evaluation </h1>

In [None]:
import os
from glob import glob
from PIL import Image
import random

In [None]:
source_path = 'evaluation/source'
source_img_path = []

for ext in ('*.jpg', '*.png', '*.JPG', '*.PNG'):
    source_img_path.extend(glob(os.path.join(source_path, ext)))

In [None]:
source_img_path = random.sample(source_img_path, 100)

In [None]:
import shutil

saved_source_path = 'evaluation/clean_source'
for i, path in enumerate(source_img_path):
    shutil.copy(path, f'{saved_source_path}/source={i + 1}.png')


In [None]:
target_path = 'evaluation/target'
target_img_path = []

for ext in ('*.jpg', '*.png', '*.JPG', '*.PNG'):
    target_img_path.extend(glob(os.path.join(target_path, ext)))

In [None]:
target_img_path = random.sample(target_img_path, 100)

In [None]:
saved_target_path = 'evaluation/clean_target'
for i, path in enumerate(target_img_path):
    shutil.copy(path, f'{saved_target_path}/target={i + 1}.png')

<h1> plot the loss curve </h1></br>
<h4> plot netG loss </h4>

In [None]:
loss_path = 'output_HPC/victor_demo_v4/loss/loss_400'

In [None]:
loss = []
with (open(loss_path, "rb")) as openfile:
    while True:
        try:
            loss.append(pickle.load(openfile))
        except EOFError:
            break
loss = loss[0]

In [None]:
loss_G = [l['G'] for l in loss]
loss_GAN = [l['G_GAN'] for l in loss]

In [None]:
plt.plot(loss_G)
plt.xlabel('epoch')
plt.ylabel('netG_loss')

In [None]:
plt.plot(loss_GAN)
plt.xlabel('epoch')
plt.ylabel('netG_loss')

<h1>Load the network</h1>

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else: device = torch.device('cpu')

In [3]:
device

device(type='cuda')

In [4]:
def ArgParse():
    parser = argparse.ArgumentParser(description='CUT inference usage.')
    # Evaluation
    # model parameters
    """GAN parameters"""
    parser.add_argument('--CUT_mode', type=str, default="CUT", choices=['CUT', 'cut', 'FastCUT', 'fastcut'], help='')
    parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
    parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
    parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
    parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
    parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
    parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks', 'resnet_6blocks'], help='specify generator architecture')
    parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
    parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
    parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
    parser.add_argument('--init_type', type=str, default='xavier', choices=['normal', 'xavier', 'kaiming', 'orthogonal'], help='network initialization')
    parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
    parser.add_argument('--no_dropout', type=util.str2bool, nargs='?', const=True, default=True,
                        help='no dropout for the generator')
    parser.add_argument('--antialias', action='store_true', help='if specified, use antialiased-downsampling')
    parser.add_argument('--antialias_up', action='store_true', help='if specified, use [upconv(learned filter)]')
    parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss：GAN(G(X))')
    """netF paramters"""
    parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
    parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=True, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
    parser.add_argument('--nce_layers', type=str, default='0,3,5,7,11', help='compute NCE loss on which layers')
    parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
    parser.add_argument('--netF_nc', type=int, default=256)
    parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
    parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
    """netS parameters"""
    parser.add_argument('--netS', type=str, default='resnet', choices=['resnet', 'unet_256', 'smp'], help='how to segment the input image')
    parser.add_argument('--smp_arch', type=str, default='Unet', help='the segmentor architectur')
    parser.add_argument('--smp_encoder', type=str, default='efficientnet-b3', help='the encoder name')
    parser.add_argument('--normS', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for S')
    parser.add_argument('--num_class', type=int, default=2, help='# of output image channels for segmented mask')
    parser.add_argument('--netS_lambda', type=int, default=10, help='lambda for SEG loss')
    parser.add_argument('--netS_Loss', type=str, help='semantic segmentation loss function', choices=['dice', 'bce', 'DICE', 'BCE'], default='bce')
    parser.add_argument('--flip_equivariance',
                        type=bool, nargs='?', default=False,
                        help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")

    parser.add_argument('--src_dir', help='source dataset folder', type=str, default='evaluation/clean_source')
    parser.add_argument('--out_dir', help='output folder', type=str, default='evaluation/v8/')
    parser.add_argument('--name', type=str, default='demo_v4', help='name of the experiment. It decides where to store samples and models')
    parser.add_argument('--easy_label', type=str, default='demo_v4', help='Interpretable name')
    parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')

    
    parser.add_argument('--isTrain', type=util.str2bool, default=False)


    opt, _ = parser.parse_known_args()

    # Set default parameters for CUT and FastCUT
    if opt.CUT_mode.lower() == "cut":
        parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
    elif opt.CUT_mode.lower() == "fastcut":
        parser.set_defaults(
            nce_idt=False, lambda_NCE=10.0, flip_equivariance=True,
            n_epochs=150, n_epochs_decay=50
        )
    else:
        raise ValueError(opt.CUT_mode)

    return opt


In [5]:
opt = ArgParse()
opt

Namespace(CUT_mode='CUT', input_nc=3, output_nc=3, ngf=64, ndf=64, netD='basic', netG='resnet_9blocks', n_layers_D=3, normG='instance', normD='instance', init_type='xavier', init_gain=0.02, no_dropout=True, antialias=False, antialias_up=False, lambda_GAN=1.0, lambda_NCE=1.0, nce_idt=True, nce_layers='0,3,5,7,11', netF='mlp_sample', netF_nc=256, nce_T=0.07, num_patches=256, netS='resnet', smp_arch='Unet', smp_encoder='efficientnet-b3', normS='instance', num_class=2, netS_lambda=10, netS_Loss='bce', flip_equivariance=True, src_dir='evaluation/clean_source', out_dir='evaluation/v8/', name='demo_v4', easy_label='demo_v4', checkpoints_dir='./checkpoints', isTrain=False)

In [6]:
model = CUT_SEG_model(opt=opt)

using device: cuda


<h4>load generator</h4>

In [7]:
netG = getattr(model, 'netG')

In [8]:
if isinstance(netG, torch.nn.DataParallel):
    netG = netG.module

In [10]:
load_path = 'checkpoints_HPC/demo_v8/400_net_G.pth'

In [11]:
state_dict = torch.load(load_path, map_location=str(device))

In [12]:
if hasattr(state_dict, '_metadata'):
    del state_dict._metadata

In [13]:
netG.load_state_dict(state_dict)

<All keys matched successfully>

<h4> create the dataset </h4>

In [14]:
import torch.utils.data as data
import os
from glob import glob
import numpy as np
from torch.autograd import Variable
from PIL import Image
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from utils.util import img2tensor
import random



class EvaluationDS(data.Dataset):
    
    def __init__(self, src_img_path):
        super().__init__()
        
        self.src_img_path = []
        for ext in ('*.jpg', '*.png', '*.JPG', '*.PNG'):
            self.src_img_path.extend(glob(os.path.join(src_img_path, ext)))

        
    def __len__(self):
        return len(self.src_img_path)
    
    
    def transform(self, real_img):
        # Resize
        resize = T.Resize(size=(286,286), interpolation=T.InterpolationMode.NEAREST)
        real_img = resize(real_img)

        # Random crop
        i, j, h, w = T.RandomCrop.get_params(
            real_img, output_size=(256, 256))
        real_img = TF.crop(real_img, i, j, h, w)
        

        # Random horizontal flipping
        if random.random() > 0.5:
            real_img = TF.hflip(real_img)
            

        return real_img
    
    def normalize(self, real_img):

        normalize = T.Normalize(mean=(127.5,127.5,127.5), std=(127.5,127.5,127.5))
        normalize_real_img = normalize(real_img)
        normalize_real_img = normalize_real_img.squeeze(0)
        
        return normalize_real_img
        
    
    def __getitem__(self, index):
        src_img = np.asarray(Image.open(self.src_img_path[index]))
        src_img = img2tensor(src_img)
        
        t_src_img = self.transform(src_img)

        # normalize the imgs
        n_src_img = self.normalize(t_src_img)
        
        return n_src_img

In [15]:
opt.src_dir

'evaluation/clean_source'

In [16]:
source = EvaluationDS(opt.src_dir)

In [17]:
src_dataloader = DataLoader(source, batch_size=1, shuffle=True)

<h4> generate the translated image and save </h4>

In [18]:
compare_path = os.path.join(opt.out_dir, 'compare')
for idx, data in enumerate(src_dataloader):
    # only translate 100 images
    data = data.to(device)
    # source
    source_img = util.tensor2img(data)
    source_img = (source_img * 127.5 + 127.5).astype(np.uint8)

    # translated
    translated = netG(data)
    translated = util.tensor2img(translated)
    translated = (translated * 127.5 + 127.5).astype(np.uint8)

    _, ax = plt.subplots(1, 2, figsize=(20, 20))
    ax[0].imshow(source_img)
    ax[1].imshow(translated)
    ax[0].set_title("Input")
    ax[1].set_title("Translated")
    ax[0].axis("off")
    ax[1].axis("off")

    plt.savefig(f'{compare_path}/infer={idx + 1}.png')
    plt.close()


In [19]:
translated_path = os.path.join(opt.out_dir, 'translated')
for idx, data in enumerate(src_dataloader):
    data = data.to(device)

    # translated
    translated = netG(data)
    translated = util.tensor2img(translated)
    translated = (translated * 127.5 + 127.5).astype(np.uint8)

    plt.figure(figsize=(20, 20))
    plt.imshow(translated)
    plt.axis('off')

    plt.savefig(f'{translated_path}/infer={idx + 1}.png')
    plt.close()

: 