## Import libraries and classes, set parameters

In [1]:
import re
from itertools import product
from collections import Counter

from tqdm.auto import tqdm

import pandas as pd
import numpy as np

import torch
from torch import nn
from torch.utils.data import DataLoader

from sklearn.preprocessing import MinMaxScaler

import transformers

from models.TweetDataset import TweetDataset
from models.Wd_Xlm_T import Wd_Xlm_T

from utils.TargetEncoder import TargetEncoder

In [2]:
DATA_PATH = "./data/"
CHECKPOINT_DIR = "./checkpoints/"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(DEVICE)

cuda


In [3]:
orig_features = [
    'text_tokens',    ###############
    'hashtags',       #Tweet Features
    'tweet_id',       #
    'media',          #
    'links',          #
    'domains',        #
    'tweet_type',     #
    'language',       #
    'timestamp',      ###############
    'engaged_with_user_id',              ###########################
    'engaged_with_user_follower_count',  #Engaged With User Features
    'engaged_with_user_following_count', #
    'engaged_with_user_is_verified',     #
    'engaged_with_user_account_creation', ###########################
    'engaging_user_id',                  #######################
    'engaging_user_follower_count',      #Engaging User Features
    'engaging_user_following_count',     #
    'engaging_user_is_verified',         #
    'engaging_user_account_creation',    #######################
    'engagee_follows_engager',    #################### Engagement Features
    'reply',          #Target Reply
    'retweet',        #Target Retweet    
    'retweet_comment',#Target Retweet with comment
    'like',           #Target Like
                      ####################
]

target_features = orig_features[-4:]
numerical_features = ['engaged_with_user_follower_count', 'engaged_with_user_following_count', 
                      'engaging_user_follower_count', 'engaging_user_following_count', 'url_cnt',
                      'char_cnt', 'hashtag_cnt', 'Photo_cnt', 'Video_cnt', 'GIF_cnt']
categorical_features = ['language', 'engaged_with_user_id', 'engaging_user_id', 'tweet_type']


cat_target_prod = product(categorical_features, target_features)
features = []
for (cat, target) in cat_target_prod:
    features.append(cat+"_"+target+"_TE")

m = 20
MAX_LEN = 100

## Load and preprocess data

In [4]:
orig_df = pd.read_csv(DATA_PATH+"dataset_filtered_small.csv")
orig_df.sort_values(by=['timestamp'], inplace=True)
orig_df.drop("Unnamed: 0", axis=1, inplace=True)
orig_df.head()

Unnamed: 0,text_tokens,hashtags,tweet_id,media,links,domains,tweet_type,language,timestamp,engaged_with_user_id,...,engaging_user_id,engaging_user_follower_count,engaging_user_following_count,engaging_user_is_verified,engaging_user_account_creation,engagee_follows_engager,reply,retweet,retweet_comment,like
1120596,101\t12148\t11675\t14707\t117\t17924\t16266\t2...,E8EC1049F02FE3900B1E45D1BDD52BEF\tF0F29CEE3668...,20C5F3A6F47B4E7A85A1443CC3D12B0C,Video,,,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,1612396800,E400C001A195BD92CB74AA0B2E2BB522,...,7327C6D707CE1C0F8DC376B8F6AA6B3F,49,237,False,1590030909,False,,,,
6744778,101\t14120\t131\t120\t120\t188\t119\t11170\t12...,,ABC98352C6238B3129A1772122532156,GIF,,,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,1612396800,80C3FD8645A74F589C103A1F9A3C40E5,...,75CC5A829BFDBECDFCDD0868EFAC04FF,79,163,False,1404873502,False,,,,1612414000.0
1536740,101\t14120\t131\t120\t120\t188\t119\t11170\t12...,,120E4BA71617DDFA622E0263783D04B2,Photo\tPhoto,,,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,1612396800,3413E19B696E35FC1B6567C260ED2E0B,...,85351873930358043D54DEC159B30E59,252,728,False,1470684663,False,,1612397000.0,,
5782609,101\t14200\t10182\t24248\t10125\t45411\t119\t1...,,AC3770889BCE6879E01ACC3D675CCD5C,Photo,,,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,1612396800,C530D3D968FBD0D08537C5EDEEDEE542,...,943B164BB3B9EE48C8F1487BCF731994,138,1108,False,1260577260,False,,,,1612399000.0
2289432,101\t20452\t10142\t14908\t10841\t17565\t80677\...,,AC97A848CC9CE41F42D566B2C5EAAB45,Video,,,TopLevel,488B32D24BD4BB44172EB981C1BCA6FA,1612396800,E7A2CEC020385D583CA1F15C53E671DD,...,50E8C399099EEAB5363E786D020917A7,101,294,False,1532466558,False,,,,1612465000.0


