Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train dataset questions #57

Closed
luoww1992 opened this issue Jan 28, 2021 · 61 comments
Closed

train dataset questions #57

luoww1992 opened this issue Jan 28, 2021 · 61 comments

Comments

@luoww1992
Copy link

luoww1992 commented Jan 28, 2021

i am making the train dataset, it needs 3 folders--original, trimap, matter.
so the size of image must be 512,?
and the image need to do other operations, like change the color and so on ?
what else should I pay attention to do in the dataset ?

@ZHKKKe
Copy link
Owner

ZHKKKe commented Jan 28, 2021

Hi, thanks for your attention.

For your questions:
Q1: it needs 3 folders--original, trimap, matter?
Yes. And the trimap can be generated from the matte.

Q2: so the size of image must be 512,?
No, you can use any size to train the model.

Q2: and the image need to do other operations, like change the color and so on ?
You can use the most common data augmentation to process the training data, e.g., flipping, normalization.

@antithing
Copy link

antithing commented Jan 28, 2021

@luoww1992 would you be able to share your successful training approach? did you write a dataloader and a trimap creator?
Any tips or code would be greatly apreciated! . Thank you!

@luoww1992
Copy link
Author

luoww1992 commented Jan 29, 2021

@ luoww1992您能分享成功的培训方法吗?您是否编写了数据加载器和Trimap创建器?
任何提示或代码都将不胜感激!。谢谢!

i am making trimap dataset,it will spend some times

.......

i am organizing the code .

@luoww1992
Copy link
Author

luoww1992 commented Jan 30, 2021

@ZHKKKe , i am training dataset for videoMatting , but the loss is very large, and the detail_loss is 0 ,no change.
my train.py :

import math
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from functools import reduce

import cv2
import numpy as np
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import grey_dilation, grey_erosion
from scipy.ndimage import morphology
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
from math import *

__all__ = [
    'supervised_training_iter',
    'soc_adaptation_iter',
]



def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expansion, dilation=1):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
        super(MobileNetV2, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [expansion, 24, 2, 2],
            [expansion, 32, 3, 2],
            [expansion, 64, 4, 2],
            [expansion, 96, 3, 1],
            [expansion, 160, 3, 2],
            [expansion, 320, 1, 1],
        ]

        # building first layer
        input_channel = _make_divisible(input_channel * alpha, 8)
        self.last_channel = _make_divisible(last_channel * alpha, 8) if alpha > 1.0 else last_channel
        self.features = [conv_bn(self.in_channels, input_channel, 2)]

        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = _make_divisible(int(c * alpha), 8)
            for i in range(n):
                if i == 0:
                    self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
                else:
                    self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
                input_channel = output_channel

        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))

        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # building classifier
        if self.num_classes is not None:
            self.classifier = nn.Sequential(
                nn.Dropout(0.2),
                nn.Linear(self.last_channel, num_classes),
            )

        # Initialize weights
        self._init_weights()

    def forward(self, x, feature_names=None):
        # Stage1
        x = reduce(lambda x, n: self.features[n](x), list(range(0, 2)), x)
        # Stage2
        x = reduce(lambda x, n: self.features[n](x), list(range(2, 4)), x)
        # Stage3
        x = reduce(lambda x, n: self.features[n](x), list(range(4, 7)), x)
        # Stage4
        x = reduce(lambda x, n: self.features[n](x), list(range(7, 14)), x)
        # Stage5
        x = reduce(lambda x, n: self.features[n](x), list(range(14, 19)), x)

        # Classification
        if self.num_classes is not None:
            x = x.mean(dim=(2, 3))
            x = self.classifier(x)

        # Output
        return x

    def _load_pretrained_model(self, pretrained_file):
        pretrain_dict = torch.load(pretrained_file, map_location='cpu')
        model_dict = {}
        state_dict = self.state_dict()
        print("[MobileNetV2] Loading pretrained model...")
        for k, v in pretrain_dict.items():
            if k in state_dict:
                model_dict[k] = v
            else:
                print(k, "is ignored")
        state_dict.update(model_dict)
        self.load_state_dict(state_dict)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


class IBNorm(nn.Module):
    """ Combine Instance Norm and Batch Norm into One Layer
    """

    def __init__(self, in_channels):
        super(IBNorm, self).__init__()
        in_channels = in_channels
        self.bnorm_channels = int(in_channels / 2)
        self.inorm_channels = in_channels - self.bnorm_channels

        self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
        self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)

    def forward(self, x):
        bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
        in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())

        return torch.cat((bn_x, in_x), 1)


class Conv2dIBNormRelu(nn.Module):
    """ Convolution + IBNorm + ReLu
    """

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True,
                 with_ibn=True, with_relu=True):
        super(Conv2dIBNormRelu, self).__init__()

        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size,
                      stride=stride, padding=padding, dilation=dilation,
                      groups=groups, bias=bias)
        ]

        if with_ibn:
            layers.append(IBNorm(out_channels))
        if with_relu:
            layers.append(nn.ReLU(inplace=True))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class BaseBackbone(nn.Module):
    """ Superclass of Replaceable Backbone Model for Semantic Estimation
    """

    def __init__(self, in_channels):
        super(BaseBackbone, self).__init__()
        self.in_channels = in_channels

        self.model = None
        self.enc_channels = []

    def forward(self, x):
        raise NotImplementedError

    def load_pretrained_ckpt(self):
        raise NotImplementedError


