In [None]:
# Copyright 2021 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [1]:
import os, time
os.environ["CUDA_VISIBLE_DEVICES"]='3'
os.environ["CUDA_LAUNCH_BLOCKING"]='1'

import glob
import pandas as pd
import numpy as np
import cudf
import cupy
import gc
from datetime import datetime

from util import compute_rce_fast

DP = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))>1
DP

False

In [2]:
import cupy as cp
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import nvtabular as nvt
from nvtabular.loader.torch import TorchAsyncItr, DLDataLoader
from nvtabular.framework_utils.torch.models import Model
from nvtabular.framework_utils.torch.utils import process_epoch

import torch
from torch import nn
torch.__version__

'1.7.1+cu101'

## model

In [3]:
class ConcatenatedEmbeddings(torch.nn.Module):
    """Map multiple categorical variables to concatenated embeddings.
    Args:
        embedding_table_shapes: A dictionary mapping column names to
            (cardinality, embedding_size) tuples.
        dropout: A float.
    Inputs:
        x: An int64 Tensor with shape [batch_size, num_variables].
    Outputs:
        A Float Tensor with shape [batch_size, embedding_size_after_concat].
    """

    def __init__(self, embedding_table_shapes, dropout=0.0):
        super().__init__()
        self.embedding_layers = torch.nn.ModuleList(
            [
                torch.nn.Embedding(cat_size, emb_size, sparse=(cat_size > 1e5))
                for cat_size, emb_size in embedding_table_shapes.values()
            ]
        )
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, x):
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
            
        # first two cat columns (a_user and b_user) share same emb table            
        x = [self.embedding_layers[0](x[:,0])] + [layer(x[:, i+1]) for i, layer in enumerate(self.embedding_layers)] 
        x = torch.cat(x, dim=1)
        x = self.dropout(x)
        return x

In [4]:
import torch.nn as nn

sigmoid = nn.Sigmoid()

