In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import re
from pathlib import Path
from collections import defaultdict

import pandas as pd
import numpy as np

In [None]:
train_file = Path("out/experiments/finetune/depth_3/train_samples.jsonl")
test_file = Path("out/experiments/finetune/depth_3/test_samples.jsonl")

In [None]:
def count_nodes(file):
    nodes = defaultdict(int)
    with open(file, "r") as f:
        for line in f:
            data = json.loads(line)["response"]
            assert len(data.split("\n")) > 1, data
            for line in data.split("\n"):
                if line == "":
                    continue
                assert len(line.split(" > ")) >= 2, data
                for node in line.split(" > "):
                    nodes[node] += 1
    return nodes

In [None]:
train_count = count_nodes(train_file)
test_count = count_nodes(test_file)

In [None]:
intersection = set(train_count.keys()) & set(test_count.keys())
train_only = set(train_count.keys()) - set(test_count.keys())
test_only = set(test_count.keys()) - set(train_count.keys())


def print_top_samples(node_set, k):
    node_counts = {k: train_count.get(k, 0) + test_count.get(k, 0) for k in node_set}
    top_k = sorted(node_counts.keys(), key=lambda k: node_counts[k], reverse=True)[:k]
    df = pd.DataFrame({"node": top_k, "count": [node_counts[node] for node in top_k]})
    df.set_index("node", inplace=True)
    display(df)


print("Intersection:")
print_top_samples(intersection, 20)

print("\nTrain only:")
print_top_samples(train_only, 20)

print("\nTest only:")
print_top_samples(test_only, 20)