class MobileNetV2Backbone(BaseBackbone):
    """ MobileNetV2 Backbone
    """

    def __init__(self, in_channels):
        super(MobileNetV2Backbone, self).__init__(in_channels)
        self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
        self.enc_channels = [16, 24, 32, 96, 1280]

    def forward(self, x):
        x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
        enc2x = x
        x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
        enc4x = x
        x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
        enc8x = x
        x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
        enc16x = x
        x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
        enc32x = x
        return [enc2x, enc4x, enc8x, enc16x, enc32x]

    def load_pretrained_ckpt(self):
        # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
        ckpt_path = './mobilenetv2_human_seg.ckpt'
        if not os.path.exists(ckpt_path):
            print('cannot find the pretrained mobilenetv2 backbone')
            exit()

        ckpt = torch.load(ckpt_path)
        self.model.load_state_dict(ckpt)


class SEBlock(nn.Module):
    """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
    """

    def __init__(self, in_channels, out_channels, reduction=1):
        super(SEBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, int(in_channels // reduction), bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels // reduction), out_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)

        return x * w.expand_as(x)


class LRBranch(nn.Module):
    """ Low Resolution Branch of MODNet
    """

    def __init__(self, backbone):
        super(LRBranch, self).__init__()

        enc_channels = backbone.enc_channels
        self.backbone = backbone
        self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
        self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
        self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
        self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False,
                                        with_relu=False)

    def forward(self, img, inference):
        enc_features = self.backbone.forward(img)
        enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]

        enc32x = self.se_block(enc32x)
        lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
        lr16x = self.conv_lr16x(lr16x)

        lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
        lr8x = self.conv_lr8x(lr8x)
        pred_semantic = None
        if not inference:
            lr = self.conv_lr(lr8x)
            pred_semantic = torch.sigmoid(lr)
        return pred_semantic, lr8x, [enc2x, enc4x]


class HRBranch(nn.Module):
    """ High Resolution Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(HRBranch, self).__init__()

        self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
        self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)

        self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
        self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)

        self.conv_hr4x = nn.Sequential(
            Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr2x = nn.Sequential(
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, enc2x, enc4x, lr8x, inference):
        img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
        img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False)

        enc2x = self.tohr_enc2x(enc2x)
        hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))

        enc4x = self.tohr_enc4x(enc4x)
        hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
        lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
        hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))

        hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
        hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))

        pred_detail = None
        if not inference:
            hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
            hr = self.conv_hr(torch.cat((hr, img), dim=1))
            pred_detail = torch.sigmoid(hr)

        return pred_detail, hr2x


class FusionBranch(nn.Module):
    """ Fusion Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(FusionBranch, self).__init__()
        self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)

        self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
        self.conv_f = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
            Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, lr8x, hr2x):
        lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
        lr4x = self.conv_lr4x(lr4x)
        lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)

        f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
        f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
        f = self.conv_f(torch.cat((f, img), dim=1))
        pred_matte = torch.sigmoid(f)

        return pred_matte


class MODNet(nn.Module):
    """ Architecture of MODNet
    """

    def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
        super(MODNet, self).__init__()

        self.in_channels = in_channels
        self.hr_channels = hr_channels
        self.backbone_arch = backbone_arch
        self.backbone_pretrained = backbone_pretrained

        self.backbone = MobileNetV2Backbone(self.in_channels)

        self.lr_branch = LRBranch(self.backbone)
        self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
        self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                self._init_conv(m)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
                self._init_norm(m)

        if self.backbone_pretrained:
            self.backbone.load_pretrained_ckpt()

    def forward(self, img, inference):
        pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)

        pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
        pred_matte = self.f_branch(img, lr8x, hr2x)

        return pred_semantic, pred_detail, pred_matte

    def freeze_norm(self):
        norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
        for m in self.modules():
            for n in norm_types:
                if isinstance(m, n):
                    m.eval()
                    continue

    def _init_conv(self, conv):
        nn.init.kaiming_uniform_(
            conv.weight, a=0, mode='fan_in', nonlinearity='relu')
        if conv.bias is not None:
            nn.init.constant_(conv.bias, 0)

    def _init_norm(self, norm):
        if norm.weight is not None:
            nn.init.constant_(norm.weight, 1)
            nn.init.constant_(norm.bias, 0)


