In [1]:
import pickle

# Open the pickle file in binary read mode ('rb')
with open('my_dict.pkl', 'rb') as file:
    batch = pickle.load(file)

# Now 'data' contains the Python object that was saved
print(batch.keys())

dict_keys(['batch_idx', 'bboxes', 'cls', 'im_file', 'img', 'ori_shape', 'resized_shape'])


In [2]:
# Open the pickle file in binary read mode ('rb')
with open('t_list.pkl', 'rb') as file:
    t_list = pickle.load(file)

# Now 'data' contains the Python object that was saved
print(len(t_list))

2


In [62]:
t_list[0].shape, t_list[1].shape, t_list[2].shape,

(torch.Size([4, 144, 80, 80]),
 torch.Size([4, 144, 40, 40]),
 torch.Size([4, 144, 20, 20]))

In [63]:
import torch
import torch.nn.functional as F

def make_anchors(feats, strides, device=None):
    """
    Create anchor grid (cell centers) for each feature level.
    feats: list of feature tensors [B, C, H, W]
    strides: list of strides per level (len = len(feats))
    returns:
        anchor_points: [N, 2] (x, y) absolute coords
        stride_tensor: [N]
    """
    anchor_points = []
    stride_tensor = []
    for f, s in zip(feats, strides):
        _, _, h, w = f.shape
        # grid of centers in feature coordinates
        yv, xv = torch.meshgrid(
            torch.arange(h, device=f.device),
            torch.arange(w, device=f.device),
            indexing="ij",
        )
        # center of each cell: (x + 0.5, y + 0.5) * stride
        xy = torch.stack((xv + 0.5, yv + 0.5), dim=-1).view(-1, 2) * s
        anchor_points.append(xy)
        stride_tensor.append(torch.full((h * w,), s, device=f.device, dtype=torch.float32))
    anchor_points = torch.cat(anchor_points, dim=0)         # [N, 2]
    stride_tensor = torch.cat(stride_tensor, dim=0)         # [N]
    return anchor_points, stride_tensor


def dfl_to_distances(dfl_logits, num_bins=16):
    """
    Convert DFL logits to expected distances.
    dfl_logits: [B, 4*num_bins, N]  (flattened H*W over all levels)
    returns:
        distances: [B, 4, N] in bin units (no stride applied)
    """
    B, C, N = dfl_logits.shape
    assert C == 4 * num_bins, f"Expected {4*num_bins} channels for DFL, got {C}"
    
    # [B, 4, num_bins, N]
    dfl_logits = dfl_logits.view(B, 4, num_bins, N)
    # softmax over bins
    prob = F.softmax(dfl_logits, dim=2)
    # bin indices [0..num_bins-1]
    proj = torch.arange(num_bins, device=dfl_logits.device, dtype=torch.float32)
    # expectation over bins: sum(prob * index)
    distances = (prob * proj.view(1, 1, num_bins, 1)).sum(dim=2)  # [B, 4, N]
    return distances


