# Global Config Variables

In [31]:
SHOT = 1                                                                                          # number of shots
AUG_TYPE = 0                                                                                      # current best setting is use get_aug_data0()
DATA_ROOT = '/Users/nigel/Documents/Research-Git/Dataset/VOCdevkit/VOC2012'                       # the absolute/relative path to your dataset directory
COLOR_MODE = 'R'                                                                                  # 'R','G','B' red, green, blue, please use capital letters
CONFIG_PATH = '/Users/nigel/Documents/Research-Git/Vis/Config/pascal_aug.yaml'                    # the absolute/relative path to your base config .yaml file
SEARCH_MODE = 1                                                                                   # 0: base VS IDA (no LCCA); 1: base VS LCCA (no IDA); 2: base VS IDA+LCCA
META_AUG = 2                                                                                      # <= 1 means no additional augmentations will be applied during data loading
RESUME_PATH = '/Users/nigel/Documents/Research-Git/Vis/pretrained/pascal/split0/pspnet_resnet50/'  # directory to resume weight path for PSPNet, weight file name should be 'best.pth'
CLASS_LIST_PATH = '/Users/nigel/Documents/Research-Git/Vis/lists/json'                            # directory path to JSON list files (please name class files as 'data-list.json' & 'sub-class.json')
ATT_TYPE = 3                                                                                      # images used for attention (LCCA)--0: only use original image; 1: only use augmented image; 2: no tensor_slice; 3: adaptive
META_MODEL_PATH = '/Users/nigel/Documents/Research-Git/Vis/pretrained/mmn-all/'                   # directory to resume weight path for MMN, note that exact folder name 'f11e_pm10' is not requried as it will be inferred from config

# Data

## Tool Packages

### Transform.py

In [32]:
# encoding:utf-8

import random
import math
import numpy as np
import numbers
import collections
import cv2
import torch
import PIL
import PIL.ImageOps
import PIL.ImageEnhance
import PIL.ImageDraw
from PIL import Image
from torchvision import transforms
import torchvision.transforms.functional as F


# ==================================================================================================
# Transforms have been borrowed from https://github.com/hszhao/semseg/blob/master/util/py
# ==================================================================================================
PARAMETER_MAX = 10


class Compose(object):
    def __init__(self, segtransform):
        self.segtransform = segtransform

    def __call__(self, image, label=None):
        if label is None:
            for t in self.segtransform:
                image = t(image, None)
            return image
        else:
            for t in self.segtransform:
                image, label = t(image, label)
            return image, label


class ToTensorPIL(object):
    # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
    def __call__(self, image, label):
        image = transforms.ToTensor()(image)

        if label is not None:
            if not isinstance(label, np.ndarray):
                raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray"
                                    "[eg: data readed by cv2.imread()].\n"))
            if not len(label.shape) == 2:
                raise (RuntimeError(
                    "segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n"))
            label = torch.from_numpy(label)
            if not isinstance(label, torch.LongTensor):
                label = label.long()
            return image, label
        else:
            return image


class ToTensor(object):
    # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
    def __call__(self, image, label):
        if not isinstance(image, np.ndarray):
            raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray"
                                "[eg: data readed by cv2.imread()].\n"))
        if len(image.shape) > 3 or len(image.shape) < 2:
            raise (RuntimeError(
                "segtransform.ToTensor() only handle np.ndarray with 3 dims or 2 dims.\n"))
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=2)

        image = torch.from_numpy(image.transpose((2, 0, 1)))
        if not isinstance(image, torch.FloatTensor):
            image = image.float().div(255)
        if label is not None:
            if not isinstance(label, np.ndarray):
                raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray"
                                    "[eg: data readed by cv2.imread()].\n"))
            if not len(label.shape) == 2:
                raise (RuntimeError(
                    "segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n"))
            label = torch.from_numpy(label)
            if not isinstance(label, torch.LongTensor):
                label = label.long()
            return image, label
        else:
            return image


class Normalize(object):
    # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std
    def __init__(self, mean, std=None):
        if std is None:
            assert len(mean) > 0
        else:
            assert len(mean) == len(std)
        self.mean = mean
        self.std = std

    def __call__(self, image, label):
        if self.std is None:
            for t, m in zip(image, self.mean):
                t.sub_(m)
        else:
            for t, m, s in zip(image, self.mean, self.std):
                t.sub_(m).div_(s)
        if label is not None:
            return image, label
        else:
            return image


class Resize(object):
    # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w).
    def __init__(self, size, padding=None):
        self.size = size
        self.padding = padding

    def __call__(self, image, label):

        def find_new_hw(ori_h, ori_w, test_size):
            if ori_h >= ori_w:
                ratio = test_size * 1.0 / ori_h
                new_h = test_size                 # test_size is target_size
                new_w = int(ori_w * ratio)
            elif ori_w > ori_h:
                ratio = test_size * 1.0 / ori_w
                new_h = int(ori_h * ratio)
                new_w = test_size

            if new_h % 8 != 0:
                new_h = (int(new_h / 8)) * 8   # 为什么新的长宽是8的倍数
            else:
                new_h = new_h
            if new_w % 8 != 0:
                new_w = (int(new_w / 8)) * 8
            else:
                new_w = new_w
            return new_h, new_w

        # Step 1: resize while keeping the h/w ratio. The largest side (i.e height or width) is reduced to $size.
        #                                             The other is reduced accordingly
        test_size = self.size
        new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size)

        image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)),
                                interpolation=cv2.INTER_LINEAR)

        # Step 2: Pad wtih 0 whatever needs to be padded to get a ($size, $size) image
        back_crop = np.zeros((test_size, test_size, 3))
        if self.padding:
            back_crop[:, :, 0] = self.padding[0]
            back_crop[:, :, 1] = self.padding[1]
            back_crop[:, :, 2] = self.padding[2]
        back_crop[:new_h, :new_w, :] = image_crop
        image = back_crop

        # Step 3: Do the same for the label (the padding is 255)
        if label is not None:
            s_mask = label
            new_h, new_w = find_new_hw(
                s_mask.shape[0], s_mask.shape[1], test_size)
            s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),
                                interpolation=cv2.INTER_NEAREST)
            back_crop_s_mask = np.ones((test_size, test_size)) * 255
            back_crop_s_mask[:new_h, :new_w] = s_mask
            label = back_crop_s_mask

            return image, label
        else:
            return image, new_h, new_w


class Resize_np(object):
    # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w).
    def __init__(self, size):
        if isinstance(size, int):
            self.size = (size, size)
        else:
            self.size = size

    def __call__(self, image, label):

        # resize image
        # F.resize(image, self.size, self.interpolation)
        image = cv2.resize(image, dsize=self.size,
                           interpolation=cv2.INTER_LINEAR)
        image = image.astype(np.int)
        # resize the label
        label = cv2.resize(label.astype(np.float32),
                           dsize=self.size, interpolation=cv2.INTER_NEAREST)

        return image, label


class RandScale(object):
    # Randomly resize image & label with scale factor in [scale_min, scale_max]
    def __init__(self, scale, aspect_ratio=None, fixed_size=None, padding=None):
        assert (isinstance(scale, collections.Iterable) and len(scale) == 2)
        if isinstance(scale[0], numbers.Number) \
                and isinstance(scale[1], numbers.Number) \
                and 0 < scale[0] < scale[1]:
            self.scale = scale               # scale = (0.5, 1.5)
        else:
            raise (RuntimeError("segRandScale() scale param error.\n"))

        if aspect_ratio is None:
            self.aspect_ratio = aspect_ratio
        elif isinstance(aspect_ratio, collections.Iterable) \
                and len(aspect_ratio) == 2 \
                and isinstance(aspect_ratio[0], numbers.Number) \
                and isinstance(aspect_ratio[1], numbers.Number) \
                and 0 < aspect_ratio[0] < aspect_ratio[1]:
            self.aspect_ratio = aspect_ratio
        else:
            raise (RuntimeError("segRandScale() aspect_ratio param error.\n"))

        self.fixed_size, self.padding = fixed_size, padding

    def __call__(self, image, label):
        # 从 scale[0] 到 scale[1]随机选取一个scale
        temp_scale = self.scale[0] + \
            (self.scale[1] - self.scale[0]) * random.random()
        temp_aspect_ratio = 1.0
        if self.aspect_ratio is not None:
            temp_aspect_ratio = self.aspect_ratio[0] + (
                self.aspect_ratio[1] - self.aspect_ratio[0]) * random.random()
            temp_aspect_ratio = math.sqrt(temp_aspect_ratio)
        scale_factor_x = temp_scale * temp_aspect_ratio
        scale_factor_y = temp_scale / temp_aspect_ratio
        image = cv2.resize(image, None, fx=scale_factor_x, fy=scale_factor_y,
                           interpolation=cv2.INTER_LINEAR)
        label = cv2.resize(label, None, fx=scale_factor_x, fy=scale_factor_y,
                           interpolation=cv2.INTER_NEAREST)

        if self.fixed_size is not None and self.fixed_size > 0:
            new_h, new_w, _ = image.shape

            back_crop = np.zeros((self.fixed_size, self.fixed_size, 3))
            if self.padding:
                back_crop[:, :, 0] = self.padding[0]
                back_crop[:, :, 1] = self.padding[1]
                back_crop[:, :, 2] = self.padding[2]
            back_crop[:new_h, :new_w, :] = image
            image = back_crop

            back_crop_mask = np.ones((self.fixed_size, self.fixed_size)) * 255
            back_crop_mask[:new_h, :new_w] = label
            label = back_crop_mask

        return image, label


