## Cleaning

In [None]:
import pandas as pd

filename = 'twitter_cs'
save_path = f'../data/{filename}'

twcs: pd.DataFrame = pd.read_csv(f'../data/{filename}.csv', nrows=5_000)
twcs.head()

In [None]:
twcs["text"] = (
    twcs["text"]
    .str.replace(r"@[^ ]*", "", regex=True)
    .str.replace(r"#\S+", "", regex=True)
    .str.replace(r"\^[^ ]*", "", regex=True)
    .str.replace(r"https?:\/\/[^\s\\n]+", "", regex=True)
    .str.replace(r"\n+", ' ', regex=True)
    .str.strip()
)
twcs = twcs.rename(columns={'inbound': 'is_customer'})
twcs['first_response'] = twcs['response_tweet_id'].str.split(',').str[0]
twcs['first_response'] = twcs['first_response'].fillna(-1).astype(int)

### Make threads

In [None]:
import networkx as nx
from tqdm import tqdm

tweet_graph = nx.from_pandas_edgelist(
    twcs[twcs['first_response'] != -1],
    source="first_response",
    target="tweet_id",
    create_using=nx.DiGraph(),
)

def find_final_tweet(node, graph):
    current = node
    while True:
        preds = list(graph.predecessors(current))
        if not preds:
            return current
        current = preds[0]


subgraph_map = {}
final_tweets = {node: find_final_tweet(node, tweet_graph) for node in tqdm(tweet_graph.nodes(), desc="Finding roots")}
twcs["thread_id"] = twcs["tweet_id"].map(final_tweets)

### Aggregate to chats

In [None]:
import pandas as pd
from tqdm import tqdm

def group_conversations(group):
    group = group.sort_values(by="created_at", ascending=True)

    alternating_messages = []
    prev_is_customer = None

    for _, row in group.iterrows():
        is_customer = row["is_customer"]
        if prev_is_customer is None and not is_customer:  # Skip non-customer start
            continue
        elif prev_is_customer == is_customer:
            alternating_messages[-1] += " " + row["text"]
        else:
            alternating_messages.append(row["text"])
            prev_is_customer = is_customer

    return alternating_messages


twcs.sort_values(by=["thread_id"], inplace=True)

tqdm.pandas(desc="Grouping conversations...")
grouped_chats = twcs.groupby("thread_id").progress_apply(group_conversations)

chats_df = pd.DataFrame(grouped_chats, columns=["chat"]).reset_index()
chats_df["n_messages"] = chats_df["chat"].apply(len)

### Filter

In [None]:
proper_length = (chats_df['n_messages'] >= 5) & (chats_df['n_messages'] <= 10)

keywords = [' dm', 'direct message', 'direct messaging', 'dms', 'private message' ' pm', 'private messaging']
non_dm = chats_df['chat'].apply(lambda c: all(not any(keyword in m.lower() for keyword in keywords) for m in c))

chats_df = chats_df[proper_length & non_dm]
print(f"Found {len(chats_df)} fitting chats")

In [None]:
from langdetect import detect, DetectorFactory
from langdetect.lang_detect_exception import LangDetectException
import pandas as pd

DetectorFactory.seed = 0

def is_english(text):
    try:
        return detect(text) == 'en'
    except LangDetectException:
        return False

chats_df['aug_text'] = chats_df.apply(lambda x: " // ".join(x['chat']), axis=1)

tqdm.pandas(desc="Detecting language...")
chats_df['is_english'] = chats_df['aug_text'].progress_apply(is_english)

In [None]:
chats_df = chats_df[chats_df['is_english']]

## Augmentation

### Embedding

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2"
)

to_embed = chats_df['aug_text'].tolist()
embeddings = model.encode(to_embed, show_progress_bar=True)
chats_df['aug_embedding'] = embeddings.tolist()

### Clustering

In [None]:
from sklearn.cluster import KMeans
from umap import UMAP
import matplotlib.pyplot as plt

def calculate_wcss(data, pot_n_clusters):
    wcss = []
    for n in tqdm(pot_n_clusters):
        kmeans = KMeans(n_clusters=n)
        kmeans.fit(data)
        wcss.append(kmeans.inertia_)
    return wcss

def plot_elbow(wcss, pot_n_clusters):
    plt.figure(figsize=(8, 6))
    plt.plot(pot_n_clusters, wcss, marker="o", linestyle="-", color="b")
    plt.xticks(pot_n_clusters) 
    plt.xlabel("Number of Clusters")
    plt.ylabel("Within-Cluster Sum of Squares (WCSS)")
    plt.title("Elbow Method for Optimal K")
    plt.grid(True)
    plt.tight_layout() 
    plt.show()


pot_n_clusters = range(1, 30, 1)
umap_emb = UMAP(n_components=2).fit_transform(embeddings)
wcss = calculate_wcss(umap_emb, pot_n_clusters)
plot_elbow(wcss, pot_n_clusters)

In [None]:
import numpy as np

def print_cluster_values(embeddings, n_clusters=2, per_cluster=3):
    kmeans = KMeans(n_clusters=n_clusters)
    kmeans.fit(embeddings)

    for i in range(n_clusters):
        cluster = np.where(kmeans.labels_ == i)[0]
        cluster_sample = np.random.choice(cluster, per_cluster, replace=False)
        print(f"Cluster {i}, total length {len(cluster)}:")
        for idx in cluster_sample:
            print(to_embed[idx])
        print()

