In [1]:
from tqdm import trange
import torch
import torch.nn as nn
from transformers import AdamW
import pandas as pd
import numpy as np
from keras.preprocessing.sequence import pad_sequences
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from transformers import GPT2Tokenizer

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
special_tokens = ["_start_", "_delimiter_", "_classify_"]
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', do_lower_case=True)
tokenizer.add_tokens(special_tokens)
special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
model = torch.load("../models/Gpt2_Classifier_Large.pt")
model.resize_token_embeddings(len(tokenizer))



Embedding(50260, 768)

In [4]:
model = model.cuda()

We need to create our training dataset

In [5]:
directory = "../data/bot_detection/"

In [6]:
train = pd.read_csv(directory + "train.csv", header=None)

train = pd.DataFrame({
    'id':range(len(train)),
    'label':train[0],
    'mark':['a']*train.shape[0],
    'text': train[1].replace(r'\n', ' ', regex=True)
})

train.columns = ["index", "label", "mark", "tweet"]

In [7]:
test = pd.read_csv(directory + "test.csv", header=None)

test = pd.DataFrame({
    'id':range(len(test)),
    'label':test[0],
    'mark':['a']*test.shape[0],
    'text': test[1].replace(r'\n', ' ', regex=True)
})

test.columns =  ["index", "label", "mark", "tweet"]

## Preprocessing

In [8]:
train_sentences = train.tweet.values

In [9]:
test_sentences = test.tweet.values

We might only want those tweets with one hashtag, url or user

In [10]:
url_positions = []
hashtag_positions = []
user_positions = []
for train_sentence in train_sentences:
    train_sentence_split = train_sentence.split()
    urls = [i for i in train_sentence_split if (i.startswith('https:') or i.startswith('http:'))]
    if(len(urls)==1):
        start = train_sentence.find(urls[0])
        end = start+len(urls[0])
        url_positions.append([start,end])
    else:
        url_positions.append([-1,-1])

    hashtags = [i for i in train_sentence_split if i.startswith('#')]
    if(len(hashtags)==1):
        start = train_sentence.find(hashtags[0])
        end = start+len(hashtags[0])
        train_sentences[0][start:end]
        hashtag_positions.append([start,end])
    else:
        hashtag_positions.append([-1,-1])

    users = [i for i in train_sentence_split if i.startswith('@')]
    if(len(users)==1):
        start = train_sentence.find(users[0])
        end = start+len(users[0])
        train_sentences[0][start:end]
        user_positions.append([start,end])
    else:
        user_positions.append([-1,-1])

In [11]:
url_mask = np.all(np.array(url_positions)!=np.array([-1,-1]),axis=1)
np.array(url_positions)[url_mask]
train_sentences[url_mask]

array(['25 Non-Profit WordPress Themes for Charitable Organizations https://t.co/4yHghl3llR #creative #art',
       'when you wake up thinking its friday 🙄 https://t.co/bKImYhaDAr',
       "One of the most dynamic players in the game.  Breaking down @run__cmc's game through True View highlights 💪 (via… https://t.co/Uv8XZ9FHLb",
       ..., '“WHAT’S IN THE BOX?!” https://t.co/TyVZYmPz64',
       "Philadelphia NAACP Takes Issue With DA's Ruling Against Appeal For Convicted Cop Killer Mumia Abu-Jamal.… https://t.co/VydMJm2xjp",
       'Federal Conservative leader Andrew Scheer speaks at the Energy Relaunch conference in Calgary #cdnpoli #ableg https://t.co/0cqmfyf5Gy'],
      dtype=object)

In [12]:
hashtag_mask = np.all(np.array(hashtag_positions)!=np.array([-1,-1]),axis=1)
np.array(hashtag_positions)[hashtag_mask]
train_sentences[hashtag_mask]

array(['@CODWorldLeague @eUnitedgg Welcome to the 2019 CWL Pro League! #ASTROfamily https://t.co/vYhl5Zwbkc',
       'https://t.co/JZ65RNKdYZ       Saturdays are looking a bit different this Fall. You can now catch @ESPNCFB live on @Hulu.  #ad',
       'Learn to appreciate what you have before time makes you appreciate what you had. #gratitude',
       ...,
       "#BTS' Official Site has been updated with new pages: LOVE YOURSELF 結 Answer Concept Photos L and F.… https://t.co/e2WOSBUnK4",
       'Congratulations @LeoDiCaprio #Oscars https://t.co/5WLesgfnbe',
       'Hey #BTSArmy! It’s BTS Week on @ellentube! Seriously. I’m not BTSing you. @BTS_twt https://t.co/RThONr22ow  ⚡️   https://t.co/gbA7D4dMyB'],
      dtype=object)

