In [None]:
import os
from glob import glob
import model.aotgan 
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from tqdm import tqdm
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import Dataset
import random
import numpy as np
from torch.utils.data import DataLoader
from torchinfo import summary
import matplotlib.pyplot as plt
from loss1 import loss as loss_module
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "4" 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
params={'image_size':512,
        'rates':[1, 2, 4, 8],
        'block_num':8,
        'model':'aotgan',
        'gan_type':"smgan",
        'lrg':1e-4,
        'lrd':1e-4,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':8,
        'epochs':500,
        'data_path':'../../data/dataset/adenoma/',
        'num_workers':4,
        'rec_loss':'1*L1+250*Style+0.1*Perceptual'
        }
losses = list(params['rec_loss'].split("+"))
params['rec_loss'] = {}
for l in losses:
    weight, name = l.split("*")
    params['rec_loss'][name] = float(weight)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, args,dataset):
        super(Dataset, self).__init__()
        self.w = self.h = args['image_size']

        # image and mask
        self.image_path =glob(args['data_path']+dataset+'/image/*.jpg')
        self.mask_path = [i.replace('/image','/mask') for i in self.image_path]

        #augmentation
        self.trans = transforms.Compose(
            [
                transforms.Resize(args['image_size'], interpolation=transforms.InterpolationMode.NEAREST),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation((0, 45), interpolation=transforms.InterpolationMode.NEAREST),
                transforms.RandomResizedCrop(args['image_size']),
            ]
        )
        

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, index):
        # load image
        random_seed = 42
        random.seed(random_seed)  # Python의 random 함수에 대한 시드 설정
        torch.manual_seed(random_seed)
        image = Image.open(self.image_path[index]).convert("RGB")
        filename = os.path.basename(self.image_path[index])
        mask = Image.open(self.mask_path[index])
        mask = mask.convert("L")
        # augment
        torch.manual_seed(random_seed)
        image = F.to_tensor(self.trans(image)) * 2.0 - 1.0
        torch.manual_seed(random_seed)
        mask = F.to_tensor(self.trans(mask))
        
        return image, mask, filename
    
train_dataset=CustomDataset(params,'train')
test_dataset=CustomDataset(params,'test')
train_dataloader = DataLoader(
        train_dataset,
        batch_size=params['batch_size'],
         shuffle=True, drop_last=True)
test_dataloader = DataLoader(
        test_dataset,
        batch_size=params['batch_size'],
         shuffle=True, drop_last=True)

In [None]:
netG =model.aotgan.InpaintGenerator(params).to(device)
optimG = torch.optim.Adam(netG.parameters(), lr=params['lrg'], betas=(params['beta1'], params['beta2']))

netD = model.aotgan.Discriminator().to(device)
optimD = torch.optim.Adam(netD.parameters(), lr=params['lrd'], betas=(params['beta1'], params['beta2']))
rec_loss_func = {key: getattr(loss_module, key)() for key, val in params['rec_loss'].items()}
adv_loss = getattr(loss_module, "smgan")()

In [None]:

for epoch in range(params['epochs']):
    train=tqdm(train_dataloader)
    count=0
    train_L1_loss = 0.0
    train_Style_loss = 0.0
    train_Perceptual_loss = 0.0
    train_advg_loss = 0.0
    train_advd_loss = 0.0
    for images, masks,filename in train:
        count+=1
        images, masks = images.to(device), masks.to(device)
        images_masked = (images * (1 - masks).float()) + masks
        pred_img = netG(images_masked, masks)
        comp_img = (1 - masks) * images + masks * pred_img
        losses = {}
        for name, weight in params['rec_loss'].items():
            losses[name] = weight * rec_loss_func[name](pred_img, images)
        dis_loss, gen_loss = adv_loss(netD, comp_img, images, masks)
        losses["advg"] = gen_loss * 0.01
        # backforward
        optimG.zero_grad()
        optimD.zero_grad()
        sum(losses.values()).backward()
        losses["advd"] = dis_loss
        dis_loss.backward()
        optimG.step()
        optimD.step()
        train_L1_loss+=losses['L1'].item()
        train_Style_loss+=losses['Style'].item()
        train_Perceptual_loss+=losses['Perceptual'].item()
        train_advg_loss+=losses['advg'].item()
        train_advd_loss+=losses['advd'].item()
        train.set_description(f"epoch: {epoch+1}/{params['epochs']} Step: {count+1} L1 loss : {train_L1_loss/count:.4f} Style loss: {train_Style_loss/count:.4f} Perceptual loss: {train_Perceptual_loss/count:.4f} advg loss: {train_advg_loss/count:.4f} advd loss: {train_advd_loss/count:.4f}")
    test=tqdm(test_dataloader)
    test_count=0
    test_L1_loss = 0.0
    test_Style_loss = 0.0
    test_Perceptual_loss = 0.0
    test_advg_loss = 0.0
    test_advd_loss = 0.0
    with torch.no_grad():
        for images, masks,filename in test:
            test_count+=1
            images, masks = images.to(device), masks.to(device)
            images_masked = (images * (1 - masks).float()) + masks
            pred_img = netG(images_masked, masks)
            comp_img = (1 - masks) * images + masks * pred_img
            test_losses = {}
            for name, weight in params['rec_loss'].items():
                test_losses[name] = weight * rec_loss_func[name](pred_img, images)
            dis_loss, gen_loss = adv_loss(netD, comp_img, images, masks)
            test_losses["advg"] = gen_loss * 0.01
            test_losses["advd"] = dis_loss
            test_L1_loss+=test_losses['L1'].item()
            test_Style_loss+=test_losses['Style'].item()
            test_Perceptual_loss+=test_losses['Perceptual'].item()
            test_advg_loss+=test_losses['advg'].item()
            test_advd_loss+=test_losses['advd'].item()
            test.set_description(f"val_epoch: {epoch+1}/{params['epochs']} Step: {count+1} L1 loss : {test_L1_loss/count:.4f} Style loss: {test_Style_loss/count:.4f} Perceptual loss: {test_Perceptual_loss/count:.4f} advg loss: {test_advg_loss/count:.4f} advd loss: {test_advd_loss/count:.4f}")
    if epoch % 10 ==0:
        ax=plt.figure(figsize=(24,8))
        ax.add_subplot(1,3,1)
        plt.imshow(np.transpose(images[0].cpu().detach().numpy(),(1,2,0))/2+0.5)
        ax.add_subplot(1,3,2)
        plt.imshow(np.transpose(images_masked[0].cpu().detach().numpy(),(1,2,0))/2+0.5)
        ax.add_subplot(1,3,3)
        plt.imshow(np.transpose(pred_img[0].cpu().detach().numpy(),(1,2,0))/2+0.5)
        plt.show()
    torch.save(netG.state_dict(), '../../model/aot-model_endoscopy/generator_'+str(epoch)+'.pt')
    torch.save(netD.state_dict(), '../../model/aot-model_endoscopy/discriminator_'+str(epoch)+'.pt')    

In [None]:
ax=plt.figure(figsize=(24,8))
ax.add_subplot(1,3,1)
plt.imshow(np.transpose(images[0].cpu().detach().numpy(),(1,2,0))/2+0.5)
ax.add_subplot(1,3,2)
plt.imshow(np.transpose(images_masked[0].cpu().detach().numpy(),(1,2,0))/2+0.5)
ax.add_subplot(1,3,3)
plt.imshow(np.transpose(pred_img[0].cpu().detach().numpy(),(1,2,0))/2+0.5)
        

In [None]:
(np.transpose(images[0].cpu(),(1,2,0))/2+0.5).min()

In [None]:
plt.imshow(np.transpose(train_dataset[0][0])