def yolo_head_to_unified(preds, strides=(8, 16, 32), num_classes=80, num_bins=16):
    """
    preds: list of [B, 144, H, W] = [B, 64 + num_classes, H, W]
    strides: stride per feature map (same length as preds)
    returns:
        out: [B, 4 + num_classes, N_total] = [B, 84, 8400]
        where coords are [x, y, w, h] in absolute image coords (assuming stride scale).
    """
    # check basic assumptions
    B = preds[0].shape[0]
    device = preds[0].device

    # concat feature maps spatially
    # each level: [B, C, H, W] -> [B, C, H*W]
    flattened = []
    for p in preds:
        b, c, h, w = p.shape
        assert c == 4 * num_bins + num_classes, f"Expected {4*num_bins + num_classes} channels, got {c}"
        flattened.append(p.view(b, c, -1))
    # [B, C, N_total]
    pred_all = torch.cat(flattened, dim=2)
    pred_all.shape
    # split into DFL logits and class logits
    dfl_logits = pred_all[:, :4*num_bins, :]        # [B, 64, N_total]
    cls_logits = pred_all[:, 4*num_bins:, :]        # [B, num_classes, N_total]

    # convert DFL to distances (in feature-cell units)
    distances = dfl_to_distances(dfl_logits, num_bins=num_bins)  # [B, 4, N_total] (l, t, r, b) in bins

    # build anchors and strides for all levels
    anchor_points, stride_tensor = make_anchors(preds, strides, device=device)  # [N_total, 2], [N_total]
    N_total = anchor_points.shape[0]
    assert N_total == pred_all.shape[2], "Anchor count mismatch with prediction locations"

    # apply stride to distances: convert bin units to absolute pixels
    # distances: [B, 4, N]; stride_tensor: [N]
    stride_broadcast = stride_tensor.view(1, 1, N_total)
    distances_abs = distances * stride_broadcast  # [B, 4, N]

    # l, t, r, b -> x_center, y_center, w, h
    l = distances_abs[:, 0, :]   # [B, N]
    t = distances_abs[:, 1, :]
    r = distances_abs[:, 2, :]
    b = distances_abs[:, 3, :]

    # anchor_points: [N, 2] -> [1, N, 2]
    ap = anchor_points.view(1, N_total, 2)
    x_center = ap[..., 0]  # [1, N]
    y_center = ap[..., 1]  # [1, N]

    # x1 = x_center - l, y1 = y_center - t, x2 = x_center + r, y2 = y_center + b
    # then convert to center-width-height
    x1 = x_center - l
    y1 = y_center - t
    x2 = x_center + r
    y2 = y_center + b

    cx = (x1 + x2) / 2.0
    cy = (y1 + y2) / 2.0
    w  = x2 - x1
    h  = y2 - y1

    # stack coords: [B, 4, N]
    coords = torch.stack([cx, cy, w, h], dim=1)  # [B, 4, N]
    print("cls_shape",cls_logits.shape)
    # sigmoid the class logits: [B, num_classes, N]
    cls_scores = cls_logits.sigmoid()

    # final unified tensor: [B, 4 + num_classes, N]
    out = torch.cat([coords, cls_scores], dim=1)
    return out


# EXAMPLE USAGE:
if __name__ == "__main__":
    B = 4
    num_classes = 80
    num_bins = 16
    C = 4 * num_bins + num_classes  # 64 + 80 = 144

    preds = t_list

    out = yolo_head_to_unified(preds, strides=(8, 16, 32), num_classes=num_classes, num_bins=num_bins)
    print(out.shape)  # should be [4, 84, 8400]

cls_shape torch.Size([4, 80, 8400])
torch.Size([4, 84, 8400])


In [64]:
prediction = out[1].unsqueeze(0)
print(prediction.shape)

torch.Size([1, 84, 8400])


In [65]:
def xywh2xyxy(x):
    """
    Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
    top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.

    Args:
        x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.

    Returns:
        (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.
    """
    assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
    y = empty_like(x)  # faster than clone/copy
    xy = x[..., :2]  # centers
    wh = x[..., 2:] / 2  # half width-height
    y[..., :2] = xy - wh  # top left xy
    y[..., 2:] = xy + wh  # bottom right xy
    return y

