<a href="https://colab.research.google.com/github/TahaErr/ACV_Term_Project/blob/main/PPE_GNN_Preparation_Data_Split.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Data save path from PPE_DATA FACTORY
#DATASET_SAVE_DIR = '/content/drive/Othercomputers/Dizüstü Bilgisayarım/Google drive Masaustu/0-PPE/0-Data Factory outputs'


In [None]:
!pip install ultralytics

Collecting ultralytics
  Downloading ultralytics-8.3.242-py3-none-any.whl.metadata (37 kB)
Collecting ultralytics-thop>=2.0.18 (from ultralytics)
  Downloading ultralytics_thop-2.0.18-py3-none-any.whl.metadata (14 kB)
Downloading ultralytics-8.3.242-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m38.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ultralytics_thop-2.0.18-py3-none-any.whl (28 kB)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.242 ultralytics-thop-2.0.18


In [None]:
import os, json, glob
import pandas as pd
from collections import Counter
from tqdm import tqdm

DATA_DIR = "/content/drive/Othercomputers/Dizüstü Bilgisayarım/Google drive Masaustu/0-PPE/0-Data Factory outputs"

# Edge type ids (ÖNCEKİ JSON ÜRETİM KODUNUZLA AYNI)
EDGE_SELF_LOOP     = 0
EDGE_WEAR_PPE      = 1
EDGE_NEAR_MACHINE  = 2

def safe_int(x, default=-1):
    try:
        return int(x)
    except Exception:
        return default

def parse_one_graph(g):
    """
    g: dict loaded from one JSON graph file
    returns:
      graph_row: one row per graph
      person_rows: rows per person_state
      machine_labels: list of machine label strings
      unexpected_edge_ids: set of unexpected edge ids seen in this graph
    """
    src = str(g.get("source_video", "unknown"))
    label = safe_int(g.get("label", -1))
    frame_id = safe_int(g.get("frame_id", -1))

    has_person = bool(g.get("has_person", False))
    num_persons = safe_int(g.get("num_persons", 0))
    num_machines = safe_int(g.get("num_machines", 0))

    # Machine subclasses: node_types üzerinden (en sağlam)
    node_types = g.get("node_types", [])
    machine_labels = []
    for nt in node_types:
        if isinstance(nt, str) and nt.startswith("machine:"):
            machine_labels.append(nt.split("machine:", 1)[1])

    # Edge type distribution
    et = g.get("edge_type", [])
    et_counter = Counter(et)

    allowed_edge_ids = {EDGE_SELF_LOOP, EDGE_WEAR_PPE, EDGE_NEAR_MACHINE}
    unexpected_edge_ids = set([e for e in et_counter.keys() if e not in allowed_edge_ids])

    # Near edges ve proximity (edge_attr proximity)
    near_edge_count = et_counter.get(EDGE_NEAR_MACHINE, 0)
    edge_attr = g.get("edge_attr", [])
    near_prox_vals = []
    if et and edge_attr and len(et) == len(edge_attr):
        for t, a in zip(et, edge_attr):
            if t == EDGE_NEAR_MACHINE:
                try:
                    near_prox_vals.append(float(a[0]))
                except Exception:
                    pass

    avg_near_prox = (sum(near_prox_vals) / len(near_prox_vals)) if near_prox_vals else None

    # Person states
    ps = g.get("person_states", [])
    person_rows = []
    for p in ps:
        person_rows.append({
            "source_video": src,
            "frame_id": frame_id,
            "label": label,
            "person_node_id": safe_int(p.get("person_node_id", -1)),
            "has_helmet": bool(p.get("has_helmet", False)),
            "has_vest": bool(p.get("has_vest", False)),
            "near_machine": bool(p.get("near_machine", False)),
        })

    graph_row = {
        "source_video": src,
        "frame_id": frame_id,
        "label": label,
        "has_person": has_person,
        "num_persons": num_persons,
        "num_machines": num_machines,
        "num_machine_types_in_graph": len(set(machine_labels)),

        # near edge stats
        "near_edge_count": near_edge_count,
        "avg_near_proximity": avg_near_prox,

        # edge type counts (gloves yok)
        "edge_self": et_counter.get(EDGE_SELF_LOOP, 0),
        "edge_wear_ppe": et_counter.get(EDGE_WEAR_PPE, 0),
        "edge_near": et_counter.get(EDGE_NEAR_MACHINE, 0),
    }

    return graph_row, person_rows, machine_labels, unexpected_edge_ids


# --- Load all JSONs ---
json_files = sorted(glob.glob(os.path.join(DATA_DIR, "*.json")))
print("Total JSON files:", len(json_files))

graph_rows = []
person_rows_all = []
machine_labels_all = []

unexpected_edges_global = Counter()

for fp in tqdm(json_files):
    with open(fp, "r") as f:
        g = json.load(f)

    gr, pr, mlabels, unexpected = parse_one_graph(g)
    graph_rows.append(gr)
    person_rows_all.extend(pr)
    machine_labels_all.extend([(gr["source_video"], ml) for ml in mlabels])

    for e in unexpected:
        unexpected_edges_global[e] += 1

graphs_df = pd.DataFrame(graph_rows)
persons_df = pd.DataFrame(person_rows_all)
machines_df = pd.DataFrame(machine_labels_all, columns=["source_video", "machine_label"])

