<a href="https://colab.research.google.com/github/facebookresearch/vissl/blob/v0.1.6/tutorials/Feature_Extraction_V0_1_6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

# Feature Extraction

In this tutorial, we look at a simple example of how to use VISSL to extract features after finished training the vissl moddels.

**EXAMPLE 1**: Download the pre-trained [jigsaw-retouch](https://drive.google.com/file/d/159SgjqklmLHWpEQNq14i_gJk0NDhyAHE/view?usp=sharing) to the `root` directory and rename it to `checkpoints`.

**EXAMPLE 2**: Download the pre-trained [jigsaw-cityscapes](https://drive.google.com/file/d/1Af710oLe_n1h4RMMnhdbxWQWiDCJx68j/view?usp=sharing) to the `root` directory and rename it to `checkpoints`.

VISSL should be successfuly installed by now and all the dependencies should be available.


In [2]:
import vissl
import tensorboard
import apex
import torch

## Using the custom data in VISSL

The original data is saved in the `data` directory. The transferred images are saved in such a way, that they are stored in the `data/generated_images/#epoch` directory (`#epoch` is the number of CycleGAN epoch).

**EXAMPLE 1**: download the retouch data set from [retouch-dataset](https://drive.google.com/file/d/1r8pQCoVzEAHdy9wLW_MUkyfgBBFePMPv/view?usp=sharing) and insert it into the `data/real_images` directory. Download the transferred images from [transferred-retouch-images](https://drive.google.com/file/d/1nMcyF-z2yvPBDY70qBsT2Ydg7NUITpmR/view?usp=sharing) and isert the subfolders with the epoch number into the `data/generated_images` directory.

**EXAMPLE 2**: download the truncated retouch GTAV data set from [gta5-truncated-dataset](https://drive.google.com/file/d/1R9zmrwAKf03KOq9MSfhdPd6xOVRGEtrY/view?usp=sharing) and insert it into the `data/real_images` directory. Download the transferred images from [transferred-gta5-images](https://drive.google.com/file/d/1SLdGNHDi3LZTHXXNMNFDTmAQibAEjj-x/view?usp=sharing) and isert the subfolders with the epoch number into the `data/generated_images` directory. Note, it also works with the whole data set, one only has to change the `splits/gta5.txt` to the whole dataset. The truncated version is used due to memory and time efficiency.

In [3]:
import skimage.io as io
import matplotlib.pylab as plt
import numpy as np

from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
# from medpy.io import load
import random

import torch
import os
from PIL import Image  # using pillow-simd for increased speed

import cv2
from torchvision import transforms


def pil_loader(path):
    # open path as file to avoid ResourceWarning
    # (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert("L")

class Retouch_dataset(Dataset):
    def __init__(self,
                 base_dir,
                 list_dir,
                 size=(512, 512),
                 is_train=False,
                 transform=None,
                 ext='.png'):
        self.transform = transform  # using transform in torch!
        self.sample_list = open(list_dir).readlines()

        self.data_dir = base_dir
        self.loader = pil_loader
        self.to_tensor = transforms.Compose(
            [
                transforms.Resize(size),
                transforms.ToTensor(),
                #transforms.Normalize(mean=(0.5,), std=(0.5,)),
            ])

        self.is_train = is_train
        self.transform = transform
        self.ext = ext

    def augment(self, data, label):
        data_label = torch.cat((data, label), dim=0)
        data_label_aug = self.transform(data_label)
        data_aug = data_label_aug[0, :, :].unsqueeze(0)
        label_aug = data_label_aug[1, :, :].unsqueeze(0)
        return data_aug, label_aug

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        sample_name = self.sample_list[idx].strip('\n')

        vendor = sample_name.split(' ')[0]
        slice_name = sample_name.split(' ')[1]
        slice_idx = sample_name.split(' ')[2].zfill(3)

        data_path = os.path.join(self.data_dir,
                                 vendor,
                                 slice_name,
                                 'images',
                                 slice_idx + self.ext)
        '''

        label_path = os.path.join(self.data_dir,
                                  vendor,
                                  slice_name,
                                  'labels',
                                  slice_idx + '.npy')

        label = torch.from_numpy(np.load(label_path))
        label_idx = torch.argmax(label, dim=0, keepdim=True)
        
        '''
        
        data = self.to_tensor(self.loader(data_path))
        transform_avaliable = self.transform is not None and self.is_train
        do_aug = transform_avaliable and random.random() > 0.5

        if do_aug:
            # data, label_idx = self.augment(data, label_idx)
            data = self.augment(data)

        # label_idx = label_idx.squeeze(0).long()

        sample = {'image': data.repeat(3, 1, 1),
                  'case_name': sample_name}
        # print((label_idx==0).sum()/512**2)
        return sample


  from scipy.ndimage.interpolation import zoom


In [4]:
# gta5 dataset (source)

import os
import os.path as osp
import sys

import numpy as np
import random
import matplotlib.pyplot as plt
import collections
import torch
import torchvision
from torch.utils import data
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

'''
labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]
'''

class GTA5Dataset(data.Dataset):
    def __init__(self, root, list_path, max_iters=None, crop_size=(256, 256), mean=(128, 128, 128), ignore_label=255):
        self.root = root
        self.list_path = list_path
        self.crop_size = crop_size
        self.ignore_label = ignore_label
        self.mean = mean
        self.img_ids = [i_id.strip() for i_id in open(list_path)]
        if not max_iters==None:
            self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids)))
        self.files = []

        self.id_to_trainid = {7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5,
                              19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,
                              26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18}

        self.id2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
            3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
            7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
            14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
            18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
            28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}

        #self.id_to_trainid = {7: 1, 24: 2, 26: 3} #Road/car/people
        self.id_to_trainid = {11: 1, 24: 2, 21: 3} #Building/car/vegetation
        #self.ignore_label = 0

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, index):
        name = self.img_ids[index]
        
        # name = name[:-4] + ".jpg"
        
        image = Image.open(osp.join(self.root, "images/%s" % name)).convert('RGB')
        label = Image.open(osp.join(self.root, "labels/%s" % name))
        # resize
        image = image.resize(self.crop_size, Image.BICUBIC)
        label = label.resize(self.crop_size, Image.NEAREST)

        image = np.asarray(image, np.float32)
        label = np.asarray(label, np.int8)

        label_copy = self.ignore_label * np.ones(label.shape, dtype=np.float32)
        #for k, v in self.id_to_trainid.items():
        for k, v in self.id2label.items():
            label_copy[label == k] = v
        size = image.shape
        #image = image[:, :, ::-1]  # change to BGR
        #image -= self.mean
        image = image.transpose((2, 0, 1))
        sample = {'image': image.copy(),
                  'label': label_copy.copy()}

        return sample