In [5]:
print(orig_df['engaged_with_user_id'].nunique())

1549966


In [6]:
print(orig_df['engaging_user_id'].nunique())

49458


#### Convert tokens to text and extract features from text

In [7]:
def count_urls(text):
    url_cnt = 0
    
    for url in re.finditer("(https : \/(\s(\/\s)?\w+(\.)?)+)", text):
        url_cnt += 1
        
    return url_cnt


def char_count(text):
    char_count = 0
    
    for c in text:
        if(c != ' ' and c != '\t'
           and c != '\n' and c != '\r'):
            char_count += 1
            
    return char_count

In [8]:
def remove_or_replace_urls(tweet, replace=True):
    str_list = []
    prev_idx = 0
    for url in re.finditer("(https : \/(\s(\/\s)?\w+(\.)?)+)", tweet):
        str_list.append(tweet[prev_idx:url.start()])
        
        if(replace):
            str_list.append("[LINK]")
            
        prev_idx = url.end() + 1
        
    if(len(str_list)==0):
        return tweet
    return "".join(str_list)


def add_converted_text(df, tokenizer):
    text = df['text_tokens'].map(lambda x: tokenizer.decode([int(s) for s in x.split('\t')])).values
    df.drop('text_tokens', axis='columns', inplace=True)
    df.loc[:, 'text'] = text
    
    return df



def add_text_extract_feats(df, feature_extractors, tokenizer):
    df = add_converted_text(df, tokenizer)
    
    for feat_name, feat_func in feature_extractors.items():
        feature = df['text'].map(lambda txt: feat_func(txt))
        
        df.loc[:, feat_name] = feature
    
    return df

In [9]:
bert_multilingual_tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-multilingual-cased")

feature_extractors = {
    'url_cnt': count_urls,
    'char_cnt': char_count,
}

text_features_df = add_text_extract_feats(orig_df.copy(), feature_extractors, bert_multilingual_tokenizer)
text_features_df.head()

Unnamed: 0,hashtags,tweet_id,media,links,domains,tweet_type,language,timestamp,engaged_with_user_id,engaged_with_user_follower_count,...,engaging_user_is_verified,engaging_user_account_creation,engagee_follows_engager,reply,retweet,retweet_comment,like,text,url_cnt,char_cnt
1120596,E8EC1049F02FE3900B1E45D1BDD52BEF\tF0F29CEE3668...,20C5F3A6F47B4E7A85A1443CC3D12B0C,Video,,,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,1612396800,E400C001A195BD92CB74AA0B2E2BB522,6460161,...,False,1590030909,False,,,,,"[CLS] Una vez dentro, sólo habrá una salida. ¶...",1,166
6744778,,ABC98352C6238B3129A1772122532156,GIF,,,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,1612396800,80C3FD8645A74F589C103A1F9A3C40E5,53204,...,False,1404873502,False,,,,1612414000.0,[CLS] https : / / t. co / kJysqCq6UR [SEP],1,33
1536740,,120E4BA71617DDFA622E0263783D04B2,Photo\tPhoto,,,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,1612396800,3413E19B696E35FC1B6567C260ED2E0B,96565,...,False,1470684663,False,,1612397000.0,,,[CLS] https : / / t. co / r9ltA3k5k9 [SEP],1,33
5782609,,AC3770889BCE6879E01ACC3D675CCD5C,Photo,,,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,1612396800,C530D3D968FBD0D08537C5EDEEDEE542,6569,...,False,1260577260,False,,,,1612399000.0,[CLS] Como los trata el calor.... a mi asi! ht...,1,63
2289432,,AC97A848CC9CE41F42D566B2C5EAAB45,Video,,,TopLevel,488B32D24BD4BB44172EB981C1BCA6FA,1612396800,E7A2CEC020385D583CA1F15C53E671DD,3600646,...,False,1532466558,False,,,,1612465000.0,[CLS] Never forget when Pop Smoke dropped two ...,1,118


