# Schemas

## V8DetectionLoss - call function

```python
# 145 = 80 + 4 * 16 + 1
feats = num_classes + box_coord * dfl_channels + distance
```

**Training**
```python
preds = list[
    tensor[1, 145, 80, 80], # [bs, feats, h/1, w/1]
    tensor[1, 145, 40, 40], # [bs, feats, h/1, w/1]
    tensor[1, 145, 20, 20]  # [bs, feats, h/1, w/1]
    ]
```

**Validation**
```python
preds = tuple[
    tensor[2, 84, 6174], # [bs, ??, ??] last value depends on input image
    list[
        tensor[2, 145, 56, 84], # [bs, feats, h/1, w/1] height and width depends on input image
        tensor[2, 145, 28, 42], # [bs, feats, h/2, w/2]
        tensor[2, 145, 14, 21]  # [bs, feats, h/4, w/4]
    ]

]
```

In [26]:
import torch
from ultralytics.utils.tal import make_anchors
no = 145
use_dist = True
device = torch.device('mps')
stride = torch.tensor([8, 16, 32], device=device)

og_w = 640
og_h = 640

input_w = og_w//8
input_h = og_h//8

preds = [
    torch.randn(1, no, input_h, input_w),
    torch.randn(1, no, input_h//2, input_w//2),
    torch.randn(1, no, input_h//4, input_w//4),
]
feats = preds
print("preds.shape", [p.shape for p in preds])

#print("preds.shape", [p.shape for p in preds])
#print("---")

preds = [xi.view(preds[0].shape[0], no, -1) for xi in preds]
#print("preds.shape", [p.shape for p in preds])
#print("---")
preds = torch.cat(preds, 2)
#print("preds.shape", preds.shape)
#print("---")
pred_distri, pred_scores, pred_dist = torch.split(preds, [64, 80, 1], dim=1)
#print("pred_distri.shape", pred_distri.shape)
#print("pred_scores.shape", pred_scores.shape)
#print("pred_dist.shape", pred_dist.shape) # [1, 1, 8400] 8400 = 80*80 + 40*40 + 20*20
#print("---")

pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_dist = pred_dist.permute(0, 2, 1).contiguous()
print("pred_scores.shape", pred_scores.shape) # [1, 8400, 80]
print("pred_distri.shape", pred_distri.shape) # [1, 8400, (4*16)]
print("pred_dist.shape", pred_dist.shape) # [1, 8400, 1]
print("---")
dtype = pred_scores.dtype
batch_size = pred_scores.shape[0]
imgsz = torch.tensor(feats[0].shape[2:], device=device, dtype=dtype) * stride[0]  # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, stride, 0.5)
#print("dtype", dtype) # torch.float32
#print("batch_size", batch_size) # 1
#print("imgsz", imgsz) # [640, 640]
#print("anchor_points.shape", anchor_points.shape) # [8400, 2]
#print("stride_tensor.shape", stride_tensor.shape) # [8400, 1]
#print("---")



preds.shape [torch.Size([1, 145, 80, 80]), torch.Size([1, 145, 40, 40]), torch.Size([1, 145, 20, 20])]
pred_scores.shape torch.Size([1, 8400, 80])
pred_distri.shape torch.Size([1, 8400, 64])
pred_dist.shape torch.Size([1, 8400, 1])
---


In [32]:
# 5 annotated objects with distances
gt_distances = torch.zeros(1, 5, 1)
gt_distances[0, 0, 0] = 0.8
gt_distances[0, 1, 0] = 0.7
gt_distances[0, 2, 0] = 0.6
gt_distances[0, 3, 0] = 0.5
gt_distances[0, 4, 0] = 0.4

target_gt_idx = torch.zeros(1,8400).long()
target_gt_idx[0, 0] = 1
target_gt_idx[0, 1] = 1
target_gt_idx[0, 2] = 1
target_gt_idx[0, 3] = 1

target_gt_idx[0, 8396] = 2
target_gt_idx[0, 8397] = 2
target_gt_idx[0, 8398] = 2
target_gt_idx[0, 8399] = 2

new_tensor = gt_distances.view(-1, gt_distances.shape[-1])[target_gt_idx]
print("new_tensor.shape", new_tensor.shape)
print(new_tensor[0][0][0])
print(new_tensor[0][1][0])
print(new_tensor[0][2][0])
print(new_tensor[0][3][0])
print(new_tensor[0][4][0])
print(new_tensor[0][8396][0])
print(new_tensor[0][8397][0])
print(new_tensor[0][8398][0])
print(new_tensor[0][8399][0])

new_tensor.shape torch.Size([1, 8400, 1])
tensor(0.7000)
tensor(0.7000)
tensor(0.7000)
tensor(0.7000)
tensor(0.8000)
tensor(0.6000)
tensor(0.6000)
tensor(0.6000)
tensor(0.6000)
