In [None]:
# @title Load Dataset

!kaggle datasets download -d debeshjha1/kvasirseg

!unzip kvasirseg.zip


In [None]:
# @title Imports
import os
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pandas as pd
import numpy as np
import torch
import shutil

from torchvision import datasets, transforms


import torch.nn as nn

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torch.utils.data import Dataset, DataLoader

from collections import Counter
import random

import torch.optim as optim

from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

# Set random seeds for full reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)

torch.backends.cudnn.benchmark = True





In [None]:
# @title Config

DATASET = "Kvasir-SEG"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 1
CHECKPOINT_FILE = "/content/drive/MyDrive/models/checkpoint.pth.tar" # pruned_checkpoint
IMG_DIR = "/content/Kvasir-SEG/Kvasir-SEG/images/"
LABEL_DIR = "/content/Kvasir-SEG/Kvasir-SEG/bbox/"
DATASET_DIR = "/content/Kvasir-SEG/Kvasir-SEG/"

CONF_THRESHOLD = 0.6
MAP_IOU_THRESH = 0.5
NMS_IOU_THRESH = 0.45
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
BATCH_SIZE = 30
NUM_EPOCHS = 100000000
NUM_WORKERS = 4
PIN_MEMORY = True
LOAD_MODEL = False
SAVE_MODEL = False

CRITERION = "L1"

split_pct = 0.1
IMAGE_SIZE = 416

S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]

img_lst = os.listdir(IMG_DIR)
label_lst = os.listdir(LABEL_DIR)

sorted_img_lst = sorted(img_lst)
sorted_label_lst = sorted(label_lst)

ANCHORS = [
    [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
    [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
    [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
]  # Note these have been rescaled to be between [0, 1]

# ANCHORS = [[(0.87, 0.91), (0.61, 0.75), (0.57, 0.5)], [(0.45, 0.63), (0.4, 0.45), (0.28, 0.47)], [(0.29, 0.32), (0.19, 0.24), (0.09, 0.11)]]

KVASIR_CLASSES = [
    "polyp"
]



In [None]:
# @title K-Means Clustering (Anchors)
def iou(box, clusters):
    """
    Calculate the IoU between a box and k clusters (anchor boxes).
    """
    x = np.minimum(clusters[:, 0], box[0])
    y = np.minimum(clusters[:, 1], box[1])

    if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0:
        raise ValueError("Box has no area")

    intersection = x * y
    box_area = box[0] * box[1]
    cluster_area = clusters[:, 0] * clusters[:, 1]

    union = box_area + cluster_area - intersection
    return intersection / union

def kmeans_anchors(boxes, k, dist_fn=iou, max_iter=10000):
    """
    K-Means clustering with IoU-based distance for anchor box generation.

    Parameters:
    - boxes: A numpy array of shape (num_boxes, 2), where each row is [width, height].
    - k: The number of anchors (clusters).
    - dist_fn: The distance function, which should be IoU in this case.
    - max_iter: Maximum number of iterations for the K-Means algorithm.

    Returns:
    - anchors: A numpy array of shape (k, 2) containing the optimized anchor boxes (width, height).
    """
    # Initialize the clusters randomly by choosing k boxes from the dataset
    indices = np.random.choice(boxes.shape[0], k, replace=False)
    clusters = boxes[indices]
    count = 0
    for iteration in range(max_iter):
        print(count)
        count += 1
        # Assign each box to the closest cluster (highest IoU)
        distances = np.array([1 - dist_fn(box, clusters) for box in boxes])
        nearest_clusters = np.argmin(distances, axis=1)

        # Recalculate clusters as the mean of all boxes assigned to them
        new_clusters = np.array([boxes[nearest_clusters == i].mean(axis=0) for i in range(k)])

        # Check for convergence (if clusters do not change)
        if np.all(clusters == new_clusters):
            print("break")
            break
        clusters = new_clusters

    return clusters

def get_bounding_boxes(img_lst, label_lst, print_examples=False):
    """
    Extract bounding box widths and heights from a list of images and their corresponding label files.

    Parameters:
    - img_lst: List of image filenames.
    - label_lst: List of corresponding label filenames (CSV).
    - print_examples: Boolean flag to print out width/height examples.

    Returns:
    - A numpy array of shape (num_boxes, 2), where each row contains [width, height] for each bounding box.
    """
    boxes = []

    # Loop through each image and label file
    for i in range(len(label_lst)):
        img_path = os.path.join(IMG_DIR, img_lst[i])
        label_path = os.path.join(LABEL_DIR, label_lst[i])

        # Load the image to get its dimensions
        img = Image.open(img_path)
        image_width, image_height = img.size

        # Read the CSV label file containing bounding box data
        df = pd.read_csv(label_path)

        # Extract bounding box coordinates and class name
        bounding_boxes = df[['class_name', 'xmin', 'ymin', 'xmax', 'ymax']].to_numpy()

        # Loop through bounding boxes and compute width and height
        for bbox in bounding_boxes:
            class_name, xmin, ymin, xmax, ymax = bbox

            # Calculate width and height of the bounding box
            width = xmax - xmin
            height = ymax - ymin

            # Normalize width and height relative to the image size
            normalized_width = width / image_width
            normalized_height = height / image_height

            # Optionally print a few examples
            if print_examples and len(boxes) < 5:  # Print only the first few examples
                fig, ax = plt.subplots()

                ax.imshow(img)
                # Plot each bounding box
                for index, row in df.iterrows():
                    x_min, y_min, x_max, y_max = row['xmin'], row['ymin'], row['xmax'], row['ymax']
                    width = x_max - x_min
                    height = y_max - y_min
                    # Create a rectangle patch
                    rect = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='r', facecolor='none')
                    # Add the patch to the axis
                    ax.add_patch(rect)

                plt.show()

                plt.show()
                print(label_path)
                print("image size: ", img.size)
                print(f"Original Width: {width}, Original Height: {height}")
                print(f"Normalized Width: {normalized_width}, Normalized Height: {normalized_height}")

            # Append the [width, height] to the list
            boxes.append([normalized_width, normalized_height])

    # Convert the list to a numpy array of shape (num_boxes, 2)
    return np.array(boxes)




# Get the bounding box widths and heights
bounding_boxes = get_bounding_boxes(sorted_img_lst, sorted_label_lst, print_examples=True)
print(len(bounding_boxes))
print(len(label_lst))

# Perform IoU-based K-Means clustering to get optimized anchor boxes
anchors = kmeans_anchors(bounding_boxes, k=9)

print(f"Optimized Anchors (width, height):\n{anchors}")

areas = np.prod(anchors, axis=1)  # Calculate the area (width * height)
print(areas)
sorted_indices = np.argsort(-areas)  # Sort indices based on area
sorted_anchors = anchors[sorted_indices]
# Rearrange into the desired format (largest anchors first, grouped into 3 per scale)
formatted_anchors = [
    list(map(tuple, sorted_anchors[:3])),  # Largest anchors
    list(map(tuple, sorted_anchors[3:6])), # Medium anchors
    list(map(tuple, sorted_anchors[6:]))   # Smallest anchors
]
formatted_anchors = [[(float(x), float(y)) for (x, y) in row] for row in formatted_anchors]
print(formatted_anchors)

ANCHORS = formatted_anchors


In [None]:
# @title Transforms
train_transform = A.Compose([A.Resize(IMAGE_SIZE,IMAGE_SIZE),
                             A.Rotate(limit=15,p=0.1),
                             A.HorizontalFlip(p=0.5),
                             A.Normalize(mean=(0,0,0),std=(1,1,1),max_pixel_value=255),
                             ToTensorV2()])

val_transform = A.Compose([A.Resize(IMAGE_SIZE,IMAGE_SIZE),
                           A.Normalize(mean=(0,0,0),std=(1,1,1),max_pixel_value=255),
                           ToTensorV2()])

scale = 1.1
train_transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=int(IMAGE_SIZE * scale)),
        A.PadIfNeeded(
            min_height=int(IMAGE_SIZE * scale),
            min_width=int(IMAGE_SIZE * scale),
            border_mode=cv2.BORDER_CONSTANT,
            value=0,
        ),
        A.RandomCrop(width=IMAGE_SIZE, height=IMAGE_SIZE),
        A.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
        A.OneOf(
            [
                A.ShiftScaleRotate(
                    rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
                ),
                A.Affine(shear=15, p=0.5),
            ],
            p=1.0,
        ),
        A.HorizontalFlip(p=0.5),
        A.Blur(p=0.1),
        A.CLAHE(p=0.1),
        A.Posterize(p=0.1),
        A.ToGray(p=0.1),
        A.ChannelShuffle(p=0.05),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[],),
)

test_transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=IMAGE_SIZE),
        A.PadIfNeeded(
            min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT, value=0
        ),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ToTensorV2(),
    ],
    bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[]),
)


In [None]:
# @title Show Image

def show_images(img_lst, label_lst, loops=10):
    for i in range(loops):
        img_path = os.path.join(IMG_DIR, img_lst[i])
        label_path = os.path.join(LABEL_DIR, label_lst[i])
        img = Image.open(img_path)
        image_width, image_height = img.size

        print(img_path)
        print(img.size)
        print(type(img))
        plt.imshow(img)
        plt.show()


        print(label_path)

        # Read the CSV file
        df = pd.read_csv(label_path)

        # Display the dataframe
        print(df)

        bounding_boxes = df[['class_name', 'xmin', 'ymin', 'xmax', 'ymax']].to_numpy()
        print(bounding_boxes)
        print(bounding_boxes.shape)

        fig, ax = plt.subplots()

        ax.imshow(img)
        # Plot each bounding box
        for index, row in df.iterrows():
            x_min, y_min, x_max, y_max = row['xmin'], row['ymin'], row['xmax'], row['ymax']
            width = x_max - x_min
            height = y_max - y_min
            # Create a rectangle patch
            rect = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='r', facecolor='none')
            # Add the patch to the axis
            ax.add_patch(rect)

        plt.show()

        # Assuming 'polyp' is the only class, so class_id = 1
        df['class_name'] = 1

        # Normalize bounding box coordinates
        df['x_center'] = (df['xmin'] + df['xmax']) / 2 / image_width
        df['y_center'] = (df['ymin'] + df['ymax']) / 2 / image_height
        df['width'] = (df['xmax'] - df['xmin']) / image_width
        df['height'] = (df['ymax'] - df['ymin']) / image_height

        # Prepare YOLO format data
        yolo_format = df[['class_name', 'x_center', 'y_center', 'width', 'height']].to_numpy()

        # Display the dataframe
        print(df)

        print(yolo_format)

        print(yolo_format.shape)

        fig, ax = plt.subplots()
        ax.imshow(img)
        # Create a Rectangle patch
        for box in yolo_format:
            box = box[1:]
            upper_left_x = box[0] - box[2] / 2
            upper_left_y = box[1] - box[3] / 2
            rect = patches.Rectangle(
                (upper_left_x * image_width, upper_left_y * image_height),
                box[2] * image_width,
                box[3] * image_height,
                linewidth=2,
                edgecolor='r',
                facecolor="none",
            )
            # Add the patch to the Axes
            ax.add_patch(rect)


        plt.show()




        print("----------------------------------------------------")

show_images(sorted_img_lst, sorted_label_lst)

In [None]:
# @title Shuffling Data

permuted_train_img_lst = np.random.permutation(np.array(sorted_img_lst))
permuted_train_label_lst = [x.replace(".jpg", ".csv") for x in permuted_train_img_lst]
print(permuted_train_img_lst[:5])
print(permuted_train_label_lst[:5])