In [10]:
del orig_df

#### Add additional features

In [11]:
def extract_hashtag_cnt_from_obj(hashtags):
    if(not pd.isna(hashtags)):
        return len(hashtags.split('\t'))
    
    return 0

def add_hashtag_cnt(df):
    hashtag_cnt = df['hashtags'].map(lambda x: extract_hashtag_cnt_from_obj(x)).values
    df.loc[:, 'hashtag_cnt'] = hashtag_cnt

    return df

new_df = add_hashtag_cnt(text_features_df)
new_df.head()

Unnamed: 0,hashtags,tweet_id,media,links,domains,tweet_type,language,timestamp,engaged_with_user_id,engaged_with_user_follower_count,...,engaging_user_account_creation,engagee_follows_engager,reply,retweet,retweet_comment,like,text,url_cnt,char_cnt,hashtag_cnt
1120596,E8EC1049F02FE3900B1E45D1BDD52BEF\tF0F29CEE3668...,20C5F3A6F47B4E7A85A1443CC3D12B0C,Video,,,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,1612396800,E400C001A195BD92CB74AA0B2E2BB522,6460161,...,1590030909,False,,,,,"[CLS] Una vez dentro, sólo habrá una salida. ¶...",1,166,5
6744778,,ABC98352C6238B3129A1772122532156,GIF,,,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,1612396800,80C3FD8645A74F589C103A1F9A3C40E5,53204,...,1404873502,False,,,,1612414000.0,[CLS] https : / / t. co / kJysqCq6UR [SEP],1,33,0
1536740,,120E4BA71617DDFA622E0263783D04B2,Photo\tPhoto,,,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,1612396800,3413E19B696E35FC1B6567C260ED2E0B,96565,...,1470684663,False,,1612397000.0,,,[CLS] https : / / t. co / r9ltA3k5k9 [SEP],1,33,0
5782609,,AC3770889BCE6879E01ACC3D675CCD5C,Photo,,,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,1612396800,C530D3D968FBD0D08537C5EDEEDEE542,6569,...,1260577260,False,,,,1612399000.0,[CLS] Como los trata el calor.... a mi asi! ht...,1,63,0
2289432,,AC97A848CC9CE41F42D566B2C5EAAB45,Video,,,TopLevel,488B32D24BD4BB44172EB981C1BCA6FA,1612396800,E7A2CEC020385D583CA1F15C53E671DD,3600646,...,1532466558,False,,,,1612465000.0,[CLS] Never forget when Pop Smoke dropped two ...,1,118,0


In [12]:
def extract_media_cnt_tuple(media):
    if(pd.isna(media)):
        return (0,0,0)
    counter = Counter(media.split('\t'))
    
    return (counter['Photo'], counter['Video'], counter['GIF'])

def add_media_cols(df):
    media_counts = df['media'].map(lambda x: extract_media_cnt_tuple(x))
    media_types = ['Photo_cnt', 'Video_cnt', 'GIF_cnt']
    
    for i, col in enumerate(media_types):
        df.loc[:, col] = media_counts.apply(lambda counts: counts[i])
        
    return df

new_df = add_media_cols(new_df)
new_df.head()

