In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from config import *
from data import VOC2012SegDataset, crop_augment_preprocess_batch
from models.seg_models import evaluate, set_trainable_params
from models.vl_models import GenParams, OllamaMLLM
from models.vl_encoders import VLE_REGISTRY, VLEncoder
from prompter import FastPromptBuilder
from logger import LogManager
from path import get_mask_prs_path
from viz import get_layer_numel_str
from utils import clear_memory, get_activation

from functools import partial
from collections import OrderedDict
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import segmentation as segmodels
import torchvision.transforms.v2 as T
import torchvision.transforms.functional as TF
from torchvision.transforms._presets import SemanticSegmentation
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
import torchmetrics as tm
from open_clip_train.scheduler import cosine_lr, const_lr, const_lr_cooldown
from open_clip.loss import SigLipLoss, ClipLoss
import math

from typing import Optional
from torch.nn.modules.loss import _Loss

In [3]:
SEG_CONFIG = CONFIG['seg']
SEG_TRAIN_CONFIG = SEG_CONFIG['train']

VLE_CONFIG = CONFIG['vle']
VLE_TRAIN_CONFIG = VLE_CONFIG['train']

if SEG_TRAIN_CONFIG['log_only_to_stdout']:
    log_manager = LogManager(
        exp_name=SEG_TRAIN_CONFIG['exp_name'],
        exp_desc=SEG_TRAIN_CONFIG['exp_desc'],
    )
else:
    log_manager = LogManager(
        exp_name=SEG_TRAIN_CONFIG['exp_name'],
        exp_desc=SEG_TRAIN_CONFIG['exp_desc'],
        file_logs_dir_path=SEG_TRAIN_CONFIG['file_logs_dir_path'],
        tb_logs_dir_path=SEG_TRAIN_CONFIG['tb_logs_dir_path']
    )

