In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

In [None]:
%%capture
# Install latest transformers for Gemma 3N
!pip install --no-deps --upgrade timm # Only for Gemma 3N

In [None]:
from unsloth import FastModel
import torch

"""
fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    # Pretrained models
    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",

    # Other Gemma 3 quants
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
] # More models at https://huggingface.co/unsloth
"""


In [None]:

import json
import csv
import os
from datetime import timedelta
from typing import List, Tuple


In [None]:

def generate_chunks(audio_length_sec: int, window_size=15, overlap=5) -> List[Tuple[int, int, str]]:
    chunks = []
    start = 0
    while start < audio_length_sec:
        end = min(start + window_size, audio_length_sec)
        chunks.append((start, end, f"chunk_{start:04d}_{end:04d}.mp3"))
        start += (window_size - overlap)
    return chunks


In [None]:

def prune_summary(summary_list, max_entries=50):
    if len(summary_list) <= max_entries:
        return summary_list
    condensed_note = "Earlier: " + "; ".join(summary_list[:len(summary_list) - max_entries + 1])
    return [condensed_note] + summary_list[-(max_entries - 1):]


In [None]:
import re
import json

def inference_function(model, tokenizer, chunk_path, summary_so_far, chunk_start, chunk_end):
    summary_text = "\n".join(summary_so_far) if summary_so_far else "None so far."
    prompt = f"""
You are analysing a 15-second cockpit audio segment in the context of all important events so far.

Previous important events & tone trends:
{summary_text}

Current audio segment start time: {chunk_start}
Current audio segment end time: {chunk_end}

Your tasks:
1. Transcribe the spoken content of the current audio accurately.
2. Detect tone/emotion in this segment (e.g., calm, panicked, urgent, frustrated, angry, distressed).
3. Assess risk level using both:
   - The current 15-second audio segment.
   - The summary of previous important events and tone trends.
4. Explain the reasoning for your risk assessment, referring to both current and past context where relevant.
5. Keep the reasoning short but precise.
6. Do not include any text outside of the JSON structure.

Respond in **valid JSON** exactly in this format:
{{
  "chunk_start": "HH:MM:SS",
  "chunk_end": "HH:MM:SS",
  "transcript": "<full transcription>",
  "tone": {{
    "label": "<tone label>",
    "confidence": <0.0 to 1.0>
  }},
  "risk": {{
    "level": "<HIGH | MEDIUM | LOW>",
    "reason": "<short explanation>"
  }},
  "updated_summary": "<brief important events from this segment to append to summary>"
}}
"""

    messages = [{
        "role": "user",
        "content": [
            {"type": "audio", "audio": chunk_path},
            {"type": "text", "text": prompt}
        ]
    }]

    response = model.generate(
        **tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        ).to("cuda"),
        max_new_tokens=512,
        temperature=1.0,
        top_p=0.95,
        top_k=64
    )

    decoded = tokenizer.batch_decode(response)[0]
    json_str=decoded.split("<start_of_turn>model\n```json\n")[1].split("\n```<end_of_turn>")[0]

    print("=== Extracted JSON ===")
    print(json_str)
    print("=====================")

    try:
        return json.loads(json_str)
    except json.JSONDecodeError as e:
        print("⚠️ JSON parsing failed:", e)
        return {
            "chunk_start": chunk_start,
            "chunk_end": chunk_end,
            "transcript": "",
            "tone": {"label": "", "confidence": 0.0},
            "risk": {"level": "LOW", "reason": "Parsing failed"},
            "updated_summary": "[Parsing failed]"
        }

In [None]:

def starter_function(chunk, summary_so_far_list, all_results, model=None, tokenizer=None):
    if model is None or tokenizer is None:
        print("⚠️ Model/Tokenizer not passed in — loading inside starter_function()")
        model, tokenizer = load_model_and_tokenizer()

    chunk_start_sec, chunk_end_sec, audio_path = chunk
    chunk_start_str = str(timedelta(seconds=chunk_start_sec))
    chunk_end_str = str(timedelta(seconds=chunk_end_sec))

    result_json = inference_function(model, tokenizer, audio_path, summary_so_far_list, chunk_start_str, chunk_end_str)

    all_results.append(result_json)
    summary_so_far_list.append(result_json["updated_summary"])
    summary_so_far_list = prune_summary(summary_so_far_list)
    return all_results, summary_so_far_list


In [None]:

