In [None]:
# ================================================================
# VIDEO ENVIRONMENT PREDICTION (PLACES365) EVERY ~3 SECONDS
# ---------------------------------------------------------------
# - Loads Places365 labels for readability
# - Loads your fine-tuned multi-head model from LOCAL paths
# - Samples frames from a video about every 3 seconds
# - Prints human-readable Places365 predictions over time
# ================================================================

import os
from pathlib import Path
import time

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image


# ----------------------------
# CONFIG: EDIT THESE PATHS
# ----------------------------

# Local path to YOUR fine-tuned multi-head model (.pth)
FINETUNED_MODEL_PATH = Path("/content/resnet_places365_best.pth")  # e.g. /content/resnet_places365_mit_multihead_best.pth

# Local Places365 label file (will auto-download if missing)
PLACES_LABELS_PATH = Path("categories_places365.txt")

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)


# ----------------------------
# STEP 1: Ensure Places365 label file
# ----------------------------

def download_places365_labels_if_needed(labels_path: Path):
    """
    Download the official Places365 categories file if it doesn't exist.
    """
    if labels_path.exists():
        print("Found Places365 labels at:", labels_path)
        return

    print("Downloading Places365 label file...")
    os.system(
        f"wget -O '{labels_path}' "
        "https://raw.githubusercontent.com/CSAILVision/places365/master/categories_places365.txt"
    )
    if not labels_path.exists():
        raise FileNotFoundError("Failed to download categories_places365.txt")


def load_places365_labels(labels_path: Path):
    """
    Load Places365 class names in correct index order from a file that looks like:
      /a/abbey 0
      /b/beach 48
      /k/kitchen 123
    We turn those into plain names like 'abbey', 'beach', 'kitchen', indexed by ID.
    """
    num_classes = 365
    classes = [None] * num_classes

    with open(labels_path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            parts = line.split()
            if len(parts) < 2:
                continue

            category_full = parts[0]       # e.g. '/a/abbey' or 'abbey'
            try:
                cls_id = int(parts[-1])    # the last token is the ID
            except ValueError:
                # If the file somehow doesn't end with an int, skip
                continue

            # Human-readable name: last part after '/'
            # '/a/abbey' -> 'abbey'
            # 'abbey'    -> 'abbey'
            name = category_full.split("/")[-1]

            if 0 <= cls_id < num_classes:
                classes[cls_id] = name

    # Fill any missing entries with a fallback
    for i in range(num_classes):
        if classes[i] is None:
            classes[i] = f"class_{i}"

    return classes



download_places365_labels_if_needed(PLACES_LABELS_PATH)
places_class_names = load_places365_labels(PLACES_LABELS_PATH)
print("Loaded", len(places_class_names), "Places365 class names.")


# ----------------------------
# STEP 2: Model definitions
# ----------------------------


class PlacesMITMultiHead(nn.Module):
    """
    ResNet50 backbone with:
      - places_head: 365 Places365 classes
      - mit_head: N indoor classes

    NOTE: This version does NOT take a pretrained resnet as an argument.
          We just build the architecture here and then load your fine-tuned
          state_dict from FINETUNED_MODEL_PATH.
    """
    def __init__(self, num_mit_classes: int):
        super().__init__()
        # Base ResNet50 (random init; we will overwrite with state_dict)
        self.backbone = models.resnet50(num_classes=365)

        in_features = self.backbone.fc.in_features

        # Places365 head (365 classes)
        self.places_head = nn.Linear(in_features, 365)

        # Backbone outputs features instead of logits
        self.backbone.fc = nn.Identity()

        # MIT indoor head (your classes)
        self.mit_head = nn.Linear(in_features, num_mit_classes)

    def forward(self, x):
        feats = self.backbone(x)
        places_logits = self.places_head(feats)
        mit_logits = self.mit_head(feats)
        return places_logits, mit_logits



# ----------------------------
# STEP 3: Image transform (no augmentation)
# ----------------------------

video_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])


# ----------------------------
# STEP 4: Build model & load your weights
# ----------------------------

def build_finetuned_model(finetuned_path: Path, num_mit_classes: int):
    """
    Build multi-head model and load your fine-tuned weights from local .pth.
    NO original Places365 .pth.tar needed at inference.
    """
    print("Building multi-head model architecture...")
    model = PlacesMITMultiHead(num_mit_classes=num_mit_classes)

    if not finetuned_path.exists():
        raise FileNotFoundError(f"Fine-tuned model not found at {finetuned_path}")

    print("Loading fine-tuned weights from:", finetuned_path)
    state = torch.load(str(finetuned_path), map_location=DEVICE)
    model.load_state_dict(state)

    model.to(DEVICE)
    model.eval()
    print("Model ready for video inference.")
    return model