In [None]:
# @title Splitting into Training and Validation

train_images_list = permuted_train_img_lst[int(split_pct*len(permuted_train_img_lst)) :]
train_labels_list = permuted_train_label_lst[int(split_pct*len(permuted_train_label_lst)) :]
print(len(train_labels_list))

val_images_list = permuted_train_img_lst[: int(split_pct*len(permuted_train_img_lst))]
val_labels_list = permuted_train_label_lst[: int(split_pct*len(permuted_train_label_lst))]
print(len(val_labels_list))


In [None]:
# @title Utils for YOLOv3

def iou_width_height(boxes1, boxes2):
    """
    Parameters:
        boxes1 (tensor): width and height of the first bounding boxes
        boxes2 (tensor): width and height of the second bounding boxes
    Returns:
        tensor: Intersection over union of the corresponding boxes
    """
    intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
        boxes1[..., 1], boxes2[..., 1]
    )
    union = (
        boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
    )
    return intersection / union

def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
    """
    Video explanation of this function:
    https://youtu.be/XXYG5ZWtjj0

    This function calculates intersection over union (iou) given pred boxes
    and target boxes.

    Parameters:
        boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
        boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
        box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)

    Returns:
        tensor: Intersection over union for all examples
    """

    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    if box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)


def cells_to_bboxes(predictions, anchors, S, is_preds=True):
    """
    Scales the predictions coming from the model to
    be relative to the entire image such that they for example later
    can be plotted or.
    INPUT:
    predictions: tensor of size (N, 3, S, S, num_classes+5)
    anchors: the anchors used for the predictions
    S: the number of cells the image is divided in on the width (and height)
    is_preds: whether the input is predictions or the true bounding boxes
    OUTPUT:
    converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
                      object score, bounding box coordinates
    """
    BATCH_SIZE = predictions.shape[0]
    num_anchors = len(anchors)
    box_predictions = predictions[..., 1:5]
    if is_preds:
        anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
        box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
        box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
        scores = torch.sigmoid(predictions[..., 0:1])
        best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
    else:
        scores = predictions[..., 0:1]
        best_class = predictions[..., 5:6]

    cell_indices = (
        torch.arange(S)
        .repeat(predictions.shape[0], 3, S, 1)
        .unsqueeze(-1)
        .to(predictions.device)
    )
    x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
    y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
    w_h = 1 / S * box_predictions[..., 2:4]
    converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
    return converted_bboxes.tolist()


def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
    """
    Video explanation of this function:
    https://youtu.be/YDkjWEN8jNA

    Does Non Max Suppression given bboxes

    Parameters:
        bboxes (list): list of lists containing all bboxes with each bboxes
        specified as [class_pred, prob_score, x1, y1, x2, y2]
        iou_threshold (float): threshold where predicted bboxes is correct
        threshold (float): threshold to remove predicted bboxes (independent of IoU)
        box_format (str): "midpoint" or "corners" used to specify bboxes

    Returns:
        list: bboxes after performing NMS given a specific IoU threshold
    """

    assert type(bboxes) == list

    bboxes = [box for box in bboxes if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)

        bboxes = [
            box
            for box in bboxes
            if box[0] != chosen_box[0]
            or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format,
            )
            < iou_threshold
        ]

        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms

def plot_image(image, boxes, target_boxes=None):
    """Plots predicted bounding boxes on the image"""
    cmap = plt.get_cmap("tab20b")
    class_labels = KVASIR_CLASSES
    colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels)+1)]
    im = np.array(image)
    height, width, _ = im.shape

    # Create figure and axes
    if target_boxes is not None:
      fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
      # Display the image
      ax1.imshow(im)
      ax2.imshow(im)
    else:
      # Create figure and axes
      fig, ax = plt.subplots(1)
      # Display the image
      ax.imshow(im)
    # box[0] is x midpoint, box[2] is width
    # box[1] is y midpoint, box[3] is height

    # Plot the target bounding boxes if provided
    if target_boxes is not None:
        for box in target_boxes:
            assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
            class_pred = box[0]
            box = box[2:]
            upper_left_x = box[0] - box[2] / 2
            upper_left_y = box[1] - box[3] / 2
            rect = patches.Rectangle(
                (upper_left_x * width, upper_left_y * height),
                box[2] * width,
                box[3] * height,
                linewidth=5,
                edgecolor=colors[1],
                facecolor="none",
            )
            # Add the patch to the Axes
            ax1.add_patch(rect)
            ax1.text(
                upper_left_x * width,
                upper_left_y * height,
                s='target',
                color="white",
                verticalalignment="top",
                bbox={"color": colors[int(class_pred)], "pad": 0},
            )
        # Create a Rectangle patch
        for box in boxes:
            assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
            class_pred = box[0]
            box = box[2:]
            upper_left_x = box[0] - box[2] / 2
            upper_left_y = box[1] - box[3] / 2
            rect = patches.Rectangle(
                (upper_left_x * width, upper_left_y * height),
                box[2] * width,
                box[3] * height,
                linewidth=2,
                edgecolor=colors[int(class_pred)],
                facecolor="none",
            )
            # Add the patch to the Axes
            ax2.add_patch(rect)
            ax2.text(
                upper_left_x * width,
                upper_left_y * height,
                s=class_labels[int(class_pred)],
                color="white",
                verticalalignment="top",
                bbox={"color": colors[int(class_pred)], "pad": 0},
            )
    else:
      # Create a Rectangle patch
      for box in boxes:
          assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
          class_pred = box[0]
          box = box[2:]
          upper_left_x = box[0] - box[2] / 2
          upper_left_y = box[1] - box[3] / 2
          rect = patches.Rectangle(
              (upper_left_x * width, upper_left_y * height),
              box[2] * width,
              box[3] * height,
              linewidth=2,
              edgecolor=colors[int(class_pred)],
              facecolor="none",
          )
          # Add the patch to the Axes
          ax.add_patch(rect)
          plt.text(
              upper_left_x * width,
              upper_left_y * height,
              s=class_labels[int(class_pred)],
              color="white",
              verticalalignment="top",
              bbox={"color": colors[int(class_pred)], "pad": 0},
          )

    plt.show()


def get_loaders(train_images_list, train_labels_list, val_images_list, val_labels_list):

    train_dataset = YOLODataset(
        train_images_list,
        train_labels_list,
        transform=train_transforms,
        S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
        img_dir=IMG_DIR,
        label_dir=LABEL_DIR,
        anchors=ANCHORS,
    )
    test_dataset = YOLODataset(
        val_images_list,
        val_labels_list,
        transform=test_transforms,
        S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
        img_dir=IMG_DIR,
        label_dir=LABEL_DIR,
        anchors=ANCHORS,
    )
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=True,
        drop_last=False,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=False,
        drop_last=False,
    )

    train_eval_dataset = YOLODataset(
        val_images_list,
        val_labels_list,
        transform=test_transforms,
        S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
        img_dir=IMG_DIR,
        label_dir=LABEL_DIR,
        anchors=ANCHORS,
    )
    train_eval_loader = DataLoader(
        dataset=train_eval_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=False,
        drop_last=False,
    )

    return train_loader, test_loader, train_eval_loader

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def check_class_accuracy(model, loader, threshold):
    model.eval()
    tot_class_preds, correct_class = 0, 0
    tot_noobj, correct_noobj = 0, 0
    tot_obj, correct_obj = 0, 0

    for idx, (x, y) in enumerate(tqdm(loader)):
        x = x.to(DEVICE)
        with torch.no_grad():
            out = model(x)

        for i in range(3):
            y[i] = y[i].to(DEVICE)
            obj = y[i][..., 0] == 1 # in paper this is Iobj_i
            noobj = y[i][..., 0] == 0  # in paper this is Iobj_i

            correct_class += torch.sum(
                torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
            )
            tot_class_preds += torch.sum(obj)

            obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
            correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
            tot_obj += torch.sum(obj)
            correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
            tot_noobj += torch.sum(noobj)

    print(f"Class accuracy is: {(correct_class/(tot_class_preds+1e-16))*100:2f}%")
    print(f"No obj accuracy is: {(correct_noobj/(tot_noobj+1e-16))*100:2f}%")
    print(f"Obj accuracy is: {(correct_obj/(tot_obj+1e-16))*100:2f}%")
    model.train()


def plot_couple_examples(model, loader, thresh, iou_thresh, anchors, with_targets):
    model.eval()
    x, y = next(iter(loader))
    x = x.to("cuda")
    with torch.no_grad():
        out = model(x)
        bboxes = [[] for _ in range(x.shape[0])]
        target_bboxes = [[] for _ in range(x.shape[0])]
        for i in range(3):
            batch_size, A, S, _, _ = out[i].shape
            anchor = anchors[i]
            boxes_scale_i = cells_to_bboxes(
                out[i], anchor, S=S, is_preds=True
            )
            for idx, (box) in enumerate(boxes_scale_i):
                bboxes[idx] += box

            if with_targets:
                batch_size, A, S, _, _ = y[i].shape
                anchor = anchors[i]
                boxes_scale_i = cells_to_bboxes(
                    y[i], anchor, S=S, is_preds=False
                )
                for idx, (box) in enumerate(boxes_scale_i):
                    target_bboxes[idx] += box

        model.train()


    for i in range(batch_size):
        nms_boxes = non_max_suppression(
            bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
        )
        if with_targets:
          nms_target_boxes = non_max_suppression(
              target_bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
          )
          plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes, nms_target_boxes)
        else:
          plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)



In [None]:
# @title YOLOv3 Dataset


class YOLODataset(Dataset):
    def __init__(
        self,
        img_list,
        label_list,
        img_dir,
        label_dir,
        anchors,
        image_size=416,
        S=[13, 26, 52],
        C=1,
        transform=None,
    ):
        self.img_list = img_list
        self.label_list = label_list
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.image_size = image_size
        self.transform = transform
        self.S = S
        self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2])  # for all 3 scales
        self.num_anchors = self.anchors.shape[0]
        self.num_anchors_per_scale = self.num_anchors // 3
        self.C = C
        self.ignore_iou_thresh = 0.5

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.img_list[index])
        img = Image.open(img_path)
        image_width, image_height = img.size
        image = np.array(img.convert("RGB"))

        label_path = os.path.join(self.label_dir, self.label_list[index])

         # Read the CSV file
        df = pd.read_csv(label_path)

        # Assuming 'polyp' is the only class, so class_id = 0
        df['class_name'] = 0

        # normalize the coordinates of bboxes in the format of YOLO
        # Normalize bounding box coordinates
        df['x_center'] = (df['xmin'] + df['xmax']) / 2 / image_width
        df['y_center'] = (df['ymin'] + df['ymax']) / 2 / image_height
        df['width'] = (df['xmax'] - df['xmin']) / image_width
        df['height'] = (df['ymax'] - df['ymin']) / image_height
        yolo_format = df[['class_name', 'x_center', 'y_center', 'width', 'height']].to_numpy()
        bboxes = np.roll(yolo_format, 4, axis=1).tolist()




        if self.transform:
            augmentations = self.transform(image=image, bboxes=bboxes)
            image = augmentations["image"]
            bboxes = augmentations["bboxes"]

        # Below assumes 3 scale predictions (as paper) and same num of anchors per scale
        targets = [torch.zeros((self.num_anchors // 3, S, S, 6)) for S in self.S]
        for box in bboxes:
            iou_anchors = iou_width_height(torch.tensor(box[2:4]), self.anchors)
            anchor_indices = iou_anchors.argsort(descending=True, dim=0)
            x, y, width, height, class_label = box
            has_anchor = [False] * 3  # each scale should have one anchor
            for anchor_idx in anchor_indices:
                scale_idx = anchor_idx // self.num_anchors_per_scale
                anchor_on_scale = anchor_idx % self.num_anchors_per_scale
                S = self.S[scale_idx]
                i, j = int(S * y), int(S * x)  # which cell
                anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
                if not anchor_taken and not has_anchor[scale_idx]:
                    targets[scale_idx][anchor_on_scale, i, j, 0] = 1 # confidence score
                    x_cell, y_cell = S * x - j, S * y - i  # both between [0,1]
                    width_cell, height_cell = (
                        width * S,
                        height * S,
                    )  # can be greater than 1 since it's relative to cell
                    box_coordinates = torch.tensor(
                        [x_cell, y_cell, width_cell, height_cell]
                    )
                    targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates
                    targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)
                    has_anchor[scale_idx] = True

                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    targets[scale_idx][anchor_on_scale, i, j, 0] = -1  # ignore prediction

        return image, tuple(targets)

# (confidence score, x_center, y_center, width, height, class_label) for each cell of each scale

In [None]:
# @title Dataset Test
def test():
    anchors = ANCHORS

    transform = test_transforms

    dataset = YOLODataset(
        val_images_list,
        val_labels_list,
        IMG_DIR,
        LABEL_DIR,
        S=[13, 26, 52],
        anchors=anchors,
        transform=transform,
    )
    S = [13, 26, 52]
    # [3, 3, 2]
    scaled_anchors = torch.tensor(anchors) / (
        1 / torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    )
    loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)
    # y:[3, b, 3, S, S, 6]; 6 = [p, box_coordinates, class_label]
    for x, y in loader:
        boxes = []
        print(y[0].shape) # first scale
        print('--------')
        for i in range(y[0].shape[1]): # 3 diffrent sizes on only first scale
            anchor = scaled_anchors[i] # getting the anchors for each scale
            print(anchor.shape)
            print(y[i].shape)
            print(len(cells_to_bboxes(
                y[i], is_preds=False, S=y[i].shape[2], anchors=anchor
            )[0]))
            boxes += cells_to_bboxes(
                y[i], is_preds=False, S=y[i].shape[2], anchors=anchor
            )[0] # select the first element's boxes from the batch

        print('--------')
        print(len(boxes))
        print('--------')
        boxes = non_max_suppression(boxes, iou_threshold=1, threshold=0.7, box_format="midpoint")
        print(boxes)
        plot_image(x[0].permute(1, 2, 0).to("cpu"), boxes)


if __name__ == "__main__":
    test()

In [None]:
# @title YOLOv3 architecture
"""
Implementation of YOLOv3 architecture
"""
"""
Information about architecture config:
Tuple is structured by (filters, kernel_size, stride)
Every conv is a same convolution.
List is structured by "B" indicating a residual block followed by the number of repeats
"S" is for scale prediction block and computing the yolo loss
"U" is for upsampling the feature map and concatenating with a previous layer
"""
import torch
import torch.nn as nn
config = [
    (32, 3, 1),
    (64, 3, 2),
    ["B", 1],
    (128, 3, 2),
    ["B", 2],
    (256, 3, 2),
    ["B", 8],
    (512, 3, 2),
    ["B", 8],
    (1024, 3, 2),
    ["B", 4],  # To this point is Darknet-53
    (512, 1, 1),
    (1024, 3, 1),
    "S",
    (256, 1, 1),
    "U",
    (256, 1, 1),
    (512, 3, 1),
    "S",
    (128, 1, 1),
    "U",
    (128, 1, 1),
    (256, 3, 1),
    "S",
]


# Batch norm and leaky relu added to it
# out_channels: Number of filters applied by the convolutional layer, which equals the number of output channels
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
        # Calls the constructor of the parent class nn.Module
        super().__init__()
        # **kwargs allows passing additional parameters like kernel_size, stride, padding, etc
        self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
        # Initializes a batch normalization layer for the output channels of the convolutional layer
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky = nn.LeakyReLU(0.1)
        self.use_bn_act = bn_act

    def forward(self, x):
        if self.use_bn_act:
            return self.leaky(self.bn(self.conv(x)))
        else:
            return self.conv(x)


# A combination of two convolutional blocks with a residual connection
# The input size will therefore be maintained through the residual block
class ResidualBlock(nn.Module):
    def __init__(self, channels, use_residual=True, num_repeats=1):
        super().__init__()
        self.layers = nn.ModuleList()
        for repeat in range(num_repeats):
            self.layers += [
                nn.Sequential(
                    CNNBlock(channels, channels // 2, kernel_size=1),
                    CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
                )
            ]
        self.use_residual = use_residual
        self.num_repeats = num_repeats

    def forward(self, x):
        for layer in self.layers:
            if self.use_residual:
                x = x + layer(x)
            else:
                x = layer(x)

        return x

# The last two convolutional layers leading up to the prediction for each scale
# reshape the output such that it has the the shape (batch size, anchors per scale, grid size, grid size, 5 + number of classes)
# where 5 refers to the object score and four bounding box coordinates
class ScalePrediction(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pred = nn.Sequential(
            CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
            CNNBlock(
                2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
            ),
        )
        self.num_classes = num_classes

    def forward(self, x):
        return (
            self.pred(x)
            .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
            .permute(0, 1, 3, 4, 2)
        )


class YOLOv3(nn.Module):
    def __init__(self, in_channels=3, num_classes=80):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.layers = self._create_conv_layers()

    def forward(self, x):
        outputs = []  # for each scale
        route_connections = []
        for layer in self.layers:
            if isinstance(layer, ScalePrediction):
                outputs.append(layer(x))
                continue

            x = layer(x)

            if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
                route_connections.append(x)

            elif isinstance(layer, nn.Upsample):
                x = torch.cat([x, route_connections[-1]], dim=1)
                route_connections.pop()

        return outputs

    def _create_conv_layers(self):
        layers = nn.ModuleList()
        in_channels = self.in_channels

        for module in config:
            if isinstance(module, tuple):
                out_channels, kernel_size, stride = module
                layers.append(
                    CNNBlock(
                        in_channels,
                        out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=1 if kernel_size == 3 else 0,
                    )
                )
                in_channels = out_channels

            elif isinstance(module, list):
                num_repeats = module[1]
                layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))

            elif isinstance(module, str):
                if module == "S":
                    layers += [
                        ResidualBlock(in_channels, use_residual=False, num_repeats=1),
                        CNNBlock(in_channels, in_channels // 2, kernel_size=1),
                        ScalePrediction(in_channels // 2, num_classes=self.num_classes),
                    ]
                    in_channels = in_channels // 2

                elif module == "U":
                    layers.append(nn.Upsample(scale_factor=2),)
                    in_channels = in_channels * 3

        return layers


if __name__ == "__main__":
    num_classes = 1
    IMAGE_SIZE = 416
    model = YOLOv3(num_classes=num_classes)
    x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
    out = model(x)
    assert model(x)[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
    assert model(x)[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
    assert model(x)[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
    print("Success!")
    print("input shape:", x.shape)
    print("output shape 1:", out[0].shape)
    print("output shape 2:", out[1].shape)
    print("output shape 3:", out[2].shape)

In [None]:
# @title mAP
def get_evaluation_bboxes(
    loader,
    model,
    iou_threshold,
    anchors,
    threshold,
    box_format="midpoint",
    device="cuda",
):
    # make sure model is in eval before get bboxes
    model.eval()
    train_idx = 0
    all_pred_boxes = []
    all_true_boxes = []
    for batch_idx, (x, labels) in enumerate(tqdm(loader)):
        x = x.to(device)

        with torch.no_grad():
            predictions = model(x)

        batch_size = x.shape[0]
        bboxes = [[] for _ in range(batch_size)]
        for i in range(3):
            S = predictions[i].shape[2]
            anchor = torch.tensor([*anchors[i]]).to(device) * S
            boxes_scale_i = cells_to_bboxes(
                predictions[i], anchor, S=S, is_preds=True
            )
            for idx, (box) in enumerate(boxes_scale_i):
                bboxes[idx] += box

        # we just want one bbox for each label, not one for each scale
        true_bboxes = cells_to_bboxes(
            labels[2], anchor, S=S, is_preds=False
        )

        for idx in range(batch_size):
            nms_boxes = non_max_suppression(
                bboxes[idx],
                iou_threshold=iou_threshold,
                threshold=threshold,
                box_format=box_format,
            )

            for nms_box in nms_boxes:
                all_pred_boxes.append([train_idx] + nms_box)

            for box in true_bboxes[idx]:
                if box[1] > threshold:
                    all_true_boxes.append([train_idx] + box)

            train_idx += 1

    model.train()
    return all_pred_boxes, all_true_boxes


def mean_average_precision(
    pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
):
    """
    Video explanation of this function:
    https://youtu.be/FppOzcDvaDI

    This function calculates mean average precision (mAP)

    Parameters:
        pred_boxes (list): list of lists containing all bboxes with each bboxes
        specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
        true_boxes (list): Similar as pred_boxes except all the correct ones
        iou_threshold (float): threshold where predicted bboxes is correct
        box_format (str): "midpoint" or "corners" used to specify bboxes
        num_classes (int): number of classes

    Returns:
        float: mAP value across all classes given a specific IoU threshold
    """

    # list storing all AP for respective classes
    average_precisions = []

    # used for numerical stability later on
    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []

        # Go through all predictions and targets,
        # and only add the ones that belong to the
        # current class c
        for detection in pred_boxes:
            if detection[1] == c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        # find the amount of bboxes for each training example
        # Counter here finds how many ground truth bboxes we get
        # for each training example, so let's say img 0 has 3,
        # img 1 has 5 then we will obtain a dictionary with:
        # amount_bboxes = {0:3, 1:5}
        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        # We then go through each key, val in this dictionary
        # and convert to the following (w.r.t same example):
        # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        # sort by box probabilities which is index 2
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        # If none exists for this class then we can safely skip
        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            # Only take out the ground_truths that have the same
            # training idx as detection
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    torch.tensor(detection[3:]),
                    torch.tensor(gt[3:]),
                    box_format=box_format,
                )

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                # only detect ground truth detection once
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    # true positive and add this bounding box to seen
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1

            # if IOU is lower then the detection is a false positive
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))
        # torch.trapz for numerical integration
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)



In [None]:
# @title F1
def mean_average_precision_with_F1(
    pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
):
    """
    Video explanation of this function:
    https://youtu.be/FppOzcDvaDI

    This function calculates mean average precision (mAP)

    Parameters:
        pred_boxes (list): list of lists containing all bboxes with each bboxes
        specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
        true_boxes (list): Similar as pred_boxes except all the correct ones
        iou_threshold (float): threshold where predicted bboxes is correct
        box_format (str): "midpoint" or "corners" used to specify bboxes
        num_classes (int): number of classes

    Returns:
        float: mAP value across all classes given a specific IoU threshold
    """

    # list storing all AP for respective classes
    average_precisions = []

    # used for numerical stability later on
    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []

        # Go through all predictions and targets,
        # and only add the ones that belong to the
        # current class c
        for detection in pred_boxes:
            if detection[1] == c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        # find the amount of bboxes for each training example
        # Counter here finds how many ground truth bboxes we get
        # for each training example, so let's say img 0 has 3,
        # img 1 has 5 then we will obtain a dictionary with:
        # amount_bboxes = {0:3, 1:5}
        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        # We then go through each key, val in this dictionary
        # and convert to the following (w.r.t same example):
        # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        # sort by box probabilities which is index 2
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        # If none exists for this class then we can safely skip
        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            # Only take out the ground_truths that have the same
            # training idx as detection
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    torch.tensor(detection[3:]),
                    torch.tensor(gt[3:]),
                    box_format=box_format,
                )

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                # only detect ground truth detection once
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    # true positive and add this bounding box to seen
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1

            # if IOU is lower then the detection is a false positive
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))
        # torch.trapz for numerical integration
        average_precisions.append(torch.trapz(precisions, recalls))

        # F1 Score
        f1_score = (2 * (precisions *recalls)) / (precisions + recalls)

    return sum(average_precisions) / len(average_precisions), f1_score

In [None]:
# @title YOLOv3 Loss function
"""
Implementation of Yolo Loss Function similar to the one in Yolov3 paper,
the difference from what I can tell is I use CrossEntropy for the classes
instead of BinaryCrossEntropy.
"""

"""
YOLO (v3) architecture was optimized on a combination of four losses:
no object loss, object loss, box coordinate loss, and class loss.
"""


class YoloLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        # sigmoid function and then binary crossentropy loss
        self.bce = nn.BCEWithLogitsLoss()
        self.entropy = nn.CrossEntropyLoss()
        self.sigmoid = nn.Sigmoid()

        # Constants signifying how much to pay for each respective part of the loss
        self.lambda_class = 1
        self.lambda_noobj = 5
        self.lambda_obj = 1
        self.lambda_box = 5

    def forward(self, predictions, target, anchors):
        # Check where obj and noobj (we ignore if target == -1)
        obj = target[..., 0] == 1  # in paper this is Iobj_i
        noobj = target[..., 0] == 0  # in paper this is Inoobj_i

        # ======================= #
        #   FOR NO OBJECT LOSS    #
        # ======================= #
        """
        We want to incur loss only for their object score.
        The target will be all zeros since we want these
        anchors to predict an object score of zero

        a sigmoid function to the network outputs
        """
        no_object_loss = self.bce(
            (predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]),
        )

        # ==================== #
        #   FOR OBJECT LOSS    #
        # ==================== #

        """
        the loss will only be applied to the anchors assigned to a target bb
        signified by indexing by obj.

        """
        anchors = anchors.reshape(1, 3, 1, 1, 2)
        box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)
        ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
        object_loss = self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj])

        # ======================== #
        #   FOR BOX COORDINATES    #
        # ======================== #

        predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3])  # x,y coordinates
        target[..., 3:5] = torch.log(
            (1e-16 + target[..., 3:5] / anchors)
        )  # width, height coordinates
        box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])

        # ================== #
        #   FOR CLASS LOSS   #
        # ================== #

        class_loss = self.entropy(
            (predictions[..., 5:][obj]), (target[..., 5][obj].long()),
        )

        #print("__________________________________")
        #print(self.lambda_box * box_loss)
        #print(self.lambda_obj * object_loss)
        #print(self.lambda_noobj * no_object_loss)
        #print(self.lambda_class * class_loss)
        #print("\n")

        return (
            self.lambda_box * box_loss
            + self.lambda_obj * object_loss
            + self.lambda_noobj * no_object_loss
            + self.lambda_class * class_loss
        )

In [None]:
# @title Utils for Training

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, current_loss):
        if self.best_loss is None:
            self.best_loss = current_loss
        elif self.best_loss - current_loss > self.min_delta:
            self.best_loss = current_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):
    model.train()  # Ensure model is in training mode
    loop = tqdm(train_loader, leave=True)
    losses = []
    for batch_idx, (x, y) in enumerate(loop):
        x = x.to(DEVICE)
        y0, y1, y2 = (
            y[0].to(DEVICE),
            y[1].to(DEVICE),
            y[2].to(DEVICE),
        )

        with torch.cuda.amp.autocast():
            out = model(x)
            loss = (
                loss_fn(out[0], y0, scaled_anchors[0])
                + loss_fn(out[1], y1, scaled_anchors[1])
                + loss_fn(out[2], y2, scaled_anchors[2])
            )

        losses.append(loss.item())
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update progress bar
        mean_loss = sum(losses) / len(losses)
        loop.set_postfix(loss=mean_loss)

    return mean_loss  # Return the mean loss for this epoch

def validate_fn(val_loader, model, loss_fn, scaled_anchors):
    model.eval()  # Ensure model is in evaluation mode
    val_losses = []
    with torch.no_grad():  # Disable gradient calculation for validation
        loop = tqdm(val_loader, leave=True)
        for batch_idx, (x, y) in enumerate(loop):
            x = x.to(DEVICE)
            y0, y1, y2 = y[0].to(DEVICE), y[1].to(DEVICE), y[2].to(DEVICE)

            out = model(x)
            loss = (loss_fn(out[0], y0, scaled_anchors[0]) +
                    loss_fn(out[1], y1, scaled_anchors[1]) +
                    loss_fn(out[2], y2, scaled_anchors[2]))

            val_losses.append(loss.item())
            mean_val_loss = sum(val_losses) / len(val_losses)
            loop.set_postfix(val_loss=mean_val_loss)

    return mean_val_loss

In [None]:
# @title Train
"""
Main file for training Yolo model on Pascal VOC and COCO dataset
"""

import torch
import torch.optim as optim

from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

torch.backends.cudnn.benchmark = True

print(ANCHORS)

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, current_loss):
        if self.best_loss is None:
            self.best_loss = current_loss
        elif self.best_loss - current_loss > self.min_delta:
            self.best_loss = current_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):
    model.train()  # Ensure model is in training mode
    loop = tqdm(train_loader, leave=True)
    losses = []
    for batch_idx, (x, y) in enumerate(loop):
        x = x.to(DEVICE)
        y0, y1, y2 = (
            y[0].to(DEVICE),
            y[1].to(DEVICE),
            y[2].to(DEVICE),
        )

        with torch.cuda.amp.autocast():
            out = model(x)
            loss = (
                loss_fn(out[0], y0, scaled_anchors[0])
                + loss_fn(out[1], y1, scaled_anchors[1])
                + loss_fn(out[2], y2, scaled_anchors[2])
            )

        losses.append(loss.item())
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update progress bar
        mean_loss = sum(losses) / len(losses)
        loop.set_postfix(loss=mean_loss)

    return mean_loss  # Return the mean loss for this epoch

def validate_fn(val_loader, model, loss_fn, scaled_anchors):
    model.eval()  # Ensure model is in evaluation mode
    val_losses = []
    with torch.no_grad():  # Disable gradient calculation for validation
        loop = tqdm(val_loader, leave=True)
        for batch_idx, (x, y) in enumerate(loop):
            x = x.to(DEVICE)
            y0, y1, y2 = y[0].to(DEVICE), y[1].to(DEVICE), y[2].to(DEVICE)

            out = model(x)
            loss = (loss_fn(out[0], y0, scaled_anchors[0]) +
                    loss_fn(out[1], y1, scaled_anchors[1]) +
                    loss_fn(out[2], y2, scaled_anchors[2]))

            val_losses.append(loss.item())
            mean_val_loss = sum(val_losses) / len(val_losses)
            loop.set_postfix(val_loss=mean_val_loss)

    return mean_val_loss

def main():
    model = YOLOv3(num_classes=NUM_CLASSES).to(DEVICE)
    # parameter counts
    # Total parameters
    total_params = sum(p.numel() for p in model.parameters())

    # Trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Total Parameters: {total_params}")
    print(f"Trainable Parameters: {trainable_params}")


    optimizer = optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )
    loss_fn = YoloLoss()
    scaler = torch.cuda.amp.GradScaler()

    train_loader, test_loader, train_eval_loader = get_loaders(
        train_images_list, train_labels_list, val_images_list, val_labels_list
    )

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_FILE, model, optimizer, LEARNING_RATE
        )

    scaled_anchors = (
        torch.tensor(ANCHORS)
        * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    ).to(DEVICE)

    epoch_losses = []  # List to store loss per epoch
    val_losses = []
    early_stopping = EarlyStopping(patience=5, min_delta=0.01)  # Set patience and min_delta here

    epoch = 0

    while True:
        #plot_couple_examples(model, test_loader, 0.6, 0.5, scaled_anchors)
        print(f"Currently epoch {epoch}")
        # training step
        mean_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)
        epoch_losses.append(mean_loss)  # Store the loss for this epoch

        # Validation step
        mean_val_loss = validate_fn(train_eval_loader, model, loss_fn, scaled_anchors)
        val_losses.append(mean_val_loss)

        # Early Stopping logic
        early_stopping(mean_val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break

        # if (epoch + 1)  % 10 == 0 and epoch > 0:

        #    if SAVE_MODEL:
        #       save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")

        #    check_class_accuracy(model, train_loader, threshold=CONF_THRESHOLD)

          #  pred_boxes, true_boxes = get_evaluation_bboxes(
          #       train_loader,
          #       model,
          #       iou_threshold=NMS_IOU_THRESH,
          #       anchors=ANCHORS,
          #       threshold=CONF_THRESHOLD,
          #   )
          #  mapval = mean_average_precision(
          #       pred_boxes,
          #       true_boxes,
          #       iou_threshold=MAP_IOU_THRESH,
          #       box_format="midpoint",
          #       num_classes=NUM_CLASSES,
          #   )
          #  print(f"MAP: {mapval.item()}")
          #  model.train()

        #print(f"Currently epoch {epoch}")
        #print("On Train Eval loader:")
        #print("On Train loader:")
        # check_class_accuracy(model, train_loader, threshold=CONF_THRESHOLD)

        # if epoch > 0 and epoch % 3 == 0:
        #     check_class_accuracy(model, test_loader, threshold=CONF_THRESHOLD)
            # pred_boxes, true_boxes = get_evaluation_bboxes(
            #     test_loader,
            #     model,
            #     iou_threshold=NMS_IOU_THRESH,
            #     anchors=ANCHORS,
            #     threshold=CONF_THRESHOLD,
            # )
            # mapval = mean_average_precision(
            #     pred_boxes,
            #     true_boxes,
            #     iou_threshold=MAP_IOU_THRESH,
            #     box_format="midpoint",
            #     num_classes=NUM_CLASSES,
            # )
            # print(f"MAP: {mapval.item()}")
            # model.train()
        epoch += 1

    # save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar")
    torch.save(model, 'second_model_complete.pth')
    # Plot the loss after training
    plt.plot(epoch_losses, label="Training Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training Loss Over Epochs")
    plt.legend()
    plt.show()

if __name__ == "__main__":
    main()

In [None]:
# @title Test using loader

model = YOLOv3(num_classes=NUM_CLASSES).to(DEVICE)
optimizer = optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

loss_fn = YoloLoss()
scaler = torch.cuda.amp.GradScaler()

train_loader, test_loader, train_eval_loader = get_loaders(
        train_images_list, train_labels_list, val_images_list, val_labels_list
)

if True:
        load_checkpoint(
            "/content/checkpoint.pth.tar", model, optimizer, LEARNING_RATE
)

scaled_anchors = (
        torch.tensor(ANCHORS)
        * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(DEVICE)

model.to(DEVICE)

validate_fn(train_eval_loader, model, loss_fn, scaled_anchors)

plot_couple_examples(model, test_loader, 0.6, 0.45, scaled_anchors, with_targets=True)



print("--------------")
pred_boxes, true_boxes = get_evaluation_bboxes(
                train_loader,
                model,
                iou_threshold=NMS_IOU_THRESH,
                anchors=ANCHORS,
                threshold=CONF_THRESHOLD,
            )
mapval = mean_average_precision(
                pred_boxes,
                true_boxes,
                iou_threshold=MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=NUM_CLASSES,
            )
print(f"MAP: {mapval.item()}")

check_class_accuracy(model, test_loader, threshold=CONF_THRESHOLD)


In [None]:
# @title Show image two
def show_images_two(img_lst, label_lst, loops=8):
    for i in range(loops):
        img_path = os.path.join(IMG_DIR, img_lst[i])
        label_path = os.path.join(LABEL_DIR, label_lst[i])
        img = Image.open(img_path)
        image_width, image_height = img.size

        print(img_path)
        # print(img.size)
        # print(type(img))
        # plt.imshow(img)
        # plt.show()


        print(label_path)

        # Read the CSV file
        df = pd.read_csv(label_path)

        # Display the dataframe
        print(df)

        bounding_boxes = df[['class_name', 'xmin', 'ymin', 'xmax', 'ymax']].to_numpy()
        print(bounding_boxes)
        print(bounding_boxes.shape)

        fig, ax = plt.subplots()

        ax.imshow(img)
        # Plot each bounding box
        for index, row in df.iterrows():
            x_min, y_min, x_max, y_max = row['xmin'], row['ymin'], row['xmax'], row['ymax']
            width = x_max - x_min
            height = y_max - y_min
            # Create a rectangle patch
            rect = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='r', facecolor='none')
            # Add the patch to the axis
            ax.add_patch(rect)

        plt.show()

        # # Assuming 'polyp' is the only class, so class_id = 1
        # df['class_name'] = 1

        # # Normalize bounding box coordinates
        # df['x_center'] = (df['xmin'] + df['xmax']) / 2 / image_width
        # df['y_center'] = (df['ymin'] + df['ymax']) / 2 / image_height
        # df['width'] = (df['xmax'] - df['xmin']) / image_width
        # df['height'] = (df['ymax'] - df['ymin']) / image_height

        # # Prepare YOLO format data
        # yolo_format = df[['class_name', 'x_center', 'y_center', 'width', 'height']].to_numpy()

        # # Display the dataframe
        # print(df)

        # print(yolo_format)

        # print(yolo_format.shape)

        # fig, ax = plt.subplots()
        # ax.imshow(img)
        # # Create a Rectangle patch
        # for box in yolo_format:
        #     box = box[1:]
        #     upper_left_x = box[0] - box[2] / 2
        #     upper_left_y = box[1] - box[3] / 2
        #     rect = patches.Rectangle(
        #         (upper_left_x * image_width, upper_left_y * image_height),
        #         box[2] * image_width,
        #         box[3] * image_height,
        #         linewidth=2,
        #         edgecolor='r',
        #         facecolor="none",
        #     )
        #     # Add the patch to the Axes
        #     ax.add_patch(rect)


        # plt.show()




        print("----------------------------------------------------")

In [None]:
def find_csv_files_with_more_than_2_rows(directory_path):
    csv_files = [f for f in os.listdir(directory_path) if f.endswith('.csv')]
    files_with_more_than_2_rows = []

    for csv_file in csv_files:
        file_path = os.path.join(directory_path, csv_file)
        try:
            df = pd.read_csv(file_path)
            if len(df) > 2:  # Check if the number of rows is more than 2
                # Remove the '.csv' extension
                file_name_without_extension = os.path.splitext(csv_file)[0]
                files_with_more_than_2_rows.append(csv_file)
        except Exception as e:
            print(f"Could not read {csv_file}: {e}")

    return files_with_more_than_2_rows

directory_path = LABEL_DIR  # Replace with your directory path
files_labels = find_csv_files_with_more_than_2_rows(directory_path)
print(len(files_labels))
print("CSV files with more than 2 rows:")
for file in files_labels:
    print(file)


files_images = [x.replace(".csv", ".jpg") for x in files_labels]
print("jpg files with more than 2 rows:")
for file in files_images:
    print(file)

show_images_two(files_images, files_labels)

In [None]:
# model = YOLOv3(num_classes=NUM_CLASSES).to(DEVICE)
# optimizer = optim.Adam(
#         model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
#     )
loss_fn = YoloLoss()
# scaler = torch.cuda.amp.GradScaler()

# train_loader, test_loader, train_eval_loader = get_loaders(
#         files_images, files_labels, files_images, files_labels
#     )

# if True:
#         load_checkpoint(
#             CHECKPOINT_FILE, model, optimizer, LEARNING_RATE
#         )

scaled_anchors = (
        torch.tensor(ANCHORS)
        * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    ).to(DEVICE)
model = torch.load('/content/drive/MyDrive/models/second_pruned_resnet.pth')
model = model.to(DEVICE)
# model = model.to(DEVICE)
# CONF_THRESHOLD = 0.6
# MAP_IOU_THRESH = 0.5
# NMS_IOU_THRESH = 0.45

train_loader, test_loader, train_eval_loader = get_loaders(
    train_images_list, train_labels_list, val_images_list, val_labels_list
)
plot_couple_examples(model, test_loader, CONF_THRESHOLD, NMS_IOU_THRESH, scaled_anchors, with_targets=True)


# Total parameters
total_params = sum(p.numel() for p in model.parameters())
# Trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")
print("--------------")
pred_boxes, true_boxes = get_evaluation_bboxes(
                test_loader,
                model,
                iou_threshold=NMS_IOU_THRESH,
                anchors=ANCHORS,
                threshold=CONF_THRESHOLD,
            )
mapval = mean_average_precision(
                pred_boxes,
                true_boxes,
                iou_threshold=MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=NUM_CLASSES,
            )
print(f"MAP: {mapval.item()}")

_, f1_score = mean_average_precision_with_F1(
                pred_boxes,
                true_boxes,
                iou_threshold=MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=NUM_CLASSES,
            )
print(f"F1: {f1_score.item()}")

initial_loss = validate_fn(train_eval_loader, model, loss_fn, scaled_anchors)
print(f"Initial validation loss: {initial_loss:.4f}")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# @title Graph: LOSS PER EPOCH
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

# Path to your tensorboard log directory (change accordingly)
log_dir = '/content/lightning_logs/resnet_pruning/version_0/'

# Load TensorBoard logs
event_acc = EventAccumulator(log_dir)
event_acc.Reload()

# Retrieve the training and validation losses
train_loss = event_acc.Scalars('train_loss')  # Use the name of the scalar you logged (e.g., 'train_loss')
val_loss = event_acc.Scalars('val_loss')  # If you logged validation loss as well

# Extract the steps (epochs) and loss values
train_steps = [x.step for x in train_loss]
train_losses = [x.value for x in train_loss]

# Optional: If you logged validation losses
val_steps = [x.step for x in val_loss]
val_losses = [x.value for x in val_loss]

# Plotting training loss
plt.plot(train_steps, train_losses, label='Train Loss')

# Plotting validation loss (if available)
if val_loss:
    plt.plot(val_steps, val_losses, label='Validation Loss')

plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training and Validation Loss per Step')
plt.legend()
plt.show()

In [None]:
 # @title Test using loader
model.to(DEVICE)
plot_couple_examples(model, test_loader, 0.6, 0.45, scaled_anchors, with_targets=True)

# Trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")

print("--------------")
pred_boxes, true_boxes = get_evaluation_bboxes(
                train_loader,
                model,
                iou_threshold=NMS_IOU_THRESH,
                anchors=ANCHORS,
                threshold=CONF_THRESHOLD,
            )
mapval = mean_average_precision(
                pred_boxes,
                true_boxes,
                iou_threshold=MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=NUM_CLASSES,
            )
print(f"MAP: {mapval.item()}")

check_class_accuracy(model, test_loader, threshold=CONF_THRESHOLD)


In [None]:
# @title Utils for Training

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, current_loss):
        if self.best_loss is None:
            self.best_loss = current_loss
        elif self.best_loss - current_loss > self.min_delta:
            self.best_loss = current_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):
    model.train()  # Ensure model is in training mode
    loop = tqdm(train_loader, leave=True)
    losses = []
    for batch_idx, (x, y) in enumerate(loop):
        x = x.to(DEVICE)
        y0, y1, y2 = (
            y[0].to(DEVICE),
            y[1].to(DEVICE),
            y[2].to(DEVICE),
        )

        with torch.cuda.amp.autocast():
            out = model(x)
            loss = (
                loss_fn(out[0], y0, scaled_anchors[0])
                + loss_fn(out[1], y1, scaled_anchors[1])
                + loss_fn(out[2], y2, scaled_anchors[2])
            )

        losses.append(loss.item())
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update progress bar
        mean_loss = sum(losses) / len(losses)
        loop.set_postfix(loss=mean_loss)

    return mean_loss  # Return the mean loss for this epoch

def validate_fn(val_loader, model, loss_fn, scaled_anchors):
    model.eval()  # Ensure model is in evaluation mode
    val_losses = []
    with torch.no_grad():  # Disable gradient calculation for validation
        loop = tqdm(val_loader, leave=True)
        for batch_idx, (x, y) in enumerate(loop):
            x = x.to(DEVICE)
            y0, y1, y2 = y[0].to(DEVICE), y[1].to(DEVICE), y[2].to(DEVICE)

            out = model(x)
            loss = (loss_fn(out[0], y0, scaled_anchors[0]) +
                    loss_fn(out[1], y1, scaled_anchors[1]) +
                    loss_fn(out[2], y2, scaled_anchors[2]))

            val_losses.append(loss.item())
            mean_val_loss = sum(val_losses) / len(val_losses)
            loop.set_postfix(val_loss=mean_val_loss)

    return mean_val_loss

In [None]:
# @title Taylor Pruning
import torch
import torch.nn as nn

def taylor_prune_conv_layer(layer, model, loss_fn, validloader, device, scaled_anchors, num_filters_to_prune):
    # Initialize variables
    filter_ranks = []
    activations = []
    grad_index = 0

    # Register forward hook to capture activations
    def capture_activations(module, input, output):
        activations.append(output.detach())

    # Register backward hook to compute rank
    def compute_rank(grad):
        nonlocal grad_index
        act_idx = len(activations) - grad_index - 1
        # Compute rank: [B, C, H, W] -> [B, C]
        rank = torch.abs(torch.mean((grad * activations[act_idx]), dim=(2, 3)))
        # Normalize each filter rank by L2 norm (across batches)
        div = torch.sqrt(torch.sum(rank ** 2, dim=1, keepdim=True))
        rank = rank / div
        # Mean over dimension 0 (batch size): [C]
        rank = torch.mean(rank, dim=0)

        if len(filter_ranks) < grad_index + 1:
            filter_ranks.append(rank)
        else:
            filter_ranks[grad_index] += rank

        grad_index += 1

    # Register hooks
    handle_forward = layer.register_forward_hook(capture_activations)
    handle_backward = layer.register_backward_hook(lambda module, grad_input, grad_output: compute_rank(grad_output[0]))

    # Iterate through validation data
    model.eval()
    # for batch in validloader:
    #     x = batch[0].to(device)
    #     labels = batch[1].to(device)

    #     # Forward pass
    #     output = model(x)

    #     # Backward pass
    #     loss = nn.functional.cross_entropy(output, labels)
    #     model.zero_grad()  # Clear previous gradients
    #     loss.backward()

    #     # Clear activations for next batch
    #     activations.clear()
    #     grad_index = 0
    for batch_idx, (x, y) in enumerate(tqdm(validloader, leave=True)):
        x = x.to(device)
        y0, y1, y2 = y[0].to(device), y[1].to(device), y[2].to(device)

        # Forward pass
        out = model(x)
        loss = (loss_fn(out[0], y0, scaled_anchors[0]) +
                loss_fn(out[1], y1, scaled_anchors[1]) +
                loss_fn(out[2], y2, scaled_anchors[2]))

        # Backward pass: Required for Taylor pruning to get gradients
        model.zero_grad()  # Clear previous gradients
        loss.backward()  # Compute gradients

        # val_losses.append(loss.item())
        # mean_val_loss = sum(val_losses) / len(val_losses)

        # Clear activations for next batch
        activations.clear()
        grad_index = 0

    # Remove hooks
    handle_forward.remove()
    handle_backward.remove()

    # Compute final ranks
    if filter_ranks:
        final_ranks = torch.stack(filter_ranks).sum(dim=0)
    else:
        raise ValueError("No filter ranks computed. Check if the layer is part of the model's forward pass.")

    # Get indices of filters to prune
    sorted_indices = torch.argsort(final_ranks)
    filters_to_prune = sorted_indices[:num_filters_to_prune]

    # Create mask
    mask = torch.ones(layer.out_channels, dtype=torch.bool, device=device)
    mask[filters_to_prune] = False

    # Create new layer
    new_out_channels = mask.sum().item()
    new_layer = nn.Conv2d(layer.in_channels, new_out_channels,
                          kernel_size=layer.kernel_size, stride=layer.stride,
                          padding=layer.padding, bias=(layer.bias is not None)).to(device)

    # Copy weights and bias of kept filters
    new_layer.weight.data = layer.weight.data[mask].clone()
    if layer.bias is not None:
        new_layer.bias.data = layer.bias.data[mask].clone()

    return new_layer, mask