Unnamed: 0,hashtags,tweet_id,media,links,domains,tweet_type,language,timestamp,engaged_with_user_id,engaged_with_user_follower_count,...,retweet,retweet_comment,like,text,url_cnt,char_cnt,hashtag_cnt,Photo_cnt,Video_cnt,GIF_cnt
1120596,E8EC1049F02FE3900B1E45D1BDD52BEF\tF0F29CEE3668...,20C5F3A6F47B4E7A85A1443CC3D12B0C,Video,,,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,1612396800,E400C001A195BD92CB74AA0B2E2BB522,6460161,...,,,,"[CLS] Una vez dentro, sólo habrá una salida. ¶...",1,166,5,0,1,0
6744778,,ABC98352C6238B3129A1772122532156,GIF,,,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,1612396800,80C3FD8645A74F589C103A1F9A3C40E5,53204,...,,,1612414000.0,[CLS] https : / / t. co / kJysqCq6UR [SEP],1,33,0,0,0,1
1536740,,120E4BA71617DDFA622E0263783D04B2,Photo\tPhoto,,,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,1612396800,3413E19B696E35FC1B6567C260ED2E0B,96565,...,1612397000.0,,,[CLS] https : / / t. co / r9ltA3k5k9 [SEP],1,33,0,2,0,0
5782609,,AC3770889BCE6879E01ACC3D675CCD5C,Photo,,,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,1612396800,C530D3D968FBD0D08537C5EDEEDEE542,6569,...,,,1612399000.0,[CLS] Como los trata el calor.... a mi asi! ht...,1,63,0,1,0,0
2289432,,AC97A848CC9CE41F42D566B2C5EAAB45,Video,,,TopLevel,488B32D24BD4BB44172EB981C1BCA6FA,1612396800,E7A2CEC020385D583CA1F15C53E671DD,3600646,...,,,1612465000.0,[CLS] Never forget when Pop Smoke dropped two ...,1,118,0,0,1,0


#### Drop columns which won't be used

In [13]:
new_df.drop(['hashtags', 'media', 'links',
              'domains', 'timestamp',
              'engaging_user_account_creation' ,'engaged_with_user_account_creation'], axis=1, inplace=True)
print(new_df.columns)

Index(['tweet_id', 'tweet_type', 'language', 'engaged_with_user_id',
       'engaged_with_user_follower_count', 'engaged_with_user_following_count',
       'engaged_with_user_is_verified', 'engaging_user_id',
       'engaging_user_follower_count', 'engaging_user_following_count',
       'engaging_user_is_verified', 'engagee_follows_engager', 'reply',
       'retweet', 'retweet_comment', 'like', 'text', 'url_cnt', 'char_cnt',
       'hashtag_cnt', 'Photo_cnt', 'Video_cnt', 'GIF_cnt'],
      dtype='object')


#### Change targets to be binary $\in$ {0,1} instead of timestamps

In [14]:
def fix_target(df, col):
    df[col].fillna(0, inplace=True)
    df[col].mask(df[col]>0, 1, inplace=True)
    
for col in target_features:
    fix_target(new_df, col)

new_df.head(2)

Unnamed: 0,tweet_id,tweet_type,language,engaged_with_user_id,engaged_with_user_follower_count,engaged_with_user_following_count,engaged_with_user_is_verified,engaging_user_id,engaging_user_follower_count,engaging_user_following_count,...,retweet,retweet_comment,like,text,url_cnt,char_cnt,hashtag_cnt,Photo_cnt,Video_cnt,GIF_cnt
1120596,20C5F3A6F47B4E7A85A1443CC3D12B0C,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,E400C001A195BD92CB74AA0B2E2BB522,6460161,254,True,7327C6D707CE1C0F8DC376B8F6AA6B3F,49,237,...,0.0,0.0,0.0,"[CLS] Una vez dentro, sólo habrá una salida. ¶...",1,166,5,0,1,0
6744778,ABC98352C6238B3129A1772122532156,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,80C3FD8645A74F589C103A1F9A3C40E5,53204,132,False,75CC5A829BFDBECDFCDD0868EFAC04FF,79,163,...,0.0,0.0,1.0,[CLS] https : / / t. co / kJysqCq6UR [SEP],1,33,0,0,0,1


