# Twitch-PT

In [1]:
import pandas as pd
import networkx as nx
from methods import find_groups
from ablation_metrics import compare_subgroups, evaluate_pattern_subgroups
import matplotlib.pyplot as plt
from torch_geometric.datasets import Twitch

In [2]:
dataset = Twitch(root='./data/Twitch', name='PT')

data = dataset[0]

In [3]:
class CustomGraphDataset:
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, idx):
        graph_data = self.dataset[idx]
        
        graph_metrics = {
            'edge_index': graph_data.edge_index,
            'num_nodes': graph_data.num_nodes,
            'node_feat': graph_data.x
        }
        
        labels = graph_data.y
        return graph_metrics, labels

    def __len__(self):
        return len(self.dataset)

custom_dataset = CustomGraphDataset(dataset)
graph, label = custom_dataset[0]

In [4]:
edge_index = data.edge_index 
num_nodes = data.num_nodes 
G = nx.Graph()

G.add_nodes_from(range(num_nodes))

edges = list(zip(edge_index[0].tolist(), edge_index[1].tolist()))
G.add_edges_from(edges)

In [5]:
label

tensor([0, 0, 0,  ..., 1, 0, 1])

In [None]:
attributes = graph['node_feat']
lu = pd.DataFrame(attributes)
lu['target'] = label == 1
lu.head()

In [None]:
result_normal = find_groups(G, 20, lu, ablation_mode=False)
result_normal

In [None]:
result_normal.to_csv('subgroups_twitch.csv')

In [None]:
result_ablation = find_groups(G, 20, lu, ablation_mode=True)
result_ablation

In [None]:
comparison = compare_subgroups(result_normal, result_ablation, k=10)
comparison

In [None]:
mean_ratio_normal, dist_ratios_normal = evaluate_pattern_subgroups(result_normal, lu, k=10)
mean_ratio_ablation, dist_ratios_ablation = evaluate_pattern_subgroups(result_ablation, lu, k=10)

print("Avg fraction of target=1 in top-10 normal subgroups: ", mean_ratio_normal)
print("Avg fraction of target=1 in top-10 ablation subgroups:", mean_ratio_ablation)

In [None]:
plt.figure(figsize=(8,5))
plt.hist(result_normal['q'], bins=30, alpha=0.5, label='Normal')
plt.hist(result_ablation['q'], bins=30, alpha=0.5, label='Ablation')
plt.xlabel('Q Score')
plt.ylabel('Frequency')
plt.title('Distribution of Q in Normal vs. Ablation')
plt.legend()
plt.show()