In [None]:
pip install torch torchvision scikit-learn matplotlib seaborn tensorboard

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, transforms
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.transforms import v2 as T
from torchvision.io import read_image
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision.ops import nms, box_iou
import torchvision.ops as ops
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops.boxes import masks_to_boxes
import torchvision.transforms.functional as TF
from torchvision.transforms.functional import InterpolationMode, resize
import torchvision
from torchvision import models
from sklearn.metrics import f1_score, recall_score, precision_score, classification_report, confusion_matrix
import seaborn as sns
from torch.nn import functional as F
from collections import Counter, defaultdict

import os
import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from PIL import Image
from transformers import AutoProcessor, AutoModel
import random

In [None]:
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision.io import read_image
import torchvision.transforms.v2 as T


class CustomDataset(Dataset):
    def __init__(self, root_dir, transforms=None, iou_threshold=0.5, min_area=100):
        self.root_dir = root_dir
        self.transforms = transforms
        self.iou_threshold = iou_threshold
        self.min_area = min_area

        self.imgs = []
        self.masks = []
        self.manu_to_label = {}
        label_counter = 1

        for manu_name in os.listdir(root_dir):
            manu_path = os.path.join(root_dir, manu_name)
            if not os.path.isdir(manu_path) or manu_name.startswith('.'):
                continue

            if manu_name.lower() in ["no manufacturer data", "no accession number match", "rti"]:
                continue

            if manu_name not in self.manu_to_label:
                self.manu_to_label[manu_name] = label_counter
                label_counter += 1

            for ind_name in os.listdir(manu_path):
                ind_path = os.path.join(manu_path, ind_name)
                if not os.path.isdir(ind_path) or ind_name.startswith('.'):
                    continue

                images_path = os.path.join(ind_path, "images")
                masks_path = os.path.join(ind_path, "masks")

                if os.path.exists(images_path) and os.path.exists(masks_path):
                    for img_name in os.listdir(images_path):
                        if not img_name.startswith('.'):
                            img_path = os.path.join(images_path, img_name)
                            mask_base = os.path.splitext(img_name)[0]
                            mask_name = f"{mask_base}_mask.jpg"
                            mask_path = os.path.join(masks_path, mask_name)
                            if os.path.exists(mask_path):
                                self.imgs.append((img_path, manu_name))
                                self.masks.append(mask_path)

        print(f"Found {len(self.imgs)} images and {len(self.masks)} masks")
        print("Manufacturer to label mapping:", self.manu_to_label)

    def __len__(self):
        return len(self.imgs)

    def _process_mask(self, mask):


        unique_mask = mask[0, :, :] + mask[1, :, :] * 256 + mask[2, :, :] * 256 * 256
        binary_mask = (unique_mask > 0)
        binary_np = binary_mask.numpy()
        filled_binary = ndimage.binary_fill_holes(binary_np)
        cross = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.bool_)
        closed = ndimage.binary_closing(filled_binary, structure=cross, iterations=1)
        final_filled = ndimage.binary_fill_holes(closed)
        labeled_array, num_features = ndimage.label(final_filled)
        labeled_mask = torch.from_numpy(labeled_array)

        masks = []
        for i in range(1, num_features + 1):
            component_mask = (labeled_mask == i)
            if torch.sum(component_mask) > self.min_area:
                masks.append(component_mask)

        return torch.stack(masks) if masks else torch.zeros((0, mask.shape[1], mask.shape[2]), dtype=torch.bool)

    def _get_tight_boxes(self, masks, base_padding=5):
        boxes = []
        filtered_masks = []

        for mask in masks:
            y_indices, x_indices = torch.where(mask)
            if len(y_indices) > 0 and len(x_indices) > 0:
                x_min_raw = torch.min(x_indices).item()
                y_min_raw = torch.min(y_indices).item()
                x_max_raw = torch.max(x_indices).item()
                y_max_raw = torch.max(y_indices).item()

                raw_width = x_max_raw - x_min_raw
                raw_height = y_max_raw - y_min_raw
                mask_area = torch.sum(mask).item()
                area_factor = min(1.5, max(1.0, (mask_area / 10000) * 1.2))

                if raw_width > raw_height * 2:
                    padding = base_padding * 1.4 * area_factor
                elif raw_height > raw_width * 2:
                    padding = base_padding * 1.4 * area_factor
                else:
                    padding = base_padding * area_factor

                x_min = max(0, x_min_raw - padding)
                y_min = max(0, y_min_raw - padding)
                x_max = min(mask.shape[1], x_max_raw + padding)
                y_max = min(mask.shape[0], y_max_raw + padding)

                width = x_max - x_min
                height = y_max - y_min
                area = width * height
                aspect_ratio = width / height if height > 0 else 0

                if area > self.min_area * 0.8 and 0.1 < aspect_ratio < 10:
                    boxes.append([x_min, y_min, x_max, y_max])
                    filtered_masks.append(mask)

        return (torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 4), dtype=torch.float32)), filtered_masks

    def _apply_nms(self, boxes, scores=None, iou_threshold=None):
        if len(boxes) == 0:
            return torch.tensor([], dtype=torch.int64)
        if scores is None:
            scores = torch.ones(len(boxes))
        threshold = self.iou_threshold if iou_threshold is None else iou_threshold
        from torchvision.ops import nms
        return nms(boxes, scores, threshold)

    def _resize_with_lockstep_precision(self, img, mask, target_size=2048):

        c_img, h, w = img.shape
        c_mask = mask.shape[0]
        scale = min(target_size / h, target_size / w)
        new_h, new_w = int(h * scale), int(w * scale)

        resized_img = resize(img, size=(new_h, new_w), interpolation=InterpolationMode.BILINEAR, antialias=True)
        resized_mask = resize(mask, size=(new_h, new_w), interpolation=InterpolationMode.NEAREST)

        padded_img = torch.zeros((c_img, target_size, target_size), dtype=img.dtype)
        padded_mask = torch.zeros((c_mask, target_size, target_size), dtype=mask.dtype)

        h_offset = (target_size - new_h) // 2
        w_offset = (target_size - new_w) // 2

        padded_img[:, h_offset:h_offset+new_h, w_offset:w_offset+new_w] = resized_img
        padded_mask[:, h_offset:h_offset+new_h, w_offset:w_offset+new_w] = resized_mask

        return padded_img, padded_mask, (h_offset, w_offset, new_h, new_w)

    def __getitem__(self, idx):
        img_path, manu = self.imgs[idx]
        mask_path = self.masks[idx]

        img = read_image(img_path)
        if img.shape[0] == 1:
            img = img.repeat(3, 1, 1)
        img = img.float() / 255.0 if img.max() > 1.0 else img.float()

        mask = read_image(mask_path).byte()
        img, mask, _ = self._resize_with_lockstep_precision(img, mask, 2048)

        masks = self._process_mask(mask)
        boxes, filtered_masks = self._get_tight_boxes(masks, base_padding=2)

        if len(boxes) > 0:
            keep = self._apply_nms(boxes)
            boxes = boxes[keep]
            filtered_masks = [filtered_masks[i] for i in keep] if filtered_masks else []
            masks = torch.stack(filtered_masks) if filtered_masks else torch.zeros((0, mask.shape[1], mask.shape[2]), dtype=torch.bool)
        else:
            masks = torch.zeros((0, mask.shape[1], mask.shape[2]), dtype=torch.bool)

        labels = torch.full((boxes.shape[0],), self.manu_to_label[manu], dtype=torch.int64)
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "area": area,
            "image_id": torch.tensor([idx]),
            "iscrowd": torch.zeros((boxes.shape[0],), dtype=torch.int64)
        }

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        h, w = img.shape[1:]
        boxes = target["boxes"]
        boxes[:, 0].clamp_(0, w - 1)
        boxes[:, 1].clamp_(0, h - 1)
        boxes[:, 2].clamp_(0, w - 1)
        boxes[:, 3].clamp_(0, h - 1)
        target["boxes"] = boxes

        return img, target


In [None]:
def get_transform(train=False):
    return T.Compose([
        T.ToImage(),  # Converts to C x H x W tensor
        T.ToDtype(torch.float, scale=True)  # Converts to float in [0,1]
    ])

def load_model(num_classes, device):
    sizes = ((32,), (64,), (128,), (256,), (512,))
    aspect_ratios = ((0.2, 0.3, 0.5, 1.0, 2.0, 3.0),) * len(sizes)
    anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)
    roi_pooler = MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=7, sampling_ratio=2)

    backbone = resnet_fpn_backbone("resnet18", weights="DEFAULT")
    model = MaskRCNN(
        backbone, num_classes=num_classes,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=roi_pooler,
        rpn_score_thresh=0.3,
        box_score_thresh=0.3,
        box_nms_thresh=0.25
    )
    model.rpn.post_nms_top_n_train = 200
    model.rpn.post_nms_top_n_test = 110 
    model.to(device)
    return model

In [None]:
def visualize_predictions(image_tensor, prediction, ground_truth=None, label_map=None, score_thresh=0.3, iou_thresh=0.3):
    """
    Displays an image with predicted masks, boxes, and labels.
    Matches GT and predicted boxes using IoU, and annotates accordingly.

    Args:
        image_tensor (Tensor): The input image tensor (C, H, W).
        prediction (Dict): Model predictions with keys: boxes, labels, scores, (optional) masks.
        ground_truth (Dict, optional): Ground truth boxes and labels.
        label_map (Dict[int, str], optional): Mapping of class indices to names.
        score_thresh (float): Score threshold for filtering predictions.
        iou_thresh (float): IoU threshold for matching GT and predictions.
    """
    # Prepare image
    image = (image_tensor.clone() * 255).byte()
    if image.ndim == 3 and image.shape[0] == 1:
        image = image.expand(3, -1, -1)

    # Extract and filter predictions
    boxes_pred = prediction['boxes']
    labels_pred = prediction['labels']
    scores_pred = prediction['scores']
    masks_pred = prediction.get('masks')

    keep = scores_pred > score_thresh
    boxes_pred = boxes_pred[keep]
    labels_pred = labels_pred[keep]
    scores_pred = scores_pred[keep]
    if masks_pred is not None and len(masks_pred) > 0:
        masks_pred = masks_pred[keep].squeeze(1) > 0.5

    # Initialize labels for all predicted boxes
    label_texts = [
        f"Pred: {label_map.get(l.item(), l.item())} ({s.item():.2f})" if label_map else f"Pred: {l.item()} ({s.item():.2f})"
        for l, s in zip(labels_pred, scores_pred)
    ]

    # Match predictions to ground truth and update label text if matched
    if ground_truth and len(boxes_pred) > 0 and len(ground_truth['boxes']) > 0:
        boxes_gt = ground_truth['boxes']
        labels_gt = ground_truth['labels']
        ious = box_iou(boxes_gt, boxes_pred)

        for gt_idx, gt_label in enumerate(labels_gt):
            iou_row = ious[gt_idx]
            max_iou, pred_idx = iou_row.max(0)
            if max_iou > iou_thresh:
                pred_idx = pred_idx.item()
                gt_text = label_map.get(gt_label.item(), gt_label.item()) if label_map else str(gt_label.item())
                pred_text = label_map.get(labels_pred[pred_idx].item(), labels_pred[pred_idx].item()) if label_map else str(labels_pred[pred_idx].item())
                score = scores_pred[pred_idx].item()
                label_texts[pred_idx] = f"GT: {gt_text} | Pred: {pred_text} ({score:.2f})"

    # Draw boxes and labels
    image_with_boxes = draw_bounding_boxes(image, boxes_pred, labels=label_texts, width=2)

    # Draw masks if present
    if masks_pred is not None and len(masks_pred) > 0:
        image_with_boxes = draw_segmentation_masks(image_with_boxes, masks=masks_pred, alpha=0.3)

    # Show image
    plt.figure(figsize=(10, 10))
    plt.imshow(image_with_boxes.permute(1, 2, 0))
    plt.axis("off")
    plt.title("GT vs Predicted Labels")
    plt.show()

In [None]:
# TensorBoard visualization function
def visualize_tb(writer, image, output, sample_idx, prefix="prediction", label_map=None):
    """
    Visualize the model's predictions by drawing bounding boxes and masks on the image.
    Args:
        writer (SummaryWriter): TensorBoard SummaryWriter to log the image.
        image (Tensor): The input image tensor.
        output (Dict): The model's output containing boxes, scores, labels, and masks.
        sample_idx (int): The sample index for unique identification.
        prefix (str): Prefix for the TensorBoard tag.
        label_map (Dict): Label mapping for manufacturer names.
    Returns:
        None
    """
    try:
        # Normalize the image to the range [0, 255] and convert to uint8
        if image.dtype != torch.uint8:
            image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
        image = image[:3, ...]  # Ensure the image has 3 channels
        
        if "boxes" in output and len(output["boxes"]) > 0:
            # Apply Non-Maximum Suppression (NMS) to filter boxes
            keep = nms(output["boxes"], output["scores"], iou_threshold=0.3)
            
            # Create labels for the predicted boxes
            if label_map:
                pred_labels = [f"{label_map.get(label.item(), f'manu: {label.item()}')}|{score:.3f}%"
                               for label, score in zip(output["labels"][keep], output["scores"][keep])]
            else:
                pred_labels = [f"manu: {label}|{score:.3f}%"
                               for label, score in zip(output["labels"][keep], output["scores"][keep])]
            
            # Get the filtered boxes
            pred_boxes = output["boxes"][keep].long()
            # Draw the bounding boxes on the image
            output_image = draw_bounding_boxes(image, pred_boxes, pred_labels, colors="red")
            
            if "masks" in output and len(output["masks"]) > 0:
                # Get the filtered masks
                masks = (output["masks"][keep] > 0.7).squeeze(1)
                # Draw the segmentation masks on the image
                output_image = draw_segmentation_masks(output_image, masks, alpha=0.5, colors="blue")
        else:
            output_image = image
            
        # Log the image with bounding boxes and masks to TensorBoard
        writer.add_image(f"{prefix}_From_Sample_{sample_idx}", output_image, sample_idx)
    except Exception as e:
        print(f"Visualization error: {str(e)}")

In [None]:
def evaluate_model(model, dataloader, device, dataset, iou_threshold=0.3, visualize_samples=10):
    """
    Evaluate model and visualize predictions for all images.
    """
    # Add TensorBoard writer initialization
    writer = SummaryWriter('runs/mask_rcnn_evaluation_latest_v1')
    
    model.eval()
    y_true = []
    y_pred = []
    
    label_map = dataset.dataset.manu_to_label
    inv_label_map = {v: k for k, v in label_map.items()}
    
    # Add counter for TensorBoard
    tb_sample_counter = 0
    batch_counter = 0
    running_correct = 0
    running_total = 0
    
    # For saving sample images
    sample_counter = 0

    with torch.no_grad():
        for images, targets in tqdm(dataloader, desc="Evaluating and Visualizing"):
            images = [img.to(device) for img in images]
            predictions = model(images)

            for image, pred, gt in zip(images, predictions, targets):
                gt_boxes = gt['boxes'].cpu()
                gt_labels = gt['labels'].cpu().tolist()

                pred_boxes = pred['boxes'].cpu()
                pred_scores = pred['scores'].cpu()
                pred_labels = pred['labels'].cpu()
                pred_masks = pred['masks'].cpu().squeeze(1) if 'masks' in pred else None

                # Simple threshold - no adaptive correction
                keep = pred_scores > 0.3
                pred_boxes = pred_boxes[keep]
                pred_labels = pred_labels[keep]
                pred_scores = pred_scores[keep]
                pred_masks = pred_masks[keep] if pred_masks is not None else None

                if len(pred_boxes) == 0:
                    y_true.extend(gt_labels)
                    y_pred.extend([0] * len(gt_labels))
                else:
                    ious = box_iou(gt_boxes, pred_boxes)
                    for i, gt_label in enumerate(gt_labels):
                        iou_row = ious[i]
                        max_iou, max_idx = iou_row.max(0)
                        if max_iou > iou_threshold:
                            matched_label = pred_labels[max_idx].item()
                            y_true.append(gt_label)
                            y_pred.append(matched_label)
                        else:
                            y_true.append(gt_label)
                            y_pred.append(0)

                # Update running accuracy
                batch_pred = y_pred[-len(gt_labels):]
                batch_true = y_true[-len(gt_labels):]
                running_correct += sum([t == p for t, p in zip(batch_true, batch_pred)])
                running_total += len(batch_true)
                
                # Log running accuracy every 10 batches
                if batch_counter % 10 == 0 and running_total > 0:
                    current_acc = 100.0 * running_correct / running_total
                    writer.add_scalar('Accuracy/per_screw_running', current_acc, batch_counter)
                
                batch_counter += 1

                # Add TensorBoard visualization for detections
                if tb_sample_counter < 50:
                    pred_for_tb = {
                        'boxes': pred_boxes,
                        'labels': pred_labels,
                        'scores': pred_scores,
                        'masks': pred_masks
                    }
                    visualize_tb(writer, image.cpu(), pred_for_tb, tb_sample_counter, 
                               "predictions", inv_label_map)
                    tb_sample_counter += 1

                if sample_counter < visualize_samples:
                    img_cpu = image.cpu()
                    
                    # Convert image to proper format for visualization
                    if img_cpu.dtype != torch.uint8:
                        img_cpu = (img_cpu * 255).byte()
                    
                    # Create figure
                    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 8))
                    
                    # Ground truth visualization
                    gt_img = img_cpu.clone()
                    if len(gt_boxes) > 0:
                        gt_labels_text = [f"{inv_label_map.get(label, f'Class {label}')}" 
                                         for label in gt_labels]
                        gt_img = draw_bounding_boxes(gt_img, gt_boxes, 
                                                    labels=gt_labels_text, 
                                                    colors="green", width=2)
                    
                    ax1.imshow(gt_img.permute(1, 2, 0))
                    ax1.set_title("Ground Truth")
                    ax1.axis('off')
                    
                    # Predictions visualization
                    pred_img = img_cpu.clone()
                    if len(pred_boxes) > 0:
                        pred_labels_text = [f"{inv_label_map.get(label.item(), f'Class {label.item()}')} ({score:.2f})" 
                                           for label, score in zip(pred_labels, pred_scores)]
                        pred_img = draw_bounding_boxes(pred_img, pred_boxes, 
                                                      labels=pred_labels_text, 
                                                      colors="red", width=2)
                        
                        if pred_masks is not None and len(pred_masks) > 0:
                            pred_img = draw_segmentation_masks(pred_img, 
                                                             masks=pred_masks > 0.5, 
                                                             alpha=0.4)
                    
                    ax2.imshow(pred_img.permute(1, 2, 0))
                    ax2.set_title("Predictions")
                    ax2.axis('off')
                    
                    plt.tight_layout()

                    writer.add_figure(f'Sample_Images/sample_{sample_counter}', fig, global_step=0)
                    
                    plt.savefig(f'sample_evaluation_{sample_counter}.png', 
                               bbox_inches='tight', dpi=150)
                    plt.close(fig)
                    
                    sample_counter += 1

    print("\n--- Per-Screw Classification Report ---")
    
    report_dict = classification_report(y_true, y_pred, labels=list(range(1, 7)), 
                                       zero_division=0, output_dict=True)
    print(classification_report(y_true, y_pred, labels=list(range(1, 7)), zero_division=0))
    
    # Log per-class metrics to TensorBoard
    for label in range(1, 7):
        label_name = inv_label_map.get(label, f"Class_{label}")
        if str(label) in report_dict:
            writer.add_scalar(f'Metrics/{label_name}/Precision', 
                            report_dict[str(label)]['precision'], 0)
            writer.add_scalar(f'Metrics/{label_name}/Recall', 
                            report_dict[str(label)]['recall'], 0)
            writer.add_scalar(f'Metrics/{label_name}/F1-Score', 
                            report_dict[str(label)]['f1-score'], 0)
    
    # Create and log confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=list(range(1, 7)))
    print("Confusion Matrix:")
    print(cm)
    
    # Create confusion matrix figure
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[inv_label_map.get(i, f"Class_{i}") for i in range(1, 7)],
                yticklabels=[inv_label_map.get(i, f"Class_{i}") for i in range(1, 7)])
    plt.title('Confusion Matrix - Per Screw Classification')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    
    # Add to TensorBoard
    writer.add_figure('Confusion_Matrix/Per_Screw', plt.gcf(), 0)
    plt.close()
    
    # Create metrics bar chart
    plt.figure(figsize=(12, 6))
    metrics_data = []
    labels_list = []
    
    for label in range(1, 7):
        label_name = inv_label_map.get(label, f"Class_{label}")
        if str(label) in report_dict:
            metrics_data.append([
                report_dict[str(label)]['precision'],
                report_dict[str(label)]['recall'],
                report_dict[str(label)]['f1-score']
            ])
            labels_list.append(label_name)
    
    metrics_data = np.array(metrics_data)
    x = np.arange(len(labels_list))
    width = 0.25
    
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(x - width, metrics_data[:, 0], width, label='Precision')
    ax.bar(x, metrics_data[:, 1], width, label='Recall')
    ax.bar(x + width, metrics_data[:, 2], width, label='F1-Score')
    
    ax.set_xlabel('Manufacturer')
    ax.set_ylabel('Score')
    ax.set_title('Classification Metrics by Manufacturer')
    ax.set_xticks(x)
    ax.set_xticklabels(labels_list, rotation=45)
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    writer.add_figure('Metrics/Per_Class_Comparison', plt.gcf(), 0)
    plt.close()
    
    # Calculate final accuracy
    acc = 100.0 * sum([t == p for t, p in zip(y_true, y_pred)]) / len(y_true)
    print(f"\nPer-screw Accuracy: {acc:.2f}%")
    
    # Add final accuracy as the last point
    writer.add_scalar('Accuracy/per_screw_running', acc, batch_counter)
    writer.add_scalar('Accuracy/per_screw_final', acc, 0)
    
    # Create accuracy summary figure
    plt.figure(figsize=(8, 6))
    plt.bar(['Per-screw Accuracy'], [acc], color='blue')
    plt.ylim(0, 100)
    plt.ylabel('Accuracy (%)')
    plt.title('Model Accuracy Summary')
    for i, v in enumerate([acc]):
        plt.text(i, v + 1, f'{v:.2f}%', ha='center', va='bottom')
    plt.grid(axis='y', alpha=0.3)
    writer.add_figure('Accuracy/Summary', plt.gcf(), 0)
    plt.close()
    
    # Add overall metrics to TensorBoard
    writer.add_scalar('Metrics/Overall/Precision', report_dict['weighted avg']['precision'], 0)
    writer.add_scalar('Metrics/Overall/Recall', report_dict['weighted avg']['recall'], 0)
    writer.add_scalar('Metrics/Overall/F1-Score', report_dict['weighted avg']['f1-score'], 0)
    
    # Add histogram of predictions
    writer.add_histogram('Predictions/Distribution', torch.tensor(y_pred), 0)
    writer.add_histogram('GroundTruth/Distribution', torch.tensor(y_true), 0)
    
    writer.close()
    
    print(f"\nTensorBoard logs saved. Run 'tensorboard --logdir runs' to view.")
    print(f"Sample images saved: {sample_counter} images")
    
    return acc

