## Cleaning

In [None]:
import pandas as pd
from tqdm import tqdm

filename = 'twitter_cs'

twcs: pd.DataFrame = pd.read_csv(f'data/{filename}.csv', nrows=1_000)
twcs = twcs.sort_values('created_at')
twcs.head()

In [None]:
twcs["text"] = (
    twcs["text"]
    .str.replace(r"^\s*@[^ ]*", "", regex=True)
    .str.replace(r"https?:\/\/[^\s\\n]+", "", regex=True)
    .str.replace(r"\n+", ' ', regex=True)
    .str.strip()
)
twcs = twcs.rename(columns={'inbound': 'is_customer'})

### Make threads

In [None]:
def find_root(tweet_id, df):
    parent_id = tweet_id

    while True:
        potential_parent = df[df['in_response_to_tweet_id'] == parent_id]['tweet_id']

        if len(potential_parent) == 0:
            return parent_id
        
        parent_id = potential_parent.values[0]

tqdm.pandas(desc="Making threads...")
twcs['thread_id'] = twcs['tweet_id'].progress_apply(lambda x: find_root(x, twcs))

### Aggregate to chats

In [None]:
def group_conversations(df):
    altnerating_messages = []
    last_is_customer = None

    for _, row in df.iterrows():
        is_customer = row['is_customer']
        if last_is_customer == is_customer:
            altnerating_messages[-1] += ' ' + row['text']
        else:
            altnerating_messages.append(row['text'])
            last_is_customer = is_customer

    return altnerating_messages

chats = twcs.copy()
tqdm.pandas(desc="Grouping conversations...")
chats['chat'] = chats.progress_apply(lambda x: group_conversations(twcs[twcs['thread_id'] == x['thread_id']]), axis=1)
chats = chats.drop_duplicates('thread_id')[['chat']]
chats['n_messages'] = chats['chat'].apply(lambda x: len(x))

### Filter

In [None]:
proper_length = chats['n_messages'] >= 4
non_dm = chats['chat'].apply(lambda c: all([' dm' not in m.lower() for m in c]))

chats = chats[proper_length & non_dm]

### Nodes

In [None]:
import torch

nodes = list(chats["chat"])
torch.save(nodes, f"data/{filename}_nodes.pt")

### Edges

In [None]:
from torch import Tensor

edges = []
edge_idxs = []
for _, chat_row in chats.head(1).iterrows():
    chat = chat_row["chat"]

    human_idxs = [i for i in range(0, len(chat), 2)]

    chat_edges = []
    chat_edges_idxs = []
    for ui in range(len(chat)):
        for uj in range(len(chat)):
            if ui == uj:
                edge_type = [True, False, False, False]
            else:
                edge_type = [
                    False,
                    ui > uj,
                    ui in human_idxs,
                    uj in human_idxs,
                ]

            edge_type = sum(2**i for i, v in enumerate(reversed(edge_type)) if v)

            chat_edges_idxs.append((ui, uj))
            chat_edges.append(edge_type)

    edges.append(Tensor(chat_edges).to(torch.int8))
    edge_idxs.append(Tensor(chat_edges_idxs).T.long())

torch.save(edges, f"data/{filename}_edges.pt")
torch.save(edge_idxs, f"data/{filename}_edge_idxs.pt")