class Crop(object):
    """Crops the given ndarray image (H*W*C or H*W).
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
        int instead of sequence like (h, w), a square crop (size, size) is made.
    """

    def __init__(self, size, crop_type='center', padding=None, ignore_label=255):
        if isinstance(size, int):
            self.crop_h = size
            self.crop_w = size
        elif isinstance(size, collections.Iterable) and len(size) == 2 \
                and isinstance(size[0], int) and isinstance(size[1], int) \
                and size[0] > 0 and size[1] > 0:
            self.crop_h = size[0]
            self.crop_w = size[1]
        else:
            raise (RuntimeError("crop size error.\n"))
        if crop_type == 'center' or crop_type == 'rand':
            self.crop_type = crop_type
        else:
            raise (RuntimeError("crop type error: rand | center\n"))
        if padding is None:
            self.padding = padding
        elif isinstance(padding, list):
            if all(isinstance(i, numbers.Number) for i in padding):
                self.padding = padding
            else:
                raise (RuntimeError("padding in Crop() should be a number list\n"))
            if len(padding) != 3:
                raise (RuntimeError("padding channel is not equal with 3\n"))
        else:
            raise (RuntimeError("padding in Crop() should be a number list\n"))
        if isinstance(ignore_label, int):
            self.ignore_label = ignore_label
        else:
            raise (RuntimeError("ignore_label should be an integer number\n"))

    def __call__(self, image, label):
        h, w = image.shape[:2]
        pad_h = max(self.crop_h - h, 0)
        pad_w = max(self.crop_w - w, 0)
        pad_h_half = int(pad_h / 2)
        pad_w_half = int(pad_w / 2)
        if pad_h > 0 or pad_w > 0:
            if self.padding is None:
                raise (RuntimeError(
                    "segtransform.Crop() need padding while padding argument is None\n"))
            image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half,
                                       pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.padding)
            # image = np.zeros(3,)
            if label is not None:
                label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half,
                                           pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label)
        h, w = image.shape[:2]
        if self.crop_type == 'rand':
            h_off = random.randint(0, h - self.crop_h)
            w_off = random.randint(0, w - self.crop_w)
        else:
            h_off = int((h - self.crop_h) / 2)
            w_off = int((w - self.crop_w) / 2)
        image = image[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w]
        image = image.astype(np.int)
        if label is not None:
            label = label[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w]
            return image, label
        else:
            return image


class FitCrop(object):
    """Crops the given ndarray image (H*W*C or H*W).
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
        int instead of sequence like (h, w), a square crop (size, size) is made.
    """

    def __init__(self, k=2, multi=False):
        self.k = k  # whether to crop at 1/2 or 1/3,  if fg is very small portion, will cutoff bigger area
        self.multi = multi  # whether to return multiple cropped image

    def __call__(self, image, label):
        h, w, _ = image.shape

        label_binary = label.copy()
        label_binary[label_binary == 255] = 0
        _, labels = cv2.connectedComponents(label_binary)  # labels 为联通域 的 idx

        freq = np.bincount(labels.flatten())
        freq[0] = 0
        obj_idx = np.argmax(freq)      # id for 最大联通域
        pxl_cnt = freq[obj_idx]
        h0, h1, w0, w1 = self.get_coord(labels, obj_idx, h, w)
        image = image[h0:h1, w0:w1]
        label = label[h0:h1, w0:w1]

        if self.multi and len(freq) >= 3:
            freq[obj_idx] = 0
            obj_idx2 = np.argmax(freq)
            pxl_cnt2 = freq[obj_idx2]

            if pxl_cnt2 / pxl_cnt >= 0.3:
                h0, h1, w0, w1 = self.get_coord(labels, obj_idx2, h, w)
                image2 = image[h0:h1, w0:w1]
                label2 = label[h0:h1, w0:w1]

                return image, label, image2, label2

        return image, label

    def get_coord(self, labels, obj_idx, h, w):
        mask_pos = np.where(labels == obj_idx)
        min_h, max_h, min_w, max_w = np.min(mask_pos[0]), np.max(
            mask_pos[0]), np.min(mask_pos[1]), np.max(mask_pos[1])

        h0, h1 = min_h // self.k, h - (h - max_h) // self.k
        w0, w1 = min_w // self.k, w - (w - max_w) // self.k

        if (h1 - h0) / (w1 - w0) <= 0.7:  # height too small
            if h0 <= h - h1:
                h0 = 0
            else:
                h1 = h
        elif (h1 - h0) / (w1 - w0) >= 1.5:  # width too small
            if w0 <= w - w1:
                w0 = 0
            else:
                w1 = w
        return h0, h1, w0, w1


class RandRotate(object):
    # Randomly rotate image & label with rotate factor in [rotate_min, rotate_max]
    def __init__(self, rotate, padding, ignore_label=255, p=0.5):
        assert (isinstance(rotate, collections.Iterable) and len(rotate) == 2)
        if isinstance(rotate[0], numbers.Number) and isinstance(rotate[1], numbers.Number) \
                and rotate[0] < rotate[1]:
            self.rotate = rotate
        else:
            raise (RuntimeError("segtransform.RandRotate() scale param error.\n"))
        assert padding is not None
        assert isinstance(padding, list) and len(padding) == 3
        if all(isinstance(i, numbers.Number) for i in padding):
            self.padding = padding
        else:
            raise (RuntimeError("padding in RandRotate() should be a number list\n"))
        assert isinstance(ignore_label, int)
        self.ignore_label = ignore_label
        self.p = p

    def __call__(self, image, label):
        if random.random() < self.p:
            angle = self.rotate[0] + \
                (self.rotate[1] - self.rotate[0]) * random.random()
            h, w = label.shape
            matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
            image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR,
                                   borderMode=cv2.BORDER_CONSTANT, borderValue=self.padding)
            label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST,
                                   borderMode=cv2.BORDER_CONSTANT, borderValue=self.ignore_label)
        return image, label


class RandomHorizontalFlip(object):
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, image, label):
        if random.random() < self.p:
            image = cv2.flip(image, 1)
            label = cv2.flip(label, 1)
        return image, label


class RandomVerticalFlip(object):
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, image, label):
        if random.random() < self.p:
            image = cv2.flip(image, 0)
            label = cv2.flip(label, 0)
        return image, label


class RandomGaussianBlur(object):
    def __init__(self, radius=5):
        self.radius = radius

    def __call__(self, image, label):
        if random.random() < 0.5:
            image = cv2.GaussianBlur(image, (self.radius, self.radius), 0)
        return image, label


class ColorJitter(object):
    def __init__(self, cj_type='b'):
        self.cj_type = cj_type

    def __call__(self, img, label):
        '''
        ### Different Color Jitter ###
        img: image
        cj_type: {b: brightness, s: saturation, c: constast}
        '''
        if self.cj_type == "b":
            hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
            # Hue, Saturation, and Value (Brightness)
            h, s, v = cv2.split(hsv)
            value = 35 if np.mean(v) <= 125 else -35
            if value >= 0:
                lim = 255 - value
                v[v > lim] = 255
                v[v <= lim] += value
            else:
                lim = np.absolute(value)
                v[v < lim] = 0
                v[v >= lim] -= np.absolute(value)

            final_hsv = cv2.merge((h, s, v))
            img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)

        elif self.cj_type == "s":
            # value = random.randint(-50, 50)
            value = np.random.choice(np.array([0.5, 0.75, 1.25, 1.5]))
            hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
            h, s, v = cv2.split(hsv)
            s *= value

            final_hsv = cv2.merge((h, s, v))
            img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)

        elif self.cj_type == "c":
            brightness = 10
            contrast = random.randint(40, 100)
            dummy = np.int16(img)
            dummy = dummy * (contrast / 127 + 1) - contrast + brightness
            img = np.clip(dummy, 0, 255)

        return img, label


class ColorAug(object):
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        # [max(0, 1 - brightness), 1 + brightness]
        self.brightness = brightness
        # [max(0, 1 - contrast), 1 + contrast]
        self.contrast = contrast
        # [max(0, 1 - saturation), 1 + saturation]
        self.saturation = saturation
        self.hue = hue                                 # [-hue, hue]
        self.gitter = transforms.ColorJitter(
            self.brightness, self.contrast, self.saturation, self.hue)

    def __call__(self, image, label):
        image = Image.fromarray(np.uint8(image)).convert('RGB')
        image = self.gitter(image)
        image = np.array(image)
        return image, label


class Contrast(object):
    def __init__(self, v=0.9, max_v=0.05, bias=0):
        self.v = _float_parameter(v, max_v) + bias

    def __call__(self, image, label):
        image = Image.fromarray(np.uint8(image)).convert('RGB')
        return PIL.ImageEnhance.Contrast(image).enhance(self.v), label


class Brightness(object):
    def __init__(self, v=1.8, max_v=0.1, bias=0):
        self.v = _float_parameter(v, max_v) + bias

    def __call__(self, image, label):
        image = Image.fromarray(np.uint8(image)).convert('RGB')
        return PIL.ImageEnhance.Brightness(image).enhance(self.v), label


class Sharpness(object):
    def __init__(self, v=0.9, max_v=0.05, bias=0):
        self.v = _float_parameter(v, max_v) + bias

    def __call__(self, image, label):
        image = Image.fromarray(np.uint8(image)).convert('RGB')
        return PIL.ImageEnhance.Sharpness(image).enhance(self.v), label


class AutoContrast(object):
    def __call__(self, image, label):
        image = Image.fromarray(np.uint8(image)).convert('RGB')
        return PIL.ImageOps.autocontrast(image), label


def _float_parameter(v, max_v):
    return float(v) * max_v / PARAMETER_MAX


class RGB2BGR(object):
    # Converts image from RGB order to BGR order, for model initialized from Caffe
    def __call__(self, image, label):
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        return image, label


class BGR2RGB(object):
    # Converts image from BGR order to RGB order, for model initialized from Pytorch
    def __call__(self, image, label):
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image, label


### Utils.py

In [33]:
# encoding:utf-8

import os
import cv2
import numpy as np
from tqdm import tqdm
from functools import partial
from multiprocessing import Pool
from collections import defaultdict
from typing import Callable, Dict, Iterable, List, Tuple, TypeVar


IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']

A = TypeVar("A")
B = TypeVar("B")


def mmap_(fn: Callable[[A], B], iter: Iterable[A]) -> List[B]:
    return Pool().map(fn, iter)


