In [62]:
import json
import os
import pandas as pd
from datetime import datetime
import pickle 
from matplotlib import pyplot as plt
from collections import Counter
import numpy as np
import torch
import logging
from tqdm import tqdm
import ast

from src.utils.tree_utils import *
from src.utils.utils import *
from src.dataset import *

In [2]:
out_dir = './data/aliceysu'
journalist = 'aliceysu'
with open(os.path.join(out_dir, f'{journalist}_dict.pkl'), 'rb') as f:
    data = pickle.load(f)

with open(os.path.join(out_dir, f'{journalist}_ids.pkl'), 'rb') as f:
    map_id = pickle.load(f)
    
with open(os.path.join(out_dir, f'{journalist}_lan.pkl'), 'rb') as f:
    map_lan = pickle.load(f)

with open(os.path.join(out_dir, f'{journalist}_type.pkl'), 'rb') as f:
    map_type = pickle.load(f)

with open(os.path.join(out_dir, f'{journalist}_reply.pkl'), 'rb') as f:
    map_reply = pickle.load(f)
    
with open(os.path.join(out_dir, f'{journalist}_edgeprob.pkl'), 'rb') as f:
    edgeprob = pickle.load(f)


In [None]:
def load_data(data_dir, journalist, classes, batch_size, collate):
    ### Data (normalize input inter-event times, then padding to create dataloaders)
    num_classes, num_sequences = classes, 0
    seq_dataset = []
    arr = []
    dp = []
    rel = []
    
    split = [64, 128]
    val = 0
    journal_sort = pd.read_csv((os.path.join(data_dir, f'{journalist}_context.csv')))
    ids = list(set(journal_sort['conversation_id']))
    id_pair = {}
    id_conv = {}
    for idx in ids:
        id_pair[idx], id_conv[idx] = create_conversation_list(journal_sort[journal_sort['conversation_id']==idx], idx)
    id_data, data, label = create_data(journal_sort, ids)
    prob = pkl.load(open(os.path.join(data_dir, f'{journalist}_edgeprob.pkl'), 'rb'))
    
    with open(os.path.join(data_dir, f'{journalist}_global_path.txt'), "r") as f:
        for line in tqdm(f, total=get_number_of_lines(f)):
            dp.append(json.loads(line.strip()))

    with open(os.path.join(data_dir, f'{journalist}_local_path.txt'), "r") as f:
        for line in tqdm(f, total=get_number_of_lines(f)):
            rel.append(json.loads(line.strip()))
    
    global_input = convert_global(dp, id_data)
    local_data = convert_local(rel)
    local_mat = generate_local_mat(local_data, id_data)
    local_input = create_mat(local_mat, mat_type='concat')
    logging.info(f'loaded split {journalist}...')
    # data - dict: dim_process, devtest, args, train, dev, test, index (train/dev/test given as)
    # data[split] - list dicts {'time_since_start': at, 'time_since_last_event': dt, 'type_event': mark} or
    # data[split] - dict {'arrival_times', 'delta_times', 'marks'}
    # data['dim_process'] = Number of accounts = 119,298
    # num_sequences: number of conversations of a journalist
    num_classes = classes
    num_sequences = len(set(journal_sort['conversation_id']))
    
    X_train, X_dev, X_test = data[:split[0]], data[split[0]:split[1]], data[split[1]:]
    prob_train, prob_dev, prob_test = prob[:split[0]], prob[split[0]:split[1]], prob[split[1]:]
    global_train, global_dev, global_test = global_input[:split[0]], global_input[split[0]:split[1]], global_input[split[1]:]
    local_train, local_dev, local_test = local_input[:split[0]], local_input[split[0]:split[1]], local_input[split[1]:]
    label_train, label_dev, label_test = label[:split[0]], label[split[0]:split[1]], label[split[1]:]

    d_train = TreeDataset(X_train, prob_train, global_train, local_train, label_train)
    d_val = TreeDataset(X_dev, prob_dev, global_dev, local_dev, label_dev)  
    d_test  = TreeDataset(X_test, prob_test, global_test, local_test, label_test)   

    # for padding input sequences to maxlen of batch for running on gpu, and arranging them by length efficient
    collate = collate  
    dl_train = torch.utils.data.DataLoader(d_train, batch_size=batch_size, shuffle=False, collate_fn=collate)
    dl_val = torch.utils.data.DataLoader(d_val, batch_size=batch_size, shuffle=False, collate_fn=collate)
    dl_test = torch.utils.data.DataLoader(d_test, batch_size=batch_size, shuffle=False, collate_fn=collate)
    return dl_train, dl_val, dl_test





In [400]:
train, val, test = load_data(out_dir, journalist, 3, 8, collate)


100%|██████████| 419/419 [00:00<00:00, 469.94it/s]
100%|██████████| 419/419 [00:00<00:00, 2197.12it/s]
  return np.array(result)


