In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import sys

import torch
from hydra import compose, initialize
from yolo import Config, PostProcess, create_converter, create_model, draw_bboxes
from yolo.utils.model_utils import get_device
from yolo.tools.loss_functions import create_loss_function

from ifc_dl.conf.augmentations import get_transform_fn
from ifc_dl.data.mock_coco_dataset import MockCocoDataModule
from ifc_dl.utils import plot_instance_segmentation_data

project_root = Path().resolve().parent
sys.path.append(str(project_root))

In [None]:
image_size = (512, 512)

# Load the datamodule 

In [None]:
augs = {
    "resize": {
        "params": {"height": image_size[0], "width": image_size[1], "interpolation": 3},
        "all_datasets": True,
    },
    "horizontal_flip": {"params": {"p": 0.5}},
}

transforms, val_transform = get_transform_fn(augs)
datamodule = MockCocoDataModule(batch_size=2, transforms=transforms, val_transform=val_transform)
datamodule.setup()
train_dl = datamodule.train_dataloader()

x, y = next(iter(train_dl))
x = torch.stack(x)

In [None]:

def convert_y_for_yolo(y):
    #  convert y to the format that yolo likes 
    max_annotations = max(len(y_["labels"]) for y_ in y)

    y_yolo = torch.ones((x.shape[0], max_annotations, 5)) * -1
    y_masks = torch.ones((x.shape[0], max_annotations, *y[0]["masks"].shape[-2:])) * -1

    for i_y, y_ in enumerate(y):
        for i_instance, (label, box, mask) in enumerate(zip(y_["labels"], y_["boxes"], y_["masks"])):
            y_yolo[i_y, i_instance, 0] = label
            y_yolo[i_y, i_instance, 1:] = box
            
            y_masks[i_y, i_instance] = (mask > 0).to(int)
            
    return y_yolo, y_masks

y_yolo, y_masks = convert_y_for_yolo(y)

# Load the YOLO model

In [None]:
%%capture
CONFIG_PATH = "../YOLO/yolo/config"
CONFIG_NAME = "config-seg"
CLASS_NUM = 91

with initialize(config_path=CONFIG_PATH, version_base=None, job_name="notebook_job"):
    cfg: Config = compose(config_name=CONFIG_NAME)

device, _ = get_device(cfg.device)
model = create_model(
    cfg.model,
    class_num=CLASS_NUM,
)
model = model.to(device)

converter = create_converter(
    cfg.model.name, model, cfg.model.anchor, image_size, device
)

post_proccess = None
if cfg.task.get("nms"):
    post_proccess = PostProcess(converter, cfg.task.nms)

cfg.dataset.class_num = CLASS_NUM
loss = create_loss_function(cfg, converter)

# Forward pass and segmentation masks predictions!

In [None]:
model.train()

out = model(x)

det_logits, seg_logits = out["Main"]
det_logits_aux, seg_logits_aux = out["AUX"]

In [None]:
# NOTE: mask coeffs and last one is the mask prototype
print("--- segmentation head ---")
for l in seg_logits:
    print(l.shape)

print("\n--- detection head ---")
# for each resolution we have: class, object, bbox, mask coefficients 
for l in det_logits:
    for det_l in l: 
        print(det_l.shape)
    print()

In [None]:
from einops import rearrange

def get_mask_preds(seg_logits, sigmoid=False):
    # linear combination of the coefficients with the mask predictions
    coeffs, proto = seg_logits[:-1], seg_logits[-1]
    
    reshaped_coeffs = []
    for coeff in coeffs:
        reshaped_coeff = rearrange(coeff, "B M w h -> B (w h) M")
        reshaped_coeffs.append(reshaped_coeff)

    pred_coeffs = torch.concat(reshaped_coeffs, dim=1)

    pred_masks_logits = torch.einsum("bnm, bmhw -> bnhw", pred_coeffs, proto)
    if not sigmoid:
        return pred_masks_logits
    
    return torch.sigmoid(pred_masks_logits)    


seg_preds = get_mask_preds(seg_logits)
seg_preds_aux = get_mask_preds(seg_logits_aux)

In [None]:
det_preds = converter(det_logits)
det_preds_aux = converter(det_logits_aux)

for p in det_preds:
    print(p.shape)

# Loss computation

Make sure that the normal loss can still be computed normally, then add the masks to it.

## TODO
- [] Add the masks to the 
- [] Add to the loss computation the possibility to have the masks
- [] Compute the BCE with the target masks

In [None]:
det_preds[0].shape

what do I need to pass as input to the loss function for the segmentation part?
- coefficients for the prototypes 
- target masks [B, max_instances_in_gt, H, W]
- pred masks (both aux and main) [B, all_anchor_preds, H', W']

Then I will need to use: 
- the GT boxes (already in the main loss)

To add as well, optionally:
- the coeff diversity loss 
- the foreground and backgroung weights 

In [None]:
from copy import deepcopy

loss_value, loss_dict = loss(
    det_preds_aux, det_preds, deepcopy(y_yolo), y_masks, seg_logits_aux, seg_logits
)


loss_dict

## Training loop!

In [None]:
from tqdm import tqdm

EPOCHS = 10
BATCH_SIZE = 3

datamodule = MockCocoDataModule(batch_size=BATCH_SIZE, transforms=transforms, val_transform=val_transform)
datamodule.setup()
train_dl = datamodule.train_dataloader()

optim = torch.optim.Adam(model.parameters())


tqdm_loop = tqdm(enumerate(train_dl))
for epoch in range(EPOCHS):
    for batch_idx, (x, y) in tqdm_loop:
        x = torch.stack(x)
        y_yolo, y_masks = convert_y_for_yolo(y)
        
        out = model(x)

        det_logits, seg_logits = out["Main"]
        det_logits_aux, seg_logits_aux = out["AUX"]
        
        det_preds = converter(det_logits)
        det_preds_aux = converter(det_logits_aux)
        
        loss_value, loss_dict = loss(
            det_preds_aux, det_preds, deepcopy(y_yolo), y_masks, seg_logits_aux, seg_logits
        )
        tqdm_loop.set_description(f"Epoch: {epoch+1} | Batch {batch_idx+1} / {len(train_dl)} | losses {loss_dict=}")
        
        optim.zero_grad()
        loss_value.backward()
        optim.step()
    

# TODO: add the optim and the backward pass

In [None]:
det_logits, seg_logits = out["Main"]
det_logits_aux, seg_logits_aux = out["AUX"]

det_preds = converter(det_logits)
det_preds_aux = converter(det_logits_aux)

loss_value, loss_dict = loss(
    det_preds_aux, det_preds, deepcopy(y_yolo), y_masks, seg_logits_aux, seg_logits
)