def is_image_file(filename: str) -> bool:
    filename_lower = filename.lower()
    return any(filename_lower.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(
        data_root: str,
        data_list: str,
        class_list: List[int]
) -> Tuple[List[Tuple[str, str]], Dict[int, List[Tuple[str, str]]]]:

    if not os.path.isfile(data_list):
        raise (RuntimeError("Image list file do not exist: " + data_list + "\n"))
    '''
        Recovers all tupples (img_path, label_path) relevant to the current experiments (class_list
        is used as filter)

        input:
            data_root : Path to the data directory
            data_list : Path to the .txt file that contain the train/test split of images
            class_list: List of classes to keep
        returns:
            image_label_list: List of (img_path, label_path) that contain at least 1 object of a class
                              in class_list
            class_file_dict: Dict of all (img_path, label_path that contain at least 1 object of a class
                              in class_list, grouped by classes.
    '''
    image_label_list: List[Tuple[str, str]] = []
    list_read = open(data_list).readlines()

    print(f"Processing data for {class_list}")
    class_file_dict: Dict[int, List[Tuple[str, str]]] = defaultdict(list)

    process_partial = partial(
        process_image, data_root=data_root, class_list=class_list)

    for sublist, subdict in mmap_(process_partial, tqdm(list_read)):
        image_label_list += sublist

        for (k, v) in subdict.items():
            class_file_dict[k] += v

    return image_label_list, class_file_dict
    # image_label_list：list of 所有包含的图片 [(image_filename, label_filename)], class_file_dict：cls_id->相应图片list


def process_image(
        line: str,
        data_root: str,
        class_list: List
) -> Tuple[List, Dict]:
    '''
        Reads and parses a line corresponding to 1 file

        input:
            line : A line corresponding to 1 file, in the format path_to_image.jpg path_to_image.png
            data_root : Path to the data directory
            class_list: List of classes to keep

    '''
    line = line.strip()
    line_split = line.split(' ')
    image_name = os.path.join(data_root, line_split[0])   # image_file
    label_name = os.path.join(data_root, line_split[1])   # label_file
    item: Tuple[str, str] = (image_name, label_name)
    label = cv2.imread(label_name, cv2.IMREAD_GRAYSCALE)
    label_class = np.unique(label).tolist()                # 当前图片的所有 category

    if 0 in label_class:
        label_class.remove(0)
    if 255 in label_class:
        label_class.remove(255)
    for label_class_ in label_class:
        assert label_class_ in list(range(1, 81)), label_class_

    c: int
    new_label_class = []                                # 选取符合条件, 在meta train中的label
    for c in label_class:
        if c in class_list:                             # 保证cls在当前图片中占有一定的比重
            tmp_label = np.zeros_like(label)
            target_pix = np.where(label == c)          # 返回 row idx 和 col idx
            tmp_label[target_pix[0], target_pix[1]] = 1
            if tmp_label.sum() >= 2 * 32 * 32:
                new_label_class.append(c)

    label_class = new_label_class    # 筛选了当前图片中满足条件（最少有32*32*2个pixel）的所有cls

    image_label_list: List[Tuple[str, str]] = []
    class_file_dict: Dict[int, List[Tuple[str, str]]] = defaultdict(list)

    if len(label_class) > 0:
        image_label_list.append(item)  # item包含 image filename & label filname

        for c in label_class:
            assert c in class_list
            class_file_dict[c].append(item)

    return image_label_list, class_file_dict
    # image_label_list：list of 所有包含的图片, class_file_dict：cls_id->相应图片list: 只针对当前这张图片！


### Classes.py

In [34]:
# encoding:utf-8

import argparse
from collections import defaultdict
from typing import Dict, List, Any

classId2className = {'coco': {
    1: 'person',
    2: 'bicycle',
    3: 'car',
    4: 'motorcycle',
    5: 'airplane',
    6: 'bus',
    7: 'train',
    8: 'truck',
    9: 'boat',
    10: 'traffic light',
    11: 'fire hydrant',
    12: 'stop sign',
    13: 'parking meter',
    14: 'bench',
    15: 'bird',
    16: 'cat',
    17: 'dog',
    18: 'horse',
    19: 'sheep',
    20: 'cow',
    21: 'elephant',
    22: 'bear',
    23: 'zebra',
    24: 'giraffe',
    25: 'backpack',
    26: 'umbrella',
    27: 'handbag',
    28: 'tie',
    29: 'suitcase',
    30: 'frisbee',
    31: 'skis',
    32: 'snowboard',
    33: 'sports ball',
    34: 'kite',
    35: 'baseball bat',
    36: 'baseball glove',
    37: 'skateboard',
    38: 'surfboard',
    39: 'tennis racket',
    40: 'bottle',
    41: 'wine glass',
    42: 'cup',
    43: 'fork',
    44: 'knife',
    45: 'spoon',
    46: 'bowl',
    47: 'banana',
    48: 'apple',
    49: 'sandwich',
    50: 'orange',
    51: 'broccoli',
    52: 'carrot',
    53: 'hot dog',
    54: 'pizza',
    55: 'donut',
    56: 'cake',
    57: 'chair',
    58: 'sofa',
    59: 'pottedplant',
    60: 'bed',
    61: 'diningtable',
    62: 'toilet',
    63: 'tv',
    64: 'laptop',
    65: 'mouse',
    66: 'remote',
    67: 'keyboard',
    68: 'cell phone',
    69: 'microwave',
    70: 'oven',
    71: 'toaster',
    72: 'sink',
    73: 'refrigerator',
    74: 'book',
    75: 'clock',
    76: 'vase',
    77: 'scissors',
    78: 'teddy bear',
    79: 'hair drier',
    80: 'toothbrush'},

    'pascal': {
    1: 'airplane',   # 0.14
    2: 'bicycle',    # 0.07
    3: 'bird',       # 0.13
    4: 'boat',       # 0.12
    5: 'bottle',     # 0.15
    6: 'bus',        # 0.35
    7: 'cat',        # 0.20
    8: 'car',        # 0.26
    9: 'chair',      # 0.10
    10: 'cow',       # 0.24
    11: 'diningtable',  # 0.22
    12: 'dog',         # 0.23
    13: 'horse',       # 0.21
    14: 'motorcycle',  # 0.22
    15: 'person',      # 0.20
    16: 'pottedplant',  # 0.11
    17: 'sheep',       # 0.19
    18: 'sofa',        # 0.23
    19: 'train',       # 0.27
    20: 'tv'           # 0.14
}
}

className2classId = defaultdict(dict)
for dataset in classId2className:
    for id in classId2className[dataset]:
        className2classId[dataset][classId2className[dataset][id]] = id


def get_split_classes(args: argparse.Namespace) -> Dict[str, Any]:
    """
    Returns the split of classes for Pascal-5i and Coco-20i
    inputs:
        args

    returns :
         split_classes : Dict.
                         split_classes['coco'][0]['train'] = training classes in fold 0 of Coco-20i
    """
    split_classes = {'coco': defaultdict(dict), 'pascal': defaultdict(dict)}

    # =============== COCO ===================
    name = 'coco'
    class_list = list(range(1, 81))
    # key: coco -> -1 -> val  "split -1 包含所有的class"
    split_classes[name][-1]['val'] = class_list
    if args.use_split_coco:
        vals_lists = [list(range(1, 78, 4)), list(range(2, 79, 4)),
                      list(range(3, 80, 4)), list(range(4, 81, 4))]
        # vals_lists = [[5, 2, 15, 9, 40], [6, 3, 16, 57, 20],
        #               [61, 17, 18, 4, 1], [59, 19, 58, 7, 63]]
        for i, val_list in enumerate(vals_lists):
            split_classes[name][i]['val'] = val_list
            split_classes[name][i]['train'] = list(
                set(class_list) - set(val_list))

    else:
        class_list = list(range(1, 81))
        vals_lists = [list(range(1, 21)), list(range(21, 41)),         # 共80个class,4个split.
                      list(range(41, 61)), list(range(61, 81))]
        for i, val_list in enumerate(vals_lists):
            split_classes[name][i]['val'] = val_list
            split_classes[name][i]['train'] = list(
                set(class_list) - set(val_list))

    # =============== Pascal ===================
    name = 'pascal'
    class_list = list(range(1, 21))
    vals_lists = [list(range(1, 6)), list(range(6, 11)),
                  list(range(11, 16)), list(range(16, 21))]
    split_classes[name][-1]['val'] = class_list
    for i, val_list in enumerate(vals_lists):
        split_classes[name][i]['val'] = val_list
        split_classes[name][i]['train'] = list(set(class_list) - set(val_list))

    return split_classes


def filter_classes(train_name: str,
                   train_split: int,
                   test_name: str,
                   test_split: int,
                   split_classes: Dict) -> List[int]:
    """ Useful for domain shift experiments. Filters out classes that were seen
        during  training (i.e in the train_name dataset) from the current list.

    inputs:
        train_name : 'coco' or 'pascal'
        test_name : 'coco' or 'pascal'
        train_split : In {0, 1, 2, 3}
        test_split : In {0, 1, 2, 3, -1}. -1 represents "all classes" (the one used in our experiments)
        split_classes: Dict of all classes used for each dataset and each split


    returns :
        kept_classes_id : Filtered list of class ids that will be used for testing
    """
    print(f'INFO: {train_name} -> {test_name}')
    print(f'INFO: {train_split} -> {test_split}')
    print(">> Start Filtering classes ")
    seen_classes = [classId2className[train_name][c]
                    for c in split_classes[train_name][train_split]['train']]  # 所有meta train cls name
    # meta_test数据 cls id
    initial_classes = split_classes[test_name][test_split]['val']
    kept_classes_id = []
    removed_classes = []
    kept_classes_name = []
    for c in initial_classes:
        if classId2className[test_name][c] in seen_classes:
            removed_classes.append(classId2className[test_name][c])
        else:
            kept_classes_id.append(c)
            kept_classes_name.append(classId2className[test_name][c])
    print(">> Removed classes = {} ".format(removed_classes))
    print(">> Kept classes = {} ".format(kept_classes_name))
    return kept_classes_id


## Dataset

In [35]:
import cv2
import torch
import random
import argparse
import numpy as np
from typing import List
from torch.utils.data import Dataset
from torchvision import transforms as T
from torch.utils.data.distributed import DistributedSampler


### Standard Data

In [36]:
class StandardData(Dataset):
    def __init__(self, args: argparse.Namespace,
                 transform: Compose,
                 data_list_path: str,
                 class_list: List[int],
                 return_paths: bool):
        # path to dataset directory
        self.data_root = args.data_root
        # dict of containing two dicts (pascal & coco), mapping the classes that corresponds to different dataset splits
        self.class_list = class_list
        # return a tuple with: 1, a List of all (image_path, label_path) Tuples; 2, a Dict mapping (image_path, label_path) Tuples for different classes
        self.data_list, _ = make_dataset(
            args.data_root, data_list_path, class_list)
        # Composed transformation for dataset
        self.transform = transform
        # whether to return image&label path in __getitem__
        self.return_paths = return_paths

        def __len__(self):
            return len(self.data_list)

        def __getitem__(self, index):

            # read image and labels
            image_path, label_path = self.data_list[index]
            image = cv2.imread(image_path, cv2.IMREAD_COLOR)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = np.float32(image)
            label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)

            # double check to make sure the shape matches
            if image.shape[0] != label.shape[0] or image.shape[1] != label.shape[1]:
                raise (RuntimeError("Query Image & label shape mismatch: " +
                       image_path + " " + label_path + "\n"))

            # remove unwanted label classes
            label_class = np.unique(label).tolist()
            if 0 in label_class:
                label_class.remove(0)
            if 255 in label_class:
                label_class.remove(255)

            new_label_class = []
            undesired_class = []
            for c in label_class:
                if c in self.class_list:
                    new_label_class.append(c)
                else:
                    undesired_class.append(c)
            label_class = new_label_class
            assert len(label_class) > 0

            # background
            new_label = np.zeros_like(label)
            for lab in label_class:
                indexes = np.where(label == lab)
                new_label[indexes[0], indexes[1]] = self.class_list.index(
                    lab) + 1       # Add 1 because class 0 is for bg
            for lab in undesired_class:
                indexes = np.where(label == lab)
                new_label[indexes[0], indexes[1]] = 255

            if self.transform is not None:
                image, new_label = self.transform(image, new_label)
            if self.return_paths:
                return image, new_label, image_path, label_path
            else:
                return image, new_label