In [4]:
async def train_loop(
        segnet: nn.Module,
        vlm: OllamaMLLM,
        vle: VLEncoder,
        train_dl: DataLoader,
        val_dl: DataLoader,
        fast_prompt_builder: FastPromptBuilder,
        seg_preprocess_fn: nn.Module,
        gen_params: GenParams,
        criterion: _Loss,
        contr_criterion: nn.Module,
        metrics_dict: dict[dict, tm.Metric],
        checkpoint_dict: Optional[dict] = None
) -> None:
    
    # --- 1. Initialization and State Restoration ---
    start_epoch = 0
    global_step = 0
    if checkpoint_dict:
        start_epoch = checkpoint_dict['epoch'] + 1
        global_step = checkpoint_dict['global_step']
        log_manager.log_line(f"Resuming training from epoch {start_epoch}, global step {global_step}.")

    grad_accum_steps = SEG_TRAIN_CONFIG['grad_accum_steps']
    num_batches_per_epoch = len(train_dl)
    # The number of optimizer steps per epoch
    num_steps_per_epoch = math.ceil(num_batches_per_epoch / grad_accum_steps)

    # --- 2. Optimizer Setup ---
    lr=SEG_TRAIN_CONFIG['lr_schedule']['base_lr']
    optimizer = torch.optim.AdamW(segnet.parameters(), lr=lr)
    if checkpoint_dict:
        optimizer.load_state_dict(checkpoint_dict['optimizer_state_dict'])

    # --- 3. Scheduler Setup ---
    total_steps = num_steps_per_epoch * SEG_TRAIN_CONFIG['num_epochs']
    sched_config = SEG_TRAIN_CONFIG['lr_schedule']
    
    scheduler = None
    if sched_config['policy'] == 'const':
        scheduler = const_lr(optimizer, sched_config['base_lr'], sched_config['warmup_length'], total_steps)
    elif sched_config['policy'] == 'const-cooldown':
        cooldown_steps = num_steps_per_epoch * sched_config['epochs_cooldown']
        scheduler = const_lr_cooldown(optimizer, sched_config['base_lr'], sched_config['warmup_length'], total_steps, cooldown_steps, sched_config['lr_cooldown_power'], sched_config['lr_cooldown_end'])
    elif sched_config['policy'] == 'cosine':
        scheduler = cosine_lr(optimizer, sched_config['base_lr'], sched_config['warmup_length'], total_steps)

    # --- 4. AMP and Model Compilation Setup ---
    ...

    # --- 5. Initial Validation ---
    log_manager.log_title("Initial Validation")
    val_loss, val_metrics_score = evaluate(segnet, val_dl, criterion, metrics_dict)
    log_manager.log_scores(f"Before any weight update, VALIDATION", val_loss, val_metrics_score, start_epoch, "val", None, "val_")
    best_val_mIoU = val_metrics_score['mIoU']

    log_manager.log_title("Training Start")
    
    # --- 6. Main Training Loop ---
    train_metrics = tm.MetricCollection(metrics_dict)
    for epoch in range(start_epoch, SEG_TRAIN_CONFIG["num_epochs"]):

        train_metrics.reset() # in theory, this can be removed

        for step, (scs_img, gts) in enumerate(train_dl):

            # --- Seg --- #

            segnet.train()

            scs = seg_preprocess_fn(scs_img)

            scs = scs.to(CONFIG["device"])
            gts = gts.to(CONFIG["device"]) # shape [B, H, W]
            
            logits = segnet(scs)
            logits: torch.Tensor = logits["out"] if isinstance(logits, OrderedDict) else logits # shape [N, C, H, W]
            
            train_metrics.update(logits.detach().argmax(dim=1), gts)

            seg_batch_loss: torch.Tensor = criterion(logits, gts) / grad_accum_steps

            if True:

                # --- VLM --- #

                scs_img = (scs_img*255).to(torch.uint8)
                gts = gts.unsqueeze(1)
                prs = logits.argmax(dim=1, keepdim=True)
                # Both VLM and VLE receive the images in the same downsampled size.
                gts_down = TF.resize(gts, fast_prompt_builder.image_size, TF.InterpolationMode.NEAREST)
                prs_down = TF.resize(prs, fast_prompt_builder.image_size, TF.InterpolationMode.NEAREST)
                scs_down = TF.resize(scs_img, fast_prompt_builder.image_size, TF.InterpolationMode.BILINEAR)
                cs_prompts = fast_prompt_builder.build_cs_inference_prompts(gts_down, prs_down, scs_down)

                batch_idxs = [train_dl.batch_size*step + i for i in range(len(scs_down))]
                
                cs_answer_list = await vlm.predict_many_class_splitted(
                    cs_prompts,
                    batch_idxs,
                    gen_params=gen_params,
                    jsonl_save_path=None,
                    only_text=True,
                    splits_in_parallel=False,
                    batch_size=None,
                    use_tqdm=False
                )

                # --- VLE --- #

                global_text_tokens = list()
                for i, img_idx in enumerate(batch_idxs):
                    # cs_texts = [text for pos_, text in cs_answer_list[i]['content'].items()]
                    cs_texts = list(cs_answer_list[i]['content'].values()) # gather the text for each pos. class of this image
                    cs_texts = vle.preprocess_texts(cs_texts)
                    cs_vle_output = vle.encode_and_project(images=None, texts=cs_texts, broadcast=False)

                    global_text_token = cs_vle_output.global_text_token
                    aggr_global_text_token = torch.max(global_text_token, dim=0).values # aggregating the class-splitted text vectors with a MaxPool.
                    global_text_tokens.append(aggr_global_text_token)

                global_text_tokens = torch.stack(global_text_tokens)

                # global_text_tokens = global_text_tokens @ torch.rand(size=(512, 960), device=global_text_tokens.device) # TODO here the 'global_text_tokens' have to be processed again before the additional loss
                global_text_tokens = global_text_tokens @ torch.zeros(size=(512, 960), device=global_text_tokens.device) # TODO here the 'global_text_tokens' have to be processed again before the additional loss
                
                bottleneck_out: torch.Tensor = segnet.backbone['16'].activations['bottleneck']
                bottleneck_vec = segnet.backbone.gap(bottleneck_out).squeeze()
                
                contr_batch_loss = contr_criterion(
                    image_features=bottleneck_vec,
                    text_features=global_text_tokens,
                    logit_scale=vle.model.logit_scale/vle.model.logit_scale,
                    logit_bias=vle.model.logit_bias*0,
                    output_dict=False
                )
                
                batch_loss: torch.Tensor = seg_batch_loss + SEG_TRAIN_CONFIG['with_text']['loss_lam']*contr_batch_loss # multi-task loss
            else:
                batch_loss = seg_batch_loss

            batch_loss.backward()

            print(f"{step=}, {batch_loss=}")

            is_last_batch = (step + 1) == num_batches_per_epoch
            is_accum_step = (step + 1) % grad_accum_steps == 0

            # --- Optimizer Step and Scheduler Update ---
            if is_accum_step or is_last_batch:
                
                if scheduler:
                    scheduler(global_step)

                max_grad_norm = SEG_TRAIN_CONFIG['grad_clip_norm']
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    segnet.parameters(),
                    max_grad_norm if max_grad_norm else float('inf'),
                    norm_type=2.0
                )
                
                optimizer.step()
                optimizer.zero_grad()

                global_step += 1 # Increment global step *only* after an optimizer step

                # --- Logging ---
                if global_step % SEG_TRAIN_CONFIG['log_every'] == 0:
                    train_metrics_score = train_metrics.compute()                    
                    current_lr = optimizer.param_groups[0]['lr']
                    step_in_epoch = (step // grad_accum_steps) + 1
                    log_manager.log_scores(
                        f"epoch: {epoch+1}/{SEG_TRAIN_CONFIG['num_epochs']}, step: {step_in_epoch}/{num_steps_per_epoch} (global_step: {global_step})",
                        batch_loss * grad_accum_steps, train_metrics_score, global_step, "train",
                        f", lr: {current_lr:.2e}, grad_norm: {grad_norm:.2f}", "batch_"
                    )

                train_metrics.reset() # only the batch metrics are logged

            # torch.cuda.synchronize() if CONFIG['device'] == 'cuda' else None

        # --- End of Epoch Validation and Checkpointing ---
        val_loss, val_metrics_score = evaluate(segnet, val_dl, criterion, metrics_dict)
        log_manager.log_scores(f"epoch: {epoch+1}/{SEG_TRAIN_CONFIG['num_epochs']}, VALIDATION", val_loss, val_metrics_score, epoch+1, "val", None, "val_")

        if val_metrics_score['mIoU'] > best_val_mIoU:
            best_val_mIoU = val_metrics_score['mIoU']
            
            if SEG_TRAIN_CONFIG['save_weights_root_path']:
                # Note: 'epoch' is saved, so on resume we start from 'epoch + 1'
                new_checkpoint_dict = {
                    'epoch': epoch,
                    'global_step': global_step,
                    'model_state_dict': segnet.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }

                save_dir = Path(SEG_TRAIN_CONFIG['save_weights_root_path'])
                save_dir.mkdir(parents=True, exist_ok=True)
                ckp_filename = f"lraspp_mobilenet_v3_large_{SEG_TRAIN_CONFIG['exp_name']}.pth"
                full_ckp_path = save_dir / ckp_filename
                torch.save(new_checkpoint_dict, full_ckp_path)
                log_manager.log_line(f"New best model saved to {full_ckp_path} with validation mIoU: {best_val_mIoU:.4f}")
    
    log_manager.log_title("Training Finished")

In [None]:
train_ds = VOC2012SegDataset(
    root_path=Path("/home/olivieri/exp/data/VOCdevkit"),
    split='train',
    resize_size=SEG_CONFIG['image_size'],
    center_crop=True,
    with_unlabelled=True,
)

val_ds = VOC2012SegDataset(
    root_path=Path("/home/olivieri/exp/data/VOCdevkit"),
    split='val',
    resize_size=SEG_CONFIG['image_size'],
    center_crop=True,
    with_unlabelled=True,
    img_idxs=slice(None, 20, None) # TODO remove to have full val. dataset
)

In [6]:
# Segmentation Model
segnet = segmodels.lraspp_mobilenet_v3_large(weights=None, weights_backbone=None).to(CONFIG["device"])
segnet.load_state_dict(torch.load(Path(SEG_CONFIG['pretrained_weights_root_path']) / ("lraspp_mobilenet_v3_large-full-pt" + ".pth")))
# TODO to modify with the actual segnet intermediate checkpoint
segnet.eval()

checkpoint_dict = None
if SEG_TRAIN_CONFIG['resume_path']:
    vle_weights_path = Path(SEG_TRAIN_CONFIG['resume_path'])
    if vle_weights_path.exists():
        checkpoint_dict = torch.load(vle_weights_path, map_location=CONFIG['device'])
        segnet.load_state_dict(checkpoint_dict['model_state_dict'])
    else:
        raise AttributeError(f"ERROR: Resume path '{vle_weights_path}' not found. ")

bottleneck_gap = nn.AdaptiveAvgPool2d(output_size=(1, 1))
segnet.backbone.add_module('gap', bottleneck_gap)

set_trainable_params(segnet, train_decoder_only=SEG_TRAIN_CONFIG['train_decoder_only'])


# NOTE should I clone the fw hook output?
# register the forward hook to store the bottleneck output.
target_layer = segnet.backbone['16'] # [960, 32, 32] bottleneck output
target_layer.activations = dict() # to store the fw hooks output
handle = target_layer.register_forward_hook(get_activation('bottleneck'))

In [7]:
# Vision-Language Model
model_name = "gemma3:12b-it-qat"
vlm = OllamaMLLM(model_name)

by_model = "LRASPP_MobileNet_V3"

gen_params = GenParams(
    seed=CONFIG["seed"],
    temperature=SEG_TRAIN_CONFIG['with_text']['vlm_temperature'],
)

prompt_blueprint={
        "context": "default",
        "color_map": "default",
        "input_format": "sep_ovr_original",
        "task": "default",
        "output_format": "default",
        "support_set_intro": "default",
        "support_set_item": "default",
        "query": "default",
}

# NOTE when used in this pipeline, the dataset is useful only to access class maps and color maps, the actual data is not retrieved from here.
seg_dataset = VOC2012SegDataset(
    root_path=Path(CONFIG['datasets']['VOC2012_root_path']),
    split='train',
    resize_size=CONFIG['seg']['image_size'],
    center_crop=True,
    with_unlabelled=False,
)

sup_set_seg_dataset = VOC2012SegDataset(
    root_path=Path(CONFIG['datasets']['VOC2012_root_path']),
    split='prompts_split',
    resize_size=CONFIG['seg']['image_size'],
    center_crop=True,
    with_unlabelled=False,
    mask_prs_path=get_mask_prs_path(by_model=by_model)
)

fast_prompt_builder = FastPromptBuilder(
    seg_dataset=seg_dataset,
    seed=CONFIG["seed"],
    prompt_blueprint=prompt_blueprint,
    by_model=by_model,
    alpha=0.6,
    class_map=seg_dataset.get_class_map(with_unlabelled=False),
    color_map=seg_dataset.get_color_map_dict(with_unlabelled=False),
    image_size=CONFIG['vlm']['image_size'],
    sup_set_img_idxs=[16],
    sup_set_seg_dataset=sup_set_seg_dataset,
    str_formats=None,
)

In [8]:
# Vision-Language Encoder
vle: VLEncoder = VLE_REGISTRY.get("flair", version='flair-cc3m-recap.pt', device=CONFIG['device'], vision_adapter=False, text_adapter=False)
vle_weights_path = Path(SEG_TRAIN_CONFIG['with_text']['vle_weights_path'])
if vle_weights_path.exists():
    vle.model.load_state_dict(torch.load(vle_weights_path, map_location=CONFIG['device'])['model_state_dict'])
else:
    raise AttributeError(f"ERROR: VLE weights path '{vle_weights_path}' not found.")

vle.set_vision_trainable_params()

# NOTE deleting vision layers only if encoding text only.
del vle.model.visual, vle.model.visual_proj, vle.model.image_post
clear_memory()

In [9]:
contr_criterion = SigLipLoss() #  NOTE or should I use FLAIRLoss?

In [None]:
text_protocol = '...'

if text_protocol == 'contrastive_global':
    ...
if text_protocol == 'contrastive_local':
    ...

In [11]:
seg_preprocess_fn = partial(SemanticSegmentation, resize_size=SEG_CONFIG['image_size'])() # same as original one, but with custom resizing

# training cropping functions
center_crop_fn = T.CenterCrop(SEG_CONFIG['image_size'])
random_crop_fn = T.RandomCrop(SEG_CONFIG['image_size'])

# augmentations
augment_fn = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    # T.RandomAffine(degrees=0, scale=(0.5, 2)), # Zooms in and out of the image.
    # T.RandomAffine(degrees=[-30, 30], translate=[0.2, 0.2], scale=(0.5, 2), shear=15), # Full affine transform.
    # T.RandomPerspective(p=0.5, distortion_scale=0.2) # Shears the image
])