# ----------------------------
# STEP 5: Video classification function
# ----------------------------
# --------------------------------------------------
# Build a list of indices for the Places classes we care about
# --------------------------------------------------

# MIT indoor class names (update if your list is different)

MIT_CLASS_NAMES = [ #ORDER MATTERS - must be in this order (only change if u wanna rename classes)
    "bathroom",
    "bedroom",
    "classroom",
    "colloquium",
    "common_area",
    "computer_lab",
    "hallway",
    "kitchen",
    "library",
    "living_room",
    "office",
]


NUM_MIT_CLASSES = len(MIT_CLASS_NAMES)



ALLOWED_PLACES_CLASSES = [ #this order does not matter, just make sure it exists in places365 (github)
    "office",
    "corridor",
    "classroom",
    "kitchen",
    "bathroom",
    "library",
    "living_room",
    "dining_room",
    "computer_room",
    "cafeteria",
    "lobby",
    "auditorium",
    "banquet_hall",
    "library/indoor",
    "bedroom",
    "church/indoor",
    "conference_room",
    "dining_hall",
    "garage/indoor",
]

allowed_places_indices = []
for i, name in enumerate(places_class_names):
    if name in ALLOWED_PLACES_CLASSES:
        allowed_places_indices.append(i)

print("Allowed Places365 classes:")
for idx in allowed_places_indices:
    print(f"  {idx:3d} -> {places_class_names[idx]}")

if not allowed_places_indices:
    print("⚠️ Warning: no allowed classes matched! Check spelling.")







def classify_video_both_heads_fast(
    video_path: Path,
    model: nn.Module,
    step_seconds: float = 3.0,
    topk_places: int = 3,
    topk_mit: int = 3,
    mit_priority_threshold: float = 0.6,
    show_timestamps: bool = True,
):
    """
    Classify video using BOTH Places365 head and MIT indoor head.

    - Samples a frame every `step_seconds`
    - Runs both heads
    - Chooses an "overall best" label:
        * If MIT top-1 prob >= mit_priority_threshold -> use MIT label
        * else -> use Places365 label
    - Prints:
        * Overall best label
        * Top MIT predictions
        * Top Places predictions
    """
    if not video_path.exists():
        raise FileNotFoundError(f"Video not found at {video_path}")

    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise RuntimeError(f"Failed to open video: {video_path}")

    fps = cap.get(cv2.CAP_PROP_FPS)
    if fps <= 0:
        print("Warning: could not read FPS, defaulting to 30.")
        fps = 30.0

    frame_interval = int(fps * step_seconds)
    if frame_interval <= 0:
        frame_interval = 1

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print(f"\nVideo: {video_path}")
    print(
        f"FPS: {fps:.2f}, total_frames: {total_frames}, "
        f"sampling every {frame_interval} frames (~{step_seconds}s)\n"
    )

    sampled_idx = 0
    frame_idx = 0

    while frame_idx < total_frames:
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame_bgr = cap.read()
        if not ret:
            break

        sampled_idx += 1

        # BGR -> RGB
        frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(frame_rgb)

        # Preprocess
        x = video_tf(pil_img).unsqueeze(0).to(DEVICE)



        # Forward pass: BOTH heads
        with torch.no_grad():
            places_logits, mit_logits = model(x)
            places_probs = F.softmax(places_logits, dim=1)[0]
            mit_probs    = F.softmax(mit_logits, dim=1)[0]


        # ----- Places365 top-k -----
        if allowed_places_indices:
            subset_probs = places_probs[allowed_places_indices]
            p_topk_probs, p_topk_idx = torch.topk(
                subset_probs,
                k=min(topk_places, len(allowed_places_indices)),
            )
            p_topk_probs = p_topk_probs.cpu().numpy()
            p_topk_idx   = p_topk_idx.cpu().numpy()
            p_global_idx = [allowed_places_indices[i] for i in p_topk_idx]
        else:
            p_topk_probs, p_global_idx = torch.topk(places_probs, k=topk_places)
            p_topk_probs = p_topk_probs.cpu().numpy()
            p_global_idx = p_global_idx.cpu().numpy()

        # Top-1 Places
        places_best_idx = int(p_global_idx[0])
        places_best_prob = float(p_topk_probs[0])
        places_best_name = places_class_names[places_best_idx]

        # ----- MIT head top-k -----
        m_topk_probs, m_topk_idx = torch.topk(
            mit_probs,
            k=min(topk_mit, len(MIT_CLASS_NAMES)),
        )

        m_topk_probs = m_topk_probs.cpu().numpy()
        m_topk_idx   = m_topk_idx.cpu().numpy()
        m_topk_names = [MIT_CLASS_NAMES[i] for i in m_topk_idx]

        # Top-1 MIT
        mit_best_idx = int(m_topk_idx[0])
        mit_best_prob = float(m_topk_probs[0])
        mit_best_name = MIT_CLASS_NAMES[mit_best_idx]

        # ----- Choose "overall best" -----
        # If MIT is very confident (e.g. indoor bathroom, classroom, etc.), trust it.
        # Otherwise, fall back to Places365's best guess.
        # ----- Choose "overall best" -----
        # Special case: if Places365 top-1 is 'cafeteria', ALWAYS trust it.


        if places_best_name == "cafeteria":
            overall_source = "Places365"
            overall_name   = places_best_name
            overall_prob   = places_best_prob

        # Otherwise, use the normal MIT-priority rule
        elif mit_best_prob >= mit_priority_threshold:
            overall_source = "MIT"
            overall_name   = mit_best_name
            overall_prob   = mit_best_prob
        else:
            overall_source = "Places365"
            overall_name   = places_best_name
            overall_prob   = places_best_prob

        t_sec = frame_idx / fps if fps > 0 else frame_idx / 30.0

        print("---------------------------------------------------")
        if show_timestamps:
            print(f"Sample #{sampled_idx} at ~{t_sec:.1f}s:")

        print(f"OVERALL BEST: [{overall_source}] {overall_name}  prob={overall_prob:.3f}")

        print("\nMIT head (fine-tuned) top predictions:")
        for name, p in zip(m_topk_names, m_topk_probs):
            print(f"  {name:20s} prob={p:.3f}")

        print("\nPlaces365 head (original) top predictions:")
        for cls_idx, p in zip(p_global_idx, p_topk_probs):
            cls_idx = int(cls_idx)
            cls_name = places_class_names[cls_idx] if 0 <= cls_idx < len(places_class_names) else f"idx_{cls_idx}"
            print(f"  {cls_name:30s} prob={p:.3f}")

        frame_idx += frame_interval

    cap.release()
    print("\nDone processing video.")








