In [1]:
from imports import *
from utils import load_data
from image_train.data import create_dl, ImageDS
from image_train.model import EMBRes
from image_train.arcface import ArcMarginProduct
from image_train.train import get_hparams, train
np.random.seed(1337)
device = torch.device('cuda')

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
df, train_df, val_df, train_labels, val_labels = load_data(train_perc=0.3)

df = pd.read_csv('data/train.csv')
df['label_group'] = df['label_group'].astype('category').cat.codes
df['image_phash'] = df['image_phash'].astype('category').cat.codes
np.random.seed(1337)

train val split
ph = np.random.permutation(df['image_phash'].unique())

train_perc = 0.3
train_idx = int(train_perc * len(ph))

train_labels = ph[:train_idx]
val_labels = ph[train_idx:]

train_df = df[df['image_phash'].isin(train_labels)]
val_df = df[df['image_phash'].isin(val_labels)]

In [4]:
# creating dataloaders
small_images_dir = 'data/small_train_images/'

tr_dl = create_dl(train_df, small_images_dir, batch_size=64)
tr_test_dl = create_dl(train_df, small_images_dir, shuffle=False)
val_dl = create_dl(val_df, small_images_dir, shuffle=False)
full_dl = create_dl(val_df, small_images_dir, shuffle=False)

***Embeddings normalization is not done in the model but in the arcface metric***

In [5]:
vision_model = 'resnet50'
model = timm.create_model('resnet50', pretrained=True, num_classes=0).to('cuda')
metric_fc = ArcMarginProduct(512, df['label_group'].nunique(), s=30, m=0.5, easy_margin=False).to(device)

In [6]:
n_epochs, lf, params, optimizer, sched, train_transforms, val_transforms = get_hparams(tr_dl, model, metric_fc, n_epochs=5, lr=1e-4)

In [8]:
def compute_centers(dataloader, model, val_transforms, dataframe) :
    dataframe['label_group'] = dataframe['label_group'].astype('category').cat.codes
    dataframe['indx'] = range(len(dataframe))
    label_indxs = dataframe.groupby('label_group').agg({'indx':'unique'})
    with torch.no_grad() :
        embs = []
        for imgs, _ in tqdm(dataloader) :
            imgs = val_transforms(imgs).to('cuda')
            features = model(imgs)
            embs.append(features.cpu())
    embs = F.normalize(torch.cat(embs, 0))
    centers = torch.zeros(len(label_indexes), embs.shape[1]).to('cuda')
    for i in range(len(label_indexes)) :
        centers[i] = embs[label_indxs.iloc[i].values[0]].mean(dim=0)
    return centers


In [9]:
centers = compute_centers(tr_test_dl, model, val_transforms, label_indxs)

HBox(children=(FloatProgress(value=0.0, max=164.0), HTML(value='')))




In [7]:
train_df['label_group'] = train_df['label_group'].astype('category').cat.codes
train_df['indx'] = range(len(train_df))
label_indxs = train_df.groupby('label_group').agg({'indx':'unique'})

In [10]:
# https://github.com/ronghuaiyang/arcface-pytorch

class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False, centers=None):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        if centers is None :
            print('Using random weights')
            self.weight = Parameter(torch.FloatTensor(out_features, in_features))
            nn.init.xavier_uniform_(self.weight)
        else :
            print('Using center as wieghts')
            self.weight = Parameter(centers.to('cuda'))
        

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s
        # print(output)

        return output

In [11]:
def test_loss(model, optimizer, loss_func, sched, metric_fc, train_dl, val_dl, n_epochs, train_df, val_df,
          train_transforms, val_transforms, save_path, val_first=False):
    
    tr_losses = []
    tr_scores = []
    val_scores = []
    prev_best_f_val = -10
    prev_best_f_train = -10

    with torch.no_grad() :

        # TRAINING
        model.train()
        tr_loss = []
        embs = []
        pbar = tqdm(train_dl)
        for imgs, labels in pbar:

            imgs = train_transforms(imgs).to('cuda')

            optimizer.zero_grad()
            feature = model(imgs)
            labels = labels.long().to('cuda')
            out = metric_fc(feature, labels)
            loss = loss_func(out, labels)

            tr_loss.append(loss.item())
            pbar.set_description(f"Train loss: {round(np.mean(tr_loss),3)}")
       
        tr_losses.append(tr_loss)
        summary = f"Train loss {np.asarray(tr_loss).mean()} "
        print(summary)
    return (tr_losses, val_scores)

## With init centers

In [21]:
train_df.iloc[3767]

posting_id                                      train_1382500866
image                       5d075d7eaa258052ab125c75c06293d6.jpg
image_phash                                     838436c07dff19e4
title          RELIZA WALL STICKER PENGUKUR TINGGI BADAN JERA...
label_group                                                    0
indx                                                        3767
Name: 12367, dtype: object

In [12]:
metric_fc = ArcMarginProduct(2048, train_df['label_group'].nunique(), s=30, m=0.5, easy_margin=False, centers=centers).to(device)

Using center as wieghts


In [13]:
test_loss(model, optimizer, lf, sched, metric_fc, tr_dl, val_dl, n_epochs, train_df, val_df, 
      train_transforms, val_transforms, save_path='data/tests_model_image/test', val_first=False)

HBox(children=(FloatProgress(value=0.0, max=164.0), HTML(value='')))


Train loss 9.41272031679386 


([[8.603670120239258,
   7.848506927490234,
   9.831758499145508,
   8.707134246826172,
   8.61577033996582,
   9.819419860839844,
   9.868816375732422,
   9.03233528137207,
   7.903177738189697,
   9.413893699645996,
   9.949143409729004,
   9.39145565032959,
   8.838394165039062,
   9.662611961364746,
   9.31142807006836,
   8.920430183410645,
   8.428309440612793,
   9.946556091308594,
   9.603825569152832,
   8.85608196258545,
   10.691878318786621,
   9.186503410339355,
   8.94955062866211,
   10.741680145263672,
   9.83850383758545,
   8.157082557678223,
   8.15455436706543,
   10.405105590820312,
   8.705581665039062,
   9.325439453125,
   9.830687522888184,
   10.278044700622559,
   8.848200798034668,
   9.725737571716309,
   8.430460929870605,
   9.535745620727539,
   9.94510555267334,
   8.894384384155273,
   9.122233390808105,
   10.573990821838379,
   8.85316276550293,
   9.455981254577637,
   9.150248527526855,
   9.93157958984375,
   11.295001029968262,
   9.8850879669189

In [14]:
metric_fc = ArcMarginProduct(2048, train_df['label_group'].nunique(), s=30, m=0.5, easy_margin=False, centers=None).to(device)

Using random weights


In [15]:

test_loss(model, optimizer, lf, sched, metric_fc, tr_dl, val_dl, n_epochs, train_df, val_df, 
      train_transforms, val_transforms, save_path='data/tests_model_image/test', val_first=False)

HBox(children=(FloatProgress(value=0.0, max=164.0), HTML(value='')))

KeyboardInterrupt: 

### 