## Cleaning

In [None]:
import pandas as pd
from tqdm import tqdm

filename = 'twitter_cs'

twcs: pd.DataFrame = pd.read_csv(f'../data/{filename}.csv', nrows=1000)
twcs.head()

In [None]:
twcs["text"] = (
    twcs["text"]
    .str.replace(r"@[^ ]*", "", regex=True)
    .str.replace(r"#\S+", "", regex=True)
    .str.replace(r"\^[^ ]*", "", regex=True)
    .str.replace(r"https?:\/\/[^\s\\n]+", "", regex=True)
    .str.replace(r"\n+", ' ', regex=True)
    .str.strip()
)
twcs = twcs.rename(columns={'inbound': 'is_customer'})
twcs['first_response'] = twcs['response_tweet_id'].str.split(',').str[0]
twcs['first_response'] = twcs['first_response'].fillna(-1).astype(int)

### Make threads

In [None]:
import networkx as nx
from tqdm import tqdm

tweet_graph = nx.from_pandas_edgelist(
    twcs[twcs['first_response'] != -1],
    source="first_response",
    target="tweet_id",
    create_using=nx.DiGraph(),
)

def find_final_tweet(node, graph):
    current = node
    while True:
        preds = list(graph.predecessors(current))
        if not preds:
            return current
        current = preds[0]


subgraph_map = {}
final_tweets = {node: find_final_tweet(node, tweet_graph) for node in tqdm(tweet_graph.nodes(), desc="Finding roots")}
twcs["thread_id"] = twcs["tweet_id"].map(final_tweets)

### Aggregate to chats

In [None]:
import pandas as pd
from tqdm import tqdm

def group_conversations(group):
    group = group.sort_values(by="created_at", ascending=True)

    alternating_messages = []
    prev_is_customer = None

    for _, row in group.iterrows():
        is_customer = row["is_customer"]
        if prev_is_customer is None and not is_customer:  # Skip non-customer start
            continue
        elif prev_is_customer == is_customer:
            alternating_messages[-1] += " " + row["text"]
        else:
            alternating_messages.append(row["text"])
            prev_is_customer = is_customer

    return alternating_messages


twcs.sort_values(by=["thread_id"], inplace=True)

tqdm.pandas(desc="Grouping conversations...")
grouped_chats = twcs.groupby("thread_id").progress_apply(group_conversations)

chats_df = pd.DataFrame(grouped_chats, columns=["chat"]).reset_index()
chats_df["n_messages"] = chats_df["chat"].apply(len)

### Filter

In [None]:
proper_length = (chats_df['n_messages'] >= 5) & (chats_df['n_messages'] <= 10)

keywords = [' dm', 'direct message', 'direct messaging', 'dms', 'private message']
non_dm = chats_df['chat'].apply(lambda c: all(not any(keyword in m.lower() for keyword in keywords) for m in c))

chats_df = chats_df[proper_length & non_dm]
print(f"Found {len(chats_df)} fitting chats")

## Augmentation

### Embed random question

In [None]:
import numpy as np

def random_customer_idx(x):
    even = [i for i in range(0, x - 2, 2)] # Don't pick the last message
    return np.random.choice(even)

chats_df['aug_idx'] = chats_df['n_messages'].apply(lambda x: random_customer_idx(x))
chats_df['aug_text'] = chats_df.apply(lambda x: x['chat'][x['aug_idx']], axis=1)

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer(
    "mixedbread-ai/mxbai-embed-large-v1"
)

to_embed = chats_df['aug_text'].tolist()
embeddings = model.encode(to_embed, show_progress_bar=True)
chats_df['aug_embedding'] = embeddings.tolist()

### Cluster questions

In [None]:
from sklearn.cluster import KMeans

k = 20

kmeans = KMeans(n_clusters=k, random_state=0).fit(embeddings)
chats_df['cluster'] = kmeans.labels_

### Creating augmented data

In [None]:
def substitue_sim_answer(row, cluster_samples, df):
    same_cluster_indices = cluster_samples[row['cluster']]
    
    sub_idx = np.random.choice(same_cluster_indices)
    sub = df.loc[sub_idx]

    sub_chat, sub_aug_idx= sub['chat'], sub['aug_idx']
    sub_answer = sub_chat[sub_aug_idx]

    orig_chat, orig_aug_idx = row['chat'], row['aug_idx']
    aug_chat = orig_chat[:orig_aug_idx + 1] + [sub_answer] + orig_chat[orig_aug_idx + 2:]

    return aug_chat