class Swish(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * sigmoid(i)
        ctx.save_for_backward(i)
        return result
    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class Swish_Module(nn.Module):
    def forward(self, x):
        return Swish.apply(x)


from transformers import AutoTokenizer, AutoModel

bert_type = 'distilbert-base-multilingual-cased'

tokenizer = AutoTokenizer.from_pretrained(bert_type)

class Net(nn.Module):
    def __init__(self, num_features, layers, embedding_table_shapes, dropout=0.2, bert_type=None, gru_dim=128, emb_dim=768):
        super(Net, self).__init__()
        self.dropout = dropout
        self.initial_cat_layer = ConcatenatedEmbeddings(embedding_table_shapes, dropout=dropout)
        embedding_size = sum(emb_size for _, emb_size in embedding_table_shapes.values())
        layers = [layers] if type(layers) is int else layers
        layers = [num_features + gru_dim + embedding_size + 128 + 128] + layers
        self.use_bert = True
        self.embed = AutoModel.from_pretrained(bert_type).embeddings.word_embeddings  
        assert emb_dim == self.embed.embedding_dim
#             self.reduce_dim = nn.Linear(self.embed.embedding_dim, 256)
#             self.embed = nn.Embedding(119547, emb_dim)
#         layers[0] += gru_dim
        self.lstm = nn.GRU(emb_dim, gru_dim, batch_first=True, bidirectional=False)    
#             self.lstm = nn.Linear(self.embed.embedding_dim, gru_dim)

        self.fn_layers = nn.ModuleList(
                            nn.Sequential(
                                nn.Dropout(p=dropout),
                                nn.Linear(layers[i], layers[i+1]),
                                nn.BatchNorm1d(layers[i+1]),
                                Swish_Module(),
                            )  for i in range(len(layers) -1)
                         )        
        self.fn_last = nn.Linear(layers[-1],4)
        
    def forward(self, x_cat, x_cont, bert_tok):
        a_emb = self.initial_cat_layer.embedding_layers[0](x_cat[:,0])
        b_emb = self.initial_cat_layer.embedding_layers[0](x_cat[:,1])
        mf = a_emb * b_emb        
        
        x_cat = self.initial_cat_layer(x_cat)
        bert_tok = self.embed(bert_tok)#.mean(dim=1)
#             bert_tok = self.reduce_dim(bert_tok)
        lstm_out = self.lstm(bert_tok)[0][:,-1]
        output = torch.cat([x_cont, lstm_out, x_cat, mf],dim=1)
        for layer in self.fn_layers:
            output = layer(output)
        logit = self.fn_last(output)
        return logit

## scheduler

In [5]:
from warmup_scheduler import GradualWarmupScheduler
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import warnings; warnings.simplefilter('ignore')

def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 

## train loop

In [6]:
criterion = nn.BCEWithLogitsLoss()

def train_epoch(model, loader, optimizer, scaler, optimizer2):

    model.train()
    train_loss = []
    bar = tqdm(loader)
    for batch in bar:
        x_cat, x_cont, text_tok, targets = batch
        
        x_cat = x_cat.cuda()
        x_cont = x_cont.cuda()
        text_tok = text_tok.cuda()
        targets = targets.cuda()

        optimizer.zero_grad()
        optimizer2.zero_grad()

        if use_torch_amp:
            with amp.autocast():
                logits = model(x_cat, x_cont, text_tok)
#                 logits = model(data)
            loss = criterion(logits, targets)       
            
            scaler.scale(loss).backward()

            # You can choose which optimizers receive explicit unscaling, if you
            # want to inspect or modify the gradients of the params they own.
            scaler.unscale_(optimizer)
            scaler.unscale_(optimizer2)

            scaler.step(optimizer)
            scaler.step(optimizer2)

            scaler.update()            
            
        elif use_amp:
            logits = model(x_cat, x_cont, text_tok)
#             logits = model(data)
            loss = criterion(logits, targets)
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
        else:
            logits = model(x_cat, x_cont, text_tok)
#             logits = model(data)
            loss = criterion(logits, targets)
            loss.backward()
            optimizer.step()

        loss_np = loss.item()
        train_loss.append(loss_np)
        smooth_loss = sum(train_loss[-50:]) / min(len(train_loss), 50)
        bar.set_description('loss: %.4f, smth: %.4f' % (loss_np, smooth_loss))

    return np.mean(train_loss)

def valid_epoch(model, loader):

    model.eval()
    val_loss = []
    LOGITS = []
    TARGETS = []
    with torch.no_grad():
        for batch in tqdm(loader):
            x_cat, x_cont, text_tok, targets = batch

            x_cat = x_cat.cuda()
            x_cont = x_cont.cuda()
            text_tok = text_tok.cuda()
            targets = targets.cuda()
        
            logits = model(x_cat, x_cont, text_tok)
#             logits = model(data)
            loss = criterion(logits, targets)
            val_loss.append(loss.item())
            LOGITS.append(logits.cpu())
            TARGETS.append(targets.cpu())
            
    LOGITS = torch.cat(LOGITS)
    TARGETS = torch.cat(TARGETS)
    rce = {}
    for i in range(4):
        rce[label_names[i]] = compute_rce_fast(cp.asarray(LOGITS[:,i].sigmoid()),cp.asarray(TARGETS[:,i])).get()            
    mean_rce = np.mean([v for k,v in rce.items()])
            
    val_loss = np.mean(val_loss)

    return val_loss, rce, mean_rce

# NVT loader

In [7]:
label_names = sorted(['reply', 'retweet', 'retweet_comment', 'like'])
CAT_COLUMNS = ['a_user_id','b_user_id','language','media','tweet_type']
NUMERIC_COLUMNS = ['a_follower_count',
                     'a_following_count',
                     'a_is_verified',
                     'b_follower_count',
                     'b_following_count',
                     'b_is_verified',
                     'b_follows_a',
                     'tw_len_media',
                     'tw_len_photo',
                     'tw_len_video',
                     'tw_len_gif',
                     'tw_len_quest',
                     'tw_len_token',
                     'tw_count_capital_words',
                     'tw_count_excl_quest_marks',
                     'tw_count_special1',
                     'tw_count_hash',
                     'tw_last_quest',
                     'tw_len_retweet',
                     'tw_len_rt',
                     'tw_count_at',
                     'tw_count_words',
                     'tw_count_char',
                     'tw_rt_count_words',
                     'tw_rt_count_char',
                     'len_hashtags',
                     'len_links',
                     'len_domains',
                     'a_ff_rate',
                     'b_ff_rate',
                     'ab_fing_rate',
                     'ab_fer_rate',
                     'a_age',
                     'b_age',
                     'ab_age_dff',
                     'ab_age_rate']
len(NUMERIC_COLUMNS)

36

In [8]:
def read_norm_merge(path, split='train'):
    ddf = pd.read_parquet(path)

    ddf['quantile'] = 0
    quantiles = [92, 216, 442, 1064]
    for i, quant in enumerate(quantiles):
        ddf['quantile'] = (ddf['quantile']+(ddf['a_follower_count']>quant).astype('int8')).astype('int8')

    ddf['date'] = pd.to_datetime(ddf['timestamp'], unit='s')
    
    VALID_DOW = '2021-02-18'
    if split=='train':
        ddf = ddf[ddf['date']<pd.to_datetime(VALID_DOW)].reset_index(drop=True)
    elif split=='valid':
        ddf = ddf[ddf['date']>=pd.to_datetime(VALID_DOW)].reset_index(drop=True)    
    else:
        pass
    
    ddf['a_ff_rate'] = (ddf['a_following_count'] / ddf['a_follower_count']).astype('float32')
    ddf['b_ff_rate'] = (ddf['b_follower_count']  / ddf['b_following_count']).astype('float32')
    ddf['ab_fing_rate'] = (ddf['a_following_count'] / ddf['b_following_count']).astype('float32')
    ddf['ab_fer_rate'] = (ddf['a_follower_count'] / (1+ddf['b_follower_count'])).astype('float32')
    ddf['a_age'] = ddf['a_account_creation'].astype('int16') + 128
    ddf['b_age'] = ddf['b_account_creation'].astype('int16') + 128
    ddf['ab_age_dff'] = ddf['b_age'] - ddf['a_age']
    ddf['ab_age_rate'] = ddf['a_age']/(1+ddf['b_age'])

    ## Normalize
    for col in NUMERIC_COLUMNS:
        if col == 'tw_len_quest':
            ddf[col] = np.clip(ddf[col].values,0,None)
        if ddf[col].dtype == 'uint16':
            ddf[col].astype('int32')

        if col == 'ab_age_dff':
            ddf[col] = ddf[col] / 256.            
        elif 'int' in str(ddf[col].dtype) or 'float' in str(ddf[col].dtype):    
            ddf[col] = np.log1p(ddf[col])

        if ddf[col].dtype == 'float64':
            ddf[col] = ddf[col].astype('float32') 

    ## get categorical embedding id        
    for col in CAT_COLUMNS:
        ddf[col] = ddf[col].astype('float')
        if col in ['a_user_id','b_user_id']:
            mapping_col = 'a_user_id_b_user_id'
        else:
            mapping_col = col
        mapping = pd.read_parquet(f'/raid/recsys_pre_TE_w_tok/workflow_232parts_joint_thr3_pos/categories/unique.{mapping_col}.parquet').reset_index()
        mapping.columns = ['index',col]
        ddf = ddf.merge(mapping, how='left', on=col).drop(columns=[col]).rename(columns={'index':col})
        ddf[col] = ddf[col].fillna(0).astype('int')        

    label_names = ['reply', 'retweet', 'retweet_comment', 'like']
    DONT_USE = ['timestamp','a_account_creation','b_account_creation','engage_time',
                'fold', 'dt_dow', 'a_account_creation', 
                'b_account_creation', 'elapsed_time', 'links','domains','hashtags','id', 'date', 'is_train', 
                'tw_hash0', 'tw_hash1', 'tw_hash2', 'tw_http0', 'tw_uhash', 'tw_hash', 'tw_word0', 
                'tw_word1', 'tw_word2', 'tw_word3', 'tw_word4', 'dt_minute', 'dt_second',
               'dt_day', 'group', 'text', 'tweet_id', 'tw_original_user0', 'tw_original_user1', 'tw_original_user2',
                'tw_rt_user0', 'tw_original_http0', 'tw_tweet',]
    DONT_USE = [c for c in ddf.columns if c in DONT_USE]
    gc.collect(); gc.collect()
    
    return ddf.drop(columns=DONT_USE)

In [9]:
PATHS = sorted(glob.glob('/raid/recsys/train_proc3/*.parquet'))
len(PATHS)

232

In [10]:
# for col in NUMERIC_COLUMNS:
#     print(col)
#     plt.hist(train[col].values, bins=50)
#     plt.title(col)
# #     print(ddf[col].describe())
#     plt.show()

In [8]:
import torch
from torch.utils.data import Dataset,DataLoader

class AllDataset(Dataset):
    def __init__(self, df, max_len_txt, NUMERIC_COLUMNS, CAT_COLUMNS):
        self.X = df[NUMERIC_COLUMNS].values
        self.X_cat = df[CAT_COLUMNS].values
        self.labels = df[label_names].values
        self.text_tokens = df.text_tokens.values
        self.max_len_txt = max_len_txt
    def __len__(self):
        return self.labels.shape[0]
    def __getitem__(self, index):        
#         text = tokenizer.decode([int(token_id) for token_id in self.text_tokens[index][4:-4].split('\t')]) # [4:-4] is to remove [CLS] and [SEP]
#         inputs = tokenizer(text, truncation=True, padding='max_length', max_length=max_len_txt, return_tensors='pt')['input_ids'].squeeze()
        inputs = [int(token_id) for token_id in self.text_tokens[index].split('\t')][:self.max_len_txt]
        if len(inputs) < self.max_len_txt:
            inputs += [0]*(self.max_len_txt-len(inputs))
        return self.X_cat[index], self.X[index].astype(np.float32), torch.tensor(inputs), self.labels[index].astype(np.float32)

In [9]:
gru_dim=128
max_len_txt=48
emb_dim=768
lr = 1e-2
lr2= 1e-4
ep = 46   
BATCH_SIZE = 1024
num_workers = 16
use_torch_amp = True
import torch.cuda.amp as amp
use_amp = False

model_name = 'load_thr3_pos_1e-2_1e-4'

In [13]:
len(NUMERIC_COLUMNS)

36

In [14]:
%%time
train_lst = []
for path in PATHS[:10]:
    train_lst.append(read_norm_merge(path, 'valid'))
valid = pd.concat(train_lst)
gc.collect()

valid_dataset = AllDataset(valid, max_len_txt, NUMERIC_COLUMNS, CAT_COLUMNS)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers) 
valid.shape, len(valid_loader)

CPU times: user 8min 46s, sys: 1min 8s, total: 9min 55s
Wall time: 9min 38s


((10324907, 47), 10083)

In [15]:
# train_parts_order = np.concatenate([np.random.permutation(232)])
train_parts_order = np.array([ 46, 111, 208, 230,   3,  22, 227, 153,  78,  52,  20, 185,   6,
        130, 177,  83,  97, 194,  24, 187,  93,  59, 217, 180, 129,  62,
          1,  43, 229, 102, 196,  50,   4,  12, 114,  70,  18,  91,  71,
        190, 174,  23,  63,  89, 188,  16, 104,  67,  39, 225, 176,  28,
        198,   2,  76, 166, 216, 116, 199, 113, 107, 201,  64, 115,   8,
        171,  44, 218, 158, 181,  79,  47, 155, 159, 164, 109,  56, 106,
        122, 203, 144,  14, 163, 124, 110, 126,  80,  77,  94, 135,  33,
        134, 224, 145, 172, 191,  60, 148, 215, 212, 219,  35, 167,  37,
        132, 182, 228,  75,  87, 156, 137,  74,  29,  95, 118,  90, 222,
         19,  57, 162, 105, 223, 210, 140,  10,  72, 152, 183, 170,  51,
         82, 117,  13, 211, 120,  81, 160,  27, 200, 128, 169, 213, 179,
         42,  11, 143,  15, 209, 151,  48, 207, 112, 119, 231, 175,   0,
        146, 154,  68, 197,  21, 206, 125, 192,  31,  86, 138,  36, 108,
        103,  58, 142,  54,  98,  99, 127, 214,   7,  92, 121, 202, 141,
        150,  88,  53,  38, 139, 147, 131,  66,  40,  26, 123,  73, 100,
        165, 186, 149, 205,   5, 189,  25,  32, 133, 101, 204, 178, 193,
        136,  84, 161,  30, 221,  65,  85,  41,  17,  61,  45, 173, 195,
          9, 184,  55,  49, 168,  69,  34,  96, 157, 226, 220])
