In [1]:
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import TextClassificationPipeline
import pandas as pd
import numpy as np
import tqdm
from tqdm import tqdm
import sys, os
from transformers import Pipeline
from torch import Tensor 
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch.nn as nn
import sentence_transformers

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Data and Model

In [2]:
DATASET_REPO = "shreyahavaldar/multilingual_politeness"
MODEL_REPO = "shreyahavaldar/xlm-roberta-politeness"

def load_data():
    hf_dataset = load_dataset(DATASET_REPO)
    #convert to torch dataset
    hf_dataset.set_format(type='torch')
    return hf_dataset

def load_model():
    model = AutoModel.from_pretrained(MODEL_REPO)
    #convert to pytorch model
    torch_model = nn.Sequential(model, nn.Linear(model.config.hidden_size, 1))
    model.to(device)
    return torch_model

### Define Alignment Metric

In [3]:
class Metric(nn.Module): 
    def __init__(self, model_name:str="distiluse-base-multilingual-cased"): 
        super(Metric, self).__init__()
        self.model = sentence_transformers.SentenceTransformer(model_name)
        self.centroids = self.get_centroids()
    
    def get_centroids(self):
        # read lexica files
        languages = ["english", "spanish", "chinese", "japanese"]
        lexica = {}
        for l in languages:
            filepath = f"../src/exlib/utils/politeness_lexica/{l}_politelex.csv"
            lexica[l] = pd.read_csv(filepath)

        # create centroids
        all_centroids = {}        
        for l in languages:
            categories = lexica[l]["CATEGORY"].unique()
            centroids = {}
            for c in categories:
                words = lexica[l][lexica[l]["CATEGORY"] == c]["word"].tolist()
                embeddings = self.model.encode(words)
                centroid = np.mean(embeddings, axis=0)
                centroids[c] = centroid
            assert len(categories) == len(centroids.keys())
            all_centroids[l] = centroids
            print(f"Centroids for {l} created.")
        return all_centroids

    # input: list of words
    def calculate_single_group_alignment(self, group:list, language:str="english"):
        #find max avg cos sim between word embeddings and centroids
        category_similarities = {}
        centroids = self.centroids[language]
        for category, centroid_emb in centroids.items():
            #calculate cosine similarity
            cos_sim = []
            for word in group:
                word_emb = self.model.encode(word)
                cos_sim.append(np.dot(word_emb, centroid_emb) / (np.linalg.norm(word_emb) * np.linalg.norm(centroid_emb)))
            avg_cos_sim = np.mean(cos_sim)
            category_similarities[category] = avg_cos_sim
        #return highest similarity score
        return max(category_similarities.values())

    def calculate_group_alignment(self, groups:list, language:str="english"):
        group_alignments = []
        for group in groups:
            group_alignments.append(self.calculate_single_group_alignment(group, language))
        return group_alignments

### Example Group Alignment Calculation

In [4]:
metric = Metric()
sample_groups = [["dog", "cat", "fish"], 
                ["hello", "goodbye", "please"], 
                ["computer", "laptop", "phone"], 
                ["idiot", "stupid", "dumb"], 
                ["thank you", "grateful", "thanks"]]
alignments = metric.calculate_group_alignment(sample_groups)
for group, alignment in zip(sample_groups, alignments):
    print(f"Group: {group}, Alignment: {alignment}")

Centroids for english created.
Centroids for spanish created.
Centroids for chinese created.
Centroids for japanese created.
Group: ['dog', 'cat', 'fish'], Alignment: 0.5292773842811584
Group: ['hello', 'goodbye', 'please'], Alignment: 0.7011184692382812
Group: ['computer', 'laptop', 'phone'], Alignment: 0.4826013147830963
Group: ['idiot', 'stupid', 'dumb'], Alignment: 0.7102837562561035
Group: ['thank you', 'grateful', 'thanks'], Alignment: 0.9256609082221985
