# **Install Library**

In [None]:
!pip install segmentation-models-pytorch
!pip install pytorch-lightning==1.8.3.post0
!pip install torchsummary
!pip install vit-pytorch
!pip install timm

In [None]:
#Prevent GPU from Memory Growth
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [None]:
import os
import sys
import random
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import cv2
from tqdm import tqdm_notebook, tnrange
from itertools import chain
from skimage.io import imread, imshow, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from keras.models import Model, load_model
from keras.layers import Input, Lambda, Conv2D, Conv2DTranspose, MaxPooling2D, concatenate 
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from tensorflow.keras import metrics
from tensorflow.keras.utils import register_keras_serializable


# **Define Mean IoU Metric**

In [None]:
@register_keras_serializable()
class MeanIoUMetric(metrics.Metric):
    
    def __init__(self, num_classes, name='mean_iou', **kwargs):
        super(MeanIoUMetric, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes 
        self.iou_metric = metrics.MeanIoU(num_classes=num_classes) 

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred_ = tf.cast(y_pred > 0.5, tf.int32)  
        self.iou_metric.update_state(y_true, y_pred_)

    def result(self):
        return self.iou_metric.result()

    def reset_state(self):
        self.iou_metric.reset_state()
        
    @classmethod
    def from_config(cls, config):
        return cls(**config)  
    def get_config(self):
        config = super().get_config()
        config.update({"num_classes": self.num_classes})
        return config
    
    
mean_iou_metric = MeanIoUMetric(num_classes=2)


# **Load CNN Model**

In [None]:
Model_Supervised = tf.keras.models.load_model(
    '/kaggle/working/unetplusplus-save-model-keras/UNET_plusplus.keras',
    custom_objects={'MeanIoUMetric': MeanIoUMetric} ,safe_mode=False
)

# **Load Checkboard**

In [None]:
from timm.models.swin_transformer import SwinTransformer
from vit_pytorch import ViT
import cv2
import numpy as np
import random
import torch
import matplotlib.pyplot as plt
import copy
import torch.nn as nn
from segmentation_models_pytorch import UnetPlusPlus


all_attention_loss  = []

def random_checkboard_mask_new(img, ratio_n=None):

    if ratio_n == None:
        random_value = torch.rand(1)

    if random_value < 1/6:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_0.npy")

    elif 1/6 < random_value < 2/6:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_1.npy")

    elif 2/6 < random_value < 3/6:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_2.npy")

    elif 3/6 < random_value < 4/6:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_3.npy")

    elif 4/6 < random_value < 5/6:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_4.npy")

    else:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_5.npy")

    return mask


# **Dataset**

In [None]:
from PIL import Image
import torch.utils.data as data
import random
import cv2 as cv
from torchvision import datasets, transforms


class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, subfolder='train', direction='AtoB', transform=None, resize_scale=None, crop_size=None, fliplr=False):
        super(DatasetFromFolder, self).__init__()
        if direction == 'AtoB':
            self.input_path = os.path.join(image_dir, subfolder, 'a')
            self.target_path = os.path.join(image_dir, subfolder, 'b')
        else:
            self.input_path = os.path.join(image_dir, subfolder, 'b')
            self.target_path = os.path.join(image_dir, subfolder, 'a')

        self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        #self.image_filenames = self.image_filenames[:200]

        
     
        self.direction = direction
        self.transform = transform
        self.resize_scale = resize_scale    
        self.crop_size = crop_size 
        self.fliplr = fliplr    

    def __getitem__(self, index):
        
        
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        img_tar = os.path.join(self.target_path, self.image_filenames[index])
        img_input = cv.imread(img_fn)
        img_target = cv.imread(img_tar)
        img_target2 = cv.imread(img_tar)
        
        stride = False
        
        if stride:
            pass
        
        else:
            mask1 = random_checkboard_mask_new(img_input, None)
        
        img_input = cv.resize(img_input, (256, 256))
        img_input = img_input * mask1
        
        

        if self.resize_scale:
            img_input = cv.resize(img_input, (self.resize_scale, self.resize_scale))
            img_target = cv.resize(img_target, (self.resize_scale, self.resize_scale))


        if self.crop_size:
            
            x = random.randint(0, self.resize_scale - self.crop_size)
            y = random.randint(0, self.resize_scale - self.crop_size)
            
            img_input = img_input[x : x + self.crop_size, y:y+self.crop_size, :]
            img_target = img_target[x : x + self.crop_size, y:y+self.crop_size, :]


        if self.fliplr:
            if random.random() < 0.5:
                
                img_input = cv.flip(img_input, 1)
                img_target = cv.flip(img_target, 1)

        img_input = transforms.ToPILImage()(img_input)
        img_target = transforms.ToPILImage()(img_target)
        
        if self.transform is not None:
            img_input = self.transform(img_input)
            img_target = self.transform(img_target)
        img_target2 = cv.resize(img_target2, (256, 256))
        return img_input, img_target, mask1 ,img_target2

    def __len__(self):
        return len(self.image_filenames)

