In [6]:
# Install required packages
!pip install torch torchvision scipy numpy pillow matplotlib opencv-python torchaudio
!pip install davis2017-evaluation

# # Download and extract DAVIS 2017 dataset
!wget -O DAVIS-2017-trainval-480p.zip https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip
!unzip -q DAVIS-2017-trainval-480p.zip -d /content/DAVIS

import cv2
import numpy as np
from PIL import Image
import torch
import torchvision
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import functional as F
from IPython.display import Image as IPImage, display
from google.colab import files
import shutil
import matplotlib.pyplot as plt
import random
from sklearn.metrics import accuracy_score, f1_score
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import warnings
import time


# Define the DAVIS Dataset class
class DAVISDataset(Dataset):
    def __init__(self, davis_root, subset='val', transform=None):
        self.davis_root = davis_root
        self.subset = subset
        self.transform = transform
        self.img_dir = os.path.join(davis_root, 'JPEGImages', '480p')
        self.anno_dir = os.path.join(davis_root, 'Annotations', '480p')
        with open(os.path.join(davis_root, 'ImageSets', '2017', f'{subset}.txt'), 'r') as f:
            self.sequences = f.read().splitlines()
        self.class_map = {
            'bear': 1, 'bmx-bumps': 2, 'boat': 3, 'boxing-fisheye': 4, 'breakdance-flare': 5,
            'bus': 6, 'car-turn': 7, 'cat-girl': 8, 'classic-car': 9, 'color-run': 10,
            'crossing': 11, 'dance-jump': 12, 'dancing': 13, 'disc-jockey': 14, 'dog-agility': 15,
            'dog-gooses': 16, 'dogs-scale': 17, 'drift-turn': 18, 'drone': 19, 'elephant': 20,
            'flamingo': 21, 'hike': 22, 'hockey': 23, 'horsejump-low': 24, 'kid-football': 25,
            'kite-walk': 26, 'koala': 27, 'lady-running': 28, 'lindy-hop': 29, 'longboard': 30,
            'lucia': 31, 'mallard-fly': 32, 'mallard-water': 33, 'miami-surf': 34, 'motocross-bumps': 35,
            'motorbike': 36, 'night-race': 37, 'paragliding': 38, 'planes-water': 39, 'rallye': 40,
            'rhino': 41, 'rollerblade': 42, 'schoolgirls': 43, 'scooter-board': 44, 'scooter-gray': 45,
            'sheep': 46, 'skate-park': 47, 'snowboard': 48, 'soccerball': 49, 'stroller': 50,
            'stunt': 51, 'surf': 52, 'swing': 53, 'tennis': 54, 'tractor-sand': 55,
            'train': 56, 'tuk-tuk': 57, 'upside-down': 58, 'varanus-cage': 59, 'walking': 60,
            'bike-packing': 61, 'blackswan': 62, 'bmx-trees': 63, 'breakdance': 64, 'camel': 65,
            'car-roundabout': 66, 'car-shadow': 67, 'cows': 68, 'dance-twirl': 69, 'dog': 70,
            'dogs-jump': 71, 'drift-chicane': 72, 'drift-straight': 73, 'goat': 74, 'gold-fish': 75,
            'horsejump-high': 76, 'india': 77, 'judo': 78, 'kite-surf': 79, 'lab-coat': 80,
            'libby': 81, 'loading': 82, 'mbike-trick': 83, 'motocross-jump': 84, 'paragliding-launch': 85,
            'parkour': 86, 'pigs': 87, 'scooter-black': 88, 'shooting': 89, 'soapbox': 90
        }
        self.class_names = {v: k for k, v in self.class_map.items()}
        self.samples = []
        for seq in self.sequences:
            img_seq_dir = os.path.join(self.img_dir, seq)
            anno_seq_dir = os.path.join(self.anno_dir, seq)
            img_files = sorted(os.listdir(img_seq_dir))
            for img_file in img_files:
                if img_file.endswith('.jpg'):
                    frame_num = img_file.split('.')[0]
                    anno_file = f"{frame_num}.png"
                    if os.path.exists(os.path.join(anno_seq_dir, anno_file)):
                        self.samples.append((
                            os.path.join(img_seq_dir, img_file),
                            os.path.join(anno_seq_dir, anno_file),
                            seq
                        ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, mask_path, seq_name = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        image_np = np.array(image)
        mask_np = np.array(mask)
        obj_ids = np.unique(mask_np)[1:] if np.unique(mask_np).size > 1 else []
        if len(obj_ids) == 0:
            masks = []
        else:
            masks = mask_np == obj_ids[:, None, None]
        boxes = []
        labels = []
        class_label = self.class_map.get(seq_name, 1)
        for i, mask in enumerate(masks):
            pos = np.where(mask)
            if len(pos[0]) > 0:
                xmin = np.min(pos[1])
                xmax = np.max(pos[1])
                ymin = np.min(pos[0])
                ymax = np.max(pos[0])
                if xmax > xmin and ymax > ymin:
                    boxes.append([xmin, ymin, xmax, ymax])
                    labels.append(class_label)
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros(0, dtype=torch.int64)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }
        if self.transform:
            image, target = self.transform(image, target)
        return image, target

# Define transformation classes
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target

class RandomHorizontalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = F.hflip(image)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
        return image, target

# Define the Faster R-CNN model
def create_fasterrcnn_model(num_classes, pretrained_backbone=True):
    import torchvision
    backbone = torchvision.models.resnet50(weights='DEFAULT' if pretrained_backbone else None)
    backbone = nn.Sequential(*list(backbone.children())[:-2])
    backbone.out_channels = 2048
    anchor_generator = AnchorGenerator(
        sizes=((32, 64, 128, 256, 512),),
        aspect_ratios=((0.5, 1.0, 2.0),)
    )
    roi_pooler = MultiScaleRoIAlign(
        featmap_names=['0'],
        output_size=7,
        sampling_ratio=2
    )
    model = FasterRCNN(
        backbone=backbone,
        num_classes=num_classes,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=roi_pooler,
        min_size=800,
        max_size=1333
    )
    return model

def compute_iou(box1, box2):
    """ Compute IoU between two boxes: box1 and box2 are tensors (x1, y1, x2, y2) """
    xA = torch.max(box1[0], box2[0])
    yA = torch.max(box1[1], box2[1])
    xB = torch.min(box1[2], box2[2])
    yB = torch.min(box1[3], box2[3])

    inter_area = max(0, xB - xA) * max(0, yB - yA)
    box1_area = max(0, box1[2] - box1[0]) * max(0, box1[3] - box1[1])
    box2_area = max(0, box2[2] - box2[0]) * max(0, box2[3] - box2[1])
    union = box1_area + box2_area - inter_area
    return inter_area / union if union != 0 else 0

# Enhanced training function with F1, accuracy, and loss tracking
def train_model(model, data_loader, optimizer, num_epochs, device):
    model.to(device)
    loss_history, acc_history, f1_history = [], [], []

    for epoch in range(num_epochs):
        model.train()
        epoch_running_loss = 0.0
        all_preds, all_targets = [], []

        for images, targets in data_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            total_loss = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            epoch_running_loss += total_loss.item()

            # Prediction for metrics
            model.eval()
            with torch.no_grad():
                outputs = model(images)

            model.train()

            for output, target in zip(outputs, targets):
                if len(output['boxes']) == 0 or len(target['boxes']) == 0:
                    warnings.warn("Empty prediction or target, skipping this sample.")
                    continue

                pred_boxes = output['boxes'].cpu()
                pred_labels = output['labels'].cpu()
                tgt_boxes = target['boxes'].cpu()
                tgt_labels = target['labels'].cpu()

                # Simple match logic using IoU > 0.5
                matched_preds, matched_targets = [], []
                for i, tgt_box in enumerate(tgt_boxes):
                    best_iou, best_idx = 0, -1
                    for j, pred_box in enumerate(pred_boxes):
                        iou = compute_iou(tgt_box, pred_box)
                        if iou > best_iou:
                            best_iou = iou
                            best_idx = j
                    if best_iou > 0.5:
                        matched_preds.append(pred_labels[best_idx].item())
                        matched_targets.append(tgt_labels[i].item())

                if matched_preds:
                    all_preds.extend(matched_preds)
                    all_targets.extend(matched_targets)

        # Compute metrics
        epoch_loss = epoch_running_loss / len(data_loader)
        epoch_acc = accuracy_score(all_targets, all_preds) if all_preds else 0.0
        epoch_f1 = f1_score(all_targets, all_preds, average='weighted') if all_preds else 0.0

        loss_history.append(epoch_loss)
        acc_history.append(epoch_acc)
        f1_history.append(epoch_f1)

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}, F1: {epoch_f1:.4f}")

    return model

