In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import sys

import torch
from hydra import compose, initialize
from yolo import Config, PostProcess, create_converter, create_model
from yolo.utils.model_utils import get_device, get_mask_preds
from yolo.tools.loss_functions import create_loss_function

from ifc_dl.conf.augmentations import get_transform_fn
from ifc_dl.data.mock_coco_dataset import MockCocoDataModule
from ifc_dl.utils import plot_instance_segmentation_data

project_root = Path().resolve().parent
sys.path.append(str(project_root))

In [None]:
image_size = (512, 512)

# Load the datamodule 

In [None]:
augs = {
    "resize": {
        "params": {"height": image_size[0], "width": image_size[1], "interpolation": 3},
        "all_datasets": True,
    },
    "horizontal_flip": {"params": {"p": 0.5}},
}

transforms, val_transform = get_transform_fn(augs)
datamodule = MockCocoDataModule(batch_size=2, transforms=transforms, val_transform=val_transform)
datamodule.setup()
train_dl = datamodule.train_dataloader()

x, y = next(iter(train_dl))
x = torch.stack(x)

In [None]:
def convert_y_for_yolo(y):
    #  convert y to the format that yolo likes 
    max_annotations = max(len(y_["labels"]) for y_ in y)

    y_yolo = torch.ones((x.shape[0], max_annotations, 5)) * -1
    y_masks = torch.ones((x.shape[0], max_annotations, *y[0]["masks"].shape[-2:])) * -1

    for i_y, y_ in enumerate(y):
        for i_instance, (label, box, mask) in enumerate(zip(y_["labels"], y_["boxes"], y_["masks"])):
            y_yolo[i_y, i_instance, 0] = label
            y_yolo[i_y, i_instance, 1:] = box
            
            y_masks[i_y, i_instance] = (mask > 0).to(int)
            
    return y_yolo, y_masks

y_yolo, y_masks = convert_y_for_yolo(y)

# Load the YOLO model

In [None]:
%%capture
CONFIG_PATH = "../YOLO/yolo/config"
CONFIG_NAME = "config-seg"
CLASS_NUM = 91

# can also be True, False
# if you put the path to another model weights, it will load all the weights for the common layers between the two models
MODEL_WEIGHTS = "/Users/simone.bonato/Desktop/ecolution/ecolution-floorplan-seg/submodules/YOLO/weights/v9-c.pt"

with initialize(config_path=CONFIG_PATH, version_base=None, job_name="notebook_job"):
    cfg: Config = compose(config_name=CONFIG_NAME)

device, _ = get_device(cfg.device)
model = create_model(
    cfg.model,
    class_num=CLASS_NUM,
    weight_path=MODEL_WEIGHTS
)
model = model.to(device)

converter = create_converter(
    cfg.model.name, model, cfg.model.anchor, image_size, device
)

post_proccess = None
if cfg.task.get("nms"):
    post_proccess = PostProcess(converter, cfg.task.nms)

cfg.dataset.class_num = CLASS_NUM
loss = create_loss_function(cfg, converter)

In [None]:
# load the weight from the training
# w = torch.load("/Users/simone.bonato/Desktop/ecolution/ecolution-floorplan-seg/submodules/YOLO/model_seg_w.pth", map_location=device)
# model.load_state_dict(w)

# Forward pass and segmentation masks predictions!

In [None]:
model.train()
# model.eval()

out = model(x)

det_logits, seg_logits = out["Main"]
det_logits_aux, seg_logits_aux = out["AUX"]

seg_preds = get_mask_preds(seg_logits, sigmoid=True)
seg_preds_aux = get_mask_preds(seg_logits_aux, sigmoid=True)

In [None]:
# NOTE: mask coeffs and last one is the mask prototype
print("--- segmentation head ---")
for l in seg_logits:
    print(l.shape)

print("\n--- detection head ---")
# for each resolution we have: class, object, bbox, mask coefficients 
for l in det_logits:
    for det_l in l: 
        print(det_l.shape)
    print()

