In [80]:
# rl_bedrock_roster_strict.py
import boto3
import json
import tarfile
import tempfile
import os
import time
import random
import re
import io
import xgboost as xgb
import numpy as np
from botocore.config import Config

# -----------------------------
# AWS clients & config
# -----------------------------
config = Config(read_timeout=120, connect_timeout=30)  # increase timeouts
bedrock_client = boto3.client("bedrock-runtime", region_name="us-east-1", config=config)
s3_client = boto3.client("s3")

BEDROCK_MODEL_ID = "amazon.titan-text-premier-v1:0"
S3_BUCKET = "hospital-roster-data"

# ORIGINAL model path (used for loading initial model)
ORIG_MODEL_S3_KEY = "training/xgboost/output/sagemaker-xgboost-2025-09-20-01-07-22-211/output/model.tar.gz"

# NEW canonical save path (store updated models here)
NEW_MODEL_S3_KEY = "training/xgboost/output/model.tar.gz"

# backup prefix (timestamped backups will go under this prefix)
MODEL_BACKUP_PREFIX = "training/xgboost/backup/"

# where to store episode rosters
EPISODE_ROSTER_PREFIX = "roster_history/rl_episode_rosters/"

# -----------------------------
# Helpers: load JSON from S3
# -----------------------------
def load_json_from_s3(bucket, key):
    obj = s3_client.get_object(Bucket=bucket, Key=key)
    return json.loads(obj["Body"].read().decode("utf-8"))

# -----------------------------
# Load S3 data (URIs as requested)
# -----------------------------
nurses = load_json_from_s3(S3_BUCKET, "raw_data/nurse_data/nurse.json")
rules = load_json_from_s3(S3_BUCKET, "raw_data/rules.json")
demand = load_json_from_s3(S3_BUCKET, "raw_data/demand_data/demand.json")
shift_def = load_json_from_s3(S3_BUCKET, "raw_data/shift.json")
example_roster = load_json_from_s3(S3_BUCKET, "roster_history/roster_15092025.json")

print("Loaded S3 data: nurses, rules, demand, shift_def, example_roster")

# -----------------------------
# Load XGBoost model from original S3 URI (if available)
# -----------------------------
def load_model_tar_from_s3(bucket, key):
    obj = s3_client.get_object(Bucket=bucket, Key=key)
    return io.BytesIO(obj["Body"].read())

def load_xgb_model_from_tarbytes(tar_bytes_io):
    try:
        with tempfile.TemporaryDirectory() as tmpdir:
            tar_path = os.path.join(tmpdir, "model.tar.gz")
            with open(tar_path, "wb") as f:
                f.write(tar_bytes_io.read())

            with tarfile.open(tar_path, "r:gz") as tar:
                tar.extractall(path=tmpdir)

            for fname in os.listdir(tmpdir):
                cand = os.path.join(tmpdir, fname)
                if fname == "xgboost-model" or fname.endswith((".model", ".bin", ".json")):
                    booster = xgb.Booster()
                    try:
                        booster.load_model(cand)
                        print("Loaded XGBoost booster from", cand)
                        return booster
                    except Exception as e:
                        # Try other files if one fails
                        print("Failed to load candidate model file:", cand, e)
                        continue
    except Exception as e:
        print("Error loading model tar:", e)
    return None

xgb_model = None
try:
    tar_io = load_model_tar_from_s3(S3_BUCKET, ORIG_MODEL_S3_KEY)
    xgb_model = load_xgb_model_from_tarbytes(tar_io)
except Exception as e:
    print("Could not fetch/load initial XGBoost model:", e)
if xgb_model is None:
    print("Proceeding with fallback scoring (no booster loaded).")

# -----------------------------
# Bedrock call with backoff (no prompt changes)
# -----------------------------
def query_bedrock(prompt_text, model_id=BEDROCK_MODEL_ID, retries=6):
    request_body = {
        "inputText": prompt_text,
        "textGenerationConfig": {
            "maxTokenCount": 3072,
            "temperature": 0.7,
            "topP": 0.9,
        }
    }
    for attempt in range(retries):
        try:
            resp = bedrock_client.invoke_model(
                modelId=model_id,
                contentType="application/json",
                accept="application/json",
                body=json.dumps(request_body)
            )
            result_json = json.loads(resp["body"].read())
            return result_json.get("results", [{}])[0].get("outputText", "")
        except bedrock_client.exceptions.ThrottlingException:
            wait = (2 ** attempt) + random.random()
            print(f"Throttled; backing off {wait:.2f}s")
            time.sleep(wait)
        except Exception as e:
            print("Bedrock invoke error:", e)
            break
    return ""