# Add this validation function after the train_model function
def validate_model(model, data_loader, device):
    model.eval()
    model.to(device)
    all_preds, all_targets = [], []
    val_loss = 0.0 # Initialize val_loss

    with torch.no_grad():
        for images, targets in data_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Forward pass to calculate loss
            loss_dict = model(images, targets) # Calculate loss
            losses = sum(loss for loss in loss_dict.values())
            val_loss += losses.item() # Accumulate loss

            outputs = model(images)

            for output, target in zip(outputs, targets):
                if len(output['boxes']) == 0 or len(target['boxes']) == 0:
                    warnings.warn("Empty prediction or target in validation.")
                    continue

                pred_boxes = output['boxes'].cpu()
                pred_labels = output['labels'].cpu()
                tgt_boxes = target['boxes'].cpu()
                tgt_labels = target['labels'].cpu()

                matched_preds, matched_targets = [], []
                for i, tgt_box in enumerate(tgt_boxes):
                    best_iou, best_idx = 0, -1
                    for j, pred_box in enumerate(pred_boxes):
                        iou = compute_iou(tgt_box, pred_box)
                        if iou > best_iou:
                            best_iou = iou
                            best_idx = j
                    if best_iou > 0.5:
                        matched_preds.append(pred_labels[best_idx].item())
                        matched_targets.append(tgt_labels[i].item())

                    if matched_preds:
                        all_preds.extend(matched_preds)
                        all_targets.extend(matched_targets)
    # Calculate average loss
    val_loss /= len(data_loader)

    val_acc = accuracy_score(all_targets, all_preds) if all_preds else 0.0
    val_f1 = f1_score(all_targets, all_preds, average='weighted') if all_preds else 0.0

    print(f"Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, F1 Score: {val_f1:.4f}") # Print validation loss

    return val_loss, val_acc, val_f1


def plot_validation_metrics(val_losses, val_accuracies, val_f1_scores):
    epochs = list(range(1, len(val_losses)+1))
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.plot(epochs, val_losses, 'r-o')
    plt.title("Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")

    plt.subplot(1, 3, 2)
    plt.plot(epochs, val_accuracies, 'g-o')
    plt.title("Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")

    plt.subplot(1, 3, 3)
    plt.plot(epochs, val_f1_scores, 'b-o')
    plt.title("Validation F1 Score")
    plt.xlabel("Epoch")
    plt.ylabel("F1 Score")

    plt.tight_layout()
    plt.show()

def plot_training_curves(loss_history, acc_history, f1_history):
    epochs = range(1, len(loss_history)+1)
    plt.figure(figsize=(10,6))
    plt.plot(epochs, loss_history, 'r-', label='Loss')
    plt.plot(epochs, acc_history, 'b--', label='Accuracy')
    plt.plot(epochs, f1_history, 'g-.', label='F1 Score')
    plt.xlabel("Epochs")
    plt.ylabel("Metric")
    plt.title("Training Metrics Over Epochs")
    plt.legend()
    plt.grid(True)
    plt.show()