# **Model Generator and  Discriminator**

In [None]:
import torch
from torchsummary import summary


class Generator(torch.nn.Module):
    def __init__(self, input_dim, num_filter, output_dim):
        super(Generator, self).__init__()

        self.model = UnetPlusPlus(
            encoder_name="resnet34", 
            encoder_weights="imagenet", 
            in_channels=input_dim,         
            classes=output_dim,            
            decoder_use_batchnorm=True
        )
        
        
    def forward(self, x):
        return self.model(x)
    
class SwinTDiscriminator(nn.Module):
    def __init__(self, image_size=256, patch_size=4, num_classes=1):
        super(SwinTDiscriminator, self).__init__()
        self.swin_t = SwinTransformer(
            img_size=image_size, 
            patch_size=patch_size, 
            in_chans=6, 
            num_classes=num_classes, 
            embed_dim=48, 
            depths=[1, 1, 3, 1], 
            num_heads=[8,16,32,64], 
            window_size=8, 
            mlp_ratio=2., 
            qkv_bias=True, 
            qk_scale=None, 
            drop_rate=0., 
            attn_drop_rate=0., 
            drop_path_rate=0.1, 
            norm_layer=nn.LayerNorm, 
            ape=False, 
            patch_norm=True, 
            use_checkpoint=False
        )

    def forward(self, x, label):
        x = torch.cat([x, label], 1)  
        x = self.swin_t(x)  
        x = torch.sigmoid(x) 
        return x
        


# **Utils**

In [None]:
import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import os
import imageio
import cv2 as cv
from numpy.linalg import norm


def to_np(x):
    return x.data.cpu().numpy()


def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)


def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)


def normalization(data):
    _range = np.max(data) - np.min(data)
    return (data - np.min(data)) / _range


import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO  # Python 3.x

class Logger(object):
    def __init__(self, log_dir):
        """Create a summary writer logging to log_dir."""
        self.writer = tf.summary.FileWriter(log_dir)

# **Losses**

In [None]:
#MSGMS_Loss
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import kornia

use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

class Prewitt(nn.Module):
    def __init__(self):
        super().__init__()
        self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=0, bias=False)
        Gx = torch.tensor([[-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0]]) / 3
        Gy = torch.tensor([[1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -1.0, -1.0]]) / 3
        G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
        G = G.unsqueeze(1).to(device)
        self.filter.weight = nn.Parameter(G, requires_grad=False)

    def forward(self, img):
        x = self.filter(img)
        x = torch.mul(x, x)
        x = torch.sum(x, dim=1, keepdim=True)
        x = torch.sqrt(x)
        return x


# Define the gradient magnitude similarity map:
def GMS(Ii, Ir, edge_filter, median_filter, c=0.0026):
    x = torch.mean(Ii, dim=1, keepdim=True)
    y = torch.mean(Ir, dim=1, keepdim=True)
    g_I = edge_filter(median_filter(x))
    g_Ir = edge_filter(median_filter(y))
    g_map = (2 * g_I * g_Ir + c) / (g_I**2 + g_Ir**2 + c)
    return g_map