# -----------------------------
# KEEP THE PROMPT EXACTLY AS BEFORE
# -----------------------------
def build_roster_prompt(instruction, nurses, rules, demand, shift_def, example_roster):
    return f"""
System Instructions:
You are a nurse roster generator. Output valid JSON roster ONLY. Do not include code fences, explanations, or extra text at the beginning or end of the output. Do not put ```json at the start and end. Immediately start and end with open and close curly braces.
Follow all HARD constraints strictly; respect SOFT preferences if possible.

Constraints:
- Daily max hours: {rules['constraints']['daily_hours_cap']}
- Weekly max hours: {rules['constraints']['weekly_hours_cap']}
- Rest between shifts: {rules['constraints']['rest_time_hours']}h
- Weekly rest days: {rules['constraints']['weekly_rest_days']}
- Max assignments per slot: {rules['constraints']['max_assignments_per_slot']['enabled']}
- Department balance: {rules['constraints']['department_balance']['enabled']}
- Core skill requirement: {rules['constraints']['core_skill_requirement']['enabled']}
- Skill mix requirement: {rules['constraints']['skill_mix_requirement']['enabled']}
- Unavailability: {rules['unavailability']['rule']}

Soft preferences: {rules['preferences']['objective']}

Example roster JSON format:
{json.dumps(example_roster, indent=2)}

Nurse Databse:
{json.dumps(nurses, indent=2)}

Shift demand:
{json.dumps(demand, indent=2)}

Shift definitions:
{json.dumps(shift_def, indent=2)}

Task:
{instruction}
"""

# -----------------------------
# Strict hierarchical roster schema validator
# -----------------------------
def is_valid_hierarchical_roster(obj):
    """
    Validate strict format:
    {
      "departments": [
        {
          "name": "General",
          "nurses": [
            {
              "id": "N034",
              "shifts": [
                {"day":"Mon", "shift":"Full-Morning"},
                ...
              ]
            },...
          ]
        },...
      ]
    }
    """
    if not isinstance(obj, dict):
        return False
    deps = obj.get("departments")
    if not isinstance(deps, list) or len(deps) == 0:
        return False
    for d in deps:
        if not isinstance(d, dict) or "name" not in d or "nurses" not in d:
            return False
        if not isinstance(d["nurses"], list):
            return False
        for n in d["nurses"]:
            if not isinstance(n, dict) or "id" not in n or "shifts" not in n:
                return False
            if not isinstance(n["shifts"], list):
                return False
            for s in n["shifts"]:
                if not isinstance(s, dict) or "day" not in s or "shift" not in s:
                    return False
    return True

# -----------------------------
# Flatten hierarchical roster to flat assignments for training/eval
# -----------------------------
def flatten_hierarchical_roster(h):
    flat = []
    if not isinstance(h, dict):
        return flat
    for d in h.get("departments", []):
        dept_name = d.get("name")
        for nurse in d.get("nurses", []):
            nid = nurse.get("id")
            for s in nurse.get("shifts", []):
                day = s.get("day")
                shiftname = s.get("shift")
                flat.append({"nurse_id": nid, "department": dept_name, "day": day, "shift": shiftname})
    return flat

# -----------------------------
# Coerce possible flat roster to hierarchical (if model returned flat list)
# -----------------------------
def coerce_flat_to_hierarchical(flat_list):
    """
    Accepts a list of dicts like:
      [{"nurse_id": "...", "department": "...", "day": "...", "shift": "..."}, ...]
    Returns hierarchical roster in strict format.
    """
    if not isinstance(flat_list, list):
        return None
    # group by department then by nurse
    dept_map = {}
    for a in flat_list:
        if not isinstance(a, dict):
            continue
        nid = a.get("nurse_id") or a.get("id") or a.get("nurse")
        dept = a.get("department") or a.get("dept") or a.get("department_name")
        day = a.get("day") or a.get("date")
        shiftname = a.get("shift") or a.get("time")
        if not (nid and dept and day and shiftname):
            # skip incomplete entries
            continue
        dept_map.setdefault(dept, {}).setdefault(nid, []).append({"day": day, "shift": shiftname})

    departments = []
    for dept_name, nurses_map in dept_map.items():
        nurses_list = []
        for nid, shifts in nurses_map.items():
            nurses_list.append({"id": nid, "shifts": shifts})
        departments.append({"name": dept_name, "nurses": nurses_list})

    return {"departments": departments}

