In [18]:
import cv2
import os
import math
import numbers
import random
import logging
import numpy as np

import torch
from torch.utils.data import Dataset
from torch.nn import functional as F
from torchvision import transforms
import PIL.Image as Image
import pickle

###########
# from   utils import CONFIG
global cfg


class ToTensor(object):
    """
    Convert ndarrays in sample to Tensors with normalization.
    """
    def __init__(self, phase="test"):
        global cfg
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        self.phase = phase

    def __call__(self, sample):
        # convert GBR images to RGB
        image, alpha, trimap = sample['image'][:,:,::-1], sample['alpha'], sample['trimap']
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = torch.from_numpy(image.transpose((2, 0, 1)).astype('float32')) / 255.
        alpha = torch.from_numpy(alpha.astype('float32')).unsqueeze(dim=0)

        # normalize image
        mask = np.equal(trimap, 128).astype('float32')
        mask = torch.from_numpy(mask).unsqueeze(dim=0)
        trimap = torch.from_numpy(trimap.astype('float32')).unsqueeze(dim=0)


        if self.phase == "train":
            # convert GBR images to RGB
            fg = torch.from_numpy(sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype('float32')) / 255.
            bg = torch.from_numpy(sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype('float32')) / 255.
            alpha = torch.cat((alpha, mask, fg, bg, image), dim=0)
        else:
            alpha = torch.cat((alpha, trimap), dim=0)
        image = (image - self.mean) / self.std
        image = torch.cat((image, trimap/255.), dim=0)

        sample['image'], sample['alpha'] = image, alpha


        if self.phase == "train":
            del sample['image_name']
            del sample['trimap']
            del sample['fg']
            del sample['bg']
        else:
            del sample['trimap']

        return sample


class RandomAffine(object):
    """
    Random affine translation
    """
    def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0):
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError("If degrees is a single number, it must be positive.")
            self.degrees = (-degrees, degrees)
        else:
            assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
                "degrees should be a list or tuple and it must be of length 2."
            self.degrees = degrees

        if translate is not None:
            assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
                "translate should be a list or tuple and it must be of length 2."
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
            assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
                "scale should be a list or tuple and it must be of length 2."
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
            if isinstance(shear, numbers.Number):
                if shear < 0:
                    raise ValueError("If shear is a single number, it must be positive.")
                self.shear = (-shear, shear)
            else:
                assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
                    "shear should be a list or tuple and it must be of length 2."
                self.shear = shear
        else:
            self.shear = shear

        self.resample = resample
        self.fillcolor = fillcolor
        self.flip = flip

    @staticmethod
    def get_params(degrees, translate, scale_ranges, shears, flip, img_size):
        """Get parameters for affine transformation
        Returns:
            sequence: params to be passed to the affine transformation
        """
        angle = random.uniform(degrees[0], degrees[1])
        if translate is not None:
            max_dx = translate[0] * img_size[0]
            max_dy = translate[1] * img_size[1]
            translations = (np.round(random.uniform(-max_dx, max_dx)),
                            np.round(random.uniform(-max_dy, max_dy)))
        else:
            translations = (0, 0)

        if scale_ranges is not None:
            scale = (random.uniform(scale_ranges[0], scale_ranges[1]),
                     random.uniform(scale_ranges[0], scale_ranges[1]))
        else:
            scale = (1.0, 1.0)

        if shears is not None:
            shear = random.uniform(shears[0], shears[1])
        else:
            shear = 0.0

        if flip is not None:
            # flip = (np.random.rand(2) < flip).astype(np.int) * 2 - 1
            flip = (np.random.rand(2) < flip).astype(np.uint8) * 2 - 1

        return angle, translations, scale, shear, flip

    def __call__(self, sample):
        fg, alpha = sample['fg'], sample['alpha']
        rows, cols, ch = fg.shape
        if np.maximum(rows, cols) < 1024:
            params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size)
        else:
            params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size)

        center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5)
        M = self._get_inverse_affine_matrix(center, *params)
        M = np.array(M).reshape((2, 3))

        fg = cv2.warpAffine(fg, M, (cols, rows),
                            flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP)
        alpha = cv2.warpAffine(alpha, M, (cols, rows),
                               flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP)

        sample['fg'], sample['alpha'] = fg, alpha


        return sample


    @ staticmethod
    def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip):
        # Helper method to compute inverse matrix for affine transformation

        # As it is explained in PIL.Image.rotate
        # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
        # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
        # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
        # RSS is rotation with scale and shear matrix
        # It is different from the original function in torchvision
        # The order are changed to flip -> scale -> rotation -> shear
        # x and y have different scale factors
        # RSS(shear, a, scale, f) = [ cos(a + shear)*scale_x*f -sin(a + shear)*scale_y     0]
        # [ sin(a)*scale_x*f          cos(a)*scale_y             0]
        # [     0                       0                      1]
        # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1

        angle = math.radians(angle)
        shear = math.radians(shear)
        scale_x = 1.0 / scale[0] * flip[0]
        scale_y = 1.0 / scale[1] * flip[1]

        # Inverted rotation matrix with scale and shear
        d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
        matrix = [
            math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0,
            -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0
        ]
        matrix = [m / d for m in matrix]

        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
        matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
        matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])

        # Apply center translation: C * RSS^-1 * C^-1 * T^-1
        matrix[2] += center[0]
        matrix[5] += center[1]

        return matrix