class MSGMS_Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.GMS = partial(GMS, edge_filter=Prewitt(), median_filter=kornia.filters.MedianBlur((3, 3)))

    def GMS_loss(self, Ii, Ir):
        return torch.mean(1 - self.GMS(Ii, Ir))

    def forward(self, Ii, Ir):
        total_loss = self.GMS_loss(Ii, Ir)

        for _ in range(3):
            Ii = F.avg_pool2d(Ii, kernel_size=2, stride=2)
            Ir = F.avg_pool2d(Ir, kernel_size=2, stride=2)
            total_loss += self.GMS_loss(Ii, Ir)

        return total_loss / 4


class MSGMS_Score(nn.Module):
    def __init__(self):
        super().__init__()
        self.GMS = partial(GMS, edge_filter=Prewitt(), median_filter=kornia.filters.MedianBlur((3, 3)))
        self.median_filter = kornia.filters.MedianBlur((21, 21))

    def GMS_Score(self, Ii, Ir):
        return self.GMS(Ii, Ir)

    def forward(self, Ii, Ir):
        total_scores = self.GMS_Score(Ii, Ir)
        img_size = Ii.size(-1)
        total_scores = F.interpolate(total_scores, size=img_size, mode='bilinear', align_corners=False)
        for _ in range(3):
            Ii = F.avg_pool2d(Ii, kernel_size=2, stride=2)
            Ir = F.avg_pool2d(Ir, kernel_size=2, stride=2)
            score = self.GMS_Score(Ii, Ir)
            total_scores += F.interpolate(score, size=img_size, mode='bilinear', align_corners=False)

        return (1 - total_scores) / 4


#PS_LOSS
import torch
import torch.nn as nn
import torchvision.models as models



class StyleLoss(nn.Module):

    def __init__(self):
        super(StyleLoss, self).__init__()
        self.add_module('vgg', VGG19().cuda())
        self.criterion = torch.nn.L1Loss()

    def compute_gram(self, x):
        b, ch, h, w = x.size()
        f = x.view(b, ch, w * h)
        f_T = f.transpose(1, 2)
        G = f.bmm(f_T) / (h * w * ch)

        return G

    def __call__(self, x, y):
        # Compute features
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)

        # Compute loss
        style_loss = 0.0
        style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
        style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
        style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
        style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))

        return style_loss


class PerceptualLoss(nn.Module):
    r"""
    Perceptual loss, VGG-based
    https://arxiv.org/abs/1603.08155
    https://github.com/dxyang/StyleTransfer/blob/master/utils.py
    """

    def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
        super(PerceptualLoss, self).__init__()
        self.add_module('vgg', VGG19().cuda())
        self.criterion = torch.nn.L1Loss().cuda()
        self.weights = weights

    def __call__(self, x, y):
        # Compute features
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)

        content_loss = 0.0
        content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
        content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
        content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
        content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
        content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])


        return content_loss