### Episodic Data

In [37]:
class EpisodicData(Dataset):
    def __init__(self, mode_train: bool, dt_transform: Compose, class_list: List[int], args: argparse.Namespace):
        self.args = args
        self.shot = args.shot
        self.class_list = class_list
        self.transform = dt_transform
        self.data_root = args.data_root
        self.random_shot = args.random_shot
        self.aug_type = args.get('aug_type', 0)
        self.meta_aug = args.get('meta_aug', 0)
        self.aug_th = args.get('aug_th', [0.15, 0.30])
        self.padding = [v*255 for v in args.mean] if args.get('padding') == 'avg' else None

        self.data_list, self.sub_class_file_list = make_dataset(args.data_root, args.train_list, self.class_list) if mode_train \
                                                    else make_dataset(args.data_root, args.val_list, self.class_list)

        print(f"Transformations applied: {self.transform.segtransform}")
    
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        # ============================== Build Query ==============================
        # ====== Read query image + Get suitable classes ======
        image_path, label_path = self.data_list[index]
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = np.float32(image)
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)

        assert image.shape[0]==label.shape[0] and image.shape[1] == label.shape[1], f"Query & label shape mismatch: {image_path}, {label_path}\n"
        
        # ====== Retain only suitable class labels =====
        original_label_class = np.unique(label).tolist()
        if 0 in original_label_class:
            original_label_class.remove(0)
        if 255 in original_label_class:
            original_label_class.remove(255)

        label_class = []
        for c in original_label_class:
            if c in self.class_list:
                label_class.append(c)
        
        assert len(label_class)>0, f"No suitable class labels for at 'image path': {image_path}, 'label path': {label_path}"

        # ===== From classes in query image, chose one randomly =====
        class_chosen = np.random.choice(label_class)
        new_label = np.zeros_like(label)
        ignore_pix = np.where(label == 255)
        target_pix = np.where(label == class_chosen)
        new_label[ignore_pix] = 255
        new_label[target_pix] = 1
        label = new_label

        # 当前split 选取的class, 所对应的image/label path
        file_list_class_chosen = self.sub_class_file_list[class_chosen]
        num_file = len(file_list_class_chosen)

        # ============================== Build Support ==============================
        # First randomly choose indexes of support images
        support_image_path_list = []
        support_label_path_list = []
        support_idx_list = []

        shot = random.randint(1, self.shot) if self.random_shot else self.shot

        for k in range(shot):
            support_idx = random.randint(1, num_file) - 1

            # init with this value to ensure going into the loop
            support_image_path = image_path
            support_label_path = label_path

            # 排除 query img 并确保 support image(s) 没有重复
            while(
                (support_image_path == image_path and support_label_path == label_path)
                or
                (support_idx in support_idx_list)
            ):                                                                     
                support_idx = random.randint(1, num_file) - 1
                support_image_path, support_label_path = file_list_class_chosen[support_idx]
            support_idx_list.append(support_idx)
            support_image_path_list.append(support_image_path)
            support_label_path_list.append(support_label_path)

        support_image_list = []
        support_label_list = []
        subcls_list = [self.class_list.index(class_chosen) + 1]

        # Second, read support images and masks
        for k in range(shot):
            support_image_path = support_image_path_list[k]
            support_label_path = support_label_path_list[k]
            support_image = cv2.imread(support_image_path, cv2.IMREAD_COLOR)
            support_image = cv2.cvtColor(support_image, cv2.COLOR_BGR2RGB)
            support_image = np.float32(support_image)
            support_label = cv2.imread(support_label_path, cv2.IMREAD_GRAYSCALE)

            target_pix = np.where(support_label == class_chosen)
            ignore_pix = np.where(support_label == 255)
            support_label[:, :] = 0
            support_label[target_pix[0], target_pix[1]] = 1
            support_label[ignore_pix[0], ignore_pix[1]] = 255

            assert support_image.shape[0] == support_label.shape[0] and support_image.shape[1] == support_label.shape[1],\
                f"Support & label shape mismatch: support image path: {support_image_path}; support label path: {support_label_path}\n"
            
            support_image_list.append(support_image)
            support_label_list.append(support_label)
        assert len(support_label_list) == shot and len(support_image_list) == shot,\
            f"Length of support image/label list is more than {shot} shot(s)\n"

        # Original support images and labels
        support_images = support_image_list.copy()
        support_labels = support_label_list.copy()

        support_images_without_norm = []

        # ============================== Data Transform & Augmentations ==============================
        if self.transform is not None:

            qry_img, target = self.transform(image, label)    # transform query img
            qry_img_without_norm, _ = \
                Compose(self.transform.segtransform[:-1])(image, label) # query img transformed without normalization

            for k in range(shot):                             # transform support img
                if self.meta_aug > 1:
                    org_img, org_label = self.transform(support_image_list[k], support_label_list[k])  # flip and resize
                    label_freq = np.bincount(support_label_list[k].flatten())
                    fg_ratio = label_freq[1] / np.sum(label_freq)

                    if self.aug_type == 0:
                        new_img, new_label = self.get_aug_data0(
                            fg_ratio, support_image_list[k], support_label_list[k])
                    elif self.aug_type == 1:
                        new_img, new_label = self.get_aug_data1(
                            fg_ratio, support_image_list[k], support_label_list[k])
                    elif self.aug_type == 3:
                        new_img, new_label = self.get_aug_data3(
                            fg_ratio, support_image_list[k], support_label_list[k])
                    elif self.aug_type == 10:
                        new_img, new_label = self.get_aug_data10(
                            fg_ratio, support_image_list[k], support_label_list[k])
                    # aug with ColorJitter
                    elif self.aug_type == 4:
                        new_img, new_label = self.get_aug_data4(
                            fg_ratio, support_image_list[k], support_label_list[k], self.args)

                    if new_img is not None:
                        support_image_list[k] = torch.cat(
                            [org_img.unsqueeze(0), new_img], dim=0)
                        support_label_list[k] = torch.cat(
                            [org_label.unsqueeze(0), new_label], dim=0)
                    else:
                        support_image_list[k], support_label_list[k] = org_img.unsqueeze(
                            0), org_label.unsqueeze(0)

                else:
                    support_image_list[k], support_label_list[k] = self.transform(
                        support_image_list[k], support_label_list[k])
                    support_image_list[k] = support_image_list[k].unsqueeze(0)
                    support_label_list[k] = support_label_list[k].unsqueeze(0)

                spt_img_without_norm, spt_label_without_norm = Compose(self.transform.segtransform[:-1])(support_images[k], support_labels[k])
                support_images_without_norm.append(spt_img_without_norm)
                
        # Reshape properly
        spprt_imgs = torch.cat(support_image_list, 0)
        spprt_labels = torch.cat(support_label_list, 0)
        spprt_imgs_without_norm = torch.cat(support_images_without_norm, 0)

        return qry_img, target, spprt_imgs, spprt_labels, subcls_list, \
            [support_image_path_list, support_labels, spprt_imgs_without_norm],\
            [image_path, label, qry_img_without_norm]
        # subcls_list  返回的是 选取的class在所有meta train cls list 中的index+1/

    def get_aug_data0(self, fg_ratio, support_image, support_label):
        if fg_ratio <= self.aug_th[0]:
            k = 2 if fg_ratio <= 0.03 else 3  # whether to crop at 1/2 or 1/3
            meta_trans = Compose(
                [FitCrop(k=k)] + self.transform.segtransform[-3:])
        elif self.aug_th[0] < fg_ratio < self.aug_th[1]:
            meta_trans = Compose(
                [ColorJitter(cj_type='b')] + self.transform.segtransform[-3:])
        else:
            scale = 473 / max(support_label.shape) * 0.8
            meta_trans = Compose([RandScale(scale=(
                scale, scale + 0.1), fixed_size=473, padding=self.padding)] + self.transform.segtransform[-2:])
        new_img, new_label = meta_trans(support_image, support_label)
        return new_img.unsqueeze(0), new_label.unsqueeze(0)

    # only size augmentation, no color augmentation
    def get_aug_data10(self, fg_ratio, support_image, support_label):
        if fg_ratio <= self.aug_th[0] or fg_ratio >= self.aug_th[1]:
            if fg_ratio <= self.aug_th[0]:
                k = 2 if fg_ratio <= 0.03 else 3  # whether to crop at 1/2 or 1/3
                meta_trans = Compose(
                    [FitCrop(k=k)] + self.transform.segtransform[-3:])
            else:
                scale = 473 / max(support_label.shape) * 0.7
                meta_trans = Compose([RandScale(scale=(
                    scale, scale + 0.1), fixed_size=473, padding=self.padding)] + self.transform.segtransform[-2:])
            new_img, new_label = meta_trans(support_image, support_label)
            return new_img.unsqueeze(0), new_label.unsqueeze(0)
        else:
            return None, None

    def get_aug_data1(self, fg_ratio, support_image, support_label):  # create two augmented data
        scale = 473 / max(support_label.shape)

        if fg_ratio <= self.aug_th[0]:  # 0.15
            meta_trans1 = Compose(
                [FitCrop(k=2)] + self.transform.segtransform[-3:])
            meta_trans2 = Compose(
                [FitCrop(k=3)] + self.transform.segtransform[-3:])
        elif self.aug_th[0] < fg_ratio < self.aug_th[1]:
            meta_trans1 = Compose(
                [FitCrop(k=3)] + self.transform.segtransform[-3:])
            meta_trans2 = Compose([RandScale(scale=(scale * 0.85, scale * 0.85 + 0.1),
                                  fixed_size=473, padding=self.padding)] + self.transform.segtransform[-2:])
        else:
            meta_trans1 = Compose([RandScale(scale=(scale * 0.85, scale * 0.85 + 0.1),
                                  fixed_size=473, padding=self.padding)] + self.transform.segtransform[-2:])
            meta_trans2 = Compose([RandScale(scale=(scale * 0.85, scale * 0.85 + 0.1),
                                  fixed_size=473, padding=self.padding)] + self.transform.segtransform[-2:])
        new_img1, new_label1 = meta_trans1(support_image, support_label)
        new_img2, new_label2 = meta_trans2(support_image, support_label)

        new_imgs = torch.cat(
            [new_img1.unsqueeze(0), new_img2.unsqueeze(0)], dim=0)
        new_labels = torch.cat(
            [new_label1.unsqueeze(0), new_label2.unsqueeze(0)], dim=0)
        return new_imgs, new_labels

    def get_aug_data2(self, fg_ratio, support_image, support_label):   # 最初的 data augmentation
        if fg_ratio <= 0.15:
            k = 2 if fg_ratio <= 0.05 else 3
            meta_trans = Compose(
                [FitCrop(k=k)] + self.transform.segtransform[-3:])
        else:
            meta_trans = Compose([RandomHorizontalFlip(
                p=1.0)] + self.transform.segtransform[-3:])
        new_img, new_label = meta_trans(support_image, support_label)
        return new_img.unsqueeze(0), new_label.unsqueeze(0)

    # base data augmentation: resize (with padding)
    def get_aug_data3(self, fg_ratio, support_image, support_label):
        if fg_ratio <= self.aug_th[0]:
            k = 2 if fg_ratio <= 0.03 else 3  # whether to crop at 1/2 or 1/3
            trans_crop = FitCrop(k=k, multi=True)
            crop_out = trans_crop(support_image, support_label)
            meta_trans = Compose(self.transform.segtransform[-3:])

            new_img, new_label = meta_trans(crop_out[0], crop_out[1])
            if len(crop_out) == 2:
                return new_img.unsqueeze(0), new_label.unsqueeze(0)
            elif len(crop_out) == 4:
                new_img2, new_label2 = meta_trans(crop_out[2], crop_out[3])
                return torch.cat([new_img.unsqueeze(0), new_img2.unsqueeze(0)], dim=0), torch.cat([new_label.unsqueeze(0), new_label2.unsqueeze(0)], dim=0)

        elif self.aug_th[0] < fg_ratio < self.aug_th[1]:
            meta_trans = Compose(
                [ColorJitter(cj_type='b')] + self.transform.segtransform[-3:])
        else:
            scale = 473 / max(support_label.shape) * 0.7
            meta_trans = Compose([RandScale(scale=(
                scale, scale + 0.1), fixed_size=473, padding=self.padding)] + self.transform.segtransform[-2:])
        new_img, new_label = meta_trans(support_image, support_label)
        return new_img.unsqueeze(0), new_label.unsqueeze(0)

    def get_aug_data4(self, fg_ratio, support_image, support_label, args):
        if fg_ratio <= self.aug_th[0]:
            k = 2 if fg_ratio <= 0.03 else 3  # whether to crop at 1/2 or 1/3
            meta_trans = Compose(
                [FitCrop(k=k)] + self.transform.segtransform[-3:])
        elif self.aug_th[0] < fg_ratio < self.aug_th[1]:
            meta_trans = Compose([ColorAug(args.get('brightness', 0), args.get('contrast', 0), args.get(
                'saturation', 0), args.get('hue', 0))] + self.transform.segtransform[-3:])
        else:
            scale = 473 / max(support_label.shape) * 0.8
            meta_trans = Compose([RandScale(scale=(
                scale, scale + 0.1), fixed_size=473, padding=self.padding)] + self.transform.segtransform[-2:])
        new_img, new_label = meta_trans(support_image, support_label)
        return new_img.unsqueeze(0), new_label.unsqueeze(0)