print_cluster_values(umap_emb, n_clusters=10)

In [None]:
from sklearn.cluster import KMeans
from umap import UMAP

k = 10
clusters = KMeans(n_clusters=k, random_state=0).fit(umap_emb)
chats_df['cluster'] = clusters.labels_

### Creating augmented data

In [None]:
import numpy as np

def prepare_cluster_samples(df):
    cluster_samples = {}
    for cluster, group in df.groupby("cluster"):
        cluster_samples[cluster] = group.index.tolist()
    return cluster_samples


def subsitute_sim_answer(row, cluster_samples, df, example_per_cluster=10) -> pd.DataFrame:
    max_retries = 100
    orig_chat = row["chat"]
    same_cluster_indices = cluster_samples[row["cluster"]]

    aug_chats = []
    used_pairs = set()
    for _ in range(example_per_cluster):
        if max_retries == 0:
            break

        sub = df.loc[np.random.choice(same_cluster_indices)]
        sub_chat = sub['chat']

        min_len = min(len(sub_chat), len(orig_chat))
        agent_indices = [i for i in range(1, min_len, 2)]
        splice_idx = np.random.choice(agent_indices)

        if (sub.name, splice_idx) in used_pairs:
            max_retries -= 1
            continue

        aug_chats.append(
            orig_chat[:splice_idx]
            + [sub_chat[splice_idx]]
            + orig_chat[splice_idx + 1 :]
        )

        used_pairs.add((sub.name, splice_idx))
    
    return pd.DataFrame(
        {
            "chat": [orig_chat] * len(aug_chats),
            "aug_chat": aug_chats
        }
    )


cluster_samples = prepare_cluster_samples(chats_df)
tqdm.pandas(desc="Augmenting chats...")
aug_chats_df = pd.concat(chats_df.progress_apply(lambda x: subsitute_sim_answer(x, cluster_samples, chats_df), axis=1).tolist(), ignore_index=True)

In [None]:
chats = np.array(aug_chats_df[['chat', 'aug_chat']].values)

to_shuffle = np.random.rand(len(chats)) > 0.5

labels = np.where(to_shuffle, -1, 1)

to_shuffle = np.column_stack((to_shuffle, to_shuffle))
shuffled_chats = np.where(to_shuffle, chats[:, ::-1], chats)
del chats

## Construcing graph

### Nodes

In [None]:
import torch
from torch import Tensor
import torch.nn.functional as F

model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

max_sen_len = 0
max_nodes = 0
tokenized = []
for c1, c2 in tqdm(shuffled_chats, desc="Tokenizing"):
    t1, t2 = model.tokenize(c1)['input_ids'], model.tokenize(c2)['input_ids']
    max_sen_len = max(max_sen_len, t1.size(1), t2.size(1))
    max_nodes = max(max_nodes, t1.size(0), t2.size(0))
    tokenized.append([
        t1, t2
    ])

node_tensor = torch.zeros(len(tokenized), 2, max_nodes, max_sen_len)
for i, (c1, c2) in tqdm(enumerate(tokenized), desc="Padding"): 
    p1 = F.pad(c1, (0, max_sen_len - c1.size(1), 0, max_nodes - c1.size(0))) 
    p2 = F.pad(c2, (0, max_sen_len - c2.size(1), 0, max_nodes - c2.size(0)))

    node_tensor[i] = torch.stack([p1, p2])

node_tensor = node_tensor.long()
torch.save(node_tensor, f"{save_path}/nodes.pt")

### Labels

In [None]:
torch.save(Tensor(labels), f"{save_path}/labels.pt")

### Edges

In [None]:
from torch import Tensor

def create_chat_graph(chat, max_edges):
    human_idxs = [i for i in range(0, len(chat), 2)]

    chat_edges = []
    chat_edges_idxs = []
    for ui in range(len(chat)):
        for uj in range(len(chat)):
            if ui == uj:
                edge_type = [True, False, False, False]
            else:
                edge_type = [
                    False,
                    ui > uj,
                    ui in human_idxs,
                    uj in human_idxs,
                ]

            edge_type = sum(2**i for i, v in enumerate(reversed(edge_type)) if v)

            chat_edges_idxs.append((ui, uj))
            chat_edges.append(edge_type)
        
    chat_edges_pad = chat_edges + [0] * (max_edges - len(chat_edges))
    chat_edges_idxs_pad = chat_edges_idxs + [(0, 0)] * (max_edges - len(chat_edges_idxs))
    
    return chat_edges_pad, chat_edges_idxs_pad


max_edges = max_nodes**2

edges = torch.zeros(len(shuffled_chats), 2, max_edges, dtype=torch.int32)
edge_idxs = torch.zeros(len(shuffled_chats), 2, 2, max_edges, dtype=torch.int64)

for i, (c1, c2) in tqdm(enumerate(shuffled_chats), desc="Creating edges"):
   c1_edges, c1_edge_idxs = create_chat_graph(c1, max_edges)
   c2_edges, c2_edge_idxs = create_chat_graph(c2, max_edges)

   edges[i] = Tensor([c1_edges, c2_edges])
   edge_idxs[i] = Tensor([c1_edge_idxs, c2_edge_idxs]).transpose(1, 2)

torch.save(edges, f"{save_path}/edges.pt")
torch.save(edge_idxs, f"{save_path}/edge_idxs.pt")