In [None]:
# Copyright 2021 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [1]:
import os, time
os.environ["CUDA_VISIBLE_DEVICES"]='2'
os.environ["CUDA_LAUNCH_BLOCKING"]='1'

import glob
import pandas as pd
import numpy as np
import cudf
import cupy
import gc
from datetime import datetime

from util import compute_rce_fast

DP = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))>1
DP

False

In [2]:
import cupy as cp
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import nvtabular as nvt
from nvtabular.loader.torch import TorchAsyncItr, DLDataLoader
from nvtabular.framework_utils.torch.models import Model
from nvtabular.framework_utils.torch.utils import process_epoch

import torch
from torch import nn
torch.__version__

'1.7.1+cu101'

## model

In [3]:
class ConcatenatedEmbeddings(torch.nn.Module):
    """Map multiple categorical variables to concatenated embeddings.
    Args:
        embedding_table_shapes: A dictionary mapping column names to
            (cardinality, embedding_size) tuples.
        dropout: A float.
    Inputs:
        x: An int64 Tensor with shape [batch_size, num_variables].
    Outputs:
        A Float Tensor with shape [batch_size, embedding_size_after_concat].
    """

    def __init__(self, embedding_table_shapes, dropout=0.0):
        super().__init__()
        self.embedding_layers = torch.nn.ModuleList(
            [
                torch.nn.Embedding(cat_size, emb_size, sparse=(cat_size > 1e5))
                for cat_size, emb_size in embedding_table_shapes.values()
            ]
        )
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, x):
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
            
        # first two cat columns (a_user and b_user) share same emb table            
        x = [self.embedding_layers[0](x[:,0])] + [layer(x[:, i+1]) for i, layer in enumerate(self.embedding_layers)] 
        x = torch.cat(x, dim=1)
        x = self.dropout(x)
        return x

In [4]:
import torch.nn as nn

sigmoid = nn.Sigmoid()