In [None]:
# @title Randomly pruning

def random_prune_conv_layer(layer, num_filters_to_prune):
    # Get the total number of filters in the layer
    total_filters = layer.weight.size(0)

    # Randomly select filters to prune
    filters_to_prune = torch.randperm(total_filters)[:num_filters_to_prune]

    # Create a mask for the filters (True for filters to keep, False for filters to prune)
    mask = torch.ones(total_filters, dtype=torch.bool, device=layer.weight.device)
    mask[filters_to_prune] = False

    # Calculate the number of remaining filters
    new_out_channels = mask.sum().item()

    # Create a new layer with the reduced number of filters
    new_layer = nn.Conv2d(layer.in_channels, new_out_channels,
                          kernel_size=layer.kernel_size, stride=layer.stride,
                          padding=layer.padding, bias=(layer.bias is not None))

    # Copy the weights and bias (if exists) of the kept filters to the new layer
    new_layer.weight.data = layer.weight.data[mask].clone()
    if layer.bias is not None:
        new_layer.bias.data = layer.bias.data[mask].clone()

    return new_layer, mask

In [None]:
from dataclasses import replace
# @title Utils for Pruning (New Version)

def update_lightning_model(lightning_model):

    total_params = sum(p.numel() for p in lightning_model.model.parameters())

    # Trainable parameters
    trainable_params = sum(p.numel() for p in lightning_model.model.parameters() if p.requires_grad)

    print(f"Total Parameters: {total_params}")
    print(f"Trainable Parameters: {trainable_params}")


    lightning_model = LightningYOLOModule(lightning_model.model, scaled_anchors)
    return lightning_model

def prune_conv_layer(layer, num_filters_to_prune):
    weight = layer.weight.data # [out_channels, in_channels, kernel_size, kernel_size]
    l1_norm = weight.abs().sum(dim=(1, 2, 3)) # weights of a fitler; 5 filters total [5]
    sorted_indices = torch.argsort(l1_norm)
    filters_to_prune = sorted_indices[:num_filters_to_prune]

    mask = torch.ones(weight.size(0), dtype=torch.bool)
    mask[filters_to_prune] = False

    new_out_channels = mask.sum().item()
    new_layer = nn.Conv2d(layer.in_channels, new_out_channels,
                          kernel_size=layer.kernel_size, stride=layer.stride,
                          padding=layer.padding, bias=(layer.bias is not None))

    new_layer.weight.data = layer.weight.data[mask].clone()
    if layer.bias is not None:
        new_layer.bias.data = layer.bias.data[mask].clone()


    return new_layer, mask

def get_current_layer(model, updated_layer):

    for module in model.modules():
        if layers_equal(module, updated_layer):
            return module

    return None

def replace_layer_in_model(model, old_layer, new_layer):
    for name, module in model.named_modules():
        if module is old_layer:
            if '.' in name:
                parent_name, child_name = name.rsplit('.', 1)
                parent = model
                for part in parent_name.split('.'):
                    parent = getattr(parent, part)
                setattr(parent, child_name, new_layer)
            else:
                setattr(model, name, new_layer)
            break

def update_batchnorm_for_pruned_conv(bn_layer, filter_indices_to_keep):
    """
    Adjusts the BatchNorm2d layer to match the new output channels of the pruned Conv2d layer.

    Parameters:
    - bn_layer (nn.BatchNorm2d): The BatchNorm2d layer to be updated.
    - filter_indices_to_keep (list): List of indices corresponding to the filters that were kept during pruning.

    Returns:
    - new_bn (nn.BatchNorm2d): The updated BatchNorm2d layer.
    """
    pruned_out_channels = filter_indices_to_keep.sum().item()
    # Create a new BatchNorm2d layer with the pruned number of channels
    new_bn = nn.BatchNorm2d(pruned_out_channels)

    # Copy only the parameters corresponding to the remaining filters
    new_bn.weight.data = bn_layer.weight.data[filter_indices_to_keep]
    new_bn.bias.data = bn_layer.bias.data[filter_indices_to_keep]
    new_bn.running_mean = bn_layer.running_mean[filter_indices_to_keep]
    new_bn.running_var = bn_layer.running_var[filter_indices_to_keep]

    return new_bn


def layers_are_equal(layer1, layer2):
    """
    Check if two layers are structurally identical by comparing their parameters.

    :param layer1: The first layer to compare.
    :param layer2: The second layer to compare.
    :return: True if the layers are identical, False otherwise.
    """
    if type(layer1) != type(layer2):
        return False

    for p1, p2 in zip(layer1.parameters(), layer2.parameters()):
        if not torch.equal(p1, p2):
            return False

    return True

# def replace_layer_by_instance(model, old_layer, new_layer, parent_name=""):
#     """
#     Replace CNNBlock, ResidualBlock, and their layers

#     :param model: The model or layer containing the layers.
#     :param old_layer: The old layer instance to replace.
#     :param new_layer: The new layer to insert.
#     """
#     replaced = False

#     for name, module in model.layers.named_children():
#         full_name = f"{parent_name}.{name}" if parent_name else name
#         if isinstance(module, CNNBlock)
#             for child_name, layer in module.named_children():
#                 child_full_name = f"{full_name}.{child_name}"
#                 if layers_are_equal(layer, old_layer):
#                     setattr(model, child_full_name, new_layer)
#                     replaced = True
#                     break
#         elif isinstance(module, ResidualBlock):
#             for sequential in module.layers:
#                 for block in sequential:
#                     for name, layer in block.named_children():
#                         if layers_are_equal(layer, old_layer):
#                             setattr(model, name, new_layer)
#                             replaced = True
#                             break

#         if replaced:
#             break

# def replace_layer_by_instance(model, old_layer, new_layer):
#     """
#     Recursively replaces a layer within a nested ModuleList or Sequential model
#     by directly comparing the layer instances.

#     :param model: The model or layer containing the layers.
#     :param old_layer: The old layer instance to replace.
#     :param new_layer: The new layer to insert.
#     """
#     for name, layer in model.named_children():
#         if layer is old_layer:
#             # Direct comparison of the layer instance
#             setattr(model, name, new_layer)
#             return True  # Early exit when the layer is replaced
#         elif isinstance(layer, (nn.ModuleList, nn.Sequential)):
#             # Recursively enter nested structures
#             if replace_layer_by_instance(layer, old_layer, new_layer):
#                 return True  # Early exit when the layer is replaced
#     return False  # If no match is found

