In [1]:
import sys
print(sys.executable)

/opt/conda/bin/python


In [2]:
import numpy as np
import skimage.io as io
import random
import os
import albumentations as A
import cv2
import pickle
import torch
from pycocotools.coco import COCO
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
import glob

In [3]:
torch.cuda.empty_cache()

In [4]:
import random

def param2theta(param, w, h):
    param = np.linalg.inv(param)
    theta = np.zeros([2,3])
    theta[0,0] = param[0,0]
    theta[0,1] = param[0,1]*h/w
    theta[0,2] = param[0,2]*2/w + theta[0,0] + theta[0,1] - 1
    theta[1,0] = param[1,0]*w/h
    theta[1,1] = param[1,1]
    theta[1,2] = param[1,2]*2/h + theta[1,0] + theta[1,1] - 1
    return theta

class MSCOCODataset(Dataset):
    def __init__(self, args=None, aug_trans=None, basic_trans=None):
        self.args = args
        self.images = sorted(glob.glob(args.data_dir + '/*'))
        self.ann_file = '{}/annotations/instances_{}.json'.format(args.ann_dir, args.mode)
        self.coco = COCO(self.ann_file)
        self.aug_trans = aug_trans
        self.basic_trans = basic_trans
        self.img_size = args.input_size
        self.center = (self.img_size/2, self.img_size/2)
        
    def __len__(self):
        return len(self.images)//4

    def __getitem__(self, idx):
        annId = int(self.images[idx*4].split('_')[0].split('/')[-1])
        # print('annId:', annId)
        # file_name, i, img_type = self.imxages[idx*4].split('_')
        # print(self.images[idx*4], self.images[idx*4+1], self.images[idx*4+2], self.images[idx*4+3])
    
        bg = self.getImage(self.images[idx*4])
        fg = self.getImage(self.images[idx*4+1])
        gt = self.getImage(self.images[idx*4+2])
        mask = self.getMask(self.images[idx*4+3])
      
        ##########

        angle_factor = random.randint(-10, 10)
        scale_factor = round(random.uniform(0.95, 1.05), 2)
        translation_factor_x = round(random.uniform(0, self.img_size*0.05), 2)
        translation_factor_y = round(random.uniform(0, self.img_size*0.05), 2)
        rotate_matrix = cv2.getRotationMatrix2D(center=self.center, angle=angle_factor, scale=0.95)

        new_row = np.array([[0, 0, 1]], dtype=np.float32)
        rotate_matrix = np.concatenate((rotate_matrix, new_row), axis=0)
        translation_matrix = np.array([[1, 0, translation_factor_x],
                                       [0, 1, translation_factor_y],
                                       [0, 0, 1],], dtype=np.float32)

        matrix = np.matmul(rotate_matrix, translation_matrix)[:2][:3]
        matrix = np.concatenate((matrix, new_row), axis=0)
        matrix = param2theta(matrix, self.img_size, self.img_size)

        # matrix = np.array([[1, 0, 1],
        #                 [0, 1, 0],], dtype=np.float32)
        # new_row = np.array([[0, 0, 1]], dtype=np.float32)
        # matrix = np.concatenate((matrix, new_row), axis=0)
        # matrix = param2theta(matrix, self.img_size, self.img_size)

        tf_fg = cv2.warpAffine(src=fg, M=matrix, dsize=(self.img_size, self.img_size))
        tf_mask = cv2.warpAffine(src=mask, M=matrix, dsize=(self.img_size, self.img_size))
        
        ###########
        
        bg = self.basic_trans(image=bg)['image']
        gt = self.basic_trans(image=gt)['image']
        
        tf_item = self.aug_trans(image=tf_fg, mask=tf_mask) #only color aug applied
        tf_fg = tf_item['image']
        tf_mask = tf_item['mask'].unsqueeze(0)
    
        tf_item = self.basic_trans(image=fg, mask=mask)
        fg = tf_item['image']
        mask = tf_item['mask'].unsqueeze(0)
        
        return gt, fg, bg, mask>0.5, tf_fg, tf_mask>0.5, matrix#, cat
    
    def getClassName(self, classID, cats):
        for i in range(len(cats)):
            if cats[i]['id'] == classID:
                return cats[i]['name']
        return "None"
    
    def getImage(self, file_name):
        img = cv2.imread(file_name, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img
    
    def getMask(self, file_name):
        mask = cv2.imread(file_name, cv2.IMREAD_GRAYSCALE)
        return mask

In [5]:
import torch.nn as nn
import torch.nn.functional as F
import math

class GenCompModel(nn.Module):
    def __init__(self, args):
        super(GenCompModel, self).__init__()
        self.args = args
        self.img_size = args.input_size
        self.batch_size = args.batch_size
        self.stn = STN(7) # fg+bg+mask channels
        self.colornet = LinearWithChannel(self.batch_size, self.img_size, self.img_size, 3) 
        self.refinenet = TransformNetwork(3, 3)
    
    def forward(self, fg, bg, mask):
        HI, AI, trans_mat = self.stn(fg, bg, mask)
        FI = self.colornet(HI) # changed to inputs as I
        AI = AI > 0.5
        R_in = torch.multiply(FI, AI) + torch.multiply(bg, ~AI)
        R_out = self.refinenet(R_in)
        return AI, HI, FI, trans_mat, R_in, R_out
    
class Discriminator(nn.Module):    
    def __init__(self):        
        super(Discriminator, self).__init__()
        self.imgdisc = ImageDiscriminator(3, norm='spec')
        self.segnet = TransformNetwork(3, 1)
        
    def forward(self, mask, R_out):
        img_out = self.imgdisc(R_out)
        fg_seg_out = self.segnet(torch.multiply(mask, R_out), last='sigmoid')
        bg_seg_out = self.segnet(torch.multiply(~mask, R_out), last='sigmoid')

        return img_out, fg_seg_out, bg_seg_out
    
class TransformNetwork(nn.Module):
    def __init__(self, in_ch, out_ch):        
        super(TransformNetwork, self).__init__()        
        
        self.layers = nn.Sequential(            
            ConvLayer(in_ch, 32, 9, 1),
            ConvLayer(32, 64, 3, 2),
            ConvLayer(64, 128, 3, 2),
            
            ResidualLayer(128, 128, 3, 1),
            ResidualLayer(128, 128, 3, 1),
            ResidualLayer(128, 128, 3, 1),
            ResidualLayer(128, 128, 3, 1),
            ResidualLayer(128, 128, 3, 1),
            
            DeconvLayer(128, 64, 3, 1),
            DeconvLayer(64, 32, 3, 1),
            ConvLayer(32, out_ch, 9, 1, activation='linear'))
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, last=None):
        x = self.layers(x)
        if last:
            x = self.sigmoid(x)
        return x
    
