In [None]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.32.post2" if v == "2.8.0" else "0.0.29.post3")
    %pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    %pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
%pip install transformers==4.55.4
%pip install timm
import torch; torch._dynamo.config.recompile_limit = 64;

In [None]:
from unsloth import FastModel
import torch

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E4B-it",
    dtype = None, # None for auto detection
    max_seq_length = 1024, # Choose any for long context!
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

In [None]:
from transformers import TextStreamer
# Helper function for inference
def do_gemma_3n_inference(messages, max_new_tokens = 128):
    # 1. Prepare inputs
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True,
        tokenize = True,
        return_dict = True,
        return_tensors = "pt",
    ).to("cuda")

    # 2. Generate
    outputs = model.generate(
        **inputs,
        max_new_tokens = max_new_tokens,
        temperature = 0.1,
        do_sample = True,
        # Keep the streamer so you can still watch it work in the notebook
        streamer = TextStreamer(tokenizer, skip_prompt = True),
    )

    # 3. Decode only the NEW tokens (the answer)
    input_length = inputs.input_ids.shape[1]
    new_tokens = outputs[0][input_length:]
    response_text = tokenizer.decode(new_tokens, skip_special_tokens=True)

    return response_text

In [None]:
!pip install opencv-python

In [None]:
import os
import cv2
import csv
from PIL import Image
from pathlib import Path

def extract_frames_from_path(video_path, num_frames=8):
    """
    Extract evenly spaced frames from a video file.

    Args:
        video_path: Path to video file (string or Path object)
        num_frames: Number of frames to extract (default: 8)

    Returns:
        List of PIL Image objects
    """
    video_path = str(video_path)  # Ensure it's a string
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        print(f"Error: Could not open video file: {video_path}")
        return []

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    if total_frames == 0:
        print(f"Error: Video has 0 frames: {video_path}")
        cap.release()
        return []

    # Calculate the step size (logic from your first source)
    step = max(1, total_frames // num_frames)
    frames = []

    for i in range(num_frames):
        frame_idx = min(i * step, total_frames - 1)  # Ensure we don't exceed total frames
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        if not ret:
            break

        # Convert BGR (OpenCV) to RGB (PIL) without rotation
        img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        frames.append(img)

    cap.release()

    if len(frames) < num_frames:
        print(f"Warning: Only extracted {len(frames)}/{num_frames} frames from {video_path}")

    return frames


In [None]:
import os
import sys

# Set up paths relative to notebook location
notebook_dir = os.path.dirname(os.path.abspath("__file__")) if "__file__" in dir() else os.getcwd()
project_root = os.path.dirname(notebook_dir)  # Go up one level from src/
video_folder = os.path.join(project_root, "dataset", "videos")
output_csv = os.path.join(project_root, "outputs", "gemma3n-4b_inference_results.csv")

# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(output_csv), exist_ok=True)

print(f"Project root: {project_root}")
print(f"Looking for videos in: {video_folder}")
print(f"Will save results to: {output_csv}")

results = []

# Check if video folder exists
if not os.path.exists(video_folder):
    print(f"\n⚠️  Warning: Video folder not found: {video_folder}")
    print("Please ensure videos are placed in the dataset/videos/ directory")
    print("Or update the video_folder path above to point to your videos location")
else:
    # Get list of all video files recursively
    video_files = []
    for root, dirs, files in os.walk(video_folder):
        for file in files:
            if Path(file).suffix.lower() in ['.mp4', '.mov', '.avi', '.mkv']:
                video_files.append(Path(root) / file)

    print(f"\nFound {len(video_files)} videos. Starting processing...\n")

    for video_path in video_files:
        relative_path = video_path.relative_to(video_folder)
        print(f"Processing: {relative_path}...")

        # 1. Extract the "Frame Thing" (8 snapshots)
        video_frames = extract_frames_from_path(str(video_path), num_frames=8)

        if not video_frames:
            print(f"  ⚠️  Failed to extract frames from {relative_path}")
            continue

        # 2. Build the message
        messages = [
            {
                "role": "user",
                "content": [
                    *[{"type": "image", "image": img} for img in video_frames],
                    {"type": "text", "text": "You are an assistive navigation system for a visually impaired user. Analyze the provided video from the user's forward perspective. Identify all the immediate, high-risk obstructions. State the obstruction's location using the 12-hour clock face. Process the provided video generate a single, actionable safety alert."}
                ],
            },
        ]

        # 3. Inference
        try:
            response = do_gemma_3n_inference(messages, max_new_tokens=256)
            results.append([str(relative_path), response])
            print(f"  ✓ Completed\n")
        except Exception as e:
            print(f"  ✗ Error processing {relative_path}: {e}\n")

    # Save to CSV
    if results:
        with open(output_csv, "w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow(["video_name", "model_output"])
            writer.writerows(results)

        print(f"\n✅ Success! Saved {len(results)} results to {output_csv}")
    else:
        print("\n⚠️  No results to save. Please check if videos were processed successfully.")
