In [1]:
import torch
from torchmetrics import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio

import albumentations as A
from albumentations.pytorch import ToTensorV2

import os
from os import path
from PIL import Image
import numpy as np

from tqdm import tqdm

In [2]:
PATH_NAME = f'./'
DATA_PATH = f'./maps'
LOAD_PATH = f'./hpix-weights/'

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 2e-4
BATCH_SIZE = 1
START_EPOCH = 1
NUM_EPOCHS = 200
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
LAMBDA_L1 = 100
LOAD_MODEL = True
SAVE_MODEL = True

# Evaluating model
LOAD_GEN_GLOBAL = f'{LOAD_PATH}global_gen.pth.tar'
LOAD_GEN_LOCAL = f'{LOAD_PATH}local_gen.pth.tar'

both_transform_train = A.Compose(
    [A.Resize(286, 286),
     A.RandomCrop(256, 256),
     A.HorizontalFlip(p=0.5),],
    additional_targets={"image0": "image"},
)

both_transform_test = A.Compose(
    [A.Resize(256, 256)],
    additional_targets={"image0": "image"},
)

transform_only_input = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[
                    0.5, 0.5, 0.5], max_pixel_value=255.0),
        ToTensorV2()
    ],
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[
                    0.5, 0.5, 0.5], max_pixel_value=255.0),
        ToTensorV2()
    ],
)

In [4]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('=> Loading Checkpoint')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    # Update learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def pixelLevelAcc(pred, target):
    pred = pred[0].permute(1, 2, 0).cpu().numpy()
    target = target[0].permute(1, 2, 0).cpu().numpy()
    count = 0

    # run a loop through each pixel
    for i in range(pred.shape[0]):
        for j in range(pred.shape[1]):
            rd = abs(pred[i][j][0] - target[i][j][0])
            gd = abs(pred[i][j][1] - target[i][j][1])
            bd = abs(pred[i][j][2] - target[i][j][2])
            if max(rd, gd, bd) <= 5:
                count += 1
            else:
                count += 0
    return count / (pred.shape[0] * pred.shape[1])

def ssimAcc(pred, target):
    pred = pred.type(torch.FloatTensor)
    target = target.type(torch.FloatTensor)
    ssim = StructuralSimilarityIndexMeasure(data_range=255)
    return ssim(pred, target)

def psnrAcc(pred, target):
    pred = pred.type(torch.FloatTensor)
    target = target.type(torch.FloatTensor)
    psnr = PeakSignalNoiseRatio()
    return psnr(pred, target)

In [5]:
def evaluate(global_gen, local_gen, loader, deep_supervision=True):
    global_gen.eval()
    local_gen.eval()
    pbar = tqdm(loader)
    pixel_acc = 0
    ssim_acc = 0
    psnr_acc = 0
    with torch.no_grad():
        for idx, (x, y) in enumerate(pbar):
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            if deep_supervision:
                pred_g = global_gen(x)[6]
            else:
                pred_g = global_gen(x)
            pred = local_gen(x, pred_g)

            # Remove Normalization
            pred = ((pred * 0.5 + 0.5) * 255).type(torch.IntTensor)
            y = ((y * 0.5 + 0.5) * 255).type(torch.IntTensor)

            # Calculate Pixel Accuracy and SSIM Accuracy
            p_acc = pixelLevelAcc(pred, y)
            s_acc = ssimAcc(pred, y)
            ps_acc = psnrAcc(pred, y)
            if idx < 5:
                print(f'Pixel Accuracy: {p_acc}')
                print(f'SSIM Accuracy: {s_acc}')
                print(f'PSNR Accuracy: {ps_acc}')
            pixel_acc += p_acc
            ssim_acc += s_acc
            psnr_acc += ps_acc
    return pixel_acc / len(loader), ssim_acc / len(loader), psnr_acc / len(loader)

In [None]:
# Defining the dataset loader function for validation
class MapDatasetLoader(torch.utils.data.Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.files = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_file = self.files[idx]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_img = image[:, :600, :]
        target_img = image[:, 600:, :]

        augmented = both_transform_test(image=input_img, image0=target_img)
        input_img, target_img = augmented["image"], augmented["image0"]
        input_img = transform_only_input(image=input_img)["image"]
        target_img = transform_only_mask(image=target_img)["image"]

        return input_img, target_img

In [None]:
class GlobalBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, downsample=True, act='relu', use_dropout=False):
        super().__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, 4, 2, 1,
                            bias=False, padding_mode='reflect')
            if downsample
            else torch.nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            torch.nn.InstanceNorm2d(out_channels, affine=True),
            torch.nn.ReLU() if act == 'relu' else torch.nn.LeakyReLU(0.2)
        )
        self.use_dropout = use_dropout
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

