<a href="https://colab.research.google.com/github/Paul-Steve-Mithun/FSL_AUTONOMOUS_DRIVING/blob/main/FSL_TEST_VIDEO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install torch-summary torch-lr-finder timm easyfsl
!pip install torch torchvision torchaudio
!pip install matplotlib tqdm

Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Collecting torch-lr-finder
  Downloading torch_lr_finder-0.2.1-py3-none-any.whl (11 kB)
Collecting timm
  Downloading timm-1.0.7-py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting easyfsl
  Downloading easyfsl-1.5.0-py3-none-any.whl (72 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.8/72.8 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=0.4.1->torch-lr-finder)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=0.4.1->torch-lr-finder)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=0.4.1->torch-lr-finder)
  Using cached nvidia

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models import resnet34
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from torchsummary import summary
from tqdm import tqdm
from easyfsl.samplers import TaskSampler
from google.colab.patches import cv2_imshow
from torchvision.ops import batched_nms

# Set random seed for reproducibility
torch.manual_seed(0)

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.images = []
        self.labels = []

        self.class_names = sorted(os.listdir(data_path))
        for label, class_name in enumerate(self.class_names):
            current_folder = os.path.join(data_path, class_name)
            for file in os.listdir(current_folder):
                fullpath = os.path.join(current_folder, file)
                self.images.append(fullpath)
                self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Image transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Prototypical Networks class
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(self, support_images: torch.Tensor, support_labels: torch.Tensor, query_images: torch.Tensor) -> torch.Tensor:
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)
        n_way = len(torch.unique(support_labels))
        z_proto = torch.cat([z_support[torch.nonzero(support_labels == label)].mean(0) for label in range(n_way)])
        dists = torch.cdist(z_query, z_proto)
        scores = -dists
        return scores