class GaussianBlurLayer(nn.Module):
    """ Add Gaussian Blur to a 4D tensors
    This layer takes a 4D tensor of {N, C, H, W} as input.
    The Gaussian blur will be performed in given channel number (C) splitly.
    """

    def __init__(self, channels, kernel_size):
        """ 
        Arguments:
            channels (int): Channel for input tensor
            kernel_size (int): Size of the kernel used in blurring
        """

        super(GaussianBlurLayer, self).__init__()
        self.channels = channels
        self.kernel_size = kernel_size
        assert self.kernel_size % 2 != 0

        self.op = nn.Sequential(
            nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
            nn.Conv2d(channels, channels, self.kernel_size,
                      stride=1, padding=0, bias=None, groups=channels)
        )

        self._init_kernel()

    def forward(self, x):
        """
        Arguments:
            x (torch.Tensor): input 4D tensor
        Returns:
            torch.Tensor: Blurred version of the input 
        """
        if not len(list(x.shape)) == 4:
            print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
            exit()
        elif not x.shape[1] == self.channels:
            print('In \'GaussianBlurLayer\', the required channel ({0}) is'
                  'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
            exit()

        return self.op(x)

    def _init_kernel(self):
        sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8

        n = np.zeros((self.kernel_size, self.kernel_size))
        i = math.floor(self.kernel_size / 2)
        n[i, i] = 1
        kernel = scipy.ndimage.gaussian_filter(n, sigma)

        for name, param in self.named_parameters():
            param.data.copy_(torch.from_numpy(kernel))


class ImagesDataset(Dataset):
    def __init__(self, root, transforms=None, w=960, h=544):
        self.root = root
        self.transforms = transforms
        self.w = w
        self.h = h
        self.imgs = sorted(os.listdir(os.path.join(self.root, 'image')))
        self.alphas = sorted(os.listdir(os.path.join(self.root, 'alpha')))
        assert len(self.imgs) == len(self.alphas), 'the number of dataset is different, please check it.'

    def get_trimap(self, alpha):
        # alpha \in [0, 1] should be taken into account
        # be careful when dealing with regions of alpha=0 and alpha=1
        fg = np.array(np.equal(alpha, 255).astype(np.float32))
        unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) # unknown = alpha > 0
        unknown = unknown - fg
        # image dilation implemented by Euclidean distance transform
        unknown = morphology.distance_transform_edt(unknown==0) <= np.random.randint(1, 20)
        trimap = fg * 255
        trimap[unknown] = 128
        return trimap.astype(np.uint8)


    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = cv2.imread(os.path.join(self.root, 'image', self.imgs[idx]))
        alpha = cv2.imread(os.path.join(self.root, 'alpha', self.alphas[idx]))
        trimap = self.get_trimap(alpha)
        # cv2.imshow('trimap', trimap)
        # cv2.waitKey(0)
        h, w, c = img.shape
        if not (w == self.w and h == self.h):
            img = cv2.resize(img, (self.w, self.h))
            trimap = cv2.resize(trimap, (self.w, self.h))
            alpha = cv2.resize(alpha, (self.w, self.h))
        if self.transforms:
            img = self.transforms(img)
            trimap = self.transforms(trimap)
            alpha = self.transforms(alpha)
        return img, trimap, alpha





blurer = GaussianBlurLayer(3, 3).cuda()


def supervised_training_iter(
        modnet, optimizer, image, trimap, gt_matte,
        semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0):
    """ Supervised training iteration of MODNet
    This function trains MODNet for one iteration in a labeled dataset.

    Arguments:
        modnet (torch.nn.Module): instance of MODNet
        optimizer (torch.optim.Optimizer): optimizer for supervised training 
        image (torch.autograd.Variable): input RGB image
        trimap (torch.autograd.Variable): trimap used to calculate the losses
                                          NOTE: foreground=1, background=0, unknown=0.5
        gt_matte (torch.autograd.Variable): ground truth alpha matte
        semantic_scale (float): scale of the semantic loss
                                NOTE: please adjust according to your dataset
        detail_scale (float): scale of the detail loss
                              NOTE: please adjust according to your dataset
        matte_scale (float): scale of the matte loss
                             NOTE: please adjust according to your dataset
    
    Returns:
        semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
        detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
        matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]

    Example:
        import torch
        from src.models.modnet import MODNet
        from src.trainer import supervised_training_iter

        bs = 16         # batch size
        lr = 0.01       # learn rate
        epochs = 40     # total epochs

        modnet = torch.nn.DataParallel(MODNet()).cuda()
        optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)

        dataloader = CREATE_YOUR_DATALOADER(bs)     # NOTE: please finish this function

        for epoch in range(0, epochs):
            for idx, (image, trimap, gt_matte) in enumerate(dataloader):
                semantic_loss, detail_loss, matte_loss = \
                    supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
            lr_scheduler.step()
    """

    global blurer

    # set the model to train mode and clear the optimizer
    modnet.train()
    optimizer.zero_grad()

    # forward the model
    pred_semantic, pred_detail, pred_matte = modnet(image, False)

    # calculate the boundary mask from the trimap
    boundaries = (trimap < 0.5) + (trimap > 0.5)

    # calculate the semantic loss
    gt_semantic = F.interpolate(gt_matte, scale_factor=1 / 16, mode='bilinear')
    gt_semantic = blurer(gt_semantic)
    semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
    semantic_loss = semantic_scale * semantic_loss

    # calculate the detail loss
    pred_boundary_detail = torch.where(boundaries, trimap, pred_detail)
    gt_detail = torch.where(boundaries, trimap, gt_matte)
    detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail))
    detail_loss = detail_scale * detail_loss

    # calculate the matte loss
    pred_boundary_matte = torch.where(boundaries, trimap, pred_matte)
    matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
    matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
                               + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
    matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
    matte_loss = matte_scale * matte_loss

    # calculate the final loss, backward the loss, and update the model 
    loss = semantic_loss + detail_loss + matte_loss
    loss.backward()
    optimizer.step()

    # for test
    return semantic_loss, detail_loss, matte_loss