In [13]:
user_mask = np.all(np.array(user_positions)!=np.array([-1,-1]),axis=1)
np.array(user_positions)[user_mask]
train_sentences[user_mask]

array(['@Krimlin_GG Yeah us too but if you compare us to jack links then compare Gfuel to crystal light or dxr to merax office chairs.',
       "One of the most dynamic players in the game.  Breaking down @run__cmc's game through True View highlights 💪 (via… https://t.co/Uv8XZ9FHLb",
       '@aplusk Lives were saved.', ...,
       '2-0 today as the Free Pick of the Day over 7.5 for #Rockies @ #Padres easily cashes! #freepicks #freetips https://t.co/fIVEswO44T',
       'Congratulations @LeoDiCaprio #Oscars https://t.co/5WLesgfnbe',
       '@enemybieber ok ok now we reaching reaching'], dtype=object)

In [14]:
url_positions_test = []
hashtag_positions_test = []
user_positions_test = []
for test_sentence in test_sentences:
    test_sentence_split = test_sentence.split()
    urls = [i for i in test_sentence_split if (i.startswith('https:') or i.startswith('http:'))]
    if(len(urls)==1):
        start = test_sentence.find(urls[0])
        end = start+len(urls[0])
        url_positions_test.append([start,end])
    else:
        url_positions_test.append([-1,-1])

    hashtags = [i for i in test_sentence_split if i.startswith('#')]
    if(len(hashtags)==1):
        start = test_sentence.find(hashtags[0])
        end = start+len(hashtags[0])
        hashtag_positions_test.append([start,end])
    else:
        hashtag_positions_test.append([-1,-1])

    users = [i for i in test_sentence_split if i.startswith('@')]
    if(len(users)==1):
        start = test_sentence.find(users[0])
        end = start+len(users[0])
        user_positions_test.append([start,end])
    else:
        user_positions_test.append([-1,-1])

In [15]:
url_mask_test = np.all(np.array(url_positions_test)!=np.array([-1,-1]),axis=1)
np.array(url_positions_test)[url_mask_test]
test_sentences[url_mask_test]

array(['Now Playing: ♬ Dick Curless - Evil Hearted Me ♬ https://t.co/fzgP9IRt2h',
       'Not only are you comfortably swaddled in security today, it’s ... More for Capricorn https://t.co/MVCHEli4g1',
       'These strawberry sandwich cookies are so easy to make and so tasty! Perfect for #MothersDay https://t.co/Uq7cooR2y7 via @iamthemaven',
       ..., 'Lol hey, boo. 🤗🤗 https://t.co/cmQ0125gXM',
       'When you get a fall wedding invitation in the mail: https://t.co/fwdBnOM42e',
       'CONFIRMED: Today is it for the season for one of our favorite places for black raspberry (or any) ice cream.… https://t.co/tkpt0oSn2I'],
      dtype=object)

In [16]:
hashtag_mask_test = np.all(np.array(hashtag_positions_test)!=np.array([-1,-1]),axis=1)
np.array(hashtag_positions_test)[hashtag_mask_test]
test_sentences[hashtag_mask_test]

array(['These strawberry sandwich cookies are so easy to make and so tasty! Perfect for #MothersDay https://t.co/Uq7cooR2y7 via @iamthemaven',
       '“@garthbrooks: I saw an ENTERTAINER tonight! @jtimberlake #Justinrcredible http://t.co/qqsjkEha1H -love, g” What a night. Nashville and GB.🙏',
       "New Indie Book Release: THE UNNAMED GIRL (by Mike H Mizrahi @MikeHMiz) &gt;\xa0https://t.co/DVZgymret4 &lt; 'Power of Love' Historical #MustRead",
       ...,
       "Help us raise $15,000 for disaster relief! Here's how you can donate to #TheGoodWithin: https://t.co/WykP727uN4 https://t.co/NSeHNbMYcf",
       'Who wants to make billions of dollars? AND . . . is a biological engineer? Please create palm trees that can live in Canada #DeciduousTrees',
       'Bangladesh: Landslides Threaten Rohingya Shelters - Human Rights Watch https://t.co/Kj7g24wxOu #Bangladesh'],
      dtype=object)