class VGG19(torch.nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        features = models.vgg19(pretrained=True).features
        self.relu1_1 = torch.nn.Sequential()
        self.relu1_2 = torch.nn.Sequential()

        self.relu2_1 = torch.nn.Sequential()
        self.relu2_2 = torch.nn.Sequential()

        self.relu3_1 = torch.nn.Sequential()
        self.relu3_2 = torch.nn.Sequential()
        self.relu3_3 = torch.nn.Sequential()
        self.relu3_4 = torch.nn.Sequential()

        self.relu4_1 = torch.nn.Sequential()
        self.relu4_2 = torch.nn.Sequential()
        self.relu4_3 = torch.nn.Sequential()
        self.relu4_4 = torch.nn.Sequential()

        self.relu5_1 = torch.nn.Sequential()
        self.relu5_2 = torch.nn.Sequential()
        self.relu5_3 = torch.nn.Sequential()
        self.relu5_4 = torch.nn.Sequential()

        for x in range(2):
            
            self.relu1_1.add_module(str(x), features[x])

        for x in range(2, 4):
            self.relu1_2.add_module(str(x), features[x])

        for x in range(4, 7):
            self.relu2_1.add_module(str(x), features[x])

        for x in range(7, 9):
            self.relu2_2.add_module(str(x), features[x])

        for x in range(9, 12):
            self.relu3_1.add_module(str(x), features[x])

        for x in range(12, 14):
            self.relu3_2.add_module(str(x), features[x])

        for x in range(14, 16):
            self.relu3_3.add_module(str(x), features[x])

        for x in range(16, 18):
            self.relu3_4.add_module(str(x), features[x])

        for x in range(18, 21):
            self.relu4_1.add_module(str(x), features[x])

        for x in range(21, 23):
            self.relu4_2.add_module(str(x), features[x])

        for x in range(23, 25):
            self.relu4_3.add_module(str(x), features[x])

        for x in range(25, 27):
            self.relu4_4.add_module(str(x), features[x])

        for x in range(27, 30):
            self.relu5_1.add_module(str(x), features[x])

        for x in range(30, 32):
            self.relu5_2.add_module(str(x), features[x])

        for x in range(32, 34):
            self.relu5_3.add_module(str(x), features[x])

        for x in range(34, 36):
            self.relu5_4.add_module(str(x), features[x])

        # don't need the gradients, just want the features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        relu1_1 = self.relu1_1(x)
        relu1_2 = self.relu1_2(relu1_1)

        relu2_1 = self.relu2_1(relu1_2)
        relu2_2 = self.relu2_2(relu2_1)

        relu3_1 = self.relu3_1(relu2_2)
        relu3_2 = self.relu3_2(relu3_1)
        relu3_3 = self.relu3_3(relu3_2)
        relu3_4 = self.relu3_4(relu3_3)

        relu4_1 = self.relu4_1(relu3_4)
        relu4_2 = self.relu4_2(relu4_1)
        relu4_3 = self.relu4_3(relu4_2)
        relu4_4 = self.relu4_4(relu4_3)

        relu5_1 = self.relu5_1(relu4_4)
        relu5_2 = self.relu5_2(relu5_1)
        relu5_3 = self.relu5_3(relu5_2)
        relu5_4 = self.relu5_4(relu5_3)

        out = {
            'relu1_1': relu1_1,
            'relu1_2': relu1_2,

            'relu2_1': relu2_1,
            'relu2_2': relu2_2,

            'relu3_1': relu3_1,
            'relu3_2': relu3_2,
            'relu3_3': relu3_3,
            'relu3_4': relu3_4,

            'relu4_1': relu4_1,
            'relu4_2': relu4_2,
            'relu4_3': relu4_3,
            'relu4_4': relu4_4,

            'relu5_1': relu5_1,
            'relu5_2': relu5_2,
            'relu5_3': relu5_3,
            'relu5_4': relu5_4,
        }
        return out


        
# SSIM:
def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)
            
            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)

#New Pore Attention Loss

def get_pore_attention_mask(Model_Supervised, image, y_, epoch):

    batch_size = image.shape[0]  
    masks = []

    for i in range(batch_size):
        img = image[i, :, :, :].cpu().numpy()
        img = cv2.resize(img, (512, 512))
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 
        img = np.expand_dims(img, axis=-1)
        img = np.expand_dims(img, axis=0)

        pred_mask = Model_Supervised.predict(img, verbose=0)  
        pred_mask = (pred_mask > 0.3).astype(np.uint8)  
        pred_mask = resize(pred_mask, (1, 256, 256, 1), mode='constant', preserve_range=True)
        masks.append(pred_mask)  

    masks = np.concatenate(masks, axis=0) 
 
    return masks


def attention_guided_l1_loss(gen_image, target_image, attention_mask, epoch):


    attention_mask = torch.from_numpy(attention_mask).to('cuda')
    attention_mask = attention_mask.permute(0, 3, 1, 2) 

    masked_gen_image = gen_image * attention_mask
    masked_target_image = target_image * attention_mask
    masked_l1_loss = L1_loss(masked_gen_image, masked_target_image)

    
    masked_gen_image = masked_gen_image.cpu().detach().numpy()  
    masked_target_image = masked_target_image.cpu().detach().numpy() 