def soc_adaptation_iter(
        modnet, backup_modnet, optimizer, image,
        soc_semantic_scale=100.0, soc_detail_scale=1.0):
    """ Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet
    This function fine-tunes MODNet for one iteration in an unlabeled dataset.
    Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been 
    trained in a labeled dataset.

    Arguments:
        modnet (torch.nn.Module): instance of MODNet
        backup_modnet (torch.nn.Module): backup of the trained MODNet
        optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC 
        image (torch.autograd.Variable): input RGB image
        soc_semantic_scale (float): scale of the SOC semantic loss 
                                    NOTE: please adjust according to your dataset
        soc_detail_scale (float): scale of the SOC detail loss
                                  NOTE: please adjust according to your dataset
    
    Returns:
        soc_semantic_loss (torch.Tensor): loss of the semantic SOC
        soc_detail_loss (torch.Tensor): loss of the detail SOC

    Example:
        import copy
        import torch
        from src.models.modnet import MODNet
        from src.trainer import soc_adaptation_iter

        bs = 1          # batch size
        lr = 0.00001    # learn rate
        epochs = 10     # total epochs

        modnet = torch.nn.DataParallel(MODNet()).cuda()
        modnet = LOAD_TRAINED_CKPT()    # NOTE: please finish this function

        optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99))
        dataloader = CREATE_YOUR_DATALOADER(bs)     # NOTE: please finish this function

        for epoch in range(0, epochs):
            backup_modnet = copy.deepcopy(modnet)
            for idx, (image) in enumerate(dataloader):
                soc_semantic_loss, soc_detail_loss = \
                    soc_adaptation_iter(modnet, backup_modnet, optimizer, image)
    """

    global blurer

    # set the backup model to eval mode
    backup_modnet.eval()

    # set the main model to train mode and freeze its norm layers
    modnet.train()
    modnet.module.freeze_norm()

    # clear the optimizer
    optimizer.zero_grad()

    # forward the main model
    pred_semantic, pred_detail, pred_matte = modnet(image, False)

    # forward the backup model
    with torch.no_grad():
        _, pred_backup_detail, pred_backup_matte = backup_modnet(image, False)

    # calculate the boundary mask from `pred_matte` and `pred_semantic`
    pred_matte_fg = (pred_matte.detach() > 0.1).float()
    pred_semantic_fg = (pred_semantic.detach() > 0.1).float()
    pred_semantic_fg = F.interpolate(pred_semantic_fg, scale_factor=16, mode='bilinear')
    pred_fg = pred_matte_fg * pred_semantic_fg

    n, c, h, w = pred_matte.shape
    np_pred_fg = pred_fg.data.cpu().numpy()
    np_boundaries = np.zeros([n, c, h, w])
    for sdx in range(0, n):
        sample_np_boundaries = np_boundaries[sdx, 0, ...]
        sample_np_pred_fg = np_pred_fg[sdx, 0, ...]

        side = int((h + w) / 2 * 0.05)
        dilated = grey_dilation(sample_np_pred_fg, size=(side, side))
        eroded = grey_erosion(sample_np_pred_fg, size=(side, side))

        sample_np_boundaries[np.where(dilated - eroded != 0)] = 1
        np_boundaries[sdx, 0, ...] = sample_np_boundaries

    boundaries = torch.tensor(np_boundaries).float().cuda()

    # sub-objectives consistency between `pred_semantic` and `pred_matte`
    # generate pseudo ground truth for `pred_semantic`
    downsampled_pred_matte = blurer(F.interpolate(pred_matte, scale_factor=1 / 16, mode='bilinear'))
    pseudo_gt_semantic = downsampled_pred_matte.detach()
    pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic > 0.01).float()

    # generate pseudo ground truth for `pred_matte`
    pseudo_gt_matte = pred_semantic.detach()
    pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float()

    # calculate the SOC semantic loss
    soc_semantic_loss = F.mse_loss(pred_semantic, pseudo_gt_semantic) + F.mse_loss(downsampled_pred_matte,
                                                                                   pseudo_gt_matte)
    soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss)

    # NOTE: using the formulas in our paper to calculate the following losses has similar results
    # sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only)
    backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail)
    backup_detail_loss = torch.sum(backup_detail_loss, dim=(1, 2, 3)) / torch.sum(boundaries, dim=(1, 2, 3))
    backup_detail_loss = torch.mean(backup_detail_loss)

    # sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only)
    backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte)
    backup_matte_loss = torch.sum(backup_matte_loss, dim=(1, 2, 3)) / torch.sum(boundaries, dim=(1, 2, 3))
    backup_matte_loss = torch.mean(backup_matte_loss)

    soc_detail_loss = soc_detail_scale * (backup_detail_loss + backup_matte_loss)

    # calculate the final loss, backward the loss, and update the model 
    loss = soc_semantic_loss + soc_detail_loss

    loss.backward()
    optimizer.step()

    return soc_semantic_loss, soc_detail_loss


# ----------------------------------------------------------------------------------


def main(root):
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet)
    GPU = True if torch.cuda.device_count() > 0 else False
    if GPU:
        print('Use GPU...')
        modnet = modnet.cuda()
        modnet.load_state_dict(torch.load(pretrained_ckpt))
    else:
        print('Use CPU...')
        modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu')))
    modnet.eval()
    bs = 1  # batch size
    lr = 0.01  # learn rate
    epochs = 40  # total epochs
    num_workers = 8
    optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)

    # dataloader = CREATE_YOUR_DATALOADER(bs)  # NOTE: please finish this function
    dataset = ImagesDataset(root)
    dataloader = DataLoader(dataset, batch_size=bs, num_workers=num_workers, pin_memory=True)

    for epoch in range(epochs):
        for idx, (image, trimap, gt_matte) in enumerate(dataloader):
            image = np.transpose(image, (0, 3, 1, 2)).float().cuda()
            trimap = np.transpose(trimap, (0, 3, 1, 2)).float().cuda()
            gt_matte = np.transpose(gt_matte, (0, 3, 1, 2)).float().cuda()
            semantic_loss, detail_loss, matte_loss = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
            print(f"epoch: {epoch+1}/{epochs} semantic_loss: {semantic_loss}, detail_loss: {detail_loss}, matte_loss: {matte_loss}")
        lr_scheduler.step()