In [17]:
user_mask_test = np.all(np.array(user_positions_test)!=np.array([-1,-1]),axis=1)
np.array(user_positions_test)[user_mask_test]
test_sentences[user_mask_test]

array(['These strawberry sandwich cookies are so easy to make and so tasty! Perfect for #MothersDay https://t.co/Uq7cooR2y7 via @iamthemaven',
       '“@garthbrooks: I saw an ENTERTAINER tonight! @jtimberlake #Justinrcredible http://t.co/qqsjkEha1H -love, g” What a night. Nashville and GB.🙏',
       ...,
       'Goofballs write their own resumes. https://t.co/bioBQQEfdx @ibeebz2 https://t.co/tUR1G7SSB4',
       '@PrTwain Thanks for the recommendation! Any specific episodes that we should be sure to catch?',
       '@camila_melxo Yes! Please email your details to corporateorders@georgetowncupcake.com &amp; a Manager can help with the LevelUp credit!'],
      dtype=object)

In [18]:
url_positions_start = np.array(url_positions)[url_mask][:,0]
url_positions_end = np.array(url_positions)[url_mask][:,1]

hashtag_positions_start = np.array(hashtag_positions)[hashtag_mask][:,0]
hashtag_positions_end = np.array(hashtag_positions)[hashtag_mask][:,1]

user_positions_start = np.array(user_positions)[user_mask][:,0]
user_positions_end = np.array(user_positions)[user_mask][:,1]

In [19]:
url_positions_start_test = np.array(url_positions_test)[url_mask_test][:,0]
url_positions_end_test = np.array(url_positions_test)[url_mask_test][:,1]

hashtag_positions_start_test = np.array(hashtag_positions_test)[hashtag_mask_test][:,0]
hashtag_positions_end_test = np.array(hashtag_positions_test)[hashtag_mask_test][:,1]

user_positions_start_test = np.array(user_positions_test)[user_mask_test][:,0]
user_positions_end_test = np.array(user_positions_test)[user_mask_test][:,1]

In [20]:
train_sentences_url = ["_start_ " + "url?" + " _delimiter_ " + sentence + " _classify_" for sentence in train_sentences[url_mask]]
train_sentences_hashtag = ["_start_ " + "hashtag?" + " _delimiter_ " + sentence + " _classify_" for sentence in train_sentences[hashtag_mask]]
train_sentences_user = ["_start_ " + "user?" + " _delimiter_ " + sentence + " _classify_" for sentence in train_sentences[user_mask]]

In [21]:
test_sentences_url = ["_start_ " + "url?" + " _delimiter_ " + sentence + " _classify_" for sentence in test_sentences[url_mask_test]]
test_sentences_hashtag = ["_start_ " + "hashtag?" + " _delimiter_ " + sentence + " _classify_" for sentence in test_sentences[hashtag_mask_test]]
test_sentences_user = ["_start_ " + "user?" + " _delimiter_ " + sentence + " _classify_" for sentence in test_sentences[user_mask_test]]

In [22]:
MAX_LENGTH = 128

In [None]:
train_sentences_encoded = [tokenizer.encode(sent + " _classify_", add_prefix_space=True) for sent in train_sentences[url_mask]] 
prefix = tokenizer.encode("_start_ " + "url?" + " _delimiter_ " )
input_ids_url = [prefix + encoded_sent  for encoded_sent in train_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_url = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(train_sentences[url_mask]))]

train_sentences_encoded = [tokenizer.encode(sent + " _classify_", add_prefix_space=True) for sent in train_sentences[hashtag_mask]] 
prefix = tokenizer.encode("_start_ " + "hashtag?" + " _delimiter_ ")
input_ids_hashtag = [prefix + encoded_sent for encoded_sent in train_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_hashtag = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(train_sentences[hashtag_mask]))]