In [5]:
# Cityscapes dataset (target)

import os
import os.path as osp
import sys

import numpy as np
import random
import matplotlib.pyplot as plt
import collections
import torch
import torchvision
from torch.utils import data
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
            220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
            0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign',
        'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
        'bicycle']

    
class CityscapesDataset(data.Dataset):
    def __init__(self, root, list_path, max_iters=None, crop_size=(256, 256), mean=(128, 128, 128), ignore_label=255):
        self.root = root
        self.list_path = list_path
        self.crop_size = crop_size
        self.ignore_label = ignore_label
        self.mean = mean
        self.img_ids = [i_id.strip() for i_id in open(list_path)]
        if not max_iters==None:
            self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids)))
        self.files = []
        
        self.id_to_trainid = {7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5,
                              19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,
                              26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18}

        self.id2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
            3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
            7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
            14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
            18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
            28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}

        #self.id_to_trainid = {7: 1, 24: 2, 26: 3} #Road/car/people
        self.id_to_trainid = {11: 1, 24: 2, 21: 3} #Building/car/vegetation
        #self.ignore_label = 0

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, index):
        name = self.img_ids[index]
        image_root = osp.join(self.root, 'cityscapes')
        label_root = osp.join(self.root, 'gtFine')
        image = Image.open(osp.join(image_root, "%s" % name)).convert('RGB')
        label = Image.open(osp.join(label_root, "%s" % name.replace("leftImg8bit", "gtFine_labelIds")))
        # resize
        image = image.resize(self.crop_size, Image.BICUBIC)
        label = label.resize(self.crop_size, Image.NEAREST)

        image = np.asarray(image, np.float32)
        label = np.asarray(label, np.int8)

        label_copy = self.ignore_label * np.ones(label.shape, dtype=np.float32)
        #for k, v in self.id_to_trainid.items():
        for k, v in self.id2label.items():
            label_copy[label == k] = v
        size = image.shape
        #image = image[:, :, ::-1]  # change to BGR
        #image -= self.mean
        image = image.transpose((2, 0, 1))
        sample = {'image': image.copy(),
                  'label': label_copy.copy()}

        return sample
    

In [17]:
from omegaconf import OmegaConf
from vissl.utils.hydra_config import AttrDict
from vissl.models import build_model
from classy_vision.generic.util import load_checkpoint
from vissl.utils.checkpoint import init_model_from_consolidated_weights
from PIL import Image
import torchvision.transforms as transforms
import argparse

from datasets import create_dataset

import sys, os

from tqdm import tqdm

import numpy as np

import csv
import pandas as pd

from scipy import linalg
import torch

from sklearn.metrics.pairwise import cosine_similarity