#### Change 'engagee_follows_engager', 'engaged_with_user_is_verified', 'engaging_user_is_verified' to be numeric 

In [15]:
def replace_boolean_numeric(df, column):
    df[column].replace({True: 1.0, False: 0.0}, inplace=True)
    
for col in ['engagee_follows_engager', 'engaged_with_user_is_verified', 'engaging_user_is_verified']:
    replace_boolean_numeric(new_df, col)
    
new_df.head(2)

Unnamed: 0,tweet_id,tweet_type,language,engaged_with_user_id,engaged_with_user_follower_count,engaged_with_user_following_count,engaged_with_user_is_verified,engaging_user_id,engaging_user_follower_count,engaging_user_following_count,...,retweet,retweet_comment,like,text,url_cnt,char_cnt,hashtag_cnt,Photo_cnt,Video_cnt,GIF_cnt
1120596,20C5F3A6F47B4E7A85A1443CC3D12B0C,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,E400C001A195BD92CB74AA0B2E2BB522,6460161,254,1.0,7327C6D707CE1C0F8DC376B8F6AA6B3F,49,237,...,0.0,0.0,0.0,"[CLS] Una vez dentro, sólo habrá una salida. ¶...",1,166,5,0,1,0
6744778,ABC98352C6238B3129A1772122532156,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,80C3FD8645A74F589C103A1F9A3C40E5,53204,132,0.0,75CC5A829BFDBECDFCDD0868EFAC04FF,79,163,...,0.0,0.0,1.0,[CLS] https : / / t. co / kJysqCq6UR [SEP],1,33,0,0,0,1


#### Split dataframe to test, validation, train dataframes

In [16]:
def train_valid_test_split(df, valid_ratio=0.2, test_ratio=0.2):
    if(valid_ratio + test_ratio >= 1.0):
        raise ValueError("Invalid valid and test ratio")
    
    all_point_cnt = df.shape[0]
    val_cnt = int(valid_ratio * all_point_cnt)
    test_cnt = int(test_ratio * all_point_cnt)
    val_test_cnt = val_cnt + test_cnt
    
    train_points = df.iloc[:-val_test_cnt].copy()
    valid_points = df.iloc[-val_test_cnt:-test_cnt].copy()
    test_points = df.iloc[-test_cnt:].copy()
    
    return train_points, valid_points, test_points

train_df, valid_df, test_df = train_valid_test_split(new_df)
del new_df

In [17]:
train_engaging_users = set(train_df['engaging_user_id'].unique())
train_unique_tweets = set(train_df['tweet_id'].unique())

valid_engaging_users = set(train_df['engaging_user_id'].unique())
valid_unique_tweets = set(train_df['tweet_id'].unique())

test_engaging_users = set(train_df['engaging_user_id'].unique())
test_unique_tweets = set(train_df['tweet_id'].unique())

In [18]:
train_valid_user_diff = len(valid_engaging_users.difference(train_engaging_users))
train_valid_tweet_diff = len(valid_engaging_users.difference(train_engaging_users))

print("# new users in valid: {}".format(train_valid_user_diff))
print("# new tweets in valid: {}".format(train_valid_tweet_diff))

# new users in valid: 0
# new tweets in valid: 0


In [19]:
train_test_user_diff = len(test_engaging_users.difference(train_engaging_users))
train_test_tweet_diff = len(test_engaging_users.difference(train_engaging_users))

print("# new users in valid: {}".format(train_test_user_diff))
print("# new tweets in valid: {}".format(train_test_user_diff))

# new users in valid: 0
# new tweets in valid: 0


In [20]:
del train_engaging_users
del train_unique_tweets

del valid_engaging_users
del valid_unique_tweets

del test_engaging_users
del test_unique_tweets

#### Scale numerical features

In [21]:
train_numerical_feat = train_df.loc[:,numerical_features].values
valid_numerical_feat = valid_df.loc[:,numerical_features].values
test_numerical_feat = test_df.loc[:,numerical_features].values

