In [None]:
!nvidia-smi

Wed Sep 15 15:04:55 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8    28W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import numpy as np
from PIL import Image

from torch.utils.data import Dataset, Sampler, DataLoader
from torch.utils.data.sampler import BatchSampler

import os
import imageio
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo

In [None]:
# init params:
# semihard margin 0.25
# batchsize 64 (4,4,4)
# embedding size 128
# optimizer SGD momentum 0.9 lr 0.001 weight_decay 0.0005 each 2 epoch->lr/10

# embedding norma

# aug:
# randomresizecrop
# randomhorizontalflip
# normalization

# throw away cats were img count < 4 

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [None]:
!cp "/content/gdrive/My Drive/SOP/train.zip" "train.zip"
!cp "/content/gdrive/My Drive/SOP/sop-splitfile.zip" "sop-splitfile.zip"

In [None]:
!unzip train.zip -d './sop-train'
!unzip sop-splitfile.zip -d './sop-splitfile'

In [None]:
split_file = pickle.load(open("/content/sop-splitfile/SOP_train_valid_split.pickle", "rb"))

In [None]:
split_file['cabinet_final']['train'].keys()

dict_keys(['paths', 'product_labels', 'category_labels'])

In [None]:
split_file['cabinet_final']['train']['paths'][500:530]


In [None]:
split_file['cabinet_final']['train']['product_labels'][500:530]

In [None]:
np.unique(split_file['cabinet_final']['valid']['category_labels'])

array([0])



---



> **Dataloader**

---





In [None]:
from torchvision.transforms import (
    RandomResizedCrop,
    RandomHorizontalFlip,
    ColorJitter,
    ToTensor,
    Resize,
    CenterCrop,
    Compose,
    ToPILImage,
)

In [None]:
def HWC_to_CHW(img):
    return np.transpose(img, (2, 0, 1))


class TorchvisionNormalize:
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.mean = mean
        self.std = std

    def __call__(self, img):
        imgarr = np.asarray(img)
        proc_img = np.empty_like(imgarr, np.float32)

        proc_img[..., 0] = (imgarr[..., 0] / 255.0 - self.mean[0]) / self.std[0]
        proc_img[..., 1] = (imgarr[..., 1] / 255.0 - self.mean[1]) / self.std[1]
        proc_img[..., 2] = (imgarr[..., 2] / 255.0 - self.mean[2]) / self.std[2]

        return HWC_to_CHW(proc_img)

In [None]:
train_transforms = Compose(
    [
        ToPILImage(),
        RandomResizedCrop(224, scale=(0.25, 1.0)),
        RandomHorizontalFlip(),
        TorchvisionNormalize(),
    ]
)


val_transforms = Compose(
    [
        ToPILImage(), 
        Resize(256),
        CenterCrop(224), 
        TorchvisionNormalize()
    ]
)

In [None]:
def rebuild_path(path):
    s = os.path.split(path)
    return s[0] + "/" + s[0] + "_" + s[1]


def load_image_paths(split_file, mode="train"):
    img_paths_list = [
        [rebuild_path(p[25:]) for p in split_file[cat][mode]["paths"]]
        for cat in split_file
    ]
    img_paths_list = np.hstack(img_paths_list).ravel()
    return img_paths_list


def load_image_labels(split_file, mode="train", label_key="category_labels"):
    img_labels_list = [split_file[cat][mode][label_key] for cat in split_file]
    img_labels_list = np.hstack(img_labels_list).ravel()
    return img_labels_list


class SOPImageDataset(Dataset):
    def __init__(self, split_file_path, sop_root, transforms, mode="train"):
        self.split_file = pickle.load(open(split_file_path, "rb"))
        self.img_paths_list = load_image_paths(self.split_file, mode)
        self.img_cat_labels_list = load_image_labels(
            self.split_file, mode, "category_labels"
        )
        self.img_prod_labels_list = load_image_labels(
            self.split_file, mode, "product_labels"
        )
        self.sop_root = sop_root

        self.transforms = transforms

    def __len__(self):
        return len(self.img_paths_list)

    def __getitem__(self, idx):
        path = self.img_paths_list[idx]
        cat_label = np.array(self.img_cat_labels_list[idx])
        prod_label = np.array(self.img_prod_labels_list[idx])

        img_path = self.sop_root + path
        img = np.asarray(imageio.imread(img_path))

        if len(img.shape) == 2:
            img = np.asarray(Image.fromarray(img).convert("RGB"))

        if self.transforms:
            img = self.transforms(img)
        sample = {
            "img": img,
            "cat_label": torch.from_numpy(cat_label),
            "prod_label": prod_label,
        }
        return sample