In [66]:
def empty_like(x):
    """Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
    return (
        torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
    )

In [67]:
conf_thres = 0.25
iou_thres = 0.75
classes=None
agnostic = False
multi_label = False
labels=()
max_det = 300
nc: int = 0
max_time_img = 0.05
max_nms = 30000
max_wh = 7680
in_place = True
rotated = False
end2end = False
return_idxs = False

In [68]:
bs = prediction.shape[0]  # batch size (BCN, i.e. 1,84,6300)
nc = nc or (prediction.shape[1] - 4)  # number of classes
extra = prediction.shape[1] - nc - 4  # number of extra info
mi = 4 + nc  # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates
xinds = torch.stack([torch.arange(len(i),) for i in xc])[..., None]  # to track idxs

In [69]:
print(bs, nc, extra, mi, xc, xinds, sep= "\n")

1
80
0
84
tensor([[False, False, False,  ..., False, False, False]])
tensor([[[   0],
         [   1],
         [   2],
         ...,
         [8397],
         [8398],
         [8399]]])


In [70]:
xinds.shape

torch.Size([1, 8400, 1])

In [71]:
xc[xc==True], xc.shape

(tensor([True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]),
 torch.Size([1, 8400]))

In [72]:
# prediction = prediction.transpose(-1, -2)

In [73]:
prediction[..., :4] = xywh2xyxy(prediction[..., :4])

In [74]:
output = [torch.zeros((0, 6 + extra))] * bs
keepi = [torch.zeros((0, 1))] * bs  # to store the kept idxs

In [75]:
# for xi, (x, xk) in enumerate(zip(prediction, xinds)):  # image index, (preds, preds indices)
#     # print(xi, x.shape, xk.shape, sep= "\n")
#     # # Apply constraints
#     # # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0  # width-height
#     filt = xc[xi]  # confidence
#     x, xk = x[filt], xk[filt]

In [76]:
# filt.shape, x.shape, xk.shape

In [77]:
# # Cat apriori labels if autolabelling
# if labels and len(labels[xi]) and not rotated:
#     lb = labels[xi]
#     v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
#     v[:, :4] = xywh2xyxy(lb[:, 1:5])  # box
#     v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
#     x = torch.cat((x, v), 0)

In [78]:
# x.shape

In [79]:
# box, cls, mask = x.split((4, nc, extra), 1)

# if multi_label:
#     i, j = torch.where(cls > conf_thres)
#     x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
#     xk = xk[i]
# else:  # best class only
#     conf, j = cls.max(1, keepdim=True)
#     filt = conf.view(-1) > conf_thres
#     x = torch.cat((box, conf, j.float(), mask), 1)[filt]
#     xk = xk[filt]

In [80]:
# box.shape

In [81]:
# if classes is not None:
#     filt = (x[:, 5:6] == classes).any(1)
#     x, xk = x[filt], xk[filt]

# n = x.shape[0]  # number of boxes
# if n > max_nms:  # excess boxes
#     filt = x[:, 4].argsort(descending=True)[:max_nms]  # sort by confidence and remove excess boxes
#     x, xk = x[filt], xk[filt]

In [82]:
# import torch
# import torchvision
# import numpy as np

# def plot_boxes_on_image(image_tensor, boxes, classes, class_names=None, colors=None):
#     """
#     Plot bounding boxes and class labels directly on an image tensor.
    
#     Args:
#         image_tensor (torch.Tensor): Image tensor in shape (C, H, W) or (H, W, C)
#         boxes (torch.Tensor): Bounding boxes in xyxy format, shape (N, 4)
#         classes (torch.Tensor): Class indices, shape (N,)
#         class_names (list): List of class names
#         colors (list): List of RGB colors for different classes
    
#     Returns:
#         torch.Tensor: Image tensor with boxes and labels drawn
#     """
#     # Ensure image is in (C, H, W) format and uint8
#     if image_tensor.dim() == 3:
#         if image_tensor.shape[0] == 3:  # (C, H, W)
#             image = image_tensor
#         else:  # (H, W, C)
#             image = image_tensor.permute(2, 0, 1)
#     else:
#         raise ValueError("Image tensor must be 3-dimensional")
    
#     # Convert to uint8 if needed
#     if image.dtype != torch.uint8:
#         if image.max() <= 1.0:
#             image = (image * 255).byte()
#         else:
#             image = image.byte()
    
#     # Make a copy to draw on
#     image_with_boxes = image.clone()
#     _, H, W = image_with_boxes.shape
    
#     # Default colors for different classes (BGR format for easy drawing)
#     if colors is None:
#         colors = [
#             (255, 0, 0),    # red
#             (0, 255, 0),    # green  
#             (0, 0, 255),    # blue
#             (255, 255, 0),  # cyan
#             (255, 0, 255),  # magenta
#             (0, 255, 255),  # yellow
#             (128, 0, 0),    # dark red
#             (0, 128, 0),    # dark green
#             (0, 0, 128),    # dark blue
#             (128, 128, 0),  # olive
#         ]
    
#     # Default class names
#     if class_names is None:
#         class_names = [f"class_{i}" for i in range(int(classes.max().item()) + 1)]
    
#     # Convert boxes to pixel coordinates
#     boxes = boxes.clone()
#     boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(0, W)  # x coordinates
#     boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(0, H)  # y coordinates
#     boxes = boxes.int()
    
#     # Draw each box
#     for i, (box, cls_idx) in enumerate(zip(boxes, classes)):
#         cls_idx = int(cls_idx.item())
#         color = colors[cls_idx % len(colors)]
        
#         x1, y1, x2, y2 = box.tolist()
        
#         # Draw bounding box
#         draw_rectangle(image_with_boxes, x1, y1, x2, y2, color, thickness=2)
        
#         # Draw class label background
#         label = class_names[cls_idx]
#         draw_text_with_background(image_with_boxes, label, x1, y1 - 10, color)
    