#     if (epoch+1) % 5 == 0:
#         plt.figure(figsize=(10, 5))
#         plt.subplot(1, 3, 1)
#         plt.imshow(masked_gen_image[0,0,:,:], cmap='gray') 
#         plt.title(" gen_image * attention_mask")

#         plt.subplot(1, 3, 2)
#         plt.imshow(masked_target_image[0,0,:,:], cmap='gray') 
#         plt.title("target_image * attention_mask")
#         plt.show()


    return masked_l1_loss



# **Train**

In [None]:
from torchvision import transforms
from torch.autograd import Variable
from math import exp
import torch.nn.functional as F
import time
from torch.utils.tensorboard import SummaryWriter 
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize 
import time


direction = 'AtoB'
batch_size= 16
ngf = 64
ndf = 64
input_size = 256
resize_scale = None
crop_size = None
fliplr = False
num_epochs = 100
lrG = 0.0008
lrD = 0.0002
lamb = 100.0
beta1 = 0.5
beta2= 0.999
writer = SummaryWriter('./path/log1')
data_dir = '/kaggle/input/path dataset'
model_dir = './saved-model/'


if not os.path.exists(model_dir):
    os.mkdir(model_dir)

transform = transforms.Compose([transforms.Resize(input_size),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])


train_data = DatasetFromFolder(data_dir, subfolder='train', direction=direction, 
                resize_scale=resize_scale,  transform=transform, crop_size=crop_size, fliplr=fliplr)


train_data_loader = torch.utils.data.DataLoader(dataset=train_data,
                                                batch_size=batch_size,
                                                shuffle=True, pin_memory=True, num_workers=72, prefetch_factor=20, persistent_workers=True)

test_data = DatasetFromFolder(data_dir, subfolder='validation', direction=direction, transform=transform)

test_data_loader = torch.utils.data.DataLoader(dataset=test_data,
                                               batch_size=batch_size,
                                               shuffle=True)


G = Generator(3, ngf, 3)
D = SwinTDiscriminator()
if torch.cuda.device_count() > 1:  # If we have multiple GPUs
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    D = nn.DataParallel(D)
    G = nn.DataParallel(G)
D.cuda()
G.cuda()


BCE_loss = torch.nn.BCELoss().cuda()
L1_loss = torch.nn.L1Loss().cuda()
L2_loss = torch.nn.MSELoss().cuda()


perceptual_loss = PerceptualLoss().cuda()
style_loss = StyleLoss().cuda()
msgms_loss = MSGMS_Loss().cuda()

G_optimizer = torch.optim.Adam(G.parameters(), lr=lrG, betas=(beta1, beta2))
D_optimizer = torch.optim.Adam(D.parameters(), lr=lrD, betas=(beta1, beta2))

def adjust_learning_rate1(optimizer, epoch):
    lr = 0.0001*(0.99**(epoch))
    print("lr is {}".format(lr))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def adjust_learning_rate2(optimizer, epoch):
    lr = 0.0004*(0.99**(epoch))
    print("lr is {}".format(lr))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


D_avg_losses = []
G_avg_losses = []

step = 0

loss_L1 = False
loss_L1_Style = False
loss_L1_SSIM_GMS = False
loss_L1_SSIM_GMS_Style = True

best_val_loss = 10000000000

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('./path/log1')

# Initialize SSIM function
ssim_func = SSIM().cuda()