# Load the support set
def load_support_set(data_path, n_way, n_shot):
    transform = transforms.Compose([
        transforms.Resize((84, 84)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    support_images = []
    support_labels = []
    class_names = sorted(os.listdir(data_path))

    for label, class_name in enumerate(class_names):
        current_folder = os.path.join(data_path, class_name)
        for file in os.listdir(current_folder)[:n_shot]:
            fullpath = os.path.join(current_folder, file)
            image = Image.open(fullpath).convert('RGB')
            image = transform(image)
            support_images.append(image)
            support_labels.append(label)

    return torch.stack(support_images), torch.tensor(support_labels), class_names

# Load model checkpoint
def load_checkpoint(file_path, model, optimizer):
    checkpoint = torch.load(file_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    mean_accuracy = checkpoint['mean_accuracy']
    std_accuracy = checkpoint['std_accuracy']
    print(f"Checkpoint loaded from epoch {epoch} with mean accuracy: {mean_accuracy:.2f}%")
    return epoch, mean_accuracy, std_accuracy

# Determine if a GPU is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

detection_model = fasterrcnn_resnet50_fpn(pretrained=True)
detection_model.to(device) # Move model to the selected device
detection_model.eval()

# Function to get bounding boxes with confidence threshold and NMS
def get_bounding_boxes(image, confidence_threshold=0.8, nms_threshold=0.1):
    image_tensor = F.to_tensor(image).unsqueeze(0).to(device)

    with torch.no_grad():
        detections = detection_model(image_tensor)[0]

    scores = detections['scores'].cpu().numpy()
    boxes = detections['boxes'].cpu().numpy()
    labels = detections['labels'].cpu().numpy()

    # Filter out boxes with low confidence
    high_conf_indices = scores > confidence_threshold
    boxes = boxes[high_conf_indices]
    scores = scores[high_conf_indices]
    labels = labels[high_conf_indices]

    # Apply non-maximum suppression
    keep = batched_nms(torch.tensor(boxes), torch.tensor(scores), torch.tensor(labels), nms_threshold)
    boxes = boxes[keep.numpy()]
    labels = labels[keep.numpy()]

    return boxes, labels

# Function to filter boxes by size
def filter_boxes_by_size(boxes, size_threshold=50):
    filtered_boxes = []
    for box in boxes:
        width = box[2] - box[0]
        height = box[3] - box[1]
        if width > size_threshold and height > size_threshold:
            filtered_boxes.append(box)
    return np.array(filtered_boxes)

# Preprocess image with bounding box
def preprocess_image(image, bbox, target_size=(84, 84)):
    transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
    return transform(cropped_image)

# Classify objects
def classify_objects(model, support_images, support_labels, query_images, class_names):
    model.eval()
    with torch.no_grad():
        outputs = model(support_images, support_labels, query_images)
    predicted_labels = torch.max(outputs, 1)[1].cpu().numpy()
    predicted_class_names = [class_names[label] for label in predicted_labels]
    return predicted_class_names

# Display image with labels
def display_image_with_labels(image, bboxes, labels):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    ax = plt.gca()
    for bbox, label in zip(bboxes, labels):
        rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        plt.text(bbox[0], bbox[1], label, color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))
    plt.axis('off')
    plt.show()

# Initialize model and optimizer
convolutional_network = resnet34(pretrained=True)
convolutional_network.fc = nn.Flatten()
model = PrototypicalNetworks(convolutional_network)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:03<00:00, 50.4MB/s]
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 94.6MB/s]


In [None]:
checkpoint_path = '/content/drive/MyDrive/prototypical_networks_model_epoch20.pth'
loaded_epoch, loaded_mean_accuracy, loaded_std_accuracy = load_checkpoint(checkpoint_path, model, optimizer)

Checkpoint loaded from epoch 20 with mean accuracy: 86.25%


In [None]:
# Main workflow
support_set_path = '/content/drive/MyDrive/datasets/Steve_Dataset'
query_video_path = '/content/drive/MyDrive/datasets/footage_1.mp4'  # Path to the uploaded video
output_video_path = '/content/drive/MyDrive/datasets/output_video_3.mp4'
n_way = 4
n_shot = 150

In [None]:
# Load support set
support_images, support_labels, class_names = load_support_set(support_set_path, n_way, n_shot)
support_images, support_labels = support_images.to(device), support_labels.to(device)


In [None]:
# Process video
import time
cap = cv2.VideoCapture(query_video_path)

# Get video properties
fps = int(cap.get(cv2.CAP_PROP_FPS))
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # or 'XVID'
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

# Latency measurement variables
latencies = []

# Process video frames with progress bar
for _ in tqdm(range(frame_count), desc="Processing Video"):
    ret, frame = cap.read()
    if not ret:
        break

    # Convert frame to PIL image
    frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    # Start latency timer
    start_time = time.time()

    # Get bounding boxes
    bboxes, _ = get_bounding_boxes(frame_pil)

    # Filter bounding boxes based on size threshold
    filtered_bboxes = filter_boxes_by_size(bboxes, size_threshold=40)

    if len(filtered_bboxes) > 0:
        query_images = torch.stack([preprocess_image(frame_pil, bbox) for bbox in filtered_bboxes]).to(device)
        predicted_labels = classify_objects(model, support_images, support_labels, query_images, class_names)

        for bbox, label in zip(filtered_bboxes, predicted_labels):
            cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255, 0, 0), 2)
            cv2.putText(frame, label, (int(bbox[0]), int(bbox[1])-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)

    # End latency timer
    end_time = time.time()
    latencies.append(end_time - start_time)

    out.write(frame)

cap.release()
out.release()
cv2.destroyAllWindows()
print(f'Video saved to {output_video_path}')

# Calculate average latency
average_latency = sum(latencies) / len(latencies)
print(f"Average latency per frame: {average_latency:.4f} seconds")

  return F.conv2d(input, weight, bias, self.stride,
Processing Video: 100%|██████████| 791/791 [03:59<00:00,  3.30it/s]


Video saved to /content/drive/MyDrive/datasets/output_video_3.mp4
Average latency per frame: 0.2708 seconds