# -----------------------------
# Safe JSON extraction from model output (best-effort)
# -----------------------------
def safe_extract_json(text):
    if not text:
        return None
    
    # --- Strip markdown code fences like ```json ... ``` ---
    # remove leading/trailing fences
    text = re.sub(r"```(?:json)?", "", text, flags=re.IGNORECASE).strip("` \n")
    
    # --- Try direct parse ---
    try:
        return json.loads(text)
    except Exception:
        pass
    
    # --- Try to extract first JSON object or array ---
    m = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
    if not m:
        return None
    candidate = m.group(0)
    
    try:
        return json.loads(candidate)
    except Exception:
        # --- Try cleaning trailing commas ---
        candidate2 = re.sub(r",\s*(\]|\})", r"\1", candidate)
        try:
            return json.loads(candidate2)
        except Exception:
            return None

# -----------------------------
# Build features for one assignment (must match training order if you want to use xgb)
# -----------------------------
def build_feature_list(nurse_obj, dept, day, shiftname):
    seniority_map = {"Junior": 0, "Mid": 1, "Senior": 2}
    cand_experience = nurse_obj.get("experience_years", nurse_obj.get("experience", 0))
    cand_hours_contract = nurse_obj.get("contracted_hours", nurse_obj.get("hours_contract", 0))
    cand_seniority_num = seniority_map.get(nurse_obj.get("seniority_level", nurse_obj.get("seniority", "Junior")), 0)

    prefs = nurse_obj.get("preferences", []) or nurse_obj.get("pref", []) or []
    pref_morning = 1 if any("morning" in p.lower() for p in prefs) else 0
    pref_evening = 1 if any("evening" in p.lower() for p in prefs) else 0
    pref_night = 1 if any("night" in p.lower() for p in prefs) else 0

    skills = nurse_obj.get("skills", []) or []
    skill_ER = 1 if "ER" in skills else 0
    skill_General = 1 if "General" in skills else 0
    skill_ICU = 1 if "ICU" in skills else 0
    skill_OT = 1 if "OT" in skills else 0
    skill_Pediatrics = 1 if "Pediatrics" in skills else 0

    hours_in_week = nurse_obj.get("hours_in_week", 0)
    would_violate_45 = 0
    would_violate_8_per_day = 0
    has_rest_day = 1 if any("rest" in t.lower() for t in nurse_obj.get("union_terms", [])) else 0

    return [
        float(cand_experience),
        float(cand_hours_contract),
        float(cand_seniority_num),
        float(pref_morning),
        float(pref_evening),
        float(pref_night),
        float(skill_ER),
        float(skill_General),
        float(skill_ICU),
        float(skill_OT),
        float(skill_Pediatrics),
        float(hours_in_week),
        float(would_violate_45),
        float(would_violate_8_per_day),
        float(has_rest_day),
    ]

# -----------------------------
# Compatibility scoring (uses xgb if available otherwise fallback)
# -----------------------------
def compute_compatibility_scores(nurses, demand, xgb_model=None):
    scores = {}
    for nurse in nurses:
        nid = nurse.get("nurse_id") or nurse.get("id") or nurse.get("name")
        for dept_name, dept_data in demand.items():
            for day_name, shifts in dept_data.items():
                for shift_name in shifts:
                    feat = build_feature_list(nurse, dept_name, day_name, shift_name)
                    if xgb_model:
                        try:
                            arr = np.array(feat, dtype=np.float32).reshape(1, -1)
                            dmat = xgb.DMatrix(arr)
                            pred = float(xgb_model.predict(dmat)[0])
                            score = pred
                        except Exception:
                            score = 0.5
                    else:
                        score = 0.5
                    scores[(nid, dept_name, day_name, shift_name)] = score
    return scores

