In [None]:
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Patch

pd.set_option("display.max_rows", None)


aggregation_interval = 24  
data_dir = "/Users/melisadzanovic/Documents/ML - Thesis/hcs-synthetic-data-generator-main"
json_path = os.path.join(data_dir, "events_full.json")

def load_data(json_path):
    if not os.path.isfile(json_path):
        raise FileNotFoundError(f"File not found: {json_path}")
    with open(json_path) as f:
        events = json.load(f)
    return pd.DataFrame(events)


def create_histogram_data(json_path, aggregation_interval=24):
    df = load_data(json_path)

    if "start" in df.columns and "timestamp" in df.columns:
        df["datetime"] = pd.to_datetime(df["start"].fillna(df["timestamp"]))
    elif "start" in df.columns:
        df["datetime"] = pd.to_datetime(df["start"])
    elif "timestamp" in df.columns:
        df["datetime"] = pd.to_datetime(df["timestamp"])
    else:
        raise ValueError("No time column found - need either 'start' or 'timestamp'")

    base_time = df["datetime"].min()
    df["period"] = ((df["datetime"] - base_time).dt.total_seconds() //
                    (aggregation_interval * 3600)).astype(int)

    grouped = df.groupby(["patient_id", "practitioner_id", "period"])
    aggregates = []

    for (patient_id, practitioner_id, period), group in grouped:
        period_start = base_time + pd.Timedelta(hours=aggregation_interval * period)
        period_end = period_start + pd.Timedelta(hours=aggregation_interval)

        has_appt = (group["type"] == "Appointment").any()
        has_obs = (group["type"] == "Observation").any()
        has_enc = (group["type"] == "Encounter").any()

        has_btg = False
        has_care = False
        num_btg = 0
        num_care = 0
        if "data" in group.columns:
            audit_events = group[group["type"] == "AuditEvent"]
            if len(audit_events) == 0:
                continue
            for data in audit_events["data"]:
                purpose = data.get("purpose")
                if purpose in ("EMERGENCY", "BTG"):
                    has_btg = True
                    num_btg += 1
                elif purpose == "CAREMGT":
                    has_care = True
                    num_care += 1
        else:
            continue

        if has_appt or has_obs or has_enc or has_btg:
            has_care = False
            num_care = 0

        num_total_events = len(audit_events)
        avg_time_between = (audit_events["datetime"].sort_values().diff().dt.total_seconds().mean())
        num_unique_resources = audit_events["resource"].nunique() if "resource" in audit_events.columns else None

        aggregates.append({
            "patient_id": patient_id,
            "practitioner_id": practitioner_id,
            "period": period,
            "period_start": period_start,
            "period_end": period_end,
            "has_appointment": has_appt,
            "has_observation": has_obs,
            "has_encounter": has_enc,
            "has_btg_access": has_btg,
            "has_care_access": has_care,
            "num_btg_events": num_btg,
            "num_care_events": num_care,
            "num_total_events": num_total_events,
            "avg_time_between_events": avg_time_between,
            "num_unique_resources_accessed": num_unique_resources
        })

    df_out = pd.DataFrame(aggregates)
    df_out["table_id"] = df_out.apply(calculate_table_id, axis=1)
    df_out["label"] = df_out.apply(determine_label, axis=1)
    return df_out

def calculate_table_id(row):
    a = row["has_appointment"]
    o = row["has_observation"]
    e = row["has_encounter"]
    b = row["has_btg_access"]
    return 1 + (int(a)*8 + int(o)*4 + int(e)*2 + int(b)*1)

def determine_label(row):
    table_id = row["table_id"]
    if table_id == 1:
        return "Anomaly"
    return "Normal"

result_df = create_histogram_data(json_path=json_path, aggregation_interval=aggregation_interval)

csv_path = os.path.join(data_dir, f"labeled_events_full-{aggregation_interval}h.csv")
result_df.to_csv(csv_path, index=False)
print(f"CSV saved: {csv_path}")

print("\n=== Summary ===")
print(result_df["label"].value_counts())
print("\n=== Table ID ===")
print(result_df["table_id"].value_counts().sort_index())
