# Loading and Reformatting

In [None]:
from typing import List

import pandas as pd
import torch
from torch import Tensor

filename = 'gogi_chats'
messages = pd.read_csv(f'data/{filename}.csv')

In [None]:
def mean_non_na(series: pd.Series) -> float:
    as_numbers = pd.to_numeric(series, errors='coerce')
    return as_numbers.dropna().mean()

eval_dimensions = ['friendliness', 'helpfulness', 'clearness', 'astuteness', 'tactfulness']
eval_agg_funcs = {dim: mean_non_na for dim in eval_dimensions}
chats = messages[['conversation_id', 'message', *eval_dimensions]].groupby('conversation_id').agg(
    {
        **eval_agg_funcs,
        'message': lambda x: list(x)
    }
).reset_index().rename(columns={'message': 'chat'})
chats = chats.dropna(subset=eval_dimensions)
chats.head()

# Nodes

In [None]:
nodes = list(chats['chat'])
torch.save(nodes, f"data/{filename}_nodes.pt")

## Edges

In [None]:
from numpy import int8


edges = []
edge_idxs = []
for _, chat_row in chats.head(1).iterrows():
    chat = chat_row['chat']

    human_idxs = [i for i in range(0, len(chat), 2)]

    chat_edges = []
    chat_edges_idxs = []
    for ui in range(len(chat)):
        for uj in range(len(chat)):
            if ui == uj:
               edge_type = [True, False, False, False] 
            else:
                edge_type = [
                    False,
                    ui > uj,
                    ui in human_idxs,
                    uj in human_idxs,
                ]
            
            edge_type = sum(2**i for i, v in enumerate(reversed(edge_type)) if v)

            chat_edges_idxs.append((ui, uj))
            chat_edges.append(edge_type)
    
    edges.append(Tensor(chat_edges).to(torch.int8))
    edge_idxs.append(Tensor(chat_edges_idxs).T.long())

torch.save(edges, f'data/{filename}_edges.pt')
torch.save(edge_idxs, f'data/{filename}_edge_idxs.pt')

## Labels

In [None]:
torch.save(Tensor(chats[eval_dimensions].values), f'data/{filename}_labels.pt')