In [9]:
import torch
from torch import nn
from transformers import LongformerModel, LongformerTokenizer
from transformers import BertTokenizer, BertModel


class LFEmbeddingModule(nn.Module):
    def __init__(self, args, device):
        super(LFEmbeddingModule, self).__init__()
        self.args = args
        self.device = device
        if 'longformer' in self.args['model']:
            self.lf_model = LongformerModel.from_pretrained(self.args['model'], output_hidden_states=True, output_attentions=True).to(device)
            self.lf_tokenizer = LongformerTokenizer.from_pretrained(self.args['model'])
        else:
            self.lf_model = BertModel.from_pretrained(self.args['model'], output_hidden_states=True, output_attentions=True).to(device)
            self.lf_tokenizer = BertTokenizer.from_pretrained(self.args['model'])
            
    def get_embeddings(self, comments, titles, descriptions, transcripts, other_comments):
        indexed_cs = []
        indexed_tok = []
        max_len_total = self.args['max_len']
        max_len_title = self.args['title_token_count']
        max_len_desc = self.args['desc_token_count']
        max_len_trans = self.args['transcript_token_count']
        max_len_other_comments = self.args['other_comments_token_count']
        padding = 'max_length' if self.args['pad_metadata'] else False
        for comment, title, desc, transcript, other_comment in zip(comments, titles, descriptions, transcripts, other_comments):

            input_c = self.lf_tokenizer.encode_plus(comment, max_length=max_len_total, padding=False, truncation=True)
            enc_c = input_c['input_ids']
            tok_c = input_c['token_type_ids']
            if self.args['add_title']:
                input_t = self.lf_tokenizer.encode_plus(title, max_length=max_len_title, padding=padding, truncation=True)
                enc_t = input_t['input_ids']
                tok_t = input_t['token_type_ids']
                enc_c.extend(enc_t[1:])
                tok_c.extend(tok_t[1:])
                
            if self.args['add_description']:
                input_d = self.lf_tokenizer.encode_plus(desc, max_length=max_len_desc, padding=padding, truncation=True)
                enc_d = input_d['input_ids']
                tok_d = input_d['token_type_ids']
                enc_c.extend(enc_d[1:])
                tok_c.extend(tok_d[1:])
                
            if self.args['add_transcription']:
                input_tr = self.lf_tokenizer.encode_plus(transcript, max_length=max_len_trans, padding=padding, truncation=True)
                enc_tr = input_tr['input_ids']
                tok_tr = input_tr['token_type_ids']
                enc_c.extend(enc_tr[1:])
                tok_c.extend(tok_tr[1:])
                
            if self.args['add_other_comments']:
                input_oc = self.lf_tokenizer.encode_plus(other_comment, max_length=max_len_other_comments, padding=padding, truncation=True)
                enc_oc = input_oc['input_ids']
                tok_oc = input_oc['token_type_ids']
                enc_c.extend(enc_oc[1:])
                tok_c.extend(tok_oc[1:])
                
            enc_c = enc_c[:max_len_total]
            tok_c = tok_c[:max_len_total]
            # enc_c.extend((max_len_total - len(enc_c))*[self.lf_tokenizer.pad_token_id])
            # tok_c.extend((max_len_total - len(tok_c))*[0])
            indexed_cs.append(enc_c)
            indexed_tok.append(tok_c)
        indexed_cs = torch.tensor(indexed_cs).to(self.device)
        indexed_tok = torch.tensor(indexed_tok).to(self.device)
        return indexed_cs, indexed_tok
    
class CommentModel(nn.Module):
    def __init__(self, args):
        super(CommentModel, self).__init__()
        self.args = args
        if 'base' in self.args['model']:
            self.fc_size = 768
        else:
            self.fc_size = 1024   
        self.fc = nn.Sequential(
            nn.Linear(self.fc_size, 1),
            nn.Sigmoid()
        )

    def forward(self, text_emb):
        out = self.fc(text_emb)
        return out


