In [2]:
import deeptriplet
import deeptriplet.datasets
import deeptriplet.triplet
from deeptriplet.models.deeplabv3p import DeepLabSpatialEarly
from deeptriplet.datasets import PascalMultiTriplet
from deeptriplet.triplet import MultiTripletPreselected

import torch
torch.backends.cudnn.benchmark = True
import numpy as np
import matplotlib.pyplot as plt
import PIL
from PIL import Image, ImageOps, ImageFilter
import math
import random

from torch.utils import data
from torchvision import transforms
from torch import optim

trainset = deeptriplet.datasets.PascalDatasetRandomTripletAugmented(
                        pascal_root="/scratch-second/yardima/datasets/VOC2012",
                        split_file="/home/yardima/Python/experiments/pascal_split/train_obj.txt",
                        n_triplets=500)

trainloader = data.DataLoader(trainset,
                                batch_size=8,
                                num_workers=4,
                                shuffle=True)

net = DeepLabSpatialEarly(backbone='resnet', output_stride=16, num_classes=64, sync_bn=False, freeze_bn=False, dynamic_coordinates=False, pretrained=False)
net = net.cuda()

d = torch.load("/scratch-second/yardima/pretrained-models/deeplab-resnet-v3-plus.pth")
net.init_from_semseg_model(d)
net = net.cuda()
init_lr = 3e-4
optimizer = optim.SGD([{'params': net.get_1x_lr_params(), 'lr': init_lr},
                       {'params': net.get_10x_lr_params(), 'lr': init_lr * 10}],
                      lr=init_lr, momentum=0.9, weight_decay=5e-4)

# loss_fn = nn.CrossEntropyLoss(ignore_index=255)
loss_fn = deeptriplet.triplet.RandomTripletPreselected(n_batch=8, n_triplets=500)
net = net.train()

In [3]:
%%time
for ii, sample_batched in enumerate(trainloader):

    inputs, labels = sample_batched
    inputs = inputs.cuda()
    labels = labels.cuda()

    optimizer.zero_grad()

    outputs = net.forward(inputs)
    loss = loss_fn.compute_loss(outputs, labels)

    loss.backward()
    
    optimizer.step()
    
    print(loss.item())

    del loss, outputs
    
    if ii > 50:
        break

3.0755867958068848
1.9772038459777832
2.9934940338134766
4.014428615570068
4.535888671875
3.2391197681427
6.229218482971191
3.4361634254455566
3.332658529281616
2.1183698177337646
2.455063819885254
1.511649489402771
1.598556637763977
1.7728849649429321
1.3821467161178589
1.3036298751831055
1.504967212677002
1.33651864528656
1.291544795036316
1.5299073457717896
1.1572859287261963
1.241246223449707
1.1037187576293945
0.8393304347991943
1.1357799768447876
1.062286138534546
0.9933729767799377
1.1239629983901978
0.9717240333557129
0.878770112991333
1.158694863319397
0.8220223784446716
1.0184012651443481
1.1250277757644653
0.6477077603340149
0.9476750493049622
0.9854138493537903
1.024397850036621
1.0736340284347534
0.8611063361167908
0.8044372797012329
0.6669158935546875
0.6082229018211365
0.7692926526069641
0.9317765235900879
0.6247314810752869
0.7059388756752014
0.7799034118652344
0.9340869784355164
1.072808861732483
0.5462229251861572
1.0621817111968994
CPU times: user 29.8 s, sys: 15.4 s

In [3]:


def read_labeled_image_list(data_dir, data_list):
    """Reads txt file containing paths to images and ground truth masks.

    Args:
      data_dir: path to the directory with images and masks.
      data_list: path to the file with lines of the form '/path/to/image /path/to/mask'.

    Returns:
      Two lists with all file names for images and masks, respectively.
    """
    f = open(data_list, 'r')
    images = []
    masks = []
    for line in f:
        image, mask = line.strip("\n").split(' ')
        images.append(data_dir + image)
        masks.append(data_dir + mask)

    return images, masks