print("graphs_df:", graphs_df.shape, "| persons_df:", persons_df.shape, "| machines_df:", machines_df.shape)

# --- Sanity: unexpected edge ids ---
if len(unexpected_edges_global) > 0:
    print("\n[WARNING] Beklenmeyen edge_type id tespit edildi! (Bu yeni pipeline’a uymuyor olabilir)")
    print(dict(unexpected_edges_global))
else:
    print("\nEdge type sanity OK: Sadece {0:self, 1:wear_ppe, 2:near_machine} kullanılmış.")


# --- 1) Label distribution per video + overall ---
def label_table(df, by="source_video"):
    tab = (
        df.groupby(by)["label"]
          .value_counts(dropna=False)
          .unstack(fill_value=0)
          .sort_index()
    )
    # Risk label logic halen 0..3 varsayımıyla raporluyoruz
    for c in [0, 1, 2, 3]:
        if c not in tab.columns:
            tab[c] = 0
    tab = tab[[0, 1, 2, 3]]
    tab["total"] = tab.sum(axis=1)

    pct = tab[[0, 1, 2, 3]].div(tab["total"], axis=0).round(4)
    pct.columns = [f"p{c}" for c in pct.columns]
    return pd.concat([tab, pct], axis=1)

print("\n=== Label distribution per video ===")
per_video_labels = label_table(graphs_df, by="source_video")
print(per_video_labels)

print("\n=== Label distribution overall ===")
overall_labels = graphs_df["label"].value_counts().to_dict()
print(overall_labels)


# --- 2) Person PPE / near statistics per video + overall ---
if len(persons_df) > 0:
    persons_df["ppe_complete"] = persons_df["has_helmet"] & persons_df["has_vest"]

    per_video_person = persons_df.groupby("source_video").agg(
        persons=("person_node_id", "count"),
        helmet_rate=("has_helmet", "mean"),
        vest_rate=("has_vest", "mean"),
        ppe_complete_rate=("ppe_complete", "mean"),
        near_rate=("near_machine", "mean"),
    ).round(4)

    print("\n=== Person-level PPE/near stats per video ===")
    print(per_video_person)

    overall_person = persons_df.agg(
        persons=("person_node_id", "count"),
        helmet_rate=("has_helmet", "mean"),
        vest_rate=("has_vest", "mean"),
        ppe_complete_rate=("ppe_complete", "mean"),
        near_rate=("near_machine", "mean"),
    ).round(4)

    print("\n=== Person-level PPE/near stats overall ===")
    print(overall_person)
else:
    print("\nNo person_states found in JSONs.")


# --- 3) Machine type diversity per video + overall ---
if len(machines_df) > 0:
    per_video_machine = (
        machines_df.groupby("source_video")["machine_label"]
        .value_counts()
        .unstack(fill_value=0)
        .sort_index()
    )
    per_video_machine["unique_machine_types"] = (per_video_machine > 0).sum(axis=1)
    per_video_machine["total_machine_nodes"] = per_video_machine.drop(columns=["unique_machine_types"]).sum(axis=1)

    print("\n=== Machine type counts per video ===")
    print(per_video_machine)

    overall_machine_counts = machines_df["machine_label"].value_counts()
    print("\n=== Machine type counts overall ===")
    print(overall_machine_counts.to_dict())
else:
    print("\nNo machines found in JSONs.")


# --- 4) Edge type distribution per video + overall (graph-level aggregation) ---
edge_cols = ["edge_self", "edge_wear_ppe", "edge_near"]
per_video_edges = graphs_df.groupby("source_video")[edge_cols].sum()
per_video_edges["total_edges"] = per_video_edges.sum(axis=1)

print("\n=== Edge type totals per video ===")
print(per_video_edges)

overall_edges = graphs_df[edge_cols].sum().to_dict()
overall_edges["total_edges"] = sum(overall_edges.values())
print("\n=== Edge type totals overall ===")
print(overall_edges)


# --- 5) Quick diversity checks ---
MIN_GRAPHS_PER_VIDEO = 300
MIN_LABEL_COUNT_PER_VIDEO = 30  # her label için, video bazında asgari örnek

print("\n=== Diversity Alerts (video bazında) ===")
alerts = []
for vid, row in per_video_labels.iterrows():
    total = int(row["total"])
    if total < MIN_GRAPHS_PER_VIDEO:
        alerts.append((vid, f"Graph sayısı düşük: {total} < {MIN_GRAPHS_PER_VIDEO}"))
    for lab in [0, 1, 2, 3]:
        if int(row.get(lab, 0)) < MIN_LABEL_COUNT_PER_VIDEO:
            alerts.append((vid, f"Label {lab} az: {int(row.get(lab, 0))} < {MIN_LABEL_COUNT_PER_VIDEO}"))

if alerts:
    for a in alerts:
        print(f"Video {a[0]}: {a[1]}")
else:
    print("Alarm yok: video bazında temel dağılım eşikleri sağlanıyor.")


# --- 6) Save summaries (optional) ---
summary_dir = os.path.join(DATA_DIR, "_summaries")
os.makedirs(summary_dir, exist_ok=True)

