# GNN Inference Notebook
Loads saved artifacts (model + config) and demonstrates batch toxicity prediction on comment graph.

## 1. Load Artifacts

In [None]:
import json, torch
from pathlib import Path
from torch_geometric.nn import HeteroConv, SAGEConv, Linear

art_dir=Path('artifacts')
assert art_dir.exists(), 'artifacts/ missing. Train the model first.'
config=json.load(open(art_dir/'config.json'))
state=torch.load(art_dir/'model_best.pt', map_location='cpu')

class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden, out_classes):
        super().__init__()
        self.conv1 = HeteroConv({
            ('user','authored','comment'): SAGEConv((-1,-1), hidden),
            ('comment','replies_to','comment'): SAGEConv((-1,-1), hidden),
        }, aggr='mean')
        self.lin = Linear(hidden, out_classes)
    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        return self.lin(x_dict['comment'].relu())

model=HeteroGNN(config['gnn_hidden'], config['gnn_out_classes'])
model.load_state_dict(state['model_state'])
model.eval()
print('Loaded model. Hidden:', config['gnn_hidden'])

## 2. Reconstruct Minimal Graph for Inference

In [None]:
import pandas as pd, numpy as np
from torch_geometric.data import HeteroData

csv_path=config['data_csv']
df=pd.read_csv(csv_path)
# replicate label engineering if needed
if 'ToxicBinary' not in df.columns:
    if df['ToxicLabel'].dtype==object:
        df['ToxicBinary']=df['ToxicLabel'].str.lower().str.startswith('toxic').astype(int)
    else:
        df['ToxicBinary']=(df['ToxicScore']>0.7).astype(int)

# Create trivial embeddings placeholder (zeros) -> real pipeline would reload saved embeddings
emb_dim= model.lin.in_channels if hasattr(model.lin,'in_channels') else  config['gnn_hidden']
comment_ids = df[config['comment_id_col']].astype(str).tolist()
comments_emb = np.zeros((len(comment_ids), emb_dim), dtype='float32')

users = df[config['user_col']].astype(str).unique().tolist()
user_feat = np.zeros((len(users),2), dtype='float32')

cid_map={c:i for i,c in enumerate(comment_ids)}
user_map={u:i for i,u in enumerate(users)}

hetero_inf=HeteroData()
hetero_inf['comment'].x = torch.tensor(comments_emb)
hetero_inf['comment'].y = torch.tensor(df['ToxicBinary'].values)
hetero_inf['user'].x = torch.tensor(user_feat)
# authored edges
src=[user_map[r[config['user_col']]] for _,r in df.iterrows()]
dst=[cid_map[r[config['comment_id_col']]] for _,r in df.iterrows()]
hetero_inf['user','authored','comment'].edge_index = torch.tensor([src,dst])
# no reply edges in minimal reconstruction

logits = model(hetero_inf.x_dict, hetero_inf.edge_index_dict)
probs = torch.softmax(logits, dim=1)[:,1].detach().numpy()
print('Pred shape:', probs.shape)

## 3. Attach Predictions & Export

In [None]:
import pandas as pd, time
pred_df = df[[config['comment_id_col'], config['user_col']]].copy()
pred_df['PredToxicProb']=probs
pred_df['PredToxicLabel']=(pred_df['PredToxicProb']>0.5).astype(int)
stamp=time.strftime('%Y%m%d_%H%M%S')
out_path=f'predictions_{stamp}.csv'
pred_df.to_csv(out_path, index=False)
print('Wrote', out_path)
pred_df.head()