In [None]:
import os
from glob import glob
import model.aotgan 
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from tqdm import tqdm
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import Dataset
import random
import numpy as np
from torch.utils.data import DataLoader
from torchinfo import summary
import matplotlib.pyplot as plt
from loss1 import loss as loss_module
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

In [None]:
params={'image_size':512,
        'rates':[1, 2, 4, 8],
        'block_num':8,
        'model':'aotgan',
        'gan_type':"smgan",
        'lrg':1e-4,
        'lrd':1e-4,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':8,
        'epochs':500,
        'data_path':'../../data/dataset/colon/',
        'num_workers':4,
        'rec_loss':'1*L1+250*Style+0.1*Perceptual'
        }
losses = list(params['rec_loss'].split("+"))
params['rec_loss'] = {}
for l in losses:
    weight, name = l.split("*")
    params['rec_loss'][name] = float(weight)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, args,dataset):
        super(Dataset, self).__init__()
        self.w = self.h = args['image_size']

        # image and mask
        self.image_path =glob(args['data_path']+dataset+'/image/*.png')
        self.mask_path = [i.replace('/image','/mask') for i in self.image_path]
        self.trans_1 = transforms.Compose(
            [
                transforms.Resize((args['image_size'],args['image_size']), interpolation=transforms.InterpolationMode.NEAREST)
            ]
        )
    def trans(self,image_t,a):
        image_t=F.to_tensor(F.rotate(self.trans_1(image_t),a))
        return image_t

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, index):
        # load image
        image = Image.open(self.image_path[index]).convert("RGB")
        filename = os.path.basename(self.image_path[index])
        mask = Image.open(self.mask_path[index])
        mask = mask.convert("L")
        # augment
        angle=random.randint(0, 360)
        
        image = F.to_tensor(self.trans_1(image)) * 2.0 - 1.0
        mask =F.to_tensor(self.trans_1(mask))
        
        return image, mask, filename
    
test_dataset=CustomDataset(params,'test')
test_dataloader = DataLoader(
        test_dataset,
        batch_size=params['batch_size'],
         shuffle=True, drop_last=True)

In [None]:
netG =model.aotgan.InpaintGenerator(params).to(device)
optimG = torch.optim.Adam(netG.parameters(), lr=params['lrg'], betas=(params['beta1'], params['beta2']))

netD = model.aotgan.Discriminator().to(device)
optimD = torch.optim.Adam(netD.parameters(), lr=params['lrd'], betas=(params['beta1'], params['beta2']))
netG.load_state_dict(torch.load('../../model/aot-model_colon/generator_62.pt',map_location=device))
netD.load_state_dict(torch.load('../../model/aot-model_colon/discriminator_62.pt',map_location=device))

In [None]:
summary(netG, input_size=((params['batch_size'], 3, params['image_size'], params['image_size']),(params['batch_size'], 1, params['image_size'], params['image_size'])))

In [None]:
save_path='../../data/dataset/colon/test/result/'
with torch.no_grad():
    test=tqdm(test_dataloader)
    count=0
    for images, masks,filename in test:
        count+=1
        images, masks = images.to(device), masks.to(device)
        images_masked = (images * (1 - masks).float()) + masks
        pred_img = netG(images_masked, masks)
        comp_img = (1 - masks) * images + masks * pred_img
        for i in range(params['batch_size']):
            img3=np.transpose((pred_img[i].cpu().detach().numpy()/2+0.5)*255,(1,2,0))
            pillow_pred=Image.fromarray(img3.astype(np.uint8))
            pillow_pred.save(save_path+filename[i])
        ax=plt.figure(figsize=(24,8))
        ax.add_subplot(1,3,1)
        plt.imshow(np.transpose(images[0].cpu().detach().numpy(),(1,2,0))/2+0.5)
        ax.add_subplot(1,3,2)
        plt.imshow(np.transpose(images_masked[0].cpu().detach().numpy(),(1,2,0))/2+0.5)
        ax.add_subplot(1,3,3)
        plt.imshow(np.transpose(pred_img[0].cpu().detach().numpy(),(1,2,0))/2+0.5)
        plt.show()

In [None]:
img1=np.transpose((images[0].cpu().detach().numpy()/2+0.5)*255,(1,2,0))
img2=np.transpose((masks[0].cpu().detach().numpy())*255,(1,2,0))
img3=np.transpose((pred_img[0].cpu().detach().numpy()/2+0.5)*255,(1,2,0))

In [None]:
pillow_mask=Image.fromarray(img2[:,:,0].astype(np.uint8))
pillow_image=Image.fromarray(img1.astype(np.uint8))
pillow_pred=Image.fromarray(img3.astype(np.uint8))
pillow_pred