In [1]:
from tqdm import trange
from transformers import BertTokenizer
import torch
import torch.nn as nn
from transformers import AdamW
import pandas as pd
import numpy as np
from keras.preprocessing.sequence import pad_sequences
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = torch.load("../models/BERT_Classifier_Large.pt")



In [5]:
model = model.cuda()

We need to create our training dataset

In [6]:
directory = "../data/bot_detection/"

In [7]:
train = pd.read_csv(directory + "train.csv", header=None)

train = pd.DataFrame({
    'id':range(len(train)),
    'label':train[0],
    'mark':['a']*train.shape[0],
    'text': train[1].replace(r'\n', ' ', regex=True)
})

train.columns = ["index", "label", "mark", "tweet"]

In [8]:
test = pd.read_csv(directory + "test.csv", header=None)

test = pd.DataFrame({
    'id':range(len(test)),
    'label':test[0],
    'mark':['a']*test.shape[0],
    'text': test[1].replace(r'\n', ' ', regex=True)
})

test.columns =  ["index", "label", "mark", "tweet"]

## Preprocessing

In [9]:
train_sentences = train.tweet.values

In [10]:
test_sentences = test.tweet.values

We might only want those tweets with one hashtag, url or user

In [11]:
url_positions = []
hashtag_positions = []
user_positions = []
for train_sentence in train_sentences:
    train_sentence_split = train_sentence.split()
    urls = [i for i in train_sentence_split if (i.startswith('https:') or i.startswith('http:'))]
    if(len(urls)==1 and urls[0]==train_sentence_split[-1]): # If contains only a token, and this token is at the end of the tweet
        start = train_sentence.find(urls[0])
        end = start+len(urls[0])
        url_positions.append([start,end])
    else:
        url_positions.append([-1,-1])

    hashtags = [i for i in train_sentence_split if i.startswith('#')]
    if(len(hashtags)==1 and hashtags[0]==train_sentence_split[-1]):
        start = train_sentence.find(hashtags[0])
        end = start+len(hashtags[0])
        train_sentences[0][start:end]
        hashtag_positions.append([start,end])
    else:
        hashtag_positions.append([-1,-1])

    users = [i for i in train_sentence_split if i.startswith('@')]
    if(len(users)==1 and users[0]==train_sentence_split[-1]):
        start = train_sentence.find(users[0])
        end = start+len(users[0])
        train_sentences[0][start:end]
        user_positions.append([start,end])
    else:
        user_positions.append([-1,-1])

In [12]:
url_mask = np.all(np.array(url_positions)!=np.array([-1,-1]),axis=1)
np.array(url_positions)[url_mask]
train_sentences[url_mask]

array(['when you wake up thinking its friday 🙄 https://t.co/bKImYhaDAr',
       "One of the most dynamic players in the game.  Breaking down @run__cmc's game through True View highlights 💪 (via… https://t.co/Uv8XZ9FHLb",
       'News: A new restaurant and wine bar is coming to the North Shore, opening in the space where a BBQ spot had been. https://t.co/oCSHHs4YxQ',
       ..., '“WHAT’S IN THE BOX?!” https://t.co/TyVZYmPz64',
       "Philadelphia NAACP Takes Issue With DA's Ruling Against Appeal For Convicted Cop Killer Mumia Abu-Jamal.… https://t.co/VydMJm2xjp",
       'Federal Conservative leader Andrew Scheer speaks at the Energy Relaunch conference in Calgary #cdnpoli #ableg https://t.co/0cqmfyf5Gy'],
      dtype=object)

In [13]:
hashtag_mask = np.all(np.array(hashtag_positions)!=np.array([-1,-1]),axis=1)
np.array(hashtag_positions)[hashtag_mask]
train_sentences[hashtag_mask]

array(['https://t.co/JZ65RNKdYZ       Saturdays are looking a bit different this Fall. You can now catch @ESPNCFB live on @Hulu.  #ad',
       'Learn to appreciate what you have before time makes you appreciate what you had. #gratitude',
       "I am completely in love with the new song @KeshaRose it's lovely and moving. #KeshaIsBack",
       ...,
       "Oh and let's make this extra special. Tweet us what song you'd want to see us cover. #E3CoverRequest",
       'Here is a Free service on Fiverr. Claim it now! https://t.co/rZBjtmK5wz #WUVIP',
       'NewsInSG: Get HR policy right, and the opportunities are endless: Alvin Yapp: The BusAds Pte Ltd d... https://t.co/g5b8Nu2vIL #Singapore'],
      dtype=object)

In [14]:
user_mask = np.all(np.array(user_positions)!=np.array([-1,-1]),axis=1)
np.array(user_positions)[user_mask]
train_sentences[user_mask]