graphs_df.to_csv(os.path.join(summary_dir, "graphs_summary.csv"), index=False)
persons_df.to_csv(os.path.join(summary_dir, "persons_summary.csv"), index=False)
if len(machines_df) > 0:
    machines_df.to_csv(os.path.join(summary_dir, "machines_summary.csv"), index=False)

per_video_labels.to_csv(os.path.join(summary_dir, "label_dist_per_video.csv"))
if len(persons_df) > 0:
    per_video_person.to_csv(os.path.join(summary_dir, "person_stats_per_video.csv"))
if len(machines_df) > 0:
    per_video_machine.to_csv(os.path.join(summary_dir, "machine_types_per_video.csv"))
per_video_edges.to_csv(os.path.join(summary_dir, "edge_types_per_video.csv"))

print("\nSaved summaries to:", summary_dir)


Total JSON files: 6537


100%|██████████| 6537/6537 [02:05<00:00, 52.16it/s] 


graphs_df: (6537, 12) | persons_df: (11782, 7) | machines_df: (7579, 2)

Edge type sanity OK: Sadece {0:self, 1:wear_ppe, 2:near_machine} kullanılmış.

=== Label distribution per video ===
                 0     1   2    3  total      p0      p1      p2      p3
source_video                                                            
1               68   302   2   36    408  0.1667  0.7402  0.0049  0.0882
2             1030   478   5  794   2307  0.4465  0.2072  0.0022  0.3442
3               86   361   5   71    523  0.1644  0.6902  0.0096  0.1358
4               57   273   2  156    488  0.1168  0.5594  0.0041  0.3197
5              181   274   1   91    547  0.3309  0.5009  0.0018  0.1664
6              255  1028  53  188   1524  0.1673  0.6745  0.0348  0.1234
7              272   249   0  219    740  0.3676  0.3365  0.0000  0.2959

=== Label distribution overall ===
{1: 2965, 0: 1949, 3: 1555, 2: 68}

=== Person-level PPE/near stats per video ===
              persons  helmet_rate  

Bu kodun yaptığı iş, DATA_DIR içindeki frame graph JSON dosyalarınızı tarayıp, GNN eğitimi için “dosya yolu + etiket” bilgisini içeren bir manifest (indeks) CSV üretmektir. Ayrıca video bazında etiket dağılımlarını raporlayıp CSV’ye kaydeder.

In [None]:
import os, glob, json
import pandas as pd
from tqdm import tqdm

DATA_DIR = "/content/drive/Othercomputers/Dizüstü Bilgisayarım/Google drive Masaustu/0-PPE/0-Data Factory outputs"
OUT_DIR = os.path.join(DATA_DIR, "_index_3class")
os.makedirs(OUT_DIR, exist_ok=True)

# Yeni kural sistemine göre 3-class mapping (gloves yok)
# raw label meanings (sizdeki risk mantığı):
# 0: uzak + (helmet & vest var)
# 1: uzak + (helmet veya vest eksik)
# 2: yakın + (helmet & vest var)
# 3: yakın + (helmet veya vest eksik)
#
# 3-class training:
# 0 -> Safe
# 1,2 -> Warning (ara seviye)
# 3 -> Critical
LABEL_MAP_3C = {0: 0, 1: 1, 2: 1, 3: 2}

rows = []
json_files = sorted(glob.glob(os.path.join(DATA_DIR, "*.json")))
print("Total JSON files found:", len(json_files))

for fp in tqdm(json_files):
    with open(fp, "r") as f:
        g = json.load(f)

    y = int(g.get("label", -1))
    if y not in LABEL_MAP_3C:
        continue  # beklenmeyen label'ları atla (safety)

    rows.append({
        "path": fp,
        "file": os.path.basename(fp),
        "source_video": str(g.get("source_video", "unknown")),
        "frame_id": int(g.get("frame_id", -1)),
        "label_raw": y,                 # {0,1,2,3}
        "label_3c": LABEL_MAP_3C[y],    # {0,1,2}
        "has_person": bool(g.get("has_person", False)),
        "num_persons": int(g.get("num_persons", 0)),
        "num_machines": int(g.get("num_machines", 0)),
    })

df = pd.DataFrame(rows)
df = df.sort_values(["source_video", "frame_id"]).reset_index(drop=True)

manifest_path = os.path.join(OUT_DIR, "manifest_3class.csv")
df.to_csv(manifest_path, index=False)

print("\nSaved:", manifest_path)
print("\nLabel raw distribution:")
print(df["label_raw"].value_counts().sort_index())

print("\nLabel 3-class distribution (train ids):")
print(df["label_3c"].value_counts().sort_index())

# --- Video bazında RAW label dağılımı (0..3) ---
per_video_raw = (
    df.groupby("source_video")["label_raw"]
      .value_counts()
      .unstack(fill_value=0)
      .sort_index()
)
for c in [0, 1, 2, 3]:
    if c not in per_video_raw.columns:
        per_video_raw[c] = 0
per_video_raw = per_video_raw[[0, 1, 2, 3]]
per_video_raw["total"] = per_video_raw.sum(axis=1)
per_video_raw.to_csv(os.path.join(OUT_DIR, "label_dist_per_video_raw.csv"))

print("\nSaved:", os.path.join(OUT_DIR, "label_dist_per_video_raw.csv"))