In [405]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SequenceClassifier(nn.Module):
    def __init__(self, original_input_dim, input_dim, hidden_dim, num_classes, num_layers, dropout=0.1):
        super(SequenceClassifier, self).__init__()
        self.original_input_dim = original_input_dim
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.num_layers = num_layers

        # Linear layer to adjust input dimension
        self.input_projection = nn.Linear(self.original_input_dim, self.input_dim)

        # Transformer Encoder
        encoder_layers = nn.TransformerEncoderLayer(d_model=self.input_dim, nhead=8, dim_feedforward=self.hidden_dim, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=self.num_layers)

        # Classification Head
        self.classifier = nn.Linear(self.input_dim, self.num_classes)

    def forward(self, x, mask=None):
        # Project input to the desired dimension
        x = self.input_projection(x)

        # Transformer Encoder
        x = x.permute(1, 0, 2)  # Reshape x to [seq_length, batch_size, input_dim]
        if mask is not None:
            mask = mask.permute(1, 0)  # Reshape mask to [seq_length, batch_size]
            mask = mask.unsqueeze(1)  # Add dimension for heads
        x = self.transformer_encoder(x, src_key_padding_mask=mask)

        # Classification for each element in the sequence
        x = x.permute(1, 0, 2)  # Reshape back to [batch_size, seq_length, input_dim]
        logits = self.classifier(x)

        return logits

# Example usage
input_dim = 11  # Feature dimension of each element in the sequence
hidden_dim = 512  # Hidden dimension for transformer encoder
num_classes = 3  # Number of classes
num_layers = 4  # Number of transformer layers

model = SequenceClassifier(input_dim, hidden_dim, num_classes, num_layers)

# Example input and mask
example_input = torch.randn(8, 10, 11)
example_mask = torch.tensor([[1,1,1,0,0], [1,1,1,1,1]], dtype=torch.bool)
logits = model(example_input, example_mask)  # Output of shape [8, 10, 3] (batch_size, seq_length, num_classes)


AssertionError: embed_dim must be divisible by num_heads

In [406]:
example_input = torch.randn(8, 10, 11)

In [407]:
example_input.permute(1, 0, 2).size()

torch.Size([10, 8, 11])

In [21]:
num_classes, num_sequences = 0, 0
seq_dataset = []
arr = []
    
split = [64, 128]
val = 0
#data = pkl.load(open(os.path.join(out_dir, f'{journalist}_dict.pkl'), 'rb'))
#logging.info(f'loaded split aliceysu...')
num_classes = 3
num_sequences = len(set(data['conversation_id']))
journal = pd.DataFrame.from_dict(data)
#journal_sort = journal.sort_values(by=['created_at'])
#journal_batch = journal_sort[["type", "possibly_sensitive", "lang", "reply_settings",
#                              "retweet_count", "reply_count", "like_count", "quote_count", "impression_count",
#                              "mentions", "urls", "labels"]]
journal_sort = pd.read_csv(os.path.join(out_dir, f'{journalist}_context.csv'))

In [22]:
ids = list(set(journal_sort['conversation_id']))
id_pair = {}
id_conv = {}
for idx in ids:
    id_pair[idx], id_conv[idx] = create_conversation_list(journal_sort[journal_sort['conversation_id']==idx], idx)

In [5]:
with open(os.path.join(out_dir, f'{journalist}_global_path.txt')) as f:
    global_path = f.readlines()
    
with open(os.path.join(out_dir, f'{journalist}_local_path.txt')) as f:
    local_path = f.readlines()

In [None]:
def convert_path(g_list):
    g_dict = {}
    for item in g_list:
        temp = ast.literal_eval(item)
        for item in temp:
            k = list(item.keys())[0]
            g_dict[k] = item[k][k]
    return g_dict

In [7]:
global_dict = convert_path(global_path)

In [8]:
#journal_sort['global'] = [global_dict.get(str(id), global_dict[id]) for id in journal_sort['tweet_id']]
#journal_sort.to_csv(os.path.join(out_dir, f'{journalist}_context.csv'))
global_paths = [] # input
id_clean = {}
for i in ids:
    temp = []
    id_clean[i] = []
    for j, item in enumerate(id_conv[i]):
        if str(item) not in global_dict.keys():
            continue
        temp.append(global_dict[str(item)])
        id_clean[i].append(item)
    global_paths.append(temp)

In [54]:
# with open(os.path.join(out_dir, f'{journalist}_global_path.txt'), "w") as fout:
#     num_dps = 0
#     for k in id_pair.keys():
#         tree_root = build_tree(id_pair[k])
        
#         tree_root.create_global_relation()
#         node_list = tree_root.dfs()

#         root_paths = TreeNode.extract_data(node_list,f=lambda node: clamp_and_slice_ids(
#                 node.global_relation, max_width=-1, max_depth=-1))
#         asts = separate_dps(root_paths, n_ctx)

#         """for lr, extended in asts:
#             if extended != 0:
#                 break
#             if len(lr) - extended > 1:
#                 """
#         json.dump(root_paths, fp=fout)  # each line is the json of a list [dict,dict,...]
#         num_dps += 1
#         fout.write("\n")

In [218]:
# num_dps = 0
# with open(os.path.join(out_dir, f'{journalist}_local_path.txt'), "w") as fout:
#     for k in id_pair.keys():
#         tree_root = build_tree(id_pair[k])
        
#         tree_root.create_local_relation()
#         node_list = tree_root.dfs()

#         local_relation = TreeNode.extract_data(node_list,f=lambda node: clamp_and_slice_ids(
#                 node.local_relation, max_width=-1, max_depth=-1))
#         rel = separate_dps(local_relation, n_ctx)

#         """for lr, extended in rel:
#             if extended != 0:
#                 break
#             if len(lr) - extended > 1:"""
#         json.dump(local_relation, fp=fout)  # each line is the json of a list [dict,dict,...]
#         num_dps += 1
#         fout.write("\n")