array(['Happy birthday family 🙏🏼 @jumpman_9',
       'Manageable Freedom?: https://t.co/52ZKAna7kU via @YouTube',
       'I just got... SO EXCITED for @VidCon!!!!!', ...,
       '#tradewar may snip 5% off China GDP  https://t.co/IJmlRLWNld  @praveenasharma3',
       'Want to win a $100 Amazon GC or PayPal Cash? https://t.co/bmE47mvATE via @DonnaChaffins',
       'Do You Feel Like You Live in a Battlefield? https://t.co/XytH4nGTS4 via @biblegateway'],
      dtype=object)

In [15]:
url_positions_test = []
hashtag_positions_test = []
user_positions_test = []
for test_sentence in test_sentences:
    test_sentence_split = test_sentence.split()
    urls = [i for i in test_sentence_split if (i.startswith('https:') or i.startswith('http:'))]
    if(len(urls)==1 and urls[0]==test_sentence_split[-1]):
        start = test_sentence.find(urls[0])
        end = start+len(urls[0])
        url_positions_test.append([start,end])
    else:
        url_positions_test.append([-1,-1])

    hashtags = [i for i in test_sentence_split if i.startswith('#')]
    if(len(hashtags)==1 and hashtags[0]==test_sentence_split[-1]):
        start = test_sentence.find(hashtags[0])
        end = start+len(hashtags[0])
        hashtag_positions_test.append([start,end])
    else:
        hashtag_positions_test.append([-1,-1])

    users = [i for i in test_sentence_split if i.startswith('@')]
    if(len(users)==1 and users[0]==test_sentence_split[-1]):
        start = test_sentence.find(users[0])
        end = start+len(users[0])
        user_positions_test.append([start,end])
    else:
        user_positions_test.append([-1,-1])

In [16]:
url_mask_test = np.all(np.array(url_positions_test)!=np.array([-1,-1]),axis=1)
np.array(url_positions_test)[url_mask_test]
test_sentences[url_mask_test]

array(['Now Playing: ♬ Dick Curless - Evil Hearted Me ♬ https://t.co/fzgP9IRt2h',
       'Not only are you comfortably swaddled in security today, it’s ... More for Capricorn https://t.co/MVCHEli4g1',
       'Do These Two Lines Match Up On Your Hands Here s What It Means https://t.co/pC5IFtVBxk',
       ..., 'Lol hey, boo. 🤗🤗 https://t.co/cmQ0125gXM',
       'When you get a fall wedding invitation in the mail: https://t.co/fwdBnOM42e',
       'CONFIRMED: Today is it for the season for one of our favorite places for black raspberry (or any) ice cream.… https://t.co/tkpt0oSn2I'],
      dtype=object)

In [17]:
hashtag_mask_test = np.all(np.array(hashtag_positions_test)!=np.array([-1,-1]),axis=1)
np.array(hashtag_positions_test)[hashtag_mask_test]
test_sentences[hashtag_mask_test]

array(["New Indie Book Release: THE UNNAMED GIRL (by Mike H Mizrahi @MikeHMiz) &gt;\xa0https://t.co/DVZgymret4 &lt; 'Power of Love' Historical #MustRead",
       'Man shot to death by EPS had rifle in stolen vehicle: ASIRT https://t.co/cZSaiP8V6i #yeg',
       'How to Delete Specific Safari History on Mac https://t.co/Jhz9QeAydF #AppleTips',
       ...,
       "Barack Obama's final words got me in tears 😭😭 https://t.co/njYcpJ6PxW #ObamaFarewell",
       'Who wants to make billions of dollars? AND . . . is a biological engineer? Please create palm trees that can live in Canada #DeciduousTrees',
       'Bangladesh: Landslides Threaten Rohingya Shelters - Human Rights Watch https://t.co/Kj7g24wxOu #Bangladesh'],
      dtype=object)

In [18]:
user_mask_test = np.all(np.array(user_positions_test)!=np.array([-1,-1]),axis=1)
np.array(user_positions_test)[user_mask_test]
test_sentences[user_mask_test]

array(['These strawberry sandwich cookies are so easy to make and so tasty! Perfect for #MothersDay https://t.co/Uq7cooR2y7 via @iamthemaven',
       'http://t.co/fVhH7zdpWU      WALMART Underground Tunnels &amp; INSIDER INFO PROMPT CLOSINGS https://t.co/s5BfH7Wlva via @YouTube',
       'LETS GOOOO!!!! Whole HEAP of world premiers on the show tonight. WHERE U AT??? &gt;&gt;&gt;&gt;&gt; @1xtra',
       ...,
       'People adopt made-up social rules to be part of a group https://t.co/om4gsHvVhD by @j_timmer',
       'The Power of Grace in Daily Life https://t.co/uwsSjcbc9r via @bay_art',
       'Writing About Family: Advice, Second Thoughts, Xanax, and a Note to My Mom https://t.co/QMm7pm1nHM @glimmertrain'],
      dtype=object)