if __name__ == '__main__':
    path = 'MODNet/dataset'
    main(path)

the loss show
epoch: 21/40 semantic_loss: 267474.09375, detail_loss: 0.0, matte_loss: 11912.3251953125
epoch: 21/40 semantic_loss: 152139.375, detail_loss: 0.0, matte_loss: 9059.408203125
epoch: 21/40 semantic_loss: 314421.375, detail_loss: 0.0, matte_loss: 10369.087890625

my GPU 2080Ti 11G

where are something wrong?

This was referenced Feb 1, 2021
@luoww1992 luoww1992 reopened this Feb 1, 2021
@luoww1992
Copy link
Author

luoww1992 commented Feb 1, 2021

@ZHKKKe ,
the size of my train videoMatting images is too large ? or I I'm missing some steps when make dataset ?

@ZHKKKe
Copy link
Owner

ZHKKKe commented Feb 1, 2021

@luoww1992 Hi, thanks for your attention.

For your questions:

Q1: the loss is large
You need to normalize the ground truth to [0, 1].
Please add transforms to ImagesDataset.

Q2: the detail_loss is 0
The pixel values in the loaded trimap should be 0=backgroud, 0.5=unknown or 1=foreground. Please pre-process it in your dataset.

You can refer to the latest comments for more information:

        image (torch.autograd.Variable): input RGB image
                                         its pixel values should be normalized
        trimap (torch.autograd.Variable): trimap used to calculate the losses
                                          its pixel values can be 0, 0.5, or 1
                                          (foreground=1, background=0, unknown=0.5)
        gt_matte (torch.autograd.Variable): ground truth alpha matte
                                            its pixel values are between [0, 1]

@luoww1992
Copy link
Author

@ZHKKKe
i am training it,
and some warns:

UserWarning: Using a target size (torch.Size([4, 3, 36, 64])) that is different to the input size (torch.Size([4, 1, 36, 64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
UserWarning: Using a target size (torch.Size([4, 3, 576, 1024])) that is different to the input size (torch.Size([4, 1, 576, 1024])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)

my input images is 1024*576.
will it cause some error in training ?

@ZHKKKe
Copy link
Owner

ZHKKKe commented Feb 2, 2021

Please change the channel of gt_matte to 1

@Johnson-yue
Copy link

@luoww1992 你训练成功了嘛?

@Shk-aftab
Copy link

@luoww1992 where you able to train it ?

@luoww1992
Copy link
Author

I have made the train dataset with 80k images, i am training it, i run it for 4 days, but it is too slow

@Shk-aftab
Copy link

@luoww1992 that's great ... can you share the dataset?

@luoww1992
Copy link
Author

@luoww1992 that's great ... can you share the dataset?

sorry, i can not share it before i get the authorization,
there are some personal images.
you can try to make it by alpha images

@czHappy
Copy link

czHappy commented Feb 23, 2021

@luoww1992 that's great ... can you share the dataset?

sorry, i can not share it before i get the authorization,
there are some personal images.
you can try to make it by alpha images

Could you share the complete code that can be trained correctly? And I would appreciate it if you could send me a few training samples(10-20 is enough). My email is wyking9@163.com, thanks a lot!

@Shk-aftab
Copy link

Hey @luoww1992 just share the complete trainable code.

@luoww1992 that's great ... can you share the dataset?

sorry, i can not share it before i get the authorization,
there are some personal images.
you can try to make it by alpha images

Could you share the complete code that can be trained correctly? And I would appreciate it if you could send me a few training samples(10-20 is enough). My email is wyking9@163.com, thanks a lot!

@luoww1992 for me too shaikhaftab139@gmail.com, thanks

@luoww1992
Copy link
Author

luoww1992 commented Feb 24, 2021

trainDefault.txt

i have updated it,

please pay attention to the notes in func.

@andy910389
Copy link

@luoww1992 that's great ... can you share the dataset?

sorry, i can not share it before i get the authorization,
there are some personal images.
you can try to make it by alpha images

please share me the complete code that can be trained correctly, many thanks !
mail: andy910389@gmail.com

@luoww1992
Copy link
Author

luoww1992 commented Feb 25, 2021 via email

@czHappy
Copy link

czHappy commented Feb 25, 2021

i have added it in trainDefault.txt ----- 原始邮件 ----- 发件人:Shu-Hao Ye notifications@github.com 收件人:ZHKKKe/MODNet MODNet@noreply.github.com 抄送人:luoww1992 luoww1992@sina.com, Mention mention@noreply.github.com 主题:Re: [ZHKKKe/MODNet] train dataset questions (#57) 日期:2021年02月25日 15点08分 @luoww1992 that's great ... can you share the dataset? sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images please share me the complete code that can be trained correctly, many thanks ! mail: andy910389@gmail.com — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

Thank you very much! I will read the code carefully, hope it will work!

@luoww1992
Copy link
Author

now i am runing two types model,
1, default train model -- i use the same size as the size author provided and the modle author provided.
2, myself train model -- i change the size to train the new model from zero
it is very slow to train it .
i am writing to finish training and start to do SOC step.

@ZHKKKe ZHKKKe closed this as completed Feb 25, 2021
@ZHKKKe ZHKKKe reopened this Feb 25, 2021
@ZHKKKe
Copy link
Owner

ZHKKKe commented Feb 25, 2021

@luoww1992
FYI. I trained the model on a Single GPU with batch-size=8 and input-size=512. The total training time is about 2~3 days on a dataset that contains 100k samples.

@luoww1992
Copy link
Author

@ luoww1992
仅供参考。我使用batch-size=8和在单GPU上训练了模型input-size=512。在包含1万个样本的数据集上,总训练时间约为2到3天。
i am batch-size=4 , gppu=11G and 80k images . it will cost more times.
when you finish training, how much does the matte_loss go down to?

@ZHKKKe
Copy link
Owner

ZHKKKe commented Feb 26, 2021

@luoww1992
I got the average training matte_loss=0.0364 in the last training epoch. However, the loss value should be different depending on the dataset.

@luoww1992
Copy link
Author

luoww1992 commented Mar 1, 2021

@ZHKKKe
i will run the soc step, i notice we only need images to make dataset without alpha images.
Don't we need a standard to show us the model is better while there are no alpha images to compare?
how many images you use to run SOC?

@FraPochetti
Copy link

@ZHKKKe thanks a lot!

@twin-92
Copy link

twin-92 commented Apr 13, 2021

@luoww1992: can you share trainDefault.txt code with SOC. Thanks a lot!

@dzyjjpy
Copy link

dzyjjpy commented Apr 13, 2021

@ZHKKKe
i am training it,
and some warns:

UserWarning: Using a target size (torch.Size([4, 3, 36, 64])) that is different to the input size (torch.Size([4, 1, 36, 64])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
UserWarning: Using a target size (torch.Size([4, 3, 576, 1024])) that is different to the input size (torch.Size([4, 1, 576, 1024])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)

my input images is 1024*576.

will it cause some error in training ?

Can it be used for other task, such as sky seg? want to sement the tree more accurately.
I trained the MODNet successfully and all the three loss for datasets smaller than 0.02 as picture shows. However, the inference result looks really bad.
image
image

@luoww1992 @ZHKKKe Could you pls give me some advice?

@ZHKKKe
Copy link
Owner

ZHKKKe commented Apr 14, 2021

@dzyjjpy
For your questions:
Q1: will it cause some error in training ?
It seems that you use a ground truth with three channels for training. However, the outputs of the model have only one channel.
Please check your ground truth.

Q2: However, the inference result looks really bad.
Can you share your ground truth for training?

@dzyjjpy
Copy link

dzyjjpy commented Apr 14, 2021

@dzyjjpy
For your questions:
Q1: will it cause some error in training ?
It seems that you use a ground truth with three channels for training. However, the outputs of the model have only one channel.
Please check your ground truth.
The groud truth has only one channel

Q2: However, the inference result looks really bad.
Can you share your ground truth for training?
As the png file I attach.
0007

@twin-92
Copy link

twin-92 commented Apr 14, 2021

@ZHKKKe can you explain for me 2 question:

  1. why do you use image input with size 512 and the large is divisible by 32, not use input size 512x512 (I don't see resize image to 512x512 in your inference.py
  2. and you create alpha matting dataset by Photoshop. You will create binary mask of person in image (only 2 values: 0 and 255), then you save this mask, and it is labeled alpha matting of this image?
    Thank you very much!

@Boya-Na
Copy link

Boya-Na commented Apr 15, 2021

@ZHKKKe

Hi, Thanks for your sharing. I have a question that my detail loss is still 0 although I have made the value of trimap to be 0, 0.5 or 1, and the value of gr_matte to be between [0, 1]. Which point should I check? Thanks a lot!

@Boya-Na
Copy link

Boya-Na commented Apr 15, 2021

@dzyjjpy
I guess the ground truth needs to be turned over in value possibly if you want to cut the sky as the background. I mean the value of the sky should be zero and others be 1. Another point is that the area which is out of the fisheye part may impact the feature extraction. In addition, the backbone part is setting as segment human and I also suggest paying attention to the background matting v2 for your problem (for example, use a single color that is collecting from the sky part as the background image). These are just my opinions and It may not be working for your problem, but I just hope they can be useful for you. Thanks

@Boya-Na
Copy link

Boya-Na commented Apr 15, 2021

@ZHKKKe Thanks and I have solved it. That is my error that I generated the trimap with 128 / 255 which is not equal to 0.5

@upperblacksmith
Copy link

upperblacksmith commented Apr 22, 2021

@luoww1992

i have finished all steps,
it is good in many images when testing,
but there have flicker and jitter in matting image edge in some images。
can we do something to reduce it ?
such as:
use high resolution images to train modnet;
change the color space;
change some args in traing or Soc
老哥,我在使用SOC训练的时候,损失一直为NAN,可以参考一下你SOC部分的代码吗

@upperblacksmith
Copy link

i have added it in trainDefault.txt ----- 原始邮件 ----- 发件人:Shu-Hao Ye notifications@github.com 收件人:ZHKKKe/MODNet MODNet@noreply.github.com 抄送人:luoww1992 luoww1992@sina.com, Mention mention@noreply.github.com 主题:Re: [ZHKKKe/MODNet] train dataset questions (#57) 日期:2021年02月25日 15点08分 @luoww1992 that's great ... can you share the dataset? sorry, i can not share it before i get the authorization, there are some personal images. you can try to make it by alpha images please share me the complete code that can be trained correctly, many thanks ! mail: andy910389@gmail.com — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

Thank you very much! I will read the code carefully, hope it will work!

Could you share the complete code that can be trained correctly? And I would appreciate it. My email is 755976168@qq.com, thanks a lot!

@luoww1992
Copy link
Author

look up, there is a link trainDefault.txt.
Or you can search it in current page with 'Ctrl+F'.

@upperblacksmith
Copy link

@luoww1992

look up, there is a link trainDefault.txt.
Or you can search it in current page with 'Ctrl+F'.
That's right, I have found that in early.
so,would you mind tell me the image shape in your training datasets when you have a try soc step

@upperblacksmith
Copy link

@luoww1992
when your using soc step,which model your update the parameters. modnet or backup_modnet?
666

@luoww1992
Copy link
Author

@upperblacksmith , i save two models at first, then i test them, i can't find the different between the models.

@luoww1992
Copy link
Author

@upperblacksmith ,
my dataset size is 2048*1152, when training, the size is half.
and the loss,
maybe you have no some steps:
Q1:You need to normalize the ground truth to [0, 1].
Please add transforms to ImagesDataset.

Q2: the detail_loss is 0
The pixel values in the loaded trimap should be 0=backgroud, 0.5=unknown or 1=foreground. Please pre-process it in your dataset.

You can refer to the latest comments for more information:

    image (torch.autograd.Variable): input RGB image
                                     its pixel values should be normalized
    trimap (torch.autograd.Variable): trimap used to calculate the losses
                                      its pixel values can be 0, 0.5, or 1
                                      (foreground=1, background=0, unknown=0.5)
    gt_matte (torch.autograd.Variable): ground truth alpha matte
                                        its pixel values are between [0, 1]

@luoww1992
Copy link
Author

luoww1992 commented Apr 26, 2021

@upperblacksmith ,
add the code to training code:

run.txt

@upperblacksmith
Copy link

@ZHKKKe
what do you think of the code in the picture when using soc step,i can't understand how to initialize modnet and backup_modnet.I will appreciate for your advices.
11111

@upperblacksmith
Copy link

@ZHKKKe
i have finished all steps,
it is good in many images when testing,
but there have flicker and jitter in matting image edge in some images。
can we do something to reduce it ?
such as:
use high resolution images to train modnet;
change the color space;
change some args in traing or Soc
@luoww1992 could you have any ideas to solved it right now.

@ZHKKKe
Copy link
Owner

ZHKKKe commented Apr 29, 2021

@upperblacksmith
For your SOC problem, the initialization of SOC is the trained model. modnet_backup means you copy the trained model and fix its weights.

@huang5656151
Copy link

@dzyjjpy
For your questions:
Q1: will it cause some error in training ?
It seems that you use a ground truth with three channels for training. However, the outputs of the model have only one channel.
Please check your ground truth.
The groud truth has only one channel

Q2: However, the inference result looks really bad.
Can you share your ground truth for training?
As the png file I attach.
0007

@dzyjjpy
For your questions:
Q1: will it cause some error in training ?
It seems that you use a ground truth with three channels for training. However, the outputs of the model have only one channel.
Please check your ground truth.
The groud truth has only one channel

Q2: However, the inference result looks really bad.
Can you share your ground truth for training?
As the png file I attach.
0007

@dzyjjpy 问一下这个天空和树枝的分割的效果,有没有改善?

@ntquyen11
Copy link

@luoww1992 Alpha image of your dataset means right or left figure?

image

@luoww1992
Copy link
Author

luoww1992 commented May 31, 2021 via email

@ZHKKKe ZHKKKe closed this as completed Jun 21, 2021
@tengshaofeng
Copy link

@FraPochetti Hi, for your questions: 1. when you say 100k samples, do you mean 100k distinct images + related mattes. 100k samples is generated from 3k hand-annotated foregrounds by compositing each foreground with about 30 different backgrounds. 2. I understand from the paper that you have 3k hand-annotated images. Do you get to 100k by pasting the extracted foregrounds onto a randomly chosen set of backgrounds (as most other papers do)? Yes, you are correct. 3. by input-size=512, do you mean that you rescale the entire image to 512x512 or do you take random 512x512 patches of it (applicable only if your image has some resolution > 512 of course)? We composited the training set with the images of size 512x512 directly. During composition, we first rescale each foregournd to the size between $384~768$ randomly (if the rescaled size > 512, cropping is required). We then composited the rescaled foregournd with the backgrounds. 4. do you use any specific augmentation, other than hflip, color jittering, etc? Nope. In our case, adding Gaussian noise or Gaussian blur will decrease the performance.

@ZHKKKe For the question 3, if i rescale the foreground to 512x640, then random crop from that, then composite it with center crop of one background image. if foreground is rescaled to 512x400, then paste it to center of the center crop of one background image. Is it right? thanks.

@ZHKKKe
Copy link
Owner

ZHKKKe commented Apr 27, 2022

@tengshaofeng
Not to center of the center crop of one background. I think it is to _anywhere_ of the center crop of one background

@tengshaofeng
Copy link

@tengshaofeng Not to center of the center crop of one background. I think it is to _anywhere_ of the center crop of one background

Thanks for your reply.

@shizidushu
Copy link

@tengshaofeng May you share your dataloader code?

@tengshaofeng
Copy link

tengshaofeng commented Jul 13, 2022

@tengshaofeng May you share your dataloader code?

class ImagesDataset(Dataset):

def __init__(self, root, transform=None, ref_size=512):
    self.root = root
    self.transform = transform
    self.tensor = transforms.Compose([transforms.ToTensor()])
    self.ref_size = ref_size
    self.alphas = []
    self.alphas += sorted(glob.glob(self.root+'/alphas_v1/*'))   
    print('total imgs:', len(self.alphas))

def getTrimap(self, alpha):
    fg = np.array(np.equal(alpha, 255).astype(np.float32))
    unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))  # unknown = alpha > 0
    unknown = unknown - fg
    unknown = morphology.distance_transform_edt(unknown == 0) <= np.random.randint(1, 20)
    trimap = fg
    trimap[unknown] = 0.5
    # print(trimap[:, :, :1].shape)
    return trimap  # [:, :, :1]

def get_trimap(self, alpha):
    foreground = alpha > 0
    ### 以下连续几行修复了,当alpha为全0时候出错,即没有前景的是时候
    res = None
    res = Image.fromarray(foreground).getbbox()
    if res is None:
        left, upper, right, ylower = 0, 0, alpha.shape[1], alpha.shape[0]
    else:
        left, upper, right, ylower = res

    bbox_size = ((right - left) + (ylower - upper)) // 2
    d_size = bbox_size // 256 * random.randint(10, 20)  # dilate kernel size
    e_size = bbox_size // 256 * random.randint(10, 20)  # erode kernel size
    alpha = alpha / 255.0  # numpy array of your matte (with values between [0, 1])
    trimap = (alpha >= 0.9).astype('float32')
    not_bg = (alpha > 0).astype('float32')
    trimap[np.where(
        (grey_dilation(not_bg, size=(d_size, d_size)) - grey_erosion(trimap, size=(e_size, e_size))) != 0)] = 0.5
    return trimap


def __len__(self):
    return len(self.alphas)

def __getitem__(self, idx):
    try:
        item = self.load_item(idx)
    except Exception as e:
        print('loading error: ', self.alphas[idx], e)
        item = self.load_item(0)  # 防止异常
    return item

def load_item(self, idx):
    alpha = np.array(Image.open(self.alphas[idx]))  # cv2.imread(self.alphas[idx], -1), 防止libpng warning: iCCP
    alpha = alpha[..., -1] if len(alpha.shape) > 2 else alpha
    img_f = os.path.splitext(self.alphas[idx].replace('alphas', 'images'))[0] + '.jpg'
    if not os.path.exists(img_f):
        img_f = img_f.replace('.jpg', '.png')
    # print(img_f)
    img = Image.open(img_f)  # rgb cv2.imread(img_f)
    img = np.asarray(img)
    if len(img.shape) == 2:
        img = img[:, :, None]
    if img.shape[2] == 1:
        img = np.repeat(img, 3, axis=2)
    elif img.shape[2] == 4:
        img = img[:, :, 0:3]

    im_h, im_w, im_c = img.shape
    #  非标准512x512图片,resize到短边为512~800,然后random crop
    if not (im_h == self.ref_size and im_w == self.ref_size):
        random_size = np.random.randint(512, 1201)
        if im_w >= im_h:
            im_rh = random_size
            im_rw = int(im_w / im_h * random_size)
        else:
            im_rw = random_size
            im_rh = int(im_h / im_w * random_size)

        img = cv2.resize(img, (im_rw, im_rh), cv2.INTER_CUBIC)
        alpha = cv2.resize(alpha, (im_rw, im_rh), cv2.INTER_CUBIC)
        # center crop
        # x0 = (im_rw - self.ref_size) // 2
        # y0 = (im_rh - self.ref_size) // 2
        # img = img[y0:y0+self.ref_size, x0:x0+self.ref_size, ...]
        # alpha = alpha[y0:y0+self.ref_size, x0:x0+self.ref_size, ...]
        # random crop
        x0 = randint(0, im_rw - self.ref_size + 1)
        y0 = randint(0, im_rh - self.ref_size + 1)
        img = img[y0:y0 + self.ref_size, x0:x0 + self.ref_size, ...]
        alpha = alpha[y0:y0 + self.ref_size, x0:x0 + self.ref_size, ...]



    trimap = self.get_trimap(alpha)
    # print(trimap.shape)

    # 左右镜像增广
    if np.random.binomial(1, 0.5) > 0:
        img = img[:, ::-1, ...].copy()
        alpha = alpha[:, ::-1, ...].copy()
        trimap = trimap[:, ::-1, ...].copy()

    if self.transform:
        img = self.transform(img)
    alpha = self.tensor(alpha)
    trimap = self.tensor(trimap)
    return self.alphas[idx], img, trimap, alpha

@shizidushu
Copy link

@tengshaofeng Thank you very much for the dataloader code. In x0 = randint(0, im_rw - self.ref_size + 1), +1 should be removed, or there might be only 511 left (when ref_size=512).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests