<a href="https://colab.research.google.com/github/greyhound101/shopee/blob/main/distil_swe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import copy
import math
import pandas as pd
import numpy as np
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

import transformers
from transformers import (BertTokenizer, BertModel,
                          DistilBertTokenizer, DistilBertModel)

In [None]:


le=LabelEncoder()
a=np.load('../input/stratify-counts/train_0.npy')
df=pd.read_csv('../input/shopee-product-matching/train.csv')
df['label_group']=le.fit_transform(df['label_group'])
df=df.drop_duplicates('label_group').reset_index(drop=True)
labels=list(df.loc[a]['label_group'].values)
train = pd.read_csv("../input/shopee-product-matching/train.csv")

labelencoder= LabelEncoder()
train['label_group'] = labelencoder.fit_transform(train['label_group'])
train=train.loc[train['label_group'].isin(labels)]
train.shape    



In [None]:
title_lengths = train['title'].apply(lambda x: len(x.split(" "))).to_numpy()
print(f"MIN words: {title_lengths.min()}, MAX words: {title_lengths.max()}")
plt.hist(title_lengths);

In [None]:
class CFG:
    DistilBERT = True # if set to False, BERT model will be used
    bert_hidden_size = 768
    
    batch_size = 64
    epochs = 30
    num_workers = 4
    learning_rate = 1e-5 #3e-5
    scheduler = "ReduceLROnPlateau"
    step = 'epoch'
    patience = 2
    factor = 0.8
    dropout = 0.5
    model_path = "/kaggle/working"
    max_length = 30
    model_save_name = "model.pt"
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [None]:
if CFG.DistilBERT:
    model_name='cahya/distilbert-base-indonesian'
    tokenizer = DistilBertTokenizer.from_pretrained(model_name)
    bert_model = DistilBertModel.from_pretrained(model_name)
else:
    model_name='cahya/bert-base-indonesian-522M'
    tokenizer = BertTokenizer.from_pretrained(model_name)
    bert_model = BertModel.from_pretrained(model_name)

In [None]:
text = train['title'].values[np.random.randint(0, len(train) - 1, 1)[0]]
print(f"Text of the title: {text}")
encoded_input = tokenizer(text, return_tensors='pt')
print(f"Input tokens: {encoded_input['input_ids']}")
decoded_input = tokenizer.decode(encoded_input['input_ids'][0])
print(f"Decoded tokens: {decoded_input}")
output = bert_model(**encoded_input)
print(f"last layer's output shape: {output.last_hidden_state.shape}")

In [None]:
lbl_encoder = LabelEncoder()
train['label_code'] = lbl_encoder.fit_transform(train['label_group'])
NUM_CLASSES = train['label_code'].nunique()

In [None]:
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, tokenizer, mode="train", max_length=None):
        self.dataframe = dataframe
        if mode != "test":
            self.targets = dataframe['label_code'].values
        texts = list(dataframe['title'].apply(lambda o: str(o)).values)
        self.encodings = tokenizer(texts, 
                                   padding=True, 
                                   truncation=True, 
                                   max_length=max_length)
        self.mode = mode
        
        
    def __getitem__(self, idx):
        item = {key: torch.tensor(values[idx]) for key, values in self.encodings.items()}
        if self.mode != "test":
            item['labels'] = torch.tensor(self.targets[idx]).long()
        return item
    
    def __len__(self):
        return len(self.dataframe)
dataset = TextDataset(train.sample(1000), tokenizer, max_length=CFG.max_length)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=CFG.batch_size, 
                                         num_workers=CFG.num_workers, 
                                         shuffle=True)
batch = next(iter(dataloader))
print(batch['input_ids'].shape, batch['labels'].shape)

In [None]:
class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=CFG.device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s
        # print(output)

        return output
class Model(nn.Module):
    def __init__(self, 
                 bert_model, 
                 num_classes=NUM_CLASSES, 
                 last_hidden_size=CFG.bert_hidden_size):
        
        super().__init__()
        self.bert_model = bert_model
        self.arc_margin = ArcMarginProduct(last_hidden_size, 
                                           num_classes, 
                                           s=30.0, 
                                           m=0.50, 
                                           easy_margin=False)
    
    def get_bert_features(self, batch):
        output = self.bert_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
        last_hidden_state = output.last_hidden_state # shape: (batch_size, seq_length, bert_hidden_dim)
        CLS_token_state = last_hidden_state[:, 0, :] # obtaining CLS token state which is the first token.
        return CLS_token_state
    
    def forward(self, batch):
        CLS_hidden_state = self.get_bert_features(batch)
        output = self.arc_margin(CLS_hidden_state, batch['labels'])
        return output
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()
    
    def reset(self):
        self.avg, self.sum, self.count = [0]*3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count
    
    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text