In [15]:
out_fp = './result'
journalist = 'aliceysu'
n_ctx = 4000

In [18]:
batch_data = []
target_data = []
conv_data = []
ref_data = []
id_data = []
for idx in ids:
    convs = journal_sort[journal_sort['conversation_id'] == idx]
    convs_batch = convs[["type", "possibly_sensitive", "lang", "reply_settings",
                     "retweet_count", "reply_count", "like_count", "quote_count", "impression_count",
                     "mentions", "urls"]]
    conv_data.append(list(convs['conversation_id']))
    ref_data.append(list(convs['reference_id']))
    id_data.append(list(convs['tweet_id']))
    batch_data.append(convs_batch.values.tolist())
    target_data.append(list(convs['labels']))
    
label_data = target_data

In [19]:
position_data = []
for i, idx in enumerate(ids):
    tree_root = build_tree(id_pair[idx])
    node_info = get_node_info(tree_root)
    temp_list = []
    #temp_list.append([node_info[idx]['level'], node_info[idx]['number_of_siblings'], node_info[idx]['sibling_order']])
    for item in id_data[i]:
        """if item not in node_info.keys():
            continue"""
        temp_list.append([node_info[item]['level'], 
                    node_info[item]['number_of_siblings'], 
                    node_info[item]['sibling_order']])
    position_data.append(temp_list)

In [20]:
X_train, X_dev, X_test = batch_data[:split[0]], batch_data[split[0]:split[1]], batch_data[split[1]:]
pos_train, pos_dev, pos_test = position_data[:split[0]], position_data[split[0]:split[1]], position_data[split[1]:]
id_train, id_dev, id_test = id_data[:split[0]], id_data[split[0]:split[1]], id_data[split[1]:]
label_train, label_dev, label_test = label_data[:split[0]], label_data[split[0]:split[1]], label_data[split[1]:]

In [16]:
dp = []
with open(os.path.join(out_fp, f'{journalist}_global_path.txt'), "r") as f:
    for line in tqdm(f, total=get_number_of_lines(f)):
        dp.append(json.loads(line.strip()))
        
rel = []
with open(os.path.join(out_fp, f'{journalist}_local_path.txt'), "r") as f:
    for line in tqdm(f, total=get_number_of_lines(f)):
        rel.append(json.loads(line.strip()))

100%|██████████| 419/419 [00:01<00:00, 395.10it/s]
100%|██████████| 419/419 [00:00<00:00, 1766.06it/s]


In [19]:
roots = convert_global(dp, id_data)
local = convert_local(rel)

In [79]:
with open(os.path.join(out_dir, f'{journalist}_global_path.pkl'),'wb') as f:
    pickle.dump(roots, f)

In [236]:
## test
max_len = max(len(dp[0]) for dp in batch_data[:8])
max_depth = 12
max_width = 16
seqs = (id_data[:8], batch_data[:8], roots[:8])
r = []
for i in range(8):
    new_root = [roots[i][str(x)] for x in id_data[i]][:] + [[] for _ in range(max_len - len(id_data[i]))]
    r.append(new_root)
    
positions = generate_positions(r[1], max_width, max_depth)
position_seqs = []
positions = generate_positions(r[1], max_width=max_width, max_depth=max_depth)
position_seqs.append(positions.unsqueeze(0))

position_seqs[0].size()

KeyError: '4112'

In [123]:
local_mat = generate_local_mat(local, id_data)
local_input = create_mat(local_mat, mat_type='concat')



## test functions

In [None]:
def create_data(journal_sort, ids):
    batch_data = []
    target_data = []
    conv_data = []
    ref_data = []
    id_data = []
    for idx in ids:
        convs = journal_sort[journal_sort['conversation_id'] == idx]
        convs_batch = convs[['type', 'possibly_sensitive', 'lang', 'reply_settings', 
                               'retweet_count', 'reply_count', 'like_count', 'quote_count',
                                'impression_count', 'mentions', 'urls']]
        #conv_data.append(list(convs['conversation_id']))
        conv_data.append(convs_batch.to_numpy().tolist())
        ref_data.append(list(convs['reference_id']))
        id_data.append(list(convs['tweet_id']))
        batch_data.append(convs_batch.values.tolist())
        target_data.append(list(convs['labels']))
    
    label_data = target_data
    return id_data, conv_data, label_data