Using device: cuda
Found Places365 labels at: categories_places365.txt
Loaded 365 Places365 class names.
Allowed Places365 classes:
   27 -> auditorium
   38 -> banquet_hall
   45 -> bathroom
   52 -> bedroom
   75 -> cafeteria
   92 -> classroom
  100 -> computer_room
  102 -> conference_room
  106 -> corridor
  120 -> dining_hall
  121 -> dining_room
  203 -> kitchen
  215 -> living_room
  217 -> lobby
  244 -> office


In [None]:
# ----------------------------
# STEP 6: EXAMPLE USAGE
# ----------------------------

# 1) Specify how many MIT classes you trained with (e.g., 10)
NUM_MIT_CLASSES = 11

# 2) Build model once
video_model = build_finetuned_model(FINETUNED_MODEL_PATH, num_mit_classes=NUM_MIT_CLASSES)

# 3) Run on a local video file (edit this path!)
example_video_path = Path("/content/Auditorium.MOV")  # e.g. /content/example_video.mp4

# Uncomment to run:
classify_video_both_heads_fast(
    example_video_path,
    video_model,
    step_seconds=2.0,   # every X seconds
    topk_places=2,  #display lower condidence predictions
    topk_mit=2,     #display lower condidence predictions
    mit_priority_threshold=0.6, #how low to allow mit head to predict before it switches to places365
)

Building multi-head model architecture...
Loading fine-tuned weights from: /content/resnet_places365_best.pth
Model ready for video inference.

Video: /content/Auditorium.MOV
FPS: 59.94, total_frames: 321, sampling every 119 frames (~2.0s)

---------------------------------------------------
Sample #1 at ~0.0s:
OVERALL BEST: [MIT] colloquium  prob=0.998

MIT head (fine-tuned) top predictions:
  colloquium           prob=0.998
  library              prob=0.001

Places365 head (original) top predictions:
  auditorium                     prob=0.050
  conference_room                prob=0.001
---------------------------------------------------
Sample #2 at ~2.0s:
OVERALL BEST: [MIT] colloquium  prob=0.998

MIT head (fine-tuned) top predictions:
  colloquium           prob=0.998
  classroom            prob=0.001

Places365 head (original) top predictions:
  auditorium                     prob=0.052
  conference_room                prob=0.001
-------------------------------------------------