def merge_high_risk(results):
    merged = []
    current_event = None
    for r in results:
        if r["risk"]["level"] == "HIGH":
            if not current_event:
                current_event = {
                    "start": r["chunk_start"],
                    "end": r["chunk_end"],
                    "reason": [r["risk"]["reason"]],
                    "tone": [r["tone"]["label"]]
                }
            else:
                current_event["end"] = r["chunk_end"]
                current_event["reason"].append(r["risk"]["reason"])
                current_event["tone"].append(r["tone"]["label"])
        else:
            if current_event:
                merged.append(current_event)
                current_event = None
    if current_event:
        merged.append(current_event)
    return merged


In [None]:

def save_results(results, merged_events=None):
    with open("results.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Start", "End", "Transcript", "Tone", "Confidence", "Risk", "Reason"])
        for r in results:
            writer.writerow([
                r["chunk_start"], r["chunk_end"],
                r["transcript"], r["tone"]["label"],
                r["tone"]["confidence"], r["risk"]["level"],
                r["risk"]["reason"]
            ])

    if merged_events:
        with open("high_risk_events.csv", "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["Start", "End", "Tones", "Reasons"])
            for e in merged_events:
                writer.writerow([
                    e["start"], e["end"],
                    ", ".join(set(e["tone"])),
                    "; ".join(set(e["reason"]))
                ])


In [None]:

def load_model_and_tokenizer():
    from unsloth import FastModel
    print("Loading Gemma3N model and tokenizer...")
    model, tokenizer = FastModel.from_pretrained(
        model_name = "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
        dtype = None,
        max_seq_length = 2048,
        load_in_4bit = True,
        full_finetuning = False,
    )
    print("Model and tokenizer loaded.")
    return model, tokenizer


In [None]:

from pydub import AudioSegment

def split_audio(file_path, chunk_len=15*1000, overlap=5*1000):
    audio = AudioSegment.from_file(file_path)
    chunks = []
    start = 0
    while start < len(audio):
        end = min(start + chunk_len, len(audio))
        chunk_audio = audio[start:end]
        chunk_path = f"chunk_{start//1000:04d}_{end//1000:04d}.mp3"
        chunk_audio.export(chunk_path, format="mp3")
        chunks.append((start//1000, end//1000, chunk_path))
        start += (chunk_len - overlap)
    return chunks


In [None]:

import matplotlib.pyplot as plt

def plot_risk_timeline(results, merged_events):
    # Convert HH:MM:SS to seconds for plotting
    def time_to_sec(t):
        parts = list(map(int, t.split(":")))
        return parts[0]*3600 + parts[1]*60 + parts[2]

    times = [time_to_sec(r["chunk_start"]) for r in results]
    risk_levels = [3 if r["risk"]["level"]=="HIGH" else
                   2 if r["risk"]["level"]=="MEDIUM" else 1 for r in results]

    plt.figure(figsize=(12, 4))
    plt.plot(times, risk_levels, marker='o', label="Risk Level", linewidth=1)

    for event in merged_events:
        plt.axvspan(time_to_sec(event["start"]), time_to_sec(event["end"]), color="red", alpha=0.3)

    plt.yticks([1, 2, 3], ["LOW", "MEDIUM", "HIGH"])
    plt.xlabel("Time (seconds)")
    plt.ylabel("Risk Level")
    plt.title("Risk Timeline")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("risk_timeline.png")
    plt.show()


In [None]:
import time

def simulator(audio_file, realtime_delay=False, model=None,tokenizer=None):
    chunks = split_audio(audio_file, chunk_len=15*1000, overlap=5*1000)
    if not(model and tokenizer):
        model, tokenizer = load_model_and_tokenizer()
    summary_so_far_list = []
    all_results = []

    for chunk in chunks:
        all_results, summary_so_far_list = starter_function(
            chunk, summary_so_far_list, all_results, model, tokenizer
        )
        if realtime_delay:
            time.sleep(1)  # Change to 15 for real-time pacing

    merged_events = merge_high_risk(all_results)
    save_results(all_results, merged_events)
    plot_risk_timeline(all_results, merged_events)

    print(f"Processed {len(all_results)} chunks and {len(merged_events)} high-risk events.")

In [None]:
model, tokenizer = load_model_and_tokenizer()

In [None]:
simulator("8501_full audio.mov",model=model,tokenizer=tokenizer)

In [None]:
simulator("Alitalia final.mp3",model=model,tokenizer=tokenizer)