def calculate_fid(feat1, feat2):
    """ Calculate FID between images1 and images2
    Args:
        images1: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8
        images2: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8
        use_multiprocessing: If multiprocessing should be used to pre-process the images
        batch size: batch size used for inception network
    Returns:
        FID (scalar)
    """
    mu1, sigma1 = calculate_activation_statistics(feat1)
    mu2, sigma2 = calculate_activation_statistics(feat2)
    fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)

    return fid


def calculate_activation_statistics(feat):
    """Calculates the statistics used by FID
    Args:
        feat: torch.tensor, shape: (N, 2048), dtype: torch.float32 in range 0 - 1
    Returns:
        mu:     mean over all activations from the last pool layer of the inception model
        sigma:  covariance matrix over all activations from the last pool layer
                of the inception model.
    """

    feat_np = feat.cpu().detach().numpy()
    mu = np.mean(feat_np, axis=0) # (2048, 0)
    sigma = np.cov(feat_np, rowvar=False) # (2048, 2048)
    return mu, sigma


# Modified from: https://github.com/bioinf-jku/TTUR/blob/master/fid.py
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

    Stable version by Dougal J. Sutherland.
    Params:
    -- mu1 : Numpy array containing the activations of the pool_3 layer of the
             inception net ( like returned by the function 'get_predictions')
             for generated samples.
    -- mu2   : The sample mean over activations of the pool_3 layer, precalcualted
               on an representive data set.
    -- sigma1: The covariance matrix over activations of the pool_3 layer for
               generated samples.
    -- sigma2: The covariance matrix over activations of the pool_3 layer,
               precalcualted on an representive data set.
    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
    assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

    diff = mu1 - mu2
    # product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean


def create_dataset(dataset_mode, folder_name, sample_list, size):
    if dataset_mode == "OCT":
        dataset = Retouch_dataset(base_dir=folder_name, list_dir=sample_list, size = size)
    elif dataset_mode == "gta5":
        dataset = GTA5Dataset(root=folder_name, list_path=sample_list, crop_size=size, ignore_label=19)
    elif dataset_mode == "cityscapes":
        dataset = CityscapesDataset(root=folder_name, list_path=sample_list, crop_size=size, ignore_label=19)
    else:
        print("Unrecognized dataset!")
        sys.exit()
        
    return dataset


def run_eval(opt):
    real_dataset = create_dataset(opt.dataset_mode, opt.real_dir, opt.real_list, opt.crop_size)
    fake_dataset = create_dataset(opt.dataset_mode, opt.fake_dir, opt.fake_list, opt.crop_size)

    real_loader = torch.utils.data.DataLoader(real_dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=opt.shuffle,
                                         num_workers=opt.num_threads,
                                         pin_memory=True,
                                         drop_last=False)

    fake_loader = torch.utils.data.DataLoader(fake_dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=opt.shuffle,
                                         num_workers=opt.num_threads,
                                         pin_memory=True,
                                         drop_last=False)
    
    model.cuda()
    # model.eval()
    feature_fake = 0
    for idx, input in enumerate(tqdm(fake_loader)):
        features = model(input['image'].cuda())
        feat = torch.flatten(features[0], start_dim=1)
        if idx == 0:
            feature_fake = torch.zeros((len(fake_loader), feat.shape[1]))
            feature_fake[idx, :] = feat
        else:
            feature_fake[idx, :] = feat

    feature_real = 0
    if not os.path.exists("real.pt"):
        for idx, input in enumerate(tqdm(real_loader)):
            features = model(input['image'].cuda())
            feat = torch.flatten(features[0], start_dim=1)
            if idx == 0:
                feature_real = torch.zeros((len(real_loader), feat.shape[1]))
                feature_real[idx, :] = feat
                # feature_target = feat
            else:
                feature_real[idx, :] = feat
                # feature_target = torch.cat((feature_target, feat), 0)
            
        torch.save(feature_real, 'real.pt')
    else:
        feature_real = torch.load('real.pt')

    fid = calculate_fid(feature_fake, feature_real)
    print("Epoch {}:".format(opt.load_epoch), "score", fid)

    csv_path = os.path.join(os.getcwd(), "results", "self_supervised_results_{}.csv".format(opt.dataset_mode))
   
    if os.path.isfile(csv_path):
        x = []
        value = []
        with open(csv_path, 'r') as csvfile:
            lines = csv.reader(csvfile, delimiter=',')
            for idx, row in enumerate(lines):
                if idx != 0:
                    x.append(row[0])
                    value.append(row[1])
        x.append(opt.load_epoch)
        value.append(fid)
        x_np = np.asarray(x).astype(int)
        value_np = np.asarray(value).astype(float)

    to_write = []
    to_write.append(["epoch", "jigsaw"])

    if os.path.isfile(csv_path):
        for epoch in range(len(x_np)):
            result = [x_np[epoch], value_np[epoch]]
            to_write.append(result)
    else:
        result = [opt.load_epoch, fid]
        to_write.append(result)

    with open(csv_path, "w") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(to_write)
        x = []
        value = []
        with open(csv_path, 'r') as csvfile:
            lines = csv.reader(csvfile, delimiter=',')
            for idx, row in enumerate(lines):
                if idx != 0:
                    x.append(row[0])
                    value.append(row[1])
        x.append(opt.load_epoch)
        value.append(fid)
        x_np = np.asarray(x).astype(int)
        value_np = np.asarray(value).astype(float)

    to_write = []
    to_write.append(["epoch", "jigsaw"])

    if os.path.isfile(csv_path):
        for epoch in range(len(x_np)):
            result = [x_np[epoch], value_np[epoch]]
            to_write.append(result)
    else:
        result = [opt.load_epoch, fid]
        to_write.append(result)

    with open(csv_path, "w") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(to_write)