def prepare_cluster_samples(df):
    cluster_samples = {}
    for cluster, group in df.groupby('cluster'):
        cluster_samples[cluster] = group.index.tolist()
    return cluster_samples

cluster_samples = prepare_cluster_samples(chats_df)
tqdm.pandas(desc="Augmenting chats...")
chats_df['aug_chat'] = chats_df.progress_apply(lambda x: substitue_sim_answer(x, cluster_samples, chats_df), axis=1)

In [None]:
chats = np.array(chats_df[['chat', 'aug_chat']].values)

to_shuffle = np.random.rand(len(chats)) > 0.5

labels = np.where(to_shuffle, -1, 1)

to_shuffle = np.column_stack((to_shuffle, to_shuffle))
shuffled_chats = np.where(to_shuffle, chats[:, ::-1], chats)
del chats

## Construcing graph

### Nodes

In [None]:
import torch
from torch import Tensor
import torch.nn.functional as F


model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

max_sen_len = 0
max_nodes = 0
tokenized = []
for c1, c2 in shuffled_chats:
    t1, t2 = model.tokenize(c1)['input_ids'], model.tokenize(c2)['input_ids']
    max_sen_len = max(max_sen_len, t1.size(1), t2.size(1))
    max_nodes = max(max_nodes, t1.size(0), t2.size(0))
    tokenized.append([
        t1, t2
    ])

node_tensor = torch.zeros(len(tokenized), 2, max_nodes, max_sen_len)
for i, (c1, c2) in enumerate(tokenized): 
    p1 = F.pad(c1, (0, max_sen_len - c1.size(1), 0, max_nodes - c1.size(0))) 
    p2 = F.pad(c2, (0, max_sen_len - c2.size(1), 0, max_nodes - c2.size(0)))

    node_tensor[i] = torch.stack([p1, p2])

node_tensor = node_tensor.long()
torch.save(node_tensor, f"../data/{filename}_nodes.pt")

### Labels

In [None]:
torch.save(Tensor(labels), f"../data/{filename}_labels.pt")

### Edges

In [None]:
from torch import Tensor

def create_chat_graph(chat, max_edges):
    human_idxs = [i for i in range(0, len(chat), 2)]

    chat_edges = []
    chat_edges_idxs = []
    for ui in range(len(chat)):
        for uj in range(len(chat)):
            if ui == uj:
                edge_type = [True, False, False, False]
            else:
                edge_type = [
                    False,
                    ui > uj,
                    ui in human_idxs,
                    uj in human_idxs,
                ]

            edge_type = sum(2**i for i, v in enumerate(reversed(edge_type)) if v)

            chat_edges_idxs.append((ui, uj))
            chat_edges.append(edge_type)
        
    chat_edges_pad = chat_edges + [0] * (max_edges - len(chat_edges))
    chat_edges_idxs_pad = chat_edges_idxs + [(0, 0)] * (max_edges - len(chat_edges_idxs))
    
    return chat_edges_pad, chat_edges_idxs_pad


max_edges = max_nodes**2

edges = torch.zeros(len(shuffled_chats), 2, max_edges, dtype=torch.int32)
edge_idxs = torch.zeros(len(shuffled_chats), 2, 2, max_edges, dtype=torch.int64)

for i, (c1, c2) in enumerate(shuffled_chats):
   c1_edges, c1_edge_idxs = create_chat_graph(c1, max_edges)
   c2_edges, c2_edge_idxs = create_chat_graph(c2, max_edges)

   edges[i] = Tensor([c1_edges, c2_edges])
   edge_idxs[i] = Tensor([c1_edge_idxs, c2_edge_idxs]).transpose(1, 2)

torch.save(edges, f"../data/{filename}_edges.pt")
torch.save(edge_idxs, f"../data/{filename}_edge_idxs.pt")

In [None]:
for i in range(10):
    random_idx = np.random.randint(0, len(shuffled_chats))
    print("\n".join(shuffled_chats[random_idx][0]))
    print()
    print("\n".join(shuffled_chats[random_idx][1]))
    print()
    print(labels[random_idx])
    print('-'*50)