for epoch in range(num_epochs):
    time1 = time.time()
    D_losses = []
    G_losses = []
    adjust_learning_rate1(G_optimizer, epoch)
    adjust_learning_rate2(D_optimizer, epoch)
    train_ssim_values = []
    for i, (input, target, mask,img_target2) in enumerate(train_data_loader):
        
        x_ = Variable(input.cuda())
        y_ = Variable(target.cuda())
        
    
        D_real_decision = D(x_, y_).squeeze()
        real_ = Variable(torch.ones(D_real_decision.size()).cuda())
        D_real_loss = BCE_loss(D_real_decision, real_)
    
        gen_image = G(x_)
        D_fake_decision = D(x_, gen_image).squeeze()
        fake_ = Variable(torch.zeros(D_fake_decision.size()).cuda())
        D_fake_loss = BCE_loss(D_fake_decision, fake_)
    
        D_loss = (D_real_loss + D_fake_loss) * 0.5
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()
    
        gen_image = G(x_)
        
        with torch.no_grad():
            ssim_value = ssim_func(gen_image, y_)
            train_ssim_values.append(ssim_value.item())
        
        attention_mask = get_pore_attention_mask(Model_Supervised,img_target2, y_ , epoch)
        attention_loss = attention_guided_l1_loss(gen_image, y_, attention_mask,epoch)
    
        
        D_fake_decision = D(x_, gen_image).squeeze()
        G_fake_loss = BCE_loss(D_fake_decision, real_) # fool the discriminator into classifying the generated image as real.
          
        if loss_L1 == True:
            # using pure L1 loss
            l1_loss = lamb * L1_loss(gen_image, y_)
        elif loss_L1_SSIM_GMS_Style == True:
            loss_MSGMS = msgms_loss(gen_image, y_)
            loss_SSIM = 1 - ssim(gen_image, y_)
            gen_style_loss = style_loss(gen_image, y_) * 10
            l_rec = gen_style_loss + loss_MSGMS + loss_SSIM + L1_loss(gen_image, y_)
            l1_loss = lamb * l_rec
            
        
        G_loss = G_fake_loss + l1_loss + attention_loss 
        all_attention_loss.append(attention_loss)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        # loss values
        D_losses.append(D_loss.item())
        G_losses.append(G_loss.item())
        
        print('Epoch [%d/%d], Step [%d/%d], D_loss: %.4f, G_loss: %.4f '
              % (epoch+1, num_epochs, i+1, len(train_data_loader), D_loss.item(), G_loss.item()))
        step += 1
        
    D_avg_loss = torch.mean(torch.FloatTensor(D_losses))
    G_avg_loss = torch.mean(torch.FloatTensor(G_losses))
    
    D_avg_losses.append(D_avg_loss)
    G_avg_losses.append(G_avg_loss)
    
    avg_train_ssim = sum(train_ssim_values) / len(train_ssim_values)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train SSIM: {avg_train_ssim:.4f}')
    
    writer.add_scalar('G_loss_mean', torch.mean(torch.FloatTensor(G_losses)), epoch)
    writer.add_scalar('D_loss_mean', torch.mean(torch.FloatTensor(D_losses)), epoch)

    time2 = time.time()
    print("Time for Each Epoch is {}".format(time2 - time1))
    
    if (epoch+1) % 2 == 0: 
        val_losses = 0.00
        # time_start = time.time()
        for i, (input, target, mask,img_target2) in enumerate(test_data_loader):
            x_ = Variable(input.cuda())
            y_ = Variable(target.cuda())
            
            with torch.no_grad():
                gen_image = G(x_)
            if loss_L1 == True:
            # using pure L1 loss
                l1_loss =  L1_loss(gen_image, y_)
            elif loss_L1_SSIM_GMS_Style == True:
                loss_MSGMS = msgms_loss(gen_image, y_)
                loss_SSIM = 1 - ssim(gen_image, y_)
                gen_style_loss = style_loss(gen_image, y_) * 10
                l_rec = gen_style_loss + loss_MSGMS + loss_SSIM + L1_loss(gen_image, y_)
                l1_loss = lamb * l_rec

            loss_all = l1_loss
            val_losses += loss_all

        if val_losses < best_val_loss:
            best_val_loss = min(best_val_loss, val_losses)
            print("best_val_loss is {}".format(best_val_loss))
            torch.save(G.state_dict(), model_dir + 'best_G_param.pkl')
            torch.save(D.state_dict(), model_dir + 'best_D_param.pkl')
            print("the best model is epoch_{}".format(epoch + 1))


    if (epoch+1) % 5 == 0:
        torch.save(G.state_dict(), model_dir + '%d'%(epoch +1) +'generator_param.pkl')
        torch.save(D.state_dict(), model_dir + '%d'%(epoch +1) + 'discriminator_param.pkl')
    