In [20]:
url_positions_start = np.array(url_positions)[url_mask][:,0]
url_positions_end = np.array(url_positions)[url_mask][:,1]

hashtag_positions_start = np.array(hashtag_positions)[hashtag_mask][:,0]
hashtag_positions_end = np.array(hashtag_positions)[hashtag_mask][:,1]

user_positions_start = np.array(user_positions)[user_mask][:,0]
user_positions_end = np.array(user_positions)[user_mask][:,1]

In [21]:
len(user_positions_start)

6352

In [22]:
url_positions_start_test = np.array(url_positions_test)[url_mask_test][:,0]
url_positions_end_test = np.array(url_positions_test)[url_mask_test][:,1]

hashtag_positions_start_test = np.array(hashtag_positions_test)[hashtag_mask_test][:,0]
hashtag_positions_end_test = np.array(hashtag_positions_test)[hashtag_mask_test][:,1]

user_positions_start_test = np.array(user_positions_test)[user_mask_test][:,0]
user_positions_end_test = np.array(user_positions_test)[user_mask_test][:,1]

In [23]:
train_sentences_url = ["[CLS] " + "url?" + " [SEP] " + sentence + " [SEP]" for sentence in train_sentences[url_mask]]
train_sentences_hashtag = ["[CLS] " + "hashtag?" + " [SEP] " + sentence + " [SEP]" for sentence in train_sentences[hashtag_mask]]
train_sentences_user = ["[CLS] " + "user?" + " [SEP] " + sentence + " [SEP]" for sentence in train_sentences[user_mask]]

In [24]:
test_sentences_url = ["[CLS] " + "url?" + " [SEP] " + sentence + " [SEP]" for sentence in test_sentences[url_mask_test]]
test_sentences_hashtag = ["[CLS] " + "hashtag?" + " [SEP] " + sentence + " [SEP]" for sentence in test_sentences[hashtag_mask_test]]
test_sentences_user = ["[CLS] " + "user?" + " [SEP] " + sentence + " [SEP]" for sentence in test_sentences[user_mask_test]]

In [25]:
MAX_LENGTH = 128

In [26]:
train_sentences_encoded = [tokenizer.encode(sent + " [SEP]") for sent in train_sentences[url_mask]] 
prefix = tokenizer.encode("[CLS] " + "url?" + " [SEP] " )
input_ids_url = [prefix + encoded_sent  for encoded_sent in train_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_url = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(train_sentences[url_mask]))]

train_sentences_encoded = [tokenizer.encode(sent + " [SEP]") for sent in train_sentences[hashtag_mask]] 
prefix = tokenizer.encode("[CLS] " + "hashtag?" + " [SEP] ")
input_ids_hashtag = [prefix + encoded_sent for encoded_sent in train_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_hashtag = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(train_sentences[hashtag_mask]))]

train_sentences_encoded = [tokenizer.encode(sent + " [SEP]") for sent in train_sentences[user_mask]] 
prefix = tokenizer.encode("[CLS] " + "user?" + " [SEP] ")
input_ids_user = [prefix + encoded_sent for encoded_sent in train_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_user = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(train_sentences[user_mask]))]

In [27]:
test_sentences_encoded = [tokenizer.encode(sent + " [SEP]") for sent in test_sentences[url_mask_test]] 
prefix = tokenizer.encode("[CLS] " + "url?" + " [SEP] " )
input_ids_url_test = [prefix + encoded_sent  for encoded_sent in test_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_url_test = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(test_sentences[url_mask_test]))]

test_sentences_encoded = [tokenizer.encode(sent + " [SEP]") for sent in test_sentences[hashtag_mask_test]] 
prefix = tokenizer.encode("[CLS] " + "hashtag?" + " [SEP] ")
input_ids_hashtag_test = [prefix + encoded_sent for encoded_sent in test_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_hashtag_test = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(test_sentences[hashtag_mask_test]))]

test_sentences_encoded = [tokenizer.encode(sent + " [SEP]") for sent in test_sentences[user_mask_test]] 
prefix = tokenizer.encode("[CLS] " + "user?" + " [SEP] ")
input_ids_user_test = [prefix + encoded_sent for encoded_sent in test_sentences_encoded]
len_a =len(prefix)
len_b = MAX_LENGTH-len(prefix)
token_type_ids_user_test = [[0 for x in range(0,len_a)]+[1 for x in range(0,len_b)] for i in range(0,len(test_sentences[user_mask_test]))]