#     return image_with_boxes

# def draw_rectangle(image, x1, y1, x2, y2, color, thickness=2):
#     """
#     Draw a rectangle on the image tensor.
#     """
#     C, H, W = image.shape
    
#     # Draw horizontal lines
#     for t in range(thickness):
#         # Top line
#         y_top = max(0, min(H-1, y1 + t))
#         if y_top < H:
#             image[:, y_top, max(0, x1):min(W, x2+1)] = torch.tensor(color).view(3, 1)
        
#         # Bottom line  
#         y_bottom = max(0, min(H-1, y2 - t))
#         if y_bottom >= 0:
#             image[:, y_bottom, max(0, x1):min(W, x2+1)] = torch.tensor(color).view(3, 1)
    
#     # Draw vertical lines
#     for t in range(thickness):
#         # Left line
#         x_left = max(0, min(W-1, x1 + t))
#         if x_left < W:
#             image[:, max(0, y1):min(H, y2+1), x_left] = torch.tensor(color).view(3, 1)
        
#         # Right line
#         x_right = max(0, min(W-1, x2 - t))
#         if x_right >= 0:
#             image[:, max(0, y1):min(H, y2+1), x_right] = torch.tensor(color).view(3, 1)

# def draw_text_with_background(image, text, x, y, color, bg_color=(0, 0, 0)):
#     """
#     Draw simple text using rectangles (simplified character drawing).
#     This is a basic implementation - for better text, consider using a proper font rendering.
#     """
#     C, H, W = image.shape
    
#     # Adjust y position to ensure it's within image bounds
#     y = max(10, min(H - 15, y))
#     x = max(0, min(W - len(text) * 6, x))
    
#     # Draw background rectangle for text
#     bg_height = 10
#     bg_width = len(text) * 6
#     for i in range(bg_height):
#         for j in range(bg_width):
#             if y + i < H and x + j < W:
#                 image[:, y + i, x + j] = torch.tensor(bg_color)
    
#     # Draw simple text (using colored pixels)
#     for char_idx, char in enumerate(text):
#         char_x = x + char_idx * 6
#         # Simple character patterns (very basic)
#         if char.isalpha() or char.isdigit():
#             # Draw a simple pattern for each character
#             for i in range(3, 8):  # vertical
#                 for j in range(2, 5):  # horizontal
#                     if y + i < H and char_x + j < W:
#                         image[:, y + i, char_x + j] = torch.tensor(color)

# # Example usage with your NMS output:
# def visualize_nms_results(image_tensor, nms_output, class_names=None):
#     """
#     Visualize NMS results on an image.
    
#     Args:
#         image_tensor: Input image tensor
#         nms_output: Output from non_max_suppression function
#         class_names: List of class names
    
#     Returns:
#         Image tensor with detections drawn
#     """
#     if len(nms_output) == 0 or nms_output[0].shape[0] == 0:
#         return image_tensor
    
#     # Extract boxes and classes from NMS output
#     # Assuming nms_output[0] has shape (N, 6) where columns are: x1, y1, x2, y2, conf, cls
#     detections = nms_output[0]
#     boxes = detections[:, :4]  # x1, y1, x2, y2
#     classes = detections[:, 5]  # class indices
    
#     return plot_boxes_on_image(image_tensor, boxes, classes, class_names)

# # Alternative: Direct visualization from the point where you have box, cls, mask
# def visualize_detections_before_nms(image_tensor, box, cls, conf_thres=0.25, class_names=None):
#     """
#     Visualize detections before NMS is applied.
    
#     Args:
#         image_tensor: Input image tensor
#         box: Bounding boxes from x.split((4, nc, extra), 1)
#         cls: Class predictions from x.split((4, nc, extra), 1)
#         conf_thres: Confidence threshold
#         class_names: List of class names
#     """
#     # Apply confidence threshold and get best class
#     conf, j = cls.max(1, keepdim=True)
#     mask = conf.view(-1) > conf_thres
    
#     boxes_filtered = box[mask]
#     classes_filtered = j[mask].squeeze()
    
#     return plot_boxes_on_image(image_tensor, boxes_filtered, classes_filtered, class_names)

# # Usage example:
# # Assuming you're inside the NMS function after: box, cls, mask = x.split((4, nc, extra), 1)