class ImageRetrievalSampler(BatchSampler):
    def __init__(self, data_source, num_samples=4, num_cats=4, num_prods=4):
        self.data_source = data_source
        self.num_samples = num_samples
        self.num_cats = num_cats
        self.num_prods = num_prods
        self.cat_unique_labels = np.unique(data_source.img_cat_labels_list)
        self.prod_unique_labels = np.unique(data_source.img_prod_labels_list)
        self.sf = data_source.split_file

    def __iter__(self):
        cats = np.random.choice(self.cat_unique_labels, self.num_cats, replace=False)

        valid_prod, prod_counts = np.unique(
            self.data_source.img_prod_labels_list, return_counts=True
        )
        valid_prod = valid_prod[np.where(prod_counts >= self.num_prods)]
        i = 0
        while i < len(self):
            out = []
            for c in cats:
                # print('cat: ', c)
                folder = list(self.sf.keys())[c]
                prod_labels = self.sf[folder]["train"]["product_labels"]
                prods_to_choice = [p for p in prod_labels if p in valid_prod]
                prods = np.random.choice(prods_to_choice, self.num_prods, replace=False)
                for p in prods:
                    idxs = np.where(self.data_source.img_prod_labels_list == p)[0]
                    idxs = np.random.choice(idxs, self.num_samples, replace=False)
                    out.extend(idxs)
            yield out
            i += 1

    def __len__(self):
        return len(self.data_source) // (
            self.num_samples * self.num_cats * self.num_prods
        )

In [None]:
dataset = SOPImageDataset(
    "/content/sop-splitfile/SOP_train_valid_split.pickle",
    "/content/sop-train",
    train_transforms,
)
sampler = ImageRetrievalSampler(dataset)

In [None]:
import matplotlib.pyplot as plt

for i, idxs in enumerate(sampler):
    for p in idxs:
        plt.imshow((dataset[p]["img"]).transpose(1, 2, 0))
        plt.show()
    if i == 0:
        break

In [None]:
split_file = pickle.load(
    open("/content/sop-splitfile/SOP_train_valid_split.pickle", "rb")
)

In [None]:
split_file.keys()

dict_keys(['cabinet_final', 'bicycle_final', 'chair_final', 'sofa_final', 'mug_final', 'stapler_final', 'toaster_final', 'coffee_maker_final', 'table_final', 'fan_final', 'lamp_final', 'kettle_final'])



---



> **Model**

---





In [None]:
resnet50_url = "https://download.pytorch.org/models/resnet50-19c8e357.pth"


class FixedBatchNorm(nn.BatchNorm2d):
    def forward(self, input):
        return F.batch_norm(
            input,
            self.running_mean,
            self.running_var,
            self.weight,
            self.bias,
            training=False,
            eps=self.eps,
        )