In [30]:
def is_Sublist(l, s):
    i=-1
    sub_set = False
    if s == []:
        sub_set = True
    elif s == l:
        sub_set = True
    elif len(s) > len(l):
        sub_set = False

    else:
        for i in range(len(l)):
            if l[i] == s[0]:
                n = 1
                while (n < len(s)) and (l[i+n] == s[n]):
                    n += 1

                if n == len(s):
                    sub_set = True
                    break

    return sub_set,i

start and end positions are relative to the original sentence (train_sentences_url), we need to convert it to the position in input_ids_url

In [31]:
url_positions_start_ids = []
url_positions_end_ids = []
offset = len("[CLS] " + "url?" + " [SEP] ")
for i in range(0,len(train_sentences_url)):
    if(url_positions_start[i]!=-1 and url_positions_end[i]!=-1):
        url_ids = tokenizer.encode(train_sentences_url[i][offset+url_positions_start[i]:offset+url_positions_end[i]])
        result = is_Sublist(input_ids_url[i],url_ids)
        start = result[1]
        end = start+len(url_ids)
        url_positions_start_ids.append(start)
        url_positions_end_ids.append(end)
    else:
        url_positions_start_ids.append(-1)
        url_positions_end_ids.append(-1)

In [32]:
url_positions_start_ids_test = []
url_positions_end_ids_test = []
offset = len("[CLS] " + "url?" + " [SEP] ")
for i in range(0,len(test_sentences_url)):
    if(url_positions_start_test[i]!=-1 and url_positions_end_test[i]!=-1):
        url_ids = tokenizer.encode(test_sentences_url[i][offset+url_positions_start_test[i]:offset+url_positions_end_test[i]])
        result = is_Sublist(input_ids_url_test[i],url_ids)
        start = result[1]
        end = start+len(url_ids)
        url_positions_start_ids_test.append(start)
        url_positions_end_ids_test.append(end)
    else:
        url_positions_start_ids_test.append(-1)
        url_positions_end_ids_test.append(-1)

In [33]:
hashtag_positions_start_ids = []
hashtag_positions_end_ids = []
offset = len("[CLS] " + "hashtag?" + " [SEP] ")
for i in range(0,len(train_sentences_hashtag)):
    if(hashtag_positions_start[i]!=-1 and hashtag_positions_end[i]!=-1):
        hashtag_ids = tokenizer.encode(train_sentences_hashtag[i][offset+hashtag_positions_start[i]:offset+hashtag_positions_end[i]])
        result = is_Sublist(input_ids_hashtag[i],hashtag_ids)
        start = result[1]
        end = start+len(hashtag_ids)
        hashtag_positions_start_ids.append(start)
        hashtag_positions_end_ids.append(end)
    else:
        hashtag_positions_start_ids.append(-1)
        hashtag_positions_end_ids.append(-1)

In [43]:
hashtag_positions_start_ids_test = []
hashtag_positions_end_ids_test = []
offset = len("[CLS] " + "hashtag?" + " [SEP] ")
for i in range(0,len(test_sentences_hashtag)):
    if(hashtag_positions_start_test[i]!=-1 and hashtag_positions_end_test[i]!=-1):
        hashtag_ids = tokenizer.encode(test_sentences_hashtag[i][offset+hashtag_positions_start_test[i]:offset+hashtag_positions_end_test[i]])
        result = is_Sublist(input_ids_hashtag_test[i],hashtag_ids)
        start = result[1]
        end = start+len(hashtag_ids)
        hashtag_positions_start_ids_test.append(start)
        hashtag_positions_end_ids_test.append(end)
    else:
        hashtag_positions_start_ids_test.append(-1)
        hashtag_positions_end_ids_test.append(-1)

In [34]:
user_positions_start_ids = []
user_positions_end_ids = []
offset = len("[CLS] " + "user?" + " [SEP] ")
for i in range(0,len(train_sentences_user)):
    if(user_positions_start[i]!=-1 and user_positions_end[i]!=-1):
        user_ids = tokenizer.encode(train_sentences_user[i][offset+user_positions_start[i]:offset+user_positions_end[i]])
        result = is_Sublist(input_ids_user[i],user_ids)
        start = result[1]
        end = start+len(user_ids)
        user_positions_start_ids.append(start)
        user_positions_end_ids.append(end)
    else:
        user_positions_start_ids.append(-1)
        user_positions_end_ids.append(-1)

In [35]:
user_positions_start_ids_test = []
user_positions_end_ids_test = []
offset = len("[CLS] " + "user?" + " [SEP] ")
for i in range(0,len(test_sentences_user)):
    if(user_positions_start_test[i]!=-1 and user_positions_end_test[i]!=-1):
        user_ids = tokenizer.encode(test_sentences_user[i][offset+user_positions_start_test[i]:offset+user_positions_end_test[i]])
        result = is_Sublist(input_ids_user_test[i],user_ids)
        start = result[1]
        end = start+len(user_ids)
        user_positions_start_ids_test.append(start)
        user_positions_end_ids_test.append(end)
    else:
        user_positions_start_ids_test.append(-1)
        user_positions_end_ids_test.append(-1)