### Train Loader

In [38]:
def get_train_loader(args, episodic=True, return_path=False):
    """
        Build the train loader. This is a episodic loader.
    """
    assert args.train_split in [0, 1, 2, 3]
    padding = [
        v*255 for v in args.mean] if args.get('padding') == 'avg' else None
    aug_dic = {
        'randscale': RandScale([args.scale_min, args.scale_max]),
        'randrotate': RandRotate(
            [args.rot_min, args.rot_max],
            padding=[0 for x in args.mean],
            ignore_label=255
        ),
        'hor_flip': RandomHorizontalFlip(),
        'vert_flip': RandomVerticalFlip(),
        'crop': Crop(
            [args.image_size, args.image_size], crop_type='rand',
            padding=[0 for x in args.mean], ignore_label=255
        ),
        'resize': Resize(args.image_size),
        'resize_np': Resize_np(size=(args.image_size, args.image_size)),
        'color_aug': ColorAug(args.get('brightness', 0), args.get('contrast', 0), args.get('saturation', 0), args.get('hue', 0))
    }

    train_transform = [aug_dic[name] for name in args.augmentations]
    train_transform += [ToTensor(), Normalize(mean=args.mean, std=args.std)]
    train_transform = Compose(train_transform)

    # 只用了 args.use_split_coco 这个参数， 返回coco和pascal所有4个split, dict of dict
    split_classes = get_split_classes(args)
    # list of all meta train class labels
    class_list = split_classes[args.train_name][args.train_split]['train']

    # ====== Build loader ======
    if episodic:
        train_data = EpisodicData(
            mode_train=True, dt_transform=train_transform, class_list=class_list, args=args
        )
    else:
        train_data = StandardData(transform=train_transform, class_list=class_list,
                                  return_paths=return_path,  data_list_path=args.train_list,
                                  args=args)

    world_size = torch.distributed.get_world_size() if args.distributed else 1
    train_sampler = DistributedSampler(
        train_data) if args.distributed else None
    batch_size = int(args.batch_size /
                     world_size) if args.distributed else args.batch_size

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True)
    return train_loader, train_sampler


### Val Loader

In [39]:
def get_val_loader(args, episodic=True, return_path=False):
    """
        Build the episodic validation loader.
    """
    assert args.test_split in [0, 1, 2, 3, -1, 'default']

    val_trans = [ToTensor(), Normalize(mean=args.mean, std=args.std)]
    if 'resize_np' in args.augmentations:                                                     # base aug 只有 resize
        val_trans = [Resize_np(size=(args.image_size, args.image_size))] + val_trans
    else:
        padding = [v*255 for v in args.mean] if args.get('padding')=='avg' else None
        val_trans = [Resize(args.image_size, padding=padding)] + val_trans
    val_transform = Compose(val_trans)
    val_sampler = None
    split_classes = get_split_classes(args)     # 返回coco和pascal所有4个split, dict of dict

    # ====== Filter out classes seen during training ======
    if args.test_name == 'default':
        test_name = args.train_name    # 'pascal'
        test_split = args.train_split  # split 0
    else:
        test_name = args.test_name
        test_split = args.test_split
    class_list = filter_classes(args.train_name, args.train_split, test_name, test_split, split_classes)  # 只有cross domain时才有用

    # ====== Build loader ======
    if episodic:
        val_data = EpisodicData(mode_train=False, dt_transform=val_transform, class_list=class_list, args=args)

        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=1,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)
    else:
        class_list = split_classes[args.train_name][args.train_split]['train']
        val_data = StandardData(args, val_transform, class_list=class_list, return_paths=return_path, data_list_path=args.val_list)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 sampler=val_sampler)

    return val_loader, val_transform

# Config

### Src.util

In [40]:
from utils import *

### Load Cfg

In [41]:
cfg = load_cfg_from_cfg_file(CONFIG_PATH)
cfg.data_root = DATA_ROOT
cfg.distributed = False

# ===== disable meta_aug, switch att_type to None, and change shot to 1 =====
cfg.shot = SHOT
cfg.meta_aug = META_AUG          # no additional augmentations will be applied during data loading if set to 1
cfg.aug_type = AUG_TYPE          # default to aug type 0
cfg.att_type = ATT_TYPE          # original, augmented, adaptive
cfg.resume_weights = RESUME_PATH # path to resume_weight for backbone PSPNet

# ===== PROTR MMN Related Configs =====
cfg.norm_s = False
cfg.norm_q = True
cfg.inner_loss = 'wce'
cfg.meta_loss = 'wdc'
cfg.encoder_dim = 512

if (SEARCH_MODE==1):
    cfg.meta_aug = 1
    cfg.att_type = 2
print(cfg)

FB_param_noise: 0
adapt_iter: 100
agg: cat
all_lr: l
arch: resnet
att_drop: 0.5
att_type: 2
att_wt: 0.3
aug_th: [0.12, 0.25]
aug_type: 0
augmentations: ['resize']
aux: False
backbone_dim: 2048
batch_size: 1
batch_size_val: 1
bins: [1, 2, 3, 6]
bottleneck_dim: 512
ckpt_path: checkpoints/
ckpt_used: best
cls_lr: 0.1
cls_type: oooo
conv4d: red
data_root: /Users/nigel/Documents/Research-Git/Dataset/VOCdevkit/VOC2012
dist: dot
distributed: False
dropout: 0.1
encoder_dim: 512
episodic: True
epochs: 8
exp_name: dot1_wt3_aug0_1125_att3
gamma: 0.1
gpus: [0]
image_size: 473
inner_loss: wce
inner_loss_type: wt_ce
layers: 50
log_freq: 50
log_iter: 1190
loss_shot: avg
loss_type: wt_dc
lr_stepsize: 30
m_scale: False
main_optim: SGD
manual_seed: 2021
mean: [0.485, 0.456, 0.406]
meta_aug: 1
meta_loss: wdc
milestones: [40, 70]
mixup: False
model_dir: model_ckpt
momentum: 0.9
n_runs: 1
nesterov: True
norm_feat: True
norm_q: True
norm_s: False
num_classes_tr: 2
num_classes_val: 5
padding: avg
padding_lab

# Model

## PSPNet

In [42]:
from model.pspnet import get_model, PSPNet

In [43]:
# psp backbone
backbone = get_model(cfg)
backbone_dict = backbone.state_dict()
# load pspnet weight
# ===================================================================
if cfg.resume_weights:
    if cfg.get('wt_file', 0) == 1:
        fname = cfg.resume_weights + 'without_bias.pth'
    else:
        fname = cfg.resume_weights + 'with_bias.pth'
    if os.path.isfile(fname):
        print("=> loading weight '{}'".format(fname))
        pre_weight = torch.load(fname, map_location=torch.device('cpu'))['state_dict']
        model_dict = backbone.state_dict()

        for index, key in enumerate(model_dict.keys()):
            if 'classifier' not in key and 'gamma' not in key:
                if model_dict[key].shape == pre_weight[key].shape:
                    model_dict[key] = pre_weight[key]
                else:
                    print( 'Pre-trained shape and model shape dismatch for {}'.format(key) )

        backbone.load_state_dict(model_dict, strict=True)
        print("=> loaded weight '{}'".format(fname))
    else:
        print("=> no weight found at '{}'".format(fname))
# ===================================================================