In [None]:
det_preds = converter(det_logits)
det_preds_aux = converter(det_logits_aux)

for p in det_preds:
    print(p.shape)

In [None]:
# fake_seg_logits = []

# b = 2
# c = 32
# hw = [64, 32, 16, 128]

# for i in range(len(hw)):
#     fake_seg_logits.append(torch.randn(b, c, hw[i], hw[i]))

# out = model(x)
# out["Main"] = (out["Main"], fake_seg_logits)

In [None]:
from omegaconf import OmegaConf
nms_config = {
    "min_confidence": 0.1,
    "min_iou": 0.,
    "max_bbox": 300,
}
nms_config = OmegaConf.create(nms_config)
post_proccess = PostProcess(converter, nms_config)

# TODO: make sure that when the masks are upscaled by the converter!
pred = post_proccess(out)

In [None]:
_, pred_seg = pred

for b in pred_seg:
    for p in b:
        import matplotlib.pyplot as plt
        plt.imshow(p.cpu().numpy())
        plt.show()

# Loss computation

Make sure that the normal loss can still be computed normally, then add the masks to it.

## TODO
- [] Add the masks to the 
- [] Add to the loss computation the possibility to have the masks
- [] Compute the BCE with the target masks

In [None]:
det_preds[0].shape

what do I need to pass as input to the loss function for the segmentation part?
- coefficients for the prototypes 
- target masks [B, max_instances_in_gt, H, W]
- pred masks (both aux and main) [B, all_anchor_preds, H', W']

Then I will need to use: 
- the GT boxes (already in the main loss)

To add as well, optionally:
- the coeff diversity loss 
- the foreground and backgroung weights 

In [None]:
from copy import deepcopy

loss_value, loss_dict = loss(
    det_preds_aux, det_preds, deepcopy(y_yolo), y_masks, seg_logits_aux, seg_logits
)


loss_dict

## Training loop!

In [None]:
from tqdm import tqdm
from copy import deepcopy

EPOCHS = 10
BATCH_SIZE = 3

datamodule = MockCocoDataModule(batch_size=BATCH_SIZE, transforms=transforms, val_transform=val_transform)
datamodule.setup()
train_dl = datamodule.train_dataloader()

optim = torch.optim.Adam(model.parameters())
device = "cpu"
model.to(device)

for epoch in range(EPOCHS):
    tqdm_loop = tqdm(enumerate(train_dl), total=len(train_dl), desc="Training")
    for batch_idx, (x, y) in tqdm_loop:
        
        # ugly fix!
        x = list(x)
        for i in range(len(x)):
            if x[i].shape[0] ==1:
                x[i] = x[i].repeat(3, 1, 1)
        
        x = torch.stack(x).to(device)
        y_yolo, y_masks = convert_y_for_yolo(y)
        
        y_yolo = y_yolo.to(device)
        y_masks = y_masks.to(device)
        
        out = model(x)

        det_logits, seg_logits = out["Main"]
        det_logits_aux, seg_logits_aux = out["AUX"]
        
        det_preds = converter(det_logits)
        det_preds_aux = converter(det_logits_aux)
        
        loss_value, loss_dict = loss(
            det_preds_aux, det_preds, deepcopy(y_yolo), y_masks, seg_logits_aux, seg_logits
        )
        tqdm_loop.set_description(f"Epoch: {epoch+1} | Batch {batch_idx+1} / {len(train_dl)} | losses {loss_dict=}")
        
        optim.zero_grad()
        loss_value.backward()
        optim.step()

In [None]:
%load_ext autoreload
%autoreload 2

pytest_path = "/Users/simone.bonato/Desktop/ecolution/ecolution-floorplan-seg/submodules/YOLO/tests/test_tools/test_loss_functions.py"

import pytest

pytest_args = [
    pytest_path,
    "-k",
    "test_loss_function",
    "-v",
    "--tb=short",
]
pytest.main(pytest_args)