train_sentences_encoded = [tokenizer.encode(sent + " _classify_", add_prefix_space=True) for sent in train_sentences[user_mask]]
prefix = tokenizer.encode("_start_ " + "user?" + " _delimiter_ ")
input_ids_user = [prefix + encoded_sent for encoded_sent in train_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_user = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(train_sentences[user_mask]))]

In [24]:
test_sentences_encoded = [tokenizer.encode(sent + " _classify_", add_prefix_space=True) for sent in test_sentences[url_mask_test]] 
prefix = tokenizer.encode("_start_ " + "url?" + " _delimiter_ " )
input_ids_url_test = [prefix + encoded_sent  for encoded_sent in test_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_url_test = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(test_sentences[url_mask_test]))]

test_sentences_encoded = [tokenizer.encode(sent + " _classify_", add_prefix_space=True) for sent in test_sentences[hashtag_mask_test]] 
prefix = tokenizer.encode("_start_ " + "hashtag?" + " _delimiter_ ")
input_ids_hashtag_test = [prefix + encoded_sent for encoded_sent in test_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_hashtag_test = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(test_sentences[hashtag_mask_test]))]

test_sentences_encoded = [tokenizer.encode(sent + " _classify_", add_prefix_space=True) for sent in test_sentences[user_mask_test]]
prefix = tokenizer.encode("_start_ " + "user?" + " _delimiter_ ")
input_ids_user_test = [prefix + encoded_sent for encoded_sent in test_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_user_test = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(test_sentences[user_mask_test]))]

In [25]:
def is_Sublist(l, s):
    i=-1
    sub_set = False
    if s == []:
        sub_set = True
    elif s == l:
        sub_set = True
    elif len(s) > len(l):
        sub_set = False

    else:
        for i in range(len(l)):
            if l[i] == s[0]:
                n = 1
                while (n < len(s)) and (l[i+n] == s[n]):
                    n += 1

                if n == len(s):
                    sub_set = True
                    break

    return sub_set,i

start and end positions are relative to the original sentence (train_sentences_url), we need to convert it to the position in input_ids_url

In [26]:
url_positions_start_ids = []
url_positions_end_ids = []
offset = len("_start_ " + "url?" + " _delimiter_ ")
for i in range(0,len(train_sentences_url)):
    if(url_positions_start[i]!=-1 and url_positions_end[i]!=-1):
        url_ids = tokenizer.encode(train_sentences_url[i][offset+url_positions_start[i]:offset+url_positions_end[i]], add_prefix_space=True )
        result = is_Sublist(input_ids_url[i],url_ids)
        start = result[1]
        end = start+len(url_ids)
        url_positions_start_ids.append(start)
        url_positions_end_ids.append(end)
    else:
        url_positions_start_ids.append(-1)
        url_positions_end_ids.append(-1)

In [27]:
url_positions_start_ids_test = []
url_positions_end_ids_test = []
offset = len("_start_ " + "url?" + " _delimiter_ ")
for i in range(0,len(test_sentences_url)):
    if(url_positions_start_test[i]!=-1 and url_positions_end_test[i]!=-1):
        url_ids = tokenizer.encode(test_sentences_url[i][offset+url_positions_start_test[i]:offset+url_positions_end_test[i]], add_prefix_space=True)
        result = is_Sublist(input_ids_url_test[i],url_ids)
        start = result[1]
        end = start+len(url_ids)
        url_positions_start_ids_test.append(start)
        url_positions_end_ids_test.append(end)
    else:
        url_positions_start_ids_test.append(-1)
        url_positions_end_ids_test.append(-1)

In [28]:
hashtag_positions_start_ids = []
hashtag_positions_end_ids = []
offset = len("_start_ " + "hashtag?" + " _delimiter_ ")
for i in range(0,len(train_sentences_hashtag)):
    if(hashtag_positions_start[i]!=-1 and hashtag_positions_end[i]!=-1):
        hashtag_ids = tokenizer.encode(train_sentences_hashtag[i][offset+hashtag_positions_start[i]:offset+hashtag_positions_end[i]], add_prefix_space=True)
        result = is_Sublist(input_ids_hashtag[i],hashtag_ids)
        start = result[1]
        end = start+len(hashtag_ids)
        hashtag_positions_start_ids.append(start)
        hashtag_positions_end_ids.append(end)
    else:
        hashtag_positions_start_ids.append(-1)
        hashtag_positions_end_ids.append(-1)