# --- Video bazında 3-class label dağılımı (0..2) ---
per_video_3c = (
    df.groupby("source_video")["label_3c"]
      .value_counts()
      .unstack(fill_value=0)
      .sort_index()
)
for c in [0, 1, 2]:
    if c not in per_video_3c.columns:
        per_video_3c[c] = 0
per_video_3c = per_video_3c[[0, 1, 2]]
per_video_3c["total"] = per_video_3c.sum(axis=1)
per_video_3c.to_csv(os.path.join(OUT_DIR, "label_dist_per_video_3class.csv"))

print("\nSaved:", os.path.join(OUT_DIR, "label_dist_per_video_3class.csv"))

# Colab çıktıları
display(per_video_raw)
display(per_video_3c)


Total JSON files found: 6537


100%|██████████| 6537/6537 [00:47<00:00, 137.54it/s]



Saved: /content/drive/Othercomputers/Dizüstü Bilgisayarım/Google drive Masaustu/0-PPE/0-Data Factory outputs/_index_3class/manifest_3class.csv

Label raw distribution:
label_raw
0    1949
1    2965
2      68
3    1555
Name: count, dtype: int64

Label 3-class distribution (train ids):
label_3c
0    1949
1    3033
2    1555
Name: count, dtype: int64

Saved: /content/drive/Othercomputers/Dizüstü Bilgisayarım/Google drive Masaustu/0-PPE/0-Data Factory outputs/_index_3class/label_dist_per_video_raw.csv

Saved: /content/drive/Othercomputers/Dizüstü Bilgisayarım/Google drive Masaustu/0-PPE/0-Data Factory outputs/_index_3class/label_dist_per_video_3class.csv


label_raw,0,1,2,3,total
source_video,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,68,302,2,36,408
2,1030,478,5,794,2307
3,86,361,5,71,523
4,57,273,2,156,488
5,181,274,1,91,547
6,255,1028,53,188,1524
7,272,249,0,219,740


label_3c,0,1,2,total
source_video,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,68,304,36,408
2,1030,483,794,2307
3,86,366,71,523
4,57,275,156,488
5,181,275,91,547
6,255,1081,188,1524
7,272,249,219,740


In [None]:
# Her video için dosya listesi üret (split kolaylaşır)
lists_dir = os.path.join(OUT_DIR, "by_video_lists")
os.makedirs(lists_dir, exist_ok=True)

for vid, sub in df.groupby("source_video"):
    out_fp = os.path.join(lists_dir, f"video_{vid}_files.txt")
    sub["path"].to_csv(out_fp, index=False, header=False)

print("Saved per-video file lists to:", lists_dir)


Saved per-video file lists to: /content/drive/Othercomputers/Dizüstü Bilgisayarım/Google drive Masaustu/0-PPE/0-Data Factory outputs/_index_3class/by_video_lists


In [None]:
import os
import numpy as np
import pandas as pd

# --- Paths (YENİ SİSTEM) ---
DATA_DIR = "/content/drive/Othercomputers/Dizüstü Bilgisayarım/Google drive Masaustu/0-PPE/0-Data Factory outputs"
INDEX_DIR = os.path.join(DATA_DIR, "_index_3class")
manifest_path = os.path.join(INDEX_DIR, "manifest_3class.csv")

OUT_SPLIT_DIR = os.path.join(INDEX_DIR, "splits_chunked")
os.makedirs(OUT_SPLIT_DIR, exist_ok=True)

# --- Split params ---
train_ratio = 0.80
val_ratio   = 0.10
test_ratio  = 0.10

chunk_size_graphs = 60    # ~60 seconds since you sample ~1 graph/sec
gap_graphs = 3            # default buffer
seed = 1337

# --- Guarantees / thresholds ---
MIN_CHUNKS_FOR_BOTH_EVAL = 3
MIN_EVAL_SAMPLES_AFTER_GAP = 30
PROTECT_SHORT_CHUNKS_FROM_EVAL = True

# Optional (genelde faydalı): person olmayan graph'ları çıkar
FILTER_NO_PERSON_GRAPHS = False  # True yaparsanız has_person==False satırları split'e girmez

assert abs((train_ratio + val_ratio + test_ratio) - 1.0) < 1e-9, "Ratios must sum to 1.0"

df = pd.read_csv(manifest_path)

# Yeni 3-class şema:
# label_raw: 0..3 (risk)
# label_3c:  0..2  (0->0, 1/2->1, 3->2)
assert set(df["label_3c"].unique()).issubset({0, 1, 2}), df["label_3c"].unique()

# (Opsiyonel) person olmayan graph'ları çıkar
if FILTER_NO_PERSON_GRAPHS:
    df = df[df["has_person"] == True].copy()

# Ensure proper sorting inside each video
df = df.sort_values(["source_video", "frame_id"]).reset_index(drop=True)

# Create per-video sample index (0..N-1)
df["sample_idx"] = df.groupby("source_video").cumcount()