def load_data(data_dir, journalist, classes, batch_size, collate):
    ### Data (normalize input inter-event times, then padding to create dataloaders)
    num_classes, num_sequences = 0, 0
    seq_dataset = []
    arr = []
    dp = []
    rel = []
    
    split = [64, 128]
    val = 0
    journal_sort = pd.read_csv((os.path.join(data_dir, f'{journalist}_context.csv')))
    ids = list(set(journal_sort['conversation_id']))
    id_pair = {}
    id_conv = {}
    for idx in ids:
        id_pair[idx], id_conv[idx] = create_conversation_list(journal_sort[journal_sort['conversation_id']==idx], idx)
    id_data, data, label = create_data(journal_sort, ids)
    prob = pkl.load(open(os.path.join(data_dir, f'{journalist}_edgeprob.pkl'), 'rb'))
    
    with open(os.path.join(data_dir, f'{journalist}_global_path.txt'), "r") as f:
        for line in tqdm(f, total=get_number_of_lines(f)):
            dp.append(json.loads(line.strip()))

    with open(os.path.join(data_dir, f'{journalist}_local_path.txt'), "r") as f:
        for line in tqdm(f, total=get_number_of_lines(f)):
            rel.append(json.loads(line.strip()))
    
    global_input = convert_global(dp, id_data)
    local_data = convert_local(rel)
    local_mat = generate_local_mat(local_data, id_data)
    local_input = create_mat(local_mat, mat_type='concat')
    logging.info(f'loaded split {journalist}...')
    # data - dict: dim_process, devtest, args, train, dev, test, index (train/dev/test given as)
    # data[split] - list dicts {'time_since_start': at, 'time_since_last_event': dt, 'type_event': mark} or
    # data[split] - dict {'arrival_times', 'delta_times', 'marks'}
    # data['dim_process'] = Number of accounts = 119,298
    # num_sequences: number of conversations of a journalist
    num_classes = classes
    #num_sequences += len(data[split]['arrival_times'])
    num_sequences = len(set(journal_sort['conversation_id']))
    
    X_train, X_dev, X_test = data[:split[0]], data[split[0]:split[1]], data[split[1]:]
    prob_train, prob_dev, prob_test = prob[:split[0]], prob[split[0]:split[1]], prob[split[1]:]
    global_train, global_dev, global_test = global_input[:split[0]], global_input[split[0]:split[1]], global_input[split[1]:]
    local_train, local_dev, local_test = local_input[:split[0]], local_input[split[0]:split[1]], local_input[split[1]:]
    label_train, label_dev, label_test = label[:split[0]], label[split[0]:split[1]], label[split[1]:]

    d_train = TreeDataset(X_train, prob_train, global_train, local_train, label_train)
    d_val = TreeDataset(X_dev, prob_dev, global_dev, local_dev, label_dev)  
    d_test  = TreeDataset(X_test, prob_test, global_test, local_test, label_test)   

    # for padding input sequences to maxlen of batch for running on gpu, and arranging them by length efficient
    collate = collate  
    dl_train = torch.utils.data.DataLoader(d_train, batch_size=batch_size, shuffle=False, collate_fn=collate)
    dl_val = torch.utils.data.DataLoader(d_val, batch_size=batch_size, shuffle=False, collate_fn=collate)
    dl_test = torch.utils.data.DataLoader(d_test, batch_size=batch_size, shuffle=False, collate_fn=collate)
    return dl_train, dl_val, dl_test






In [None]:
class TreeNodes:
    def __init__(self, name):
        self.name = name
        self.children = []  # List of TreeNode objects
        self.level = 0  # Level of the node in the tree
        self.sibling_order = 0  # Order among siblings
        self.parent = None  # Parent of the node
        self.local_relation = dict()
        self.global_relation = dict()

    def add_child(self, child_node):
        child_node.parent = self
        child_node.level = self.level + 1 if self.level is not None else 0
        child_node.sibling_order = len(self.children)
        self.children.append(child_node)

    def num_siblings(self):
        return len(self.parent.children)-1 if self.parent else 0
    
    def extract_data(node_list, only_leaf=False, f=lambda node: node.data):
        ret = []
        #print("?")
        for node in node_list:
            if not (only_leaf and node.node_type == "type"):
                ret.append({node.name: f(node)})
        return ret

    def create_local_relation(self):

        def _dfs(node):
            for child in node.children:
                node_child_rel = [child.level, child.num_siblings(), child.sibling_order]
                node_father_rel = [node.level, node.num_siblings(), node.sibling_order]
                #node_father_rel = child.parent
                node.local_relation[child.name] = [node_child_rel, node_father_rel, 0]
                child.local_relation[node.name] = [node_child_rel, node_father_rel, 1]
                _dfs(child)

        _dfs(self)

    def create_global_relation(self):
        def g_dfs(node):
            node_rel = [node.level, node.num_siblings(), node.sibling_order]
            if not node.parent:
                node.global_relation[node.name] = [node_rel]
            else: 
                if node.parent.name not in node.parent.global_relation.keys():
                    node.global_relation[node.name] = node.parent.parent.global_relation[node.parent.parent.name] + [node_rel]
                else:
                    node.global_relation[node.name] = node.parent.global_relation[node.parent.name] + [node_rel]
            for child in node.children:
                g_dfs(child)

        g_dfs(self)
 

    def dfs(self):
        ret = []

        def _dfs(node, ret):
           #ret : List
            ret.append(node)
            for child in node.children:
                _dfs(child, ret)

        _dfs(self, ret)
        return ret
    
def build_tree(conversations):
    nodes = {}
    root = 0

    for parent, child in conversations:
        if parent not in nodes:
            nodes[parent] = TreeNodes(parent)
        if child not in nodes:
            nodes[child] = TreeNodes(child)

        nodes[parent].add_child(nodes[child])

        if not root:
            root = nodes[parent]

    return root


def separate_dps(ast, max_len):
    """
    Handles training / evaluation on long ASTs by splitting
    them into smaller ASTs of length max_len, with a sliding
    window of max_len / 2.

    Example: for an AST ast with length 1700, and max_len = 1000,
    the output will be:
    [[ast[0:1000], 0], [ast[500:1500], 1000], [ast[700:1700], 1500]]

    Input:
        ast : List[Dictionary]
            List of nodes in pre-order traversal.
        max_len : int

    Output:
        aug_asts : List[List[List, int]]
            List of (ast, beginning idx of unseen nodes)
    """
    half_len = int(max_len / 2)
    if len(ast) <= max_len:
        return [[ast, 0]]

    aug_asts = [[ast[:max_len], 0]]
    i = half_len
    while i < len(ast) - max_len:
        aug_asts.append([ast[i: i + max_len], half_len])
        i += half_len
    idx = max_len - (len(ast) - (i + half_len))
    aug_asts.append([ast[-max_len:], idx])

    return aug_asts