train_parts_order, train_parts_order.shape

(array([ 46, 111, 208, 230,   3,  22, 227, 153,  78,  52,  20, 185,   6,
        130, 177,  83,  97, 194,  24, 187,  93,  59, 217, 180, 129,  62,
          1,  43, 229, 102, 196,  50,   4,  12, 114,  70,  18,  91,  71,
        190, 174,  23,  63,  89, 188,  16, 104,  67,  39, 225, 176,  28,
        198,   2,  76, 166, 216, 116, 199, 113, 107, 201,  64, 115,   8,
        171,  44, 218, 158, 181,  79,  47, 155, 159, 164, 109,  56, 106,
        122, 203, 144,  14, 163, 124, 110, 126,  80,  77,  94, 135,  33,
        134, 224, 145, 172, 191,  60, 148, 215, 212, 219,  35, 167,  37,
        132, 182, 228,  75,  87, 156, 137,  74,  29,  95, 118,  90, 222,
         19,  57, 162, 105, 223, 210, 140,  10,  72, 152, 183, 170,  51,
         82, 117,  13, 211, 120,  81, 160,  27, 200, 128, 169, 213, 179,
         42,  11, 143,  15, 209, 151,  48, 207, 112, 119, 231, 175,   0,
        146, 154,  68, 197,  21, 206, 125, 192,  31,  86, 138,  36, 108,
        103,  58, 142,  54,  98,  99, 127, 214,   7

In [10]:
model = Net(len(NUMERIC_COLUMNS), layers=[1024,256,64], 
            embedding_table_shapes={'a_user_id_b_user_id': (19688213, 128), 'language': (67, 16), 'media': (15, 16), 'tweet_type': (4, 16)},
            bert_type=bert_type).cuda()

for param in model.embed.parameters():
    param.requires_grad = False

model    

Net(
  (initial_cat_layer): ConcatenatedEmbeddings(
    (embedding_layers): ModuleList(
      (0): Embedding(19688213, 128, sparse=True)
      (1): Embedding(67, 16)
      (2): Embedding(15, 16)
      (3): Embedding(4, 16)
    )
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (embed): Embedding(119547, 768, padding_idx=0)
  (lstm): GRU(768, 128, batch_first=True)
  (fn_layers): ModuleList(
    (0): Sequential(
      (0): Dropout(p=0.2, inplace=False)
      (1): Linear(in_features=596, out_features=1024, bias=True)
      (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Swish_Module()
    )
    (1): Sequential(
      (0): Dropout(p=0.2, inplace=False)
      (1): Linear(in_features=1024, out_features=256, bias=True)
      (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Swish_Module()
    )
    (2): Sequential(
      (0): Dropout(p=0.2, inplace=False)
      (1): Linear(in_features=256, out_

In [17]:
sd = torch.load(f'../models/two_opt_lr3_load_len48_joint_thr10_3e-3_1e-4_best.pth',map_location='cpu')
sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
del sd['initial_cat_layer.embedding_layers.0.weight']
model.load_state_dict(sd, strict=False)

_IncompatibleKeys(missing_keys=['initial_cat_layer.embedding_layers.0.weight'], unexpected_keys=[])

In [18]:
optimizer = optim.SparseAdam(list(model.parameters())[:1], lr=lr)
optimizer2 = optim.AdamW(list(model.parameters())[1:], lr=lr2)
scaler = amp.GradScaler() if use_torch_amp else None

scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, ep-1)
scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=1, after_scheduler=scheduler_cosine)

scheduler_cosine2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer2, ep-1)
scheduler_warmup2 = GradualWarmupSchedulerV2(optimizer2, multiplier=10, total_epoch=1, after_scheduler=scheduler_cosine2)

rce_best = 0

## start training

In [19]:
print(model_name)

load_thr3_pos_1e-2_1e-4


In [20]:
for epoch in range(1, ep+1):
    print(time.ctime(), 'Epoch:', epoch)
    scheduler_warmup.step(epoch-1) 
    scheduler_warmup2.step(epoch-1) 
    
    # 5 parts per epoch
    idx_this_ep = train_parts_order[(epoch*5-5):epoch*5]
    
    train_lst = []
    for idx in tqdm(idx_this_ep):
        train_lst.append(read_norm_merge(PATHS[idx], 'train' if idx<10 else 'both'))
    train = pd.concat(train_lst)
 
    gc.collect();gc.collect();
    
    train_dataset = AllDataset(train, max_len_txt, NUMERIC_COLUMNS, CAT_COLUMNS)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, drop_last=True) 
    
    train_loss = train_epoch(model, train_loader, optimizer, scaler, optimizer2)
    valid_loss,rce,mean_rce = valid_epoch(model, valid_loader)
   
    content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, {optimizer2.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.4f}, valid loss: {valid_loss:.4f}, mean_rce: {mean_rce:.2f}'
    for col in ['retweet', 'reply',  'like', 'retweet_comment']:
        content += f', {col}: {rce[col]:.2f}'
        
    print(content)
    
    if mean_rce > rce_best:
        print('rce_best increased ({:.6f} --> {:.6f}).  Saving model ...'.format(rce_best, mean_rce))
        rce_best = mean_rce
                
        torch.save(model.state_dict(), f'../models/{model_name}_best.pth')
        
    torch.save(
        {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict() if scaler else None,
            'optimizer_state_dict2': optimizer2.state_dict(),
            'rce_best': rce_best,
        },
        f'../models/{model_name}_last.pth'
    )            
        
torch.save(model.state_dict(), f'../models/{model_name}_final.pth')

  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 11:11:07 2021 Epoch: 1


100%|██████████| 5/5 [05:43<00:00, 68.80s/it]
loss: 0.2394, smth: 0.2333: 100%|██████████| 12298/12298 [08:56<00:00, 22.92it/s]
100%|██████████| 10083/10083 [02:04<00:00, 81.31it/s]


Thu Jun 10 11:27:58 2021 Epoch 1, lr: 0.0100000, 0.0001000, train loss: 0.2404, valid loss: 0.2342, mean_rce: 14.77, retweet: 17.89, reply: 16.44, like: 18.48, retweet_comment: 6.26
rce_best increased (0.000000 --> 14.766256).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 11:35:03 2021 Epoch: 2


100%|██████████| 5/5 [06:06<00:00, 73.32s/it]
loss: 0.2220, smth: 0.2249: 100%|██████████| 13575/13575 [09:53<00:00, 22.88it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.90it/s]


Thu Jun 10 11:53:21 2021 Epoch 2, lr: 0.1000000, 0.0010000, train loss: 0.2268, valid loss: 0.2222, mean_rce: 18.61, retweet: 22.99, reply: 19.35, like: 22.57, retweet_comment: 9.52
rce_best increased (14.766256 --> 18.606373).  Saving model ...
Thu Jun 10 12:00:29 2021 Epoch: 3


100%|██████████| 5/5 [06:10<00:00, 74.15s/it]
loss: 0.2233, smth: 0.2198: 100%|██████████| 13671/13671 [09:52<00:00, 23.07it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.87it/s]


Thu Jun 10 12:18:47 2021 Epoch 3, lr: 0.1000000, 0.0010000, train loss: 0.2196, valid loss: 0.2188, mean_rce: 19.81, retweet: 24.78, reply: 20.21, like: 23.62, retweet_comment: 10.63
rce_best increased (18.606373 --> 19.811167).  Saving model ...
Thu Jun 10 12:26:01 2021 Epoch: 4


100%|██████████| 5/5 [06:28<00:00, 77.66s/it]
loss: 0.2101, smth: 0.2141: 100%|██████████| 14636/14636 [10:34<00:00, 23.08it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.71it/s]


Thu Jun 10 12:45:22 2021 Epoch 4, lr: 0.0995134, 0.0009951, train loss: 0.2162, valid loss: 0.2140, mean_rce: 21.43, retweet: 26.11, reply: 22.21, like: 25.37, retweet_comment: 12.02
rce_best increased (19.811167 --> 21.429255).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 12:52:23 2021 Epoch: 5


100%|██████████| 5/5 [06:22<00:00, 76.41s/it]
loss: 0.2123, smth: 0.2121: 100%|██████████| 14331/14331 [10:20<00:00, 23.09it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.49it/s]


Thu Jun 10 13:11:32 2021 Epoch 5, lr: 0.0989074, 0.0009891, train loss: 0.2136, valid loss: 0.2111, mean_rce: 22.45, retweet: 27.24, reply: 23.11, like: 26.42, retweet_comment: 13.04
rce_best increased (21.429255 --> 22.453396).  Saving model ...
Thu Jun 10 13:18:29 2021 Epoch: 6


100%|██████████| 5/5 [05:52<00:00, 70.58s/it]
loss: 0.2047, smth: 0.2106: 100%|██████████| 12543/12543 [09:07<00:00, 22.89it/s]
100%|██████████| 10083/10083 [02:04<00:00, 81.11it/s]


Thu Jun 10 13:35:44 2021 Epoch 6, lr: 0.0980631, 0.0009806, train loss: 0.2115, valid loss: 0.2089, mean_rce: 23.20, retweet: 28.01, reply: 23.84, like: 27.16, retweet_comment: 13.80
rce_best increased (22.453396 --> 23.202980).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 13:42:53 2021 Epoch: 7


100%|██████████| 5/5 [06:22<00:00, 76.57s/it]
loss: 0.2038, smth: 0.2116: 100%|██████████| 13903/13903 [10:05<00:00, 22.94it/s]
100%|██████████| 10083/10083 [02:07<00:00, 79.11it/s]


Thu Jun 10 14:01:45 2021 Epoch 7, lr: 0.0969846, 0.0009698, train loss: 0.2100, valid loss: 0.2074, mean_rce: 23.75, retweet: 28.50, reply: 24.45, like: 27.71, retweet_comment: 14.34
rce_best increased (23.202980 --> 23.750500).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 14:08:29 2021 Epoch: 8


100%|██████████| 5/5 [06:27<00:00, 77.45s/it]
loss: 0.1998, smth: 0.2077: 100%|██████████| 14638/14638 [10:41<00:00, 22.84it/s]
100%|██████████| 10083/10083 [02:07<00:00, 79.31it/s]


Thu Jun 10 14:27:56 2021 Epoch 8, lr: 0.0956773, 0.0009568, train loss: 0.2086, valid loss: 0.2059, mean_rce: 24.26, retweet: 29.17, reply: 24.86, like: 28.21, retweet_comment: 14.80
rce_best increased (23.750500 --> 24.257999).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 14:35:18 2021 Epoch: 9


100%|██████████| 5/5 [06:37<00:00, 79.52s/it]
loss: 0.2072, smth: 0.2061: 100%|██████████| 14635/14635 [10:48<00:00, 22.56it/s]
100%|██████████| 10083/10083 [02:05<00:00, 80.13it/s]


Thu Jun 10 14:55:04 2021 Epoch 9, lr: 0.0941474, 0.0009415, train loss: 0.2074, valid loss: 0.2043, mean_rce: 24.85, retweet: 29.74, reply: 25.42, like: 28.74, retweet_comment: 15.50
rce_best increased (24.257999 --> 24.851948).  Saving model ...
Thu Jun 10 15:02:17 2021 Epoch: 10


100%|██████████| 5/5 [06:10<00:00, 74.20s/it]
loss: 0.2023, smth: 0.2057: 100%|██████████| 13865/13865 [09:58<00:00, 23.18it/s]
100%|██████████| 10083/10083 [02:07<00:00, 79.34it/s]


Thu Jun 10 15:20:43 2021 Epoch 10, lr: 0.0924024, 0.0009240, train loss: 0.2063, valid loss: 0.2029, mean_rce: 25.29, retweet: 30.22, reply: 25.85, like: 29.26, retweet_comment: 15.83
rce_best increased (24.851948 --> 25.290001).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 15:27:52 2021 Epoch: 11


100%|██████████| 5/5 [06:11<00:00, 74.20s/it]
loss: 0.1890, smth: 0.2032: 100%|██████████| 13673/13673 [09:51<00:00, 23.11it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.75it/s]


Thu Jun 10 15:46:12 2021 Epoch 11, lr: 0.0904508, 0.0009045, train loss: 0.2050, valid loss: 0.2020, mean_rce: 25.67, retweet: 30.61, reply: 26.30, like: 29.54, retweet_comment: 16.24
rce_best increased (25.290001 --> 25.671610).  Saving model ...
Thu Jun 10 15:53:32 2021 Epoch: 12


100%|██████████| 5/5 [06:25<00:00, 77.10s/it]
loss: 0.2079, smth: 0.2037: 100%|██████████| 14359/14359 [10:28<00:00, 22.83it/s]
100%|██████████| 10083/10083 [02:07<00:00, 79.21it/s]


Thu Jun 10 16:12:55 2021 Epoch 12, lr: 0.0883022, 0.0008830, train loss: 0.2045, valid loss: 0.2006, mean_rce: 26.17, retweet: 31.14, reply: 26.93, like: 29.99, retweet_comment: 16.63
rce_best increased (25.671610 --> 26.171452).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 16:19:59 2021 Epoch: 13


100%|██████████| 5/5 [06:15<00:00, 75.00s/it]
loss: 0.1947, smth: 0.2033: 100%|██████████| 13714/13714 [10:00<00:00, 22.84it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.70it/s]


Thu Jun 10 16:38:56 2021 Epoch 13, lr: 0.0859670, 0.0008597, train loss: 0.2036, valid loss: 0.1995, mean_rce: 26.53, retweet: 31.46, reply: 27.04, like: 30.43, retweet_comment: 17.18
rce_best increased (26.171452 --> 26.527067).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 16:45:25 2021 Epoch: 14


100%|██████████| 5/5 [06:17<00:00, 75.40s/it]
loss: 0.1864, smth: 0.2016: 100%|██████████| 14312/14312 [10:20<00:00, 23.05it/s]
100%|██████████| 10083/10083 [02:07<00:00, 79.22it/s]


Thu Jun 10 17:04:42 2021 Epoch 14, lr: 0.0834565, 0.0008346, train loss: 0.2029, valid loss: 0.1989, mean_rce: 26.83, retweet: 31.67, reply: 27.67, like: 30.58, retweet_comment: 17.40
rce_best increased (26.527067 --> 26.829527).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 17:10:59 2021 Epoch: 15


100%|██████████| 5/5 [06:30<00:00, 78.05s/it]
loss: 0.2114, smth: 0.2002: 100%|██████████| 14636/14636 [10:42<00:00, 22.77it/s]
100%|██████████| 10083/10083 [02:10<00:00, 77.19it/s]


Thu Jun 10 17:31:09 2021 Epoch 15, lr: 0.0807831, 0.0008078, train loss: 0.2021, valid loss: 0.1973, mean_rce: 27.27, retweet: 32.13, reply: 28.04, like: 31.26, retweet_comment: 17.66
rce_best increased (26.829527 --> 27.269592).  Saving model ...
Thu Jun 10 17:37:23 2021 Epoch: 16


100%|██████████| 5/5 [06:17<00:00, 75.46s/it]
loss: 0.1965, smth: 0.2020: 100%|██████████| 14522/14522 [10:32<00:00, 22.95it/s]
100%|██████████| 10083/10083 [02:09<00:00, 77.99it/s]


Thu Jun 10 17:56:36 2021 Epoch 16, lr: 0.0779596, 0.0007796, train loss: 0.2015, valid loss: 0.1967, mean_rce: 27.50, retweet: 32.42, reply: 28.18, like: 31.41, retweet_comment: 17.98
rce_best increased (27.269592 --> 27.497391).  Saving model ...
Thu Jun 10 18:03:23 2021 Epoch: 17


100%|██████████| 5/5 [06:28<00:00, 77.73s/it]
loss: 0.2037, smth: 0.2008: 100%|██████████| 14856/14856 [10:53<00:00, 22.72it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.52it/s]


Thu Jun 10 18:23:05 2021 Epoch 17, lr: 0.0750000, 0.0007500, train loss: 0.2008, valid loss: 0.1958, mean_rce: 27.82, retweet: 32.56, reply: 28.60, like: 31.77, retweet_comment: 18.35
rce_best increased (27.497391 --> 27.820614).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 18:29:27 2021 Epoch: 18


100%|██████████| 5/5 [06:42<00:00, 80.41s/it]
loss: 0.1868, smth: 0.1990: 100%|██████████| 14637/14637 [10:40<00:00, 22.84it/s]
100%|██████████| 10083/10083 [02:09<00:00, 78.03it/s]


Thu Jun 10 18:49:33 2021 Epoch 18, lr: 0.0719186, 0.0007192, train loss: 0.2002, valid loss: 0.1952, mean_rce: 28.15, retweet: 33.01, reply: 28.92, like: 31.87, retweet_comment: 18.80
rce_best increased (27.820614 --> 28.147221).  Saving model ...
Thu Jun 10 18:55:18 2021 Epoch: 19


100%|██████████| 5/5 [06:16<00:00, 75.24s/it]
loss: 0.2204, smth: 0.1977: 100%|██████████| 13821/13821 [10:01<00:00, 22.98it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.50it/s]


Thu Jun 10 19:14:19 2021 Epoch 19, lr: 0.0687303, 0.0006873, train loss: 0.1986, valid loss: 0.1937, mean_rce: 28.51, retweet: 33.26, reply: 29.15, like: 32.61, retweet_comment: 19.03
rce_best increased (28.147221 --> 28.511463).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 19:19:49 2021 Epoch: 20


100%|██████████| 5/5 [06:17<00:00, 75.47s/it]
loss: 0.2101, smth: 0.1990: 100%|██████████| 14182/14182 [10:16<00:00, 22.99it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.52it/s]


Thu Jun 10 19:39:12 2021 Epoch 20, lr: 0.0654508, 0.0006545, train loss: 0.1992, valid loss: 0.1932, mean_rce: 28.68, retweet: 33.43, reply: 29.39, like: 32.76, retweet_comment: 19.16
rce_best increased (28.511463 --> 28.684906).  Saving model ...
Thu Jun 10 19:44:52 2021 Epoch: 21


100%|██████████| 5/5 [06:54<00:00, 82.92s/it]
loss: 0.1962, smth: 0.2007: 100%|██████████| 14296/14296 [10:19<00:00, 23.07it/s]
100%|██████████| 10083/10083 [02:07<00:00, 78.90it/s]


Thu Jun 10 20:04:41 2021 Epoch 21, lr: 0.0620961, 0.0006210, train loss: 0.1987, valid loss: 0.1929, mean_rce: 28.83, retweet: 33.60, reply: 29.50, like: 32.80, retweet_comment: 19.43
rce_best increased (28.684906 --> 28.832020).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 20:10:10 2021 Epoch: 22


100%|██████████| 5/5 [06:16<00:00, 75.31s/it]
loss: 0.1888, smth: 0.1975: 100%|██████████| 13550/13550 [09:45<00:00, 23.15it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.56it/s]


Thu Jun 10 20:28:49 2021 Epoch 22, lr: 0.0586824, 0.0005868, train loss: 0.1983, valid loss: 0.1921, mean_rce: 29.15, retweet: 33.91, reply: 29.92, like: 33.06, retweet_comment: 19.71
rce_best increased (28.832020 --> 29.149475).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 20:34:29 2021 Epoch: 23


100%|██████████| 5/5 [06:47<00:00, 81.45s/it]
loss: 0.1969, smth: 0.1982: 100%|██████████| 14635/14635 [10:36<00:00, 22.99it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.64it/s]


Thu Jun 10 20:54:33 2021 Epoch 23, lr: 0.0552264, 0.0005523, train loss: 0.1978, valid loss: 0.1918, mean_rce: 29.28, retweet: 33.91, reply: 30.13, like: 33.16, retweet_comment: 19.90
rce_best increased (29.149475 --> 29.275816).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 21:00:01 2021 Epoch: 24


100%|██████████| 5/5 [07:36<00:00, 91.33s/it]
loss: 0.2044, smth: 0.1965: 100%|██████████| 14180/14180 [10:11<00:00, 23.19it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.68it/s]


Thu Jun 10 21:20:24 2021 Epoch 24, lr: 0.0517450, 0.0005174, train loss: 0.1973, valid loss: 0.1910, mean_rce: 29.55, retweet: 34.42, reply: 30.37, like: 33.40, retweet_comment: 19.99
rce_best increased (29.275816 --> 29.547853).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 21:25:54 2021 Epoch: 25


100%|██████████| 5/5 [06:07<00:00, 73.48s/it]
loss: 0.2042, smth: 0.1971: 100%|██████████| 13862/13862 [10:02<00:00, 23.02it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.29it/s]


Thu Jun 10 21:44:25 2021 Epoch 25, lr: 0.0482550, 0.0004826, train loss: 0.1970, valid loss: 0.1903, mean_rce: 29.74, retweet: 34.35, reply: 30.53, like: 33.77, retweet_comment: 20.31
rce_best increased (29.547853 --> 29.740671).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 21:49:50 2021 Epoch: 26


100%|██████████| 5/5 [06:16<00:00, 75.26s/it]
loss: 0.1845, smth: 0.1971: 100%|██████████| 14634/14634 [10:34<00:00, 23.05it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.28it/s]


Thu Jun 10 22:09:03 2021 Epoch 26, lr: 0.0447736, 0.0004477, train loss: 0.1967, valid loss: 0.1898, mean_rce: 29.89, retweet: 34.62, reply: 30.50, like: 33.96, retweet_comment: 20.46
rce_best increased (29.740671 --> 29.885460).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 22:14:37 2021 Epoch: 27


100%|██████████| 5/5 [06:26<00:00, 77.30s/it]
loss: 0.1801, smth: 0.1958: 100%|██████████| 14671/14671 [10:32<00:00, 23.18it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.38it/s]


Thu Jun 10 22:34:00 2021 Epoch 27, lr: 0.0413176, 0.0004132, train loss: 0.1963, valid loss: 0.1897, mean_rce: 29.98, retweet: 34.77, reply: 30.52, like: 33.96, retweet_comment: 20.69
rce_best increased (29.885460 --> 29.984455).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 22:39:28 2021 Epoch: 28


100%|██████████| 5/5 [06:20<00:00, 76.02s/it]
loss: 0.1929, smth: 0.1942: 100%|██████████| 14626/14626 [10:31<00:00, 23.16it/s]
100%|██████████| 10083/10083 [02:09<00:00, 78.11it/s]


Thu Jun 10 22:58:37 2021 Epoch 28, lr: 0.0379039, 0.0003790, train loss: 0.1958, valid loss: 0.1887, mean_rce: 30.30, retweet: 34.94, reply: 31.09, like: 34.35, retweet_comment: 20.82
rce_best increased (29.984455 --> 30.300800).  Saving model ...
Thu Jun 10 23:04:07 2021 Epoch: 29


100%|██████████| 5/5 [06:20<00:00, 76.04s/it]
loss: 0.1886, smth: 0.1968: 100%|██████████| 14668/14668 [10:35<00:00, 23.09it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.59it/s]


Thu Jun 10 23:23:19 2021 Epoch 29, lr: 0.0345492, 0.0003455, train loss: 0.1956, valid loss: 0.1885, mean_rce: 30.38, retweet: 35.05, reply: 31.12, like: 34.42, retweet_comment: 20.94
rce_best increased (30.300800 --> 30.380356).  Saving model ...
Thu Jun 10 23:28:48 2021 Epoch: 30


100%|██████████| 5/5 [06:19<00:00, 75.94s/it]
loss: 0.1890, smth: 0.1957: 100%|██████████| 14645/14645 [10:36<00:00, 23.01it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.65it/s]


Thu Jun 10 23:48:01 2021 Epoch 30, lr: 0.0312697, 0.0003127, train loss: 0.1953, valid loss: 0.1880, mean_rce: 30.55, retweet: 35.23, reply: 31.24, like: 34.61, retweet_comment: 21.12
rce_best increased (30.380356 --> 30.551567).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 23:53:29 2021 Epoch: 31


100%|██████████| 5/5 [05:58<00:00, 71.76s/it]
loss: 0.1879, smth: 0.1957: 100%|██████████| 13018/13018 [09:22<00:00, 23.14it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.69it/s]


Fri Jun 11 00:11:06 2021 Epoch 31, lr: 0.0280814, 0.0002808, train loss: 0.1949, valid loss: 0.1877, mean_rce: 30.65, retweet: 35.29, reply: 31.31, like: 34.72, retweet_comment: 21.28
rce_best increased (30.551567 --> 30.645824).  Saving model ...
Fri Jun 11 00:16:34 2021 Epoch: 32


100%|██████████| 5/5 [06:26<00:00, 77.24s/it]
loss: 0.1955, smth: 0.1931: 100%|██████████| 13865/13865 [09:58<00:00, 23.17it/s]
100%|██████████| 10083/10083 [02:09<00:00, 78.06it/s]


Fri Jun 11 00:35:22 2021 Epoch 32, lr: 0.0250000, 0.0002500, train loss: 0.1934, valid loss: 0.1869, mean_rce: 31.00, retweet: 35.65, reply: 31.86, like: 34.92, retweet_comment: 21.55
rce_best increased (30.645824 --> 30.996300).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Fri Jun 11 00:40:50 2021 Epoch: 33


100%|██████████| 5/5 [06:49<00:00, 81.81s/it]
loss: 0.1985, smth: 0.1948: 100%|██████████| 14470/14470 [10:22<00:00, 23.23it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.25it/s]


Fri Jun 11 01:00:23 2021 Epoch 33, lr: 0.0220404, 0.0002204, train loss: 0.1944, valid loss: 0.1866, mean_rce: 31.01, retweet: 35.59, reply: 31.72, like: 35.13, retweet_comment: 21.61
rce_best increased (30.996300 --> 31.011581).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Fri Jun 11 01:05:53 2021 Epoch: 34


100%|██████████| 5/5 [06:27<00:00, 77.46s/it]
loss: 0.2125, smth: 0.1938: 100%|██████████| 14637/14637 [10:32<00:00, 23.15it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.26it/s]


Fri Jun 11 01:25:09 2021 Epoch 34, lr: 0.0192169, 0.0001922, train loss: 0.1943, valid loss: 0.1864, mean_rce: 31.10, retweet: 35.71, reply: 31.83, like: 35.17, retweet_comment: 21.70
rce_best increased (31.011581 --> 31.102623).  Saving model ...
Fri Jun 11 01:30:37 2021 Epoch: 35


100%|██████████| 5/5 [06:21<00:00, 76.37s/it]
loss: 0.1897, smth: 0.1944: 100%|██████████| 14635/14635 [10:32<00:00, 23.15it/s]
100%|██████████| 10083/10083 [02:09<00:00, 77.79it/s]


Fri Jun 11 01:49:49 2021 Epoch 35, lr: 0.0165435, 0.0001654, train loss: 0.1940, valid loss: 0.1865, mean_rce: 31.11, retweet: 35.75, reply: 31.76, like: 35.11, retweet_comment: 21.81
rce_best increased (31.102623 --> 31.105156).  Saving model ...
Fri Jun 11 01:55:12 2021 Epoch: 36


100%|██████████| 5/5 [06:05<00:00, 73.12s/it]
loss: 0.1920, smth: 0.1931: 100%|██████████| 13586/13586 [09:46<00:00, 23.15it/s]
100%|██████████| 10083/10083 [02:07<00:00, 78.82it/s]


Fri Jun 11 02:13:26 2021 Epoch 36, lr: 0.0140330, 0.0001403, train loss: 0.1937, valid loss: 0.1862, mean_rce: 31.20, retweet: 35.84, reply: 31.88, like: 35.22, retweet_comment: 21.87
rce_best increased (31.105156 --> 31.201717).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Fri Jun 11 02:18:52 2021 Epoch: 37


100%|██████████| 5/5 [06:50<00:00, 82.17s/it]
loss: 0.1896, smth: 0.1930: 100%|██████████| 14526/14526 [10:34<00:00, 22.89it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.35it/s]


Fri Jun 11 02:38:34 2021 Epoch 37, lr: 0.0116978, 0.0001170, train loss: 0.1936, valid loss: 0.1860, mean_rce: 31.26, retweet: 35.90, reply: 31.95, like: 35.27, retweet_comment: 21.92
rce_best increased (31.201717 --> 31.259350).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Fri Jun 11 02:44:04 2021 Epoch: 38


100%|██████████| 5/5 [06:18<00:00, 75.74s/it]
loss: 0.1923, smth: 0.1928: 100%|██████████| 14636/14636 [10:29<00:00, 23.27it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.26it/s]


Fri Jun 11 03:03:12 2021 Epoch 38, lr: 0.0095492, 0.0000955, train loss: 0.1935, valid loss: 0.1856, mean_rce: 31.38, retweet: 35.94, reply: 32.10, like: 35.47, retweet_comment: 22.00
rce_best increased (31.259350 --> 31.378683).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Fri Jun 11 03:08:45 2021 Epoch: 39


100%|██████████| 5/5 [06:18<00:00, 75.63s/it]
loss: 0.1938, smth: 0.1924: 100%|██████████| 14637/14637 [10:30<00:00, 23.22it/s]
100%|██████████| 10083/10083 [02:09<00:00, 78.01it/s]


Fri Jun 11 03:27:51 2021 Epoch 39, lr: 0.0075976, 0.0000760, train loss: 0.1935, valid loss: 0.1857, mean_rce: 31.38, retweet: 35.98, reply: 32.07, like: 35.42, retweet_comment: 22.05
rce_best increased (31.378683 --> 31.382784).  Saving model ...
Fri Jun 11 03:33:14 2021 Epoch: 40


100%|██████████| 5/5 [06:06<00:00, 73.27s/it]
loss: 0.1852, smth: 0.1908: 100%|██████████| 13508/13508 [09:43<00:00, 23.15it/s]
100%|██████████| 10083/10083 [02:09<00:00, 77.91it/s]


Fri Jun 11 03:51:22 2021 Epoch 40, lr: 0.0058526, 0.0000585, train loss: 0.1920, valid loss: 0.1848, mean_rce: 31.62, retweet: 36.11, reply: 32.34, like: 35.81, retweet_comment: 22.21
rce_best increased (31.382784 --> 31.618502).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Fri Jun 11 03:56:46 2021 Epoch: 41


100%|██████████| 5/5 [06:49<00:00, 81.84s/it]
loss: 0.1909, smth: 0.1929: 100%|██████████| 14637/14637 [10:30<00:00, 23.20it/s]
100%|██████████| 10083/10083 [02:11<00:00, 76.79it/s]


Fri Jun 11 04:16:30 2021 Epoch 41, lr: 0.0043227, 0.0000432, train loss: 0.1932, valid loss: 0.1852, mean_rce: 31.53, retweet: 36.09, reply: 32.25, like: 35.60, retweet_comment: 22.16
Fri Jun 11 04:20:32 2021 Epoch: 42


100%|██████████| 5/5 [06:38<00:00, 79.70s/it]
loss: 0.1925, smth: 0.1939: 100%|██████████| 14473/14473 [10:23<00:00, 23.21it/s]
100%|██████████| 10083/10083 [02:07<00:00, 78.99it/s]


Fri Jun 11 04:39:55 2021 Epoch 42, lr: 0.0030154, 0.0000302, train loss: 0.1931, valid loss: 0.1850, mean_rce: 31.59, retweet: 36.14, reply: 32.29, like: 35.72, retweet_comment: 22.20


  0%|          | 0/5 [00:00<?, ?it/s]

Fri Jun 11 04:44:03 2021 Epoch: 43


100%|██████████| 5/5 [06:50<00:00, 82.01s/it]
loss: 0.1824, smth: 0.1912: 100%|██████████| 14187/14187 [10:08<00:00, 23.31it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.20it/s]


Fri Jun 11 05:03:23 2021 Epoch 43, lr: 0.0019369, 0.0000194, train loss: 0.1930, valid loss: 0.1851, mean_rce: 31.55, retweet: 36.08, reply: 32.24, like: 35.65, retweet_comment: 22.22


  0%|          | 0/5 [00:00<?, ?it/s]

Fri Jun 11 05:07:18 2021 Epoch: 44


100%|██████████| 5/5 [07:08<00:00, 85.72s/it]
loss: 0.1821, smth: 0.1947: 100%|██████████| 14784/14784 [10:46<00:00, 22.86it/s]
100%|██████████| 10083/10083 [02:09<00:00, 77.76it/s]


Fri Jun 11 05:27:52 2021 Epoch 44, lr: 0.0010926, 0.0000109, train loss: 0.1930, valid loss: 0.1851, mean_rce: 31.58, retweet: 36.13, reply: 32.36, like: 35.66, retweet_comment: 22.19


  0%|          | 0/5 [00:00<?, ?it/s]

Fri Jun 11 05:31:57 2021 Epoch: 45


100%|██████████| 5/5 [06:53<00:00, 82.63s/it]
loss: 0.1905, smth: 0.1939: 100%|██████████| 13823/13823 [09:59<00:00, 23.05it/s]
100%|██████████| 10083/10083 [02:09<00:00, 77.83it/s]


Fri Jun 11 05:51:34 2021 Epoch 45, lr: 0.0004866, 0.0000049, train loss: 0.1928, valid loss: 0.1851, mean_rce: 31.57, retweet: 36.12, reply: 32.31, like: 35.63, retweet_comment: 22.22
Fri Jun 11 05:55:37 2021 Epoch: 46


100%|██████████| 5/5 [06:18<00:00, 75.67s/it]
loss: 0.2006, smth: 0.1915: 100%|██████████| 14635/14635 [10:31<00:00, 23.16it/s]
100%|██████████| 10083/10083 [02:07<00:00, 78.79it/s]


Fri Jun 11 06:14:47 2021 Epoch 46, lr: 0.0001218, 0.0000012, train loss: 0.1929, valid loss: 0.1851, mean_rce: 31.59, retweet: 36.14, reply: 32.32, like: 35.65, retweet_comment: 22.22


## load best ep and inference LB valid

In [11]:
def read_norm_merge(ddf):

    ddf['quantile'] = 0
    quantiles = [ 240,  588, 1331, 3996]
    for i, quant in enumerate(quantiles):
        ddf['quantile'] = (ddf['quantile']+(ddf['a_follower_count']>quant).astype('int8')).astype('int8')

    ddf['date'] = cudf.to_datetime(ddf['timestamp'], unit='s')
       
    ddf['a_ff_rate'] = (ddf['a_following_count'] / ddf['a_follower_count']).astype('float32')
    ddf['b_ff_rate'] = (ddf['b_follower_count']  / ddf['b_following_count']).astype('float32')
    ddf['ab_fing_rate'] = (ddf['a_following_count'] / ddf['b_following_count']).astype('float32')
    ddf['ab_fer_rate'] = (ddf['a_follower_count'] / (1+ddf['b_follower_count'])).astype('float32')
    ddf['a_age'] = ddf['a_account_creation'].astype('int16') + 128
    ddf['b_age'] = ddf['b_account_creation'].astype('int16') + 128
    ddf['ab_age_dff'] = ddf['b_age'] - ddf['a_age']
    ddf['ab_age_rate'] = ddf['a_age']/(1+ddf['b_age'])

    ## Normalize
    for col in NUMERIC_COLUMNS:
        if col == 'tw_len_quest':
            ddf[col] = np.clip(ddf[col].values.get(),0,None)
        if ddf[col].dtype == 'uint16':
            ddf[col].astype('int32')

        if col == 'ab_age_dff':
            ddf[col] = ddf[col] / 256.            
        elif 'int' in str(ddf[col].dtype) or 'float' in str(ddf[col].dtype):    
            ddf[col] = np.log1p(ddf[col])

        if ddf[col].dtype == 'float64':
            ddf[col] = ddf[col].astype('float32') 
            
    ddf['b_user_id_hash'] = ddf['b_user_id'].copy()

    ## get categorical embedding id        
    for col in CAT_COLUMNS:
        ddf[col] = ddf[col].astype('float')
        if col in ['a_user_id','b_user_id']:
            mapping_col = 'a_user_id_b_user_id'
        else:
            mapping_col = col
        mapping = cudf.read_parquet(f'/raid/recsys_pre_TE_w_tok/workflow_232parts_joint_thr3_pos/categories/unique.{mapping_col}.parquet').reset_index()
        mapping.columns = ['index',col]
        ddf = ddf.merge(mapping, how='left', on=col).drop(columns=[col]).rename(columns={'index':col})
        ddf[col] = ddf[col].fillna(0).astype('int')        

    label_names = ['reply', 'retweet', 'retweet_comment', 'like']
    DONT_USE = ['timestamp','a_account_creation','b_account_creation','engage_time',
                'fold', 'dt_dow', 'a_account_creation', 
                'b_account_creation', 'elapsed_time', 'links','domains','hashtags','id', 'date', 'is_train', 
                'tw_hash0', 'tw_hash1', 'tw_hash2', 'tw_http0', 'tw_uhash', 'tw_hash', 'tw_word0', 
                'tw_word1', 'tw_word2', 'tw_word3', 'tw_word4', 'dt_minute', 'dt_second',
               'dt_day', 'group', 'text', 'tw_original_user0', 'tw_original_user1', 'tw_original_user2',
                'tw_rt_user0', 'tw_original_http0', 'tw_tweet',]
    DONT_USE = [c for c in ddf.columns if c in DONT_USE]
    gc.collect(); gc.collect()
    
    return ddf.drop(columns=DONT_USE)

In [12]:
%%time
df = cudf.read_parquet('/raid/recsys_valid/valid_proc.parquet',num_rows=7_000_000)
df = read_norm_merge(df).to_pandas()

df2 = cudf.read_parquet('/raid/recsys_valid/valid_proc.parquet',skiprows=7_000_000)
df2 = read_norm_merge(df2).to_pandas()

valid = pd.concat([df,df2])
del df,df2
gc.collect()

valid.shape

CPU times: user 17.1 s, sys: 11.4 s, total: 28.5 s
Wall time: 34.6 s


(14461760, 49)

In [13]:
valid_dataset = AllDataset(valid, max_len_txt, NUMERIC_COLUMNS, CAT_COLUMNS)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers) 

len(valid_loader)

14123

## make fp16 ckpt

In [14]:
# sd = torch.load(f'../models/{model_name}_best.pth')
# sd['initial_cat_layer.embedding_layers.0.weight'] = sd['initial_cat_layer.embedding_layers.0.weight'].half()
# torch.save(sd, f'../models/{model_name}_best_fp16.pth')

In [15]:
!ls -lrth ../models/{model_name}_best*

-rw-rw-r-- 1 bo bo 9.8G Jun 11 03:52 ../models/load_thr3_pos_1e-2_1e-4_best.pth
-rw-rw-r-- 1 bo bo 5.1G Jun 11 09:35 ../models/load_thr3_pos_1e-2_1e-4_best_fp16.pth


In [16]:
sd = torch.load(f'../models/{model_name}_best_fp16.pth')
# sd = torch.load('/home/bo/kaggle/recsys/recsysChallenge2021/bo/sub/v11_len48_thr25_joint_MF/MF_len48_joint_thr25_3weeks_best.pth')
sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
model.load_state_dict(sd, strict=True)

<All keys matched successfully>

In [17]:
label_names = sorted(label_names)
label_names

['like', 'reply', 'retweet', 'retweet_comment']

In [18]:
model.eval()
val_loss = []
LOGITS = []
TARGETS = []
with torch.no_grad():
    for batch in tqdm(valid_loader):
        x_cat, x_cont, text_tok, targets = batch
        x_cat = x_cat.cuda()     
        x_cont = x_cont.cuda()
        text_tok = text_tok.cuda()
        targets = targets.cuda()            
        logits = model(x_cat, x_cont, text_tok)
        loss = criterion(logits, targets)
        val_loss.append(loss.item())
        LOGITS.append(logits.cpu())
        TARGETS.append(targets.cpu())

LOGITS = torch.cat(LOGITS)
TARGETS = torch.cat(TARGETS)
rce = {}
for i in range(4):
    rce[label_names[i]] = compute_rce_fast(cp.asarray(LOGITS[:,i].sigmoid()),cp.asarray(TARGETS[:,i])).get()            
mean_rce = np.mean([v for k,v in rce.items()])
mean_rce

100%|██████████| 14123/14123 [02:48<00:00, 83.70it/s]


15.035736

In [19]:
# df_quantile = pd.concat([pd.read_parquet(path)[['quantile']] for path in VALID_PATHS]).reset_index(drop=True)
# df_quantile = df_quantile.apply(np.expm1).round().astype(int)
df_quantile = valid[['quantile']].copy().reset_index(drop=True)
df_quantile.shape

yquantile = cupy.asarray(df_quantile.values)
oof = cupy.asarray(LOGITS.sigmoid())
yvalid = cupy.asarray(TARGETS)

In [20]:
from util import compute_prauc, average_precision_score,display_score

rce_output = {}
ap_output = {}
for i in range(4):
    prauc_out = []
    rce_out = []
    ap_out = []
    for j in range(5):
        this_quantile_idx = (df_quantile == j)['quantile'].values
        yvalid_tmp = yvalid[this_quantile_idx][:, i]
        oof_tmp = oof[this_quantile_idx][:, i]
        prauc = compute_prauc(oof_tmp, yvalid_tmp)
        rce   = compute_rce_fast(oof_tmp, yvalid_tmp).item()
        ap    = average_precision_score(cupy.asnumpy(yvalid_tmp),cupy.asnumpy(oof_tmp))
        prauc_out.append(prauc)
        rce_out.append(rce)
        ap_out.append(ap)
    rce_output[label_names[i]] = rce_out
    ap_output[label_names[i]] = ap_out

In [21]:
# public test Epoch 40, fp16
print(model_name)
display_score(rce_output, ap_output)

load_thr3_pos_1e-2_1e-4
Quantile Group|AP Retweet|RCE Retweet|  AP Reply|  RCE Reply|   AP Like|   RCE Like|AP RT comment|RCE RT comment
        0          0.3656     18.9687     0.1869     17.8583     0.5934      5.3910     0.0427      9.5393
        1          0.3521     19.3371     0.1945     19.0993     0.5752      4.7853     0.0391      9.6176
        2          0.3492     19.6940     0.2157     20.3177     0.5668      5.4737     0.0376      9.7294
        3          0.3658     20.4278     0.2408     22.2145     0.5809      7.3677     0.0415     10.8575
        4          0.3518     20.7003     0.1502     17.2623     0.6484     10.4704     0.0409     11.5956
     Average       0.3569     19.8256     0.1976     19.3504     0.5929      6.6976     0.0404     10.2679


In [22]:
LOGITS = LOGITS.sigmoid().numpy()

for i,label in enumerate(label_names):
    valid[label] = LOGITS[:,i]

valid[['tweet_id', 
    'b_user_id_hash', 
    'reply', 
    'retweet', 
    'retweet_comment', 
    'like']].to_csv(f'../results_{model_name}.csv', header=False, index=False)

In [20]:
# public test Epoch 40
print(model_name)
display_score(rce_output, ap_output)

two_opt_lr3_load_len48_joint_thr10_3e-3_1e-4
Quantile Group|AP Retweet|RCE Retweet|  AP Reply|  RCE Reply|   AP Like|   RCE Like|AP RT comment|RCE RT comment
        0          0.3667     19.4164     0.1833     18.0032     0.5944      7.9084     0.0368      8.7007
        1          0.3531     19.3318     0.1900     18.8197     0.5731      6.3514     0.0338      8.8032
        2          0.3493     19.4528     0.2107     19.9319     0.5639      6.5367     0.0329      8.7347
        3          0.3646     19.9658     0.2337     21.5662     0.5770      8.0349     0.0360      9.5802
        4          0.3461     19.8541     0.1425     16.2127     0.6488     11.7476     0.0364     10.5095
     Average       0.3560     19.6042     0.1920     18.9068     0.5915      8.1158     0.0352      9.2656


In [19]:
# public test Epoch 32
print(model_name)
display_score(rce_output, ap_output)

two_opt_lr3_load_len48_joint_thr10_3e-3_1e-4
Quantile Group|AP Retweet|RCE Retweet|  AP Reply|  RCE Reply|   AP Like|   RCE Like|AP RT comment|RCE RT comment
        0          0.3640     19.2808     0.1825     17.8791     0.5927      7.8090     0.0359      8.4515
        1          0.3496     19.1279     0.1884     18.6515     0.5710      6.2476     0.0325      8.5330
        2          0.3451     19.1840     0.2076     19.7104     0.5618      6.4109     0.0317      8.5464
        3          0.3604     19.6577     0.2296     21.2601     0.5752      7.9033     0.0349      9.3634
        4          0.3413     19.6354     0.1397     15.7311     0.6472     11.7349     0.0351     10.2205
     Average       0.3521     19.3772     0.1896     18.6465     0.5896      8.0211     0.0340      9.0230


In [21]:
19.3772     +     18.6465     +     8.0211     +    9.0230

55.06779999999999

In [None]:
ChrisDeotte	version_32	0.3384	18.6481	0.1857	18.5767	0.6244	13.2391	0.0339	9.1282	3 hours

In [20]:
18.6481	+	18.5767	+	13.2391	+	9.1282

59.5921

In [19]:
# public test
print(model_name)
display_score(rce_output, ap_output)

MF_len48_joint_thr25_3weeks
Quantile Group|AP Retweet|RCE Retweet|  AP Reply|  RCE Reply|   AP Like|   RCE Like|AP RT comment|RCE RT comment
        0          0.3648     19.1185     0.1768     17.4406     0.5983      8.8159     0.0343      8.4536
        1          0.3457     18.4016     0.1811     17.8598     0.5753      6.9555     0.0309      8.3305
        2          0.3388     17.9507     0.2000     18.7393     0.5646      6.6237     0.0306      8.0775
        3          0.3504     17.6180     0.2199     19.9633     0.5772      7.1854     0.0310      8.4171
        4          0.3247     16.5028     0.1267     14.6986     0.6501     10.9945     0.0286      8.7294
     Average       0.3449     17.9183     0.1809     17.7403     0.5931      8.1150     0.0311      8.4016