In [22]:
scaler = MinMaxScaler()

train_scaled = scaler.fit_transform(train_numerical_feat)
print(train_scaled.shape)
valid_scaled = scaler.transform(valid_numerical_feat)
test_scaled = scaler.transform(test_numerical_feat)

(4178184, 10)


In [23]:
for i, feat in enumerate(numerical_features):
    train_df.loc[:,feat] = train_scaled[:,i]
    valid_df.loc[:,feat] = valid_scaled[:,i]
    test_df.loc[:,feat] = test_scaled[:,i]

train_df[numerical_features].head(2)    

Unnamed: 0,engaged_with_user_follower_count,engaged_with_user_following_count,engaging_user_follower_count,engaging_user_following_count,url_cnt,char_cnt,hashtag_cnt,Photo_cnt,Video_cnt,GIF_cnt
1120596,0.049986,6e-05,2e-05,0.001692,0.1,0.177474,0.108696,0.0,0.25,0.0
6744778,0.000412,3.1e-05,3.2e-05,0.001162,0.1,0.026166,0.0,0.0,0.0,0.5


#### Apply target encoding

In [24]:
def target_encode(train_df, valid_df, test_df, features, m):
    for target in target_features:
        te = TargetEncoder(features, m)
        
        train_df = te.fit_transform(train_df, target)
        valid_df = te.transform(valid_df, target)
        test_df = te.transform(test_df, target)
    
    return train_df, valid_df, test_df

train_df, valid_df, test_df = target_encode(train_df, valid_df, test_df, categorical_features, m)
train_df.head(2)

Unnamed: 0,tweet_id,tweet_type,language,engaged_with_user_id,engaged_with_user_follower_count,engaged_with_user_following_count,engaged_with_user_is_verified,engaging_user_id,engaging_user_follower_count,engaging_user_following_count,...,engaging_user_id_retweet_TE,tweet_type_retweet_TE,language_retweet_comment_TE,engaged_with_user_id_retweet_comment_TE,engaging_user_id_retweet_comment_TE,tweet_type_retweet_comment_TE,language_like_TE,engaged_with_user_id_like_TE,engaging_user_id_like_TE,tweet_type_like_TE
1120596,20C5F3A6F47B4E7A85A1443CC3D12B0C,TopLevel,B0FA488F2911701DD8EC5B1EA5E322D8,E400C001A195BD92CB74AA0B2E2BB522,0.049986,6e-05,1.0,7327C6D707CE1C0F8DC376B8F6AA6B3F,2e-05,0.001692,...,0.010432,0.090942,0.007492,0.011705,0.000156,0.007613,0.404614,0.127174,0.665901,0.513148
6744778,ABC98352C6238B3129A1772122532156,TopLevel,313ECD3A1E5BB07406E4249475C2D6D6,80C3FD8645A74F589C103A1F9A3C40E5,0.000412,3.1e-05,0.0,75CC5A829BFDBECDFCDD0868EFAC04FF,3.2e-05,0.001162,...,0.137603,0.090942,0.006282,0.000967,0.002331,0.007613,0.429628,0.680971,0.63683,0.513148


#### Drop categorical values

In [25]:
categorical_features.append('tweet_id')

train_df.drop(categorical_features, axis=1, inplace=True)
valid_df.drop(categorical_features, axis=1, inplace=True)
test_df.drop(categorical_features, axis=1, inplace=True)

In [26]:
def write_file(df, name):
    df.to_csv(DATA_PATH+name+".csv", index=False)
    
write_file(train_df, "train")
write_file(valid_df, "valid")
write_file(test_df, "test")

# Model training

In [4]:
## Dirty hack so we can start the notebook from here(besides imports, parameters..) 
## and don't have to split into another notebook

if ('train_df' not in locals() and 'train_df' not in globals()):
    train_df = pd.read_csv(DATA_PATH+"train.csv")

if ('valid_df' not in locals() and 'valid_df' not in globals()):
    valid_df = pd.read_csv(DATA_PATH+"valid.csv")
    
