[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Lisa-dk/LTP/blob/main/masked_model.ipynb)

In [None]:
!pip install datasets transformers

import datasets
import numpy as np
import ast
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

In [5]:
def labels_to_text(labels_string):
    return [label_values_single[i] for i in range(len(label_values_single)) if labels_string[i] == 1]

def create_argument_1(row):
    conclusion = row['Conclusion'].rstrip('.').lower()
    premise = row['Premise'].rstrip('.').lower()
    stance = row['Stance']
    argument = f"The premise {premise} is {stance} the conclusion that {conclusion}. The human values {', '.join(row['Labels_text'])} support this argument."
    return argument

    
data = datasets.load_dataset('webis/Touche23-ValueEval')
label_values = ["Self-direction: thought", "Self-direction: action", "Stimulation", "Hedonism", "Achievement", "Power: dominance", "Power: resources", "Face", "Security: personal", "Security: societal", "Tradition", "Conformity: rules", "Conformity: interpersonal", "Humility", "Benevolence: caring", "Benevolence: dependability", "Universalism: concern", "Universalism: nature", "Universalism: tolerance", "Universalism: objectivity"]
label_values_single = ["thought", "action", "Stimulation", "Hedonism", "Achievement", "dominance", "resources", "Face", "personal", "societal", "Tradition", "rules", "interpersonal", "Humility", "caring", "dependability", "concern", "nature", "tolerance", "objectivity"]

data_train = data['training'].to_pandas()
data_train['Labels_as_list'] = data_train['Labels'].map(lambda x: ast.literal_eval(x))
data_train['Labels_text'] = data_train['Labels_as_list'].map(lambda x: labels_to_text(x))

data_train["Argument_1"] = data_train.apply(create_argument_1, axis=1)

data_train.head()




  0%|          | 0/3 [00:00<?, ?it/s]

Unnamed: 0,Argument ID,Conclusion,Stance,Premise,Labels,Labels_as_list,Labels_text,Argument_1
0,A01002,We should ban human cloning,in favor of,we should ban human cloning as it will only ca...,"[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...",[societal],The premise we should ban human cloning as it ...
1,A01005,We should ban fast food,in favor of,fast food should be banned because it is reall...,"[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0]","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...",[personal],The premise fast food should be banned because...
2,A01006,We should end the use of economic sanctions,against,sometimes economic sanctions are the only thin...,"[0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0]","[0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...","[dominance, societal]",The premise sometimes economic sanctions are t...
3,A01007,We should abolish capital punishment,against,capital punishment is sometimes the only optio...,"[0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,1,0,0,0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, ...","[societal, rules, concern]",The premise capital punishment is sometimes th...
4,A01008,We should ban factory farming,against,factory farming allows for the production of c...,"[0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0]","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, ...","[personal, caring, concern]",The premise factory farming allows for the pro...


# Model for mask prediction

In [11]:
model_checkpoint = "distilbert-base-uncased"
distilbert_model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
distilbert_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

text = '''The premise we should ban human cloning as it will only cause huge issues when you have a bunch of the same humans 
        running around all acting the same is in favor of the conclusion that we should ban human cloning. 
        The human value [MASK] supports this argument.'''

inputs = distilbert_tokenizer(text, return_tensors="pt")
token_logits = distilbert_model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == distilbert_tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()


print("Correct: societal")
for token in top_5_tokens:
    print(distilbert_tokenizer.decode([token]))

Correct: societal
foundation
theory
association
society
forum