class Swish(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * sigmoid(i)
        ctx.save_for_backward(i)
        return result
    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class Swish_Module(nn.Module):
    def forward(self, x):
        return Swish.apply(x)


from transformers import AutoTokenizer, AutoModel

bert_type = 'distilbert-base-multilingual-cased'

tokenizer = AutoTokenizer.from_pretrained(bert_type)

class Net(nn.Module):
    def __init__(self, num_features, layers, embedding_table_shapes, dropout=0.2, bert_type=None, gru_dim=128, emb_dim=768):
        super(Net, self).__init__()
        self.dropout = dropout
        self.initial_cat_layer = ConcatenatedEmbeddings(embedding_table_shapes, dropout=dropout)
        embedding_size = sum(emb_size for _, emb_size in embedding_table_shapes.values())
        layers = [layers] if type(layers) is int else layers
        layers = [num_features + gru_dim + embedding_size + 128 + 128] + layers
        self.use_bert = True
        self.embed = AutoModel.from_pretrained(bert_type).embeddings.word_embeddings  
        assert emb_dim == self.embed.embedding_dim
#             self.reduce_dim = nn.Linear(self.embed.embedding_dim, 256)
#             self.embed = nn.Embedding(119547, emb_dim)
#         layers[0] += gru_dim
        self.lstm = nn.GRU(emb_dim, gru_dim, batch_first=True, bidirectional=False)    
#             self.lstm = nn.Linear(self.embed.embedding_dim, gru_dim)

        self.fn_layers = nn.ModuleList(
                            nn.Sequential(
                                nn.Dropout(p=dropout),
                                nn.Linear(layers[i], layers[i+1]),
                                nn.BatchNorm1d(layers[i+1]),
                                Swish_Module(),
                            )  for i in range(len(layers) -1)
                         )        
        self.fn_last = nn.Linear(layers[-1],4)
        
    def forward(self, x_cat, x_cont, bert_tok):
        a_emb = self.initial_cat_layer.embedding_layers[0](x_cat[:,0])
        b_emb = self.initial_cat_layer.embedding_layers[0](x_cat[:,1])
        mf = a_emb * b_emb        
        
        x_cat = self.initial_cat_layer(x_cat)
        bert_tok = self.embed(bert_tok)#.mean(dim=1)
#             bert_tok = self.reduce_dim(bert_tok)
        lstm_out = self.lstm(bert_tok)[0][:,-1]
        output = torch.cat([x_cont, lstm_out, x_cat, mf],dim=1)
        for layer in self.fn_layers:
            output = layer(output)
        logit = self.fn_last(output)
        return logit

## scheduler

In [5]:
from warmup_scheduler import GradualWarmupScheduler
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import warnings; warnings.simplefilter('ignore')

def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 

## train loop

In [6]:
criterion = nn.BCEWithLogitsLoss()

def train_epoch(model, loader, optimizer, scaler, optimizer2):

    model.train()
    train_loss = []
    bar = tqdm(loader)
    for batch in bar:
        x_cat, x_cont, text_tok, targets = batch
        
        x_cat = x_cat.cuda()
        x_cont = x_cont.cuda()
        text_tok = text_tok.cuda()
        targets = targets.cuda()

        optimizer.zero_grad()
        optimizer2.zero_grad()

        if use_torch_amp:
            with amp.autocast():
                logits = model(x_cat, x_cont, text_tok)
#                 logits = model(data)
            loss = criterion(logits, targets)       
            
            scaler.scale(loss).backward()

            # You can choose which optimizers receive explicit unscaling, if you
            # want to inspect or modify the gradients of the params they own.
            scaler.unscale_(optimizer)
            scaler.unscale_(optimizer2)

            scaler.step(optimizer)
            scaler.step(optimizer2)

            scaler.update()            
            
        elif use_amp:
            logits = model(x_cat, x_cont, text_tok)
#             logits = model(data)
            loss = criterion(logits, targets)
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
        else:
            logits = model(x_cat, x_cont, text_tok)
#             logits = model(data)
            loss = criterion(logits, targets)
            loss.backward()
            optimizer.step()

        loss_np = loss.item()
        train_loss.append(loss_np)
        smooth_loss = sum(train_loss[-50:]) / min(len(train_loss), 50)
        bar.set_description('loss: %.4f, smth: %.4f' % (loss_np, smooth_loss))

    return np.mean(train_loss)

def valid_epoch(model, loader):

    model.eval()
    val_loss = []
    LOGITS = []
    TARGETS = []
    with torch.no_grad():
        for batch in tqdm(loader):
            x_cat, x_cont, text_tok, targets = batch

            x_cat = x_cat.cuda()
            x_cont = x_cont.cuda()
            text_tok = text_tok.cuda()
            targets = targets.cuda()
        
            logits = model(x_cat, x_cont, text_tok)
#             logits = model(data)
            loss = criterion(logits, targets)
            val_loss.append(loss.item())
            LOGITS.append(logits.cpu())
            TARGETS.append(targets.cpu())
            
    LOGITS = torch.cat(LOGITS)
    TARGETS = torch.cat(TARGETS)
    rce = {}
    for i in range(4):
        rce[label_names[i]] = compute_rce_fast(cp.asarray(LOGITS[:,i].sigmoid()),cp.asarray(TARGETS[:,i])).get()            
    mean_rce = np.mean([v for k,v in rce.items()])
            
    val_loss = np.mean(val_loss)

    return val_loss, rce, mean_rce

# NVT loader

In [7]:
label_names = sorted(['reply', 'retweet', 'retweet_comment', 'like'])
CAT_COLUMNS = ['a_user_id','b_user_id','language','media','tweet_type']
NUMERIC_COLUMNS = ['a_follower_count',
                     'a_following_count',
                     'a_is_verified',
                     'b_follower_count',
                     'b_following_count',
                     'b_is_verified',
                     'b_follows_a',
                     'tw_len_media',
                     'tw_len_photo',
                     'tw_len_video',
                     'tw_len_gif',
                     'tw_len_quest',
                     'tw_len_token',
                     'tw_count_capital_words',
                     'tw_count_excl_quest_marks',
                     'tw_count_special1',
                     'tw_count_hash',
                     'tw_last_quest',
                     'tw_len_retweet',
                     'tw_len_rt',
                     'tw_count_at',
                     'tw_count_words',
                     'tw_count_char',
                     'tw_rt_count_words',
                     'tw_rt_count_char',
                     'len_hashtags',
                     'len_links',
                     'len_domains',
                     'a_ff_rate',
                     'b_ff_rate',
                     'ab_fing_rate',
                     'ab_fer_rate',
                     'a_age',
                     'b_age',
                     'ab_age_dff',
                     'ab_age_rate']
len(NUMERIC_COLUMNS)

36

In [8]:
def read_norm_merge(path, split='train'):
    ddf = pd.read_parquet(path)

    ddf['quantile'] = 0
    quantiles = [92, 216, 442, 1064]
    for i, quant in enumerate(quantiles):
        ddf['quantile'] = (ddf['quantile']+(ddf['a_follower_count']>quant).astype('int8')).astype('int8')

    ddf['date'] = pd.to_datetime(ddf['timestamp'], unit='s')
    
    VALID_DOW = '2021-02-18'
    if split=='train':
        ddf = ddf[ddf['date']<pd.to_datetime(VALID_DOW)].reset_index(drop=True)
    elif split=='valid':
        ddf = ddf[ddf['date']>=pd.to_datetime(VALID_DOW)].reset_index(drop=True)    
    else:
        pass
    
    ddf['a_ff_rate'] = (ddf['a_following_count'] / ddf['a_follower_count']).astype('float32')
    ddf['b_ff_rate'] = (ddf['b_follower_count']  / ddf['b_following_count']).astype('float32')
    ddf['ab_fing_rate'] = (ddf['a_following_count'] / ddf['b_following_count']).astype('float32')
    ddf['ab_fer_rate'] = (ddf['a_follower_count'] / (1+ddf['b_follower_count'])).astype('float32')
    ddf['a_age'] = ddf['a_account_creation'].astype('int16') + 128
    ddf['b_age'] = ddf['b_account_creation'].astype('int16') + 128
    ddf['ab_age_dff'] = ddf['b_age'] - ddf['a_age']
    ddf['ab_age_rate'] = ddf['a_age']/(1+ddf['b_age'])

    ## Normalize
    for col in NUMERIC_COLUMNS:
        if col == 'tw_len_quest':
            ddf[col] = np.clip(ddf[col].values,0,None)
        if ddf[col].dtype == 'uint16':
            ddf[col].astype('int32')

        if col == 'ab_age_dff':
            ddf[col] = ddf[col] / 256.            
        elif 'int' in str(ddf[col].dtype) or 'float' in str(ddf[col].dtype):    
            ddf[col] = np.log1p(ddf[col])

        if ddf[col].dtype == 'float64':
            ddf[col] = ddf[col].astype('float32') 

    ## get categorical embedding id        
    for col in CAT_COLUMNS:
        ddf[col] = ddf[col].astype('float')
        if col in ['a_user_id','b_user_id']:
            mapping_col = 'a_user_id_b_user_id'
        else:
            mapping_col = col
        mapping = pd.read_parquet(f'/raid/recsys_pre_TE_w_tok/workflow_232parts_joint_thr10/categories/unique.{mapping_col}.parquet').reset_index()
        mapping.columns = ['index',col]
        ddf = ddf.merge(mapping, how='left', on=col).drop(columns=[col]).rename(columns={'index':col})
        ddf[col] = ddf[col].fillna(0).astype('int')        

    label_names = ['reply', 'retweet', 'retweet_comment', 'like']
    DONT_USE = ['timestamp','a_account_creation','b_account_creation','engage_time',
                'fold', 'dt_dow', 'a_account_creation', 
                'b_account_creation', 'elapsed_time', 'links','domains','hashtags','id', 'date', 'is_train', 
                'tw_hash0', 'tw_hash1', 'tw_hash2', 'tw_http0', 'tw_uhash', 'tw_hash', 'tw_word0', 
                'tw_word1', 'tw_word2', 'tw_word3', 'tw_word4', 'dt_minute', 'dt_second',
               'dt_day', 'group', 'text', 'tweet_id', 'tw_original_user0', 'tw_original_user1', 'tw_original_user2',
                'tw_rt_user0', 'tw_original_http0', 'tw_tweet',]
    DONT_USE = [c for c in ddf.columns if c in DONT_USE]
    gc.collect(); gc.collect()
    
    return ddf.drop(columns=DONT_USE)

In [9]:
PATHS = sorted(glob.glob('/raid/recsys/train_proc3/*.parquet'))
len(PATHS)

232

In [10]:
# for col in NUMERIC_COLUMNS:
#     print(col)
#     plt.hist(train[col].values, bins=50)
#     plt.title(col)
# #     print(ddf[col].describe())
#     plt.show()

In [11]:
import torch
from torch.utils.data import Dataset,DataLoader

class AllDataset(Dataset):
    def __init__(self, df, max_len_txt, NUMERIC_COLUMNS, CAT_COLUMNS):
        self.X = df[NUMERIC_COLUMNS].values
        self.X_cat = df[CAT_COLUMNS].values
        self.labels = df[label_names].values
        self.text_tokens = df.text_tokens.values
        self.max_len_txt = max_len_txt
    def __len__(self):
        return self.labels.shape[0]
    def __getitem__(self, index):        
#         text = tokenizer.decode([int(token_id) for token_id in self.text_tokens[index][4:-4].split('\t')]) # [4:-4] is to remove [CLS] and [SEP]
#         inputs = tokenizer(text, truncation=True, padding='max_length', max_length=max_len_txt, return_tensors='pt')['input_ids'].squeeze()
        inputs = [int(token_id) for token_id in self.text_tokens[index].split('\t')][:self.max_len_txt]
        if len(inputs) < self.max_len_txt:
            inputs += [0]*(self.max_len_txt-len(inputs))
        return self.X_cat[index], self.X[index].astype(np.float32), torch.tensor(inputs), self.labels[index].astype(np.float32)

In [13]:
gru_dim=128
max_len_txt=48
emb_dim=768
lr = 3e-3
ep = 46   
BATCH_SIZE = 1024
num_workers = 16
use_torch_amp = True
import torch.cuda.amp as amp
use_amp = False

model_name = 'two_opt_lr3_load_len48_joint_thr10_3e-3_1e-4'

In [14]:
len(NUMERIC_COLUMNS)

36

In [15]:
%%time
train_lst = []
for path in PATHS[:10]:
    train_lst.append(read_norm_merge(path, 'valid'))
valid = pd.concat(train_lst)
gc.collect()

valid_dataset = AllDataset(valid, max_len_txt, NUMERIC_COLUMNS, CAT_COLUMNS)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers) 
valid.shape, len(valid_loader)