compatibility_scores = compute_compatibility_scores(nurses, demand, xgb_model)

# -----------------------------
# Evaluate a flat roster (list of assignments)
# -----------------------------
def evaluate_flat_roster(flat_assignments, compatibility_scores):
    reward = 0.0
    nurse_hours = {}

    for a in flat_assignments:
        nid = a.get("nurse_id")
        dept = a.get("department")
        day = a.get("day")
        shiftname = a.get("shift")
        if not (nid and dept and day and shiftname):
            continue
        key = (nid, dept, day, shiftname)
        reward += compatibility_scores.get(key, 0.5)
        nurse_hours[nid] = nurse_hours.get(nid, 0) + 8

    penalty = sum((hrs - 45) * 0.1 for hrs in nurse_hours.values() if hrs > 45)
    return reward - penalty

# -----------------------------
# Retrain XGBoost on episode samples and upload with backup
# -----------------------------
def retrain_and_upload_xgb(samples_X, samples_y, s3_bucket, save_key=NEW_MODEL_S3_KEY, backup_prefix=MODEL_BACKUP_PREFIX, num_rounds=25):
    if len(samples_X) == 0:
        print("No samples to train on; skipping retrain/upload.")
        return None
    X = np.array(samples_X, dtype=np.float32)
    y = np.array(samples_y, dtype=np.float32)
    dtrain = xgb.DMatrix(X, label=y)
    params = {"objective":"reg:squarederror", "max_depth":4, "eta":0.1, "verbosity":0}
    booster = xgb.train(params, dtrain, num_boost_round=num_rounds)

    with tempfile.TemporaryDirectory() as tmpdir:
        model_path = os.path.join(tmpdir, "xgboost-model")
        booster.save_model(model_path)

        tar_path = os.path.join(tmpdir, "model.tar.gz")
        with tarfile.open(tar_path, "w:gz") as tar:
            tar.add(model_path, arcname="xgboost-model")

        # backup previous model (if exists) to backup_prefix + timestamp
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        backup_key = f"{backup_prefix}{timestamp}/model.tar.gz"

        try:
            # only attempt to copy if original/new key exists (head_object)
            s3_client.head_object(Bucket=s3_bucket, Key=save_key)
            # if exists, copy current save_key to backup location
            s3_client.copy_object(Bucket=s3_bucket, CopySource={"Bucket": s3_bucket, "Key": save_key}, Key=backup_key)
            print("Backed up previous model to", backup_key)
        except s3_client.exceptions.ClientError:
            # not found or no permission - continue
            print("No existing saved model to backup (ok).")

        # Upload new model to NEW_MODEL_S3_KEY (in training/xgboost/)
        s3_client.upload_file(tar_path, s3_bucket, save_key)
        print("Uploaded new model to", save_key)

    return booster

# -----------------------------
# Save roster to S3 (episode)
# -----------------------------
def save_roster_to_s3(roster_obj, ep_idx):
    ts = int(time.time())
    key = f"{EPISODE_ROSTER_PREFIX}roster_ep_{ts}_{ep_idx}.json"
    body = json.dumps(roster_obj, indent=2)
    s3_client.put_object(Bucket=S3_BUCKET, Key=key, Body=body.encode("utf-8"))
    print("Saved roster to s3://{}/{}".format(S3_BUCKET, key))
    return key

