In [19]:
# Libraries
import torch
from torch import nn, einsum
import torch.nn.functional as f
import torch.optim as optim
import cv2
import numpy as np
import matplotlib.pyplot as plt
import einops
from einops import rearrange
import tqdm
import wandb

In [20]:
# Custom Libraries 
from backbone import Backbone
from neck_bak import swin_t_neck
from neck import Neck
from head import Head
from model import Model

In [None]:
wandb.init(project="your_project_name")

In [21]:
# Generating ground truth

def calculate_offset(image_size, bbox):
    x_center = (bbox[0] + bbox[2]) / 2
    y_center = (bbox[1] + bbox[3]) / 2
    offset_x = x_center / image_size[1]
    offset_y = y_center / image_size[0]
    return offset_x, offset_y

def calculate_width_height(bbox):
    width = bbox[2]
    height = bbox[3]
    return width, height

def generate_ground_truth(image_size, bboxs):
    heatmap = np.zeros(image_size)
    widthmap = np.zeros(image_size)
    heightmap = np.zeros(image_size)
    offsetmap = np.zeros((2,image_size[0], image_size[1]))
    for bbox in bboxs:
        x_center = int((bbox[0] + bbox[2]) / 2)
        y_center = int((bbox[1] + bbox[3]) / 2)
        heatmap[x_center, y_center] = 1 
        width, height = calculate_width_height(bbox) # Set the center of the bounding box to 1
        widthmap[x_center, y_center] = width
        heightmap[x_center, y_center] =  height
        offset = calculate_offset(image_size, bbox)
        offsetmap[0,x_center, y_center] = offset[0]
        offsetmap[1,x_center, y_center] = offset[1]
        
        print(f"X_center: {x_center}, Y_center: {y_center}, Width: {width}, Height: {height}, Offset: {offset}")
    return heatmap, widthmap, heightmap, offsetmap

In [27]:
def upscale_predictions(preds: list[torch.tensor], tgt_size: np.ndarray, intensity_thresh: float=0.8):
    heat, w, h, o = preds

    heat = heat.detach().numpy()
    heat = np.where(heat > intensity_thresh, heat, 0)
    heat = torch.tensor(heat)
    heat = f.interpolate(heat.float(), size=tgt_size, mode='bilinear', align_corners=False)

    w = f.interpolate(w.float(), size=tgt_size, mode='bilinear', align_corners=False)

    h = f.interpolate(h.float(), size=tgt_size, mode='bilinear', align_corners=False)

    o = f.interpolate(o.float(), size=tgt_size, mode='bilinear', align_corners=False)
    # o[1] = f.interpolate(o[1].float(), size=tgt_size, mode='bilinear', align_corners=False)


    heat = einops.rearrange(heat, 'b c h w -> h (w c b)')
    heat = heat.detach().numpy()
    w = einops.rearrange(w, 'b c h w -> h (w c b)')
    w = w.detach().numpy()
    h = einops.rearrange(h, 'b c h w -> h (w c b)')
    h = h.detach().numpy()
    o = einops.rearrange(o, 'b c h w -> (b c) h w')
    o = o.detach().numpy()


    heat = np.where(heat==0, heat, 0)
    w = np.where(heat==0, w, 0)
    h = np.where(heat==0, h, 0)
    o = np.where(heat==0, o, 0)
    # o[1] = np.where(heat==0, o[1], 0

    return heat, w, h, o


In [28]:
def loss(preds, truth):
    p_heat, p_w, p_h, p_o = preds
    g_heat, g_w, g_h, g_o = truth
    heatmap_loss = f.binary_cross_entropy_with_logits(f.sigmoid(p_heat), f.sigmoid(g_heat))
    width_loss = f.l1_loss(p_w, g_w)
    height_loss = f.l1_loss(p_h, g_h)
    offset_loss = f.l1_loss(p_o, g_o)
    return heatmap_loss + width_loss + height_loss + offset_loss

In [24]:
backbone = Backbone(hid_dim=96, layers=[2, 2, 2, 2], heads=[3, 6, 12, 24])

In [22]:
neck_t = Neck(hid_dim=96, layers=[2,2,2,2], heads=[24, 12, 6, 3], channels=768)

In [23]:
head = Head(in_channels=96, num_classes=1)

In [26]:
model = Model(backbone, neck_t, head)

optimizer = optim.SGD(model.parameters(), lr=3e-4)

num_epochs = 5

IMG_SIZE = (2160,3840)

THRESHOLD = 0.8

In [None]:
# THis is how input is defined
# img = cv2.cvtColor(cv2.imread("../assets/sample_run.jpeg"), cv2.COLOR_BGR2RGB)
# img = cv2.resize(img, (1600, 896))
# data = torch.tensor(img).unsqueeze(0).float()
# data = data.permute(0, 3, 1, 2)
# print(data.shape)
# plt.imshow(img)
# plt.axis("off")
# plt.show()

# #These are the bounding box of all the small objects.
# bbox_values = torch.tensor([[2035,1003,9,17], [795,1169,9,17], [2715,1524,9,17], [263,209,9,17], [931,844,9,17], [1621,1398,9,17]]).float()
# #Other bbox values are [795,1169,9,17], [2715,1524,9,17], [263,209,9,17], [931,844,9,17], [1621,1398,9,17]


In [None]:
for epoch in range(num_epochs):
    model.train()

    with tqdm.tqdm(total=len(train_data), desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
        preds = model(data)
        preds = upscale_predictions(preds, tgt_size=IMG_SIZE, intensity_thresh=THRESHOLD)
        targets = generate_ground_truth(IMG_SIZE, bboxes)
        loss = loss(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update tqdm
        pbar.update(1)
        pbar.set_postfix({'loss': loss.item()})

        # Log loss to WandB
        wandb.log({"loss": loss.item()})

print("Training finished!")

        
        