In [None]:
import sys
print(sys.executable)

# !{sys.executable} -m pip install flash-attn==2.8.0.post2

In [None]:
import torch
print("CUDA available?", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
print("CUDA device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")


In [None]:
# Cell 1: Setup and Global Model Loading
print("Starting cell 1")
# Standard imports for model loading and type hinting
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch # Often used for model operations (e.g., moving to GPU, data types)
from typing import Any, Dict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}") # Prints GPU name (e.g., NVIDIA A100)


# Define the path to your model
# MODEL_NAME = "models/Meta-Llama-3-8B-Instruct"
MODEL_NAME = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"

# --- Global Model Loading Logic ---
# This check ensures the model is loaded only once per kernel session.
# If you restart your Jupyter kernel, this cell will execute and load the model again.
if 'global_tokenizer' not in globals():
    print(f"Loading tokenizer and model '{MODEL_NAME}' for the first time...")
    
    # Load tokenizer
    global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
    
    # Load model. 
    # Also, move to GPU if available.
    global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
    if torch.cuda.is_available():
        print("CUDA available")
        global_model.to('cuda')

    print("Model loaded.")
else:
    print("Model already loaded. Reusing existing instances.")

# Optional: Print basic info to confirm
print(f"Tokenizer: {type(global_tokenizer)}")
print(f"Model: {type(global_model)}")

Starting cell 1
Using device: cuda
CUDA device name: NVIDIA A40
Loading tokenizer and model 'hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4' for the first time...


  @custom_fwd
  @custom_bwd
  @custom_fwd(cast_inputs=torch.float16)
CUDA extension not installed.
CUDA extension not installed.
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


In [None]:
%load_ext autoreload
%autoreload 2

import json, pprint, os, sys
from IPython.display import Image, display
from datetime import datetime
from agents.orchestrator import orchestrator
from agents.information_gatherer import information_gatherer
from agents.reflector import reflector
from agents.final_formatter import final_formatter
from graph import build_graph
from agents.BaseAgent import AgentState
print("Starting this cell!")

pp = pprint.PrettyPrinter(indent=2, width=100)

initial_state = {
    "input": """I need to generate a storyboard (images and text needed, remmeber to find and extract videos) 
    about the Singapore SG60 National Day Parade, and try to encourage some nationalism and patriotism""",   # your top-level query
    "plan": None,
    "knowledge": [],
    "draft_story": None,
    "reflection": None,
    "final_output": None,
    "status": "needs_info", 
    "candidates": None,
    "needs": []
}

# user_query =  """I need to generate a storyboard (images and text needed, remmeber to find and extract videos) 
#     about the Singapore SG60 National Day Parade, and try to encourage some nationalism and patriotism"""
# initial_state = AgentState(user_query=user_query)
# print(initial_state)

#Build Graph
orch_node = lambda initial_state: orchestrator(global_model, global_tokenizer, initial_state)
final_node = lambda initial_state: final_formatter(global_model, global_tokenizer, initial_state)
# print(f"\n\n Orch node is of type {type(orch_node)}")
app = build_graph(
    orchestrator_node=orch_node,
    information_gatherer_node=information_gatherer,
    final_formatter_node=final_node,
)

# show a diagram if available
try:
    png_bytes = app.get_graph().draw_mermaid_png()
    display(Image(data=png_bytes))
except Exception as e:
    print("[note] couldn't render mermaid png:", e)

# pp.pprint(initial_state)
final_state = app.invoke(initial_state)
print("Final state:", final_state)


# # --- 3) utility: pretty-print state diffs each hop ---
# def run_once(app, initial_state: dict):
#     print("=== RUN @", datetime.now().strftime("%H:%M:%S"), "===")
#     print("initial_state:")
#     pp.pprint(initial_state)
#     print("\n--- streaming updates ---")
#     try:
#         for update in app.stream(initial_state, stream_mode="updates"):
#             # each `update` is a dict of state writes at that step
#             print("Update!\n")
#             print(json.dumps(update, indent=2))
#     except TypeError:
#         # some langgraph versions expect stream_mode=["updates","values"]
#         for update in app.stream(initial_state, stream_mode=["updates","values"]):
#             print(json.dumps(update, indent=2))

#     print("\n--- final values ---")
#     final = app.invoke(initial_state)
#     pp.pprint(final)
#     return final


# # --- 4) example: minimal state to kick off the graph ---


# final_state = run_once(app, initial_state)


In [None]:
# What AgentState are YOU using here?
from agents.BaseAgent import AgentState as NB_AgentState  # or just use AgentState if in the same cell
print("NB AgentState:", NB_AgentState, NB_AgentState.__module__)

# What AgentState is PlanningAgent using?
import agents.PlanningAgent as PA
print("PA.AgentState:", PA.AgentState, PA.AgentState.__module__)

# Are they literally the same class?
import agents.BaseAgent as SS  # <-- replace with wherever your canonical AgentState is
print("PlanningAgent uses same AgentState as SS?", PA.AgentState is SS.AgentState)


In [None]:
# To display the eventual storyboard as an output

from IPython.display import display, HTML

def display_storyboard(storyboard):
    """
    Given a storyboard list of dicts with keys 'image_path' and 'caption',
    displays each image with its caption underneath in a Jupyter notebook.
    """
    for scene in storyboard:
        display(HTML(f"""
        <figure style="max-width:600px; margin: 20px 0;">
            <img src="{scene['image_path']}" 
                 style="width:100%; height:auto; border:1px solid #ccc; border-radius:4px;" />
            <figcaption style="text-align:center; font-style:italic; color:#555; margin-top:5px;">
                {scene['caption']}
            </figcaption>
        </figure>
        """))
display_storyboard(final_state)
        
        



In [None]:
import os, json
from IPython.display import HTML, display

def _to_served_src(p: str) -> str:
    # debug: print before and after
    # print(f'Before: {p}')
    # strip stray spaces before extension
    p = p.replace(".jpg", " .jpg").replace(".png", " .png")
    # remove any leading /home/jovyan/ if it exists
    p = p.replace("/home/jovyan/", "")
    # print(f'After: {p}')
    # return relative path (browser will look under Jupyter's served root)
    return p


def display_storyboard(final_state):
    # 1) prefer final_output.storyboard; fallback to knowledge
    items = []
    fo = final_state.get("final_output")
    if isinstance(fo, str) and fo.strip():
        try:
            obj = json.loads(fo)
            for s in obj.get("storyboard", []) or []:
                items.append({
                    "image_path": s.get("image_path", ""),
                    "caption": s.get("frame_caption") or s.get("caption") or ""
                })
        except Exception:
            pass
    if not items:
        for k in final_state.get("knowledge", []) or []:
            items.append({"image_path": k.get("image_path",""), "caption": k.get("caption","")})

    # 2) render
    parts = []
    for i, it in enumerate(items, 1):
        raw_path = it["image_path"]
        served = _to_served_src(raw_path)
        # optional: warn if file missing
        fs_path = served.removeprefix("/files")
        if not os.path.exists(fs_path):
            parts.append(f"<p style='color:#c00'>[Missing file] {fs_path}</p>")
        parts.append(f"""
        <figure style="max-width:700px;margin:16px 0">
          <img src="{served}" style="width:100%;height:auto;border:1px solid #ddd;border-radius:6px" />
          <figcaption style="text-align:center;color:#555;margin-top:6px">
            <strong>Scene {i}.</strong> {it['caption']}
          </figcaption>
        </figure>
        """)
    display(HTML("\n".join(parts) if parts else "<em>No storyboard items.</em>"))
display_storyboard(final_state)