# if cfg.resume_weights:
#     lines = []
#     cfg.resume_weight = f'pretrained/{cfg.train_name}/split{cfg.train_split}/pspnet_resnet{cfg.layers}/best.pth'
#     if os.path.isfile(cfg.resume_weight):
#         lines.append(f'==> loading backbone weight from: {cfg.resume_weight}')
#         pre_dict, cur_dict = torch.load(cfg.resume_weight, map_location=torch.device('cpu'))['state_dict'], backbone.state_dict()
#         for key1, key2 in zip(pre_dict.keys(), cur_dict.keys()):
#             if pre_dict[key1].shape != cur_dict[key2].shape:
#                 lines.append(f'Pre-trained {key1} shape and model {key2} shape: {pre_dict[key1].shape}, {cur_dict[key2].shape}')
#                 continue
#             cur_dict[key2] = pre_dict[key1] 
#         msg = backbone.load_state_dict(cur_dict, strict=True)
#         lines.append(f"==> {msg}")
#     else:
#         lines.append(f"==> no weight found at '{cfg.resume_weight}'")
#     print('\n'.join(lines))

#     # freeze backbone
#     for p in backbone.parameters():
#         p.requires_grad = False

=> loading weight '/Users/nigel/Documents/Research-Git/Vis/pretrained/pascal/split0/pspnet_resnet50/without_bias.pth'
=> loaded weight '/Users/nigel/Documents/Research-Git/Vis/pretrained/pascal/split0/pspnet_resnet50/without_bias.pth'


## MMN

### MMN CWT

In [44]:
from model.mmn import MMN as MMNOLD

In [45]:
#  Trans
Trans_old = MMNOLD(cfg, agg=cfg.agg, wa=cfg.wa, red_dim=cfg.red_dim).cpu()
meta_model_weight_pth = os.path.join(META_MODEL_PATH, f"{'t' if cfg.meta_aug>1 else 'f'}11e_pm{cfg.shot}{cfg.train_split}", "best.pth")
# meta_model_weight_pth = '/Users/nigel/Documents/Research-Git/Vis/pretrained/pascal_resnet50/mmn/best.pth' 
# load Trans weight
pre_dict, cur_dict = torch.load(meta_model_weight_pth, map_location=torch.device('cpu'))['state_dict'], Trans_old.state_dict()
print(f"Start Loading weight from path: {meta_model_weight_pth}")
for key1, key2 in zip(list(pre_dict.keys())[1:], cur_dict.keys()):
    if pre_dict[key1].shape != cur_dict[key2].shape:
        print(f'Pre-trained {key1} shape and model {key2} shape: {pre_dict[key1].shape}, {cur_dict[key2].shape}')
        continue
    cur_dict[key2] = pre_dict[key1] 
# Trans.load_state_dict(trans_weight['state_dict'])
classifier_key = list(pre_dict.keys())[0]
backbone_dict[classifier_key] = pre_dict[classifier_key]
print(f"Loaded Weight Successfully")

Start Loading weight from path: /Users/nigel/Documents/Research-Git/Vis/pretrained/mmn-all/f11e_pm10/best.pth
Loaded Weight Successfully


### MMN PROTR

In [46]:
from mmn.module.mmn import MMN

In [47]:
#  Trans
Trans = MMN(cfg).cpu()
meta_model_weight_pth = os.path.join(META_MODEL_PATH, f"{'t' if cfg.meta_aug>1 else 'f'}11e_pm{cfg.shot}{cfg.train_split}", "best.pth")
# meta_model_weight_pth = '/Users/nigel/Documents/Research-Git/Vis/pretrained/pascal_resnet50/mmn/best.pth' 
# load Trans weight
pre_dict, cur_dict = torch.load(meta_model_weight_pth, map_location=torch.device('cpu'))['state_dict'], Trans.state_dict()
pre_key_list, cur_key_list = list(pre_dict.keys())[1:], list(cur_dict.keys())

cur_key_list = filter(lambda x: not "wa_2" in x, cur_key_list)

print(f"Start Loading weight from path: {meta_model_weight_pth}")
for key1, key2 in zip(pre_key_list, cur_key_list):
    if pre_dict[key1].shape != cur_dict[key2].shape:
        print(f'Pre-trained {key1} shape and model {key2} shape: {pre_dict[key1].shape}, {cur_dict[key2].shape}')
        continue
    cur_dict[key2] = pre_dict[key1] 
# Trans.load_state_dict(trans_weight['state_dict'])
classifier_key = list(pre_dict.keys())[0]
backbone_dict[classifier_key] = pre_dict[classifier_key]
print(f"Loaded Weight Successfully")

Start Loading weight from path: /Users/nigel/Documents/Research-Git/Vis/pretrained/mmn-all/f11e_pm10/best.pth
Pre-trained trans.corr_net.NeighConsensus.conv.0.conv1.weight shape and model corr_net.NeighConsensus.conv.0.conv1.weight shape: torch.Size([10, 2, 3, 3]), torch.Size([10, 3, 3, 3])
Pre-trained trans.corr_net.NeighConsensus.conv.0.conv2.weight shape and model corr_net.NeighConsensus.conv.0.conv2.weight shape: torch.Size([10, 2, 3, 3]), torch.Size([10, 3, 3, 3])
Loaded Weight Successfully


In [48]:
pre_dict.keys()