def separate_lrs(lrs, max_len):
    def reformat(lrs, left):  # [left,right)
        new_lrs = []
        for idx, lr in enumerate(lrs):
            # lr -> dict: {idx:[],idx:[]}
            temp_lr = dict()
            for key, val in lr.items():
                if left <= key < left + max_len:
                    temp_lr[key - left] = val
            new_lrs.append(temp_lr)
        return new_lrs

    half_len = int(max_len / 2)
    if len(lrs) <= max_len:
        return [[reformat(lrs, 0), 0]]

    aug_asts = [[reformat(lrs[:max_len], 0), 0]]
    i = half_len
    while i < len(lrs) - max_len:
        aug_asts.append([reformat(lrs[i: i + max_len], i), half_len])
        i += half_len
    idx = max_len - (len(lrs) - (i + half_len))
    aug_asts.append([reformat(lrs[len(lrs) - max_len:], len(lrs) - max_len), idx])
    return aug_asts




In [None]:
def convert_global(root_paths, id_data):
    roots = []
    global_new = []
    for i in range(len(root_paths)):
        new_dict = {}
        for item in root_paths[i]:
            name = list(item.keys())[0]
            new_dict[name] = np.array(list(list(item.values())[0].values())).squeeze().tolist()
        roots.append(new_dict)
    for i in range(len(id_data)):
        global_new.append([roots[i][str(k)] for k in id_data[i]])
    return global_new


def convert_local(local_rel):
    rel = []
    for i in range(len(local_rel)):
        new_dict = {}
        for item in local_rel[i]:
            name = list(item.keys())[0]
            new_dict[name] = item[name]
        rel.append(new_dict)
    return rel

def indexing(ls):
    dic = {}
    for i in range(len(ls)):
        dic[ls[i]] = i
        i += 1
    return dic

def generate_local_mat(local, idx):
    mat = []
    for ids, item in enumerate(local):
        #print(ids)
        temp = []
        ind = indexing(idx[ids])
        for i in idx[ids]:
            if str(i) not in list(local[ids].keys()):
                continue
            for k in list(local[ids][str(i)].keys()):
                if k == list(local[ids].keys())[0]:
                    temp_l = local[ids][str(i)][k]
                    temp_ind = ind[i]
                    temp.append([temp_ind, temp_ind, temp_l[temp_l[2]]])
                elif int(k) not in idx[ids]:
                    continue
                else:
                    temp_l = local[ids][str(i)][k]
                    temp_ind1 = ind[i]
                    temp_ind2 = ind[int(k)]
                    temp.append([temp_ind1, temp_ind2, temp_l[temp_l[2]]])
        if not temp:
            for i in idx[ids]:
                temp_ind = ind[i]
                temp.append([temp_ind, temp_ind, [0, 0, 0]])
        mat.append(temp)
    return mat
def create_mat(local_mat, mat_type):
    result = []
    for ind, item in enumerate(local_mat):
        max_row = max(i[0] for i in item)+1
        max_col = max(i[1] for i in item)+1
        if mat_type == 'sum':
            row = np.array(item)[:,0]
            col = np.array(item)[:,1]

            # taking data 
            data = np.array([sum(np.array(i)[2]) for i in item])

            # creating sparse matrix 
            sparseMatrix = csr_matrix((data, (row, col)), shape = (dim, dim)).toarray() 
            result.append(sparseMatrix)
        else:
            matrix = np.zeros((max_row, max_col, 3), dtype=float)
            for x in item:
                row, col, value = x
                matrix[row, col] = [i + 0.05 for i in value]
            result.append(matrix)
    return np.array(result)


In [None]:
def pad_sequences(sequences, max_dim=None, pad_token=0):
    # Determine the maximum sequence length
    max_length = max(len(seq) for seq in sequences)
    if max_dim is not None and max_length > max_dim:
        max_length = max_dim

    # Pad each sequence to the maximum length
    padded_sequences = np.array([np.pad(seq, ((0, max_length - len(seq)), (0, 0)), 
                                        mode='constant', constant_values=pad_token) 
                                 for seq in sequences])

    # Create attention masks
    attention_masks = np.array([[1 if token.any() else 0 for token in seq] 
                                for seq in padded_sequences])
    
    return padded_sequences, attention_masks

def pad_labels(labels, max_dim, pad_token=0):
    """Pad the label sequence to the maximum length."""
    max_length = max(len(seq) for seq in labels)
    if max_dim is not None and max_length > max_dim:
        max_length = max_dim
        
    return np.array([np.pad(seq, (0, max_length - len(seq)), mode='constant', constant_values=pad_token) 
                                 for seq in labels])

def pad_matrix(path, max_dim=None, pad_token=0):
    """Pad a 2D matrix to the specified max_length."""
    max_length = max(matrix.shape[0] for matrix in path)
    if max_dim is not None and max_length > max_dim:
        max_length = max_dim
        
    padded_matrices = []
    for matrix in path:
        truncated_matrix = matrix[:max_length, :max_length, :]
        padding = ((0, max(0, max_length - truncated_matrix.shape[0])), 
                   (0, max(0, max_length - truncated_matrix.shape[1])), 
                   (0, 0))
        adjusted_matrix = np.pad(truncated_matrix, pad_width=padding, mode='constant', constant_values=pad_token)
        padded_matrices.append(adjusted_matrix)

    return padded_matrices