class RandomJitter(object):
    """
    Random change the hue of the image
    """

    def __call__(self, sample):
        fg, alpha = sample['fg'], sample['alpha']
        # if alpha is all 0 skip
        if np.all(alpha==0):
            return sample
        # convert to HSV space, convert to float32 image to keep precision during space conversion.
        fg = cv2.cvtColor(fg.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
        # Hue noise
        hue_jitter = np.random.randint(-40, 40)
        fg[:, :, 0] = np.remainder(fg[:, :, 0].astype(np.float32) + hue_jitter, 360)
        # Saturation noise
        sat_bar = fg[:, :, 1][alpha > 0].mean()
        sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
        sat = fg[:, :, 1]
        sat = np.abs(sat + sat_jitter)
        sat[sat>1] = 2 - sat[sat>1]
        fg[:, :, 1] = sat
        # Value noise
        val_bar = fg[:, :, 2][alpha > 0].mean()
        val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
        val = fg[:, :, 2]
        val = np.abs(val + val_jitter)
        val[val>1] = 2 - val[val>1]
        fg[:, :, 2] = val
        # convert back to BGR space
        fg = cv2.cvtColor(fg, cv2.COLOR_HSV2BGR)
        sample['fg'] = fg*255

        return sample


class RandomHorizontalFlip(object):
    """
    Random flip image and label horizontally
    """
    def __init__(self, prob=0.5):
        self.prob = prob
    def __call__(self, sample):
        fg, alpha = sample['fg'], sample['alpha']
        if np.random.uniform(0, 1) < self.prob:
            fg = cv2.flip(fg, 1)
            alpha = cv2.flip(alpha, 1)
        sample['fg'], sample['alpha'] = fg, alpha

        return sample


class RandomCrop(object):
    """
    Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size'
    :param output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
        self.margin = output_size[0] // 2
        self.logger = logging.getLogger("Logger")

    def __call__(self, sample):
        fg, alpha, trimap, name = sample['fg'], sample['alpha'], sample['trimap'], sample['image_name']
        bg = sample['bg']
        h, w = trimap.shape
        hbg, wbg = bg.shape[:2]
        ratio = max(h/hbg, w/wbg)
        if ratio > 1:
            bg = cv2.resize(bg, (math.ceil(wbg*ratio), math.ceil(hbg*ratio)), interpolation=cv2.INTER_CUBIC)
            hbg, wbg = bg.shape[:2]
        if h < self.output_size[0] + 1 or w < self.output_size[1] + 1:
            ratio = 1.1 * self.output_size[0] / h if h < w else 1.1 * self.output_size[1] / w
            # self.logger.warning("Size of {} is {}.".format(name, (h, w)))
            while h < self.output_size[0] + 1 or w < self.output_size[1] + 1:
                fg = cv2.resize(fg, (int(w * ratio), int(h * ratio)),
                                interpolation=cv2.INTER_NEAREST)
                alpha = cv2.resize(alpha, (int(w * ratio), int(h * ratio)),
                                   interpolation=cv2.INTER_NEAREST)
                trimap = cv2.resize(trimap, (int(w * ratio), int(h * ratio)), interpolation=cv2.INTER_NEAREST)
                bg = cv2.resize(bg, (int(wbg * ratio), int(hbg * ratio)),
                                interpolation=cv2.INTER_CUBIC)
                h, w = trimap.shape
        unknown_list = list(zip(*np.where(trimap[self.margin:(h - self.margin),
                                          self.margin:(w - self.margin)] == 128)))

        unknown_num = len(unknown_list)
        if len(unknown_list) < 10:
            # self.logger.warning("{} does not have enough unknown area for crop.".format(name))
            left_top = (
            np.random.randint(0, h - self.output_size[0] + 1), np.random.randint(0, w - self.output_size[1] + 1))
        else:
            idx = np.random.randint(unknown_num)
            left_top = (unknown_list[idx][0], unknown_list[idx][1])

        fg_crop = fg[left_top[0]:left_top[0] + self.output_size[0], left_top[1]:left_top[1] + self.output_size[1]]
        alpha_crop = alpha[left_top[0]:left_top[0] + self.output_size[0], left_top[1]:left_top[1] + self.output_size[1]]
        bg_crop = bg[left_top[0]:left_top[0] + self.output_size[0], left_top[1]:left_top[1] + self.output_size[1]]
        trimap_crop = trimap[left_top[0]:left_top[0] + self.output_size[0],
                      left_top[1]:left_top[1] + self.output_size[1]]

        if len(np.where(trimap == 128)[0]) == 0:
            self.logger.error("{} does not have enough unknown area for crop. Resized to target size."
                              "left_top: {}".format(name, left_top))
            fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
            alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
            trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
            bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=cv2.INTER_CUBIC)

        sample['fg'], sample['alpha'], sample['trimap'] = fg_crop, alpha_crop, trimap_crop
        sample['bg'] = bg_crop

        return sample


class GenTrimap(object):
    def __init__(self):
        self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,30)]

    def __call__(self, sample):
        alpha = sample['alpha']
        # Adobe 1K
        # fg_width = np.random.randint(1, 30)
        # bg_width = np.random.randint(1, 30)
        fg_width = np.random.randint(1, 10)
        bg_width = np.random.randint(1, 10)
        fg_mask = (alpha + 1e-5).astype(np.int).astype(np.uint8)
        bg_mask = (1 - alpha + 1e-5).astype(np.int_).astype(np.uint8)
        fg_mask = cv2.erode(fg_mask, self.erosion_kernels[fg_width])
        bg_mask = cv2.erode(bg_mask, self.erosion_kernels[bg_width])

        trimap = np.ones_like(alpha) * 128
        trimap[fg_mask == 1] = 255
        trimap[bg_mask == 1] = 0

        sample['trimap'] = trimap

        return sample


class Composite(object):
    def __call__(self, sample):
        fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha']
        alpha[alpha < 0 ] = 0
        alpha[alpha > 1] = 1
        fg[fg < 0 ] = 0
        fg[fg > 255] = 255
        bg[bg < 0 ] = 0
        bg[bg > 255] = 255

        image = fg * alpha[:, :, None] + bg * (1 - alpha[:, :, None])
        sample['image'] = image

        return sample
    
###############################################################

class DataGenerator(Dataset):
    def __init__(self, Cfg, phase="train", test_scale="resize", crop_size = 320, augmentation=True):
        '''
        cfg: config文件
        phase: string. 默认train,当前data generator状态，是training还是testing
        crop_size: int. 默认320，提取config文件中训练数据的裁剪大小
        augmentation: boolean. 默认True，是否需要data augmentation操作
        
        fg: numpy array. 保存了所有foreground图片的绝对路径
        alpha： numpy array. 保存了所有alpha图片的绝对路径
        bg： numpy array. 保存了所有background图片的绝对路径
        fg_num: int. foreground图片个数
        bg_num: int. background图片个数
        
        fg_load: dict. key是foreground图片名，value是图片的numpy.ndarray
        bg_load: dict. key是background图片名，value是图片的numpy.ndarray
        alpha_load: dict. key是alpha图片名，value是图片的numpy.ndarray
        
        transform: dict. key是train, val, test，value是一个list，里面是一系列的transform操作. 根据不同的phase，选择不同的transform操作
        '''
        
        global cfg
        cfg = Cfg
        self.phase = phase #当前data generator状态，是training还是testing
        self.crop_size = cfg.TRAIN.crop_size #提取config文件中训练数据的裁剪大小
        self.augmentation = augmentation #是否需要data augmentation操作
        if self.phase == "train": #如果当前状态是training
            self.fg = np.array([os.path.join(cfg.DATASET.data_dir, name) for name in
                       open(cfg.DATASET.train_fg_list).read().splitlines()]) 
            #cfg.DATASET.train_fg_list是train_fg.txt文件，合并foreground图片文件的绝对路径为一个大list，并且转换为numpy array
            
            self.alpha = np.array([os.path.join(cfg.DATASET.data_dir, name) for name in
                          open(cfg.DATASET.train_alpha_list).read().splitlines()])
            #cfg.DATASET.train_alpha_list是train_alpha.txt文件，合并alpha图片文件的绝对路径为一个大list，并且转换为numpy array
            
            
            self.bg = np.array([os.path.join(cfg.DATASET.data_dir, name) for name in 
                                open(cfg.DATASET.train_bg_list).read().splitlines()])
            #同理
            
            self.bg_num = len(self.bg)
            self.fg_num = len(self.fg)
            #background图片和foreground图片的数量
            
            if cfg.TRAIN.load_data: #在config文件的training部分：whether to load fg, alpha to memory。默认为false
                self.fg_load = dict()
                self.alpha_load = dict()
                #fg_load和alpha_load是两个字典，key是image名字，value是foreground value和alpha value
                
                for idx in range(self.fg_num):
                    fg_name = self.fg[idx] #第idx位置的foreground image路径
                    fg_name = fg_name[fg_name.rfind('/') + 1:-4] #foreground image文件名， xxxx.jpg，xxxx的部分
                    fg = cv2.imread(self.fg[idx]) #读取foreground image 默认彩色 (numpy.ndarray)
                    alpha = cv2.imread(self.alpha[idx], 0).astype(np.float32) / 255. #读取alpha image 默认灰度 (numpy.float32)
                    # 读取alpha，0相当于cv2.IMREAD_GRAYSCALE。
                    # 并且将alpha从[0, 255]区间映射为[0, 1]区间，为了便于分析
                    
                    self.fg_load.update({fg_name: fg})
                    self.alpha_load.update({fg_name: alpha})
                    # 将foreground value和alpha value （[0, 1]）保存到fg_load和alpha_load两个字典中,字典参考前面的注释
                    
            if cfg.TRAIN.load_bg: #whether to load bg to memory，默认为false，以下操作和alpha foreground一致
                self.bg_load = dict()
                for idx in range(self.bg_num):
                    bg_name = self.bg[idx] # 加载background,和alpha, foreground同理
                    bg_name = bg_name[bg_name.rfind('/')+1:-4]
                    bg = cv2.imread(self.bg[idx], 1)
                    self.bg_load.update({bg_name: bg})
                    
        else: #如果当前状态不是training:
            self.test_list = np.array([name.split('\t') for name in open(cfg.DATASET.val_list).read().splitlines()])
            # cfg.DATASET.val_list指向的是test.txt文件 
            # 里面一行有三个图片路径并且用tab分隔开，他们是merged, alpha和trimaps

        if augmentation: #如果需要data augumentation操作
            train_trans = [
                            RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5), #图片保持重心不变的图像随机仿射变换
                            
                            RandomHorizontalFlip(), #图片在给定几率下随机进行水平翻转，默认为0.5
                            
                            GenTrimap(),
                            
                            RandomCrop((self.crop_size, self.crop_size)), #图片随机裁剪一块图像出来, output size是 (self.crop_size, self.crop_size)
                            
                            RandomJitter(), #随机抖动
                            
                            Composite(), #根据foreground, alpha, background合成图像
                            
                            ToTensor(phase="train") 
                            ] #Convert ndarrays in sample to Tensors with normalization
        else:
            train_trans = [ GenTrimap(),
                            RandomCrop((self.crop_size, self.crop_size)), #图片随机裁剪一块图像出来, output size是 (self.crop_size, self.crop_size)
                            Composite(),
                            ToTensor(phase="train") ]

        if test_scale.lower() == "origin":
            #如果不需要做改变，就只转换成tensor
            test_trans = [ToTensor()]
        else:
            raise NotImplementedError("test_scale {} not implemented".format(test_scale))

        self.transform = {
            # transform是一个字典，key是train, val, test，value是一个list，里面是一系列的transform操作，根据不同的phase选择不同的transform操作
            'train':
                transforms.Compose(train_trans),
            'val':
                transforms.Compose([
                    ToTensor()
                ]),
            'test':
                transforms.Compose(test_trans)
        }[phase]

        self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,20)]
        # 生成一个椭圆形的核，用于腐蚀(erosion)操作，腐蚀操作是一种图像处理的操作，用于去除小的噪声点，或者连接两个分开的物体


    def __getitem__(self, idx):
        # 根据index，找到一个名为sample的index，这个sample是一个字典，根据不同的phase，sample里面的key和value是不一样的
        # 然后根据不同的phase，对sample里面的value进行不同的transform操作
        # transform操作是在__init__里面定义的
        
        if self.phase == "train":
            if not cfg.TRAIN.load_data:
                # load foreground
                fg = cv2.imread(self.fg[idx % self.fg_num]) #源代码
                
                # load alpha
                
                alpha = cv2.imread(self.alpha[idx % self.fg_num], 0).astype(np.float32) / 255 # 源代码
                
                # alpha = cv2.imread(self.alpha[idx], 0).astype(np.float32) / 255 # 修正代码
            else:
                # load foreground
                fg_name = self.fg[idx % self.fg_num]
                fg_name = fg_name[fg_name.rfind('/')+1:-4]

                fg = self.fg_load[fg_name]
                # fg是fg_load字典里面的一个key(fg_name)所对应的value是foreground的图片（numpy.ndarray）
                alpha = self.alpha_load[fg_name]
                # alpha是alpha_load字典里面的一个key(fg_name)所对应的value是alpha的图片（numpy.float32）
                
            bg_idx = np.random.randint(0, self.bg_num - 1) if cfg.TRAIN.random_bgidx else idx #随机选择一个background图片的index
            if not cfg.TRAIN.load_bg: #如果没load background
                bg = cv2.imread(self.bg[bg_idx], 1) #读取background图片
            else:
                bg_name = self.bg[idx % self.bg_num] #源代码    
                bg_name = bg_name[bg_name.rfind('/') + 1:-4] # bg_name是background图片的名字
                # print(bg_name)
                bg = self.bg_load[bg_name] # bg是bg_load字典里面的一个key(bg_name)所对应的value是background的图片（numpy.ndarray）
                
            if bg.shape[2]==1:  # 如果background图片是灰度图，
                bg = np.repeat(bg, 3, axis=2) #如果background图片是灰度图，就把它变成RGB图
                
            alpha = np.squeeze(alpha) #去掉alpha中的维度为1的维度

            if self.augmentation: #如果需要data augmentation
                fg, alpha = self._composite_fg(fg, alpha, idx)

            image_name = os.path.split(self.fg[idx % self.fg_num])[-1]
            sample = {'fg': fg, 'alpha': alpha, 'bg': bg, 'image_name': image_name} #把foreground （np.array）, alpha（np.array）, background（np.array）, image_name(string)放到一个字典（sample）里面

        else: #如果当前是test阶段
            image = cv2.imread(os.path.join(cfg.DATASET.data_dir, self.test_list[idx][0])) #读取image
            alpha = cv2.imread(os.path.join(cfg.DATASET.data_dir, self.test_list[idx][1]), 0).astype(np.float32) / 255. #读取alpha
            alpha = alpha[:, :, 0] if alpha.ndim == 3 else alpha #如果alpha是3维的，就把它变成2维的
            trimap = cv2.imread(os.path.join(cfg.DATASET.data_dir, self.test_list[idx][2]), 0) #读取trimap
            image_name = os.path.split(self.test_list[idx][0])[-1] #image_name是image的名字

            sample = {'image': image, 'alpha': alpha, 'trimap': trimap} #sample是一个字典，里面包含image, alpha, trimap

        sample = self.transform(sample) #根据当前phase，对sçample进行对应的transform操作

        return sample

    def _composite_fg(self, fg, alpha, idx): #合并foreground和alpha

        if np.random.rand() < 0.5: 
            idx2 = np.random.randint(self.fg_num) + idx #随机选择一个foreground图片的index
            if not cfg.TRAIN.load_data: # 如果不load fg, alpha to memory
                fg2 = cv2.imread(self.fg[idx2 % self.fg_num]) #随机选择一个foreground图片作为fg2
                alpha2 = cv2.imread(self.alpha[idx2 % self.fg_num], 0).astype(np.float32) / 255. #随机选择一个alpha图片作为alpha2
            else:
                fg2_name = self.fg[idx2 % self.fg_num] 
                fg2_name = fg2_name[fg2_name.rfind('/') + 1:-4] # fg2_name是随机选择的foreground图片的名字
                fg2 = self.fg_load[fg2_name] 
                alpha2 = self.alpha_load[fg2_name].astype(np.float32) / 255.
              
            # resizes fg2 and alpha2 to the same size as fg and alpha 
            alpha2 = np.squeeze(alpha2)
            h, w = alpha.shape
            fg2 = cv2.resize(fg2, (w, h), interpolation=cv2.INTER_NEAREST)
            alpha2 = cv2.resize(alpha2, (w, h), interpolation=cv2.INTER_NEAREST)
            

            alpha_tmp = 1 - (1 - alpha) * (1 - alpha2) #合并alpha和alpha2
            if np.any(alpha_tmp < 1): #如果合并后的alpha_tmp中有小于1的值
                fg = ((fg.astype(np.float32) * alpha[:, :, None] + fg2.astype(np.float32) * (1 - alpha[:, :, None]) * alpha2[:, :, None])) \
                     / (alpha_tmp[:, :, None] + 1e-5) #合并fg和fg2
                # The overlap of two 50% transparency should be 25%
                alpha = alpha_tmp
                fg = fg.astype(np.uint8)

        if np.random.rand() < 0.25: 
            #把fg和alpha都resize到640*640, 并且使用最近邻插值
            fg = cv2.resize(fg, (640, 640), interpolation=cv2.INTER_NEAREST) 
            alpha = cv2.resize(alpha, (640, 640), interpolation=cv2.INTER_NEAREST)

        return fg, alpha

    def __len__(self):
        if self.phase == "train":
            return self.bg_num
        else:
            return len(self.test_list)

#####################################



In [20]:


if __name__ == '__main__':
    
    from data_generator import DataGenerator
    from torch.utils.data import DataLoader
    from config import cfg

    # logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] %(levelname)s: %(message)s', datefmt='%m-%d %H:%M:%S')

    cfg.merge_from_file('/home/xiufeng/Code/config/aiml.yaml')

    dataset = DataGenerator(cfg, phase='train', test_scale='origin', crop_size=cfg.TRAIN.crop_size)
    dataloader = DataLoader(
            dataset,
            batch_size=1,
            shuffle=True,
            num_workers=0,
            drop_last=True,
            sampler=None
        )

    print(len(dataloader))
    print(dataset.fg_num)
    print(dataset.bg_num)
    
    import matplotlib.pyplot as plt
    
    plt.figure(figsize = (8, 8))
    for i, data in enumerate(dataloader, 0):
        images, targets = data['image'], data['alpha']
        print(i)
        # print(images.shape)
        # print(targets.shape)
        # images = images.numpy().reshape(512, 512, -1)
        # images = images[:, :, 3]
        # plt.subplot(2, 2, 1)
        # plt.imshow(images)
        
        # targets = targets.numpy().reshape(512, 512, -1)
        # targets1 = targets[:, :, 0]
        # plt.subplot(2, 2, 2)
        # plt.imshow(targets1)
        
        # targets2 = targets[:, :, 5:8]
        # plt.subplot(2, 2, 3)
        # plt.imshow(targets2)
        
        # targets3 = targets[:, :, 8:11]
        # plt.subplot(2, 2, 4)
        # plt.imshow(targets3)
        
        
        # alpha = targets[0,0].numpy()
        # img = Image.fromarray(alpha)
        # img.show()
        # plt.show()       
        # break


1000
995
1000
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273


KeyboardInterrupt: 

<Figure size 800x800 with 0 Axes>