# Assignment 3 â€” Part 4: Video Inference + Tracking

Abhinav Kumar
11/2/2025

In [None]:
import cv2, numpy as np, torch, torch.nn as nn, torch.nn.functional as F, timm
from torchvision import transforms as T

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 224; THR = 0.5

class ConvBNReLU(nn.Sequential):
    def __init__(self, in_c, out_c, k=3, s=1, p=1):
        super().__init__(nn.Conv2d(in_c, out_c, k, s, p, bias=False),
                         nn.BatchNorm2d(out_c), nn.ReLU(inplace=True))
def make_backbone(name="vit_base_patch16_224.dino"):
    try:
        m = timm.create_model(name, features_only=True, pretrained=True); ch = m.feature_info.channels()
    except Exception:
        m = timm.create_model("resnet50", features_only=True, pretrained=True); ch = m.feature_info.channels()
    return m, ch
class FPNDecoder(nn.Module):
    def __init__(self, feat_channels, out_ch=128):
        super().__init__()
        self.lat = nn.ModuleList([nn.Conv2d(c, out_ch, 1) for c in feat_channels])
        self.smooth = nn.ModuleList([ConvBNReLU(out_ch, out_ch) for _ in feat_channels])
    def forward(self, feats):
        feats = feats[-4:] if len(feats)>=4 else feats
        x=None; outs=[]
        for i in reversed(range(len(feats))):
            f = self.lat[i](feats[i])
            if x is None: x=f
            else: x = f + F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = self.smooth[i](x); outs.append(x)
        outs = list(reversed(outs))
        size = outs[0].shape[-2:]
        up = [F.interpolate(o, size, mode='bilinear', align_corners=False) for o in outs]
        return torch.cat(up, dim=1)
class SegModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone, ch = make_backbone()
        for p in self.backbone.parameters(): p.requires_grad=False
        self.decoder = FPNDecoder(ch)
        self.head = nn.Sequential(ConvBNReLU(128*min(4, len(ch)), 256),
                                  ConvBNReLU(256, 128),
                                  nn.Conv2d(128, 1, 1))
    def forward(self, x):
        feats = self.backbone(x)
        dec = self.decoder(feats)
        logit = self.head(dec)
        return F.interpolate(logit, size=x.shape[-2:], mode="bilinear", align_corners=False)

def iou(a,b):
    xA=max(a[0],b[0]); yA=max(a[1],b[1]); xB=min(a[2],b[2]); yB=min(a[3],b[3])
    inter=max(0,xB-xA)*max(0,yB-yA)
    areaA=max(0,a[2]-a[0])*max(0,a[3]-a[1]); areaB=max(0,b[2]-b[0])*max(0,b[3]-b[1])
    return inter/(areaA+areaB-inter+1e-6)
class Track:
    def __init__(self, tid, box): self.id=tid; self.box=box; self.miss=0
class SimpleDeepSort:
    def __init__(self, iou_thresh=0.3, max_miss=10):
        self.iou_thresh=iou_thresh; self.max_miss=max_miss; self.tracks=[]; self.next_id=1
    def update(self, dets):
        if not self.tracks:
            for d in dets: self._start(d)
        else:
            used=set(); assigned=set()
            I=np.zeros((len(self.tracks), len(dets)), dtype=np.float32)
            for i,t in enumerate(self.tracks):
                for j,d in enumerate(dets): I[i,j]=iou(t.box,d)
            while True:
                i,j=np.unravel_index(np.argmax(I), I.shape)
                if I[i,j]<self.iou_thresh: break
                if i in used or j in assigned: I[i,j]=-1; continue
                self.tracks[i].box=dets[j]; self.tracks[i].miss=0
                used.add(i); assigned.add(j); I[i,:]=-1; I[:,j]=-1
            for j,d in enumerate(dets):
                if j not in assigned: self._start(d)
            for k,t in enumerate(self.tracks):
                if k not in used: t.miss+=1
            self.tracks=[t for t in self.tracks if t.miss<=self.max_miss]
        return [(t.id, t.box.copy()) for t in self.tracks]
    def _start(self,d): self.tracks.append(Track(self.next_id,d)); self.next_id+=1

model = SegModel().to(device).eval()
t_img = T.Compose([T.Resize((IMG_SIZE, IMG_SIZE)), T.ToTensor()])

def mask_to_boxes(mask):
    num, labels = cv2.connectedComponents(mask.astype(np.uint8))
    boxes=[]
    for k in range(1,num):
        ys, xs = np.where(labels==k)
        if len(xs)==0 or len(ys)==0: continue
        x1,x2,y1,y2 = xs.min(), xs.max(), ys.min(), ys.max()
        if (x2-x1)*(y2-y1) > 50:
            boxes.append(np.array([x1,y1,x2,y2], dtype=np.float32))
    return boxes

input_video = "input.mp4"
output_video = "tracked_output.mp4"

cap = cv2.VideoCapture(input_video)
fps = cap.get(cv2.CAP_PROP_FPS) or 30
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)); h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h))
tracker = SimpleDeepSort()

while True:
    ok, frame = cap.read()
    if not ok: break
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    inp = cv2.resize(rgb, (IMG_SIZE, IMG_SIZE))
    ten = torch.from_numpy(inp).permute(2,0,1).float()/255.0
    with torch.no_grad():
        prob = torch.sigmoid(model(ten[None].to(device)))[0,0].cpu().numpy()
    mask_small = (prob>THR).astype(np.uint8)
    mask = cv2.resize(mask_small, (w,h), interpolation=cv2.INTER_NEAREST)
    boxes = mask_to_boxes(mask)
    tracks = tracker.update(boxes)

    out = frame.copy()
    color = np.zeros_like(out); color[...,2]=255
    overlay = (out*0.7 + color*0.3*mask[...,None]).astype(np.uint8)
    for tid, bb in tracks:
        x1,y1,x2,y2 = map(int, bb)
        cv2.rectangle(overlay, (x1,y1),(x2,y2),(0,255,0),2)
        cv2.putText(overlay, f"ID {tid}", (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
    cv2.putText(overlay, f"Dents: {len(tracks)}", (12,28), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,0,255), 2)
    writer.write(overlay)

cap.release(); writer.release()
print("Saved:", output_video)