In [29]:
hashtag_positions_start_ids_test = []
hashtag_positions_end_ids_test = []
offset = len("_start_ " + "hashtag?" + " _delimiter_ ")
for i in range(0,len(test_sentences_hashtag)):
    if(hashtag_positions_start_test[i]!=-1 and hashtag_positions_end_test[i]!=-1):
        hashtag_ids = tokenizer.encode(test_sentences_hashtag[i][offset+hashtag_positions_start_test[i]:offset+hashtag_positions_end_test[i]], add_prefix_space=True)
        result = is_Sublist(input_ids_hashtag_test[i],hashtag_ids)
        start = result[1]
        end = start+len(hashtag_ids)
        hashtag_positions_start_ids_test.append(start)
        hashtag_positions_end_ids_test.append(end)
    else:
        hashtag_positions_start_ids_test.append(-1)
        hashtag_positions_end_ids_test.append(-1)

In [30]:
user_positions_start_ids = []
user_positions_end_ids = []
offset = len("_start_ " + "user?" + " _delimiter_ ")
for i in range(0,len(train_sentences_user)):
    if(user_positions_start[i]!=-1 and user_positions_end[i]!=-1):
        user_ids = tokenizer.encode(train_sentences_user[i][offset+user_positions_start[i]:offset+user_positions_end[i]], add_prefix_space=True)
        result = is_Sublist(input_ids_user[i],user_ids)
        start = result[1]
        end = start+len(user_ids)
        user_positions_start_ids.append(start)
        user_positions_end_ids.append(end)
    else:
        user_positions_start_ids.append(-1)
        user_positions_end_ids.append(-1)

In [31]:
user_positions_start_ids_test = []
user_positions_end_ids_test = []
offset = len("_start_ " + "user?" + " _delimiter_ ")
for i in range(0,len(test_sentences_user)):
    if(user_positions_start_test[i]!=-1 and user_positions_end_test[i]!=-1):
        user_ids = tokenizer.encode(test_sentences_user[i][offset+user_positions_start_test[i]:offset+user_positions_end_test[i]], add_prefix_space=True)
        result = is_Sublist(input_ids_user_test[i],user_ids)
        start = result[1]
        end = start+len(user_ids)
        user_positions_start_ids_test.append(start)
        user_positions_end_ids_test.append(end)
    else:
        user_positions_start_ids_test.append(-1)
        user_positions_end_ids_test.append(-1)

In [32]:
MAX_LEN = 128

