In [1]:
import numpy as np
import pandas as pd

import os
import warnings
import cv2
import timm  # used for pretrained models

import albumentations  # used for image augmentations
from albumentations.pytorch.transforms import ToTensorV2

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

import math

from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Config:
    DATA_DIR = "../data/train_images"
    TRAIN_CSV = "../data/train.csv"
    MODEL_PATH = "model_14_3.95628162292398.pt"

    IMG_SIZE = 512
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]

    EPOCHS = 15  # Try 15 epochs
    BATCH_SIZE = 64

    NUM_WORKERS = 16
    DEVICE = "cuda"

    CLASSES = 11014
    SCALE = 30
    MARGIN = 0.5

    MODEL_NAME = "eca_nfnet_l0"
    FC_DIM = 512

In [3]:
def visualize(df, index, col):
    """可视化图片和预测结果，预测结果为posting_id的list"""
    row = df.iloc[index]
    preds = row[col]
    img_dir: str = "data/train_images"
    images = [df[df.posting_id == pred].image.values[0] for pred in preds]

    target_title = row.title
    target_img = row.image

    titles = [df[df.posting_id == pred].title.values[0] for pred in preds]
    images = [Image.open(os.path.join(img_dir, img)) for img in images]

    rows = 5
    cols = 5
    fig, ax = plt.subplots(rows, cols, figsize=(20, 20))

    plt.axis("off")

    for i in range(cols):
        ax[0, i].axis("off")

    for i in range(rows):
        ax[i, 0].axis("off")

    ax[0, 0].imshow(Image.open(os.path.join(img_dir, target_img)))
    ax[0, 0].set_title(
        "\n".join([target_title[i : i + 10] for i in range(0, len(target_title), 10)]),
        fontsize=12,
    )

    for i in range(1, rows):
        for j in range(1, cols):
            idx = (i - 1) * cols + (j - 1)
            ax[i, j].axis("off")
            if idx < len(images):
                ax[i, j].imshow(images[idx])
                ax[i, j].set_title(
                    "\n".join(
                        [
                            titles[idx][k : k + 10]
                            for k in range(0, len(titles[idx]), 10)
                        ]
                    ),
                    fontsize=12,
                )


def compute_f1(col):
    def f1(row):
        n = len(np.intersect1d(row["label"], row[col]))
        return 2 * n / (len(row["label"]) + len(row[col]))

    return f1


def compute_recall(col):
    def recall(row):
        n = len(np.intersect1d(row["label"], row[col]))
        return n / len(row["label"])

    return recall


def compute_precision(col):
    def precision(row):
        n = len(np.intersect1d(row["label"], row[col]))
        return n / len(row[col])

    return precision


def compute_precision_K(col, K):
    def precision(row):
        n = len(np.intersect1d(row["label"], row[col][:K]))
        return n / K

    return precision


def compute_AP(col, N):
    """compute average precision"""

    def AP(row):
        n = len(np.intersect1d(row["label"], row[col]))
        max_n = min(len(row[col]), N)
        if n == 0:
            return 0
        return (
            sum(
                [
                    compute_precision_K(col, i)(row)
                    for i in range(1, max_n + 1)
                    if row[col][i - 1] in row["label"]
                ]
            )
            / max_n
        )

    return AP


def retrieval(embs, df, chunk_size=4096, threshold=None, topK=None):
    assert (
        threshold is not None or topK is not None
    ), "Either threshold or topK should be provided"
    assert (
        threshold is None or topK is None
    ), "Only one of threshold or topK should be provided"

    embs_pt = torch.tensor(embs).cuda()

    num_chunks = (embs_pt.shape[0] + chunk_size - 1) // chunk_size
    posting_id = df.posting_id.to_list()
    topk_posting_id = []

    print(f"Chunk Size: {chunk_size}, {num_chunks} chunks")

    for i in tqdm(range(num_chunks)):
        start = i * chunk_size
        end = min((i + 1) * chunk_size, embs_pt.shape[0])
        sim = embs_pt[start:end] @ embs_pt.T

        if topK is not None:
            indices = torch.topk(sim, topK, dim=1).indices.cpu().numpy()
            topk_posting_id.extend([[posting_id[j] for j in row] for row in indices])
        elif threshold is not None:
            mask = sim > threshold
            indices = [
                torch.nonzero(mask[j]).squeeze().cpu().numpy()
                for j in range(mask.shape[0])
            ]
            indices = [np.unique(i) for i in indices]
            sorted_indices = [
                indices[j][np.argsort(-sim[j, indices[j]].cpu().numpy())]
                for j in range(len(indices))
            ]
            topk_posting_id.extend(
                [[posting_id[j] for j in row] for row in sorted_indices]
            )

    # clean up
    del embs_pt
    torch.cuda.empty_cache()

    return topk_posting_id

In [4]:
class ShopeeDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.root_dir = Config.DATA_DIR
        self.transform = transform
        self.length = len(df)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img_path = os.path.join(self.root_dir, row.image)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = row.label_group

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented["image"]

        return {"image": image, "label": torch.tensor(label).long()}


def get_test_transforms():
    return albumentations.Compose(
        [
            albumentations.Resize(Config.IMG_SIZE, Config.IMG_SIZE, always_apply=True),
            albumentations.Normalize(),
            ToTensorV2(p=1.0),
        ]
    )

In [5]:
class Mish_func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.tanh(F.softplus(i))
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]

        v = 1.0 + i.exp()
        h = v.log()
        grad_gh = 1.0 / h.cosh().pow_(2)

        # Note that grad_hv * grad_vx = sigmoid(x)
        # grad_hv = 1./v
        # grad_vx = i.exp()

        grad_hx = i.sigmoid()

        grad_gx = grad_gh * grad_hx  # grad_hv * grad_vx

        grad_f = torch.tanh(F.softplus(i)) + i * grad_gx

        return grad_output * grad_f


class Mish(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        print("Mish initialized")
        pass

    def forward(self, input_tensor):
        return Mish_func.apply(input_tensor)


def replace_activations(model, existing_layer, new_layer):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            model._modules[name] = replace_activations(
                module, existing_layer, new_layer
            )

        if type(module) is existing_layer:
            layer_new = new_layer
            model._modules[name] = layer_new
    return model

In [6]:
class ArcMarginProduct(nn.Module):
    """
    ArcFace Layer, which can be directly integrated into the last layer of the network.
    """

    def __init__(
        self,
        in_features,
        out_features,
        scale=30.0,
        margin=0.50,
        easy_margin=False,
        ls_eps=0.0,
    ):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.scale = scale
        self.margin = margin
        self.ls_eps = ls_eps
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))

        self.easy_margin = easy_margin
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.th = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin

        self.ce = nn.CrossEntropyLoss()

        self.reset_parameters()

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot = torch.zeros(cosine.size(), device="cuda")
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)

        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.scale

        return output, self.ce(output, label)

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)


class ShopeeModel(nn.Module):
    def __init__(
        self,
        n_classes=Config.CLASSES,
        model_name=Config.MODEL_NAME,
        fc_dim=Config.FC_DIM,
        margin=Config.MARGIN,
        scale=Config.SCALE,
        use_fc=True,
        pretrained=True,
    ):
        super(ShopeeModel, self).__init__()
        print(f"Building Model Backbone for {model_name} model")

        self.backbone = timm.create_model(model_name, pretrained=pretrained)

        if model_name == "resnext50_32x4d":
            final_in_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
            self.backbone.global_pool = nn.Identity()

        elif "efficientnet" in model_name:
            final_in_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
            self.backbone.global_pool = nn.Identity()

        elif "nfnet" in model_name:
            final_in_features = self.backbone.head.fc.in_features
            self.backbone.head.fc = nn.Identity()
            self.backbone.head.global_pool = nn.Identity()

        self.pooling = nn.AdaptiveAvgPool2d(1)

        self.use_fc = use_fc

        if use_fc:
            self.dropout = nn.Dropout(p=0.1)
            self.fc = nn.Linear(final_in_features, fc_dim)
            self.bn = nn.BatchNorm1d(fc_dim)
            self.reset_parameters()
            final_in_features = fc_dim

        self.final = ArcMarginProduct(
            final_in_features, n_classes, scale=scale, margin=margin
        )

    def forward(self, image, label):
        features = self.extract_features(image)
        # logits, loss = self.final(features, label)
        return features

    def reset_parameters(self):
        nn.init.xavier_normal_(self.fc.weight)
        nn.init.constant_(self.fc.bias, 0)
        nn.init.constant_(self.bn.weight, 1)
        nn.init.constant_(self.bn.bias, 0)

    def extract_features(self, x):
        batch_size = x.shape[0]
        x = self.backbone(x)
        x = self.pooling(x).view(batch_size, -1)

        if self.use_fc:
            x = self.dropout(x)
            x = self.fc(x)
            x = self.bn(x)
        return x

In [7]:
df = pd.read_csv(Config.TRAIN_CSV)

model = ShopeeModel(model_name=Config.MODEL_NAME)
model.eval()

model = replace_activations(model, torch.nn.SiLU, Mish())

state_dict = torch.load(Config.MODEL_PATH, weights_only=True)
model.load_state_dict(state_dict)

Building Model Backbone for eca_nfnet_l0 model
Mish initialized


<All keys matched successfully>

In [8]:
from torch.nn import DataParallel

dataset = ShopeeDataset(df, transform=get_test_transforms())
loader = DataLoader(
    dataset,
    batch_size=Config.BATCH_SIZE,
    num_workers=Config.NUM_WORKERS,
    shuffle=False,
    pin_memory=True,
)

embs = []
model = DataParallel(model)
model.cuda()
for batch in tqdm(loader):
    img = batch["image"].cuda()
    label = batch["label"].cuda()
    with torch.no_grad():
        feat = model(img, label)
        image_embeddings = feat.detach().cpu().numpy()
        embs.append(image_embeddings)

image_embeddings = np.concatenate(embs)

  0%|          | 0/536 [00:00<?, ?it/s]

100%|██████████| 536/536 [04:46<00:00,  1.87it/s]


In [10]:
tmp = df.groupby("label_group").posting_id.agg("unique").to_dict()
df["label"] = df.label_group.map(tmp)

In [14]:
# normalize embeddings
image_embeddings /= np.linalg.norm(image_embeddings, 2, axis=1, keepdims=True)

df["image_retrieval"] = retrieval(image_embeddings, df, threshold=0.5)
df["f1"] = df.apply(compute_f1("image_retrieval"), axis=1)
df["recall"] = df.apply(compute_recall("image_retrieval"), axis=1)
df["precision"] = df.apply(compute_precision("image_retrieval"), axis=1)
df["AP"] = df.apply(compute_AP("image_retrieval", 50), axis=1)

df["f1"].mean(), df["recall"].mean(), df["precision"].mean(), df["AP"].mean()

Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:09<00:00,  1.07s/it]


(0.8663392386832592, 0.906322577113404, 0.8843458494742749, 0.8761870995071006)