def summ(paths):
    return [[list(map(sum, zip(*sub))) for sub in outer] for outer in paths]


class Batch():
    def __init__(self, data, labels, prob, global_path, local_path, masks):
        self.data = data
        self.labels = labels
        self.prob = prob
        self.global_path = global_path
        self.local_path = local_path
        self.masks = masks

def collate(batch):
    batch_li = [list(item) for item in batch]
    data_temp = [row[0] for row in batch_li]
    labels_temp = [torch.Tensor(row[1]) for row in batch_li]
    prob_temp = [torch.Tensor(row[2]) for row in batch_li]
    global_path_temp = [row[3] for row in batch_li]
    local_path_temp = [row[4] for row in batch_li]
    

    padded_data, masks = pad_sequences(data_temp, max_dim=2000, pad_token=0)
    #padded_labels = pad_labels(labels_temp, max_dim=2000, pad_token=0)
    #padded_prob, _ = pad_sequences(prob_temp, max_dim=2000, pad_token=0)
    padded_global, _ = pad_sequences(summ(global_path_temp), max_dim=2000, pad_token=0)
    padded_local = pad_matrix(local_path_temp, max_dim=2000, pad_token=0)

    data= torch.tensor(padded_data).to(torch.int64)
    #labels = torch.tensor(padded_labels).to(torch.int64)
    #prob = torch.tensor(padded_prob).to(torch.int64)
    global_path = torch.tensor(padded_global).to(torch.int64)
    local_path = torch.tensor(padded_local).to(torch.int64)
    labels = torch.nn.utils.rnn.pad_sequence(labels_temp, batch_first=True)
    prob = torch.nn.utils.rnn.pad_sequence(prob_temp, batch_first=True)
    #global_path = torch.nn.utils.rnn.pad_sequence(global_path_temp, batch_first=True)
    #local_path = torch.nn.utils.rnn.pad_sequence(local_path_temp, batch_first=True)
    #print(masks)
    
    #out_tweet_type = torch.nn.utils.rnn.pad_sequence(out_tweet_types, batch_first=True)
    #print("start")
    return Batch(data, labels, prob, global_path, local_path, torch.tensor(masks))


In [None]:
class TreeNode:
    def __init__(self, name):
        self.name = name
        self.children = []  # List of TreeNode objects
        self.level = 0  # Level of the node in the tree
        self.sibling_order = 0  # Order among siblings
        self.parent = None  # Parent of the node
        self.local_relation = dict()
        self.global_relation = dict()
        

    def add_child(self, child_node):
        child_node.parent = self
        child_node.level = self.level + 1 if self.level is not None else 0
        child_node.sibling_order = len(self.children)
        self.children.append(child_node)

    def num_siblings(self):
        return len(self.parent.children)-1 if self.parent else 0
    
    def extract_data(node_list, only_leaf=False, f=lambda node: node.data):
        ret = []
        for node in node_list:
            if not (only_leaf and node.node_type == "type"):
                ret.append(f(node))
        return ret

    def create_local_relation(self):
        def _dfs(node):
            for child in node.children:
                node_child_rel = [child.level, child.num_siblings(), child.sibling_order]
                node_father_rel = [node.level, node.num_siblings(), node.sibling_order]
                
                node.local_relation[child.name] = [node_child_rel, node_father_rel, 0]
                child.local_relation[node.name] = [node_child_rel, node_father_rel, 1]
                _dfs(child)

        _dfs(self)
    
    def create_global_relation(self):
        def g_dfs(node):
            node_rel = [node.level, node.num_siblings(), node.sibling_order]
            if not node.parent:
                node.global_relation[node.name] = [node_rel]
            else: 
                if node.parent.name not in node.parent.global_relation.keys():
                    node.global_relation[node.name] = node.parent.parent.global_relation[node.parent.parent.name] + [node_rel]
                else:
                    node.global_relation[node.name] = node.parent.global_relation[node.parent.name] + [node_rel]
            for child in node.children:
                g_dfs(child)

        g_dfs(self)
        #return 

    def dfs(self):
        ret = []

        def _dfs(node, ret):
           #ret : List
            ret.append(node)
            for child in node.children:
                _dfs(child, ret)

        _dfs(self, ret)
        return ret
    
def build_tree(conversations):
    nodes = {}
    root = 0

    for parent, child in conversations:
        if parent not in nodes:
            nodes[parent] = TreeNode(parent)
        if child not in nodes:
            nodes[child] = TreeNode(child)
        nodes[parent].add_child(nodes[child])

        if not root:
            root = nodes[parent]

    return root

def get_node_info(tree_root):
    node_info = {}

    def traverse(node):
        node_info[node.name] = {
            'level': node.level,
            'number_of_siblings': node.num_siblings(),
            'sibling_order': node.sibling_order,
        }
        for child in node.children:
            traverse(child)

    traverse(tree_root)
    return node_info


## Mukhil's code

In [2]:
# aliceysu
# bainjal
users = []
journalist = 'aliceysu'
user_id = '24709718'
data_dir = '../desktop/'
path = os.path.join(data_dir, f'tweets_in_{journalist}_started_convs.json')
# '../desktop/tweets_in_aliceysu_started_convs.json'
with open(path) as f:
    for line in f:
        #print(line)
        users.append(json.loads(line))
    #users = [json.loads(line) for line in f]

