# FCN score for cycleGAN models

In [99]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from evaluate.networks import fcn_8s
from skimage.io import imread
import torchvision.transforms as tvt
import PIL
import numpy as np
import matplotlib.pyplot as plt

In [3]:
from FCN.networks import FCN8s, VGGNet
from FCN.citydataset import city, classes_city, train_transforms

from evaluate.metrics import FCNScore
from evaluate import losses

## Dataset

### Folder

In [72]:
folder = r"./datasets/cityscapes/"
mode = r"test"

### Ground Truth FCN score

In [73]:
batch_size =1
city_dataset = city(mode=mode,classes=classes_city)
city_loader  = DataLoader(dataset=city_dataset, batch_size=batch_size)

### Modèle

In [74]:
class FCNScore(torch.nn.Module):
    def __init__(self, model=None, num_classes=21):
        super(FCNScore, self).__init__()
        self.model = model
        self.name = 'fcn'
        self.num_classes = num_classes

    def forward(self, input, target):
        self.model.eval()
        with torch.no_grad():
            output = self.model(input.cuda())
            labels = output.softmax(dim=1).argmax(dim=1)
        return losses.label_score(labels,target,self.num_classes)

In [75]:
fcn = FCN8s(pretrained_net=VGGNet(requires_grad=False),n_class = len(classes_city)).cuda()
fcn.load_state_dict(torch.load(r"FCN/checkpoints/EXP_11/checkpoint.pth.tar"))

# CityScapes : FCN/checkpoints/EXP_11/checkpoint.pth.tar
# Maps : CN/checkpoints/EXP_16/checkpoint.pth.tar"

In [76]:
fcn_score = FCNScore(model = fcn, num_classes=len(classes_city))

In [77]:
metrics = []
for i, batch in enumerate(city_loader):
    ppa, pca, iou_acc =   fcn_score(batch['image'], batch['target'])
    metrics.append(batch_size*np.array([ppa, pca, iou_acc]))
metrics = np.array(metrics)

In [79]:
metrics.mean(axis = 0)

array([0.84556149, 0.37381518, 0.29587395])

In [80]:
metrics.std(axis = 0)

array([0.07882992, 0.05186049, 0.05365637])

### Cycle Gan FCN score

In [102]:
DF = dict()

In [124]:
folder = r"results/{}/test_latest/images"
model = r"cityscapes_cyclegan_sn"

In [125]:
class CycleCity(Dataset):
    def __init__(self, model, folder=folder, transforms=train_transforms, classes=classes_city):
        super(CycleCity, self).__init__()
        self.data_path = folder.format(model)
        self.transforms = transforms
        self.ims = list(filter(lambda x: 'fake_A' in x, os.listdir(folder.format(model))))
        self.masks = list(map(lambda x: x.replace('fake_A', 'real_B'), self.ims))
        self.classes = classes
    def __len__(self):
        return len(self.ims)
    def make_mask(self, mask):
        def find_cluster(vec, classes=self.classes):
            rscores = np.zeros((256 * 256, len(classes)))
            for i in range(len(classes)):
                rscores[:, i] = np.linalg.norm(vec - np.repeat(classes[i].reshape(1, 3), 256 * 256, axis=0), axis=1)
            vc = np.argmin(rscores, axis=1)
            return vc

        def find_cluster_torch(vec, classes=self.classes):
            rscores = torch.zeros((256 * 256, len(classes)))
            for i in range(len(classes)):
                rscores[:, i] = torch.norm(
                    torch.cuda.FloatTensor(vec.reshape(-1, 3)) - torch.cuda.FloatTensor(
                        classes[i].reshape(1, 3)).repeat(256 * 256, 1),
                    dim=1
                )

            vc = rscores.argmin(dim=1)
            return vc
        clusters = find_cluster_torch(mask.reshape(-1,3))
        mask = clusters.view(256,256).type(torch.LongTensor)
        return mask
    def __getitem__(self, index):
        I, M = imread(os.path.join(self.data_path,self.ims[index])), imread(os.path.join(self.data_path,self.masks[index]))
        mask = self.make_mask(M)
        if self.transforms is not None:
            I = self.transforms(I)
        return {'image':I, 'target':mask, 'im':self.ims[index]}

In [126]:
batch_size = 1
cycle_city = CycleCity(model=model)
cycle_city_loader = DataLoader(dataset=cycle_city, batch_size=batch_size)

In [127]:
cycle_metrics = []
ims = []
for i, batch in enumerate(cycle_city_loader):
    ppa, pca, iou_acc =   fcn_score(batch['image'], batch['target'])
    ims.append(batch['im'])
    cycle_metrics.append(batch_size*np.array([ppa, pca, iou_acc]))
cycle_metrics = np.array(cycle_metrics)

In [128]:
print(cycle_metrics.sum(axis = 0)/len(cycle_city))
print(cycle_metrics.std(axis = 0))

[0.64273987 0.26691405 0.19499927]
[0.10265645 0.04264882 0.03582837]


In [129]:
data = list(filter(lambda x: '14'== x[0][0][:2], zip(ims, cycle_metrics)))
D = list(map(lambda x: [x[0][0], x[1][0], x[1][1], x[1][2]], data))
DF[model] = pd.DataFrame(D).sort_values(by=0)

In [135]:
def imbyim(DF):
    df = dict()
    for im in DF['cityscapes_cyclegan'].keys():
        values = np.zeros(4)
        for key in DF:
            if 'pretrained' in key:
                print(DF[key][DF[key]["0"] == im][1])
                values[0] = DF[key][DF[key][0] == im][1]
            elif 'sn' in key:
                values[3] = DF[key][DF[key][0] == im][1]
            elif 'wgan' in key:
                values[2] = DF[key][DF[key][0] == im][1]
            else:
                values[1] = DF[key][DF[key][0] == im][1]
            df[im] = pd.DataFrame(values, columns = ['PreTrained', 'Baseline', 'WGANGP', 'SN'])
    return df

#### Cycle gan pretrained 
Mean [0.52215149 0.21816304 0.14388255]




STD [0.07318946 0.0356971  0.02448053]

#### Cycle gan
Mean [0.58749603 0.25522303 0.18075776]




STD [0.0629457  0.03385866 0.02729342]

#### Cycle gan SN 

Mean [0.64273987 0.26691405 0.19499927]



STD [0.10265645 0.04264882 0.03582837]


#### Cycle Gan WGAN GP 

Mean   [0.66905457 0.29080951 0.19353192]


STD [0.07726905 0.0395544  0.0304897 ]


In [165]:
D = pd.DataFrame(np.hstack((DF[key][1].values.reshape(-1,1) for key in DF)).T, columns = DF['cityscapes_label2photo_pretrained'][0].values)
D.index = ['PreTrained', 'Baseline', 'WGANGP', 'SN']

In [168]:
def highlight_max(s):
    '''
    highlight the maximum in a Series yellow.
    '''
    is_max = s == s.max()
    return ['background-color: green' if v else '' for v in is_max]

In [169]:
D.style.apply(highlight_max)

Unnamed: 0,140_B_fake_B.png,141_B_fake_B.png,142_B_fake_B.png,143_B_fake_B.png,144_B_fake_B.png,145_B_fake_B.png
PreTrained,0.508453,0.424057,0.625854,0.553375,0.561066,0.485367
Baseline,0.557068,0.449554,0.609344,0.552261,0.534637,0.557755
WGANGP,0.651703,0.588394,0.668732,0.669403,0.652969,0.690491
SN,0.68428,0.54924,0.61618,0.562363,0.576981,0.588684