Specify the correct config file:
```
jigsaw_custom_retouch
jigsaw_custom_cityscapes
jigsaw_custom_mnist
jigsaw_custom_synthia
rotnet_custom_retouch
rotnet_custom_cityscapes
rotnet_custom_mnist
rotnet_custom_synthia
```


In [18]:
if __name__ == '__main__':
    config = OmegaConf.load("configs/config/jigsaw_custom_retouch.yaml")

    default_config = OmegaConf.load("configs/config/defaults.yaml")
    cfg = OmegaConf.merge(default_config, config)

    cfg = AttrDict(cfg)
    cfg.config.MODEL._MODEL_INIT_SEED = 0
    cfg.config.MODEL.WEIGHTS_INIT.PARAMS_FILE = "./checkpoints/model_phase100.torch"
    cfg.config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON = True
    cfg.config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY = True
    cfg.config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY = True
    cfg.config.MODEL.FEATURE_EVAL_SETTINGS.SHOULD_FLATTEN_FEATS = False
    cfg.config.MODEL.FEATURE_EVAL_SETTINGS.LINEAR_EVAL_FEAT_POOL_OPS_MAP = [["res5avg", ["Identity", []]]]

    model = build_model(cfg.config.MODEL, cfg.config.OPTIMIZER)
    weights = load_checkpoint(checkpoint_path=cfg.config.MODEL.WEIGHTS_INIT.PARAMS_FILE)
    model.cuda()

    init_model_from_consolidated_weights(
        config=cfg.config,
        model=model,
        state_dict=weights,
        state_dict_key_name="classy_state_dict",
        skip_layers=[],  # Use this if you do not want to load all layers
    )
    
    opt = argparse.ArgumentParser()
    opt.dataset_mode = "OCT"
    opt.real_dir = os.path.join(os.getcwd(), "data/real_images/retouch-dataset")
    opt.fake_dir = os.path.join(os.getcwd(), "data/generated_images/OCT_new")
    opt.load_epoch = 0
    
    opt.real_list = os.path.join(os.getcwd(), "splits/spectralis_samples.txt")
    opt.fake_list = os.path.join(os.getcwd(), "splits/cirrus_samples.txt")
    
    opt.crop_size= (512, 512)
    
    opt.num_threads = 0  
    opt.batch_size = 1 
    opt.shuffle = True  
    opt.no_flip = True  
    opt.display_id = -1  

    head = "results/"

    if not os.path.exists(head):
        os.makedirs(head)

    transferred_images_dir = opt.fake_dir
    epochs = [int(f) for f in os.listdir(transferred_images_dir) if os.path.isdir(os.path.join(transferred_images_dir, f))]
    epochs.sort()

    if os.path.exists("real.pt"):
        os.remove("real.pt")

    for epoch in epochs:
        print("run eval epoch {}".format(epoch))
        opt.fake_dir = os.path.join(opt.fake_dir, "{}".format(epoch))
        opt.load_epoch = int(epoch)
        run_eval(opt)


run eval epoch 0


100%|███████████████████████████████████████| 3072/3072 [01:06<00:00, 45.91it/s]
100%|███████████████████████████████████████| 1176/1176 [00:24<00:00, 47.08it/s]


Epoch 0: score 2.420683817463056
run eval epoch 3


  0%|                                                  | 0/3072 [00:00<?, ?it/s]


FileNotFoundError: [Errno 2] No such file or directory: '/home/zeju/Documents/Zeiss_Self_Supervised/data/generated_images/OCT_new/0/3/Cirrus_part3/d40ef678ef3dff06eb3aa3a4f4ae99e6/images/006.png'