class PascalMultiTriplet(data.Dataset):
    """Data loader for the Pascal VOC semantic segmentation dataset.
    """

    def __init__(
            self,
            *,
            pascal_root,
            split_file,
            n_triplets,
            samples_pos,
            samples_neg
    ):
        self.split_file = split_file
        self.pascal_root = pascal_root
        self.n_triplets = n_triplets

        self.n_classes = 21
        
        self.crop_size = 513
        self.base_size = 513

        self.image_list, self.label_list = read_labeled_image_list(self.pascal_root, self.split_file)

        self.transforms = transforms.Compose([transforms.ToTensor(),
                                              transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                   std=[0.229, 0.224, 0.225])])
        
        self.fill_image = (124, 116, 104)
        self.fill_label = 255
        
        self.samples_pos = samples_pos
        self.samples_neg = samples_neg


    def __len__(self):
        return len(self.label_list)

    def __getitem__(self, index):
        im_path = self.image_list[index]
        lbl_path = self.label_list[index]

        img = PIL.Image.open(im_path)
        lbl = PIL.Image.open(lbl_path)

        ## augmentation
        img, lbl = self._augment(img, lbl)

        img, lbl = self._random_crop(img, lbl)
        
        img = np.array(img, dtype=np.float32) / 255.0
        lbl = np.array(lbl, dtype=np.long)
        #         lbl[lbl==255] = 0

        img = self.transforms(img)
        
        minrange = [0, 0]
        maxrange = [513, 513]
        
        triplets = self._generate_triplet(lbl)
        

        return (img, *triplets)
    
    
    def _augment(self, img, lbl):
        
        if np.random.rand() > 0.5:
            img = img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
            lbl = lbl.transpose(PIL.Image.FLIP_LEFT_RIGHT)
            
        if np.random.random() < 0.5:
            img = img.filter(PIL.ImageFilter.GaussianBlur(
                radius=random.random()))

        
        
        return img, lbl

    
    def _random_crop(self, img, mask):
        # random scale (short edge)
        short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
        w, h = img.size
        if h > w:
            ow = short_size
            oh = int(1.0 * h * ow / w)
        else:
            oh = short_size
            ow = int(1.0 * w * oh / h)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # pad crop
        padh = self.crop_size - oh if oh < self.crop_size else 0
        padw = self.crop_size - ow if ow < self.crop_size else 0
        if short_size < self.crop_size:
            img = ImageOps.expand(img, border=(padw//2 + 1, padh//2 + 1, padw//2 + 1, padh//2 + 1), fill=self.fill_image)
            mask = ImageOps.expand(mask, border=(padw//2 + 1, padh//2 + 1, padw//2 + 1, padh//2 + 1), fill=self.fill_label)
        # random crop crop_size
        w, h = img.size
        x1 = random.randint(0, w - self.crop_size)
        y1 = random.randint(0, h - self.crop_size)
        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))

        ## TODO: unnecessary code
#         minrange = [max(x1 - padw//2 - 1, 0), max(y1 - padh//2 - 1, 0)]
#         maxrange = [min(x1 + 2 * self.crop_size - padw//2 - 1 - w, self.crop_size - 1), 
#                     min(y1 + 2 * self.crop_size - padh//2 - 1 - h, self.crop_size - 1)]
        
        return img, mask # , minrange, maxrange
    

    def _generate_triplet(self, lbl):
        lbl_view = lbl[0:513, 0:513]
        
        options = np.nonzero(lbl_view.reshape(-1) != 255)[0]
        
        if options.shape[0] > 0:
            ai = np.random.randint(low=0, 
                                    high=options.shape[0], 
                                    size=(self.n_triplets,))
            ai = options[ai]
        else:
            ai = np.array([0] * self.n_triplets, dtype=np.int64)
        

        classes, inv_map = np.unique(lbl_view, return_inverse=True)
        n_classes = len(classes)
        inv_map = inv_map.reshape(lbl_view.shape[0], lbl_view.shape[1])
        inv_map_flat = inv_map.reshape(-1)

        class_lookup = (np.arange(n_classes, dtype=np.int32).reshape((1, 1, n_classes)) !=
                        inv_map.reshape(lbl_view.shape[0], lbl_view.shape[1], 1))
        class_lookup = np.transpose(class_lookup, axes=[2, 0, 1])
        
        lbl_view_flat = lbl_view.reshape(-1)
        
        lneg = []
        lpos = []
        for i in range(n_classes):
            lneg.append(np.transpose(np.logical_and(lbl_view != 255, class_lookup[i]).reshape(-1).nonzero()).reshape((-1)))
            lpos.append(np.transpose(
                                np.logical_and(lbl_view != 255, 
                                               np.logical_not(class_lookup[i])).reshape(-1).nonzero()).reshape((-1)))
            
        ni, pi = [], []
        for i in range(self.n_triplets):
            cni = lneg[inv_map_flat[ai[i]]][lneg[inv_map_flat[ai[i]]] != ai[i]] 
            cpi = lpos[inv_map_flat[ai[i]]][lpos[inv_map_flat[ai[i]]] != ai[i]]
            
            for _ in range(self.samples_pos):
                if len(cni) == 0 or len(cpi) == 0:
                    #ni.append(ai[i])
                    pi.append(ai[i])
                else:
                    #ni.append( np.random.choice(cni))
                    pi.append( np.random.choice(cpi))
                    
            for _ in range(self.samples_neg):
                if len(cni) == 0 or len(cpi) == 0:
                    ni.append(ai[i])
                    #pi.append(ai[i])
                else:
                    ni.append( np.random.choice(cni))
                    #pi.append( np.random.choice(cpi))
            
        #aix, aiy = np.unravel_index(ai, dims=(lbl_view.shape[0], lbl_view.shape[1]))
        #aix += minrange[0]
        #aiy += minrange[1]
        #ai = np.stack((aix, aiy))
        
        #pix, piy = np.unravel_index(pi, dims=(lbl_view.shape[0], lbl_view.shape[1]))
        #pix += minrange[0]
        #piy += minrange[1]
        #pi = np.stack((pix, piy))
        
        #nix, niy = np.unravel_index(ni, dims=(lbl_view.shape[0], lbl_view.shape[1]))
        #nix += minrange[0]
        #niy += minrange[1]
        #ni = np.stack((nix, niy))
        
        #triplets = np.stack((ai, pi, ni), axis=0)
        pi =  np.array(pi, dtype=np.int64)
        ni =  np.array(ni, dtype=np.int64)
        
        ai = torch.tensor(ai.reshape(self.n_triplets), dtype=torch.long)
        pi = torch.tensor(pi.reshape(-1, self.samples_pos), dtype=torch.long)
        ni = torch.tensor(ni.reshape(-1, self.samples_neg), dtype=torch.long)
        
        return ai, pi, ni
    


    

In [4]:

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim


class MultiTripletPreselected:

    def __init__(self, n_batch, margin=1, l2_penalty=1e-3):
        self.margin = margin
        self.l2_penalty = l2_penalty
        self.n_batch = n_batch

        self.loss_fun = nn.MarginRankingLoss(margin=margin)
        self.target = torch.FloatTensor(3 * n_batch).fill_(1).cuda()
        
        i = 0
        self.dim0a = []
        self.dim0p = []
        self.dim0n = []
        for i in range(8):
            self.dim0a += 3 * [i]
            self.dim0p += 3 * [i]
            self.dim0n += 3 * [i]
            
        self.dim0a = torch.from_numpy(np.array(self.dim0a, dtype=np.long)).cuda()
        self.dim0p = torch.from_numpy(np.array(self.dim0p, dtype=np.long)).cuda()
        self.dim0n = torch.from_numpy(np.array(self.dim0n, dtype=np.long)).cuda()

    def compute_loss(self, output, a, p, n):

        if len(self.dim0a) != a.shape[0] * a.shape[1] or \
            len(self.dim0p) != p.shape[0] * p.shape[1] * p.shape[2] or \
            len(self.dim0n) != n.shape[0] * n.shape[1] * n.shape[2]:
            
            self.dim0a = []
            self.dim0p = []
            self.dim0n = []
            for i in range(a.shape[0]):
                self.dim0a += a.shape[1] * [i]
                self.dim0p += p.shape[1] * p.shape[2] * [i]
                self.dim0n += n.shape[1] * n.shape[2] * [i]
            
            self.dim0a = torch.from_numpy(np.array(self.dim0a, dtype=np.long)).cuda()
            self.dim0p = torch.from_numpy(np.array(self.dim0p, dtype=np.long)).cuda()
            self.dim0n = torch.from_numpy(np.array(self.dim0n, dtype=np.long)).cuda()
            
            self.target = torch.FloatTensor(a.shape[0] * a.shape[1]).fill_(1).cuda()
        
        
        n_dim = output.shape[1]
        
        out = output.view(output.shape[0], output.shape[1], -1)
        out = torch.transpose(out, 1, 2)
        
        # compute distances
        v_anch = out[self.dim0a, a.view(-1), :].view(-1, 1, n_dim)
        v_pos = out[self.dim0p, p.view(-1), :].view(-1, p.shape[2], n_dim)
        v_neg = out[self.dim0n, n.view(-1), :].view(-1, n.shape[2], n_dim)
        
        l2_norm = v_anch.pow(2).view(-1, n_dim).sum(dim=1).mean()
        
        delta_p = (v_anch - v_pos).pow(2).sum(dim=2)
        delta_n = (v_anch - v_neg).pow(2).sum(dim=2)

        delta_p, _ = delta_p.min(dim=1)
        delta_n, _ = delta_n.min(dim=1)

        loss = self.loss_fun(delta_n, delta_p, self.target) + l2_norm * self.l2_penalty

        return loss

In [3]:
%%time
trainset = PascalMultiTriplet(
                        pascal_root="/scratch-second/yardima/datasets/VOC2012",
                        split_file="/home/yardima/Python/experiments/pascal_split/train_obj.txt",
                        n_triplets=500,
                        samples_neg=20,
                        samples_pos=10)

trainloader = data.DataLoader(trainset,
                                batch_size=8,
                                num_workers=4,
                                shuffle=True)

loss_fn = MultiTripletPreselected(n_batch=8)

for ii, sample_batched in enumerate(trainloader):

    inputs, a,p,n = sample_batched
    inputs = inputs.cuda()
    a = a.cuda()
    p = p.cuda()
    n = n.cuda()

    optimizer.zero_grad()

    outputs = net.forward(inputs)
    loss = loss_fn.compute_loss(outputs, a, p, n)

    loss.backward()
    
    optimizer.step()
    
    print(loss.item())

    del loss,inputs,a,p,n,outputs
    
    if ii > 50:
        break

1.3335983753204346
1.1757022142410278
0.7121611833572388
0.8790137767791748
1.003315806388855
0.8870450854301453
0.8245775103569031
0.8456610441207886
0.8013911843299866
0.6996610760688782
0.77012038230896
0.604715883731842
0.8514495491981506
0.5517140030860901
0.47690239548683167
0.32431161403656006
0.49633651971817017
0.4512006938457489
0.5363562107086182
0.6377290487289429
0.3827652335166931
0.33467555046081543
0.5080620050430298
0.6009445190429688
0.436949223279953
0.26253804564476013
0.30369994044303894
0.37332314252853394
0.26764845848083496
0.39763572812080383
0.2205560803413391
0.32096898555755615
0.3608282804489136
0.567032516002655
0.27135270833969116
0.3811478614807129
0.3406652808189392
0.29958972334861755
0.2797240614891052
0.3896382451057434
0.2775852382183075
0.3833032548427582
0.23668907582759857
0.5160719156265259
0.3370261788368225
0.1922895461320877
0.23706595599651337
0.195984348654747
0.22637256979942322
0.26917076110839844
0.22303369641304016
0.35435861349105835
C

In [50]:
img, a, p, n = trainset[0]

In [53]:
print(n.shape)

torch.Size([500, 20])
