# Loading and Reformatting

In [None]:
from typing import List

import pandas as pd
import torch
from torch import Tensor

filename = 'gogi_chats'
messages = pd.read_csv(f'data/{filename}.csv')

In [None]:
def mean_non_na(series: pd.Series) -> float:
    as_numbers = pd.to_numeric(series, errors='coerce')
    return as_numbers.dropna().mean()

eval_dimensions = ['friendliness', 'helpfulness', 'clearness', 'astuteness', 'tactfulness']
eval_agg_funcs = {dim: mean_non_na for dim in eval_dimensions}
chats = messages[['conversation_id', 'message', *eval_dimensions]].groupby('conversation_id').agg(
    {
        **eval_agg_funcs,
        'message': lambda x: list(x)
    }
).reset_index().rename(columns={'message': 'chat'})
chats = chats.dropna(subset=eval_dimensions)
chats.head()

# Node Encodings

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize_chat(chat: List[str]) -> List[Tensor]:
    return [
        tokenizer(u, padding=True, truncation=True, return_tensors="pt")['input_ids'].squeeze()
        for u in chat
    ]

node_encodings = []
for _, chat_row in chats.iterrows():
    node_encodings.append(tokenize_chat(chat_row['chat']))

torch.save(node_encodings, f'data/{filename}_node_encodings.pt')

## Edges

In [None]:
edges = []
edge_idxs = []
for _, chat_row in chats.head(1).iterrows():
    chat = chat_row['chat']

    human_idxs = [i for i in range(0, len(chat), 2)]

    chat_edges = []
    chat_edges_idxs = []
    for ui in range(len(chat)):
        for uj in range(len(chat)):
            if ui == uj:
                continue

            edge_type = [
                ui > uj,
                ui in human_idxs,
                uj in human_idxs,
            ]

            chat_edges_idxs.append((ui, uj))
            chat_edges.append(edge_type)
    
    edges.append(Tensor(chat_edges).bool())
    edge_idxs.append(Tensor(chat_edges_idxs).T.long())

torch.save(edges, f'data/{filename}_edges.pt')
torch.save(edge_idxs, f'data/{filename}_edge_idxs.pt')

## Labels

In [None]:
torch.save(Tensor(chats[eval_dimensions].values), f'data/{filename}_labels.pt')