In [11]:
from sma_tools import flatten_mf_dict, MFDataset
from torch.utils.data import DataLoader
import torch
from transformers import BertModel, BertTokenizer
import jsonpickle
from emfdscore.load_mfds import * 
import numpy as np
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


## Run sentences in voice_dict through BERT base uncased

In [2]:
# Data loading
key_name = 'mfd_bert_uncased'
file_path = './data/pickle_240227/voice_mbu.pickle'
with open(file_path, 'r') as f:
    voice_dict = jsonpickle.decode(f.read())
# voice_sent_flatten = flatten_mf_dict(voice_dict, key_name, _size=32)
key_list, voice_sent_flatten = flatten_mf_dict(voice_dict, key_name)

# Model specs
model_name = 'bert-base-uncased'
model = BertModel.from_pretrained(model_name).to(mps_device)
tokenizer = BertTokenizer.from_pretrained(model_name)
batch_size = 16

In [3]:
len(voice_sent_flatten)

47548

In [4]:
# Compute the contextual representation of mfw
voice_dataset = MFDataset(voice_sent_flatten)
data_loader = DataLoader(voice_dataset, batch_size=batch_size)

sents_contrepr = []
for batch in data_loader: # ~7.5 mins for the voice dict
    sents_batch = batch
    tokenized_batch = tokenizer(
        sents_batch, 
        padding=True, 
        truncation=True,
        add_special_tokens=False, # CLS token is not needed in sma construction.
        return_tensors="pt")
    input_ids = tokenized_batch['input_ids']    
    attention_mask = tokenized_batch['attention_mask']

    with torch.no_grad():
        batch_output = model(input_ids.to(mps_device), attention_mask=attention_mask.to(mps_device))
        batch_context_output = batch_output.last_hidden_state.cpu().numpy()
        for sent_contrepr in batch_context_output:
            assert sent_contrepr.ndim == 2, f"The dimension is unexpecuted: {sent_contrepr.ndim}"
            sents_contrepr.append(sent_contrepr) 




In [10]:
print(len(key_list))
print('Length of the input sentence list:', len(voice_dataset))
print('Sanity check for the length of output coolected:', len(sents_contrepr))
print('Sanity check context vectors for a sentence/token:', sents_contrepr[0][83].shape)

4541
Length of the input sentence list: 47548
Sanity check for the length of output coolected: 47548
Sanity check context vectors for a sentence: (768,)


## Aggregate moral foundation token vectors under each foundation for MDF

In [18]:
# Get vice virtue code in mdf
mdf_vv_code = set()
for v in mfd.values():
    for v2 in v:
        mdf_vv_code.add(v2)

sma_aggregation = {k: [] for k in mdf_vv_code}

In [20]:
sentence_list_counter = 0
for key_num in range(len(key_list)):
    key = key_list[key_num]
    mf_info = voice_dict[key][key_name]

    # A collection below indicates that the data included belong to multiple sentences.
    token_collection = mf_info[2] 
    vice_virtue_collection = mf_info[3]
    assert len(token_collection) == len(vice_virtue_collection), f"The lengths of token_collection {len(token_collection)} and vice_virtue_collection {len(vice_virtue_collection)} should be the same."
    
    # loop through each sentence with mf words
    for i in range(len(token_collection)):
        sent_vector = sents_contrepr[sentence_list_counter]
        idx_list = token_collection[i] 
        vv_list = vice_virtue_collection[i] # This is a list

        # The below code runs for every token identified as MF words previously.
        for j in range(len(idx_list)):
            idx = idx_list[j] # This is int
            vv = vv_list[j] # This is a list. As a token can be assigned multiple foundations.

            try:
                target_vector = sent_vector[idx]
            except IndexError: # Caused by abnormally long sentences which were probably truncated when running the sentence through BERT
                continue
            for v in vv: # Sorry for the confusion caused by the use of `v` here.
                sma_aggregation[v].append(target_vector)
    
        sentence_list_counter += 1

        
    

In [22]:
for k in sma_aggregation:
    print(k, len(sma_aggregation[k]))

fairness.virtue 8937
authority.virtue 26582
sanctity.vice 996
fairness.vice 1694
loyalty.vice 1879
authority.vice 1756
loyalty.virtue 43781
care.vice 6776
moral 12629
sanctity.virtue 2123
care.virtue 5707


In [24]:
sentence_list_counter

47548

In [25]:
key_num

4540

In [23]:
sma_aggregation['fairness.vice'][0].shape

(768,)