In [1]:
import torch
import pandas as pd
import numpy as np
import re
from transformers import PreTrainedTokenizerFast
import unidecode
import yaml
from typing import List

from src.architectures.Classifier import *
from src.datamodules.LSTMDataModule import *
from src.datamodules.LSTMDataModule import *

In [2]:
# Load & configure tokeninzer
tokenizer = PreTrainedTokenizerFast.from_pretrained('./tokenizers/gpt2_2k')
tokenizer.add_special_tokens({'pad_token': '[PAD]', 'mask_token': '[MASK]'})

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


0

In [3]:
# Load config file for the model
with open(f'./configs/lstm.yaml', 'r') as in_file:
        cfg = yaml.load(in_file, Loader=yaml.FullLoader)
cfg['model']['architecture']['pad_id'] = tokenizer.vocab['[PAD]']
cfg['model']['architecture']['mask_id'] = tokenizer.vocab['[MASK]']

In [4]:
# Function to clean & normlize text
def clean(text):
    t = text.lower()
    t = t.replace('\\n', ' ').replace('\\t', ' ').replace('\t', ' ').replace('. com', '.com')
    t = re.sub(r'https?:\/\/[a-z.\/A-Z\d]*', ' ', t)
    t = re.sub(r"\ [A-Za-z]*\.com", ' ', t)
    t = re.sub(r"@\S+", '', t)
    t = t.replace('@', '')
    t = unidecode.unidecode(t)
    t = t.replace('#', '_')
    
    to_replace = ["&quot;", ':&lt;', ':&gt;', '&amp;', '-&lt;', '-&gt;', '=&lt;', '=&gt;', 's&lt;', 's&gt;']
    for x in to_replace:
        t = t.replace(x, '')
    pattern = re.compile(r"([A-Za-z])\1{1,}", re.DOTALL)
    t = pattern.sub(r"\1\1", t)
    
    pattern = re.compile(r"([\s.,\/#!$%^&*?;:{}=_`()+-])\1{1,}")
    t = pattern.sub(r'\1', t)
    t = re.sub(' {2,}', '', t)
    t = t.lower()
    t = t.strip()
    t = t.rstrip()
    return t

In [5]:
# Convert list of twitts into tensors
def collate_fn(batch, tokenizer):
    enc = tokenizer([x for x in batch], padding=True, return_tensors='pt')['input_ids']
    inpt = enc.transpose(0, 1)
    return inpt

In [6]:
# Load the model
model = Classifier(cfg['model']['architecture'])
model.load_state_dict(torch.load('./models/lstm_128.pt'))
model.eval()

Classifier(
  (encoder): LstmEncoder(
    (embs): Embedding(2000, 128)
    (embs_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (embs_dp): Dropout(p=0.1, inplace=False)
    (lstm): LSTM(128, 128)
  )
  (fc): ClassifierHead(
    (dense1): Linear(in_features=128, out_features=128, bias=True)
    (activation): Tanh()
    (dp): Dropout(p=0.1, inplace=False)
    (dense2): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [7]:
# Function to evaluate the model on the data.
# in:
#    data: list of twitts
# out:
#    res: a list of tuples (twitt, pred)

def predict(data: List[str]):
    data_clean = [clean(t) for t in data]
    with torch.inference_mode():
        preds = model.forward(collate_fn(data_clean, tokenizer))
        preds = torch.sigmoid(preds.flatten())
    return list(zip(data, preds.tolist()))

In [8]:
data = [
    "I just created my first LaTeX file from scratch. That didn't work out very well. (See @amandabittner , it's a great time waster)",
    "AHH YES LOL IMA TELL MY HUBBY TO GO GET ME SUM MCDONALDS =]",
    "RT @shrop: Awesome JQuery reference book for Coda! http://www.macpeeps.com/coda/ #webdesign"
]

In [9]:
predict(data)

[("I just created my first LaTeX file from scratch. That didn't work out very well. (See @amandabittner , it's a great time waster)",
  0.25448372960090637),
 ('AHH YES LOL IMA TELL MY HUBBY TO GO GET ME SUM MCDONALDS =]',
  0.8318528532981873),
 ('RT @shrop: Awesome JQuery reference book for Coda! http://www.macpeeps.com/coda/ #webdesign',
  0.9800198078155518)]