# # You can visualize like this:
# plot_boxes_on_image(batch["img"][0], box, cls.argmax(1))
# # Or if you want to see only confident detections:
# # conf, j = cls.max(1, keepdim=True)
# # mask = conf.view(-1) > conf_thres
# # image_with_boxes = plot_boxes_on_image(your_image_tensor, box[mask], j[mask].squeeze())

In [83]:
import torchvision

In [84]:
for xi, (x, xk) in enumerate(zip(prediction, xinds)):  # image index, (preds, preds indices)
    # Apply constraints
    # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0  # width-height
    filt = xc[xi]  # confidence
    x, xk = x[filt], xk[filt]

    # Cat apriori labels if autolabelling
    if labels and len(labels[xi]) and not rotated:
        lb = labels[xi]
        v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
        v[:, :4] = xywh2xyxy(lb[:, 1:5])  # box
        v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
        x = torch.cat((x, v), 0)

    # If none remain process next image
    if not x.shape[0]:
        continue

    # Detections matrix nx6 (xyxy, conf, cls)
    box, cls, mask = x.split((4, nc, extra), 1)

    if multi_label:
        i, j = torch.where(cls > conf_thres)
        x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
        xk = xk[i]
    else:  # best class only
        conf, j = cls.max(1, keepdim=True)
        filt = conf.view(-1) > conf_thres
        x = torch.cat((box, conf, j.float(), mask), 1)[filt]
        xk = xk[filt]

    # Filter by class
    if classes is not None:
        filt = (x[:, 5:6] == classes).any(1)
        x, xk = x[filt], xk[filt]

    # Check shape
    n = x.shape[0]  # number of boxes
    if not n:  # no boxes
        continue
    if n > max_nms:  # excess boxes
        filt = x[:, 4].argsort(descending=True)[:max_nms]  # sort by confidence and remove excess boxes
        x, xk = x[filt], xk[filt]

    # Batched NMS
    c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
    scores = x[:, 4]  # scores
    if rotated:
        boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1)  # xywhr
        i = nms_rotated(boxes, scores, iou_thres)
    else:
        boxes = x[:, :4] + c  # boxes (offset by class)
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
    i = i[:max_det]  # limit detections

    output[xi], keepi[xi] = x[i], xk[i].reshape(-1)

IndexError: The shape of the mask [8400] at index 0 does not match the shape of the indexed tensor [84, 8400] at index 0

In [None]:
output, keepi

In [None]:
import torch
import torchvision

def draw_bounding_boxes_on_tensor(image_tensor, boxes_tensor, labels_tensor=None, colors=None):
    """
    Draw bounding boxes on image tensor without matplotlib
    
    Args:
        image_tensor: Tensor of shape (C, H, W) or (H, W, C)
        boxes_tensor: Tensor of shape (N, 4) with [x1, y1, x2, y2] format
        labels_tensor: Optional tensor of shape (N,) with class labels
        colors: Optional list of colors for boxes
    
    Returns:
        Tensor with bounding boxes drawn
    """
    # Ensure image is in (C, H, W) format
    if image_tensor.dim() == 3 and image_tensor.shape[-1] == 3:
        image_tensor = image_tensor.permute(2, 0, 1)
    
    # Clone the image to avoid modifying original
    image_with_boxes = image_tensor.clone()
    
    # Define colors if not provided
    if colors is None:
        colors = [
            (255, 0, 0),    # Red
            (0, 255, 0),    # Green  
            (0, 0, 255),    # Blue
            (255, 255, 0),  # Yellow
            (255, 0, 255),  # Magenta
            (0, 255, 255),  # Cyan
        ]
    
    # Convert boxes to integer coordinates
    boxes = boxes_tensor.int()
    
    # Draw each bounding box
    for i, box in enumerate(boxes):
        x1, y1, x2, y2 = box
        color = colors[i % len(colors)]
        
        # Ensure coordinates are within image bounds
        x1 = max(0, min(x1, image_with_boxes.shape[2] - 1))
        y1 = max(0, min(y1, image_with_boxes.shape[1] - 1))
        x2 = max(0, min(x2, image_with_boxes.shape[2] - 1))
        y2 = max(0, min(y2, image_with_boxes.shape[1] - 1))
        
        # Draw top and bottom horizontal lines
        image_with_boxes[:, y1, x1:x2+1] = torch.tensor(color).view(3, 1)
        image_with_boxes[:, y2, x1:x2+1] = torch.tensor(color).view(3, 1)
        
        # Draw left and right vertical lines
        image_with_boxes[:, y1:y2+1, x1] = torch.tensor(color).view(3, 1)
        image_with_boxes[:, y1:y2+1, x2] = torch.tensor(color).view(3, 1)
    
    return image_with_boxes