In [None]:
import matplotlib.pyplot as plt

epochs_ssim = range(len(train_ssim_values))  
# Plot SSIM 
plt.figure(figsize=(10, 5))
plt.plot(epochs_ssim, train_ssim_values, label='SSIM', color='red')
plt.xlabel('Epoch')
plt.ylabel('SSIM Value')
plt.title('Structural Similarity Index (SSIM) During Training')
plt.legend()
#plt.grid(True)
plt.savefig("Structural Similarity Index.png", dpi=600) 
plt.show() 

# **Test**

In [None]:
import numpy as np
import torch
from torch.autograd import Variable
import os
import torch.nn as nn
from PIL import Image
import torch.utils.data as data
import random
import cv2 as cv
from torchvision import datasets, transforms
from segmentation_models_pytorch import UnetPlusPlus
from torchsummary import summary
from timm.models.swin_transformer import SwinTransformer
from vit_pytorch import ViT
import cv2
import matplotlib.pyplot as plt
import copy


data_dir = '/kaggle/input/path dataset'
save_error_dir = '/kaggle/working/result/'

ngf = 64
ndf = 64

def random_checkboard_mask_new(img, ratio_n=None):

    if ratio_n == None:
        random_value = torch.rand(1)

    if random_value < 1/6:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_0.npy")

    elif 1/6 < random_value < 2/6:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_1.npy")

    elif 2/6 < random_value < 3/6:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_2.npy")

    elif 3/6 < random_value < 4/6:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_3.npy")

    elif 4/6 < random_value < 5/6:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_4.npy")

    else:
        mask = np.load("/kaggle/input/ck-mask/ck_mask/ck_5.npy")

    return mask



class Generator(torch.nn.Module):
    def __init__(self, input_dim, num_filter, output_dim):
        super(Generator, self).__init__()

        self.model = UnetPlusPlus(
            encoder_name="resnet34", 
            encoder_weights="imagenet", 
            in_channels=input_dim,         
            classes=output_dim,            
            decoder_use_batchnorm=True
        )
        
    def forward(self, x):
        return self.model(x)
    
    
class SwinTDiscriminator(nn.Module):
    def __init__(self, image_size=256, patch_size=4, num_classes=1):
        super(SwinTDiscriminator, self).__init__()
        self.swin_t = SwinTransformer(
            img_size=image_size, 
            patch_size=patch_size, 
            in_chans=6, 
            num_classes=num_classes, 
            embed_dim=48, 
            depths=[1, 1, 3, 1], 
            num_heads=[8,16,32,64], 
            window_size=8, 
            mlp_ratio=2., 
            qkv_bias=True, 
            qk_scale=None, 
            drop_rate=0., 
            attn_drop_rate=0., 
            drop_path_rate=0.1, 
            norm_layer=nn.LayerNorm, 
            ape=False, 
            patch_norm=True, 
            use_checkpoint=False
        )

    def forward(self, x, label):
        x = torch.cat([x, label], 1)  
        x = self.swin_t(x)  
        x = torch.sigmoid(x) 
        return x
        