train_collate_fn = partial(
    crop_augment_preprocess_batch,
    crop_fn=random_crop_fn,
    augment_fn=augment_fn,
    preprocess_fn=None
)

val_collate_fn = partial(
    crop_augment_preprocess_batch,
    crop_fn=T.CenterCrop(SEG_CONFIG['image_size']),
    augment_fn=None,
    preprocess_fn=seg_preprocess_fn
)

criterion = nn.CrossEntropyLoss(ignore_index=21)

train_dl = DataLoader(
    train_ds,
    batch_size=SEG_TRAIN_CONFIG["batch_size"],
    shuffle=False, # TODO SET TO TRUE
    generator=get_torch_gen(),
    collate_fn=train_collate_fn,
)
val_dl = DataLoader(
    val_ds,
    batch_size=SEG_TRAIN_CONFIG["batch_size"],
    shuffle=False,
    generator=get_torch_gen(),
    collate_fn=val_collate_fn,
)

metrics_dict = {
    "acc": MulticlassAccuracy(num_classes=train_ds.get_num_classes(with_unlabelled=True), top_k=1, average="micro", multidim_average="global", ignore_index=21).to(CONFIG["device"]),
    "mIoU": MulticlassJaccardIndex(num_classes=train_ds.get_num_classes(with_unlabelled=True), average="macro", ignore_index=21).to(CONFIG["device"]),
}

