In [2]:
from collections import Counter
from itertools import combinations
from typing import Sequence, Tuple
import numpy as np
import tqdm
from examples.cmv.reddit_conversation_parser import CMVConversationReader
from stance_classification.utils import iter_trees_from_lines

path = r"C:\Users\ronp\Documents\stance-classification\trees_2.0.txt"
total_trees = sum(1 for _ in iter_trees_from_lines(path))
trees = tqdm.tqdm(iter_trees_from_lines(path), total=total_trees)


  0%|          | 0/16306 [00:00<?, ?it/s]

In [3]:
cmv_reader = CMVConversationReader()
conversations = map(cmv_reader.parse, trees)


In [None]:
# ignore any user who participated in more than 1k conversations
IGNORE_USERS = set(["[deleted]", "DeltaBot", "Ansuz07", "cdb03b", "PreacherJudge", "Iswallowedafly"])

users_indices = {}
users_counts = Counter()
pair_counts = Counter()
for conversation in conversations:
    authors = conversation.participants
    users_counts.update(authors)
    filtered_auhors = [a for a in authors if a not in IGNORE_USERS]
    author_indices = [users_indices.setdefault(a, len(users_indices)) for a in filtered_auhors]
    sorted_authors = sorted(author_indices)
    pairs = combinations(sorted_authors, 2)
    pair_counts.update(pairs)
    
print("total number of users:", len(users_indices))
print("total number of pairs:", len(pair_counts))

  9%|▉         | 1468/16306 [00:18<04:57, 49.83it/s] 

In [None]:
from matplotlib import pyplot as plt

def calculate_ticks(max_rank: int) -> Tuple[Sequence[int], Sequence[str]]:
    num_decimal_digits = np.floor(np.log10(max_rank))
    ticks_interval = 10 ** num_decimal_digits
    if (ticks_interval * 2) >= len(pairs):
        ticks_interval //= 10
        num_decimal_digits -= 1
    
    num_ticks = 1 + (max_rank // ticks_interval)
    ticks = [i * ticks_interval for i in range(num_ticks)]
    
    label_suffix = f"e+{num_decimal_digits}"
    ticks_labels = [f"{i}{label_suffix}" for i in range(num_ticks)]
    
    return ticks, ticks_labels

def plot_count_bins(counts: Sequence[int], pretty_xticks: bool = True):
    log_counts = np.log2(counts)
    log_counts.sort(reverse=True)
    ranks = np.arange(len(log_counts))
    plt.plot(ranks, log_counts)

    plt.xlabel("rank")
    plt.ylabel("log counts")
    
    if pretty_xticks:
        ticks, ticks_labels = calculate_ticks(len(counts))
        plt.xticks(ticks, ticks_labels)
    
    plt.show()


In [None]:
counts = [*pair_counts.values()]
plot_count_bins(counts)