CPU times: user 7min 41s, sys: 1min 3s, total: 8min 45s
Wall time: 8min 30s


((10324907, 47), 10083)

In [16]:
# train_parts_order = np.concatenate([np.random.permutation(232)])
train_parts_order = np.array([ 46, 111, 208, 230,   3,  22, 227, 153,  78,  52,  20, 185,   6,
        130, 177,  83,  97, 194,  24, 187,  93,  59, 217, 180, 129,  62,
          1,  43, 229, 102, 196,  50,   4,  12, 114,  70,  18,  91,  71,
        190, 174,  23,  63,  89, 188,  16, 104,  67,  39, 225, 176,  28,
        198,   2,  76, 166, 216, 116, 199, 113, 107, 201,  64, 115,   8,
        171,  44, 218, 158, 181,  79,  47, 155, 159, 164, 109,  56, 106,
        122, 203, 144,  14, 163, 124, 110, 126,  80,  77,  94, 135,  33,
        134, 224, 145, 172, 191,  60, 148, 215, 212, 219,  35, 167,  37,
        132, 182, 228,  75,  87, 156, 137,  74,  29,  95, 118,  90, 222,
         19,  57, 162, 105, 223, 210, 140,  10,  72, 152, 183, 170,  51,
         82, 117,  13, 211, 120,  81, 160,  27, 200, 128, 169, 213, 179,
         42,  11, 143,  15, 209, 151,  48, 207, 112, 119, 231, 175,   0,
        146, 154,  68, 197,  21, 206, 125, 192,  31,  86, 138,  36, 108,
        103,  58, 142,  54,  98,  99, 127, 214,   7,  92, 121, 202, 141,
        150,  88,  53,  38, 139, 147, 131,  66,  40,  26, 123,  73, 100,
        165, 186, 149, 205,   5, 189,  25,  32, 133, 101, 204, 178, 193,
        136,  84, 161,  30, 221,  65,  85,  41,  17,  61,  45, 173, 195,
          9, 184,  55,  49, 168,  69,  34,  96, 157, 226, 220])
train_parts_order, train_parts_order.shape