log_manager.log_intro(
    config=CONFIG,
    train_ds=train_ds,
    val_ds=val_ds,
    train_dl=train_dl,
    val_dl=val_dl
)

# Log trainable parameters
log_manager.log_title("Trainable Params")
[log_manager.log_line(t) for t in get_layer_numel_str(segnet, print_only_total=False, only_trainable=True).split('\n')]

try:
    await train_loop(
    segnet,
    vlm,
    vle,
    train_dl,
    val_dl,
    fast_prompt_builder,
    seg_preprocess_fn,
    gen_params,
    criterion,
    contr_criterion,
    metrics_dict,
    checkpoint_dict
)
except KeyboardInterrupt:
    log_manager.log_title("Training Interrupted", pad_symbol='~')

handle.remove()

log_manager.close_loggers()

[2025-07-31 15:26:45,029 INFO 3511189400.py line 51] ---------------------------------------------[ Config ]---------------------------------------------
[2025-07-31 15:26:45,030 INFO 3511189400.py line 51] {'seed': 42, 'device': 'cuda', 'datasets': {'COCO2017_root_path': '/home/olivieri/exp/shared_data/coco2017', 'VOC2012_root_path': '/home/olivieri/exp/data/VOCdevkit'}, 'seg': {'pretrained_weights_root_path': '/home/olivieri/exp/data/torch_weights/seg/lraspp_mobilenet_v3_large/no_text', 'image_size': 520, 'train': {'exp_name': 'test_no_text_250731_1526', 'exp_desc': '', 'train_decoder_only': False, 'batch_size': 8, 'num_epochs': 50, 'lr_schedule': {'policy': 'cosine', 'base_lr': 0.0001, 'warmup_length': 200, 'epochs_cooldown': None, 'lr_cooldown_power': None, 'lr_cooldown_end': None}, 'grad_clip_norm': 10.0, 'grad_accum_steps': 1, 'resume_path': None, 'log_only_to_stdout': False, 'file_logs_dir_path': '/home/olivieri/exp/logs/seg/with_text', 'tb_logs_dir_path': '/home/olivieri/exp/lo

step=0, batch_loss=tensor(5.8554, device='cuda:0', grad_fn=<AddBackward0>)
step=1, batch_loss=tensor(5.7571, device='cuda:0', grad_fn=<AddBackward0>)
step=2, batch_loss=tensor(5.9506, device='cuda:0', grad_fn=<AddBackward0>)


CancelledError: 