In [None]:
class TransitionBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.leakyrelu = torch.nn.LeakyReLU(0.2)
        self.conv = torch.nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.inorm = torch.nn.InstanceNorm2d(out_channels, affine=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.inorm(out)
        out = self.leakyrelu(out)
        return out

In [None]:
class GlobalGenerator(torch.nn.Module):
    def __init__(self, in_channels=3, deep_supervision=False, **kwargs):
        super().__init__()
        features = [64, 128, 256, 512, 512, 512, 512]

        self.up1 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(features[1], features[0], 4, 2, 1),
            torch.nn.InstanceNorm2d(features[0], affine=True),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5)
        )
        self.up2 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(features[2], features[1], 4, 2, 1),
            torch.nn.InstanceNorm2d(features[1], affine=True),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5)
        )
        self.up3 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(features[3], features[2], 4, 2, 1),
            torch.nn.InstanceNorm2d(features[2], affine=True),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5)
        )
        self.up4 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(features[4], features[3], 4, 2, 1),
            torch.nn.InstanceNorm2d(features[3], affine=True),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5)
        )
        self.up5 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(features[5], features[4], 4, 2, 1),
            torch.nn.InstanceNorm2d(features[4], affine=True),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5)
        )
        self.up6 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(features[6], features[5], 4, 2, 1),
            torch.nn.InstanceNorm2d(features[5], affine=True),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5)
        )

        self.deep_supervision = deep_supervision

        self.conv0_0 = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels, features[0], 4, 2, 1, padding_mode='reflect'),
            torch.nn.LeakyReLU(0.2)
        )
        self.conv1_0 = GlobalBlock(
            features[0], features[1], downsample=True, use_dropout=False, act='leaky')
        self.conv2_0 = GlobalBlock(
            features[1], features[2], downsample=True, use_dropout=False, act='leaky')
        self.conv3_0 = GlobalBlock(
            features[2], features[3], downsample=True, use_dropout=False, act='leaky')
        self.conv4_0 = GlobalBlock(
            features[3], features[4], downsample=True, use_dropout=False, act='leaky')
        self.conv5_0 = GlobalBlock(
            features[4], features[5], downsample=True, use_dropout=False, act='leaky')
        self.conv6_0 = GlobalBlock(
            features[5], features[6], downsample=True, use_dropout=False, act='leaky')

        self.conv7_0 = torch.nn.Sequential(
            torch.nn.Conv2d(features[6], features[6],
                            4, 2, 1, padding_mode='reflect'),
        )
        self.conv7_1 = GlobalBlock(
            features[6], features[6], downsample=False, use_dropout=True, act='relu')

        self.conv0_1 = TransitionBlock(features[0] * 2, features[0])
        self.conv1_1 = TransitionBlock(features[1] * 2, features[1])
        self.conv2_1 = TransitionBlock(features[2] * 2, features[2])
        self.conv3_1 = TransitionBlock(features[3] * 2, features[3])
        self.conv4_1 = TransitionBlock(features[4] * 2, features[4])
        self.conv5_1 = TransitionBlock(features[5] * 2, features[5])
        self.conv6_1 = GlobalBlock(
            features[6] * 2, features[5], downsample=False, use_dropout=True, act='relu')

        self.conv0_2 = TransitionBlock(features[0] * 3, features[0])
        self.conv1_2 = TransitionBlock(features[1] * 3, features[1])
        self.conv2_2 = TransitionBlock(features[2] * 3, features[2])
        self.conv3_2 = TransitionBlock(features[3] * 3, features[3])
        self.conv4_2 = TransitionBlock(features[4] * 3, features[4])
        self.conv5_2 = GlobalBlock(features[5] * 3, features[4],
                                   downsample=False, use_dropout=True, act='relu')

        self.conv0_3 = TransitionBlock(features[0] * 4, features[0])
        self.conv1_3 = TransitionBlock(features[1] * 4, features[1])
        self.conv2_3 = TransitionBlock(features[2] * 4, features[2])
        self.conv3_3 = TransitionBlock(features[3] * 4, features[3])
        self.conv4_3 = GlobalBlock(features[4] * 4, features[3],
                                   downsample=False, use_dropout=True, act='relu')

        self.conv0_4 = TransitionBlock(features[0] * 5, features[0])
        self.conv1_4 = TransitionBlock(features[1] * 5, features[1])
        self.conv2_4 = TransitionBlock(features[2] * 5, features[2])
        self.conv3_4 = GlobalBlock(features[3] * 5, features[2],
                                   downsample=False, use_dropout=True, act='relu')

        self.conv0_5 = TransitionBlock(features[0] * 6, features[0])
        self.conv1_5 = TransitionBlock(features[1] * 6, features[1])
        self.conv2_5 = GlobalBlock(features[2] * 6, features[1],
                                   downsample=False, use_dropout=True, act='relu')

        self.conv0_6 = TransitionBlock(features[0] * 7, features[0])
        self.conv1_6 = GlobalBlock(features[1] * 7, features[0],
                                   downsample=False, use_dropout=True, act='relu')

        if self.deep_supervision:
            self.final1 = torch.nn.Sequential(
                torch.nn.ConvTranspose2d(features[0], in_channels, 4, 2, 1),
                torch.nn.Tanh()
            )
            self.final2 = torch.nn.Sequential(
                torch.nn.ConvTranspose2d(features[0], in_channels, 4, 2, 1),
                torch.nn.Tanh()
            )
            self.final3 = torch.nn.Sequential(
                torch.nn.ConvTranspose2d(features[0], in_channels, 4, 2, 1),
                torch.nn.Tanh()
            )
            self.final4 = torch.nn.Sequential(
                torch.nn.ConvTranspose2d(features[0], in_channels, 4, 2, 1),
                torch.nn.Tanh()
            )
            self.final5 = torch.nn.Sequential(
                torch.nn.ConvTranspose2d(features[0], in_channels, 4, 2, 1),
                torch.nn.Tanh()
            )
            self.final6 = torch.nn.Sequential(
                torch.nn.ConvTranspose2d(features[0], in_channels, 4, 2, 1),
                torch.nn.Tanh()
            )
            self.final7 = torch.nn.Sequential(
                torch.nn.ConvTranspose2d(
                    features[0] * 8, in_channels, 4, 2, 1),
                torch.nn.Tanh()
            )
        else:
            self.final = torch.nn.Sequential(
                torch.nn.ConvTranspose2d(
                    features[0] * 8, in_channels, 4, 2, 1),
                torch.nn.Tanh()
            )

    def forward(self, x):
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(x0_0)
        x2_0 = self.conv2_0(x1_0)
        x3_0 = self.conv3_0(x2_0)
        x4_0 = self.conv4_0(x3_0)
        x5_0 = self.conv5_0(x4_0)
        x6_0 = self.conv6_0(x5_0)
        x7_0 = self.conv7_0(x6_0)
        x7_1 = self.conv7_1(x7_0)

        x0_1 = self.conv0_1(torch.cat([x0_0, self.up1(x1_0)], dim=1))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up2(x2_0)], dim=1))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up3(x3_0)], dim=1))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up4(x4_0)], dim=1))
        x4_1 = self.conv4_1(torch.cat([x4_0, self.up5(x5_0)], dim=1))
        x5_1 = self.conv5_1(torch.cat([x5_0, self.up6(x6_0)], dim=1))
        x6_1 = self.conv6_1(torch.cat([x6_0, x7_1], dim=1))

        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up1(x1_1)], dim=1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up2(x2_1)], dim=1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up3(x3_1)], dim=1))
        x3_2 = self.conv3_2(torch.cat([x3_0, x3_1, self.up4(x4_1)], dim=1))
        x4_2 = self.conv4_2(torch.cat([x4_0, x4_1, self.up5(x5_1)], dim=1))
        x5_2 = self.conv5_2(torch.cat([x5_0, x5_1, x6_1], dim=1))

        x0_3 = self.conv0_3(
            torch.cat([x0_0, x0_1, x0_2, self.up1(x1_2)], dim=1))
        x1_3 = self.conv1_3(
            torch.cat([x1_0, x1_1, x1_2, self.up2(x2_2)], dim=1))
        x2_3 = self.conv2_3(
            torch.cat([x2_0, x2_1, x2_2, self.up3(x3_2)], dim=1))
        x3_3 = self.conv3_3(
            torch.cat([x3_0, x3_1, x3_2, self.up4(x4_2)], dim=1))
        x4_3 = self.conv4_3(torch.cat([x4_0, x4_1, x4_2, x5_2], dim=1))

        x0_4 = self.conv0_4(
            torch.cat([x0_0, x0_1, x0_2, x0_3, self.up1(x1_3)], dim=1))
        x1_4 = self.conv1_4(
            torch.cat([x1_0, x1_1, x1_2, x1_3, self.up2(x2_3)], dim=1))
        x2_4 = self.conv2_4(
            torch.cat([x2_0, x2_1, x2_2, x2_3, self.up3(x3_3)], dim=1))
        x3_4 = self.conv3_4(torch.cat([x3_0, x3_1, x3_2, x3_3, x4_3], dim=1))

        x0_5 = self.conv0_5(
            torch.cat([x0_0, x0_1, x0_2, x0_3, x0_4, self.up1(x1_4)], dim=1))
        x1_5 = self.conv1_5(
            torch.cat([x1_0, x1_1, x1_2, x1_3, x1_4, self.up2(x2_4)], dim=1))
        x2_5 = self.conv2_5(
            torch.cat([x2_0, x2_1, x2_2, x2_3, x2_4, x3_4], dim=1))

        x0_6 = self.conv0_6(
            torch.cat([x0_0, x0_1, x0_2, x0_3, x0_4, x0_5, self.up1(x1_5)], dim=1))
        x1_6 = self.conv1_6(
            torch.cat([x1_0, x1_1, x1_2, x1_3, x1_4, x1_5, x2_5], dim=1))

        if self.deep_supervision:
            o1 = self.final1(x0_1)
            o2 = self.final2(x0_2)
            o3 = self.final3(x0_3)
            o4 = self.final4(x0_4)
            o5 = self.final5(x0_5)
            o6 = self.final6(x0_6)
            o7 = self.final7(
                torch.cat([x0_0, x0_1, x0_2, x0_3, x0_4, x0_5, x0_6, x1_6], dim=1))

            return [o1, o2, o3, o4, o5, o6, o7]
        else:
            final = self.final(
                torch.cat([x0_0, x0_1, x0_2, x0_3, x0_4, x0_5, x0_6, x1_6], dim=1))

            return final