In [10]:
args = {
    'model': 'bert-large-cased',
    'max_len': 512,
    'add_title': True,
    'title_token_count': 40,
    'add_description': True,
    'desc_token_count': 80,
    'add_transcription': True,
    'transcript_token_count': 200,
    'add_other_comments': True,
    'other_comments_token_count': 512,
    'pad_metadata': False
}
device = torch.device('cpu')
lf_model = LFEmbeddingModule(args, device)
comment_model = CommentModel(args).to(device)
criterion = nn.BCELoss().to(device)

Some weights of the model checkpoint at bert-large-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
import os
def load_weights(lf_model, comment_model, device):
    lf_path = os.path.join('./11Jul2022_ethereal-blaze-26/lf_model_ethereal-blaze-26.pth.tar')
    comment_path = os.path.join('./11Jul2022_ethereal-blaze-26/comment_model_ethereal-blaze-26.pth.tar')
    lf_checkpoint = torch.load(lf_path, map_location=device)
    comment_checkpoint = torch.load(comment_path, map_location=device)
    lf_model.lf_model.load_state_dict(lf_checkpoint['state_dict'])
    comment_model.load_state_dict(comment_checkpoint['state_dict'])
    return lf_model, comment_model

In [12]:
lf_model, comment_model = load_weights(lf_model, comment_model, device)

In [13]:
max_len_total=512
comments = ['"Minister Freeland" = kyke controlling your bank account.']
titles = ["Canada tells people how they can get their bank accounts unfrozen if they donated to the Trucker c.."]
descriptions = ["trucker convoy liberal shutdown anti vaccine bank accounts canada tells people movements look reacted agree unfrozen donated machine machine does antifa versus mandate protests http ones raging blm react does agree support liberal think look people versus donated trucker mandate does tells reacted tells people bank canada reacted blm antifa anti vaccine mandate protests shutdown raging machine machine protests http think ones accounts unfrozen look reacted blm support movements react anti shutdown liberal versus react anti support donated trucker convoy trucker donated agree support movements convoy mandate protests vaccine does agree convoy liberal think liberal think ones"]
transcripts = ["banks emergencies act constituents donated according rcmp rcmp unfrozen stop blockade occupation person certain account frozen mistaken use powers turn minister freeland leaders organizers protests including small action vet connection swap behle avec 50 needs speak police people trucks institutions names mp stated illegal action given financial government going concerned information financial institutions unfreeze going donor including frozen participation way occupation way account conservative use person certain wondering case mistaken ll organization communicate police organization avec person needs 50 bank accounts turn names granted frozen illegal emergencies donated small connection according rcmp given financial participation illegal blockades concerned accounts wondering granted banks small amounts communicate financial institutions blockade swap institutions concerned accounts donors ll turn unfreeze freeland rcmp act person certain account stop emergencies act government frozen wondering case trucks occupations blockades way account unfrozen speak given unfreeze accounts participation illegal going powers granted government going donor accounts frozen participation person needs including small donors people connection occupation blockade leaders bank illegal action vet occupations case certain blockades information according behle avec person minister account protests unfrozen stop blockade stated constituents act government organizers mistaken use bank accounts frozen swap protests people trucks powers granted banks stated institutions unfreeze accounts occupation person certain including granted banks emergencies vet small donors ll bank accounts institutions concerned donors names leaders organizers police frozen participation illegal frozen illegal action occupations blockades information way account occupation blockade swap behle donated financial institutions names account frozen illegal minister freeland rcmp participation freeland needs donor including small conservative mp institutions unfreeze small donated small amounts stop blockade occupation frozen account unfrozen stop banks emergencies mistaken financial institutions concerned organization communicate financial person government mp stated constituents blockades organization ll turn minister action illegal amounts 50 bank communicate accounts frozen wondering institutions names leaders use powers donors ll police organization communicate amounts 50 vet connection occupation wondering case"]
other_comments = [""]

input_ids, token_type_ids = lf_model.get_embeddings(comments, titles, descriptions, transcripts, other_comments)
attention = lf_model.lf_model(input_ids)[-1]
# sentence_b_start = token_type_ids[0].tolist().index(1)
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = lf_model.lf_tokenizer.convert_ids_to_tokens(input_id_list) 

In [14]:
from bertviz import head_view
head_view(attention, tokens)