In [3]:
data = users[0]
user_ids = set()
for i in range(len(data)):  
    user_ids.add(data[i]['conversation_id'])
print(len(user_ids))

420


In [28]:
data[0]

{'possibly_sensitive': False,
 'lang': 'qme',
 'conversation_id': '1603565869673873408',
 'referenced_tweets': [{'type': 'replied_to', 'id': '1603565869673873408'}],
 'edit_history_tweet_ids': ['1608261522870136833'],
 'reply_settings': 'everyone',
 'created_at': '2022-12-29T00:40:01.000Z',
 'public_metrics': {'retweet_count': 0,
  'reply_count': 0,
  'like_count': 0,
  'quote_count': 0,
  'impression_count': 3},
 'entities': {'urls': [{'start': 10,
    'end': 33,
    'url': 'https://t.co/stR6weHTUO',
    'expanded_url': 'https://rumble.com/v21y3qs-covid-19-vaccines-what-they-are-how-they-work-and-possible-causes-of-injuri.html',
    'display_url': 'rumble.com/v21y3qs-covid-…',
    'images': [{'url': 'https://pbs.twimg.com/news_img/1623388282397630467/SN1aV-kE?format=jpg&name=orig',
      'width': 1280,
      'height': 720},
     {'url': 'https://pbs.twimg.com/news_img/1623388282397630467/SN1aV-kE?format=jpg&name=150x150',
      'width': 150,
      'height': 150}],
    'status': 200,
 

In [4]:
people = set()
tweets = []
for i in range(len(data)):
    if (data[i]['conversation_id'] in user_ids):
        people.add(data[i]['author_id'])
        tweets.append(data[i]['author_id'])
print(len(people))
print(len(tweets))

6925
10685


In [5]:
######### INITIALISAING ACTIVITY TRACE DICTIONARY #########
activity_traces = {}
for element in user_ids:
    activity_traces[element] = []

In [6]:
date_format = "%Y-%m-%dT%H:%M:%S" 
datetime.strptime(data[0]['created_at'][:19], date_format)

datetime.datetime(2022, 12, 29, 0, 40, 1)

In [7]:
######### ACCUMLATING TWEETS UNDER THEIR CONVERSATION ID #########
count = 0
for i in range(len(data)):
    conversation_id = data[i]['conversation_id']
    if conversation_id in user_ids:
        if 'referenced_tweets' in data[i]:
            if data[i]['referenced_tweets'][0]['type'] == 'replied_to':
                if [data[i]['id'], datetime.strptime(data[i]['created_at'][:19], date_format), data[i]['author_id'], 1] in activity_traces[conversation_id]:
                    count += 1
                else:
                    activity_traces[conversation_id].append([data[i]['id'], datetime.strptime(data[i]['created_at'][:19], date_format), data[i]['author_id'], 1])
            elif data[i]['referenced_tweets'][0]['type'] == 'quoted':
                if [data[i]['id'], datetime.strptime(data[i]['created_at'][:19], date_format), data[i]['author_id'], 2] in activity_traces[conversation_id]:
                    count += 1
                else:
                    activity_traces[conversation_id].append([data[i]['id'], datetime.strptime(data[i]['created_at'][:19], date_format), data[i]['author_id'], 2])
        else:
            if [data[i]['id'], datetime.strptime(data[i]['created_at'][:19], date_format), data[i]['author_id'], 0] in activity_traces[conversation_id]:
                count += 1
            else:
                activity_traces[conversation_id].append([data[i]['id'], datetime.strptime(data[i]['created_at'][:19], date_format), data[i]['author_id'], 0])
            
# 0 for original tweet, 1 for reply to another tweet, 2 for quoted tweet.
# 54607 duplicated
print(count)

0


In [15]:
activity_traces['1603565869673873408']

[['1608261522870136833',
  datetime.datetime(2022, 12, 29, 0, 40, 1),
  '1519143967886954496',
  1],
 ['1603751795553361921',
  datetime.datetime(2022, 12, 16, 13, 59, 58),
  '2411054094',
  1],
 ['1603566654969245696',
  datetime.datetime(2022, 12, 16, 1, 44, 18),
  '51081835',
  1]]

In [16]:
######### SORTING WITHIN AN ACTIVITY TRACE BASED ON TIME #########
sorted_activity_trace_dirty = []
for key in activity_traces:
    sorted_activity_trace_dirty.append(sorted(activity_traces[key], key=lambda d: d[1]))

In [17]:
######### REMOVING ACTIVITY TRACES THAT ARE OF LENGTH ONE OR LESS #########
index = []
sorted_activity_trace = []
for i in range(len(sorted_activity_trace_dirty)):
    if len(sorted_activity_trace_dirty[i]) > 1:
        sorted_activity_trace.append(sorted_activity_trace_dirty[i]) 

In [27]:
len(sorted_activity_trace_dirty)

420

In [18]:
######### CREATING USER ID MAPPING AND THE REVERSE MAPPING TO OBTAIN INPUT MARKS VALUES AND DECODE THEM #########
ents = []
for actTrace in sorted_activity_trace:
    for act in actTrace:
        ents.append(act[2])

idmap = {}
other_way = {}
for idx, ent in enumerate(set(ents)):
    idmap[ent] = idx
    other_way[idx] = ent



In [19]:
numTraces = len(sorted_activity_trace)

np.random.shuffle(sorted_activity_trace)

valTraces = sorted_activity_trace[:int(0.1*numTraces)]
testTraces = sorted_activity_trace[int(0.1*numTraces):int(0.2*numTraces)]
trainTraces = sorted_activity_trace

In [21]:
split = {}
cols = ['arrival_times', 'delta_times', 'marks', 'tweet_type']

for col in cols:
    split[col] = []

tweetSeqMap = {}
lengths = []

for idx, activityTrace in enumerate(trainTraces): 
    #assert int(activityTrace[0][2]) == userId
    
    starts = [act[1] for act in activityTrace]
    iStart = starts[0]
    
    normStarts = [ (act[1] - iStart).total_seconds()/(24*3600) for act in activityTrace]
    assert normStarts[0] == 0
    deltaTimes = [1.0] 
    for i in range(1,len(normStarts)):
        deltaTimes.append(normStarts[i] - normStarts[i-1] )
        
    marks = []
    for i in range(len(normStarts)):
        marks.append(idmap[activityTrace[i][2]])
    
    tweet_type = []
    for i in range(len(normStarts)):
        tweet_type.append(activityTrace[i][3])
        
    tweetSeqMap[activityTrace[0][0]] = idx
    lengths.append(len(normStarts))
    split['arrival_times'].append(normStarts)
    split['delta_times'].append(deltaTimes)
    split['marks'].append(marks)
    split['tweet_type'].append(tweet_type)

In [22]:
split1 = {}
for col in cols:
    split1[col] = []
for idx, activityTrace in enumerate(testTraces): 
    #assert int(activityTrace[0][2]) == userId
    
    starts = [act[1] for act in activityTrace]
    iStart = starts[0]
    
    normStarts = [ (act[1] - iStart).total_seconds()/(24*3600) for act in activityTrace]
    assert normStarts[0] == 0
    deltaTimes = [1.0] 
    for i in range(1,len(normStarts)):
        deltaTimes.append( normStarts[i] - normStarts[i-1] )
        
    marks = []
    for i in range(len(normStarts)):
        marks.append(idmap[activityTrace[i][2]])
        
    tweet_type = []
    for i in range(len(normStarts)):
        tweet_type.append(activityTrace[i][3])
        
    tweetSeqMap[activityTrace[0][0]] = idx
    
    split1['arrival_times'].append(normStarts)
    split1['delta_times'].append(deltaTimes)
    split1['marks'].append(marks)
    split1['tweet_type'].append(tweet_type)

In [25]:
split2 = {}

for col in cols:
    split2[col] = []
for idx, activityTrace in enumerate(valTraces): 
    #assert int(activityTrace[0][2]) == userId
    
    starts = [act[1] for act in activityTrace]
    iStart = starts[0]
    
    normStarts = [ (act[1] - iStart).total_seconds()/(24*3600) for act in activityTrace]
    assert normStarts[0] == 0
    deltaTimes = [1.0] 
    for i in range(1,len(normStarts)):
        deltaTimes.append( normStarts[i] - normStarts[i-1] )
        
    marks = []
    for i in range(len(normStarts)):
        marks.append(idmap[activityTrace[i][2]])
        
    tweet_type = []
    for i in range(len(normStarts)):
        tweet_type.append(activityTrace[i][3])
    tweetSeqMap[activityTrace[0][0]] = idx

    split2['arrival_times'].append(normStarts)
    split2['delta_times'].append(deltaTimes)
    split2['marks'].append(marks)
    split2['tweet_type'].append(tweet_type)

In [24]:
data_result = {}
data_result["train"] = split
data_result["test"] = split1
data_result["dev"] = split2
data_result['dim_process'] = len(ents)

In [20]:
data_result['dev']

{'arrival_times': [[0.0, 2.527210648148148],
  [0.0, 0.0014467592592592592],
  [0.0, 0.03166666666666667, 30.01508101851852, 437.3130092592593],
  [0.0, 0.14180555555555555],
  [0.0,
   0.006423611111111111,
   0.01766203703703704,
   0.02746527777777778,
   0.07740740740740741,
   0.08822916666666666,
   0.19037037037037038,
   0.4257523148148148,
   0.5247685185185185,
   0.5637152777777777,
   0.66875,
   1.6597337962962964,
   1.6610763888888889,
   1.7932291666666667],
  [0.0,
   0.09506944444444444,
   0.3719675925925926,
   0.421875,
   0.8337037037037037,
   1.1365277777777778,
   3.1604282407407407,
   3.6797569444444442,
   6.233449074074074],
  [0.0, 0.22703703703703704, 1.990324074074074, 2.297083333333333],
  [0.0,
   0.0030555555555555557,
   0.3507986111111111,
   0.6315046296296296,
   2.8250462962962963],
  [0.0,
   0.0012731481481481483,
   0.004340277777777778,
   0.005613425925925926,
   0.012326388888888888,
   0.01238425925925926,
   0.014293981481481482,
   0.014

In [1]:
import pickle
with open ('kdd_data/attempt_dev.pkl', 'rb') as f:
    att = pickle.load(f)

In [4]:
att['dev'].keys()

dict_keys(['arrival_times', 'delta_times', 'marks', 'tweet_type'])

In [150]:
max_len = 0
for i in split['arrival_times']:
    if len(i) > max_len:
        max_len = len(i)

In [151]:
max_len

3252