class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = FixedBatchNorm(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn2 = FixedBatchNorm(out_channels)
        self.conv3 = nn.Conv2d(
            out_channels, out_channels * 4, kernel_size=1, bias=False
        )
        self.bn3 = FixedBatchNorm(out_channels * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        else:
            identity = x

        out = self.relu(out + identity)
        return out


class Net(nn.Module):
    def __init__(self, layers=(3, 4, 6, 3)):
        super(Net, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(
            3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = FixedBatchNorm(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # self.layer0 = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool)
        self.layer1 = self._get_block(64, layers[0], stride=1)
        self.layer2 = self._get_block(128, layers[1], stride=2)
        self.layer3 = self._get_block(256, layers[2], stride=2)
        self.layer4 = self._get_block(512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # nn.AvgPool2d(7, stride=1)

    def _get_block(self, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * 4:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels,
                    out_channels * 4,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                FixedBatchNorm(out_channels * 4),
            )
        layers = [
            Bottleneck(
                self.in_channels, out_channels, stride=stride, downsample=downsample
            )
        ]
        self.in_channels = out_channels * 4

        for i in range(1, blocks):
            layers.append(Bottleneck(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        #         print('avgpool', x.size())
        x = x.view(x.size(0), -1)
        #         print('flat', x.size())
        # x = self.fc(x)
        #         print('out', x.size())
        return x


def resnet50(pretrained=True):

    model = Net(layers=(3, 4, 6, 3))
    if pretrained:
        state_dict = model_zoo.load_url(resnet50_url)
        state_dict.pop("fc.weight")
        state_dict.pop("fc.bias")
        model.load_state_dict(state_dict)
    return model


class RetrievalNet(Net):
    def __init__(self, embedding_size=128):
        super(RetrievalNet, self).__init__()

        self.embedding_size = embedding_size
        self.backbone = resnet50(True)
        self.fc = nn.Linear(2048, embedding_size)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)

    def forward(self, x):
        out = self.backbone(x)
        out = self.fc(out)
        out = F.normalize(out, dim=1, p=2)
        return out

    def trainable_parameters(self):
        return (list(self.backbone.parameters()), list(self.fc.parameters()))

In [None]:
model = RetrievalNet()



---



> **Loss**

---





In [None]:
from itertools import combinations, product

# np.random.seed(666)
class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, embeddings, labels):
        triplets = self._get_triplets(embeddings, labels)
        ap_dists = F.pairwise_distance(embeddings[triplets[0]], embeddings[triplets[1]])
        an_dists = F.pairwise_distance(embeddings[triplets[0]], embeddings[triplets[2]])
        loss = F.relu(ap_dists - an_dists + self.margin)
        return loss.mean()

    def _get_triplets(self, embeddings, labels):
        distance_matrix = torch.norm(embeddings.unsqueeze(1) - embeddings, p=2, dim=2)
        unique_labels, counts = torch.unique(labels, return_counts=True)
        triplets_idxs = [[] for i in range(3)]

        for label in unique_labels:
            pos_indices = torch.where(labels == label)[0]
            negative_indices = torch.where(torch.logical_not(labels == label))[0]
            if pos_indices.shape[0] < 2:
                continue

            anchor_positives = np.array(list(combinations(pos_indices, 2)))
            for ap in anchor_positives:
                ap_dist = distance_matrix[ap[0], ap[1]]
                an_dist = distance_matrix[ap[0], negative_indices]
                sh_idxs = self._get_semihard(ap_dist, an_dist, self.margin)

                if sh_idxs is not None:
                    neg_idxs = negative_indices[sh_idxs].item()
                    triplets_idxs[0].append(ap[0])
                    triplets_idxs[1].append(ap[1])
                    triplets_idxs[2].append(neg_idxs)
        return triplets_idxs

    def _get_semihard(self, ap, an, margin=0.25):
        # np.random.seed(666)
        loss = ap + margin - an
        semihard = torch.where(loss > 0)[0]
        if semihard.nelement() != 0:
            idx = np.random.choice(semihard)
        else:
            idx = None
        return idx

In [None]:
embeddings = torch.Tensor(np.random.rand(64, 5))
labels = torch.Tensor(np.random.choice([1, 2, 3, 4], 64))

In [None]:
TripletLoss(0.25)(embeddings, labels)

tensor(0.1432)



---

# Training

---



In [None]:
train_dataset = SOPImageDataset(
    "/content/sop-splitfile/SOP_train_valid_split.pickle",
    "/content/sop-train",
    train_transforms,
)
sampler = ImageRetrievalSampler(dataset)

dataloader_train = DataLoader(train_dataset, batch_sampler=sampler)

model = RetrievalNet().cuda()
# model.apply(weights_init)
device = torch.device("cuda:0")
max_step = len(train_dataset) // 64  # * config.epoch_num

# param_groups = model.trainable_parameters()

# optimizer = PolyOptimizer(
#     model.parameters(),
#     lr=config.learning_rate,
#     weight_decay=config.learning_rate_decay,
#     max_step=max_step,
# )

optimizer = optim.SGD(
    model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4, nesterov=True
)


# checkpoint = torch.load("/content/learning_state_without_tricks.pth")
# model.load_state_dict(checkpoint["model_state_dict"])
# optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# epoch = checkpoint["epoch"]
# loss = checkpoint["train loss"]
# print("restored loss: ", loss)

writer = SummaryWriter("/content/gdrive/My Drive/SOP/runs")
model.cuda()
criterion = TripletLoss(0.25)


def train(config):
    model.train()

    for ep in range(10):
        print("Epoch{}/{}".format(ep + 1, config.epoch_num))
        if ep == 9:
            optimizer.param_groups[0]['lr'] /= 10

        train_loss = []
        correct = 0
        total = 0
        for step, batch in enumerate(dataloader_train):
            #             print(image_batch, label_batch)
            img = batch["img"].cuda()
            label = batch["label"].cuda()
            embeddings = model(img)
            loss = criterion(embeddings, label)
            train_loss.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total += label.size(0)
            if step % 50 == 0:
                print("step:%5d/%5d" % (step, max_step))  # *(ep+1)
                print(
                    "loss:%.4f" % (np.mean(train_loss)),
                    "lr: %.4f" % (optimizer.param_groups[0]["lr"]),
                )
            writer.add_scalar('loss/train', np.mean(train_loss), step)
            writer.add_scalar('lr/train', optimizer.param_groups[0]['lr'], step)

    torch.save({
        'epoch': ep,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train loss': np.mean(train_loss)
    }, '/content/gdrive/My Drive/SOP/runs/learning_state.pth')

    # torch.save(model.state_dict(), 'resnet50_SOP.pth')
    torch.cuda.empty_cache()