# Initialisation

In [7]:
#!pip install torch transformers sentencepiece

## Imports

In [8]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, AutoModelForCausalLM , AutoTokenizer, pipeline, RobertaTokenizer, RobertaForMaskedLM, AlbertTokenizer, AlbertModel, AlbertForMaskedLM
import logging
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import re
from afinn import Afinn
from tqdm import tqdm
import unidecode
from time import sleep
logging.basicConfig(level=logging.INFO)# OPTIONAL

## MAC Settings

In [9]:
print(f"PyTorch version: {torch.__version__}")

# Set the device      
device = "mps" if torch.backends.mps.is_available() else torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")

PyTorch version: 1.13.1
Using device: mps


## Definitions

In [10]:
IDENTITIES = 'identities'
ORIENTATION = 'orientation'
OTHER = 'other'
PRONOUNS = 'pronouns'
MASKBERT_ = '\[MASK\]'
MASKBERT= '[MASK]'
MASKROBERT = '<mask>'
TARGET = '<target>'
NOM = '<nom>'
POSDEP = '<pos_dep>'
ACC = '<acc>'
BE = '<be>'
QUEER = 'queer'
NONQUEER = 'non-queer'

#TEMPLATES
TEMPLATE_NOZZA = '../src/templates/template_nozza.csv'
TEMPLATE_NOZZA_COMPLETE = '../src/templates/template_nozza_complete.csv'
TEMPLATE_TOXIC2 = '../src/templates/template_toxic2.csv'
TEMPLATE_TOXIC2_COMPLETE = '../src/templates/template_toxic2_complete.csv'
TEMPLATE_TOXIC1 = '../src/templates/template_toxic1.csv'
TEMPLATE_TOXIC1_COMPLETE = '../src/templates/template_toxic1_complete.csv'
TEMPLATE_TOXIC1_CHUNK = '../src/templates/toxic1/template_toxic1'
TEMPLATE_TOXIC2_CHUNK = '../src/templates/toxic2/template_toxic2'
PREDICTION_PATH = "../src/prediction/"

#IDENTITIES CSV
IDENTITIES_CSV = '../src/queer_identities/identities.csv'
PRONOUNS_CSV = '../src/queer_identities/pronouns.csv'

#MODELS
BERT_BASE = 'bert-base-uncased'
BERT_LARGE = 'bert-large-uncased'
ROBERTA_BASE = 'roberta-base'
ROBERTA_LARGE = 'roberta-large'
GPT2 = 'gpt2'

# Template Prediction class