In [33]:
train_input_ids_url = pad_sequences(input_ids_url, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
train_input_ids_hashtag = pad_sequences(input_ids_hashtag, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
train_input_ids_user = pad_sequences(input_ids_user, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

In [34]:
test_input_ids_url = pad_sequences(input_ids_url_test, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
test_input_ids_hashtag = pad_sequences(input_ids_hashtag_test, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
test_input_ids_user = pad_sequences(input_ids_user_test, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

In [35]:
train_inputs_url = torch.tensor(train_input_ids_url)
train_inputs_hashtag = torch.tensor(train_input_ids_hashtag)
train_inputs_user = torch.tensor(train_input_ids_user)

train_url_positions_start = torch.tensor(url_positions_start_ids)
train_hashtag_positions_start = torch.tensor(hashtag_positions_start_ids)
train_user_positions_start = torch.tensor(user_positions_start_ids)

train_url_positions_end = torch.tensor(url_positions_end_ids)
train_hashtag_positions_end = torch.tensor(hashtag_positions_end_ids)
train_user_positions_end = torch.tensor(user_positions_end_ids)

train_token_type_url = torch.tensor(token_type_ids_url)
train_token_type_hashtag = torch.tensor(token_type_ids_hashtag)
train_token_type_user = torch.tensor(token_type_ids_user)

In [36]:
test_inputs_url = torch.tensor(test_input_ids_url)
test_inputs_hashtag = torch.tensor(test_input_ids_hashtag)
test_inputs_user = torch.tensor(test_input_ids_user)

test_url_positions_start = torch.tensor(url_positions_start_ids_test)
test_hashtag_positions_start = torch.tensor(hashtag_positions_start_ids_test)
test_user_positions_start = torch.tensor(user_positions_start_ids_test)

test_url_positions_end = torch.tensor(url_positions_end_ids_test)
test_hashtag_positions_end = torch.tensor(hashtag_positions_end_ids_test)
test_user_positions_end = torch.tensor(user_positions_end_ids_test)

test_token_type_url = torch.tensor(token_type_ids_url_test)
test_token_type_hashtag = torch.tensor(token_type_ids_hashtag_test)
test_token_type_user = torch.tensor(token_type_ids_user_test)

In [37]:
train_inputs = torch.cat((train_inputs_url,train_inputs_hashtag, train_inputs_user),0)
train_positions_start = torch.cat((train_url_positions_start,train_hashtag_positions_start, train_user_positions_start),0)
train_positions_end = torch.cat((train_url_positions_end,train_hashtag_positions_end, train_user_positions_end),0)
train_token_type = torch.cat((train_token_type_url,train_token_type_hashtag, train_token_type_user),0)

In [38]:
test_inputs = torch.cat((test_inputs_url,test_inputs_hashtag, test_inputs_user),0)
test_positions_start = torch.cat((test_url_positions_start,test_hashtag_positions_start, test_user_positions_start),0)
test_positions_end = torch.cat((test_url_positions_end,test_hashtag_positions_end, test_user_positions_end),0)
test_token_type = torch.cat((test_token_type_url,test_token_type_hashtag, test_token_type_user),0)

### Create the generators

In [39]:
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset)
batch_size = 8

In [40]:
train_data = TensorDataset(train_inputs, 
                           train_positions_start, 
                           train_positions_end,
                           train_token_type)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

In [41]:
test_data = TensorDataset(test_inputs, 
                           test_positions_start, 
                           test_positions_end, 
                           test_token_type)
test_sampler = RandomSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

In [42]:
linear = nn.Linear(768, 2).to(torch.device("cuda:0"))

In [43]:
param_optimizer = list(linear.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters,
                     lr=2e-5,)

## Training

In [44]:
model = model.transformer

In [262]:
# Store our loss and accuracy for plotting
train_loss_set = []

# Number of training epochs (authors recommend between 2 and 4)
epochs = 2

# trange is a tqdm wrapper around the normal python range
for _ in trange(epochs, desc="Epoch"):
  
  
  # Training
  
  # Set our model to training mode (as opposed to evaluation mode)
  linear.train()
    
  for step, batch in enumerate(train_dataloader):
    # Add batch to GPU
    ##batch = tuple(t.to(device) for t in batch)
    # Unpack the inputs from our dataloader
    b_inputs, b_positions_start, b_positions_end, b_token_type = batch
        
    sequence_output = model(b_inputs.to(device), token_type_ids=b_token_type.to(device))[0]
    logits = linear(sequence_output)
    start_logits, end_logits = logits.split(1, dim=-1)
    start_logits = start_logits.squeeze(-1)
    end_logits = end_logits.squeeze(-1)
    start_positions = b_positions_start.to(device)
    end_positions = b_positions_end.to(device)
    if len(start_positions.size()) > 1:
        start_positions = start_positions.squeeze(-1)
    if len(end_positions.size()) > 1:
        end_positions = end_positions.squeeze(-1)
    ignored_index = start_logits.size(1)
    start_positions.clamp_(0, ignored_index)
    end_positions.clamp_(0, ignored_index)
    loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
    start_loss = loss_fct(start_logits, start_positions )
    end_loss = loss_fct(end_logits, end_positions)
    total_loss = (start_loss + end_loss) / 2
    total_loss.backward(retain_graph=True)
    optimizer.step()

Epoch: 100%|██████████| 2/2 [5:07:33<00:00, 9019.41s/it]  


## Evaluation

In [30]:
test_data_url = TensorDataset(test_inputs_url, 
                           test_token_type_url)
test_sampler_url = torch.utils.data.SequentialSampler(test_data_url)
test_dataloader_url = DataLoader(test_data_url, sampler=test_sampler_url, batch_size=batch_size)
    

In [31]:
linear.eval()

preds_url = []

for batch in test_dataloader_url:
    b_inputs, b_token_type = batch
    with torch.no_grad():
        sequence_output = model(b_inputs.to(device), token_type_ids=b_token_type.to(device))[0]
        logits = linear(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
    
    start_logits = start_logits.detach().cpu().numpy()
    end_logits = end_logits.detach().cpu().numpy()
    preds_url.append([start_logits,end_logits])

In [32]:
test_data_hashtag = TensorDataset(test_inputs_hashtag, 
                           test_token_type_hashtag)
test_sampler_hashtag = torch.utils.data.SequentialSampler(test_data_hashtag)
test_dataloader_hashtag = DataLoader(test_data_hashtag, sampler=test_sampler_hashtag, batch_size=batch_size)

In [33]:
preds_hashtag = []

for batch in test_dataloader_hashtag:
    b_inputs, b_token_type = batch
    with torch.no_grad():
        sequence_output = model(b_inputs.to(device), token_type_ids=b_token_type.to(device))[0]
        logits = linear(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
    
    start_logits = start_logits.detach().cpu().numpy()
    end_logits = end_logits.detach().cpu().numpy()
    preds_hashtag.append([start_logits,end_logits])

In [34]:
test_data_user = TensorDataset(test_inputs_user, 
                           test_token_type_user)
test_sampler_user = torch.utils.data.SequentialSampler(test_data_user)
test_dataloader_user = DataLoader(test_data_user, sampler=test_sampler_user, batch_size=batch_size)

In [35]:
preds_user = []

for batch in test_dataloader_user:
    b_inputs, b_token_type = batch
    with torch.no_grad():
        sequence_output = model(b_inputs.to(device), token_type_ids=b_token_type.to(device))[0]
        logits = linear(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
    
    start_logits = start_logits.detach().cpu().numpy()
    end_logits = end_logits.detach().cpu().numpy()
    preds_user.append([start_logits,end_logits])

In [36]:
predictions_url_start = []
predictions_url_end = []
for preds in preds_url:
    predictions_url_start = predictions_url_start + list(np.argmax(preds[0],axis=1))
    predictions_url_end = predictions_url_end + list(np.argmax(preds[1],axis=1))

In [37]:
predictions_hashtag_start = []
predictions_hashtag_end = []
for preds in preds_hashtag:
    predictions_hashtag_start = predictions_hashtag_start + list(np.argmax(preds[0],axis=1))
    predictions_hashtag_end = predictions_hashtag_end + list(np.argmax(preds[1],axis=1))

In [38]:
predictions_user_start = []
predictions_user_end = []
for preds in preds_user:
    predictions_user_start = predictions_user_start + list(np.argmax(preds[0],axis=1))
    predictions_user_end = predictions_user_end + list(np.argmax(preds[1],axis=1))

In [273]:
print(classification_report([0 if pos<0 else pos for pos in url_positions_start_ids_test], predictions_url_start, digits = 4))
print(classification_report([0 if pos<0 else pos for pos in url_positions_end_ids_test], predictions_url_end, digits = 4))

             precision    recall  f1-score   support

          4     0.0221    0.1739    0.0392        46
          5     0.4372    0.9938    0.6072       161
          6     0.6927    0.9835    0.8129       424
          7     0.7804    0.9608    0.8612       688
          8     0.8167    0.9776    0.8900       939
          9     0.7660    0.9308    0.8404       925
         10     0.7017    0.9067    0.7911      1157
         11     0.8078    0.9095    0.8556      1248
         12     0.8392    0.9055    0.8711      1481
         13     0.8196    0.7837    0.8013      1600
         14     0.8443    0.8670    0.8555      1451
         15     0.8208    0.8583    0.8392      1553
         16     0.8614    0.8321    0.8465      1531
         17     0.8261    0.8606    0.8430      1435
         18     0.8339    0.8618    0.8476      1433
         19     0.8420    0.8757    0.8585      1528
         20     0.8431    0.8685    0.8556      1491
         21     0.8291    0.8630    0.8457   

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


In [274]:
print(classification_report([0 if pos<0 else pos for pos in hashtag_positions_start_ids_test], predictions_hashtag_start, digits = 4))
print(classification_report([0 if pos<0 else pos for pos in hashtag_positions_end_ids_test], predictions_hashtag_end, digits = 4))

             precision    recall  f1-score   support

          5     0.9645    0.9561    0.9603      1504
          6     0.8702    0.9105    0.8899       324
          7     0.7971    0.7407    0.7679       297
          8     0.8411    0.6330    0.7224       485
          9     0.8223    0.7259    0.7711       478
         10     0.7603    0.7433    0.7517       448
         11     0.8208    0.7260    0.7705       511
         12     0.7627    0.7095    0.7351       444
         13     0.7946    0.7221    0.7566       493
         14     0.7475    0.6823    0.7135       447
         15     0.7371    0.7207    0.7288       537
         16     0.7427    0.6620    0.7000       497
         17     0.6674    0.6916    0.6793       441
         18     0.6875    0.6336    0.6595       434
         19     0.6447    0.6496    0.6471       391
         20     0.6303    0.6382    0.6342       398
         21     0.6188    0.6579    0.6378       380
         22     0.5677    0.6751    0.6168   

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


In [276]:
print(classification_report([0 if pos<0 else pos for pos in user_positions_start_ids_test], predictions_user_start, digits = 4))
print(classification_report([0 if pos<0 else pos for pos in user_positions_end_ids_test], predictions_user_end, digits = 4))

             precision    recall  f1-score   support

          4     0.9893    0.9584    0.9736     16405
          5     0.8554    0.7552    0.8022       572
          6     0.8496    0.7727    0.8093       607
          7     0.9029    0.7124    0.7964       692
          8     0.8582    0.6243    0.7228       543
          9     0.8000    0.6508    0.7177       461
         10     0.8548    0.7056    0.7731       676
         11     0.7795    0.5996    0.6778       507
         12     0.7661    0.6000    0.6730       475
         13     0.7221    0.6192    0.6667       428
         14     0.5995    0.5570    0.5774       395
         15     0.5808    0.5938    0.5873       357
         16     0.5846    0.5000    0.5390       304
         17     0.5877    0.5201    0.5518       348
         18     0.5878    0.5705    0.5790       305
         19     0.5426    0.5639    0.5531       305
         20     0.5390    0.5579    0.5483       285
         21     0.5114    0.5253    0.5182   

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


In [39]:
print(classification_report([('0' if pos_start<0 else str(pos_start))+","+('0' if pos_end<0 else str(pos_end)) for pos_start,pos_end in zip(url_positions_start_ids_test,url_positions_end_ids_test)], [str(pos_start)+","+str(pos_end) for pos_start,pos_end in zip(predictions_url_start,predictions_url_end)], digits = 4))
print(classification_report([('0' if pos_start<0 else str(pos_start))+","+('0' if pos_end<0 else str(pos_end)) for pos_start,pos_end in zip(hashtag_positions_start_ids_test,hashtag_positions_end_ids_test)], [str(pos_start)+","+str(pos_end) for pos_start,pos_end in zip(predictions_hashtag_start,predictions_hashtag_end)], digits = 4))
print(classification_report([('0' if pos_start<0 else str(pos_start))+","+('0' if pos_end<0 else str(pos_end)) for pos_start,pos_end in zip(user_positions_start_ids_test,user_positions_end_ids_test)], [str(pos_start)+","+str(pos_end) for pos_start,pos_end in zip(predictions_user_start,predictions_user_end)], digits = 4))

             precision    recall  f1-score   support

      10,10     0.0000    0.0000    0.0000         0
      10,11     0.0000    0.0000    0.0000         0
      10,12     0.0000    0.0000    0.0000         0
      10,13     0.0000    0.0000    0.0000         0
      10,14     0.0000    0.0000    0.0000         0
      10,15     0.0000    0.0000    0.0000         0
      10,16     0.0000    0.0000    0.0000         0
      10,17     0.0000    0.0000    0.0000         0
      10,18     0.0000    0.0000    0.0000         0
      10,19     0.0000    0.0000    0.0000         0
      10,20     0.0000    0.0000    0.0000         0
      10,21     0.2857    0.4444    0.3478         9
      10,22     0.8659    0.7172    0.7845        99
      10,23     0.9672    0.7361    0.8360       360
      10,24     0.9743    0.7967    0.8766       428
      10,25     0.9265    0.8591    0.8915       220
      10,26     0.6250    0.7500    0.6818        40
      10,27     0.0000    0.0000    0.0000   

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)