In [36]:
MAX_LEN = 128

In [37]:
train_input_ids_url = pad_sequences(input_ids_url, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
train_input_ids_hashtag = pad_sequences(input_ids_hashtag, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
train_input_ids_user = pad_sequences(input_ids_user, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

In [38]:
test_input_ids_url = pad_sequences(input_ids_url_test, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
test_input_ids_hashtag = pad_sequences(input_ids_hashtag_test, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
test_input_ids_user = pad_sequences(input_ids_user_test, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

In [41]:
train_inputs_url = torch.tensor(train_input_ids_url)
train_inputs_hashtag = torch.tensor(train_input_ids_hashtag)
train_inputs_user = torch.tensor(train_input_ids_user)

train_url_positions_start = torch.tensor(url_positions_start_ids)
train_hashtag_positions_start = torch.tensor(hashtag_positions_start_ids)
train_user_positions_start = torch.tensor(user_positions_start_ids)

train_url_positions_end = torch.tensor(url_positions_end_ids)
train_hashtag_positions_end = torch.tensor(hashtag_positions_end_ids)
train_user_positions_end = torch.tensor(user_positions_end_ids)

train_token_type_url = torch.tensor(token_type_ids_url)
train_token_type_hashtag = torch.tensor(token_type_ids_hashtag)
train_token_type_user = torch.tensor(token_type_ids_user)

In [44]:
test_inputs_url = torch.tensor(test_input_ids_url)
test_inputs_hashtag = torch.tensor(test_input_ids_hashtag)
test_inputs_user = torch.tensor(test_input_ids_user)

test_url_positions_start = torch.tensor(url_positions_start_ids_test)
test_hashtag_positions_start = torch.tensor(hashtag_positions_start_ids_test)
test_user_positions_start = torch.tensor(user_positions_start_ids_test)

test_url_positions_end = torch.tensor(url_positions_end_ids_test)
test_hashtag_positions_end = torch.tensor(hashtag_positions_end_ids_test)
test_user_positions_end = torch.tensor(user_positions_end_ids_test)

test_token_type_url = torch.tensor(token_type_ids_url_test)
test_token_type_hashtag = torch.tensor(token_type_ids_hashtag_test)
test_token_type_user = torch.tensor(token_type_ids_user_test)

In [45]:
train_inputs = torch.cat((train_inputs_url,train_inputs_hashtag, train_inputs_user),0)
train_positions_start = torch.cat((train_url_positions_start,train_hashtag_positions_start, train_user_positions_start),0)
train_positions_end = torch.cat((train_url_positions_end,train_hashtag_positions_end, train_user_positions_end),0)
train_token_type = torch.cat((train_token_type_url,train_token_type_hashtag, train_token_type_user),0)

In [46]:
test_inputs = torch.cat((test_inputs_url,test_inputs_hashtag, test_inputs_user),0)
test_positions_start = torch.cat((test_url_positions_start,test_hashtag_positions_start, test_user_positions_start),0)
test_positions_end = torch.cat((test_url_positions_end,test_hashtag_positions_end, test_user_positions_end),0)
test_token_type = torch.cat((test_token_type_url,test_token_type_hashtag, test_token_type_user),0)

### Create the generators

In [47]:
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset)
batch_size = 8

In [48]:
train_data = TensorDataset(train_inputs, 
                           train_positions_start, 
                           train_positions_end,
                           train_token_type)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

In [49]:
test_data = TensorDataset(test_inputs, 
                           test_positions_start, 
                           test_positions_end, 
                           test_token_type)
test_sampler = RandomSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

In [50]:
linear = nn.Linear(768, 2).to(torch.device("cuda:0"))

In [51]:
param_optimizer = list(linear.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters,
                     lr=2e-5,)

## Training

In [52]:
model = model.bert

In [72]:
# Store our loss and accuracy for plotting
train_loss_set = []

# Number of training epochs (authors recommend between 2 and 4)
epochs = 2

# trange is a tqdm wrapper around the normal python range
for _ in trange(epochs, desc="Epoch"):
  
  
  # Training
  
  # Set our model to training mode (as opposed to evaluation mode)
  linear.train()
    
  for step, batch in enumerate(train_dataloader):
    # Add batch to GPU
    ##batch = tuple(t.to(device) for t in batch)
    # Unpack the inputs from our dataloader
    b_inputs, b_positions_start, b_positions_end, b_token_type = batch
        
    sequence_output = model(b_inputs.to(device), token_type_ids=b_token_type.to(device))[0]
    logits = linear(sequence_output)
    start_logits, end_logits = logits.split(1, dim=-1)
    start_logits = start_logits.squeeze(-1)
    end_logits = end_logits.squeeze(-1)
    start_positions = b_positions_start.to(device)
    end_positions = b_positions_end.to(device)
    if len(start_positions.size()) > 1:
        start_positions = start_positions.squeeze(-1)
    if len(end_positions.size()) > 1:
        end_positions = end_positions.squeeze(-1)
    ignored_index = start_logits.size(1)
    start_positions.clamp_(0, ignored_index)
    end_positions.clamp_(0, ignored_index)
    loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
    start_loss = loss_fct(start_logits, start_positions )
    end_loss = loss_fct(end_logits, end_positions)
    total_loss = (start_loss + end_loss) / 2
    total_loss.backward(retain_graph=True)
    optimizer.step()

Epoch: 100%|██████████| 2/2 [1:54:50<00:00, 3444.22s/it]


## Evaluation

In [73]:
test_data_url = TensorDataset(test_inputs_url, 
                           test_token_type_url)
test_sampler_url = torch.utils.data.SequentialSampler(test_data_url)
test_dataloader_url = DataLoader(test_data_url, sampler=test_sampler_url, batch_size=batch_size)

In [74]:
linear.eval()

preds_url = []

for batch in test_dataloader_url:
    b_inputs, b_token_type = batch
    with torch.no_grad():
        sequence_output = model(b_inputs.to(device), token_type_ids=b_token_type.to(device))[0]
        logits = linear(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
    
    start_logits = start_logits.detach().cpu().numpy()
    end_logits = end_logits.detach().cpu().numpy()
    preds_url.append([start_logits,end_logits])

In [75]:
test_data_hashtag = TensorDataset(test_inputs_hashtag, 
                           test_token_type_hashtag)
test_sampler_hashtag = torch.utils.data.SequentialSampler(test_data_hashtag)
test_dataloader_hashtag = DataLoader(test_data_hashtag, sampler=test_sampler_hashtag, batch_size=batch_size)

In [76]:
preds_hashtag = []

for batch in test_dataloader_hashtag:
    b_inputs, b_token_type = batch
    with torch.no_grad():
        sequence_output = model(b_inputs.to(device), token_type_ids=b_token_type.to(device))[0]
        logits = linear(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
    
    start_logits = start_logits.detach().cpu().numpy()
    end_logits = end_logits.detach().cpu().numpy()
    preds_hashtag.append([start_logits,end_logits])

In [77]:
test_data_user = TensorDataset(test_inputs_user, 
                           test_token_type_user)
test_sampler_user = torch.utils.data.SequentialSampler(test_data_user)
test_dataloader_user = DataLoader(test_data_user, sampler=test_sampler_user, batch_size=batch_size)

In [78]:
preds_user = []

for batch in test_dataloader_user:
    b_inputs, b_token_type = batch
    with torch.no_grad():
        sequence_output = model(b_inputs.to(device), token_type_ids=b_token_type.to(device))[0]
        logits = linear(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
    
    start_logits = start_logits.detach().cpu().numpy()
    end_logits = end_logits.detach().cpu().numpy()
    preds_user.append([start_logits,end_logits])

In [79]:
predictions_url_start = []
predictions_url_end = []
for preds in preds_url:
    predictions_url_start = predictions_url_start + list(np.argmax(preds[0],axis=1))
    predictions_url_end = predictions_url_end + list(np.argmax(preds[1],axis=1))

In [80]:
predictions_hashtag_start = []
predictions_hashtag_end = []
for preds in preds_hashtag:
    predictions_hashtag_start = predictions_hashtag_start + list(np.argmax(preds[0],axis=1))
    predictions_hashtag_end = predictions_hashtag_end + list(np.argmax(preds[1],axis=1))

In [81]:
predictions_user_start = []
predictions_user_end = []
for preds in preds_user:
    predictions_user_start = predictions_user_start + list(np.argmax(preds[0],axis=1))
    predictions_user_end = predictions_user_end + list(np.argmax(preds[1],axis=1))

In [56]:
print(classification_report([0 if pos<0 else pos for pos in url_positions_start_ids_test], predictions_url_start, digits = 4))
print(classification_report([0 if pos<0 else pos for pos in url_positions_end_ids_test], predictions_url_end, digits = 4))

             precision    recall  f1-score   support

          0     0.0000    0.0000    0.0000         0
          5     0.0000    0.0000    0.0000         0
          6     0.9292    1.0000    0.9633       223
          7     0.9739    0.9982    0.9859       561
          8     0.9798    0.9986    0.9891       729
          9     0.9852    0.9813    0.9832      1015
         10     0.9154    0.9966    0.9543       890
         11     0.9710    0.9950    0.9829      1008
         12     0.9821    0.9942    0.9881      1211
         13     0.9807    0.9943    0.9875      1229
         14     0.9821    0.9852    0.9836      1280
         15     0.9945    0.9731    0.9837      1300
         16     0.9910    0.9758    0.9833      1239
         17     0.9838    0.9748    0.9793      1312
         18     0.9904    0.9842    0.9873      1262
         19     0.9769    0.9831    0.9800      1245
         20     0.9911    0.9770    0.9840      1259
         21     0.9771    0.9803    0.9787   

  'recall', 'true', average, warn_for)
  'precision', 'predicted', average, warn_for)


In [57]:
print(classification_report([0 if pos<0 else pos for pos in hashtag_positions_start_ids_test], predictions_hashtag_start, digits = 4))
print(classification_report([0 if pos<0 else pos for pos in hashtag_positions_end_ids_test], predictions_hashtag_end, digits = 4))

             precision    recall  f1-score   support

          0     0.0000    0.0000    0.0000         0
          1     0.0000    0.0000    0.0000         0
          2     0.0000    0.0000    0.0000         0
          5     0.0000    0.0000    0.0000         0
          6     0.1111    1.0000    0.2000         2
          7     0.4324    0.9412    0.5926        17
          8     0.6897    0.7692    0.7273        26
          9     0.5652    0.8125    0.6667        32
         10     0.5263    0.8824    0.6593        34
         11     0.3696    0.7727    0.5000        44
         12     0.4684    0.7551    0.5781        49
         13     0.3429    0.7500    0.4706        48
         14     0.4881    0.6212    0.5467        66
         15     0.3496    0.6418    0.4526        67
         16     0.4717    0.6329    0.5405        79
         17     0.3622    0.5542    0.4381        83
         18     0.4000    0.5641    0.4681        78
         19     0.4188    0.6203    0.5000   

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


In [58]:
print(classification_report([0 if pos<0 else pos for pos in user_positions_start_ids_test], predictions_user_start, digits = 4))
print(classification_report([0 if pos<0 else pos for pos in user_positions_end_ids_test], predictions_user_end, digits = 4))

             precision    recall  f1-score   support

          0     0.0000    0.0000    0.0000         0
          4     0.0000    0.0000    0.0000         0
          5     0.1250    1.0000    0.2222         1
          6     0.4000    1.0000    0.5714         6
          7     0.4783    0.8800    0.6197        25
          8     0.3438    0.9167    0.5000        12
          9     0.2439    0.7692    0.3704        13
         10     0.1864    0.7333    0.2973        15
         11     0.2540    0.7273    0.3765        22
         12     0.1458    0.5000    0.2258        14
         13     0.1714    0.8571    0.2857        14
         14     0.0526    0.5000    0.0952         8
         15     0.1268    0.7500    0.2169        12
         16     0.0909    0.4615    0.1519        13
         17     0.0312    0.2000    0.0541        10
         18     0.0317    0.1818    0.0541        11
         19     0.1071    0.3750    0.1667        16
         20     0.0769    0.3333    0.1250   

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


In [84]:
print(classification_report([0 if pos<0 else pos for pos in user_positions_start_ids_test], predictions_user_start, digits = 4))
print(classification_report([0 if pos<0 else pos for pos in user_positions_end_ids_test], predictions_user_end, digits = 4))

             precision    recall  f1-score   support

          0     0.0000    0.0000    0.0000         0
          4     0.0000    0.0000    0.0000         0
          5     0.1250    1.0000    0.2222         1
          6     0.4000    1.0000    0.5714         6
          7     0.4783    0.8800    0.6197        25
          8     0.3333    0.9167    0.4889        12
          9     0.2439    0.7692    0.3704        13
         10     0.1864    0.7333    0.2973        15
         11     0.2540    0.7273    0.3765        22
         12     0.1458    0.5000    0.2258        14
         13     0.1714    0.8571    0.2857        14
         14     0.0526    0.5000    0.0952         8
         15     0.1268    0.7500    0.2169        12
         16     0.0923    0.4615    0.1538        13
         17     0.0312    0.2000    0.0541        10
         18     0.0312    0.1818    0.0533        11
         19     0.1071    0.3750    0.1667        16
         20     0.0769    0.3333    0.1250   

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


In [134]:
[(pred,start) for (pred,start) in zip(predictions_user_start, user_positions_start_ids_test)][3]

(16, 38)

In [146]:
np.sum(np.array([tokenizer.convert_ids_to_tokens(test_inputs_user[i].numpy())[predictions_user_start[i]] for i in range(0,len(test_inputs_user))])=='@')/len(test_inputs_user)

0.18555133079847907

In [125]:
tokenizer.convert_ids_to_tokens(test_inputs_user[1][31:].item())

ValueError: only one element tensors can be converted to Python scalars

In [113]:
tokenizer.convert_ids_to_tokens(test_inputs_user[0][23].item())

'https'

In [112]:
user_positions_start_ids_test[0]

40

In [41]:
print(classification_report([('0' if pos_start<0 else str(pos_start))+","+('0' if pos_end<0 else str(pos_end)) for pos_start,pos_end in zip(url_positions_start_ids_test,url_positions_end_ids_test)], [str(pos_start)+","+str(pos_end) for pos_start,pos_end in zip(predictions_url_start,predictions_url_end)], digits = 4))
print(classification_report([('0' if pos_start<0 else str(pos_start))+","+('0' if pos_end<0 else str(pos_end)) for pos_start,pos_end in zip(hashtag_positions_start_ids_test,hashtag_positions_end_ids_test)], [str(pos_start)+","+str(pos_end) for pos_start,pos_end in zip(predictions_hashtag_start,predictions_hashtag_end)], digits = 4))
print(classification_report([('0' if pos_start<0 else str(pos_start))+","+('0' if pos_end<0 else str(pos_end)) for pos_start,pos_end in zip(user_positions_start_ids_test,user_positions_end_ids_test)], [str(pos_start)+","+str(pos_end) for pos_start,pos_end in zip(predictions_user_start,predictions_user_end)], digits = 4))

             precision    recall  f1-score   support

       0,15     0.0000    0.0000    0.0000         0
       0,16     0.0000    0.0000    0.0000         0
       0,27     0.0000    0.0000    0.0000         0
       0,30     0.0000    0.0000    0.0000         0
       0,31     0.0000    0.0000    0.0000         0
       0,32     0.0000    0.0000    0.0000         0
       0,33     0.0000    0.0000    0.0000         0
       0,34     0.0000    0.0000    0.0000         0
       0,35     0.0000    0.0000    0.0000         0
       0,36     0.0000    0.0000    0.0000         0
       0,37     0.0000    0.0000    0.0000         0
       0,38     0.0000    0.0000    0.0000         0
       0,39     0.0000    0.0000    0.0000         0
        0,4     0.0000    0.0000    0.0000         0
       0,40     0.0000    0.0000    0.0000         0
       0,41     0.0000    0.0000    0.0000         0
       0,42     0.0000    0.0000    0.0000         0
       0,43     0.0000    0.0000    0.0000   

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


In [83]:
print(classification_report([('0' if pos_start<0 else str(pos_start))+","+('0' if pos_end<0 else str(pos_end)) for pos_start,pos_end in zip(url_positions_start_ids_test,url_positions_end_ids_test)], [str(pos_start)+","+str(pos_end) for pos_start,pos_end in zip(predictions_url_start,predictions_url_end)], digits = 4))
print(classification_report([('0' if pos_start<0 else str(pos_start))+","+('0' if pos_end<0 else str(pos_end)) for pos_start,pos_end in zip(hashtag_positions_start_ids_test,hashtag_positions_end_ids_test)], [str(pos_start)+","+str(pos_end) for pos_start,pos_end in zip(predictions_hashtag_start,predictions_hashtag_end)], digits = 4))
print(classification_report([('0' if pos_start<0 else str(pos_start))+","+('0' if pos_end<0 else str(pos_end)) for pos_start,pos_end in zip(user_positions_start_ids_test,user_positions_end_ids_test)], [str(pos_start)+","+str(pos_end) for pos_start,pos_end in zip(predictions_user_start,predictions_user_end)], digits = 4))

             precision    recall  f1-score   support

       0,15     0.0000    0.0000    0.0000         0
       0,16     0.0000    0.0000    0.0000         0
       0,27     0.0000    0.0000    0.0000         0
       0,30     0.0000    0.0000    0.0000         0
       0,31     0.0000    0.0000    0.0000         0
       0,32     0.0000    0.0000    0.0000         0
       0,34     0.0000    0.0000    0.0000         0
       0,35     0.0000    0.0000    0.0000         0
       0,36     0.0000    0.0000    0.0000         0
       0,37     0.0000    0.0000    0.0000         0
       0,39     0.0000    0.0000    0.0000         0
        0,4     0.0000    0.0000    0.0000         0
       0,40     0.0000    0.0000    0.0000         0
       0,42     0.0000    0.0000    0.0000         0
       0,46     0.0000    0.0000    0.0000         0
       0,49     0.0000    0.0000    0.0000         0
        0,8     0.0000    0.0000    0.0000         0
      10,10     0.0000    0.0000    0.0000   

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)