class LinearWithChannel(nn.Module):
    def __init__(self, batch_size, input_size, output_size, channel_size):
        super(LinearWithChannel, self).__init__()
        self.w = torch.nn.Parameter(torch.empty(channel_size))
        self.b = torch.nn.Parameter(torch.zeros(channel_size))
    
    def forward(self, x):
        return x * self.w.view(1, 3, 1, 1) + self.b.view(1, 3, 1, 1)
    
class STN(nn.Module):
    def __init__(self, in_ch):
        super(STN, self).__init__()
        self.in_ch = in_ch
        
        # localization-network for STN
        self.localization = nn.Sequential(
            ConvLayer(self.in_ch, 16, 3, 1), # 256
            ConvLayer(16, 16, 3, 1),
            ConvLayer(16, 16*2, 3, 2), # 128
            ConvLayer(16*2, 16*2, 3, 1),
            ConvLayer(16*2, 16*4, 3, 2), # 64
            ConvLayer(16*4, 16*4, 3, 1),
            ConvLayer(16*4, 16*8, 3, 2), # 32
            ConvLayer(16*8, 16*8, 3, 1), 
            ConvLayer(16*8, 16*16, 3, 2), # 16
            ConvLayer(16*16, 16*16, 3, 1), 
            ConvLayer(16*16, 16*32, 3, 2), # 8
            ConvLayer(16*32, 16*32, 3, 1), 
        )

        # [3 * 2] 크기의 아핀(affine) 행렬에 대해 예측
        self.fc_loc = nn.Sequential(
            nn.Linear(16*32*8*8, 32*8*8),
            nn.ReLU(True),
            nn.Linear(32*8*8, 8*8),
            nn.ReLU(True),
            nn.Linear(8*8, 16),
            nn.ReLU(True),
            nn.Linear(16, 2*3),
        )

        # 항등 변환(identity transformation)으로 가중치/바이어스 초기화
        self.fc_loc[6].weight.data.zero_()
        self.fc_loc[6].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # STN의 forward 함수
    def stn(self, fg, bg, mask):
        # x = x.type(torch.cuda.FloatTensor)
        mask = mask.float()
        inputs = torch.cat([fg, bg, mask], dim=1) #[B, (3+3+1), 256, 256]
        xs = self.localization(inputs)
        # print('xs shape:', xs.shape) #torch.Size([4, 10, 60, 60]) # [1, 128, 32, 32])
        xs = xs.view(-1, 16*32*8*8) #xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, fg.size(), align_corners=False)
        fg = F.grid_sample(fg, grid, align_corners=False)

        grid = F.affine_grid(theta, mask.size(), align_corners=False)
        mask = F.grid_sample(mask, grid, align_corners=False)
        
        # return fg, mask>0.5, theta
        return fg, mask, theta

    def forward(self, fg, bg, mask):
        # 입력을 변환
        fg, mask, trans_mat = self.stn(fg, bg, mask)
        return fg, mask, trans_mat
    