def layers_equal(layer1, layer2):
    # Check if layers are of the same type
    if type(layer1) != type(layer2):
        return False

    # Check if layers have the same number of parameters
    if sum(p.numel() for p in layer1.parameters()) != sum(p.numel() for p in layer2.parameters()):
        return False

    # Compare each parameter
    for p1, p2 in zip(layer1.parameters(), layer2.parameters()):
        p1 = p1.to(p2.device)
        if p1.shape != p2.shape:
            return False
        if not torch.allclose(p1, p2):
            return False

    # For convolutional layers, check additional attributes
    if isinstance(layer1, nn.Conv2d):
        attrs_to_check = ['in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'dilation', 'groups', 'bias']
        for attr in attrs_to_check:
            if attr == 'bias':
                if getattr(layer1, attr) is not None and getattr(layer2, attr) is not None:
                    continue
            if getattr(layer1, attr) != getattr(layer2, attr):
                return False

    # If all checks pass, layers are equal
    return True

def update_next_layer(layer, mask, model=None, loss_fn=None, train_eval_loader=None, device=None, scaled_anchors=None, prev_model=None):
    if isinstance(layer, CNNBlock):
        new_layer = update_next_layer(layer.conv, mask, model)
        replace_layer_in_model(model, layer.conv, new_layer)
        return layer
    if isinstance(layer, nn.Conv2d):
        new_in_channels = mask.sum().item()
        new_layer = nn.Conv2d(new_in_channels, layer.out_channels,
                              kernel_size=layer.kernel_size, stride=layer.stride,
                              padding=layer.padding, bias=(layer.bias is not None))

        new_layer.weight.data = layer.weight.data[:, mask, :, :].clone()
        if layer.bias is not None:
            new_layer.bias.data = layer.bias.data.clone()

        return new_layer

    elif isinstance(layer, ScalePrediction):
        first_conv = layer.pred[0].conv
        new_in_channels = mask.sum().item()
        new_layer = nn.Conv2d(new_in_channels, first_conv.out_channels,
                              kernel_size=first_conv.kernel_size, stride=first_conv.stride,
                              padding=first_conv.padding, bias=(first_conv.bias is not None))

        new_layer.weight.data = first_conv.weight.data[:, mask, :, :].clone()
        if first_conv.bias is not None:
            new_layer.bias.data = first_conv.bias.data.clone()

        layer.pred[0].conv = new_layer
        return layer

    elif isinstance(layer, nn.Linear):
        input_dim = layer.weight.size(1) // len(mask) # nn.Linear.weight == [out_features, in_features]; layer.weight.size(1) == num of in_features
        mask = mask.repeat_interleave(input_dim)
        new_layer = nn.Linear(mask.sum().item(), layer.out_features)
        new_layer.weight.data = layer.weight.data[:, mask].clone()
        if layer.bias is not None:
            new_layer.bias.data = layer.bias.data.clone()

        return new_layer

    elif isinstance(layer, ResidualBlock):
        block = layer
        current_idx = 0
        for idx, module in enumerate(model.layers):
            if module is layer:
                current_idx = idx
                break
        for idx, sequential in enumerate(block.layers):
            # Update the first conv layer of two CNNBlocks
            sequential[0].conv = update_next_layer(sequential[0].conv, mask, model)
            last_conv = sequential[1].conv
            last_bn = sequential[1].bn
            # prune the second layer as normal pruning
            num_to_prune = int(last_conv.out_channels - mask.sum().item())
            if CRITERION == "Taylor":
                print("-----Taylor Pruning-----")
                prev_last_conv = ((prev_model.layers[current_idx]).layers[idx])[1].conv
                new_layer, new_mask = taylor_prune_conv_layer(prev_last_conv, prev_model, loss_fn, train_eval_loader, device, scaled_anchors, num_to_prune)

            elif CRITERION == "L1":
                print("-----L1 Pruning-----")
                # Update the second conv layer and its bn of two CNNBlocks
                new_layer, new_mask = prune_conv_layer(last_conv, num_to_prune)
            else: # prune filters randomly
                print("-----Random Pruning-----")
                # Update the second conv layer and its bn of two CNNBlocks
                new_layer, new_mask = random_prune_conv_layer(last_conv, num_to_prune)


            # weight = last_conv.weight.data # [out_channels, in_channels, kernel_size, kernel_size]
            # l1_norm = weight.abs().sum(dim=(1, 2, 3)) # weights of a fitler; 5 filters total [5]
            # sorted_indices = torch.argsort(l1_norm)
            # num_to_prune = int(len(sorted_indices) - mask.sum().item())
            # filters_to_prune = sorted_indices[:num_to_prune]

            # new_mask = torch.ones(weight.size(0), dtype=torch.bool)
            # new_mask[filters_to_prune] = False

            # new_out_channels = new_mask.sum().item()
            # new_layer = nn.Conv2d(last_conv.in_channels, new_out_channels,
            #                       kernel_size=last_conv.kernel_size, stride=last_conv.stride,
            #                       padding=last_conv.padding, bias=(last_conv.bias is not None))

            # new_layer.weight.data = last_conv.weight.data[new_mask].clone()
            # if last_conv.bias is not None:
            #     new_layer.bias.data = last_conv.bias.data[new_mask].clone()

            sequential[1].conv = new_layer
            sequential[1].bn = update_batchnorm_for_pruned_conv(last_bn, new_mask)

        # Update its next layer and batchNorm (if there is)
        next_layer = find_next_module(model, block)
        if next_layer is not None:
            updated_next_layer = update_next_layer(next_layer, new_mask, model, loss_fn, train_eval_loader, device, scaled_anchors, prev_model)
            replace_layer_in_model(model, next_layer, updated_next_layer)

        # Update its concat layer input channels if it is 8 residual block
        if block is model.layers[6] or block is model.layers[8]:
            prune_concat_layer(model, block, new_mask)

        return block

    return None


def update_model_after_pruning(model, current_block, idx):
    # Update the current block in the model
    model.layers[idx] = current_block

    # If necessary, update other parts of the model that depend on this block
    # This might include updating the input channels of the next block, etc.

    # Recreate the model's forward pass if necessary

    return model


def find_next_module(model, current_module):
    found_current = False
    for layer in model.layers:
        if layer is current_module:
            found_current = True
        elif found_current and isinstance(layer, (CNNBlock, ScalePrediction, ResidualBlock)):
            return layer
    return None
def update_concat_layer(model, block, mask):
    if block is model.layers[16]:
        layer = model.layers[18].conv
        num_in_channels = layer.in_channels
        new_mask = torch.ones(num_in_channels, dtype=torch.bool)
        new_mask[:len(mask)] = mask
        new_in_channels = new_mask.sum().item()
        new_layer = nn.Conv2d(new_in_channels, layer.out_channels,
                              kernel_size=layer.kernel_size, stride=layer.stride,
                              padding=layer.padding, bias=(layer.bias is not None))

        new_layer.weight.data = layer.weight.data[:, new_mask, :, :].clone()
        if layer.bias is not None:
            new_layer.bias.data = layer.bias.data.clone()
        replace_layer_in_model(model, layer, new_layer)
    elif block is model.layers[23]:
        layer = model.layers[25].conv
        num_in_channels = layer.in_channels
        new_mask = torch.ones(num_in_channels, dtype=torch.bool)
        new_mask[:len(mask)] = mask
        new_in_channels = new_mask.sum().item()
        new_layer = nn.Conv2d(new_in_channels, layer.out_channels,
                              kernel_size=layer.kernel_size, stride=layer.stride,
                              padding=layer.padding, bias=(layer.bias is not None))

        new_layer.weight.data = layer.weight.data[:, new_mask, :, :].clone()
        if layer.bias is not None:
            new_layer.bias.data = layer.bias.data.clone()
        replace_layer_in_model(model, layer, new_layer)


def prune_concat_layer(model, block, mask):
    if block is model.layers[6]:
        layer = model.layers[25].conv
        num_in_channels = layer.in_channels
        new_mask = torch.ones(num_in_channels, dtype=torch.bool)
        new_mask[-len(mask):] = mask
        new_in_channels = new_mask.sum().item()
        new_layer = nn.Conv2d(new_in_channels, layer.out_channels,
                              kernel_size=layer.kernel_size, stride=layer.stride,
                              padding=layer.padding, bias=(layer.bias is not None))

        new_layer.weight.data = layer.weight.data[:, new_mask, :, :].clone()
        if layer.bias is not None:
            new_layer.bias.data = layer.bias.data.clone()
        replace_layer_in_model(model, layer, new_layer)
    elif block is model.layers[8]:
        layer = model.layers[18].conv
        num_in_channels = layer.in_channels
        new_mask = torch.ones(num_in_channels, dtype=torch.bool)
        new_mask[-len(mask):] = mask
        new_in_channels = new_mask.sum().item()
        new_layer = nn.Conv2d(new_in_channels, layer.out_channels,
                              kernel_size=layer.kernel_size, stride=layer.stride,
                              padding=layer.padding, bias=(layer.bias is not None))

        new_layer.weight.data = layer.weight.data[:, new_mask, :, :].clone()
        if layer.bias is not None:
            new_layer.bias.data = layer.bias.data.clone()
        replace_layer_in_model(model, layer, new_layer)


# New function to prune the first layer in a residual block
def prune_first_layer(block_idx, seq_idx, prune_percentages, model, loss_fn, scaler, scaled_anchors, train_loader, validation_loader, device, original_loss, post_pruned_threshold, post_retrained_threshold, scale_pred=False):
    count = 0
    prune_perc = prune_percentages[count]
    while True:
        torch.save(model, 'model_to_revert.pth')
        if scale_pred:
            block = model.layers[block_idx]
            conv_layer = block.pred[0].conv
            bn_layer = block.pred[0].bn
            next_layer = block.pred[1].conv
        else:
            block = model.layers[block_idx]
            sequential = block.layers[seq_idx]
            conv_layer = sequential[0].conv
            bn_layer = sequential[0].bn
            next_layer = sequential[1].conv

        num_filters_to_prune = max(1, int(prune_perc * conv_layer.out_channels))
        if conv_layer.out_channels <= num_filters_to_prune:
            print("First layer pruning stopped due to maximum pruning")
            break

        print("-------Pruning Start(main)-------")
        if CRITERION == "Taylor":
            print("-----Taylor Pruning-----")
            new_conv1, mask = taylor_prune_conv_layer(conv_layer, model, loss_fn, validation_loader, device, scaled_anchors, num_filters_to_prune)
        elif CRITERION == "L1":
            print("-----L1 Pruning-----")
            new_conv1, mask = prune_conv_layer(conv_layer, num_filters_to_prune)
        else: # prune filters randomly
            print("-----Random Pruning-----")
            new_conv1, mask = random_prune_conv_layer(conv_layer, num_filters_to_prune)
        # new_conv1, mask = prune_conv_layer(conv_layer, num_filters_to_prune)

        new_bn = update_batchnorm_for_pruned_conv(bn_layer, mask).to(device)
        new_conv2 = update_next_layer(next_layer, mask, model, loss_fn, validation_loader, device, scaled_anchors)
        replace_layer_in_model(model, conv_layer, new_conv1)
        print("before pruning:", conv_layer)
        replace_layer_in_model(model, bn_layer, new_bn)
        replace_layer_in_model(model, next_layer, new_conv2)
        conv_layer = get_current_layer(model, new_conv1)
        print("After pruning:", conv_layer)
        bn_layer = get_current_layer(model, new_bn)
        next_layer = get_current_layer(model, new_conv2)
        # Evaluate the pruned model
        optimizer = optim.Adam(
            model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
        )
        early_stopping = EarlyStopping(patience=5, min_delta=0.01)
        pruned_loss = validate_fn(validation_loader, model, loss_fn, scaled_anchors)
        print(f"Loss before pruning : {original_loss:.4f}")
        print(f"Loss after pruning : {pruned_loss:.4f}")
        epoch = 0
        if pruned_loss > original_loss * post_pruned_threshold:
            # Retrain the model
            while True:
                print(f"Currently epoch {epoch}")
                mean_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)
                mean_val_loss = validate_fn(validation_loader, model, loss_fn, scaled_anchors)
                # Early Stopping logic
                early_stopping(mean_val_loss)
                if early_stopping.early_stop:
                    print("Early stopping triggered")
                    early_stopping = EarlyStopping(patience=5, min_delta=0.01)
                    break
                if early_stopping.best_loss <= original_loss * post_retrained_threshold:
                    early_stopping = EarlyStopping(patience=5, min_delta=0.01)
                    print("Retrained loss reached the threshold")
                    break
                epoch += 1
            # Evaluate after retraining
            # retrained_loss = evaluate_loss(model, validation_loader, criterion, device)
            retrained_loss = validate_fn(validation_loader, model, loss_fn, scaled_anchors)
            print(f"Loss after retraining: {retrained_loss:.4f}")

            if retrained_loss <= original_loss * post_retrained_threshold:
                # Pruning successful
                print("-----Retraining Succeed-----")
                print(f"Continuing to prune {seq_idx} with the same prune_perc: {prune_perc:.4f}")
            else:
                # Pruning unsuccessful, stop pruning this layer
                # Revert the change
                print("-----Retraining Failed-----")
                if count < len(prune_percentages) - 1:
                    count += 1
                    prune_perc = prune_percentages[count]
                    print(f"Prune Percentage is changed to {prune_perc}")
                    model = torch.load('model_to_revert.pth')
                else:
                    print(f"{block_idx} {seq_idx} pruning done due to hitting max number of pruning percentage")
                    model = torch.load('model_to_revert.pth')
                    break

        else:
            # Pruning successful without retraining, continue pruning this layer with reduced percentage
            print(f"Continuing to prune first layer_block with new prune_perc: {prune_perc:.4f} (no retraining)")
            print("-------------")

    return model

def prune_residual_block(block_idx, block, prune_percentages, model, loss_fn, scaler, scaled_anchors, train_loader, validation_loader, device, original_loss, post_pruned_threshold, post_retrained_threshold):
    conv_layer = None
    bn_layer = None
    next_layer = None
    for seq_idx, sequential in enumerate(block.layers):
        # prune only first CNNBlock
        print(f"{seq_idx} block is currently being pruned")
        # conv_layer = sequential[0].conv
        # bn_layer = sequential[0].bn
        # next_layer = sequential[1].conv

        model = prune_first_layer(block_idx, seq_idx, prune_percentages, model, loss_fn, scaler, scaled_anchors, train_loader, validation_loader, device, original_loss, post_pruned_threshold, post_retrained_threshold)
        print("------------After updated-------------")
        print((model.layers[block_idx]).layers[seq_idx])
        print("--------------------------------------")
        # else:
        #     # first CNNBlock
        #     first_conv_layer = sequential[0].conv
        #     first_bn_layer = sequential[0].bn
        #     first_next_layer = sequential[1].conv
        #     prune_first_layer(first_conv_layer, first_bn_layer, first_next_layer, prune_perc, model, train_loader, validation_loader, trainer, lightning_model, device, original_loss, prune_decay, post_pruned_threshold, post_retrained_threshold)
        #     # second CNNBlock
        #     second_conv_layer = sequential[1].conv
        #     second_bn_layer = sequential[1].bn
        #     second_next_layer = block.layers[idx+1][0].conv
        #     prune_first_layer(second_conv_layer, second_bn_layer, second_next_layer, prune_perc, model, train_loader, validation_loader, trainer, lightning_model, device, original_loss, prune_decay, post_pruned_threshold, post_retrained_threshold)

    print("----------Final ResidualBlock---------")
    print(block)
    return model