# Assign chunk id (contiguous blocks)
df["chunk_id"] = (df["sample_idx"] // chunk_size_graphs).astype(int)

# Position inside chunk (for gap trimming)
df["pos_in_chunk"] = (df["sample_idx"] % chunk_size_graphs).astype(int)

rng = np.random.default_rng(seed)

# We'll assign chunk_ids to splits per video
df["split"] = None

for vid, sub in df.groupby("source_video", sort=False):
    chunk_ids = np.array(sorted(sub["chunk_id"].unique()))
    n_chunks = len(chunk_ids)

    # Chunk lengths (detect tail/short chunk)
    chunk_len_map = sub.groupby("chunk_id")["sample_idx"].count().to_dict()
    short_chunks = {cid for cid, clen in chunk_len_map.items() if clen < chunk_size_graphs}

    # Shuffle chunk order (still keeps within-chunk contiguity)
    rng.shuffle(chunk_ids)

    # Base rounding
    n_train = int(round(n_chunks * train_ratio))
    n_val   = int(round(n_chunks * val_ratio))
    n_test  = n_chunks - n_train - n_val

    # Fix rare rounding issues
    if n_test < 0:
        n_test = 0
        n_val  = max(0, n_chunks - n_train - n_test)

    # Guarantee #1: ensure >=1 val and >=1 test chunk if enough chunks
    if n_chunks >= MIN_CHUNKS_FOR_BOTH_EVAL:
        if n_val == 0:
            n_val = 1
        if n_test == 0:
            n_test = 1
        n_train = n_chunks - n_val - n_test

        # Ensure at least 1 train chunk as well
        if n_train < 1:
            while n_train < 1 and (n_val > 1 or n_test > 1):
                if n_val >= n_test and n_val > 1:
                    n_val -= 1
                elif n_test > 1:
                    n_test -= 1
                n_train = n_chunks - n_val - n_test

    train_chunks = set(chunk_ids[:n_train])
    val_chunks   = set(chunk_ids[n_train:n_train + n_val])
    test_chunks  = set(chunk_ids[n_train + n_val:])

    # Protect short/tail chunk from val/test (keeps eval meaningful)
    if PROTECT_SHORT_CHUNKS_FROM_EVAL and n_chunks >= MIN_CHUNKS_FOR_BOTH_EVAL and len(short_chunks) > 0:
        def swap_short_from_eval(eval_set):
            shorts_in_eval = eval_set & short_chunks
            if not shorts_in_eval:
                return
            short_cid = next(iter(shorts_in_eval))

            # find a good candidate from train (prefer full chunks)
            candidates = list(train_chunks - short_chunks)
            if not candidates:
                return
            candidate = max(candidates, key=lambda c: chunk_len_map.get(c, 0))

            # swap
            eval_set.remove(short_cid)
            eval_set.add(candidate)
            train_chunks.remove(candidate)
            train_chunks.add(short_cid)

        swap_short_from_eval(val_chunks)
        swap_short_from_eval(test_chunks)

    # Assign splits back to df
    df.loc[sub.index[sub["chunk_id"].isin(train_chunks)], "split"] = "train"
    df.loc[sub.index[sub["chunk_id"].isin(val_chunks)],   "split"] = "val"
    df.loc[sub.index[sub["chunk_id"].isin(test_chunks)],  "split"] = "test"

# Final sanity: no null splits
assert df["split"].isna().sum() == 0, "Some rows have no split."

# --- Guarantee #2: adaptive gap per video (if eval drops too much after gap) ---
df_pre_gap = df.copy()

if gap_graphs > 0:
    chunk_len = df.groupby(["source_video", "chunk_id"])["sample_idx"].count().rename("chunk_len")
    df_gap = df.merge(chunk_len, on=["source_video", "chunk_id"], how="left")

    # Keep only middle part of each chunk
    keep = (df_gap["pos_in_chunk"] >= gap_graphs) & (df_gap["pos_in_chunk"] < (df_gap["chunk_len"] - gap_graphs))
    df_gap = df_gap[keep].copy()
    df_gap.drop(columns=["chunk_len"], inplace=True)

    # Compute per-video eval sizes after gap
    counts_gap = df_gap.groupby(["source_video", "split"]).size().unstack(fill_value=0)
    for c in ["train", "val", "test"]:
        if c not in counts_gap.columns:
            counts_gap[c] = 0

    bad_videos = counts_gap.index[
        (counts_gap["val"] < MIN_EVAL_SAMPLES_AFTER_GAP) | (counts_gap["test"] < MIN_EVAL_SAMPLES_AFTER_GAP)
    ].tolist()

    # For those videos: use gap=0 (revert to pre-gap rows for that video)
    if len(bad_videos) > 0:
        df_good = df_gap[~df_gap["source_video"].isin(bad_videos)].copy()
        df_bad  = df_pre_gap[df_pre_gap["source_video"].isin(bad_videos)].copy()
        df = pd.concat([df_good, df_bad], ignore_index=True)
    else:
        df = df_gap

# Keep deterministic order at the end (optional but nice)
df = df.sort_values(["source_video", "frame_id"]).reset_index(drop=True)

# --- Save split CSVs ---
train_df = df[df["split"] == "train"].copy()
val_df   = df[df["split"] == "val"].copy()
test_df  = df[df["split"] == "test"].copy()

train_csv = os.path.join(OUT_SPLIT_DIR, "train.csv")
val_csv   = os.path.join(OUT_SPLIT_DIR, "val.csv")
test_csv  = os.path.join(OUT_SPLIT_DIR, "test.csv")

train_df.to_csv(train_csv, index=False)
val_df.to_csv(val_csv, index=False)
test_df.to_csv(test_csv, index=False)

# Also export path lists
train_txt = os.path.join(OUT_SPLIT_DIR, "train_paths.txt")
val_txt   = os.path.join(OUT_SPLIT_DIR, "val_paths.txt")
test_txt  = os.path.join(OUT_SPLIT_DIR, "test_paths.txt")

train_df["path"].to_csv(train_txt, index=False, header=False)
val_df["path"].to_csv(val_txt, index=False, header=False)
test_df["path"].to_csv(test_txt, index=False, header=False)

print("Saved splits to:", OUT_SPLIT_DIR)
print("Train/Val/Test sizes:", len(train_df), len(val_df), len(test_df))

# Quick distribution checks
def dist_report(name, part):
    print(f"\n=== {name} label_3c distribution ===")
    print(part["label_3c"].value_counts(normalize=False).sort_index())
    print(f"=== {name} label_3c percentages ===")
    print((part["label_3c"].value_counts(normalize=True).sort_index() * 100).round(2))

dist_report("TRAIN", train_df)
dist_report("VAL", val_df)
dist_report("TEST", test_df)

print("\n=== Split counts per video ===")
print(df.groupby(["source_video", "split"]).size().unstack(fill_value=0).sort_index())

# Ensure no overlap
assert set(train_df["file"]).isdisjoint(set(val_df["file"]))
assert set(train_df["file"]).isdisjoint(set(test_df["file"]))
assert set(val_df["file"]).isdisjoint(set(test_df["file"]))
print("\nNo file overlap across splits: OK")


Saved splits to: /content/drive/Othercomputers/Dizüstü Bilgisayarım/Google drive Masaustu/0-PPE/0-Data Factory outputs/_index_3class/splits_chunked
Train/Val/Test sizes: 4563 648 648

=== TRAIN label_3c distribution ===
label_3c
0    1323
1    2133
2    1107
Name: count, dtype: int64
=== TRAIN label_3c percentages ===
label_3c
0    28.99
1    46.75
2    24.26
Name: proportion, dtype: float64

=== VAL label_3c distribution ===
label_3c
0    157
1    331
2    160
Name: count, dtype: int64
=== VAL label_3c percentages ===
label_3c
0    24.23
1    51.08
2    24.69
Name: proportion, dtype: float64

=== TEST label_3c distribution ===
label_3c
0    268
1    262
2    118
Name: count, dtype: int64
=== TEST label_3c percentages ===
label_3c
0    41.36
1    40.43
2    18.21
Name: proportion, dtype: float64

=== Split counts per video ===
split         test  train  val
source_video                  
1               54    258   54
2              216   1641  216
3               54    361   54
4   

In [None]:
import torch_geometric
print("torch_geometric:", torch_geometric.__version__)

from torch_geometric.data import Data
print("PyG Data import OK")


torch_geometric: 2.7.0
PyG Data import OK


In [None]:
import os
import json
import hashlib
from dataclasses import dataclass
from collections import Counter

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.nn import GINEConv, GlobalAttention, LayerNorm


In [None]:
# IMPORTANT: mapping/caches changed -> delete old cache once before running
import shutil
if os.path.exists(CACHE_ROOT):
    shutil.rmtree(CACHE_ROOT)
os.makedirs(CACHE_ROOT, exist_ok=True)


In [None]:
import os
import json
import glob
import torch
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset

# -----------------------------
# New system constants (NO GLOVES)
# edge_type ids must be:
#   0: self-loop
#   1: wear_ppe (person->helmet/vest)
#   2: near_machine (person->machine)  edge_attr = proximity in [0,1]
# -----------------------------
EDGE_SELF_LOOP    = 0
EDGE_WEAR_PPE     = 1
EDGE_NEAR_MACHINE = 2
ALLOWED_EDGE_IDS  = {EDGE_SELF_LOOP, EDGE_WEAR_PPE, EDGE_NEAR_MACHINE}


def coerce_edge_attr(edge_attr, E: int, jpath: str):
    """
    Ensure edge_attr is list-of-lists with shape [E, 1].
    Accepts:
      - [[v], [v], ...]
      - [v, v, ...]  -> converted to [[v], ...]
      - [] if E==0
    """
    if E == 0:
        return []

    if edge_attr is None:
        raise ValueError(f"edge_attr is None but E>0 in {jpath}")

    if isinstance(edge_attr, list) and len(edge_attr) == E:
        # Case A: already list-of-lists
        if all(isinstance(a, (list, tuple)) for a in edge_attr):
            out = []
            for a in edge_attr:
                if len(a) < 1:
                    out.append([0.0])
                else:
                    out.append([float(a[0])])
            return out

        # Case B: flat list [v, v, ...]
        if all(isinstance(a, (int, float)) for a in edge_attr):
            return [[float(a)] for a in edge_attr]

    raise ValueError(
        f"edge_attr format unexpected in {jpath}. "
        f"Expected len==E ({E}) and either list-of-lists [[v],..] or flat [v,..]. "
        f"Got type={type(edge_attr).__name__}, len={len(edge_attr) if isinstance(edge_attr, list) else 'NA'}"
    )


def load_or_build_video2id(split_csv: str, processed_dir: str) -> dict:
    """
    Build a consistent mapping source_video(str) -> small int id [0..V-1]
    shared across train/val/test by scanning CSVs in the same folder.

    Saved to: processed_dir/video2id.json
    """
    os.makedirs(processed_dir, exist_ok=True)
    map_path = os.path.join(processed_dir, "video2id.json")

    if os.path.exists(map_path):
        with open(map_path, "r") as f:
            video2id = json.load(f)
        # json keys are str, values may come as int already; ensure int
        return {str(k): int(v) for k, v in video2id.items()}

    split_dir = os.path.dirname(split_csv)

    # Prefer the canonical split files if present; otherwise fall back to all csv in folder
    preferred = [os.path.join(split_dir, n) for n in ["train.csv", "val.csv", "test.csv"]]
    csvs = [p for p in preferred if os.path.exists(p)]
    if not csvs:
        csvs = sorted(glob.glob(os.path.join(split_dir, "*.csv")))

    videos = set()
    for cp in csvs:
        try:
            tmp = pd.read_csv(cp, usecols=["source_video"])
            videos.update(tmp["source_video"].astype(str).tolist())
        except Exception:
            # source_video column missing or unreadable -> ignore
            pass

    if not videos:
        videos = {"unknown"}

    videos = sorted(videos)
    video2id = {v: i for i, v in enumerate(videos)}  # 0..V-1 deterministic

    with open(map_path, "w") as f:
        json.dump(video2id, f, indent=2)

    return video2id


class GraphJsonInMemoryDataset(InMemoryDataset):
    """
    Reads graphs from JSON paths listed in a split CSV and builds PyG InMemoryDataset.

    Expected CSV columns:
      - path (json full path)
      - label_3c (0/1/2)
    Optional columns used if present:
      - source_video
      - frame_id
    """

    def __init__(
        self,
        root: str,
        split_csv: str,
        split_name: str,
        transform=None,
        pre_transform=None,
    ):
        self.split_csv = split_csv
        self.split_name = split_name
        super().__init__(root, transform, pre_transform)

        self.data, self.slices = self._safe_load_or_rebuild(self.processed_paths[0])

    def _safe_load_or_rebuild(self, path: str):
        """
        PyTorch 2.6 changed torch.load default weights_only=True.
        PyG Data objects in cache may require weights_only=False.

        Strategy:
          1) Try weights_only=True
          2) If fails, try weights_only=False (ONLY for trusted local cache)
          3) If still fails, delete cache and rebuild by calling process()
        """
        if not os.path.exists(path):
            # Defensive: if processed file missing, build it now
            self.process()
            return torch.load(path, weights_only=False)

        try:
            return torch.load(path)  # weights_only defaults to True in torch 2.6
        except Exception as e_safe:
            try:
                return torch.load(path, weights_only=False)
            except Exception as e_full:
                print(f"[WARN] Cache load failed for {path}. Rebuilding cache...")
                print(f"  - safe load error: {type(e_safe).__name__}: {str(e_safe)[:200]}")
                print(f"  - full load error: {type(e_full).__name__}: {str(e_full)[:200]}")

                try:
                    os.remove(path)
                except OSError:
                    pass

                self.process()
                return torch.load(path, weights_only=False)

    @property
    def processed_file_names(self):
        return [f"data_{self.split_name}.pt"]

    @property
    def raw_file_names(self):
        return []

    def download(self):
        pass

    def process(self):
        df = pd.read_csv(self.split_csv)

        assert "path" in df.columns, "split CSV must contain 'path' column"
        assert "label_3c" in df.columns, "split CSV must contain 'label_3c' column"

        # Sanity: label_3c should be {0,1,2}
        bad_labels = set(df["label_3c"].unique()) - {0, 1, 2}
        if bad_labels:
            raise ValueError(f"Unexpected label_3c values in {self.split_csv}: {sorted(bad_labels)}")

        has_source_video_col = ("source_video" in df.columns)
        has_frame_id_col = ("frame_id" in df.columns)

        # ---- DÜZELTME B: global, small, consistent ids 0..V-1 ----
        video2id = load_or_build_video2id(self.split_csv, self.processed_dir)
        # (İsterseniz model cfg için: num_videos = len(video2id))

        data_list = []

        for _, row in df.iterrows():
            jpath = str(row["path"])
            y = int(row["label_3c"])

            with open(jpath, "r") as f:
                g = json.load(f)

            node_features = g.get("node_features", [])
            edge_index = g.get("edge_index", [[], []])
            edge_type = g.get("edge_type", [])
            edge_attr = g.get("edge_attr", [])

            # Basic sanity
            if not isinstance(edge_index, list) or len(edge_index) != 2:
                raise ValueError(f"Invalid edge_index format in: {jpath}")

            if len(node_features) == 0:
                raise ValueError(f"Empty node_features in {jpath} (graph has no nodes).")

            src_list = edge_index[0]
            dst_list = edge_index[1]
            if len(src_list) != len(dst_list):
                raise ValueError(
                    f"edge_index src/dst len mismatch in {jpath}: {len(src_list)} vs {len(dst_list)}"
                )

            E = len(src_list)

            if len(edge_type) != E:
                raise ValueError(f"edge_type len mismatch in {jpath}: {len(edge_type)} vs E={E}")

            extra_edge_ids = set(edge_type) - ALLOWED_EDGE_IDS
            if extra_edge_ids:
                raise ValueError(
                    f"Unexpected edge_type ids {sorted(extra_edge_ids)} in {jpath} "
                    f"(expected {sorted(ALLOWED_EDGE_IDS)})"
                )

            edge_attr = coerce_edge_attr(edge_attr, E, jpath)

            # Build tensors
            x = torch.tensor(node_features, dtype=torch.float32)          # [N, F]
            edge_index_t = torch.tensor(edge_index, dtype=torch.long)     # [2, E]
            edge_type_t = torch.tensor(edge_type, dtype=torch.long)       # [E]
            edge_attr_t = torch.tensor(edge_attr, dtype=torch.float32)    # [E, 1]
            y_t = torch.tensor([y], dtype=torch.long)                     # [1]

            if x.dim() != 2 or x.size(1) < 10:
                raise ValueError(f"node_features expected [N,>=10], got {list(x.shape)} in {jpath}")

            if E > 0:
                if edge_attr_t.dim() != 2 or edge_attr_t.size(0) != E or edge_attr_t.size(1) != 1:
                    raise ValueError(f"edge_attr expected [E,1], got {list(edge_attr_t.shape)} in {jpath}")

            N = x.size(0)
            if E > 0:
                if edge_index_t.min().item() < 0 or edge_index_t.max().item() >= N:
                    raise ValueError(
                        f"edge_index has out-of-range node ids in {jpath}. N={N}, max={edge_index_t.max().item()}"
                    )

            # machine_subclass_id: [-1 for non-machine, 0..9 for machines]
            machine_subclass_id = x[:, 9].to(torch.long)

            # Metadata
            source_video = (row["source_video"] if has_source_video_col else g.get("source_video", "unknown"))
            frame_id = int(row["frame_id"]) if has_frame_id_col else int(g.get("frame_id", -1))

            # ---- DÜZELTME B: map to 0..V-1 ----
            source_video_id = video2id.get(str(source_video), video2id.get("unknown", 0))

            data = Data(
                x=x,
                edge_index=edge_index_t,
                edge_type=edge_type_t,
                edge_attr=edge_attr_t,
                y=y_t,
            )

            data.machine_subclass_id = machine_subclass_id
            # keep old attribute name for compatibility with your prints
            data.source_video = torch.tensor([source_video_id], dtype=torch.long)
            data.source_video_id = torch.tensor([source_video_id], dtype=torch.long)
            data.frame_id = torch.tensor([frame_id], dtype=torch.long)

            data_list.append(data)

        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


# ---- Usage: create datasets for each split (NEW SYSTEM PATHS) ----
DATA_DIR = "/content/drive/Othercomputers/Dizüstü Bilgisayarım/Google drive Masaustu/0-PPE/0-Data Factory outputs"
INDEX_DIR = os.path.join(DATA_DIR, "_index_3class")
SPLIT_DIR = os.path.join(INDEX_DIR, "splits_chunked")

train_csv = os.path.join(SPLIT_DIR, "train.csv")
val_csv   = os.path.join(SPLIT_DIR, "val.csv")
test_csv  = os.path.join(SPLIT_DIR, "test.csv")

CACHE_ROOT = os.path.join(SPLIT_DIR, "_pyg_cache")
os.makedirs(CACHE_ROOT, exist_ok=True)

# IMPORTANT: mapping/caches changed -> delete old cache once before running
# import shutil
# if os.path.exists(CACHE_ROOT):
#     shutil.rmtree(CACHE_ROOT)
# os.makedirs(CACHE_ROOT, exist_ok=True)

train_ds = GraphJsonInMemoryDataset(root=CACHE_ROOT, split_csv=train_csv, split_name="train")
val_ds   = GraphJsonInMemoryDataset(root=CACHE_ROOT, split_csv=val_csv,   split_name="val")
test_ds  = GraphJsonInMemoryDataset(root=CACHE_ROOT, split_csv=test_csv,  split_name="test")

print("Datasets loaded:")
print("  train:", len(train_ds))
print("  val  :", len(val_ds))
print("  test :", len(test_ds))

d0 = train_ds[0]
print("\nSample[0] tensors:")
print("  x:", d0.x.shape, d0.x.dtype)
print("  edge_index:", d0.edge_index.shape, d0.edge_index.dtype)
print("  edge_type:", d0.edge_type.shape, d0.edge_type.dtype)
print("  edge_attr:", d0.edge_attr.shape, d0.edge_attr.dtype)
print("  y:", d0.y, d0.y.dtype)
print("  machine_subclass_id:", d0.machine_subclass_id.shape, d0.machine_subclass_id.dtype)
print("  source_video_id:", d0.source_video.item(), "frame_id:", d0.frame_id.item())




Processing...
Done!
Processing...
Done!
Processing...


Datasets loaded:
  train: 4563
  val  : 648
  test : 648

Sample[0] tensors:
  x: torch.Size([3, 10]) torch.float32
  edge_index: torch.Size([2, 5]) torch.int64
  edge_type: torch.Size([5]) torch.int64
  edge_attr: torch.Size([5, 1]) torch.float32
  y: tensor([1]) torch.int64
  machine_subclass_id: torch.Size([3]) torch.int64
  source_video_id: 0 frame_id: 4920


Done!