class ConvLayer(nn.Module):    
    def __init__(self, in_ch, out_ch, kernel_size, stride, pad='reflect', activation='leaky', normalization='batch'):        
        super(ConvLayer, self).__init__()
        
        # padding
        if pad == 'reflect':            
            self.pad = nn.ReflectionPad2d(kernel_size//2)
        elif pad == 'zero':
            self.pad = nn.ZeroPad2d(kernel_size//2)
        else:
            raise NotImplementedError("Not expected pad flag !!!")
    
            
        # convolution
        self.conv_layer = nn.Conv2d(in_ch, out_ch, 
                                    kernel_size=kernel_size,
                                    stride=stride)
        if normalization == 'spec':
            self.conv_layer = nn.utils.spectral_norm(self.conv_layer)
           
        
        # activation
        if activation == 'relu':
            self.activation = nn.ReLU()     
        elif activation == 'leaky':
            self.activation = nn.LeakyReLU(0.2)
        elif activation == 'linear':
            self.activation = lambda x : x

        else:
            raise NotImplementedError("Not expected activation flag !!!")

        # normalization 
        if normalization == 'instance':            
            self.normalization = nn.InstanceNorm2d(out_ch, affine=True)
        elif normalization == 'batch':
            self.normalization = nn.BatchNorm2d(out_ch, affine=True)
        elif normalization == 'spec':
            self.normalization = None
        else:
            raise NotImplementedError("Not expected normalization flag !!!")

    def forward(self, x):
        x = self.pad(x)
        x = self.conv_layer(x)
        if self.normalization:
            x = self.normalization(x)
        x = self.activation(x)        
        return x
    
class ResidualLayer(nn.Module):    
    def __init__(self, in_ch, out_ch, kernel_size, stride, pad='reflect', normalization='batch'):        
        super(ResidualLayer, self).__init__()
        
        self.conv1 = ConvLayer(in_ch, out_ch, kernel_size, stride, pad, 
                               activation='relu', 
                               normalization=normalization)
        
        self.conv2 = ConvLayer(out_ch, out_ch, kernel_size, stride, pad, 
                               activation='linear', 
                               normalization=normalization)
        
    def forward(self, x):
        y = self.conv1(x)
        return self.conv2(y) + x
    
class DeconvLayer(nn.Module):    
    def __init__(self, in_ch, out_ch, kernel_size, stride, pad='reflect', activation='leaky', normalization='batch', upsample='nearest'):        
        super(DeconvLayer, self).__init__()
        
        # upsample
        self.upsample = upsample
        
        # pad
        if pad == 'reflect':            
            self.pad = nn.ReflectionPad2d(kernel_size//2)
        elif pad == 'zero':
            self.pad = nn.ZeroPad2d(kernel_size//2)
        else:
            raise NotImplementedError("Not expected pad flag !!!")        
        
        # conv
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride)
        
        # activation
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'leaky':
            self.activation = nn.LeakyReLU(0.2)
        elif activation == 'linear':
            self.activation = lambda x : x
        else:
            raise NotImplementedError("Not expected activation flag !!!")
        
        # normalization
        if normalization == 'instance':
            self.normalization = nn.InstanceNorm2d(out_ch, affine=True)
        elif normalization == 'batch':
            self.normalization = nn.BatchNorm2d(out_ch, affine=True)
        else:
            raise NotImplementedError("Not expected normalization flag !!!")
        
    def forward(self, x):
        x = nn.functional.interpolate(x, scale_factor=2, mode=self.upsample)        
        x = self.pad(x)
        x = self.conv(x)
        x = self.normalization(x)        
        x = self.activation(x)        
        return x
    
class ImageDiscriminator(nn.Module):
    def __init__(self, in_ch, norm='spec'):
        super(ImageDiscriminator, self).__init__()
        self.layers = nn.Sequential(
            ConvLayer(in_ch, 64, 10, 4, normalization=norm, activation='leaky'), #256->64
            ConvLayer(64, 128, 10, 4, normalization=norm, activation='leaky'), #64->16
            ConvLayer(128, 256, 10, 4, normalization=norm, activation='leaky'), #16->4
            nn.Conv2d(256, 1, 5, 1), #4->1
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.layers(input)

In [6]:
import argparse

args = argparse.Namespace(
    ann_dir = '/shared/data/COCOdataset2017',
    data_dir = '/shared/GCCdataset/alltypes',
    save_model_dir = '/shared/GCC-GAN-server/models/',
    mode = 'train', # or 'test'
    batch_size = 8,
    input_size = 256,
    epochs = 8,
    lr = 2e-5,
    lambda_g = 1,
    lambda_c = 1,
    lambda_a = 0.01,
    lambda_s = 0.01,
    lambda_s2 = 1,
    beta = 0.5,
    test_interval = 50,
    device_id = 1,#[0, 1, 2, 3]
    vis_id = 'HCRS3'# feature_dim = ,
)

In [7]:
import os
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from apex.parallel import DistributedDataParallel as DDP

device = 'cuda:'+str(args.device_id) if torch.cuda.is_available() else 'cpu'

print('Device:', device)

aug_transform = A.Compose([
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        # A.OneOf([
        #     A.ShiftScaleRotate(rotate_limit=10, p=0.5, border_mode=cv2.BORDER_CONSTANT),
        #     A.geometric.transforms.Affine(translate_percent=0.05, shear=5, scale = 0.95),
        # ], p=1.0),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],),
        ToTensorV2(),
])

basic_transform = A.Compose([
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],),
        ToTensorV2(),
])


num_workers = 8 # 4 * len(args.device_ids)
print('# Workers:', num_workers)

dataset = MSCOCODataset(args, aug_trans=aug_transform, basic_trans=basic_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

Device: cuda:1
# Workers: 8
loading annotations into memory...
Done (t=20.92s)
creating index...
index created!


In [11]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torch.optim import Adam
import visdom
import os
# from multiprocessing import Pool

def denorm(tensor):
    mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
    std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
    res = torch.clamp(tensor * std + mean, 0, 1)
    return res

def tensor2im(image_tensor, imtype=np.uint8):
    image_numpy = image_tensor.numpy() #image_tensor[0].permute(1,2,0).detach().cpu().float().numpy()
    # image_numpy = image_numpy * 0.5 - 0.5
    image_numpy = image_numpy * 255
    # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
    # image_numpy = (image_numpy + 1) / 2.0 * 255.0
    image_numpy = np.clip(image_numpy, 0, 255)

    return image_numpy.astype(imtype)

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
def save_checkpoint(epoch, model, optimizer, filename):
    state = {
        'Epoch': epoch,
        'State_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(state, filename)

if __name__ == '__main__':
    # Visdom display initialization
    vis = visdom.Visdom(env=args.vis_id)
    vis.close(env=args.vis_id)
    plot = vis.line(Y=torch.Tensor(1).zero_(), opts=dict(title='l2norm'), env=args.vis_id) #final loss
    plot2 = vis.line(Y=torch.Tensor(1).zero_(), opts=dict(title='Losses', legend=["Loss_a", "Loss_s", "Loss_c", "Loss_g"], showlegend=True), env=args.vis_id)
    plot3 = vis.line(Y=torch.Tensor(1).zero_(), opts=dict(title='GAN losses', legend=["Discriminator_real", "Discriminator_fake", "Generator"], showlegend=True), env=args.vis_id)

    vis_inputs = vis.images(np.random.rand(1, 1,256,256), opts=dict(title="Inputs"), env=args.vis_id)
    vis_FI = vis.images(np.random.rand(1,3,256,256), opts=dict(title="F(I)"), env=args.vis_id)
    vis_HI = vis.images(np.random.rand(1,3,256,256), opts=dict(title="H(I)"), env=args.vis_id)
    vis_AI = vis.images(np.random.rand(1,1,256,256), opts=dict(title="A(I)"), env=args.vis_id)
    vis_R_in = vis.images(np.random.rand(1,3,256,256), opts=dict(title="R_in"), env=args.vis_id)
    vis_R_out = vis.images(np.random.rand(1,3,256,256), opts=dict(title="Composite"), env=args.vis_id)
    vis_seg = vis.images(np.random.rand(1,1,256,256), opts=dict(title="Segmentation"), env=args.vis_id)

    l1_loss = nn.L1Loss().to(device)
    l2_loss = nn.MSELoss().to(device)
    criterion = nn.BCEWithLogitsLoss().to(device)
    
    net = GenCompModel(args) # generator
    net.apply(weights_init)
    
    disc = Discriminator()
    disc.apply(weights_init)
    
    net.to(device)
    disc.to(device)
    
    net.train()
    disc.train()

    real_label = 1.
    fake_label = 0.

    optimizerD = Adam(disc.parameters(), lr=args.lr, betas=(args.beta, 0.999))
    optimizer = Adam(net.parameters(), lr=args.lr, betas=(args.beta, 0.999))

    SMOOTH = 1e-6
    
    for epoch in range(args.epochs):
                
        for i, items in enumerate(dataloader):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ########## Train with all-real batch
            gt = items[0].to(device); fg = items[1].to(device);
            bg = items[2].to(device); mask = items[3].to(device);
            tf_fg = items[4].to(device); tf_mask = items[5].to(device);
            trans_mat = items[6].to(device) # cat = items[6].to(device);
            B, C, H, W = gt.shape

            real_value = torch.full((B,), real_label, dtype=torch.float, device=device)
            fake_value = torch.full((B,), fake_label, dtype=torch.float, device=device)
            real_map = torch.full((B, 1, H, W), real_label, dtype=torch.float, device=device)
            fake_map = torch.full((B, 1, H, W), fake_label, dtype=torch.float, device=device)
            
            gt_img_out, gt_fg_seg, gt_bg_seg = disc(mask, gt)#.view(-1) #[b,1,1,1,] -> [b]

            d_loss_real = criterion(gt_img_out.view(-1), real_value.clone()) #gt, real
            ds_loss_real = criterion(gt_bg_seg, torch.multiply(~mask, real_map.clone())) + criterion(gt_fg_seg, torch.multiply(mask, real_map.clone()))#criterion(gt_seg_out, real_map.clone()) #gt, real

            ########## Train with all-fake batch
            # Generate batch of latent vectors
            
            AI, HI, FI, pred_mat, R_in, R_out = net(tf_fg, bg, tf_mask)
            R_img_out, R_fg_seg, R_bg_seg = disc(AI, R_out) #torch.Size([8, 1, 1, 1]), torch.Size([8, 1, 256, 256])
            
            d_loss_fake = criterion(R_img_out.view(-1), fake_value.clone()) #gen, fake
            ds_loss_fake = criterion(R_bg_seg, torch.multiply(~AI, real_map.clone())) + criterion(R_fg_seg, torch.multiply(AI, fake_map.clone()))
            
            dsc_loss = d_loss_real + d_loss_fake + 0.01 * (ds_loss_real + ds_loss_fake) # +d_loss_reg
            optimizer.zero_grad()
            optimizerD.zero_grad()
            dsc_loss.backward()
            optimizerD.step()
  
            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            
            AI, HI, FI, pred_mat, R_in, R_out = net(tf_fg, bg, tf_mask)
            
            R_img_out, R_fg_seg, R_bg_seg = disc(AI, R_out)
            g_loss_fake = criterion(R_img_out.view(-1), real_value.clone())# + criterion(torch.multiply(AI, R_img_out), real_map.clone()) #criterion(R_img_out.view(-1), real_value.clone()) # err_G = loss_a
            
            gs_loss_fake = criterion(R_bg_seg, torch.multiply(~AI, real_map.clone())) + criterion(R_fg_seg, torch.multiply(AI, real_map.clone())) #  criterion(R_seg_out, real_map.clone())# + criterion(R_seg_out, fake_map.clone())

            lambda_mask = 1
            
            loss_g = l2_loss(trans_mat.float(), pred_mat.float()) # lambda_mask * torch.exp(-torch.sum(AI, (1,2,3), dtype=torch.float).mean()/args.input_size**2) + torch.linalg.matrix_norm(pred_mat.float(), ord=2).mean()+ + torch.linalg.matrix_norm(pred_mat.float(), ord=2).mean()
            loss_c = torch.linalg.matrix_norm(torch.multiply(HI-FI, AI), ord=1).mean() / torch.linalg.matrix_norm(AI.float(), ord=1).mean() + SMOOTH # / (torch.sum(AI, (1,2,3), dtype=torch.float).mean() + SMOOTH)# + torch.linalg.matrix_norm(AI.float(), ord=1)
            loss_s2 = l1_loss(gt, R_out)
            
            loss = args.lambda_g * loss_g + args.lambda_c * loss_c + args.lambda_a * g_loss_fake + args.lambda_s * gs_loss_fake + args.lambda_s2 * loss_s2
            
            optimizer.zero_grad()
            optimizerD.zero_grad()
            loss.backward()
            optimizer.step()

            if i % args.test_interval == 0:
                print('[{:d} epoch {:d}/{:d}]: c:{:f}, g:{:f}, dsc_loss:{:f}, a:{:f}, s:{:f}, total:{:f}'.format(epoch, i, len(dataloader)-1, loss_c.item(), loss_g.item(), dsc_loss.item(), g_loss_fake.item(), gs_loss_fake.item(), loss.item()), end='\r')

                vis.line(Y=torch.Tensor([torch.linalg.matrix_norm(pred_mat.float(), ord=2).mean().item()]), X=torch.Tensor([epoch*len(dataloader)+i]), win=plot, update='append', env=args.vis_id)
                vis.line(Y=torch.Tensor([args.lambda_a*g_loss_fake.item()]), X=torch.Tensor([epoch*len(dataloader)+i]), name='Loss_a', win=plot2, update='append', env=args.vis_id)
                vis.line(Y=torch.Tensor([args.lambda_s*gs_loss_fake.item()]), X=torch.Tensor([epoch*len(dataloader)+i]), name='Loss_s', win=plot2, update='append', env=args.vis_id)
                vis.line(Y=torch.Tensor([args.lambda_c*loss_c.item()]), X=torch.Tensor([epoch*len(dataloader)+i]), name='Loss_c', win=plot2, update='append', env=args.vis_id)
                vis.line(Y=torch.Tensor([args.lambda_g*loss_g.item()]), X=torch.Tensor([epoch*len(dataloader)+i]), name='Loss_g', win=plot2, update='append', env=args.vis_id)
                vis.line(Y=torch.Tensor([args.lambda_s2*loss_s2.item()]), X=torch.Tensor([epoch*len(dataloader)+i]), name='Loss_s2', win=plot2, update='append', env=args.vis_id)

                vis.line(Y=torch.Tensor([d_loss_real.item()+ds_loss_real.item()]), X=torch.Tensor([epoch*len(dataloader)+i]), name='Discriminator_real', win=plot3, update='append', env=args.vis_id)
                vis.line(Y=torch.Tensor([d_loss_fake.item()+ds_loss_fake.item()]), X=torch.Tensor([epoch*len(dataloader)+i]), name='Discriminator_fake', win=plot3, update='append', env=args.vis_id)
                vis.line(Y=torch.Tensor([loss.item()]), X=torch.Tensor([epoch*len(dataloader)+i]), name='Generator', win=plot3, update='append', env=args.vis_id)

                inputs = tensor2im(denorm((torch.multiply(tf_fg, tf_mask) + torch.multiply(bg, ~tf_mask)).detach().cpu()).float())
                FI_out = tensor2im(denorm(FI.detach().cpu()).float())
                HI_out = tensor2im(denorm(HI.detach().cpu()).float())
                AI_out = tensor2im(AI.detach().cpu())
                R_in = tensor2im(denorm(R_in.detach().cpu()).float())
                seg_out = tensor2im((R_fg_seg+R_bg_seg).detach().cpu())
                R_out = tensor2im(denorm(R_out.detach().cpu()).float())

                vis.images(inputs, win=vis_inputs, opts=dict(title="Inputs"), env=args.vis_id)
                vis.images(HI_out, win=vis_HI, opts=dict(title="H(I)"), env=args.vis_id)
                vis.images(FI_out, win=vis_FI, opts=dict(title="F(I)"), env=args.vis_id)
                vis.images(AI_out, win=vis_AI, opts=dict(title="A(I)"), env=args.vis_id)
                vis.images(R_in, win=vis_R_in, opts=dict(title="R_in"), env=args.vis_id)
                vis.images(R_out, win=vis_R_out, opts=dict(title="R_out"), env=args.vis_id)
                vis.images(seg_out, win=vis_seg, opts=dict(title="Segmap"), env=args.vis_id)
            
            if i % 10000 == 0:
                save_checkpoint(epoch, net, optimizer, args.save_model_dir + 'netHCRS-free3.pth')
                save_checkpoint(epoch, disc, optimizerD, args.save_model_dir + 'discHCRS-free3.pth')
        

Setting up a new session...


[7 epoch 84750/84835]: c:0.000010, g:0.005206, dsc_loss:1.053176, a:0.645480, s:1.064361, total:0.311907