In [None]:
class LocalBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act='relu', use_dropout=False):
        super().__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, 4, 2, 1,
                            bias=False, padding_mode='reflect')
            if down
            else torch.nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            torch.nn.InstanceNorm2d(out_channels, affine=True),
            torch.nn.ReLU() if act == 'relu' else torch.nn.LeakyReLU(0.2)
        )
        self.use_dropout = use_dropout
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

In [None]:
class LocalGenerator(torch.nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels*2, features, 4, 2,
                            1, padding_mode='reflect'),
            torch.nn.LeakyReLU(0.2)
        )  # 128
        self.down1 = LocalBlock(features, features * 2, down=True,
                                use_dropout=False, act='leaky')  # 64
        self.down2 = LocalBlock(features * 2, features * 4,
                                down=True, use_dropout=False, act='leaky')  # 32
        self.down3 = LocalBlock(features * 4, features * 8,
                                down=True, use_dropout=False, act='leaky')  # 16
        self.down4 = LocalBlock(features * 8, features * 8,
                                down=True, use_dropout=False, act='leaky')  # 8
        self.down5 = LocalBlock(features * 8, features * 8,
                                down=True, use_dropout=False, act='leaky')  # 4
        self.down6 = LocalBlock(features * 8, features * 8,
                                down=True, use_dropout=False, act='leaky')  # 2

        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv2d(features * 8, features * 8, 4, 2,
                            1, padding_mode='reflect'),  # 1
        )

        self.up1 = LocalBlock(features * 8, features * 8, down=False,
                              use_dropout=True, act='relu')  # 4
        self.up2 = LocalBlock(features * 8 * 2, features * 8,
                              down=False, use_dropout=True, act='relu')  # 8
        self.up3 = LocalBlock(features * 8 * 2, features * 8,
                              down=False, use_dropout=True, act='relu')  # 16
        self.up4 = LocalBlock(features * 8 * 2, features * 8,
                              down=False, use_dropout=False, act='relu')  # 32
        self.up5 = LocalBlock(features * 8 * 2, features * 4,
                              down=False, use_dropout=False, act='relu')  # 64
        self.up6 = LocalBlock(features * 4 * 2, features * 2,
                              down=False, use_dropout=False, act='relu')  # 128
        self.up7 = LocalBlock(features * 2 * 2, features, down=False,
                              use_dropout=False, act='relu')  # 256

        self.final_up = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(features * 2, in_channels, 4, 2, 1),
            torch.nn.Tanh()
        )

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))


In [None]:
def main():
    global_gen = GlobalGenerator(
        in_channels=3, deep_supervision=True).to(DEVICE)
    local_gen = LocalGenerator(in_channels=3).to(DEVICE)

    opt_gen_global = torch.optim.Adam(
        global_gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    opt_gen_local = torch.optim.Adam(
        local_gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

    if LOAD_MODEL:
        load_checkpoint(
            LOAD_GEN_GLOBAL,
            global_gen,
            opt_gen_global,
            LEARNING_RATE
        )
        load_checkpoint(
            LOAD_GEN_LOCAL,
            local_gen,
            opt_gen_local,
            LEARNING_RATE
        )

    val_dataset_path = f'{DATA_PATH}maps/val'
    val_dataset = MapDatasetLoader(root_dir=val_dataset_path)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=1, shuffle=False)

    pixel_level_accuracy, ssim_accuracy, psnr_accuracy = evaluate(global_gen, local_gen, val_loader)
    
    print(f'Average Pixel Level Accuracy: {pixel_level_accuracy}')
    print(f'Average SSIM Accuracy: {ssim_accuracy}')
    print(f'Average PSNR Accuracy: {psnr_accuracy}')

In [None]:
if __name__ == '__main__':
    main()