In [11]:
class TemplatePrediction:
    def __init__(self, model_name, template_path, numAtt):
        self.numAtt = numAtt
        self.template_path = template_path
        self.model_name = model_name
        self.model, self.tokenizer = self.get_tokenizer()
        self.template_prediction()

    def get_tokenizer(self):
        if((self.model_name == BERT_BASE) or (self.model_name == BERT_LARGE)):
            model = BertForMaskedLM.from_pretrained(self.model_name)
            tokenizer = BertTokenizer.from_pretrained(self.model_name)
        else:
            if((self.model_name == ROBERTA_BASE) or (self.model_name == ROBERTA_LARGE)):
                    model = RobertaForMaskedLM.from_pretrained(self.model_name)
                    tokenizer = RobertaTokenizer.from_pretrained(self.model_name)
            else: 
                if(self.model_name == GPT2):
                    model = AutoModelForCausalLM.from_pretrained(self.model_name)
                    tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        return model, tokenizer
    
    def template_prediction(self):
        if self.template_path == TEMPLATE_NOZZA_COMPLETE:
            self.template_nozza()
        else:
            if self.template_path == TEMPLATE_TOXIC1_COMPLETE:
                self.template_toxic1()
                

    def template_nozza(self):
        prediction = []
        template = pd.read_csv(self.template_path, sep=";")
        for index,row in tqdm(template.iterrows(), total=template.shape[0], desc='Predicting mask', unit='sentences'):
            sentence = row.loc['new_template']
            model_prediction = self.model_prediction(sentence)
            prediction.append(model_prediction)
        template.loc[:,'prediction'] = prediction
        display(template)
        template.to_csv(PREDICTION_PATH+self.model_name+"/template_nozza.csv", sep=';')

    def template_toxic1(self):
        prediction = []
        for i in tqdm(range(20), total=20, desc='Chunk number', unit='chunk'):
            template = pd.read_csv(TEMPLATE_TOXIC1_CHUNK+'_chunk{}.csv'.format(i)) #, sep=';', dtype={'new_template':'category', 'identity': 'category', 'type_identity': 'category','pronoun': 'category', 'pronouns_type': 'category'})
            for index,row in tqdm(template.iterrows(), total=template.shape[0], desc='Predicting mask', unit='sentences'):
                sentence = row.loc['new_template']
                print(sentence)
                model_prediction = self.model_prediction(sentence)
                prediction.append(model_prediction)
            self.template.loc[:,'prediction'] = prediction
            display(self.template)
            #self.template.to_csv(TEMPLATE_TOXIC1_CHUNK+'_chunk{}.csv', sep=';')

    def model_prediction(self, text):
        if((self.model_name == BERT_BASE) or (self.model_name == BERT_LARGE)):
            return self.bert_prediction(text)
        else:
            if((self.model_name == ROBERTA_BASE) or (self.model_name == ROBERTA_LARGE)):
                return self.roberta_prediction(text)
            else:
                if(self.model_name == GPT2):
                        return self.gpt2_prediction(text)
                
    def bert_prediction(self, text):
        text = "[CLS] %s [SEP]"%text
        #print(text)
        tokenized_text = self.tokenizer.tokenize(text)
        masked_index = tokenized_text.index(MASKBERT)
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens_tensor = torch.tensor([indexed_tokens])
        with torch.no_grad():
            output = self.model(tokens_tensor)
            predictions = output[0]

        probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
        top_k_weights, top_k_indices = torch.topk(probs, self.numAtt, sorted=True)

        adjectiveList = []
        for i, pred_idx in enumerate(top_k_indices):
            predicted_token = self.tokenizer.convert_ids_to_tokens([pred_idx])[0]
            token_weight = top_k_weights[i]
            #print(predicted_token)
            #print(token_weight.item()*100)
            adjectiveList.append(predicted_token)
        return adjectiveList
    
    def roberta_prediction(self, text):
        text = re.sub(MASKBERT_, MASKROBERT, text)
        text = "<s> %s </s>"%text
        #print(text)
        tokenized_text = self.tokenizer.tokenize(text)
        #print(tokenized_text)
        masked_index = tokenized_text.index(MASKROBERT)
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens_tensor = torch.tensor([indexed_tokens])
        with torch.no_grad():
            output = self.model(tokens_tensor)
            predictions = output[0]

        probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
        top_k_weights, top_k_indices = torch.topk(probs, self.numAtt, sorted=True)

        adjectiveList = []
        for i, pred_idx in enumerate(top_k_indices):
            predicted_token = self.tokenizer.convert_ids_to_tokens([pred_idx])[0]
            predicted_token = re.sub('Ġ', '', predicted_token)
            token_weight = top_k_weights[i]
            print(predicted_token)
            print(token_weight.item()*100)
            adjectiveList.append(predicted_token)
        return adjectiveList
        
    def gpt2_prediction(self, text):
        inputs = self.tokenizer.encode(text, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(inputs)
            predictions = outputs[0]
        next_token_candidates_tensor = predictions[0, -1, :]
        topk_candidates_indexes = torch.topk(next_token_candidates_tensor, self.numAtt).indices.tolist()
        #all_candidates_probabilities = torch.nn.functional.softmax(next_token_candidates_tensor, dim=-1)
        #topk_candidates_probabilities = all_candidates_probabilities[topk_candidates_indexes].tolist()
        topk_candidates_tokens = [self.tokenizer.decode([idx]).strip() for idx in topk_candidates_indexes]
        return list(topk_candidates_tokens)

In [12]:
TemplatePrediction(BERT_BASE, TEMPLATE_NOZZA_COMPLETE, 1)