odict_keys(['classifier.classifier.weight', 'trans.wa_3.conv_theta.weight', 'trans.wa_3.conv_theta.bias', 'trans.wa_3.conv_phi.weight', 'trans.wa_3.conv_phi.bias', 'trans.wa_3.conv_g.weight', 'trans.wa_3.conv_g.bias', 'trans.wa_3.conv_back.weight', 'trans.wa_3.conv_back.bias', 'trans.wa_4.conv_theta.weight', 'trans.wa_4.conv_theta.bias', 'trans.wa_4.conv_phi.weight', 'trans.wa_4.conv_phi.bias', 'trans.wa_4.conv_g.weight', 'trans.wa_4.conv_g.bias', 'trans.wa_4.conv_back.weight', 'trans.wa_4.conv_back.bias', 'trans.corr_net.NeighConsensus.conv.0.conv1.weight', 'trans.corr_net.NeighConsensus.conv.0.conv1.bias', 'trans.corr_net.NeighConsensus.conv.0.conv2.weight', 'trans.corr_net.NeighConsensus.conv.0.conv2.bias', 'trans.corr_net.NeighConsensus.conv.2.conv1.weight', 'trans.corr_net.NeighConsensus.conv.2.conv1.bias', 'trans.corr_net.NeighConsensus.conv.2.conv2.weight', 'trans.corr_net.NeighConsensus.conv.2.conv2.bias', 'trans.corr_net.NeighConsensus.conv.4.conv1.weight', 'trans.corr_net.Nei

In [49]:
len(list(pre_dict.keys()))

29

In [50]:
cur_dict.keys()

odict_keys(['wa_2.conv_theta.weight', 'wa_2.conv_theta.bias', 'wa_2.conv_phi.weight', 'wa_2.conv_phi.bias', 'wa_2.conv_g.weight', 'wa_2.conv_g.bias', 'wa_2.conv_back.weight', 'wa_2.conv_back.bias', 'wa_3.conv_theta.weight', 'wa_3.conv_theta.bias', 'wa_3.conv_phi.weight', 'wa_3.conv_phi.bias', 'wa_3.conv_g.weight', 'wa_3.conv_g.bias', 'wa_3.conv_back.weight', 'wa_3.conv_back.bias', 'wa_4.conv_theta.weight', 'wa_4.conv_theta.bias', 'wa_4.conv_phi.weight', 'wa_4.conv_phi.bias', 'wa_4.conv_g.weight', 'wa_4.conv_g.bias', 'wa_4.conv_back.weight', 'wa_4.conv_back.bias', 'corr_net.NeighConsensus.conv.0.conv1.weight', 'corr_net.NeighConsensus.conv.0.conv1.bias', 'corr_net.NeighConsensus.conv.0.conv2.weight', 'corr_net.NeighConsensus.conv.0.conv2.bias', 'corr_net.NeighConsensus.conv.2.conv1.weight', 'corr_net.NeighConsensus.conv.2.conv1.bias', 'corr_net.NeighConsensus.conv.2.conv2.weight', 'corr_net.NeighConsensus.conv.2.conv2.bias', 'corr_net.NeighConsensus.conv.4.conv1.weight', 'corr_net.Neigh

In [51]:
len(list(cur_dict.keys()))

36

## Optimizer.py

In [52]:
import torch
import argparse
import torch.nn as nn
from typing import List
from torch.optim.lr_scheduler import MultiStepLR, StepLR, CosineAnnealingLR


def get_optimizer(
        args: argparse.Namespace, parameters: List[nn.Module]
) -> torch.optim.Optimizer:
    if args.main_optim == 'SGD':
        return torch.optim.SGD(
            parameters, momentum=args.momentum,
            weight_decay=args.weight_decay, nesterov=args.nesterov
        )
    elif args.main_optim == 'Adam':
        return torch.optim.Adam(parameters, weight_decay=args.weight_decay)


def get_scheduler(
        args: argparse.Namespace,
        optimizer: torch.optim.Optimizer, batches: int
) -> torch.optim.lr_scheduler._LRScheduler:
    """
    cosine will change learning rate every iteration, others change learning rate every epoch
    :param batches: the number of iterations in each epochs
    :return: scheduler
    """
    SCHEDULERS = {
        'step': StepLR(optimizer, args.lr_stepsize, args.gamma),
        'multi_step': MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma),
        'cosine': CosineAnnealingLR(optimizer, batches * args.epochs, eta_min=1e-6),
        None: None
    }
    return SCHEDULERS[args.scheduler]

## Data Loader

In [53]:
import json
from einops import rearrange
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms.functional as F_trans

In [54]:
# Load necessary list for val loader

with open(os.path.join(CLASS_LIST_PATH, "data-list.json"), 'r') as f:
    cfg.data_list = json.load(f)

with open(os.path.join(CLASS_LIST_PATH, "sub-class.json"), 'r') as f:
    cfg.sub_class_file_dict = json.load(f)

In [55]:
class DatasetLoader:
    def __init__(self, cfg, train_loader=False, val_loader=True):
        assert train_loader or val_loader, "At lease one of train/val loader should be True"
        if (train_loader):
            train_loader, _ = get_train_loader(cfg)
            self.train = True
        else:
            print(f"WARNING: Train loader disabled")
            train_loader = []
            self.train = False
        if (val_loader):
            val_loader, _ = get_val_loader(cfg)
            self.val = True
        else:
            print(f"WARNING: Val loader disabled")
            val_loader = []
            self.val = False
        
        self.cfg = cfg
        self.train_loader = iter(train_loader)
        self.val_loader = iter(val_loader)

    def __len__(self):
        return len(self.val_loader) if self.val_loader else len(self.train_loader)   

    def next(self, train=False):
        if (train):
            assert self.train
            qry_img, q_label, spt_imgs, s_label, subcls, spt_info, q_info = self.train_loader.next()
        else:
            assert self.val
            qry_img, q_label, spt_imgs, s_label, subcls, spt_info, q_info = self.val_loader.next()

        return qry_img, q_label, spt_imgs, s_label, subcls, spt_info, q_info 

    def reset(self):
        print(f"Start resetting loaders.")
        if (self.train):
            train_loader, _ = get_train_loader(cfg)
            self.train_loader = iter(train_loader)
        if (self.val):
            val_loader, _ = get_val_loader(cfg)
            self.val_loader = iter(val_loader)
        print(f"Reset complete. Train: {self.train}; Val: {self.val}")

dataset_loader = DatasetLoader(cfg, train_loader=True, val_loader=True)

Processing data for [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]


100%|██████████| 5953/5953 [00:01<00:00, 3367.16it/s]


Transformations applied: [<__main__.Resize object at 0x7fb7978e1350>, <__main__.ToTensor object at 0x7fb7978e1690>, <__main__.Normalize object at 0x7fb7978e1850>]
INFO: pascal -> pascal
INFO: 0 -> 0
>> Start Filtering classes 
>> Removed classes = [] 
>> Kept classes = ['airplane', 'bicycle', 'bird', 'boat', 'bottle'] 
Processing data for [1, 2, 3, 4, 5]


100%|██████████| 1449/1449 [00:00<00:00, 6310.17it/s]


Transformations applied: [<__main__.Resize object at 0x7fb79747c1d0>, <__main__.ToTensor object at 0x7fb797a68050>, <__main__.Normalize object at 0x7fb7975c6b90>]


## Model Wrapper Class

In [56]:
from model import SegLoss
import matplotlib.pyplot as plt

In [61]:
class ModelWrapper:
    def __init__(self, backbone: PSPNet, meta_model: MMN, cfg: CfgNode, dataset_loader: DatasetLoader, search_mode:int = 0):
        # assert type(backbone)==PSPNet and type(meta_model)==MMN, "Backbone should be instance of PSPNet, meta_model should be instance of MMN"
        self.backbone, self.meta_model, self.cfg, self.dataset_loader, self.search_mode = backbone, meta_model, cfg, dataset_loader, search_mode
        # self.optimizer_meta = get_optimizer(self.cfg, [dict(params=self.meta_model.parameters(), lr=self.cfg.trans_lr * self.cfg.scale_lr)])
        # self.scheduler = get_scheduler(self.cfg, self.optimizer_meta, len(self.dataset_loader))

    def get_pred(self, data=None, omit_out=False):
        data = data if data else self.dataset_loader.next(train=False)
        assert len(data)==7, "Data should be in the form of (qry_img, q_label, spt_imgs, s_label, subcls, spt_info, q_info)"

        # Shape:
        #   qry_img: [1, 3, h, w]
        #   q_label: [1, h, w]
        #   spt_imgs: [1, shot & aug, 3, h, w]
        #   s_label: [1, shot & aug, h, w]

        qry_img, q_label, spt_imgs, s_label, subcls, spt_info, q_info = data

        num_spt = spt_imgs.shape[1]

        train_loss_meter0 = AverageMeter()
        train_iou_meter0 = AverageMeter()

        train_loss_meter1 = AverageMeter()
        train_iou_meter1 = AverageMeter()

        train_iou_compare = CompareMeter()

        if torch.cuda.is_available():
            spt_imgs = spt_imgs.cuda()
            s_label = s_label.cuda()
            q_label = q_label.cuda()
            qry_img = qry_img.cuda()

        # ====== Phase 1: Train a new binary classifier on support samples. ======
        spt_imgs = spt_imgs.squeeze(0)
        s_label = s_label.squeeze(0).long()

        self.backbone.eval()
        with torch.no_grad():
            f_s, fs_lst = self.backbone.extract_features(spt_imgs)
            
        self.backbone.classifier.train()
        if (self.search_mode==0):
            self.backbone.inner_loop(f_s[0].unsqueeze(0), s_label[0].unsqueeze(0))
        else:
            self.backbone.inner_loop(f_s, s_label)

        # ====== Phase 2: Update query score using attention. ======

        self.backbone.eval()
        criterion = SegLoss(loss_type=self.cfg.loss_type)

        with torch.no_grad():
            f_q, fq_lst = self.backbone.extract_features(qry_img)
            pred_q0 = self.backbone.classifier(f_q)
            pred_q0 = F.interpolate(pred_q0, size=q_label.shape[1:], mode='bilinear', align_corners=True)

        if (self.search_mode==0):
            self.backbone.classifier.train()
            self.backbone.inner_loop(f_s, s_label)

            self.backbone.eval()
            with torch.no_grad():
                pred_q1 = self.backbone.classifier(f_q)
                pred_q1 = F.interpolate(pred_q0, size=q_label.shape[1:], mode='bilinear', align_corners=True)

            pred_q2 = pred_q1

        else:
            # adaptive attention for feature selection
            if self.cfg.get('att_type', 2) == 0 or self.cfg.get('att_type', 2) == 1:
                fs_lst = { k: [tensor_slice(e, idx=self.cfg.att_type) for e in v] for k, v in fs_lst.items() }
                f_s = tensor_slice(f_s, idx=self.cfg.att_type)
            elif self.cfg.get('att_type', 2) == 3:
                with torch.no_grad():
                    pred_s0 = self.backbone.classifier(f_s)
                    pred_s0 = F.interpolate(pred_s0, size=s_label.shape[1:], mode='bilinear', align_corners=True)   # [B, 2, 473, 473]
                    intersection, union, _ = batch_intersectionAndUnionGPU(pred_s0.unsqueeze(0), s_label.unsqueeze(0), num_classes=2, ignore_index=255)
                    iou = torch.mean( (intersection / (union + 1e-10)).squeeze(0), dim=-1 )
                    fs_lst = { k: [tensor_slice(e, ref=iou) for e in v] for k, v in fs_lst.items() }
                    f_s = tensor_slice(f_s, ref=iou)
            
            # meta model forward & train
            self.meta_model.train()
            
            att_fq = []
            loss_sum = 0

            for k in range(len(f_s)):
                single_fs_lst = { key: [ve[k:k + 1] for ve in value] for key, value in fs_lst.items() }
                single_f_s = f_s[k:k+1]
                _, att_fq_single = self.meta_model(fq_lst, single_fs_lst, single_f_s)
                att_fq.append(att_fq_single)

            att_fq = torch.cat(att_fq, dim=0)
            att_fq = att_fq.mean(dim=0, keepdim=True)
            fq = f_q * (1-self.cfg.att_wt) + att_fq * self.cfg.att_wt

            pred_q1 = self.backbone.classifier(att_fq)
            pred_q1 = F.interpolate(pred_q1, size=q_label.shape[-2:], mode='bilinear', align_corners=True)
            pred_q2 = self.backbone.classifier(fq)
            pred_q2 = F.interpolate(pred_q2, size=q_label.shape[-2:], mode='bilinear', align_corners=True)

        # Loss function: Dynamic class weights used for query image only during training
        q_loss0 = criterion(pred_q0, q_label.long())
        q_loss2 = criterion(pred_q2, q_label.long())

        if self.cfg.loss_shot == 'avg':
            q_loss1 = criterion(pred_q1, q_label.long())
        else:   # 'sum'
            q_loss1 = loss_sum
        
        if (self.cfg.get('aux', False)) == False:
            loss = q_loss1
        else:
            loss = q_loss1 + self.cfg.aux * q_loss2

        # self.optimizer_meta.zero_grad()
        # loss.backward()
        # self.optimizer_meta.step()
        # if self.cfg.scheduler == 'cosine':
        #     self.scheduler.step()

        IoUb, IoUf = dict(), dict()                                               # IoU background, IoU foreground
        for (pred, idx) in [(pred_q0, 0), (pred_q1, 1), (pred_q2, 2)]:
            intersection, union, _ = intersectionAndUnionGPU(pred.argmax(1), q_label, self.cfg.num_classes_tr, 255)
            IoUb[idx], IoUf[idx] = (intersection / (union + 1e-10)).cpu().numpy()  # mean of BG and FG

        train_loss_meter0.update(q_loss0.item() / self.cfg.batch_size, 1)
        train_iou_meter0.update((IoUf[0]+IoUb[0])/2, 1)
        train_loss_meter1.update(q_loss1.item() / self.cfg.batch_size, 1)
        train_iou_meter1.update((IoUf[1] + IoUb[1]) / 2, 1)
        train_iou_compare.update(IoUf[1], IoUf[0])
        
        msg = f"""IoUf0={'{:2f}'.format(IoUf[0])}-----IoUb0={'{:2f}'.format(IoUb[0])}
                IoUf1={'{:2f}'.format(IoUf[1])}-----IoUb1={'{:2f}'.format(IoUb[1])}
                IoUf2={'{:2f}'.format(IoUf[2])}-----IoUb2={'{:2f}'.format(IoUb[2])}
                loss0={'{:2f}'.format(q_loss0)}-----loss1={'{:2f}'.format(q_loss1)}
                difference-(loss2-loss0)={'{:2f}'.format(q_loss2-q_loss0)}
                support-path={spt_info[0][0][0]}
                query-path={q_info[0][0]}""".replace('\t', '').replace(' ', '')
                # lr {'{:2f}'.format(self.optimizer_meta.param_groups[0]['lr'])}
        if not (omit_out):
            print(msg)

        return pred_q0, pred_q1, pred_q2, spt_imgs, s_label, qry_img, q_label, spt_info, q_info, (q_loss0, q_loss1, q_loss2), (IoUf, IoUb)

model = ModelWrapper(backbone, Trans, cfg, dataset_loader, SEARCH_MODE)

# Visualizer

In [62]:
class Visualizer:

    '''
        Visualizer for displaying and saving visualizations

        Args:

            cfg (CfgNode): a JS Object like configuration instance for storing relevant configs

            datasetloader (DatasetLoader): dataset loader instance for iteratively accessing episodic data

            model (VisModelWrapper): a model instance that wraps backbone & meta model, providing useful accessing methods

            transparency (float): set transparency for original images, default to 1.0


        Attributes:

            cfg (CfgNode): a JS Object like configuration instance for storing relevant configs

            datasetloader (DatasetLoader): dataset loader instance for iteratively accessing episodic data

            model (VisModelWrapper): a model instance that wraps backbone & meta model, providing useful accessing methods

            transparency (float): set transparency for original images, default to 1.0


        Methods:

            shownext(self, train=False)
                plot the next batch of episodic data, default using data from validation loader
            
            showprediction(self, train=False)
                plot the prediction of next episodic data, default using data from validation loader

            search(self, threshold: float, num_episodes: int, save_path: str, color_mode="R")
                search for suitable images
            
    '''

    def __init__(self, cfg: CfgNode, datasetloader: DatasetLoader, model: ModelWrapper, transparency: float):
        assert cfg is not None and type(cfg)==CfgNode, "cfg should not be None and should be an instance of class CfgNode"
        assert datasetloader is not None and type(datasetloader)==DatasetLoader, "datasetloader should not be None and should be an instance of class DatasetLoader"
        assert model is not None and type(model)==ModelWrapper, "model should not be None and should be an instance of class ModelWrapper"
        assert transparency is not None and type(transparency) in (float, int), "transparency should not be None and should be either integer OR float"

        self.cfg, self.datasetloader, self.model, self.transparency = cfg, datasetloader, model, transparency


    def _plot(self, img: torch.Tensor) -> None:

        '''
            Utility method for plotting a single image

            Args:

                img (torch.Tensor): an image of type torch.Tensor, shape should be (3, h, w) OR (h, w, 3)

            Returns: 

                None
        '''

        img = img.squeeze()

        if img.shape[0] == 3:
            plt.imshow(img.permute(1, 2, 0))
        else:
            plt.imshow(img)

    def _plot_list(self, img_list: list, spt_path: str, qry_path: str, save_path=None) -> None:

        '''
            Utility method for plotting a list of images

            Args:

                img_list (list): a list of dict, each instance contraining a dict like { 'img': <img_tensor>, 'label': <str_label> }

            Returns: 
            
                None
        '''

        print(f"Support Path: {spt_path}\nQuery Path: {qry_path}\n----------------------------------------------------------------------------------------------------")

        f = plt.figure(figsize=(20, 20*len(img_list)), dpi=300)

        for i in range(len(img_list)):
            print(img_list[i]['label'])
            plt.subplot(1, len(img_list), i+1)
            plt.axis('off')
            self._plot(img_list[i]['img'])

        if (save_path):
            f.savefig(save_path, bbox_inches='tight', dpi=300)


    def _get_masked_image(self, img: torch.Tensor, mask: torch.Tensor, mode="B") -> torch.Tensor:

        '''
            Utility method for getting an image with specified mask

            Args:

                img (torch.Tensor): an image of type torch.Tensor, shape should be (h, w, 3)

                mask (torch.Tensor): a mask of type torch.Tensor, shape should be (h, w)

                mode (str): 'R','G','B' red, green, blue

            Returns:

                masked_img (torch.Tensor): a new image with mask applied

            Raises:

                RuntimeWarning: Mode should be 'R','G','B'; defaulting to use mode 'R'
        '''

        assert len(img.shape) == 3 and (len(mask.shape) == 2 or len(mask.shape) == 3) , "Img/Mask shape invalid"
        mask = mask.squeeze()
        mask = rearrange(mask, 'h w -> (h w)')
        mask[torch.where(mask==255)[0]] = 0
        mask = rearrange(mask, '(h w) -> h w', h=473).unsqueeze(2)

        n = torch.zeros(mask.shape)

        if mode == 'R':
            mask = torch.cat((mask, n, n), dim = 2)
        elif mode == 'G':
            mask = torch.cat((n, mask, n), dim = 2)
        elif mode == 'B':
            mask = torch.cat((n, n, mask), dim = 2)
        else:
            mask = torch.cat((mask, n, n), dim = 2)
            raise RuntimeWarning("Mode should be 'R','G','B'; defaulting to use mode 'R'")
        
        masked_img = mask + img * self.transparency
        
        return masked_img

    def shownext(self, train=False, save_path=None):
        '''
            Displaying next episode images

            Args:

                train (bool): whether to use data from train loader
        '''

        # Shape:
        #   qry_img: [1, 3, h, w]
        #   q_label: [1, h, w]
        #   spt_imgs: [1, shot & aug, 3, h, w]
        #   s_label: [1, shot & aug, h, w]


        qry_img, q_label, spt_imgs, s_label, subcls, spt_info, q_info = self.datasetloader.next(train)

        s_label = s_label.squeeze(0)
        q_label = q_label.squeeze(0)

        spt_imgs = spt_info[2].squeeze()
        spt_imgs = rearrange(spt_imgs, 'c h w -> h w c')
        qry_img = q_info[2].squeeze()
        qry_img = rearrange(qry_img, 'c h w -> h w c')

        qry_path = q_info[0][0]
        spt_path = spt_info[0][0][0]

        masked_spt_im = self._get_masked_image(spt_imgs, s_label)

        img_list = [
            {'img': qry_img, 'label': f"Query Image"},
            {'img': spt_imgs, 'label': f"Original Support Image"},
            {'img': masked_spt_im, 'label': f"Support Image with Mask"},
        ]

        self._plot_list(img_list, spt_path, qry_path)

    def shownextpred(self, save_path=None):
        pred_q0, pred_q1, pred_q2, spt_imgs, s_label, qry_img, q_label, spt_info, q_info, loss_info = self.model.get_pred(data=self.datasetloader.next())

        spt_imgs = spt_info[2][0].squeeze().detach()
        spt_imgs = rearrange(spt_imgs, 'c h w -> h w c')
        qry_img = q_info[2].squeeze().detach()
        qry_img = rearrange(qry_img, 'c h w -> h w c')

        qry_path = q_info[0][0]
        spt_path = spt_info[0][0][0]

        masked_spt_im = self._get_masked_image(spt_imgs, s_label[0].squeeze().detach())

        pred_q0 = torch.argmax(pred_q0.detach(), dim=1).squeeze(0)
        pred_q1 = torch.argmax(pred_q1.detach(), dim=1).squeeze(0)
        pred_q2 = torch.argmax(pred_q2.detach(), dim=1).squeeze(0)

        masked_qry_img_truth = self._get_masked_image(qry_img, q_label)
        masked_qry_img0 = self._get_masked_image(qry_img, pred_q0)
        masked_qry_img2 = self._get_masked_image(qry_img, pred_q2)

        img_list = [
            {'img': spt_imgs, 'label': f"Original Support Image"},
            {'img': masked_spt_im, 'label': f"Support Image with Mask"},
            {'img': qry_img, 'label': f"Query Image"},
            {'img': masked_qry_img_truth, 'label': f"Query Image with Mask Ground Truth"},
            {'img': masked_qry_img0, 'label': f"Query Image with Mask 0"},
            {'img': masked_qry_img2, 'label': f"Query Image with Mask 2"},
        ]

        self._plot_list(img_list, spt_path, qry_path, save_path)


    def search(self, threshold: float, num_episodes: int, save_path: str, color_mode:str="R", search_mode:int=0) -> None:

        '''
            Search from validation data and retain episodes that satisfy required decision threshold

            Args:

                threshold (float): decision threshold -- episodes that satisfies `MAX(pred1-pred0, pred2-pred0) > threshold` will be retained

                num_episodes (int): number of episodic data to test before stop

                save_path (str): directory path to save result images

                color_mode (str): mask color

                search_mode (int): determines which search mode to use

            Raises:

                AssertionError: type does not match/tensor shape does not match

                RuntimeError: Any

            Returns:

                None
        '''
        assert type(num_episodes)==int, 'Number of episodes should be of type int'

        for i in range(num_episodes):
            omit_out = not bool(save_path)
            pred_q0, pred_q1, pred_q2, spt_imgs, s_label, qry_img, q_label, spt_info, q_info, loss_info, IoU_info = self.model.get_pred(data=self.datasetloader.next(), omit_out=omit_out)

            if (IoU_info[0][2]+IoU_info[1][2] - (IoU_info[0][0]+IoU_info[1][0])>=2*threshold):
                original_spt_imgs = spt_info[2][0].squeeze().detach()
                original_spt_imgs = rearrange(original_spt_imgs, 'c h w -> h w c')
                qry_img = q_info[2].squeeze().detach()
                qry_img = rearrange(qry_img, 'c h w -> h w c')

                qry_path = q_info[0][0]
                spt_path = spt_info[0][0][0]

                masked_spt_im = self._get_masked_image(original_spt_imgs, s_label[0].squeeze().detach(), mode=color_mode)

                if (search_mode!=1):
                    augmented_spt_img = spt_imgs.squeeze(0)[1]
                    augmented_spt_img = rearrange(augmented_spt_img, 'c h w -> h w c')
                    masked_augmented_spt_im = self._get_masked_image(augmented_spt_img, s_label[0].squeeze().detach(), mode=color_mode)
                

                pred_q0 = torch.argmax(pred_q0.detach(), dim=1).squeeze(0)
                pred_q1 = torch.argmax(pred_q1.detach(), dim=1).squeeze(0)
                pred_q2 = torch.argmax(pred_q2.detach(), dim=1).squeeze(0)

                masked_qry_img_truth = self._get_masked_image(qry_img, q_label, mode=color_mode)
                masked_qry_img0 = self._get_masked_image(qry_img, pred_q0, mode=color_mode)
                masked_qry_img2 = self._get_masked_image(qry_img, pred_q2, mode=color_mode)

                if (search_mode!=1):
                    img_list = [
                        {'img': original_spt_imgs, 'label': f"Original Support Image"},
                        {'img': masked_spt_im, 'label': f"Support Image with Mask"},
                        {'img': masked_augmented_spt_im, 'label': f"Augmented Support Image with Mask"},
                        {'img': qry_img, 'label': f"Query Image"},
                        {'img': masked_qry_img_truth, 'label': f"Query Image with Mask Ground Truth"},
                        {'img': masked_qry_img0, 'label': f"Query Image with Mask 0"},
                        {'img': masked_qry_img2, 'label': f"Query Image with Mask 2"},
                    ]
                else:
                    img_list = [
                        {'img': original_spt_imgs, 'label': f"Original Support Image"},
                        {'img': masked_spt_im, 'label': f"Support Image with Mask"},
                        {'img': qry_img, 'label': f"Query Image"},
                        {'img': masked_qry_img_truth, 'label': f"Query Image with Mask Ground Truth"},
                        {'img': masked_qry_img0, 'label': f"Query Image with Mask 0"},
                        {'img': masked_qry_img2, 'label': f"Query Image with Mask 2"},
                    ]

                self._plot_list(img_list, spt_path, qry_path, f'{save_path}/{i}.png')


visualizer = Visualizer(cfg, dataset_loader, model, 1.0)

In [63]:
SEARCH_MODE

1

# Search

In [64]:
from datetime import datetime
save_path = f"/Users/nigel/Documents/Research-Git/Vis/result_imgs/{datetime.now().strftime('%Y-%m-%d--%H-%M-%S')}"
os.mkdir(save_path)
visualizer.search(0.05, 2, save_path, COLOR_MODE, SEARCH_MODE)

AttributeError: 'list' object has no attribute 'size'