# -----------------------------
# Full RL episode loop (minimal policy; focus on generation+validate+retrain+upload)
# -----------------------------
def run_episodes(num_episodes=1):
    global xgb_model, compatibility_scores
    for ep in range(num_episodes):
        print("\n=== Episode", ep+1, "===")
        prompt = build_roster_prompt(
            instruction="Generate a JSON roster for the week. Make sure to refer to the example roster to get the structured JSON format. Only output the JSON roster! Do not include code fences, explanations, or extra text at the beginning or end of the output. Do not put ```json at the start and end. Immediately start and end with open and close curly braces",
            nurses=nurses,
            rules=rules,
            demand=demand,
            shift_def=shift_def,
            example_roster=example_roster
        )

        out = query_bedrock(prompt)
        print("Raw model output (first 5000 chars):\n", out[:5000])

        # try to extract JSON
        parsed = safe_extract_json(out)
        roster_hier = None
        if parsed is None:
            print("Could not parse JSON from model output. Saving raw and skipping this episode.")
            s3_client.put_object(Bucket=S3_BUCKET, Key=f"{EPISODE_ROSTER_PREFIX}raw_ep_{int(time.time())}_{ep}.json", Body=out.encode("utf-8"))
            continue

        # If parsed is already hierarchical, validate
        if is_valid_hierarchical_roster(parsed):
            roster_hier = parsed
        else:
            # If parsed is a dict with departments key but slightly malformed, try best-effort normalize:
            if isinstance(parsed, dict) and "departments" in parsed:
                # attempt flatten/repair then validate
                flat_try = flatten_hierarchical_roster(parsed)
                if flat_try:
                    coerced = coerce_flat_to_hierarchical(flat_try)
                    if coerced and is_valid_hierarchical_roster(coerced):
                        roster_hier = coerced
            # If parsed is a flat list, try coercion to hierarchical
            if roster_hier is None and isinstance(parsed, list):
                coerced = coerce_flat_to_hierarchical(parsed)
                if coerced and is_valid_hierarchical_roster(coerced):
                    roster_hier = coerced

        if roster_hier is None:
            # Not valid / not fixable
            print("Model output is not in the strict hierarchical format and could not be coerced. Saving raw for inspection.")
            s3_client.put_object(Bucket=S3_BUCKET, Key=f"{EPISODE_ROSTER_PREFIX}invalid_ep_{int(time.time())}_{ep}.json", Body=json.dumps(parsed, indent=2).encode("utf-8"))
            continue

        # roster_hier is validated strict format. Save it:
        save_roster_to_s3(roster_hier, ep)

        # Flatten and evaluate
        flat_assignments = flatten_hierarchical_roster(roster_hier)
        ep_reward = evaluate_flat_roster(flat_assignments, compatibility_scores)
        print("Episode reward:", ep_reward)

        # Build training samples: each assignment -> features; label = ep_reward / num_assignments
        samples_X = []
        samples_y = []
        per_label = float(ep_reward) / max(1, len(flat_assignments))
        for a in flat_assignments:
            nid = a.get("nurse_id")
            # find nurse object
            nurse_obj = next((n for n in nurses if n.get("nurse_id") == nid or n.get("id") == nid or n.get("name") == nid), None)
            if nurse_obj is None:
                # skip unknown nurse ids
                continue
            feat = build_feature_list(nurse_obj, a.get("department"), a.get("day"), a.get("shift"))
            samples_X.append(feat)
            samples_y.append(per_label)

        if len(samples_X) == 0:
            print("No valid samples to retrain on this episode.")
            continue

        # Retrain small xgboost on these samples and upload with backup
        new_booster = retrain_and_upload_xgb(samples_X, samples_y, S3_BUCKET, save_key=NEW_MODEL_S3_KEY, backup_prefix=MODEL_BACKUP_PREFIX, num_rounds=20)
        if new_booster:
            xgb_model = new_booster
            # recompute compatibility scores using updated model
            compatibility_scores = compute_compatibility_scores(nurses, demand, xgb_model)
            print("Updated compatibility scores from retrained model.")

# -----------------------------
# Run the loop (change episodes as needed)
# -----------------------------
if __name__ == "__main__":
    run_episodes(num_episodes=1)


Loaded S3 data: nurses, rules, demand, shift_def, example_roster
Loaded XGBoost booster from /tmp/tmpwwj7py1r/xgboost-model


  tar.extractall(path=tmpdir)



=== Episode 1 ===
Raw model output (first 5000 chars):
 {
  "departments": [
    {
      "name": "General",
      "nurses": [
        {
          "id": "N001",
          "shifts": [
            {
              "day": "Mon",
              "shift": "Full-Morning"
            },
            {
              "day": "Tue",
              "shift": "Full-Morning"
            },
            {
              "day": "Wed",
              "shift": "Full-Morning"
            },
            {
              "day": "Thu",
              "shift": "Full-Morning"
            },
            {
              "day": "Fri",
              "shift": "Full-Morning"
            },
            {
              "day": "Sat",
              "shift": "Full-Morning"
            },
            {
              "day": "Sun",
              "shift": "Full-Morning"
            }
          ]
        },
        {
          "id": "N002",
          "shifts": [
            {
              "day": "Mon",
              "shift": "Full-Eve