def set_value(optimizer,lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return optimizer
  
def weight_update(model,model_count,swa_weights):
    print('updating')
    weights=model.state_dict()
    if model_count==0:
          swa_weights=weights
          return swa_weights
    for i in weights.keys():
      swa_weights[i]=(swa_weights[i]*model_count+weights[i])/(model_count+1)
    return swa_weights

def _t_cycle(clr_iterations,iter_per_cycle):
        return (((clr_iterations - 1) % iter_per_cycle) + 1) / iter_per_cycle
  
  
def _clr_schedule(alpha2,alpha1,clr_iterations,iter_per_cycle,cycle_num):
    swa_cycle_start_inx=40
    

    lr_start   = 0.000001
    lr_max     = 0.000005 * 128
    lr_min     = 0.000001
    lr_ramp_ep = 5
    lr_sus_ep  = 0
    lr_decay   = 0.8
    cycle_len=2


    if cycle_num>=swa_cycle_start_inx:
          return ((1.0 - 1.0 *_t_cycle(clr_iterations,iter_per_cycle)) * alpha2) + (1.0 *_t_cycle(clr_iterations,iter_per_cycle) *alpha1)
    else:
          
          if cycle_num < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * cycle_num + lr_start   
          elif cycle_num < lr_ramp_ep + lr_sus_ep:
            lr = lr_max    
          else:
            lr = (lr_max - lr_min) * lr_decay**(cycle_num - lr_ramp_ep - lr_sus_ep) + lr_min   
          return lr
def one_epoch(model, 
              criterion, 
              loader,
              optimizer=None, 
              lr_scheduler=None, 
              mode="train", 
              step="batch", epoch=0, clr_iterations=0, swa_weights=None ,model_count=0):
    
    loss_meter = AvgMeter()
    acc_meter = AvgMeter()
    
    
    #########################################
    iter_per_cycle=2*421
    start_inx=40
    lrs=[]
    swa_cycle_start_inx=40
    #############################################
    
    
    
    lrs=[]
    lr=1e-5
    tqdm_object = tqdm(loader, total=len(loader))
    for batch in tqdm_object:
    
        batch = {k: v.to(CFG.device) for k, v in batch.items()}
        
        clr_iterations+=1
        
        if epoch>=swa_cycle_start_inx:
            lr=_clr_schedule(1e-5,1e-6,clr_iterations,iter_per_cycle,epoch)
            optimizer=set_value(optimizer,lr)
        lrs.append(lr)
        
        
        preds = model(batch)
        loss = criterion(preds, batch['labels'])
        if mode == "train":
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step == "batch":
                lr_scheduler.step()
                
        count = batch['input_ids'].size(0)
        loss_meter.update(loss.item(), count)
        
        accuracy = get_accuracy(preds.detach(), batch['labels'])
        acc_meter.update(accuracy.item(), count)
        if mode == "train":
            tqdm_object.set_postfix(train_loss=loss_meter.avg, accuracy=acc_meter.avg, lr=get_lr(optimizer))
        else:
            tqdm_object.set_postfix(valid_loss=loss_meter.avg, accuracy=acc_meter.avg)
    if (_t_cycle(clr_iterations,iter_per_cycle) !=1) or (epoch<start_inx):
      return  loss_meter,acc_meter,clr_iterations,swa_weights,model_count,lrs
    swa_weights=weight_update(model,model_count,swa_weights)
    model_count+=1
    return loss_meter, acc_meter, clr_iterations, swa_weights,model_count,lrs
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

def get_accuracy(preds, targets):
    """
    preds shape: (batch_size, num_labels)
    targets shape: (batch_size)
    """
    preds = preds.argmax(dim=1)
    acc = (preds == targets).float().mean()
    return acc

In [None]:
def train_eval(epochs, model, train_loader, 
               criterion, optimizer, lr_scheduler=None):
    
    best_loss = float('inf')
    best_model_weights = copy.deepcopy(model.state_dict())
    clr_iterations=0
    cycle_num=0
    model_count=0
    lrs=[]
    for epoch in range(epochs):
        print("*" * 30)
        print(f"Epoch {epoch + 1}")
        current_lr = get_lr(optimizer)
        if epoch==0:
            swa_weights=None
        model.train()
        train_loss, train_acc,clr_iterations,swa_weights,model_count,lr = one_epoch(model, 
                                          criterion, 
                                          train_loader, 
                                          optimizer=optimizer,
                                          lr_scheduler=lr_scheduler,
                                          mode="train",
                                          step=CFG.step, epoch=epoch,clr_iterations=clr_iterations
                                        ,swa_weights=swa_weights,model_count=model_count)                     
        model.eval()
        lrs.append(lr)
        print("*" * 30)
    return lrs
train_dataset = TextDataset(train, tokenizer, max_length=CFG.max_length)
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=CFG.batch_size, 
                                           num_workers=CFG.num_workers, 
                                           shuffle=True)

model = Model(bert_model).to(CFG.device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=CFG.learning_rate)
if CFG.scheduler == "ReduceLROnPlateau":
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                              mode="min", 
                                                              factor=CFG.factor, 
                                                              patience=CFG.patience)

lrs=train_eval(60, model, train_loader,
           criterion, optimizer, lr_scheduler=lr_scheduler)

In [None]:
total=[]
for i in lrs:
    total+=i

from matplotlib import pyplot as plt
plt.plot(total)



In [None]:


!mkdir tokenizer
tokenizer.save_pretrained("./tokenizer")
bert_model.save_pretrained('abc')