class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, subfolder='train', direction='AtoB', transform=None, resize_scale=None, crop_size=None, fliplr=False):
        super(DatasetFromFolder, self).__init__()
        if direction == 'AtoB':
            self.input_path = os.path.join(image_dir, subfolder, 'a')
            self.target_path = os.path.join(image_dir, subfolder, 'b')
            
        else:
            self.input_path = os.path.join(image_dir, subfolder, 'b')
            self.target_path = os.path.join(image_dir, subfolder, 'a')

        self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        #self.image_filenames = self.image_filenames[:200]

        
        self.direction = direction
        self.transform = transform
        self.resize_scale = resize_scale    
        self.crop_size = crop_size 
        self.fliplr = fliplr    

    def __getitem__(self, index):
        
        
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        img_tar = os.path.join(self.target_path, self.image_filenames[index])
        img_input = cv.imread(img_fn)
        img_target = cv.imread(img_tar)
        img_target2 = cv.imread(img_tar)
        
        stride = False
        
        if stride:
            pass
        
        else:
            pass

        if self.resize_scale:
            img_input = cv.resize(img_input, (self.resize_scale, self.resize_scale))
            img_target = cv.resize(img_target, (self.resize_scale, self.resize_scale))


        if self.crop_size:
            
            x = random.randint(0, self.resize_scale - self.crop_size)
            y = random.randint(0, self.resize_scale - self.crop_size)
            
            img_input = img_input[x : x + self.crop_size, y:y+self.crop_size, :]
            img_target = img_target[x : x + self.crop_size, y:y+self.crop_size, :]


        if self.fliplr:
            if random.random() < 0.5:
                
                img_input = cv.flip(img_input, 1)
                img_target = cv.flip(img_target, 1)

        img_input = transforms.ToPILImage()(img_input)
        img_target = transforms.ToPILImage()(img_target)
        
        if self.transform is not None:
            img_input = self.transform(img_input)
            img_target = self.transform(img_target)
        img_target2 = cv.resize(img_target2, (256, 256))
        return img_input, img_target ,img_target2

    def __len__(self):
        return len(self.image_filenames)
    
    
def save_error_maps_gray(input_name, input_ori, target, gen_image,
                                  epoch, save=False, save_dir='output_results/'):

    name = input_name.split(".")[0]

    gen_image_show = (((gen_image[0] - gen_image[0].min()) * 255) / (
                 gen_image[0].max() - gen_image[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8)
    input_image_show = (((input_ori[0] - input_ori[0].min()) * 255) / (
                 input_ori[0].max() - input_ori[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8)

    input = input_ori[0].numpy().transpose(1, 2, 0)
    input_gray = cv.cvtColor(input, cv.COLOR_BGR2GRAY)
    input_gray = (input_gray - input_gray.min()) / (input_gray.max() - input_gray.min())

    gen_image = gen_image[0].numpy().transpose(1, 2, 0)
    gen_image_gray = cv.cvtColor(gen_image, cv.COLOR_BGR2GRAY)
    gen_image_gray = (gen_image_gray - gen_image_gray.min()) / (gen_image_gray.max() - gen_image_gray.min())

    target = (((target[0] - target[0].min()) * 255) / (target[0].max() - target[0].min())).numpy().transpose(1, 2, 0)

    error_map = np.absolute(input_gray - gen_image_gray)
   
    error_map = (error_map - np.min(error_map)) / (np.max(error_map) - np.min(error_map))
    error_map = (error_map * 255).astype(np.uint8)

    input_gray = (input_gray * 255).astype(np.uint8)
    gen_image_gray = (gen_image_gray * 255).astype(np.uint8)

    if save:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        cv.imwrite(save_dir + '{:s}'.format(name) + '.png', error_map)


test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

test_data = DatasetFromFolder(data_dir, subfolder='test', direction= 'AtoB', transform=test_transform)
test_data_loader = torch.utils.data.DataLoader(dataset=test_data,
                                               batch_size=1,
                                               shuffle=False)

G = Generator(3, ngf, 3)
D = SwinTDiscriminator()
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    D = nn.DataParallel(D)
    G = nn.DataParallel(G)
D.cuda()
G.cuda()


D.load_state_dict(torch.load('/kaggle/working/saved-model/100discriminator_param.pkl'))


for i, (input, target, img_target2) in enumerate(test_data_loader):
    
    
    input_name = test_data.image_filenames[i] 

    input_np = (((input[0] - input[0].min()) * 255) / (input[0].max() - input[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8)

    input_ori = transforms.ToPILImage()(input_np)
    input_ori = test_transform(input_ori)

    input_ori = torch.unsqueeze(input_ori, 0)        

    x_ = Variable(input_ori.cuda())
    y_ = Variable(target.cuda())

    gen_image = G(x_)
    gen_image = gen_image.cpu().data

    save_error_maps_gray(input_name, input_ori, target, gen_image,
                              i, save=True, save_dir=save_error_dir)
    print('%d images are generated.' % (i + 1))
        