In [80]:
import os
import cv2
import sys
sys.path.append("/home/data_normal/abiz/wuzhiqiang/wzq/shopee/code_v3/pytorch-image-models")
import timm
import math
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import albumentations as A
import torch.utils.data as data
from transformers import AutoTokenizer

In [52]:
bert_model_arch="bert-base-multilingual-uncased"
root = "/home/data_normal/abiz/wuzhiqiang/wzq/data/shopee-product-matching/test.csv"
test_path = "/home/data_normal/abiz/wuzhiqiang/wzq/data/shopee-product-matching/test_images"
initial_checkpoint = "/home/data_normal/abiz/wuzhiqiang/wzq/shopee/code_v3/result/fold-0/checkpoint/00054500_model.pth"
df = pd.read_csv(root)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
df.head()

Unnamed: 0,posting_id,image,image_phash,title
0,test_2255846744,0006c8e5462ae52167402bac1c2e916e.jpg,ecc292392dc7687a,Edufuntoys - CHARACTER PHONE ada lampu dan mus...
1,test_3588702337,0007585c4d0f932859339129f709bfdc.jpg,e9968f60d2699e2c,(Beli 1 Free Spatula) Masker Komedo | Blackhea...
2,test_4015706929,0008377d3662e83ef44e1881af38b879.jpg,ba81c17e3581cabe,READY Lemonilo Mie instant sehat kuah dan goreng


In [110]:
threshold = 40
batch_size = 4
width, height = 640, 640
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
null_augment = A.Compose(
    [A.Resize(width, height)]
)

In [62]:
def collate_fn(batch):
    batch_size = len(batch)
    image = []
    index = []
    input_ids = []
    token_type_ids = []
    attention_mask = []
    for r in batch:
        image.append((r['image'] / 255.0 - mean) / std)
        index.append(r['index'])
        input_ids.append(r['items']['input_ids'])
        token_type_ids.append(r['items']['token_type_ids'])
        attention_mask.append(r['items']['attention_mask'])
    
    image = np.stack(image)
    input_ids = torch.stack(input_ids)
    token_type_ids = torch.stack(token_type_ids)
    attention_mask = torch.stack(attention_mask)

    image = image.transpose(0, 3, 1, 2)
    image = torch.from_numpy(image).contiguous().float()

    return {
        'index': index,
        'image': image,
        'input_ids': input_ids,
        'token_type_ids': token_type_ids,
        'attention_mask': attention_mask,
    }

In [30]:
class ShopeeDataset(data.Dataset):
    def __init__(self, df, augment=null_augment, bert_model_arch=bert_model_arch):
        self.df = df
        self.augment = augment
        self.tokenizer = AutoTokenizer.from_pretrained(bert_model_arch)
        texts = df['title'].fillna("NaN").tolist()  # 所有标题生成的文本列表
        self.encodings = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=128,
        )

    def __getitem__(self, index):
        posting_info = self.df.iloc[index]
        title = posting_info.title
        image_name = posting_info.image
        image_path = os.path.join(test_path, image_name)

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype(np.float32)

        items = {k: torch.tensor(v[index]) for k, v in self.encodings.items()}
        
        if self.augment:
            augmented = self.augment(image=image)
            image = augmented['image']

        r = {
            'index': index,
            'image': image,
            'items': items
        }
        return r

    def __len__(self):
        return self.df.shape[0]

    def __str__(self):
        string = ''
        string += '\tlen     = %d\n' % len(self)
        return string

In [39]:
class ArcMarginProduct(nn.Module):
    """Implement of large margin arc distance: :
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        s: norm of input feature
        m: margin
        cos(theta + m)
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        s: float = 30.0,
        m: float = 0.5,
        easy_margin: bool = False,
        smoothing: float = 0.0,
    ):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.smoothing = smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, inputs, labels):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(inputs), F.normalize(self.weight)).float()
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size(), device=labels.device)
        one_hot.scatter_(1, labels.view(-1, 1).long(), 1)
        if self.smoothing > 0:
            one_hot = (
                1 - self.smoothing
            ) * one_hot + self.smoothing / self.out_features

        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

In [44]:
class ShopeeImgNet(nn.Module):
    def __init__(self, arch="efficientnet_b3",
                 dim=2048, num_classes=11014, dropout=0.0, pretrained=True):
        super(ShopeeImgNet, self).__init__()
        self.backbone = timm.create_model(arch, pretrained=pretrained)
        final_in_features = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Identity()
        self.backbone.global_pool = nn.Identity()
        self.bn1 = nn.BatchNorm2d(final_in_features)
        self.dropout = nn.Dropout(p=dropout)
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(final_in_features, dim)
        self.bn2 = nn.BatchNorm1d(dim)

        self.margin = ArcMarginProduct(in_features=dim,
                                       out_features=num_classes)
        self._init_params()
    
    def _init_params(self):
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.constant_(self.fc.bias, 0)
        nn.init.constant_(self.bn1.weight, 1)
        nn.init.constant_(self.bn2.weight, 1)
        nn.init.constant_(self.bn1.bias, 0)
        nn.init.constant_(self.bn2.bias, 0)

    def forward(self, x, labels=None):
        feature = self.backbone(x)
        feature = self.bn1(feature)
        feature = self.dropout(feature)
        feature = self.pooling(feature).view(x.size(0), -1)
        feature = self.fc(feature)
        feature = self.bn2(feature)
        if labels is not None:
            return self.margin(feature, labels)
        return feature

In [84]:
def extract_feat(net, valid_loader):
    features = []
    with torch.no_grad():
        net.eval()
        valid_num = 0
        for t, batch in enumerate(valid_loader):
            index = batch['index']
            image = batch['image'].cuda()
            feat = net(image)
            features += [feat.detach().cpu()]
            valid_num += len(index)
        assert (valid_num == len(valid_loader.dataset))
    features = torch.cat(features).cpu().numpy()
    return features

In [88]:
def find_matches(posting_ids, threshold, features, n_batches, min_indices=1):
    assert len(posting_ids) == len(features)
    sim_threshold = threshold / 100
    y_pred = []
    n_rows = features.shape[0]
    bs = n_rows // n_batches
    batches = []
    for i in range(n_batches):
        left = bs * i
        right = bs * (i + 1)
        if i == n_batches - 1:
            right = n_rows
        batches.append(features[left: right, :])
    for batch in batches:
        dot_product = batch @ features.T
        selection = dot_product > sim_threshold
        for j in range(len(selection)):
            IDX = selection[j]  # 阈值之内的相似图片
            if np.sum(IDX) < min_indices:
                IDX = np.argsort(dot_product[j])[-min_indices:]
            y_pred.append(posting_ids[IDX].tolist())
    assert len(y_pred) == len(posting_ids)
    return y_pred

In [75]:
test_dataset = ShopeeDataset(df)
test_dataloader = data.DataLoader(
    test_dataset,
    shuffle=False,
    batch_size=batch_size,
    collate_fn=collate_fn
)

In [85]:
net = ShopeeImgNet()
net.to(device)
state_dict = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)['state_dict']
net.load_state_dict(state_dict,strict=False)
net.eval()
del state_dict

In [111]:
features = extract_feat(net, test_dataloader)
features = F.normalize(torch.from_numpy(features)).numpy()

In [112]:
posting_ids = df['posting_id']
y_pred = find_matches(posting_ids, threshold, features, n_batches=10)

In [113]:
df['matches'] = y_pred
df['matches'] = df['matches'].apply(lambda x: " ".join(x))

In [114]:
df[['posting_id','matches']].to_csv('submission.csv',index=False)

In [115]:
pd.read_csv('submission.csv').head()

Unnamed: 0,posting_id,matches
0,test_2255846744,test_2255846744
1,test_3588702337,test_3588702337
2,test_4015706929,test_4015706929