# if ('test_df' not in locals() and 'test_df' not in globals()):
#     test_df = pd.read_csv(DATA_PATH+"test.csv")
    
xlm_t_tokenizer = transformers.XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")

In [5]:
def create_dataset(df, tokenizer, numerical_features=numerical_features,
                   features=features, targets=target_features, max_len=MAX_LEN):
    all_features = numerical_features + features
    
    text = df['text'].values.tolist()
    feats = df.loc[:,all_features].values
    target_values = df.loc[:, targets].values
    
    return TweetDataset(text, feats, target_values, tokenizer, max_len)
    

train_dataset = create_dataset(train_df, xlm_t_tokenizer)
valid_dataset = create_dataset(valid_df, xlm_t_tokenizer)
#test_dataset = create_dataset(test_df, xlm_t_tokenizer)

In [6]:
config = transformers.XLMRobertaConfig().from_pretrained("xlm-roberta-base")
model = Wd_Xlm_T.from_pretrained(CHECKPOINT_DIR+"epoch_0_end", config = config,
                                dim_features=len(numerical_features + features),
                                dim_hidden=[768,512,256,128,64,32])

model = model.to(DEVICE)

In [7]:
EPOCH_COUNT = 3
BATCH_SIZE = 24

optimizer = torch.optim.Adam(model.parameters(), lr=7e-5)
loss_fn = nn.BCEWithLogitsLoss()

In [8]:
def calc_valid_loss(model, valid_loss, batch_size, loss_fn):
    model.eval()
    loss_list = []
    dataloader = DataLoader(valid_dataset, batch_size=batch_size, 
                            shuffle=False, drop_last=False)
    
    for data in tqdm(dataloader):
        input_ids = data['input_ids'].to(DEVICE)
        attention_mask = data['attention_mask'].to(DEVICE)
        features = data['features'].to(DEVICE)

        labels = data['labels'].to(DEVICE)
        
        logits = model(input_ids, attention_mask, features)
        loss = loss_fn(logits, labels)
        
        loss_list.append(loss.detach().cpu())
        
    
    
    model.train()
    return np.mean(loss_list)


def train_model(model, train_dataset, valid_datset, optimizer, loss_fn, 
                batch_size=BATCH_SIZE, epochs=EPOCH_COUNT):
    model.train()
    
    for epoch in range(3):
        dataloader = DataLoader(train_dataset, batch_size=batch_size,
                                shuffle=True, drop_last=False)
        avg_loss = []
        best_valid_loss = None
        for step, data in enumerate(tqdm(dataloader)):
            input_ids = data['input_ids'].to(DEVICE)
            attention_mask = data['attention_mask'].to(DEVICE)
            features = data['features'].to(DEVICE)

            labels = data['labels'].to(DEVICE)

            logits = model(input_ids, attention_mask, features)
            loss = loss_fn(logits, labels)

            loss.backward()
            optimizer.step()
            
            avg_loss.append(loss.detach().cpu())
            
            if(step > 0 and step % 75000 == 0):
                #train_loss = np.mean(avg_loss)
                #valid_loss = calc_valid_loss(model, valid_dataset, batch_size, loss_fn)
                #print("Epoch: {}, step: {}, train_loss: {}, valid_loss: {}"
                #      .format(epoch, step, train_loss, valid_loss))
                model.save_pretrained((CHECKPOINT_DIR+"epoch_{}_step_{}").format(epoch, step))
                
                #if(best_valid_loss is None or valid_loss < best_valid_loss):
                 #   model.save_pretrained(CHECKPOINT_DIR+"best_valid_loss")
                  #  best_valid_loss = valid_loss 

            
            optimizer.zero_grad()
            
        model.save_pretrained((CHECKPOINT_DIR+"epoch_{}_end").format(epoch))
train_model(model, train_dataset, valid_dataset, optimizer, loss_fn)

  0%|          | 0/174091 [00:00<?, ?it/s]

  0%|          | 0/174091 [00:00<?, ?it/s]

KeyboardInterrupt: 