def evaluate_accuracy(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

In [None]:
# @title Main for pruning (New Version)
import copy
def main():

    model = torch.load('/content/drive/MyDrive/models/second_model_complete.pth')
    model = model.to(DEVICE)
    # parameter counts
    # Total parameters
    total_params = sum(p.numel() for p in model.parameters())

    # Trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Total Parameters: {total_params}")
    print(f"Trainable Parameters: {trainable_params}")


    optimizer = optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )
    loss_fn = YoloLoss()
    scaler = torch.cuda.amp.GradScaler()

    train_loader, test_loader, train_eval_loader = get_loaders(
        train_images_list, train_labels_list, val_images_list, val_labels_list
    )

    scaled_anchors = (
        torch.tensor(ANCHORS)
        * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    ).to(DEVICE)

    epoch_losses = []  # List to store loss per epoch
    val_losses = []
    early_stopping = EarlyStopping(patience=5, min_delta=0.01)  # Set patience and min_delta here

    initial_loss = validate_fn(train_eval_loader, model, loss_fn, scaled_anchors)
    print(f"Initial validation loss: {initial_loss:.4f}")

    # Set pruning parameters
    post_pruned_threshold = 1.1  # 20% increase in loss
    post_retrained_threshold = 1.1  # 10% increase in loss
    prev_model = None
    prune_percentages = [0.7, 0.5, 0.2]
    # Iterative pruning and retraining
    for idx, layer in enumerate(model.layers):

        if isinstance(layer, ResidualBlock):
            print(f"{idx} is pruning now")
            print("---------ResidualBlock Pruning---------")
            model = prune_residual_block(idx, layer, prune_percentages, model, loss_fn, scaler, scaled_anchors, train_loader, train_eval_loader, DEVICE, initial_loss, post_pruned_threshold, post_retrained_threshold)
            print("----After pruning ResidualBlock----")
            print(model)
            print("-----------------------------------")
            print("-----------------------------------")
        elif isinstance(layer, CNNBlock):
            # only prune a Conv2d not in the block
            print("---------CNNBlock Pruning---------")
            count = 0
            prune_perc = prune_percentages[count]
            while True:
                print(f"{idx} is pruning now")
                #  save model to revert
                torch.save(model, 'model_to_revert.pth')
                layer = model.layers[idx]
                conv_layer = layer.conv
                if layer.use_bn_act:
                    bn_layer = layer.bn
                else:
                    bn_layer = None

                # Prune the layer
                num_filters_to_prune = max(1, int(prune_perc * conv_layer.out_channels))
                if conv_layer.out_channels <= num_filters_to_prune:
                    print(f"{idx} {layer} pruning done due to hitting max pruning filters")
                    break
                print("-------Pruning Start(main)-------")
                if CRITERION == "Taylor":
                    print("-----Taylor Pruning-----")
                    new_layer, mask = taylor_prune_conv_layer(conv_layer, model, loss_fn, train_eval_loader, DEVICE, scaled_anchors, num_filters_to_prune)
                    prev_model = copy.deepcopy(model)
                elif CRITERION == "L1":
                    print("-----L1 Pruning-----")
                    new_layer, mask = prune_conv_layer(conv_layer, num_filters_to_prune)
                else: # prune filters randomly
                    print("-----Random Pruning-----")
                    new_layer, mask = random_prune_conv_layer(conv_layer, num_filters_to_prune)
                # Replace the conv_layer in the model
                replace_layer_in_model(model, conv_layer, new_layer)
                print("before pruning: ", conv_layer)
                # refesh the current layer
                conv_layer = get_current_layer(model, new_layer)
                print("After pruning: ", conv_layer)

                # Update its batchNorm layer
                if bn_layer is not None:
                    new_bn_layer = update_batchnorm_for_pruned_conv(bn_layer, mask).to(DEVICE)
                    replace_layer_in_model(model, bn_layer, new_bn_layer)
                    # refesh the current bn
                    bn_layer = get_current_layer(model, new_bn_layer)
                # Update its next layer if there is
                if idx+1 < len(model.layers):
                    if idx == 16 or idx == 23:
                        update_concat_layer(model, layer, mask)
                    else:
                        next_layer = model.layers[idx+1]
                        if isinstance(next_layer, ScalePrediction) and idx != 28:
                            one_more_next_layer = model.layers[idx+2]
                            one_more_updated_next_layer = update_next_layer(one_more_next_layer, mask, model, loss_fn, train_eval_loader, DEVICE, scaled_anchors, prev_model).to(DEVICE)
                            replace_layer_in_model(model, one_more_next_layer, one_more_updated_next_layer)
                        updated_next_layer = update_next_layer(next_layer, mask, model, loss_fn, train_eval_loader, DEVICE, scaled_anchors, prev_model).to(DEVICE)
                        replace_layer_in_model(model, next_layer, updated_next_layer)


                # Evaluate the pruned model
                optimizer = optim.Adam(
                    model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
                )
                early_stopping = EarlyStopping(patience=5, min_delta=0.01)
                pruned_loss = validate_fn(train_eval_loader, model, loss_fn, scaled_anchors)
                print(f"Initial Loss: {initial_loss:.4f}")
                print(f"Loss After pruning {idx}: {pruned_loss:.4f}")
                epoch = 0
                if pruned_loss > initial_loss * post_pruned_threshold:
                    # Retrain the model
                    while True:
                        print(f"Currently epoch {epoch}")
                        mean_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)
                        mean_val_loss = validate_fn(train_eval_loader, model, loss_fn, scaled_anchors)
                        # Early Stopping logic
                        early_stopping(mean_val_loss)
                        if early_stopping.early_stop:
                            print("Early stopping triggered")
                            early_stopping = EarlyStopping(patience=5, min_delta=0.01)
                            break
                        if early_stopping.best_loss <= initial_loss * post_retrained_threshold:
                            early_stopping = EarlyStopping(patience=5, min_delta=0.01)
                            print("Retrained loss reached the threshold")
                            break
                        epoch += 1

                    # Evaluate after retraining
                    retrained_loss = validate_fn(train_eval_loader, model, loss_fn, scaled_anchors)
                    print(f"Loss after retraining: {retrained_loss:.4f}")

                    if retrained_loss <= initial_loss * post_retrained_threshold:
                        # Pruning successful
                        print("-----Retraining Succeed-----")
                        print(f"Continuing to prune {idx} with the same prune_perc: {prune_perc:.4f}")
                    else:
                        # Pruning unsuccessful, stop pruning this layer
                        # Revert the change
                        print("-----Retraining Failed-----")
                        if count < len(prune_percentages) - 1:
                            count += 1
                            prune_perc = prune_percentages[count]
                            print(f"Prune Percentage is changed to {prune_perc}")
                            model = torch.load('model_to_revert.pth')
                        else:
                            print(f"{idx} {layer} pruning done due to hitting max number of pruning percentage")
                            model = torch.load('model_to_revert.pth')
                            break

                else:
                    # Pruning successful without retraining, continue pruning this layer with reduced percentage
                    print(f"Continuing to prune {idx} with new prune_perc: {prune_perc:.4f} (no retraining)")

            print("----After pruning CNNBlock----")
            print(model)
            print("------------------------------")
            print("------------------------------")

        elif isinstance(layer, ScalePrediction):
            print(f"{idx} is pruning now")
            # conv_layer = layer.pred[0].conv
            # bn_layer = layer.pred[0].bn
            # next_layer = layer.pred[1].conv
            scale_pred = True
            seq_idx = 0
            model = prune_first_layer(idx, seq_idx, prune_percentages, model, loss_fn, scaler, scaled_anchors, train_loader, train_eval_loader, DEVICE, initial_loss, post_pruned_threshold, post_retrained_threshold, scale_pred)
            print("----After pruning ScalePrediction----")
            print(model)
            print("-------------------------------------")
            print("-------------------------------------")




        print("\n--------\n")
        print("\n--------\n")


    # Final evaluation

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total number of parameters: {total_params}")

    # Save the pruned model
    torch.save(model, 'L1_pruned_resnet.pth')


In [None]:
if __name__ == "__main__":
    main()

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
print(model)

In [None]:
# save_checkpoint(model, optimizer, filename=f"pruned_checkpoint.pth.tar")
# @title Save model on drive

def copy_model_file(source_path, destination_path):
    try:
        shutil.copy(source_path, destination_path)
        print(f'Tar file copied to: {destination_path}')
    except Exception as e:
        print(f'An error occurred: {e}')

# Source path of the tar file
source_tar_file_path = '/content/L1_pruned_resnet.pth'

# Destination path in Google Drive
drive_tar_folder_path = '/content/drive/MyDrive/models/'

copy_model_file(source_tar_file_path, drive_tar_folder_path)

In [None]:
# @title VIDEO FRAME TEST
import cv2
import torch
from torchvision import transforms
import time

def process_video(video_path, model, device, conf_threshold=0.5, nms_threshold=0.4, time_limit=240):
    # Load the video
    cap = cv2.VideoCapture(video_path)

    # Get video properties
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Create VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter('output.mp4', fourcc, fps, (width, height))

    # Prepare image transformation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((416, 416)),  # Adjust size according to your model
    ])

    frame_count = 0
    total_time = 0
    total_latency = 0
    start_time_total = time.time()

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame_count += 1
        start_time = time.time()

        # Preprocess the frame
        input_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        input_img = transform(input_img).unsqueeze(0).to(device)

        # Run inference
        with torch.no_grad():
            predictions = model(input_img)

        # Post-process predictions
        # This part needs to be implemented based on your model's output format
        # Need to apply non-max suppression and filter by confidence

        # Draw bounding boxes on the frame
        # This part needs to be implemented based on your post-processing results

        end_time = time.time()
        latency = end_time - start_time
        total_latency += latency
        total_time = end_time - start_time_total

        # Calculate and display FPS and latency
        if frame_count % 10 == 0:  # Update every 10 frames
            avg_fps = frame_count / total_time
            avg_latency = total_latency / frame_count
            fps_text = f"FPS: {avg_fps:.2f}"
            latency_text = f"Latency: {avg_latency*1000:.2f} ms"
            cv2.putText(frame, fps_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            cv2.putText(frame, latency_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

        # Write the frame
        out.write(frame)

        # Check if we've reached the time limit
        if total_time >= time_limit:
            print(f"Reached time limit of {time_limit} seconds. Stopping processing.")
            break

    # Print final average FPS and latency
    avg_fps = frame_count / total_time
    avg_latency = total_latency / frame_count
    print(f"Average FPS: {avg_fps:.2f}")
    print(f"Average Latency: {avg_latency*1000:.2f} ms")
    print(f"Total frames processed: {frame_count}")
    print(f"Total processing time: {total_time:.2f} seconds")

    # Release resources
    cap.release()
    out.release()
    cv2.destroyAllWindows()

# Set up the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 1  # Adjust based on your model
model = torch.load('/content/drive/MyDrive/models/second_pruned_resnet.pth', map_location=device).to(device)

model.eval()

# Process the video
video_path = '/content/drive/MyDrive/datasets/Kvasir_Capsule.mp4'
process_video(video_path, model, device)