(array([ 46, 111, 208, 230,   3,  22, 227, 153,  78,  52,  20, 185,   6,
        130, 177,  83,  97, 194,  24, 187,  93,  59, 217, 180, 129,  62,
          1,  43, 229, 102, 196,  50,   4,  12, 114,  70,  18,  91,  71,
        190, 174,  23,  63,  89, 188,  16, 104,  67,  39, 225, 176,  28,
        198,   2,  76, 166, 216, 116, 199, 113, 107, 201,  64, 115,   8,
        171,  44, 218, 158, 181,  79,  47, 155, 159, 164, 109,  56, 106,
        122, 203, 144,  14, 163, 124, 110, 126,  80,  77,  94, 135,  33,
        134, 224, 145, 172, 191,  60, 148, 215, 212, 219,  35, 167,  37,
        132, 182, 228,  75,  87, 156, 137,  74,  29,  95, 118,  90, 222,
         19,  57, 162, 105, 223, 210, 140,  10,  72, 152, 183, 170,  51,
         82, 117,  13, 211, 120,  81, 160,  27, 200, 128, 169, 213, 179,
         42,  11, 143,  15, 209, 151,  48, 207, 112, 119, 231, 175,   0,
        146, 154,  68, 197,  21, 206, 125, 192,  31,  86, 138,  36, 108,
        103,  58, 142,  54,  98,  99, 127, 214,   7

In [17]:
model = Net(len(NUMERIC_COLUMNS), layers=[1024,256,64], 
            embedding_table_shapes={'a_user_id_b_user_id': (15453524, 128), 'language': (67, 16), 'media': (15, 16), 'tweet_type': (4, 16)},
            bert_type=bert_type).cuda()

for param in model.embed.parameters():
    param.requires_grad = False

model    

Net(
  (initial_cat_layer): ConcatenatedEmbeddings(
    (embedding_layers): ModuleList(
      (0): Embedding(15453524, 128, sparse=True)
      (1): Embedding(67, 16)
      (2): Embedding(15, 16)
      (3): Embedding(4, 16)
    )
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (embed): Embedding(119547, 768, padding_idx=0)
  (lstm): GRU(768, 128, batch_first=True)
  (fn_layers): ModuleList(
    (0): Sequential(
      (0): Dropout(p=0.2, inplace=False)
      (1): Linear(in_features=596, out_features=1024, bias=True)
      (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Swish_Module()
    )
    (1): Sequential(
      (0): Dropout(p=0.2, inplace=False)
      (1): Linear(in_features=1024, out_features=256, bias=True)
      (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Swish_Module()
    )
    (2): Sequential(
      (0): Dropout(p=0.2, inplace=False)
      (1): Linear(in_features=256, out_

In [18]:
sd = torch.load(f'../models/MF_len48_joint_thr25_3weeks_best.pth')
sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
del sd['initial_cat_layer.embedding_layers.0.weight']
model.load_state_dict(sd, strict=False)

_IncompatibleKeys(missing_keys=['initial_cat_layer.embedding_layers.0.weight'], unexpected_keys=[])

In [19]:
optimizer = optim.SparseAdam(list(model.parameters())[:1], lr=lr)
optimizer2 = optim.AdamW(list(model.parameters())[1:], lr=lr/30)
scaler = amp.GradScaler() if use_torch_amp else None

scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, ep-1)
scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=1, after_scheduler=scheduler_cosine)

rce_best = 0

## start training

In [20]:
print(model_name)

two_opt_lr3_load_len48_joint_thr10_3e-3_1e-4


In [21]:
for epoch in range(1, ep+1):
    print(time.ctime(), 'Epoch:', epoch)
    scheduler_warmup.step(epoch-1) 
    
    # 5 parts per epoch
    idx_this_ep = train_parts_order[(epoch*5-5):epoch*5]
    
    train_lst = []
    for idx in tqdm(idx_this_ep):
        train_lst.append(read_norm_merge(PATHS[idx], 'train' if idx<10 else 'both'))
    train = pd.concat(train_lst)
 
    gc.collect();gc.collect();
    
    train_dataset = AllDataset(train, max_len_txt, NUMERIC_COLUMNS, CAT_COLUMNS)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, drop_last=True) 
    
    train_loss = train_epoch(model, train_loader, optimizer, scaler, optimizer2)
    valid_loss,rce,mean_rce = valid_epoch(model, valid_loader)
   
    content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.4f}, valid loss: {valid_loss:.4f}, mean_rce: {mean_rce:.2f}'
    for col in ['retweet', 'reply',  'like', 'retweet_comment']:
        content += f', {col}: {rce[col]:.2f}'
        
    print(content)
    
    if mean_rce > rce_best:
        print('rce_best increased ({:.6f} --> {:.6f}).  Saving model ...'.format(rce_best, mean_rce))
        rce_best = mean_rce
                
        torch.save(model.state_dict(), f'../models/{model_name}_best.pth')
        
    torch.save(
        {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict() if scaler else None,
            'optimizer_state_dict2': optimizer2.state_dict(),
            'rce_best': rce_best,
        },
        f'../models/{model_name}_last.pth'
    )            
        
torch.save(model.state_dict(), f'../models/{model_name}_final.pth')

  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 10:19:56 2021 Epoch: 1


100%|██████████| 5/5 [05:11<00:00, 62.23s/it]
loss: 0.2592, smth: 0.2507: 100%|██████████| 12298/12298 [08:48<00:00, 23.27it/s]
100%|██████████| 10083/10083 [02:04<00:00, 81.31it/s]


Wed Jun  9 10:36:06 2021 Epoch 1, lr: 0.0030000, train loss: 0.2578, valid loss: 0.2497, mean_rce: 10.36, retweet: 12.19, reply: 13.88, like: 12.39, retweet_comment: 2.98
rce_best increased (0.000000 --> 10.359921).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 10:40:18 2021 Epoch: 2


100%|██████████| 5/5 [05:31<00:00, 66.32s/it]
loss: 0.2257, smth: 0.2325: 100%|██████████| 13575/13575 [09:47<00:00, 23.09it/s]
100%|██████████| 10083/10083 [02:12<00:00, 76.02it/s]


Wed Jun  9 10:58:01 2021 Epoch 2, lr: 0.0300000, train loss: 0.2401, valid loss: 0.2310, mean_rce: 15.23, retweet: 19.71, reply: 16.13, like: 19.65, retweet_comment: 5.44
rce_best increased (10.359921 --> 15.233335).  Saving model ...
Wed Jun  9 11:02:16 2021 Epoch: 3


100%|██████████| 5/5 [05:38<00:00, 67.74s/it]
loss: 0.2099, smth: 0.2267: 100%|██████████| 13671/13671 [09:56<00:00, 22.92it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.75it/s]


Wed Jun  9 11:20:10 2021 Epoch 3, lr: 0.0300000, train loss: 0.2290, valid loss: 0.2246, mean_rce: 17.30, retweet: 22.51, reply: 17.88, like: 21.81, retweet_comment: 7.02
rce_best increased (15.233335 --> 17.301279).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 11:24:31 2021 Epoch: 4


100%|██████████| 5/5 [05:48<00:00, 69.76s/it]
loss: 0.2273, smth: 0.2226: 100%|██████████| 14636/14636 [10:37<00:00, 22.96it/s]
100%|██████████| 10083/10083 [02:09<00:00, 77.93it/s]


Wed Jun  9 11:43:14 2021 Epoch 4, lr: 0.0298540, train loss: 0.2236, valid loss: 0.2208, mean_rce: 18.76, retweet: 24.18, reply: 19.38, like: 22.93, retweet_comment: 8.55
rce_best increased (17.301279 --> 18.760059).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 11:47:33 2021 Epoch: 5


100%|██████████| 5/5 [05:43<00:00, 68.74s/it]
loss: 0.2280, smth: 0.2174: 100%|██████████| 14331/14331 [10:24<00:00, 22.95it/s]
100%|██████████| 10083/10083 [02:07<00:00, 79.12it/s]


Wed Jun  9 12:06:01 2021 Epoch 5, lr: 0.0296722, train loss: 0.2200, valid loss: 0.2175, mean_rce: 19.97, retweet: 25.45, reply: 20.68, like: 24.07, retweet_comment: 9.68
rce_best increased (18.760059 --> 19.968056).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 12:10:16 2021 Epoch: 6


100%|██████████| 5/5 [05:20<00:00, 64.14s/it]
loss: 0.2079, smth: 0.2164: 100%|██████████| 12543/12543 [08:58<00:00, 23.28it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.29it/s]


Wed Jun  9 12:26:53 2021 Epoch 6, lr: 0.0294189, train loss: 0.2173, valid loss: 0.2150, mean_rce: 20.87, retweet: 26.24, reply: 21.69, like: 24.95, retweet_comment: 10.59
rce_best increased (19.968056 --> 20.870073).  Saving model ...
Wed Jun  9 12:31:12 2021 Epoch: 7


100%|██████████| 5/5 [06:07<00:00, 73.41s/it]
loss: 0.2122, smth: 0.2136: 100%|██████████| 13903/13903 [09:58<00:00, 23.24it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.69it/s]


Wed Jun  9 12:49:35 2021 Epoch 7, lr: 0.0290954, train loss: 0.2154, valid loss: 0.2129, mean_rce: 21.66, retweet: 27.00, reply: 22.49, like: 25.67, retweet_comment: 11.46
rce_best increased (20.870073 --> 21.657585).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 12:53:46 2021 Epoch: 8


100%|██████████| 5/5 [05:43<00:00, 68.68s/it]
loss: 0.2307, smth: 0.2131: 100%|██████████| 14638/14638 [10:28<00:00, 23.30it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.74it/s]


Wed Jun  9 13:12:19 2021 Epoch 8, lr: 0.0287032, train loss: 0.2134, valid loss: 0.2109, mean_rce: 22.41, retweet: 27.82, reply: 23.24, like: 26.27, retweet_comment: 12.30
rce_best increased (21.657585 --> 22.409573).  Saving model ...
Wed Jun  9 13:16:34 2021 Epoch: 9


100%|██████████| 5/5 [05:43<00:00, 68.62s/it]
loss: 0.2094, smth: 0.2113: 100%|██████████| 14635/14635 [10:30<00:00, 23.22it/s]
100%|██████████| 10083/10083 [02:11<00:00, 76.56it/s]


Wed Jun  9 13:35:10 2021 Epoch 9, lr: 0.0282442, train loss: 0.2119, valid loss: 0.2091, mean_rce: 23.04, retweet: 28.34, reply: 24.00, like: 26.93, retweet_comment: 12.90
rce_best increased (22.409573 --> 23.043938).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 13:39:28 2021 Epoch: 10


100%|██████████| 5/5 [05:31<00:00, 66.31s/it]
loss: 0.2159, smth: 0.2092: 100%|██████████| 13865/13865 [10:02<00:00, 23.00it/s]
100%|██████████| 10083/10083 [02:10<00:00, 77.53it/s]


Wed Jun  9 13:57:22 2021 Epoch 10, lr: 0.0277207, train loss: 0.2104, valid loss: 0.2075, mean_rce: 23.62, retweet: 28.95, reply: 24.52, like: 27.47, retweet_comment: 13.54
rce_best increased (23.043938 --> 23.618578).  Saving model ...
Wed Jun  9 14:01:36 2021 Epoch: 11


100%|██████████| 5/5 [05:33<00:00, 66.65s/it]
loss: 0.2008, smth: 0.2083: 100%|██████████| 13673/13673 [09:52<00:00, 23.06it/s]
100%|██████████| 10083/10083 [02:09<00:00, 77.97it/s]


Wed Jun  9 14:19:21 2021 Epoch 11, lr: 0.0271353, train loss: 0.2091, valid loss: 0.2064, mean_rce: 24.06, retweet: 29.31, reply: 24.96, like: 27.89, retweet_comment: 14.08
rce_best increased (23.618578 --> 24.061878).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 14:23:42 2021 Epoch: 12


100%|██████████| 5/5 [06:10<00:00, 74.13s/it]
loss: 0.2144, smth: 0.2080: 100%|██████████| 14359/14359 [10:15<00:00, 23.33it/s]
100%|██████████| 10083/10083 [02:05<00:00, 80.30it/s]


Wed Jun  9 14:42:22 2021 Epoch 12, lr: 0.0264907, train loss: 0.2083, valid loss: 0.2052, mean_rce: 24.53, retweet: 29.80, reply: 25.50, like: 28.25, retweet_comment: 14.58
rce_best increased (24.061878 --> 24.532425).  Saving model ...
Wed Jun  9 14:46:38 2021 Epoch: 13


100%|██████████| 5/5 [05:25<00:00, 65.02s/it]
loss: 0.2034, smth: 0.2076: 100%|██████████| 13714/13714 [09:49<00:00, 23.28it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.99it/s]


Wed Jun  9 15:04:08 2021 Epoch 13, lr: 0.0257901, train loss: 0.2073, valid loss: 0.2035, mean_rce: 24.99, retweet: 30.18, reply: 25.84, like: 28.98, retweet_comment: 14.96
rce_best increased (24.532425 --> 24.989592).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 15:08:21 2021 Epoch: 14


100%|██████████| 5/5 [05:31<00:00, 66.39s/it]
loss: 0.1909, smth: 0.2038: 100%|██████████| 14312/14312 [10:17<00:00, 23.17it/s]
100%|██████████| 10083/10083 [02:05<00:00, 80.34it/s]


Wed Jun  9 15:26:23 2021 Epoch 14, lr: 0.0250370, train loss: 0.2064, valid loss: 0.2030, mean_rce: 25.31, retweet: 30.45, reply: 26.32, like: 29.06, retweet_comment: 15.41
rce_best increased (24.989592 --> 25.310909).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 15:30:42 2021 Epoch: 15


100%|██████████| 5/5 [05:32<00:00, 66.45s/it]
loss: 0.1976, smth: 0.2053: 100%|██████████| 14636/14636 [10:24<00:00, 23.44it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.56it/s]


Wed Jun  9 15:48:53 2021 Epoch 15, lr: 0.0242349, train loss: 0.2055, valid loss: 0.2017, mean_rce: 25.69, retweet: 30.82, reply: 26.67, like: 29.55, retweet_comment: 15.72
rce_best increased (25.310909 --> 25.691736).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 15:53:08 2021 Epoch: 16


100%|██████████| 5/5 [05:26<00:00, 65.39s/it]
loss: 0.2233, smth: 0.2078: 100%|██████████| 14522/14522 [10:13<00:00, 23.66it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.54it/s]


Wed Jun  9 16:11:04 2021 Epoch 16, lr: 0.0233879, train loss: 0.2048, valid loss: 0.2009, mean_rce: 26.03, retweet: 31.05, reply: 27.06, like: 29.84, retweet_comment: 16.17
rce_best increased (25.691736 --> 26.028601).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 16:15:18 2021 Epoch: 17


100%|██████████| 5/5 [05:34<00:00, 66.82s/it]
loss: 0.1966, smth: 0.2041: 100%|██████████| 14856/14856 [10:38<00:00, 23.26it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.55it/s]


Wed Jun  9 16:33:45 2021 Epoch 17, lr: 0.0225000, train loss: 0.2040, valid loss: 0.1997, mean_rce: 26.42, retweet: 31.49, reply: 27.33, like: 30.32, retweet_comment: 16.53
rce_best increased (26.028601 --> 26.415928).  Saving model ...
Wed Jun  9 16:38:04 2021 Epoch: 18


100%|██████████| 5/5 [05:27<00:00, 65.60s/it]
loss: 0.1999, smth: 0.2032: 100%|██████████| 14637/14637 [10:22<00:00, 23.50it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.67it/s]


Wed Jun  9 16:56:09 2021 Epoch 18, lr: 0.0215756, train loss: 0.2034, valid loss: 0.1986, mean_rce: 26.75, retweet: 31.83, reply: 27.63, like: 30.74, retweet_comment: 16.82
rce_best increased (26.415928 --> 26.753586).  Saving model ...
Wed Jun  9 17:00:36 2021 Epoch: 19


100%|██████████| 5/5 [05:20<00:00, 64.05s/it]
loss: 0.2108, smth: 0.2028: 100%|██████████| 13821/13821 [09:50<00:00, 23.42it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.45it/s]


Wed Jun  9 17:18:01 2021 Epoch 19, lr: 0.0206191, train loss: 0.2020, valid loss: 0.1978, mean_rce: 27.05, retweet: 32.08, reply: 27.98, like: 31.01, retweet_comment: 17.13
rce_best increased (26.753586 --> 27.051285).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 17:22:19 2021 Epoch: 20


100%|██████████| 5/5 [05:26<00:00, 65.24s/it]
loss: 0.1941, smth: 0.2009: 100%|██████████| 14182/14182 [10:22<00:00, 22.77it/s]
100%|██████████| 10083/10083 [02:12<00:00, 75.88it/s]


Wed Jun  9 17:40:28 2021 Epoch 20, lr: 0.0196353, train loss: 0.2023, valid loss: 0.1975, mean_rce: 27.18, retweet: 32.24, reply: 28.07, like: 31.09, retweet_comment: 17.33
rce_best increased (27.051285 --> 27.183308).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 17:46:40 2021 Epoch: 21


100%|██████████| 5/5 [05:53<00:00, 70.70s/it]
loss: 0.1907, smth: 0.2012: 100%|██████████| 14296/14296 [10:30<00:00, 22.66it/s]
100%|██████████| 10083/10083 [02:10<00:00, 77.21it/s]


Wed Jun  9 18:05:23 2021 Epoch 21, lr: 0.0186288, train loss: 0.2018, valid loss: 0.1970, mean_rce: 27.42, retweet: 32.50, reply: 28.36, like: 31.21, retweet_comment: 17.61
rce_best increased (27.183308 --> 27.420492).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 18:09:41 2021 Epoch: 22


100%|██████████| 5/5 [05:31<00:00, 66.28s/it]
loss: 0.1963, smth: 0.2001: 100%|██████████| 13550/13550 [09:56<00:00, 22.73it/s]
100%|██████████| 10083/10083 [02:14<00:00, 75.05it/s]


Wed Jun  9 18:27:33 2021 Epoch 22, lr: 0.0176047, train loss: 0.2013, valid loss: 0.1960, mean_rce: 27.72, retweet: 32.68, reply: 28.69, like: 31.65, retweet_comment: 17.88
rce_best increased (27.420492 --> 27.723713).  Saving model ...
Wed Jun  9 18:32:56 2021 Epoch: 23


100%|██████████| 5/5 [05:53<00:00, 70.77s/it]
loss: 0.2037, smth: 0.1998: 100%|██████████| 14635/14635 [10:35<00:00, 23.02it/s]
100%|██████████| 10083/10083 [02:12<00:00, 76.13it/s]


Wed Jun  9 18:51:49 2021 Epoch 23, lr: 0.0165679, train loss: 0.2009, valid loss: 0.1955, mean_rce: 27.88, retweet: 32.88, reply: 28.78, like: 31.79, retweet_comment: 18.09
rce_best increased (27.723713 --> 27.884253).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 18:57:07 2021 Epoch: 24


100%|██████████| 5/5 [05:44<00:00, 68.99s/it]
loss: 0.2228, smth: 0.2022: 100%|██████████| 14180/14180 [10:24<00:00, 22.72it/s]
100%|██████████| 10083/10083 [02:10<00:00, 77.46it/s]


Wed Jun  9 19:15:34 2021 Epoch 24, lr: 0.0155235, train loss: 0.2004, valid loss: 0.1951, mean_rce: 28.06, retweet: 33.09, reply: 28.94, like: 31.92, retweet_comment: 18.30
rce_best increased (27.884253 --> 28.062990).  Saving model ...
Wed Jun  9 19:20:25 2021 Epoch: 25


100%|██████████| 5/5 [05:33<00:00, 66.76s/it]
loss: 0.2099, smth: 0.1988: 100%|██████████| 13862/13862 [09:50<00:00, 23.48it/s]
100%|██████████| 10083/10083 [02:04<00:00, 81.13it/s]


Wed Jun  9 19:38:03 2021 Epoch 25, lr: 0.0144765, train loss: 0.2001, valid loss: 0.1947, mean_rce: 28.16, retweet: 33.13, reply: 29.13, like: 32.08, retweet_comment: 18.29
rce_best increased (28.062990 --> 28.158478).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 19:42:18 2021 Epoch: 26


100%|██████████| 5/5 [05:30<00:00, 66.19s/it]
loss: 0.2164, smth: 0.2014: 100%|██████████| 14634/14634 [10:38<00:00, 22.91it/s]
100%|██████████| 10083/10083 [02:12<00:00, 76.24it/s]


Wed Jun  9 20:00:49 2021 Epoch 26, lr: 0.0134321, train loss: 0.1998, valid loss: 0.1941, mean_rce: 28.43, retweet: 33.43, reply: 29.34, like: 32.26, retweet_comment: 18.69
rce_best increased (28.158478 --> 28.429626).  Saving model ...
Wed Jun  9 20:05:14 2021 Epoch: 27


100%|██████████| 5/5 [05:47<00:00, 69.44s/it]
loss: 0.2084, smth: 0.2003: 100%|██████████| 14671/14671 [10:44<00:00, 22.77it/s]
100%|██████████| 10083/10083 [02:12<00:00, 76.21it/s]


Wed Jun  9 20:24:07 2021 Epoch 27, lr: 0.0123953, train loss: 0.1994, valid loss: 0.1935, mean_rce: 28.58, retweet: 33.61, reply: 29.36, like: 32.55, retweet_comment: 18.80
rce_best increased (28.429626 --> 28.576244).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 20:29:28 2021 Epoch: 28


100%|██████████| 5/5 [06:16<00:00, 75.24s/it]
loss: 0.2103, smth: 0.1991: 100%|██████████| 14626/14626 [10:44<00:00, 22.70it/s]
100%|██████████| 10083/10083 [02:10<00:00, 77.26it/s]


Wed Jun  9 20:48:51 2021 Epoch 28, lr: 0.0113712, train loss: 0.1990, valid loss: 0.1934, mean_rce: 28.71, retweet: 33.73, reply: 29.66, like: 32.49, retweet_comment: 18.96
rce_best increased (28.576244 --> 28.706108).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 20:53:11 2021 Epoch: 29


100%|██████████| 5/5 [06:09<00:00, 73.93s/it]
loss: 0.1955, smth: 0.1996: 100%|██████████| 14668/14668 [10:44<00:00, 22.77it/s]
100%|██████████| 10083/10083 [02:12<00:00, 76.22it/s]


Wed Jun  9 21:12:25 2021 Epoch 29, lr: 0.0103647, train loss: 0.1989, valid loss: 0.1928, mean_rce: 28.86, retweet: 33.79, reply: 29.76, like: 32.76, retweet_comment: 19.12
rce_best increased (28.706108 --> 28.857880).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 21:16:56 2021 Epoch: 30


100%|██████████| 5/5 [05:53<00:00, 70.71s/it]
loss: 0.2075, smth: 0.1985: 100%|██████████| 14645/14645 [10:50<00:00, 22.50it/s]
100%|██████████| 10083/10083 [02:09<00:00, 78.01it/s]


Wed Jun  9 21:35:59 2021 Epoch 30, lr: 0.0093809, train loss: 0.1986, valid loss: 0.1922, mean_rce: 29.04, retweet: 33.88, reply: 30.02, like: 33.02, retweet_comment: 19.23
rce_best increased (28.857880 --> 29.039631).  Saving model ...
Wed Jun  9 21:41:23 2021 Epoch: 31


100%|██████████| 5/5 [05:08<00:00, 61.76s/it]
loss: 0.2006, smth: 0.1987: 100%|██████████| 13018/13018 [09:10<00:00, 23.63it/s]
100%|██████████| 10083/10083 [02:07<00:00, 78.91it/s]


Wed Jun  9 21:57:58 2021 Epoch 31, lr: 0.0084244, train loss: 0.1983, valid loss: 0.1923, mean_rce: 29.05, retweet: 33.99, reply: 30.02, like: 32.88, retweet_comment: 19.29
rce_best increased (29.039631 --> 29.046562).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 22:02:17 2021 Epoch: 32


100%|██████████| 5/5 [05:46<00:00, 69.32s/it]
loss: 0.1971, smth: 0.1980: 100%|██████████| 13865/13865 [10:15<00:00, 22.52it/s]
100%|██████████| 10083/10083 [02:09<00:00, 77.64it/s]


Wed Jun  9 22:20:37 2021 Epoch 32, lr: 0.0075000, train loss: 0.1971, valid loss: 0.1916, mean_rce: 29.27, retweet: 34.11, reply: 30.18, like: 33.23, retweet_comment: 19.56
rce_best increased (29.046562 --> 29.265709).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 22:24:53 2021 Epoch: 33


100%|██████████| 5/5 [05:47<00:00, 69.44s/it]
loss: 0.1975, smth: 0.1982: 100%|██████████| 14470/14470 [10:15<00:00, 23.50it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.94it/s]


Wed Jun  9 22:43:10 2021 Epoch 33, lr: 0.0066121, train loss: 0.1979, valid loss: 0.1917, mean_rce: 29.23, retweet: 34.11, reply: 30.15, like: 33.17, retweet_comment: 19.49


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 22:46:18 2021 Epoch: 34


100%|██████████| 5/5 [05:34<00:00, 66.81s/it]
loss: 0.1877, smth: 0.1977: 100%|██████████| 14637/14637 [10:31<00:00, 23.17it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.96it/s]


Wed Jun  9 23:04:38 2021 Epoch 34, lr: 0.0057651, train loss: 0.1979, valid loss: 0.1914, mean_rce: 29.31, retweet: 34.25, reply: 30.14, like: 33.30, retweet_comment: 19.57
rce_best increased (29.265709 --> 29.314360).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 23:08:50 2021 Epoch: 35


100%|██████████| 5/5 [05:36<00:00, 67.32s/it]
loss: 0.2132, smth: 0.1977: 100%|██████████| 14635/14635 [10:25<00:00, 23.40it/s]
100%|██████████| 10083/10083 [02:08<00:00, 78.49it/s]


Wed Jun  9 23:27:08 2021 Epoch 35, lr: 0.0049630, train loss: 0.1976, valid loss: 0.1913, mean_rce: 29.39, retweet: 34.26, reply: 30.32, like: 33.31, retweet_comment: 19.67
rce_best increased (29.314360 --> 29.390415).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Wed Jun  9 23:31:37 2021 Epoch: 36


100%|██████████| 5/5 [05:30<00:00, 66.16s/it]
loss: 0.1883, smth: 0.1965: 100%|██████████| 13586/13586 [09:43<00:00, 23.30it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.82it/s]


Wed Jun  9 23:49:05 2021 Epoch 36, lr: 0.0042099, train loss: 0.1975, valid loss: 0.1915, mean_rce: 29.40, retweet: 34.30, reply: 30.38, like: 33.18, retweet_comment: 19.72
rce_best increased (29.390415 --> 29.396254).  Saving model ...
Wed Jun  9 23:53:27 2021 Epoch: 37


100%|██████████| 5/5 [05:35<00:00, 67.19s/it]
loss: 0.2002, smth: 0.1991: 100%|██████████| 14526/14526 [10:30<00:00, 23.05it/s]
100%|██████████| 10083/10083 [02:12<00:00, 76.38it/s]


Thu Jun 10 00:11:53 2021 Epoch 37, lr: 0.0035093, train loss: 0.1974, valid loss: 0.1910, mean_rce: 29.43, retweet: 34.32, reply: 30.19, like: 33.48, retweet_comment: 19.75
rce_best increased (29.396254 --> 29.433987).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 00:16:14 2021 Epoch: 38


100%|██████████| 5/5 [06:01<00:00, 72.33s/it]
loss: 0.2041, smth: 0.1969: 100%|██████████| 14636/14636 [10:51<00:00, 22.48it/s]
100%|██████████| 10083/10083 [02:14<00:00, 74.95it/s]


Thu Jun 10 00:35:36 2021 Epoch 38, lr: 0.0028647, train loss: 0.1974, valid loss: 0.1908, mean_rce: 29.54, retweet: 34.46, reply: 30.43, like: 33.52, retweet_comment: 19.75
rce_best increased (29.433987 --> 29.540037).  Saving model ...
Thu Jun 10 00:41:03 2021 Epoch: 39


100%|██████████| 5/5 [05:46<00:00, 69.21s/it]
loss: 0.2044, smth: 0.1990: 100%|██████████| 14637/14637 [10:22<00:00, 23.52it/s]
100%|██████████| 10083/10083 [02:05<00:00, 80.10it/s]


Thu Jun 10 00:59:25 2021 Epoch 39, lr: 0.0022793, train loss: 0.1975, valid loss: 0.1908, mean_rce: 29.58, retweet: 34.46, reply: 30.52, like: 33.47, retweet_comment: 19.88
rce_best increased (29.540037 --> 29.581142).  Saving model ...
Thu Jun 10 01:03:51 2021 Epoch: 40


100%|██████████| 5/5 [05:22<00:00, 64.51s/it]
loss: 0.1781, smth: 0.1948: 100%|██████████| 13508/13508 [09:31<00:00, 23.66it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.76it/s]


Thu Jun 10 01:20:59 2021 Epoch 40, lr: 0.0017558, train loss: 0.1963, valid loss: 0.1902, mean_rce: 29.75, retweet: 34.62, reply: 30.63, like: 33.72, retweet_comment: 20.02
rce_best increased (29.581142 --> 29.746868).  Saving model ...


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 01:25:18 2021 Epoch: 41


100%|██████████| 5/5 [05:35<00:00, 67.18s/it]
loss: 0.1955, smth: 0.1976: 100%|██████████| 14637/14637 [10:19<00:00, 23.64it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.79it/s]


Thu Jun 10 01:43:27 2021 Epoch 41, lr: 0.0012968, train loss: 0.1973, valid loss: 0.1909, mean_rce: 29.54, retweet: 34.48, reply: 30.41, like: 33.40, retweet_comment: 19.89
Thu Jun 10 01:46:41 2021 Epoch: 42


100%|██████████| 5/5 [05:30<00:00, 66.03s/it]
loss: 0.1919, smth: 0.1975: 100%|██████████| 14473/14473 [10:12<00:00, 23.63it/s]
100%|██████████| 10083/10083 [02:05<00:00, 80.18it/s]


Thu Jun 10 02:04:37 2021 Epoch 42, lr: 0.0009046, train loss: 0.1972, valid loss: 0.1906, mean_rce: 29.61, retweet: 34.41, reply: 30.50, like: 33.59, retweet_comment: 19.93


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 02:07:41 2021 Epoch: 43


100%|██████████| 5/5 [05:28<00:00, 65.63s/it]
loss: 0.2087, smth: 0.1948: 100%|██████████| 14187/14187 [09:59<00:00, 23.68it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.88it/s]


Thu Jun 10 02:25:22 2021 Epoch 43, lr: 0.0005811, train loss: 0.1971, valid loss: 0.1905, mean_rce: 29.63, retweet: 34.48, reply: 30.49, like: 33.64, retweet_comment: 19.93


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 02:28:33 2021 Epoch: 44


100%|██████████| 5/5 [05:33<00:00, 66.78s/it]
loss: 0.1962, smth: 0.1953: 100%|██████████| 14784/14784 [10:23<00:00, 23.72it/s]
100%|██████████| 10083/10083 [02:06<00:00, 79.55it/s]


Thu Jun 10 02:46:45 2021 Epoch 44, lr: 0.0003278, train loss: 0.1971, valid loss: 0.1906, mean_rce: 29.65, retweet: 34.48, reply: 30.57, like: 33.58, retweet_comment: 19.97


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 02:49:57 2021 Epoch: 45


100%|██████████| 5/5 [05:26<00:00, 65.32s/it]
loss: 0.2121, smth: 0.1985: 100%|██████████| 13823/13823 [09:44<00:00, 23.64it/s]
100%|██████████| 10083/10083 [02:05<00:00, 80.11it/s]


Thu Jun 10 03:07:22 2021 Epoch 45, lr: 0.0001460, train loss: 0.1970, valid loss: 0.1904, mean_rce: 29.67, retweet: 34.57, reply: 30.48, like: 33.68, retweet_comment: 19.94


  0%|          | 0/5 [00:00<?, ?it/s]

Thu Jun 10 03:10:32 2021 Epoch: 46


100%|██████████| 5/5 [05:29<00:00, 65.88s/it]
loss: 0.1866, smth: 0.1955: 100%|██████████| 14635/14635 [10:16<00:00, 23.73it/s]
100%|██████████| 10083/10083 [02:05<00:00, 80.44it/s]


Thu Jun 10 03:28:32 2021 Epoch 46, lr: 0.0000365, train loss: 0.1971, valid loss: 0.1907, mean_rce: 29.61, retweet: 34.41, reply: 30.53, like: 33.56, retweet_comment: 19.94


In [None]:
# 20 parts, 5 epochs
mean_rce: 16.51, retweet: 20.27, reply: 18.26, like: 20.17, retweet_comment: 7.35
                    
Epoch 20, lr: 0.0000794, train loss: 0.2142, valid loss: 0.2233, 
mean_rce: 18.66, retweet: 23.00, reply: 20.36, like: 21.72, retweet_comment: 9.56     
                    
Epoch 35, lr: 0.0001654, train loss: 0.2105, valid loss: 0.2220, 
mean_rce: 19.25, retweet: 23.58, reply: 21.08, like: 22.09, retweet_comment: 10.27                    

# xgb feat NN                    
mean_rce: 20.25, retweet: 23.39, reply: 19.07, like: 13.02, retweet_comment: 25.54                    

## load best ep and inference

In [27]:
sd = torch.load(f'../models/{model_name}_best.pth')
sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
model.load_state_dict(sd, strict=True)

<All keys matched successfully>

In [28]:
label_names = sorted(label_names)
label_names

['like', 'reply', 'retweet', 'retweet_comment']

In [29]:
model.eval()
val_loss = []
LOGITS = []
TARGETS = []
with torch.no_grad():
    for batch in tqdm(valid_loader):
        x_cat, x_cont, text_tok, targets = batch
        x_cat = x_cat.cuda()     
        x_cont = x_cont.cuda()
        text_tok = text_tok.cuda()
        targets = targets.cuda()            
        logits = model(x_cat, x_cont, text_tok)
        loss = criterion(logits, targets)
        val_loss.append(loss.item())
        LOGITS.append(logits.cpu())
        TARGETS.append(targets.cpu())

LOGITS = torch.cat(LOGITS)
TARGETS = torch.cat(TARGETS)
rce = {}
for i in range(4):
    rce[label_names[i]] = compute_rce_fast(cp.asarray(LOGITS[:,i].sigmoid()),cp.asarray(TARGETS[:,i])).get()            
mean_rce = np.mean([v for k,v in rce.items()])
mean_rce

100%|██████████| 10083/10083 [01:43<00:00, 97.81it/s] 


20.13601

In [30]:
# df_quantile = pd.concat([pd.read_parquet(path)[['quantile']] for path in VALID_PATHS]).reset_index(drop=True)
# df_quantile = df_quantile.apply(np.expm1).round().astype(int)
df_quantile = valid[['quantile']].copy().reset_index(drop=True)
df_quantile.shape

yquantile = cupy.asarray(df_quantile.values)
oof = cupy.asarray(LOGITS.sigmoid())
yvalid = cupy.asarray(TARGETS)

In [31]:
from util import compute_prauc, average_precision_score,display_score

rce_output = {}
ap_output = {}
for i in range(4):
    prauc_out = []
    rce_out = []
    ap_out = []
    for j in range(5):
        this_quantile_idx = (df_quantile == j)['quantile'].values
        yvalid_tmp = yvalid[this_quantile_idx][:, i]
        oof_tmp = oof[this_quantile_idx][:, i]
        prauc = compute_prauc(oof_tmp, yvalid_tmp)
        rce   = compute_rce_fast(oof_tmp, yvalid_tmp).item()
        ap    = average_precision_score(cupy.asnumpy(yvalid_tmp),cupy.asnumpy(oof_tmp))
        prauc_out.append(prauc)
        rce_out.append(rce)
        ap_out.append(ap)
    rce_output[label_names[i]] = rce_out
    ap_output[label_names[i]] = ap_out

In [32]:
print(model_name)
display_score(rce_output, ap_output)

gru_cat5_cont36_frzemb768_gru128_len64_thr50_3weeks
Quantile Group|AP Retweet|RCE Retweet|  AP Reply|  RCE Reply|   AP Like|   RCE Like|AP RT comment|RCE RT comment
        0          0.4827     25.3136     0.2257     18.1147     0.7328     17.5110     0.0545      9.7572
        1          0.4643     24.6773     0.2055     18.0764     0.7295     17.6294     0.0554     10.1372
        2          0.4435     23.6253     0.2155     18.7955     0.7326     18.1288     0.0489      8.8695
        3          0.4352     23.1426     0.2309     19.5610     0.7339     18.9700     0.0433      8.8682
        4          0.4320     24.3085     0.2189     21.0982     0.7610     27.1258     0.0494     11.0005
     Average       0.4515     24.2134     0.2193     19.1291     0.7379     19.8730     0.0503      9.7265


In [28]:
# XGB
%%time
display_score(rce_output, ap_output)

Quantile Group|AP Retweet|RCE Retweet|  AP Reply|  RCE Reply|   AP Like|   RCE Like|AP RT comment|RCE RT comment
        0          0.4767     23.8403     0.2466     19.7694     0.7496     18.7935     0.0803     11.3031
        1          0.4608     24.3840     0.2391     20.9897     0.7444     20.0392     0.0680     10.6829
        2          0.4501     24.8955     0.2513     22.2341     0.7393     20.2754     0.0682     11.2699
        3          0.4384     24.5673     0.2677     23.6996     0.7334     20.9961     0.0662     11.5356
        4          0.4124     24.7017     0.2411     24.6334     0.7059     20.6678     0.0710     15.1545
     Average       0.4477     24.4778     0.2492     22.2652     0.7345     20.1544     0.0707     11.9892
CPU times: user 0 ns, sys: 1 ms, total: 1 ms
Wall time: 950 µs