# Example usage with your data
def process_detection_data(detection_data, image_tensor):
    """
    Process your detection data and draw bounding boxes
    
    Args:
        detection_data: Your tuple of ([boxes_tensor], [labels_tensor])
        image_tensor: Your input image tensor
    """
    boxes_list, labels_list = detection_data
    boxes_tensor = boxes_list[0]  # Extract the boxes tensor
    labels_tensor = labels_list[0]  # Extract the labels tensor
    
    # Extract only the bounding box coordinates (first 4 columns)
    bbox_coords = boxes_tensor[:, :4]
    
    # Draw bounding boxes
    result_image = draw_bounding_boxes_on_tensor(image_tensor, bbox_coords, labels_tensor)
    
    return result_image

# Alternative using torchvision (if available)
def draw_with_torchvision(image_tensor, boxes_tensor, labels_tensor):
    """
    Alternative method using torchvision utilities
    """
    # Extract bounding box coordinates
    bbox_coords = boxes_tensor[:, :4]
    
    # Create labels for the boxes
    labels = [f"Class {label}" for label in labels_tensor.tolist()]
    
    # Draw using torchvision (requires image in uint8 format)
    if image_tensor.dtype != torch.uint8:
        image_uint8 = (image_tensor * 255).byte()
    else:
        image_uint8 = image_tensor
    
    # Draw bounding boxes
    result = torchvision.utils.draw_bounding_boxes(
        image_uint8, 
        bbox_coords,
        labels=labels,
        colors="red",
        width=2
    )
    
    return result

# Example with dummy image data
def example_usage():
    # Create a dummy image tensor (3, 480, 640)
    dummy_image = torch.randn(3, 480, 640)
    
    # Your detection data
    detection_data = (
        [torch.tensor([[1.5998e+02, 1.5112e+01, 6.3102e+02, 4.5557e+02, 3.8524e-01, 0.0000e+00],
                      [3.2586e+02, 1.6286e+01, 5.1430e+02, 4.5842e+02, 3.2929e-01, 0.0000e+00]])],
        [torch.tensor([8131, 8153])]
    )
    
    # Process and draw bounding boxes
    result = process_detection_data(detection_data, dummy_image)
    
    print(f"Input image shape: {dummy_image.shape}")
    print(f"Output image shape: {result.shape}")
    print(f"Number of boxes: {len(detection_data[0][0])}")
    
    return result

# Run example
if __name__ == "__main__":
    result_image = example_usage()

In [None]:
from PIL import Image
import torchvision.transforms as transforms

def save_tensor_as_image(tensor, filename):
    """
    Save a tensor as an image file
    """
    # Convert tensor to PIL Image
    if tensor.dim() == 3 and tensor.shape[0] == 3:
        tensor = tensor.permute(1, 2, 0)
    
    # Convert to uint8 if needed
    if tensor.dtype != torch.uint8:
        tensor = (tensor * 255).byte()
    
    # Convert to PIL Image and save
    pil_image = transforms.ToPILImage()(tensor.permute(2, 0, 1))
    pil_image.save(filename)
    print(f"Image saved as {filename}")

# Save the result
save_tensor_as_image(batch["img"][0], "bounding_boxes_result.jpg")

In [None]:
import torch
import torchvision