In [None]:
def evaluate_model_with_voting(model, dataset, device, n_classes=7, iou_threshold=0.3, visualize_samples=10):
    writer = SummaryWriter('runs/mask_rcnn_voting_evaluation_latest_v1')
    
    model.eval()
    data_loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda b: tuple(zip(*b)))

    y_true_screw, y_pred_screw = [], []
    indivs = defaultdict(lambda: [None, [0.0] * n_classes])
    
    batch_counter = 0
    running_correct = 0
    running_total = 0
    
    # For saving sample images
    sample_images = []
    xray_predictions = {}  # Store predictions per X-ray for visualization

    with torch.no_grad():
        for i, (images, targets) in enumerate(tqdm(data_loader, desc="Evaluating")):
            image = images[0].to(device)
            target = targets[0]
            img_id = target['image_id'].item()
            true_label = target['labels'][0].item() if len(target['labels']) > 0 else 0

            prediction = model([image])[0]
            boxes_pred = prediction['boxes'].cpu()
            scores_pred = prediction['scores'].cpu()
            labels_pred = prediction['labels'].cpu()

            boxes_gt = target['boxes'].cpu()
            labels_gt = target['labels'].cpu()

            keep = scores_pred > 0.3
            boxes_pred = boxes_pred[keep]
            labels_pred = labels_pred[keep]
            scores_pred = scores_pred[keep]

            if len(sample_images) < visualize_samples:
                xray_predictions[img_id] = {
                    'image': image.cpu(),
                    'gt_boxes': boxes_gt,
                    'gt_labels': labels_gt,
                    'pred_boxes': boxes_pred,
                    'pred_labels': labels_pred,
                    'pred_scores': scores_pred,
                    'true_xray_label': true_label
                }

            if len(boxes_pred) == 0:
                for gt_label in labels_gt:
                    y_true_screw.append(gt_label.item())
                    y_pred_screw.append(0)
                    indivs[img_id][0] = gt_label.item()
                continue

            ious = box_iou(boxes_gt, boxes_pred)
            for idx, gt_label in enumerate(labels_gt):
                iou_row = ious[idx]
                max_iou, max_idx = iou_row.max(0)
                if max_iou > iou_threshold:
                    pred_label = labels_pred[max_idx].item()
                    pred_score = scores_pred[max_idx].item()
                else:
                    pred_label = 0
                    pred_score = 0.0

                y_true_screw.append(gt_label.item())
                y_pred_screw.append(pred_label)

                # Voting update
                indivs[img_id][0] = gt_label.item()
                if pred_label > 0:
                    indivs[img_id][1][pred_label - 1] += pred_score
            
            # Update running accuracy
            if len(y_true_screw) > running_total:
                batch_pred = y_pred_screw[running_total:]
                batch_true = y_true_screw[running_total:]
                running_correct += sum([t == p for t, p in zip(batch_true, batch_pred)])
                running_total = len(y_true_screw)
                
                # Log running accuracy
                if batch_counter % 10 == 0 and running_total > 0:
                    current_acc = 100.0 * running_correct / running_total
                    writer.add_scalar('Accuracy/per_screw_voting_running', current_acc, batch_counter)
                
                batch_counter += 1

    # Aggregate voting results
    aggregate_correct = 0
    y_true_xray = []
    y_pred_xray = []
    
    label_map = dataset.dataset.manu_to_label
    inv_label_map = {v: k for k, v in label_map.items()}
    
    for img_id, (true_label, scores) in indivs.items():
        if true_label is None:
            continue
        
        score_tensor = torch.tensor(scores)
        max_score = torch.max(score_tensor)
        
        if max_score == 0:
            predicted_label = 0
        else:
            top_classes = (score_tensor == max_score).nonzero(as_tuple=True)[0]
            
            if len(top_classes) == 1:
                predicted_label = top_classes.item() + 1
            else:
                predicted_label = top_classes[torch.argmax(score_tensor[top_classes])].item() + 1

        y_true_xray.append(true_label)
        y_pred_xray.append(predicted_label)
        
        if img_id in xray_predictions:
            xray_predictions[img_id]['voted_label'] = predicted_label
            xray_predictions[img_id]['voting_scores'] = scores
        
        if predicted_label == true_label:
            aggregate_correct += 1

    # Visualize sample X-rays with voting results
    for idx, (img_id, data) in enumerate(xray_predictions.items()):
        if idx >= visualize_samples:
            break
            
        img_cpu = data['image']
        if img_cpu.dtype != torch.uint8:
            img_cpu = (img_cpu * 255).byte()
        
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 7))
        
        # Ground truth
        gt_img = img_cpu.clone()
        if len(data['gt_boxes']) > 0:
            gt_labels_text = [f"{inv_label_map.get(label.item(), f'Class {label.item()}')}" 
                             for label in data['gt_labels']]
            gt_img = draw_bounding_boxes(gt_img, data['gt_boxes'], 
                                        labels=gt_labels_text, 
                                        colors="green", width=2)
        
        ax1.imshow(gt_img.permute(1, 2, 0))
        ax1.set_title(f"Ground Truth\nX-ray Label: {inv_label_map.get(data['true_xray_label'], f'Class {data["true_xray_label"]}')}")
        ax1.axis('off')
        
        # Predictions
        pred_img = img_cpu.clone()
        if len(data['pred_boxes']) > 0:
            pred_labels_text = [f"{inv_label_map.get(label.item(), f'Class {label.item()}')} ({score:.2f})" 
                               for label, score in zip(data['pred_labels'], data['pred_scores'])]
            pred_img = draw_bounding_boxes(pred_img, data['pred_boxes'], 
                                          labels=pred_labels_text, 
                                          colors="red", width=2)
        
        ax2.imshow(pred_img.permute(1, 2, 0))
        ax2.set_title("Individual Predictions")
        ax2.axis('off')
        
        # Voting result visualization
        voting_img = img_cpu.clone()
        voted_label = data.get('voted_label', 0)
        voting_scores = data.get('voting_scores', [0.0] * n_classes)
        
        # Create bar chart for voting scores
        ax3.bar(range(1, len(voting_scores) + 1), voting_scores)
        ax3.set_xlabel('Class')
        ax3.set_ylabel('Cumulative Score')
        ax3.set_title(f"Voting Result\nPredicted: {inv_label_map.get(voted_label, f'Class {voted_label}')}")
        ax3.set_xticks(range(1, len(voting_scores) + 1))
        ax3.set_xticklabels([inv_label_map.get(i, f'{i}') for i in range(1, len(voting_scores) + 1)], 
                           rotation=45)
        ax3.grid(axis='y', alpha=0.3)
        
        if voted_label > 0:
            ax3.bar(voted_label, voting_scores[voted_label-1], color='green', alpha=0.7)
        
        plt.tight_layout()
        
        # Save to TensorBoard
        writer.add_figure(f'Voting_Samples/xray_{idx}', fig, 0)
        plt.close()
        
        # Save as file
        fig.savefig(f'voting_sample_{idx}.png', bbox_inches='tight', dpi=150)
        
        sample_images.append(fig)

    # Create classification reports and confusion matrices
    print("\n--- Per-Screw Classification Report ---")
    report_dict = classification_report(
        y_true_screw, y_pred_screw, labels=list(range(1, n_classes)), 
        zero_division=0, output_dict=True
    )
    print(classification_report(
        y_true_screw, y_pred_screw, labels=list(range(1, n_classes)), zero_division=0
    ))
    
    cm = confusion_matrix(y_true_screw, y_pred_screw, labels=list(range(1, n_classes)))
    print("Confusion Matrix:")
    print(cm)
    
    # Create confusion matrix figure
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[inv_label_map.get(i, f"Class_{i}") for i in range(1, n_classes)],
                yticklabels=[inv_label_map.get(i, f"Class_{i}") for i in range(1, n_classes)])
    plt.title('Confusion Matrix - Per Screw')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    
    writer.add_figure('Confusion_Matrix/Per_Screw', plt.gcf(), 0)
    plt.close()

    print(f"\n--- Per-Xray (Voting) Classification Report ---")
    print(classification_report(
        y_true_xray, y_pred_xray, labels=list(range(1, n_classes)), zero_division=0
    ))
    
    cm_xray = confusion_matrix(y_true_xray, y_pred_xray, labels=list(range(1, n_classes)))
    print("Confusion Matrix (Per-Xray):")
    print(cm_xray)

    aggregate_accuracy = 100.0 * aggregate_correct / len(indivs)
    screw_accuracy = 100.0 * sum([p == t for p, t in zip(y_pred_screw, y_true_screw)]) / len(y_true_screw)

    print(f"\n--- Accuracy Summary ---")
    print(f"Per-screw classification accuracy: {screw_accuracy:.2f}%")
    print(f"Aggregate (per-X-ray) accuracy: {aggregate_accuracy:.2f}%")
    
    # Add final accuracies
    writer.add_scalar('Accuracy/per_screw_voting_final', screw_accuracy, 0)
    writer.add_scalar('Accuracy/aggregate_per_xray_final', aggregate_accuracy, 0)
    
    # Create confusion matrix for per-xray results
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_xray, annot=True, fmt='d', cmap='Greens',
                xticklabels=[inv_label_map.get(i, f"Class_{i}") for i in range(1, n_classes)],
                yticklabels=[inv_label_map.get(i, f"Class_{i}") for i in range(1, n_classes)])
    plt.title('Confusion Matrix - Per X-ray (Voting)')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    
    writer.add_figure('Confusion_Matrix/Per_Xray_Voting', plt.gcf(), 0)
    plt.close()
    
    writer.close()
    
    print(f"\nTensorBoard logs saved. Run 'tensorboard --logdir runs' to view.")
    print(f"Sample images saved: {len(sample_images)} images")
    
    return screw_accuracy, aggregate_accuracy

In [None]:
# Main execution code
data_root = "xray_data/valid_data"
dataset = CustomDataset(data_root, transforms=get_transform(train=False))

# Split dataset into train/val/test
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.2 * total_size)
test_size = total_size - train_size - val_size
_, _, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

# Load model and evaluate
num_classes = len(dataset.manu_to_label) + 1
weights_path = "checkpoints_improved_random_sampler_new_v1/final_model_checkpoint_random_sampler.pth"
checkpoint = torch.load(weights_path, map_location=device)

model = load_model(num_classes=num_classes, device=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Run evaluation with TensorBoard visualization
evaluate_model(model, DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=lambda b: tuple(zip(*b))), device, test_dataset)
evaluate_model_with_voting(model, test_dataset, device, n_classes=num_classes)

print("\n Evaluation complete! To view results in TensorBoard, run:")
print("tensorboard --logdir runs")