# Function to extract frames
def extract_frames(video_path, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return None
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_paths = []
    frame_count = 0
    try:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame_count += 1
            frame_path = os.path.join(output_dir, f"frame_{frame_count:06d}.jpg")
            cv2.imwrite(frame_path, frame)
            frame_paths.append(frame_path)
            print(f"Extracting frame {frame_count}/{total_frames}", end='\r')
    except Exception as e:
        print(f"Error during frame extraction: {str(e)}")
        cap.release()
        return None
    cap.release()
    print(f"\nExtracted {frame_count} frames to {output_dir}")
    return frame_paths, width, height, fps

# Function to visualize detections
def visualize_detections(image, prediction, class_names, threshold=0.3):
    image_np = image.permute(1, 2, 0).cpu().numpy()
    image_np = (image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255
    image_np = image_np.astype(np.uint8).copy()
    for box, label, score in zip(prediction['boxes'].cpu().numpy(),
                                prediction['labels'].cpu().numpy(),
                                prediction['scores'].cpu().numpy()):
        if score >= threshold:
            x1, y1, x2, y2 = box.astype(int)
            class_name = class_names.get(label, 'Unknown')
            cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(image_np, f'{class_name}: {score:.2f}', (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
    return image_np

# Function to process video frames
def process_video_frames(model, video_path, class_names, device, frames_dir='/content/frames', output_frames_dir='/content/output_frames', threshold=0.3):
    if os.path.exists(frames_dir):
        shutil.rmtree(frames_dir)
    if os.path.exists(output_frames_dir):
        shutil.rmtree(output_frames_dir)
    os.makedirs(frames_dir)
    os.makedirs(output_frames_dir)
    frame_paths, width, height, fps = extract_frames(video_path, frames_dir)
    if frame_paths is None:
        print("Error: Frame extraction failed.")
        return None
    total_frames = len(frame_paths)
    if total_frames == 0:
        print("Error: No frames extracted from video.")
        return None
    transform = Compose([
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    model.eval()
    sample_frames = []
    frame_count = 0
    try:
        with torch.no_grad():
            for frame_path in frame_paths:
                frame_count += 1
                print(f"Processing frame {frame_count}/{total_frames}", end='\r')
                frame_rgb = cv2.cvtColor(cv2.imread(frame_path), cv2.COLOR_BGR2RGB)
                image = Image.fromarray(frame_rgb)
                image_tensor, _ = transform(image, {})
                image_tensor = image_tensor.unsqueeze(0).to(device)
                outputs = model(image_tensor)[0]
                annotated_frame = visualize_detections(image_tensor[0], outputs, class_names, threshold)
                output_frame_path = os.path.join(output_frames_dir, f"annotated_frame_{frame_count:06d}.jpg")
                cv2.imwrite(output_frame_path, cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR))
                if frame_count % (total_frames // 5 + 1) == 0 and len(sample_frames) < 5:
                    sample_frames.append(annotated_frame)
                del image_tensor, outputs
                torch.cuda.empty_cache() if device.type == 'cuda' else None
    except Exception as e:
        print(f"Error during frame processing: {str(e)}")
        return None
    print(f"\nProcessed {frame_count} frames. Annotated frames saved to {output_frames_dir}")
    for i, frame in enumerate(sample_frames):
        sample_path = f'/content/sample_frame_{i}.png'
        cv2.imwrite(sample_path, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        display(IPImage(sample_path))
    shutil.rmtree(frames_dir)
    print("Cleaned up temporary frames directory")
    return output_frames_dir

# Main function
def main():
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    davis_root = '/content/DAVIS/DAVIS'
    transform = Compose([
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        RandomHorizontalFlip(prob=0.5)
    ])

    # Optional: Train the model (uncomment to train)
    train_dataset = DAVISDataset(davis_root, subset='train', transform=transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=True,
        num_workers=2,
        collate_fn=lambda x: tuple(zip(*x))
    )
    num_classes = 91  # Adjust num_classes
    model = create_fasterrcnn_model(num_classes=num_classes)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

    model = train_model(model, train_loader, optimizer, 1, device)
    torch.save(model.state_dict(), '/content/fasterrcnn_resnet50_davis.pth')
    print("Training completed!")

    # Load dataset for class names
    dataset = DAVISDataset(davis_root, subset='val', transform=transform)
    class_names = dataset.class_names

    from dataset import get_davis_dataloader  # Custom dataloader
    from model import get_model  # Your Faster R-CNN model setup

    # Load model and data
    model = get_model(num_classes=91)  # Adjust num_classes
    model.to(device)

    _, val_loader = get_davis_dataloader(batch_size=2)  # Custom function

    val_losses = []
    val_accuracies = []
    val_f1_scores = []

    num_epochs = 3
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        val_loss, val_acc, val_f1 = validate_model(model, val_loader, device)

        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        val_f1_scores.append(val_f1)

    # Plot all metrics after validation
    plot_validation_metrics(val_losses, val_accuracies, val_f1_scores)

    # Load the model
    num_classes = 91
    model = create_fasterrcnn_model(num_classes=num_classes)
    try:
        model.load_state_dict(torch.load('/content/fasterrcnn_resnet50_davis.pth', map_location=device))
    except FileNotFoundError:
        print("Error: Model weights file 'fasterrcnn_resnet50_davis.pth' not found.")
        print("Please upload the trained model or uncomment the training section to train the model.")
        print("To proceed without training, upload 'fasterrcnn_resnet50_davis.pth' now:")
        uploaded = files.upload()
        if 'fasterrcnn_resnet50_davis.pth' in uploaded:
            model.load_state_dict(torch.load('/content/fasterrcnn_resnet50_davis.pth', map_location=device))
        else:
            print("Error: Model weights not uploaded. Cannot proceed.")
            return
    model.to(device)

    # Upload video
    print("Please upload your video file (e.g., MP4 format):")
    uploaded = files.upload()
    if not uploaded:
        print("Error: No video file uploaded.")
        return
    video_path = list(uploaded.keys())[0]

    # Process video frames
    output_frames_dir = process_video_frames(
        model, video_path, class_names, device,
        frames_dir='/content/frames',
        output_frames_dir='/content/output_frames',
        threshold=0.3
    )

    if output_frames_dir:
        print(f"Processed frames saved to {output_frames_dir}")
        print("To download all frames, run the following cell to zip the output_frames directory:")
        print("Then download the zip file from the Colab file explorer.")
        # Zip the output frames for easy download
        !zip -r /content/output_frames.zip /content/output_frames
    else:
        print("Frame processing failed. Please check error messages.")

if __name__ == "__main__":
    main()

Using device: cuda


OutOfMemoryError: CUDA out of memory. Tried to allocate 392.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 186.12 MiB is free. Process 9082 has 14.56 GiB memory in use. Of the allocated memory 14.38 GiB is allocated by PyTorch, and 42.48 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# This is the updated version of the original code with validation funtions.
# Install required packages
!pip install torch torchvision scipy numpy pillow matplotlib opencv-python
!pip install davis2017-evaluation
!pip install --upgrade torch torchvision torchaudio

# # Download and extract DAVIS 2017 dataset
!wget -O DAVIS-2017-trainval-480p.zip https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip
!unzip -q DAVIS-2017-trainval-480p.zip -d /content/DAVIS

import os
import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import functional as F
from IPython.display import Image as IPImage, display
from google.colab import files
import shutil
import matplotlib.pyplot as plt
import random
from sklearn.metrics import accuracy_score, f1_score


# Define the DAVIS Dataset class
class DAVISDataset(Dataset):
    def __init__(self, davis_root, subset='val', transform=None):
        self.davis_root = davis_root
        self.subset = subset
        self.transform = transform
        self.img_dir = os.path.join(davis_root, 'JPEGImages', '480p')
        self.anno_dir = os.path.join(davis_root, 'Annotations', '480p')
        with open(os.path.join(davis_root, 'ImageSets', '2017', f'{subset}.txt'), 'r') as f:
            self.sequences = f.read().splitlines()
        self.class_map = {
            'bear': 1, 'bmx-bumps': 2, 'boat': 3, 'boxing-fisheye': 4, 'breakdance-flare': 5,
            'bus': 6, 'car-turn': 7, 'cat-girl': 8, 'classic-car': 9, 'color-run': 10,
            'crossing': 11, 'dance-jump': 12, 'dancing': 13, 'disc-jockey': 14, 'dog-agility': 15,
            'dog-gooses': 16, 'dogs-scale': 17, 'drift-turn': 18, 'drone': 19, 'elephant': 20,
            'flamingo': 21, 'hike': 22, 'hockey': 23, 'horsejump-low': 24, 'kid-football': 25,
            'kite-walk': 26, 'koala': 27, 'lady-running': 28, 'lindy-hop': 29, 'longboard': 30,
            'lucia': 31, 'mallard-fly': 32, 'mallard-water': 33, 'miami-surf': 34, 'motocross-bumps': 35,
            'motorbike': 36, 'night-race': 37, 'paragliding': 38, 'planes-water': 39, 'rallye': 40,
            'rhino': 41, 'rollerblade': 42, 'schoolgirls': 43, 'scooter-board': 44, 'scooter-gray': 45,
            'sheep': 46, 'skate-park': 47, 'snowboard': 48, 'soccerball': 49, 'stroller': 50,
            'stunt': 51, 'surf': 52, 'swing': 53, 'tennis': 54, 'tractor-sand': 55,
            'train': 56, 'tuk-tuk': 57, 'upside-down': 58, 'varanus-cage': 59, 'walking': 60,
            'bike-packing': 61, 'blackswan': 62, 'bmx-trees': 63, 'breakdance': 64, 'camel': 65,
            'car-roundabout': 66, 'car-shadow': 67, 'cows': 68, 'dance-twirl': 69, 'dog': 70,
            'dogs-jump': 71, 'drift-chicane': 72, 'drift-straight': 73, 'goat': 74, 'gold-fish': 75,
            'horsejump-high': 76, 'india': 77, 'judo': 78, 'kite-surf': 79, 'lab-coat': 80,
            'libby': 81, 'loading': 82, 'mbike-trick': 83, 'motocross-jump': 84, 'paragliding-launch': 85,
            'parkour': 86, 'pigs': 87, 'scooter-black': 88, 'shooting': 89, 'soapbox': 90
        }
        self.class_names = {v: k for k, v in self.class_map.items()}
        self.samples = []
        for seq in self.sequences:
            img_seq_dir = os.path.join(self.img_dir, seq)
            anno_seq_dir = os.path.join(self.anno_dir, seq)
            img_files = sorted(os.listdir(img_seq_dir))
            for img_file in img_files:
                if img_file.endswith('.jpg'):
                    frame_num = img_file.split('.')[0]
                    anno_file = f"{frame_num}.png"
                    if os.path.exists(os.path.join(anno_seq_dir, anno_file)):
                        self.samples.append((
                            os.path.join(img_seq_dir, img_file),
                            os.path.join(anno_seq_dir, anno_file),
                            seq
                        ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, mask_path, seq_name = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        image_np = np.array(image)
        mask_np = np.array(mask)
        obj_ids = np.unique(mask_np)[1:] if np.unique(mask_np).size > 1 else []
        if len(obj_ids) == 0:
            masks = []
        else:
            masks = mask_np == obj_ids[:, None, None]
        boxes = []
        labels = []
        class_label = self.class_map.get(seq_name, 1)
        for i, mask in enumerate(masks):
            pos = np.where(mask)
            if len(pos[0]) > 0:
                xmin = np.min(pos[1])
                xmax = np.max(pos[1])
                ymin = np.min(pos[0])
                ymax = np.max(pos[0])
                if xmax > xmin and ymax > ymin:
                    boxes.append([xmin, ymin, xmax, ymax])
                    labels.append(class_label)
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros(0, dtype=torch.int64)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }
        if self.transform:
            image, target = self.transform(image, target)
        return image, target

# Define transformation classes
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target

class RandomHorizontalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = F.hflip(image)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
        return image, target

# Define the Faster R-CNN model
def create_fasterrcnn_model(num_classes, pretrained_backbone=True):
    import torchvision
    backbone = torchvision.models.resnet50(weights='DEFAULT' if pretrained_backbone else None)
    backbone = nn.Sequential(*list(backbone.children())[:-2])
    backbone.out_channels = 2048
    anchor_generator = AnchorGenerator(
        sizes=((32, 64, 128, 256, 512),),
        aspect_ratios=((0.5, 1.0, 2.0),)
    )
    roi_pooler = MultiScaleRoIAlign(
        featmap_names=['0'],
        output_size=7,
        sampling_ratio=2
    )
    model = FasterRCNN(
        backbone=backbone,
        num_classes=num_classes,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=roi_pooler,
        min_size=800,
        max_size=1333
    )
    return model

# Enhanced training function with F1, accuracy, and loss tracking
def train_model(model, data_loader, optimizer, device, num_epochs=4):
    model.to(device)
    model.train()
    losses_history = []
    accuracy_history = []
    f1_history = []
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        running_loss = 0.0
        all_preds = []
        all_targets = []
        for i, (images, targets) in enumerate(data_loader):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            optimizer.zero_grad()
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            losses.backward()
            optimizer.step()
            running_loss += losses.item()

            # Collect predictions and targets for metrics
            model.eval()
            with torch.no_grad():
                outputs = model(images)
            model.train()

            for output, target in zip(outputs, targets):
                # Get predictions and target labels for the current image
                pred_labels = output['labels'].cpu().numpy()
                true_labels = target['labels'].cpu().numpy()

                # Ensure predictions and targets have the same length for this image
                min_len = min(len(pred_labels), len(true_labels))
                pred_labels = pred_labels[:min_len]
                true_labels = true_labels[:min_len]

                all_preds.extend(pred_labels)
                all_targets.extend(true_labels)

            if (i + 1) % 10 == 0:
                print(f"Batch {i+1}/{len(data_loader)}, Loss: {running_loss/10:.4f}")
                running_loss = 0.0

        epoch_loss = running_loss / len(data_loader)
        losses_history.append(epoch_loss)

        acc = accuracy_score(all_targets, all_preds)
        f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)
        accuracy_history.append(acc)
        f1_history.append(f1)
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} Accuracy: {acc:.4f} F1 Score: {f1:.4f}")

    # Plotting metrics
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 3, 1)
    plt.plot(range(1, num_epochs+1), losses_history, label='Loss')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)

    plt.subplot(1, 3, 2)
    plt.plot(range(1, num_epochs+1), accuracy_history, label='Accuracy', color='green')
    plt.title('Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)

    plt.subplot(1, 3, 3)
    plt.plot(range(1, num_epochs+1), f1_history, label='F1 Score', color='red')
    plt.title('Training F1 Score')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('/content/training_metrics.png')
    plt.close()
    display(IPImage('/content/training_metrics.png'))
    print("Training metrics saved to /content/training_metrics.png")
    return model

# Add this validation function after the train_model function
def validate_model(model, data_loader, device):
    print("\nStarting validation...")
    model.eval()
    all_preds = []
    all_targets = []
    total_loss = 0.0

    with torch.no_grad():
        for images, targets in data_loader:
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Forward pass
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            total_loss += losses.item()

            outputs = model(images)

            for output, target in zip(outputs, targets):
                pred_labels = output['labels'].cpu().numpy()
                true_labels = target['labels'].cpu().numpy()
                min_len = min(len(pred_labels), len(true_labels))
                pred_labels = pred_labels[:min_len]
                true_labels = true_labels[:min_len]
                all_preds.extend(pred_labels)
                all_targets.extend(true_labels)

    avg_loss = total_loss / len(data_loader)
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

    print(f"Validation Loss: {avg_loss:.4f}")
    print(f"Validation Accuracy: {acc:.4f}")
    print(f"Validation F1 Score: {f1:.4f}\n")
    model.train()
    return avg_loss, acc, f1

# Function to extract frames
def extract_frames(video_path, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return None
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_paths = []
    frame_count = 0
    try:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame_count += 1
            frame_path = os.path.join(output_dir, f"frame_{frame_count:06d}.jpg")
            cv2.imwrite(frame_path, frame)
            frame_paths.append(frame_path)
            print(f"Extracting frame {frame_count}/{total_frames}", end='\r')
    except Exception as e:
        print(f"Error during frame extraction: {str(e)}")
        cap.release()
        return None
    cap.release()
    print(f"\nExtracted {frame_count} frames to {output_dir}")
    return frame_paths, width, height, fps

# Function to visualize detections
def visualize_detections(image, prediction, class_names, threshold=0.3):
    image_np = image.permute(1, 2, 0).cpu().numpy()
    image_np = (image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255
    image_np = image_np.astype(np.uint8).copy()
    for box, label, score in zip(prediction['boxes'].cpu().numpy(),
                                prediction['labels'].cpu().numpy(),
                                prediction['scores'].cpu().numpy()):
        if score >= threshold:
            x1, y1, x2, y2 = box.astype(int)
            class_name = class_names.get(label, 'Unknown')
            cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(image_np, f'{class_name}: {score:.2f}', (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
    return image_np

import os
import random
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor, Normalize
from torchvision.transforms import functional as F
import cv2
from PIL import Image
from google.colab import files

# Function to process video frames
def process_video_frames(model, video_path, class_names, device, frames_dir, output_video_path, threshold=0.3):
    os.makedirs(frames_dir, exist_ok=True)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error: Could not open video.")
        return None

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out_video = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

    frame_idx = 0
    model.eval()

    with torch.no_grad():
        while True:
            ret, frame = cap.read()
            if not ret:
                break

            # Convert frame to RGB and tensor
            pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            img_tensor = F.to_tensor(pil_img).to(device)

            prediction = model([img_tensor])[0]

            for box, score, label in zip(prediction["boxes"], prediction["scores"], prediction["labels"]):
                if score >= threshold:
                    x1, y1, x2, y2 = map(int, box)
                    label_name = class_names[label] if label < len(class_names) else str(label)
                    cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    cv2.putText(frame, f"{label_name}: {score:.2f}", (x1, y1 - 10),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)

            out_video.write(frame)
            frame_idx += 1

    cap.release()
    out_video.release()
    print(f"Video processing complete. Output saved to {output_video_path}")
    return output_video_path

# Main function
def main():
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    davis_root = '/content/DAVIS/DAVIS'
    transform = Compose([
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # === Training section (optional) ===
    train_dataset = DAVISDataset(davis_root, subset='train', transform=transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=True,
        num_workers=2,
        collate_fn=lambda x: tuple(zip(*x))
    )
    num_classes = 91
    model = create_fasterrcnn_model(num_classes=num_classes)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

    # Train model
    model = train_model(model, train_loader, optimizer, device, num_epochs=1)
    torch.save(model.state_dict(), '/content/fasterrcnn_resnet50_davis.pth')
    print("Training completed!")

    # === Load for inference ===
    dataset = DAVISDataset(davis_root, subset='val', transform=transform)
    class_names = dataset.class_names

    model = create_fasterrcnn_model(num_classes=num_classes)
    try:
        model.load_state_dict(torch.load('/content/fasterrcnn_resnet50_davis.pth', map_location=device))
    except FileNotFoundError:
        print("Error: Model weights not found. Please upload:")
        uploaded = files.upload()
        if 'fasterrcnn_resnet50_davis.pth' in uploaded:
            model.load_state_dict(torch.load('/content/fasterrcnn_resnet50_davis.pth', map_location=device))
        else:
            print("Model weights not uploaded. Exiting.")
            return
    model.to(device)

    # === Video Upload & Processing ===
    print("Upload your video file (e.g., .mp4):")
    uploaded = files.upload()
    if not uploaded:
        print("No video uploaded. Exiting.")
        return
    video_path = list(uploaded.keys())[0]

    output_video_path = '/content/output_video.mp4'
    result = process_video_frames(
        model=model,
        video_path=video_path,
        class_names=class_names,
        device=device,
        frames_dir='/content/frames',
        output_video_path=output_video_path,
        threshold=0.3
    )

    if result:
        print(f"Processed video saved at: {result}")
        print("Download using this cell:")
        print("from google.colab import files; files.download('/content/output_video.mp4')")
    else:
        print("Video processing failed.")

if __name__ == "__main__":
    main()

In [None]:
# This is the coco dataset based code for this project

# Install required packages
!pip install torch torchvision scipy numpy pillow matplotlib opencv-python
!pip install davis2017-evaluation
!pip install --upgrade torch torchvision torchaudio
!pip install pycocotools

import os
import torch
import zipfile
import random
import numpy as np
from torchvision import transforms, models
from torch.utils.data import DataLoader
from pycocotools.coco import COCO
import torch.optim as optim
from sklearn.metrics import f1_score, accuracy_score
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip
from PIL import Image
import requests
from io import BytesIO
import cv2
from torch import nn
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.roi_heads import MultiScaleRoIAlign


# Function to download and extract COCO dataset
def download_and_extract_coco():
    img_url = 'http://images.cocodataset.org/zips/train2017.zip'
    ann_url = 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip'

    # Download images
    print("Downloading COCO images...")
    img_resp = requests.get(img_url)
    with open('train2017.zip', 'wb') as f:
        f.write(img_resp.content)

    # Unzip images
    with zipfile.ZipFile('train2017.zip', 'r') as zip_ref:
        zip_ref.extractall('/content/coco/images')

    # Download annotations
    print("Downloading COCO annotations...")
    ann_resp = requests.get(ann_url)
    with open('annotations_trainval2017.zip', 'wb') as f:
        f.write(ann_resp.content)

    # Unzip annotations
    with zipfile.ZipFile('annotations_trainval2017.zip', 'r') as zip_ref:
        zip_ref.extractall('/content/coco/annotations')

    print("COCO dataset downloaded and extracted!")


# Dataset Class
class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, root, annFile, transform=None):
        self.coco = COCO(annFile)
        self.img_ids = list(self.coco.imgs.keys())
        self.root = root
        self.transform = transform

    def __getitem__(self, index):
        img_id = self.img_ids[index]
        img_info = self.coco.imgs[img_id]
        path = img_info['file_name']
        img = Image.open(os.path.join(self.root, path)).convert("RGB")

        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)

        boxes = []
        labels = []
        for ann in anns:
            boxes.append(ann['bbox'])
            labels.append(ann['category_id'])

        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)

        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([img_id])
        }

        if self.transform:
            img = self.transform(img)

        return img, target

    def __len__(self):
        return len(self.img_ids)


# Transformation pipeline
transform = Compose([
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    RandomHorizontalFlip(prob=0.5)
])

# Set paths for COCO dataset
train_images_dir = '/content/coco/images/train2017'
train_anns_file = '/content/coco/annotations/instances_train2017.json'
val_images_dir = '/content/coco/images/val2017'
val_anns_file = '/content/coco/annotations/instances_val2017.json'

# Load datasets
train_dataset = CocoDataset(root=train_images_dir, annFile=train_anns_file, transform=transform)
val_dataset = CocoDataset(root=val_images_dir, annFile=val_anns_file, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4)


# Model creation (Faster R-CNN with ResNet50 backbone)
def create_fasterrcnn_model(num_classes, pretrained_backbone=True):
    backbone = models.resnet50(weights='DEFAULT' if pretrained_backbone else None)
    backbone = nn.Sequential(*list(backbone.children())[:-2])
    backbone.out_channels = 2048

    anchor_generator = AnchorGenerator(
        sizes=((32, 64, 128, 256, 512),),
        aspect_ratios=((0.5, 1.0, 2.0),)
    )

    roi_pooler = MultiScaleRoIAlign(
        featmap_names=['0'],
        output_size=7,
        sampling_ratio=2
    )

    model = FasterRCNN(
        backbone=backbone,
        num_classes=num_classes,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=roi_pooler,
        min_size=800,
        max_size=1333
    )
    return model


# Training function
def train_model(model, train_loader, optimizer, device, num_epochs=5):
    model.train()
    train_loss = []
    all_train_labels = []
    all_train_preds = []

    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, targets in train_loader:
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            optimizer.zero_grad()
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            losses.backward()
            optimizer.step()
            running_loss += losses.item()

            # Collect labels and predictions for F1 score and accuracy
            for target, output in zip(targets, model(images)):
                labels = target['labels'].cpu().numpy()
                preds = output['labels'].cpu().numpy()
                all_train_labels.extend(labels)
                all_train_preds.extend(preds)

        avg_loss = running_loss / len(train_loader)
        accuracy = accuracy_score(all_train_labels, all_train_preds)
        f1 = f1_score(all_train_labels, all_train_preds, average='weighted')

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")
        train_loss.append(avg_loss)

    return model, train_loss, accuracy, f1


# Validation function
def validate_model(model, val_loader, device):
    model.eval()
    val_loss = []
    all_val_labels = []
    all_val_preds = []

    with torch.no_grad():
        for images, targets in val_loader:
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            val_loss.append(losses.item())

            # Calculate accuracy and F1 score for validation
            for target, output in zip(targets, model(images)):
                labels = target['labels'].cpu().numpy()
                preds = output['labels'].cpu().numpy()
                all_val_labels.extend(labels)
                all_val_preds.extend(preds)

    avg_val_loss = np.mean(val_loss)
    accuracy = accuracy_score(all_val_labels, all_val_preds)
    f1 = f1_score(all_val_labels, all_val_preds, average='weighted')

    return avg_val_loss, accuracy, f1


# Plotting function
def plot_metrics(train_loss, val_loss, train_accuracy, val_accuracy, train_f1, val_f1, num_epochs):
    epochs = range(1, num_epochs + 1)

    # Plot loss
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, label='Training Loss')
    plt.plot(epochs, val_loss, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plot accuracy and F1 score
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracy, label='Training Accuracy')
    plt.plot(epochs, val_accuracy, label='Validation Accuracy')
    plt.plot(epochs, train_f1, label='Training F1 Score')
    plt.plot(epochs, val_f1, label='Validation F1 Score')
    plt.xlabel('Epochs')
    plt.ylabel('Metric')
    plt.legend()

    plt.show()


# Video Processing and Inference Function
def process_video_frames(model, video_path, output_path, device, threshold=0.5):
    model.eval()
    cap = cv2.VideoCapture(video_path)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, 20.0, (640, 480))

    while(cap.isOpened()):
        ret, frame = cap.read()
        if not ret:
            break

        image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        image_tensor = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            prediction = model(image_tensor)

        boxes = prediction[0]['boxes']
        labels = prediction[0]['labels']
        scores = prediction[0]['scores']

        for box, label, score in zip(boxes, labels, scores):
            if score > threshold:
                x1, y1, x2, y2 = box.cpu().numpy()
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
                cv2.putText(frame, f'ID: {label.item()} - {score:.2f}', (int(x1), int(y1 - 5)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

        out.write(frame)

    cap.release()
    out.release()


# Main function to start training and validation
if __name__ == "__main__":
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = create_fasterrcnn_model(num_classes=91, pretrained_backbone=True)
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Train the model
    trained_model, train_loss, train_accuracy, train_f1 = train_model(model, train_loader, optimizer, device, num_epochs=5)

    # Validate the model
    val_loss, val_accuracy, val_f1 = validate_model(trained_model, val_loader, device)

    # Plot the metrics
    plot_metrics(train_loss, [val_loss]*5, [train_accuracy]*5, [val_accuracy]*5, [train_f1]*5, [val_f1]*5, 5)

    # Process a video for object detection
    process_video_frames(trained_model, 'input_video.mp4', 'output_video.mp4', device)
    print("Video processing completed!")

    # Save the trained model
    torch.save(trained_model.state_dict(), 'trained_model.pth')

if __name__ == "__main__":
    main()

In [None]:
# This code is the updated versiont of original code with validation funtion, values plotting and video processing and also other Improvements.

# Install required packages
!pip install torch torchvision scipy numpy pillow matplotlib opencv-python torchaudio
!pip install davis2017-evaluation

# # Download and extract DAVIS 2017 dataset
!wget -O DAVIS-2017-trainval-480p.zip https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip
!unzip -q DAVIS-2017-trainval-480p.zip -d /content/DAVIS

import os
import cv2
import numpy as np
from PIL import Image
import torch
import torchvision
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import functional as F
from IPython.display import Image as IPImage, display
from google.colab import files
import shutil
import matplotlib.pyplot as plt
import random
from sklearn.metrics import accuracy_score, f1_score
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import warnings
import time


# Define the DAVIS Dataset class
class DAVISDataset(Dataset):
    def __init__(self, davis_root, subset='val', transform=None):
        self.davis_root = davis_root
        self.subset = subset
        self.transform = transform
        self.img_dir = os.path.join(davis_root, 'JPEGImages', '480p')
        self.anno_dir = os.path.join(davis_root, 'Annotations', '480p')
        with open(os.path.join(davis_root, 'ImageSets', '2017', f'{subset}.txt'), 'r') as f:
            self.sequences = f.read().splitlines()
        self.class_map = {
            'bear': 1, 'bmx-bumps': 2, 'boat': 3, 'boxing-fisheye': 4, 'breakdance-flare': 5,
            'bus': 6, 'car-turn': 7, 'cat-girl': 8, 'classic-car': 9, 'color-run': 10,
            'crossing': 11, 'dance-jump': 12, 'dancing': 13, 'disc-jockey': 14, 'dog-agility': 15,
            'dog-gooses': 16, 'dogs-scale': 17, 'drift-turn': 18, 'drone': 19, 'elephant': 20,
            'flamingo': 21, 'hike': 22, 'hockey': 23, 'horsejump-low': 24, 'kid-football': 25,
            'kite-walk': 26, 'koala': 27, 'lady-running': 28, 'lindy-hop': 29, 'longboard': 30,
            'lucia': 31, 'mallard-fly': 32, 'mallard-water': 33, 'miami-surf': 34, 'motocross-bumps': 35,
            'motorbike': 36, 'night-race': 37, 'paragliding': 38, 'planes-water': 39, 'rallye': 40,
            'rhino': 41, 'rollerblade': 42, 'schoolgirls': 43, 'scooter-board': 44, 'scooter-gray': 45,
            'sheep': 46, 'skate-park': 47, 'snowboard': 48, 'soccerball': 49, 'stroller': 50,
            'stunt': 51, 'surf': 52, 'swing': 53, 'tennis': 54, 'tractor-sand': 55,
            'train': 56, 'tuk-tuk': 57, 'upside-down': 58, 'varanus-cage': 59, 'walking': 60,
            'bike-packing': 61, 'blackswan': 62, 'bmx-trees': 63, 'breakdance': 64, 'camel': 65,
            'car-roundabout': 66, 'car-shadow': 67, 'cows': 68, 'dance-twirl': 69, 'dog': 70,
            'dogs-jump': 71, 'drift-chicane': 72, 'drift-straight': 73, 'goat': 74, 'gold-fish': 75,
            'horsejump-high': 76, 'india': 77, 'judo': 78, 'kite-surf': 79, 'lab-coat': 80,
            'libby': 81, 'loading': 82, 'mbike-trick': 83, 'motocross-jump': 84, 'paragliding-launch': 85,
            'parkour': 86, 'pigs': 87, 'scooter-black': 88, 'shooting': 89, 'soapbox': 90
        }
        self.class_names = {v: k for k, v in self.class_map.items()}
        self.samples = []
        for seq in self.sequences:
            img_seq_dir = os.path.join(self.img_dir, seq)
            anno_seq_dir = os.path.join(self.anno_dir, seq)
            img_files = sorted(os.listdir(img_seq_dir))
            for img_file in img_files:
                if img_file.endswith('.jpg'):
                    frame_num = img_file.split('.')[0]
                    anno_file = f"{frame_num}.png"
                    if os.path.exists(os.path.join(anno_seq_dir, anno_file)):
                        self.samples.append((
                            os.path.join(img_seq_dir, img_file),
                            os.path.join(anno_seq_dir, anno_file),
                            seq
                        ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, mask_path, seq_name = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        image_np = np.array(image)
        mask_np = np.array(mask)
        obj_ids = np.unique(mask_np)[1:] if np.unique(mask_np).size > 1 else []
        if len(obj_ids) == 0:
            masks = []
        else:
            masks = mask_np == obj_ids[:, None, None]
        boxes = []
        labels = []
        class_label = self.class_map.get(seq_name, 1)
        for i, mask in enumerate(masks):
            pos = np.where(mask)
            if len(pos[0]) > 0:
                xmin = np.min(pos[1])
                xmax = np.max(pos[1])
                ymin = np.min(pos[0])
                ymax = np.max(pos[0])
                if xmax > xmin and ymax > ymin:
                    boxes.append([xmin, ymin, xmax, ymax])
                    labels.append(class_label)
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros(0, dtype=torch.int64)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }
        if self.transform:
            image, target = self.transform(image, target)
        return image, target

# Define transformation classes
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target

class RandomHorizontalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = F.hflip(image)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
        return image, target

# Define the Faster R-CNN model
def create_fasterrcnn_model(num_classes, pretrained_backbone=True):
    import torchvision
    backbone = torchvision.models.resnet50(weights='DEFAULT' if pretrained_backbone else None)
    backbone = nn.Sequential(*list(backbone.children())[:-2])
    backbone.out_channels = 2048
    anchor_generator = AnchorGenerator(
        sizes=((32, 64, 128, 256, 512),),
        aspect_ratios=((0.5, 1.0, 2.0),)
    )
    roi_pooler = MultiScaleRoIAlign(
        featmap_names=['0'],
        output_size=7,
        sampling_ratio=2
    )
    model = FasterRCNN(
        backbone=backbone,
        num_classes=num_classes,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=roi_pooler,
        min_size=800,
        max_size=1333
    )
    return model

def compute_iou(box1, box2):
    """ Compute IoU between two boxes: box1 and box2 are tensors (x1, y1, x2, y2) """
    xA = max(box1[0], box2[0])
    yA = max(box1[1], box2[1])
    xB = min(box1[2], box2[2])
    yB = min(box1[3], box2[3])
    inter_area = max(0, xB - xA) * max(0, yB - yA)
    box1_area = max(0, box1[2] - box1[0]) * max(0, box1[3] - box1[1])
    box2_area = max(0, box2[2] - box2[0]) * max(0, box2[3] - box2[1])
    union = box1_area + box2_area - inter_area
    return inter_area / union if union != 0 else 0

# Enhanced training function with F1, accuracy, and loss tracking
def train_model(model, data_loader, optimizer, num_epochs, device):
    model.to(device)
    loss_history, acc_history, precision_history, recall_history, f1_history = [], [], [], [], []

    for epoch in range(num_epochs):
        model.train()
        epoch_running_loss = 0.0
        all_preds, all_targets = [], []

        for images, targets in data_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            total_loss = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            epoch_running_loss += total_loss.item()

            # Prediction for metrics
            model.eval()
            with torch.no_grad():
                outputs = model(images)

            model.train()

            for output, target in zip(outputs, targets):
                if len(output['boxes']) == 0 or len(target['boxes']) == 0:
                    warnings.warn("Empty prediction or target, skipping this sample.")
                    continue

                pred_boxes = output['boxes'].cpu()
                pred_labels = output['labels'].cpu()
                tgt_boxes = target['boxes'].cpu()
                tgt_labels = target['labels'].cpu()

                # Simple match logic using IoU > 0.5
                matched_preds, matched_targets = [], []
                for i, tgt_box in enumerate(tgt_boxes):
                    best_iou, best_idx = 0, -1
                    for j, pred_box in enumerate(pred_boxes):
                        iou = compute_iou(tgt_box, pred_box)
                        if iou > best_iou:
                            best_iou = iou
                            best_idx = j
                    if best_iou > 0.5:
                        matched_preds.append(pred_labels[best_idx].item())
                        matched_targets.append(tgt_labels[i].item())

                if matched_preds:
                    all_preds.extend(matched_preds)
                    all_targets.extend(matched_targets)

        # Compute metrics
        epoch_loss = epoch_running_loss / len(data_loader)
        epoch_acc = accuracy_score(all_targets, all_preds) if all_preds else 0.0
        epoch_precision = precision_score(all_targets, all_preds, average='weighted', zero_division=0) if all_preds else 0.0
        epoch_recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0) if all_preds else 0.0
        epoch_f1 = f1_score(all_targets, all_preds, average='weighted') if all_preds else 0.0

        loss_history.append(epoch_loss)
        acc_history.append(epoch_acc)
        precision_history.append(epoch_precision)
        recall_history.append(epoch_recall)
        f1_history.append(epoch_f1)

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}, Precision: {epoch_precision:.4f}, Recall: {epoch_recall:.4f}, F1: {epoch_f1:.4f}")

    return loss_history, acc_history, precision_history, recall_history, f1_history

# Add this validation function after the train_model function
def validate_model(model, data_loader, device):
    model.eval()
    model.to(device)
    all_preds, all_targets = []
    val_loss = 0.0

    with torch.no_grad():
        for images, targets in data_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Forward pass through the model
            outputs = model(images)

            # Compute loss (assuming model returns losses)
            loss_dict = model(images, targets)
            # Sum up all losses (e.g., classification and bbox regression losses)
            losses = sum(loss for loss in loss_dict.values())
            val_loss += losses.item()

            # Handle predictions and targets for evaluation
            for output, target in zip(outputs, targets):
                if len(output['boxes']) == 0 or len(target['boxes']) == 0:
                    warnings.warn("Empty prediction or target in validation.")
                    continue

                pred_boxes = output['boxes'].cpu()
                pred_labels = output['labels'].cpu()
                tgt_boxes = target['boxes'].cpu()
                tgt_labels = target['labels'].cpu()

                matched_preds, matched_targets = [], []
                for i, tgt_box in enumerate(tgt_boxes):
                    best_iou, best_idx = 0, -1
                    for j, pred_box in enumerate(pred_boxes):
                        iou = compute_iou(tgt_box, pred_box)
                        if iou > best_iou:
                            best_iou = iou
                            best_idx = j
                    if best_iou > 0.5:
                        matched_preds.append(pred_labels[best_idx].item())
                        matched_targets.append(tgt_labels[i].item())

                if matched_preds:
                    all_preds.extend(matched_preds)
                    all_targets.extend(matched_targets)

    # Compute metrics: Accuracy, Precision, Recall, F1 Score
    epoch_acc = accuracy_score(all_targets, all_preds) if all_preds else 0.0
    epoch_precision = precision_score(all_targets, all_preds, average='weighted', zero_division=0) if all_preds else 0.0
    epoch_recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0) if all_preds else 0.0
    epoch_f1 = f1_score(all_targets, all_preds, average='weighted') if all_preds else 0.0

    # Print metrics
    print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {epoch_acc:.4f}, Precision: {epoch_precision:.4f}, Recall: {epoch_recall:.4f}, F1: {epoch_f1:.4f}")

    # Return loss and metrics
    return val_loss, epoch_acc, epoch_precision, epoch_recall, epoch_f1


def plot_training_curves(loss_history, acc_history, precision_history, recall_history, f1_history):
    epochs = range(1, len(loss_history)+1)
    plt.figure(figsize=(10,6))
    plt.plot(epochs, loss_history, 'r-', label='Loss')
    plt.plot(epochs, acc_history, 'b--', label='Accuracy')
    plt.plot(epochs, precision_history, 'g-.', label='Precision')
    plt.plot(epochs, recall_history, 'm-', label='Recall')
    plt.plot(epochs, f1_history, 'c-', label='F1 Score')
    plt.xlabel("Epochs")
    plt.ylabel("Metric")
    plt.title("Training Metrics Over Epochs")
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_validation_metrics(val_losses, val_accuracies, val_precision, val_recall, val_f1_scores):
    epochs = list(range(1, len(val_losses)+1))
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.plot(epochs, val_losses, 'r-o')
    plt.title("Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")

    plt.subplot(1, 3, 2)
    plt.plot(epochs, val_accuracies, 'g-o')
    plt.title("Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")

    plt.subplot(1, 3, 3)
    plt.plot(epochs, val_f1_scores, 'b-o')
    plt.title("Validation F1 Score")
    plt.xlabel("Epoch")
    plt.ylabel("F1 Score")

    plt.tight_layout()
    plt.show()


# Function to extract frames
def extract_frames(video_path, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return None
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_paths = []
    frame_count = 0
    try:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame_count += 1
            frame_path = os.path.join(output_dir, f"frame_{frame_count:06d}.jpg")
            cv2.imwrite(frame_path, frame)
            frame_paths.append(frame_path)
            print(f"Extracting frame {frame_count}/{total_frames}", end='\r')
    except Exception as e:
        print(f"Error during frame extraction: {str(e)}")
        cap.release()
        return None
    cap.release()
    print(f"\nExtracted {frame_count} frames to {output_dir}")
    return frame_paths, width, height, fps

# Function to visualize detections
def visualize_detections(image, prediction, class_names, threshold=0.3):
    image_np = image.permute(1, 2, 0).cpu().numpy()
    image_np = (image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255
    image_np = image_np.astype(np.uint8).copy()
    for box, label, score in zip(prediction['boxes'].cpu().numpy(),
                                prediction['labels'].cpu().numpy(),
                                prediction['scores'].cpu().numpy()):
        if score >= threshold:
            x1, y1, x2, y2 = box.astype(int)
            class_name = class_names.get(label, 'Unknown')
            cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(image_np, f'{class_name}: {score:.2f}', (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
    return image_np

# Function to process video frames
def process_video_frames(model, video_path, class_names, device, frames_dir='/content/frames', output_frames_dir='/content/output_frames', threshold=0.3):
    if os.path.exists(frames_dir):
        shutil.rmtree(frames_dir)
    if os.path.exists(output_frames_dir):
        shutil.rmtree(output_frames_dir)
    os.makedirs(frames_dir)
    os.makedirs(output_frames_dir)
    frame_paths, width, height, fps = extract_frames(video_path, frames_dir)
    if frame_paths is None:
        print("Error: Frame extraction failed.")
        return None
    total_frames = len(frame_paths)
    if total_frames == 0:
        print("Error: No frames extracted from video.")
        return None
    transform = Compose([
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    model.eval()
    sample_frames = []
    frame_count = 0
    try:
        with torch.no_grad():
            for frame_path in frame_paths:
                frame_count += 1
                print(f"Processing frame {frame_count}/{total_frames}", end='\r')
                frame_rgb = cv2.cvtColor(cv2.imread(frame_path), cv2.COLOR_BGR2RGB)
                image = Image.fromarray(frame_rgb)
                image_tensor, _ = transform(image, {})
                image_tensor = image_tensor.unsqueeze(0).to(device)
                outputs = model(image_tensor)[0]
                annotated_frame = visualize_detections(image_tensor[0], outputs, class_names, threshold)
                output_frame_path = os.path.join(output_frames_dir, f"annotated_frame_{frame_count:06d}.jpg")
                cv2.imwrite(output_frame_path, cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR))
                if frame_count % (total_frames // 5 + 1) == 0 and len(sample_frames) < 5:
                    sample_frames.append(annotated_frame)
                del image_tensor, outputs
                torch.cuda.empty_cache() if device.type == 'cuda' else None
    except Exception as e:
        print(f"Error during frame processing: {str(e)}")
        return None
    print(f"\nProcessed {frame_count} frames. Annotated frames saved to {output_frames_dir}")
    for i, frame in enumerate(sample_frames):
        sample_path = f'/content/sample_frame_{i}.png'
        cv2.imwrite(sample_path, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        display(IPImage(sample_path))
    shutil.rmtree(frames_dir)
    print("Cleaned up temporary frames directory")
    return output_frames_dir

# Main function
def main():
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    davis_root = '/content/DAVIS/DAVIS'
    transform = Compose([
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        RandomHorizontalFlip(prob=0.5)
    ])

    # Optional: Train the model (uncomment to train)
    train_dataset = DAVISDataset(davis_root, subset='train', transform=transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=True,
        num_workers=2,
        collate_fn=lambda x: tuple(zip(*x))
    )
    num_classes = 91
    model = create_fasterrcnn_model(num_classes=num_classes)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
    model = train_model(model, train_loader, optimizer, device, num_epochs=5)
    torch.save(model.state_dict(), '/content/fasterrcnn_resnet50_davis.pth')
    print("Training completed!")

    from dataset import get_davis_dataloader  # Custom dataloader
    from model import get_model  # Your Faster R-CNN model setup

    # Load model and data
    model = get_model(num_classes=91)  # Adjust num_classes
    model.to(device)

    _, val_loader = get_davis_dataloader(batch_size=2)  # Custom function

    val_losses = []
    val_accuracies = []
    val_f1_scores = []

    num_epochs = 3
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        val_loss, val_acc, val_f1 = validate_model(model, val_loader, device)

        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        val_f1_scores.append(val_f1)

    # Plot all metrics after validation
    plot_validation_metrics(val_losses, val_accuracies, val_f1_scores)

    # Load dataset for class names
    dataset = DAVISDataset(davis_root, subset='val', transform=transform)
    class_names = dataset.class_names

    # Load the model
    num_classes = 91
    model = create_fasterrcnn_model(num_classes=num_classes)
    try:
        model.load_state_dict(torch.load('/content/fasterrcnn_resnet50_davis.pth', map_location=device))
    except FileNotFoundError:
        print("Error: Model weights file 'fasterrcnn_resnet50_davis.pth' not found.")
        print("Please upload the trained model or uncomment the training section to train the model.")
        print("To proceed without training, upload 'fasterrcnn_resnet50_davis.pth' now:")
        uploaded = files.upload()
        if 'fasterrcnn_resnet50_davis.pth' in uploaded:
            model.load_state_dict(torch.load('/content/fasterrcnn_resnet50_davis.pth', map_location=device))
        else:
            print("Error: Model weights not uploaded. Cannot proceed.")
            return
    model.to(device)

    # Upload video
    print("Please upload your video file (e.g., MP4 format):")
    uploaded = files.upload()
    if not uploaded:
        print("Error: No video file uploaded.")
        return
    video_path = list(uploaded.keys())[0]

    # Process video frames
    output_frames_dir = process_video_frames(
        model, video_path, class_names, device,
        frames_dir='/content/frames',
        output_frames_dir='/content/output_frames',
        threshold=0.3
    )

    if output_frames_dir:
        print(f"Processed frames saved to {output_frames_dir}")
        print("To download all frames, run the following cell to zip the output_frames directory:")
        print("Then download the zip file from the Colab file explorer.")
        # Zip the output frames for easy download
        !zip -r /content/output_frames.zip /content/output_frames
    else:
        print("Frame processing failed. Please check error messages.")

if __name__ == "__main__":
    main()