def draw_bounding_boxes_on_tensor(image_tensor, boxes_tensor, labels_tensor=None, colors=None):
    """
    Draw bounding boxes on image tensor without matplotlib
    
    Args:
        image_tensor: Tensor of shape (C, H, W) or (H, W, C)
        boxes_tensor: Tensor of shape (N, 4) with [x1, y1, x2, y2] format
        labels_tensor: Optional tensor of shape (N,) with class labels
        colors: Optional list of colors for boxes
    
    Returns:
        Tensor with bounding boxes drawn
    """
    # Ensure image is in (C, H, W) format
    if image_tensor.dim() == 3 and image_tensor.shape[-1] == 3:
        image_tensor = image_tensor.permute(2, 0, 1)
    
    # Clone the image to avoid modifying original
    image_with_boxes = image_tensor.clone()
    
    # Define colors if not provided
    if colors is None:
        colors = [
            (255, 0, 0),    # Red
            (0, 255, 0),    # Green  
            (0, 0, 255),    # Blue
            (255, 255, 0),  # Yellow
            (255, 0, 255),  # Magenta
            (0, 255, 255),  # Cyan
        ]
    
    # Convert boxes to integer coordinates
    boxes = boxes_tensor.int()
    
    # Draw each bounding box
    for i, box in enumerate(boxes):
        x1, y1, x2, y2 = box
        color = colors[i % len(colors)]
        
        # Ensure coordinates are within image bounds
        x1 = max(0, min(x1, image_with_boxes.shape[2] - 1))
        y1 = max(0, min(y1, image_with_boxes.shape[1] - 1))
        x2 = max(0, min(x2, image_with_boxes.shape[2] - 1))
        y2 = max(0, min(y2, image_with_boxes.shape[1] - 1))
        
        # Draw top and bottom horizontal lines
        image_with_boxes[:, y1, x1:x2+1] = torch.tensor(color).view(3, 1)
        image_with_boxes[:, y2, x1:x2+1] = torch.tensor(color).view(3, 1)
        
        # Draw left and right vertical lines
        image_with_boxes[:, y1:y2+1, x1] = torch.tensor(color).view(3, 1)
        image_with_boxes[:, y1:y2+1, x2] = torch.tensor(color).view(3, 1)
    
    return image_with_boxes

# Example usage with your data
def process_detection_data(detection_data, image_tensor):
    """
    Process your detection data and draw bounding boxes
    
    Args:
        detection_data: Your tuple of ([boxes_tensor], [labels_tensor])
        image_tensor: Your input image tensor
    """
    boxes_list, labels_list = detection_data
    boxes_tensor = boxes_list[0]  # Extract the boxes tensor
    labels_tensor = labels_list[0]  # Extract the labels tensor
    
    # Extract only the bounding box coordinates (first 4 columns)
    bbox_coords = boxes_tensor[:, :4]
    
    # Draw bounding boxes
    result_image = draw_bounding_boxes_on_tensor(image_tensor, bbox_coords, labels_tensor)
    
    return result_image

# Alternative using torchvision (if available)
def draw_with_torchvision(image_tensor, boxes_tensor, labels_tensor):
    """
    Alternative method using torchvision utilities
    """
    # Extract bounding box coordinates
    bbox_coords = boxes_tensor[:, :4]
    
    # Create labels for the boxes
    labels = [f"Class {label}" for label in labels_tensor.tolist()]
    
    # Draw using torchvision (requires image in uint8 format)
    if image_tensor.dtype != torch.uint8:
        image_uint8 = (image_tensor * 255).byte()
    else:
        image_uint8 = image_tensor
    
    # Draw bounding boxes
    result = torchvision.utils.draw_bounding_boxes(
        image_uint8, 
        bbox_coords,
        labels=labels,
        colors="red",
        width=2
    )
    
    return result

# Example with dummy image data
def example_usage():
    # Create a dummy image tensor (3, 480, 640)
    dummy_image = torch.randn(3, 480, 640)
    
    # Your detection data
    detection_data = (
        [torch.tensor([[1.5998e+02, 1.5112e+01, 6.3102e+02, 4.5557e+02, 3.8524e-01, 0.0000e+00],
                      [3.2586e+02, 1.6286e+01, 5.1430e+02, 4.5842e+02, 3.2929e-01, 0.0000e+00]])],
        [torch.tensor([8131, 8153])]
    )
    
    # Process and draw bounding boxes
    result = process_detection_data(detection_data, dummy_image)
    
    print(f"Input image shape: {dummy_image.shape}")
    print(f"Output image shape: {result.shape}")
    print(f"Number of boxes: {len(detection_data[0][0])}")
    
    return result

# Run example
if __name__ == "__main__":
    result_image = example_usage()

In [None]:
from ultralytics.utils.ops import non_max_suppression
from ultralytics.engine.results import Results

# 1) apply YOLO's NMS (same used inside AutoBackend)
nms_output = non_max_suppression(
    prediction,
    conf_thres=0.25,
    iou_thres=0.75,
    max_det=300,
    classes=None,
    agnostic=False
)

In [None]:
nms_output