In [None]:
# train_tritask_with_dino.py
"""
Tri-Task training script (BLIP-2 + LoRA) + partial Grounding-DINO-L fine-tune (last 2 blocks)
- Dataset: expects per-image JSON annotations in train_annotations/ and validation_annotations/
- Uses Hungarian matching for grounding supervision (DETR-style)
- Designed for 2x A100-40GB with bf16 + DeepSpeed/ZeRO
"""

import os
import json
import math
import random
from pathlib import Path
from typing import List, Tuple

import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.ops as tvops
from scipy.optimize import linear_sum_assignment

from transformers import Blip2Processor, Blip2ForConditionalGeneration, get_cosine_schedule_with_warmup
from peft import LoraConfig, get_peft_model

# === USER SETTINGS ===
TRAIN_IMAGES_DIR = "/path/to/train_images"
VAL_IMAGES_DIR   = "/path/to/validation_images"
TRAIN_ANNO_DIR   = "/path/to/train_annotations"
VAL_ANNO_DIR     = "/path/to/validation_annotations"
OUTPUT_DIR = "./checkpoints_tritask"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Model & training hyperparams
EPOCHS = 6                      # 4 LoRA warmup + 2 partial DINO fine-tune
MICRO_BATCH = 4                 # per GPU
GRAD_ACCUM = 2
GLOBAL_BATCH = MICRO_BATCH * GRAD_ACCUM * 2  # 2 GPUs by default (adjust)
LR_LORA = 2e-5
LR_DINO = 5e-6                  # small lr for DINO last-blocks
WEIGHT_DECAY = 0.01
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
NUM_DINO_UNFREEZE_BLOCKS = 2    # last N blocks of DINO to unfreeze
NUM_QUERIES = 100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 640

# DeepSpeed: point to config file if you plan to use deepspeed
DEEPSPEED = True
DEEPSPEED_CONFIG = "deepspeed_config.json"

# ======================================================
#  Utility: Hungarian matcher (DETR-style)
# ======================================================
def box_cxcywh_to_xyxy(x):
    cx, cy, w, h = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
    x0 = cx - 0.5 * w
    y0 = cy - 0.5 * h
    x1 = cx + 0.5 * w
    y1 = cy + 0.5 * h
    return torch.stack([x0, y0, x1, y1], dim=-1)

def box_area(boxes):
    return (boxes[..., 2] - boxes[..., 0]).clamp(min=0) * (boxes[..., 3] - boxes[..., 1]).clamp(min=0)

def generalized_box_iou(boxes1, boxes2):
    # boxes in xyxy
    # adapt from torchvision.ops.generalized_box_iou
    return tvops.generalized_box_iou(boxes1, boxes2)

class HungarianMatcher:
    def __init__(self, cost_class=1, cost_bbox=5, cost_giou=2):
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou

    @torch.no_grad()
    def match(self, outputs_logits, outputs_boxes, targets):
        # outputs_logits: (B, Q, C)  (here C=1 if objectness or >1 if classes)
        # outputs_boxes: (B, Q, 4) in normalized cxcywh
        # targets: list of dicts with 'boxes' in normalized xyxy or cxcywh
        bs, q = outputs_boxes.shape[0], outputs_boxes.shape[1]
        indices = []
        for b in range(bs):
            tgt_boxes = targets[b]  # tensor (G,4) in xyxy
            if tgt_boxes.numel() == 0:
                indices.append((np.array([], dtype=np.int64), np.array([], dtype=np.int64)))
                continue
            # class cost (objectness logits -> prob)
            out_prob = torch.sigmoid(outputs_logits[b]).cpu().numpy()  # (Q,) or (Q,C)
            # if multi-class, you would use negative log-probs for GT class
            # bbox cost: L1 between boxes
            out_boxes_xyxy = box_cxcywh_to_xyxy(outputs_boxes[b]).cpu()
            tgt_boxes_xyxy = tgt_boxes.cpu()
            # L1 cost
            cost_bbox = torch.cdist(outputs_boxes[b], ( ( (tgt_boxes_xyxy[...,0]+tgt_boxes_xyxy[...,2])/2).unsqueeze(1) if False else tgt_boxes_xyxy ), p=1).cpu().numpy()
            # Simpler: use L1 on cxcywh (if targets provided as cxcywh)
            # Use GIoU cost
            # ensure shapes
            qboxes = out_boxes_xyxy
            gious = generalized_box_iou(qboxes, tgt_boxes_xyxy).cpu().numpy()  # (Q,G)
            cost_giou = -gious
            # combine costs (class not used much since single class)
            C = self.cost_bbox * cost_bbox + self.cost_giou * cost_giou
            # Hungarian
            q_idx, t_idx = linear_sum_assignment(C)
            indices.append((q_idx.astype(np.int64), t_idx.astype(np.int64)))
        return indices

# ======================================================
# Dataset class
# ======================================================
class TriTaskDataset(Dataset):
    def __init__(self, image_dir, anno_dir, processor, img_size=IMG_SIZE, mode="train"):
        self.image_dir = Path(image_dir)
        self.anno_dir = Path(anno_dir)
        self.processor = processor
        self.img_size = img_size
        self.files = sorted([p for p in self.anno_dir.glob("*.json")])
        self.mode = mode

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        p = self.files[idx]
        with open(p, "r") as f:
            ann = json.load(f)
        img_path = self.image_dir / ann["image"]
        img = Image.open(img_path).convert("RGB")
        img = img.resize((self.img_size, self.img_size))
        pixel_values = self.processor(images=img, return_tensors="pt").pixel_values.squeeze(0)
        caption = ann.get("caption", "")
        qa_pairs = ann.get("qa_pairs", [])
        qa = random.choice(qa_pairs) if len(qa_pairs) > 0 else {"question":"", "answer":""}
        # objects: list of dicts with obj_coord [x1,y1,x2,y2] normalized
        objs = ann.get("objects", [])
        if len(objs) == 0:
            boxes = torch.zeros((0,4), dtype=torch.float32)
        else:
            coords = [o.get("obj_coord", [0,0,1,1]) for o in objs]
            # convert xyxy normalized [x1,y1,x2,y2] to cxcywh normalized for matching/pred
            arr = np.array(coords, dtype=np.float32)
            x1, y1, x2, y2 = arr[:,0], arr[:,1], arr[:,2], arr[:,3]
            cx = (x1 + x2) / 2.0
            cy = (y1 + y2) / 2.0
            w = (x2 - x1)
            h = (y2 - y1)
            cxcywh = np.stack([cx, cy, w, h], axis=1)
            boxes = torch.tensor(cxcywh, dtype=torch.float32)
        return {"pixel_values": pixel_values, "caption": caption, "question": qa.get("question",""), "answer": qa.get("answer",""), "boxes": boxes}

# ======================================================
# Load models: Grounding-DINO + BLIP-2 with LoRA
# ======================================================
def load_grounding_dino(model_config_path: str, checkpoint_path: str):
    """
    Load GroundingDINO model from repository. You must have GroundingDINO installed as a package.
    This function is written for the GroundingDINO repo structure. Adjust import paths if needed.
    """
    # Import inside function to avoid module error if repo not installed
    from groundingdino.util.slconfig import SLConfig
    from groundingdino.models import build_model
    cfg = SLConfig.fromfile(model_config_path)
    model = build_model(cfg)
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["model"], strict=False)
    return model

print("Loading BLIP-2 processor & model (this will download weights if not cached)...")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip2 = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl")
# inject LoRA adapters
lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "q_attention"],
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)
blip2 = get_peft_model(blip2, lora_cfg)

# Grounding-DINO config / checkpoint paths (edit these)
DINO_CONFIG = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"  # example path in repo
DINO_CHECKPOINT = "/path/to/groundingdino_swint_ogc.pth"

print("Loading Grounding-DINO model (may take time)...")
grounding_dino = load_grounding_dino(DINO_CONFIG, DINO_CHECKPOINT)
grounding_dino.to(DEVICE)
grounding_dino.eval()  # we'll set train mode later conditionally

# ======================================================
# Partial unfreeze: unfreeze last N transformer blocks and head
# ======================================================
def unfreeze_last_blocks_dino(model, n_blocks=2):
    """
    This inspects grounding_dino backbone to find transformer blocks and unfreezes last n_blocks.
    Implementation depends on model architecture; this works for SwinT-like backbones in GroundingDINO.
    """
    # freeze all first
    for p in model.parameters():
        p.requires_grad = False

    # try to find backbone blocks attr names
    unfreezed = 0
    # attempt several likely attribute paths (depends on repo version)
    candidate_paths = [
        "backbone.model.blocks",      # possible path
        "backbone.blocks",            # possible path
        "backbone.stage4.blocks",     # possible path
        "backbone.model.stage4.blocks"
    ]
    found = False
    for path in candidate_paths:
        try:
            # resolve attribute
            obj = model
            for part in path.split("."):
                obj = getattr(obj, part)
            blocks = obj
            # unfreeze last n_blocks
            for i in range(-n_blocks, 0):
                for p in blocks[i].parameters():
                    p.requires_grad = True
            found = True
            break
        except Exception:
            continue
    # always unfreeze the head (box predictor / text encoder if exists)
    # best effort: look for 'box_head' or 'head' attributes
    for name, module in model.named_modules():
        if "head" in name.lower() or "predictor" in name.lower() or "out" in name.lower():
            for p in module.parameters():
                p.requires_grad = True
    if not found:
        print("Warning: didn't find canonical backbone block path to unfreeze; you must manually set blocks to unfreeze.")
    else:
        print(f"Unfroze last {n_blocks} DINO blocks (if path resolution succeeded).")

# call unfreeze (we will unfreeze at correct stage in training loop)
unfreeze_last_blocks_dino(grounding_dino, n_blocks=NUM_DINO_UNFREEZE_BLOCKS)

# ======================================================
# Build head & wrapper utilities (we use DINO's own head for proposals during training)
# ======================================================

# For grounding training we will use DINO's internal forward to get predictions (logits + boxes).
# The exact forward signature depends on repo; below we use the common pattern found in their inference:
#  - backbone -> features
#  - transformer -> outputs
#  - head -> pred logits + boxes
# So we'll call model with images and training flag to get raw outputs.

# ======================================================
# Loss functions & criterion helpers
# ======================================================
class TriTaskCriterion:
    def __init__(self, weight_dict=None):
        # weights for losses
        self.weight_dict = weight_dict or {"caption":1.0, "ground":1.0, "vqa":1.0}
        self.matcher = HungarianMatcher(cost_bbox=5, cost_giou=2)
        self.l1 = nn.L1Loss(reduction="none")

    def loss_caption(self, blip_outputs_loss):
        return blip_outputs_loss

    def loss_vqa(self, vqa_loss):
        return vqa_loss

    def loss_grounding(self, pred_logits, pred_boxes, targets_boxes):
        """
        pred_logits: (B, Q) objectness logits (if multi-class use shape (B,Q,C))
        pred_boxes: (B, Q, 4) cxcywh normalized
        targets_boxes: list of tensors (G,4) xyxy normalized (we'll convert)
        Returns scalar loss = cls_loss + bbox_l1 + giou_loss
        """
        # Convert targets to xyxy for giou
        targets_xyxy = []
        for tb in targets_boxes:
            if tb.numel() == 0:
                targets_xyxy.append(torch.zeros((0,4), device=pred_boxes.device))
            else:
                # we have targets in cxcywh? Our dataset is cxcywh -> convert to xyxy
                cx = tb[:,0]; cy = tb[:,1]; w = tb[:,2]; h = tb[:,3]
                x1 = cx - 0.5*w; y1 = cy - 0.5*h; x2 = cx + 0.5*w; y2 = cy + 0.5*h
                targets_xyxy.append(torch.stack([x1,y1,x2,y2], dim=1))

        # matching
        indices = self.matcher.match(pred_logits, pred_boxes, targets_xyxy)
        batch_loss_bbox = []
        batch_loss_obj = []
        batch_loss_giou = []

        B, Q, _ = pred_boxes.shape
        for b in range(B):
            q_idx, t_idx = indices[b]
            if len(t_idx) == 0:
                # no targets, encourage low objectness
                obj_loss = torch.mean(torch.sigmoid(pred_logits[b]))
                batch_loss_obj.append(obj_loss)
                batch_loss_bbox.append(torch.tensor(0.0, device=pred_boxes.device))
                batch_loss_giou.append(torch.tensor(0.0, device=pred_boxes.device))
                continue

            src_boxes = pred_boxes[b][q_idx]    # (G,4) cxcywh
            tgt_boxes = targets_boxes[b]       # (G,4) cxcywh
            # bbox L1 on cxcywh
            l1 = torch.abs(src_boxes - tgt_boxes).mean()
            # giou on xyxy
            src_xy = box_cxcywh_to_xyxy(src_boxes)
            tgt_xy = box_cxcywh_to_xyxy(tgt_boxes)
            giou = tvops.generalized_box_iou(src_xy, tgt_xy)
            giou_loss = (1 - torch.diag(giou)).mean()
            # objectness loss: positive for matched preds, negative for others
            obj_targets = torch.zeros(Q, device=pred_logits.device)
            obj_targets[q_idx] = 1.0
            obj_loss = nn.functional.binary_cross_entropy_with_logits(pred_logits[b], obj_targets)
            batch_loss_bbox.append(l1)
            batch_loss_obj.append(obj_loss)
            batch_loss_giou.append(giou_loss)
        loss_bbox = torch.stack(batch_loss_bbox).mean()
        loss_obj = torch.stack(batch_loss_obj).mean()
        loss_giou = torch.stack(batch_loss_giou).mean()
        return loss_obj + 5.0*loss_bbox + 2.0*loss_giou

# ======================================================
# Data loaders
# ======================================================
train_dataset = TriTaskDataset(TRAIN_IMAGES_DIR, TRAIN_ANNO_DIR, processor, img_size=IMG_SIZE, mode="train")
val_dataset   = TriTaskDataset(VAL_IMAGES_DIR, VAL_ANNO_DIR, processor, img_size=IMG_SIZE, mode="val")

train_loader = DataLoader(train_dataset, batch_size=MICRO_BATCH, shuffle=True, num_workers=6, pin_memory=True, drop_last=True)
val_loader   = DataLoader(val_dataset, batch_size=MICRO_BATCH, shuffle=False, num_workers=6, pin_memory=True)

# ======================================================
# Optimizer groups: LoRA params + small heads + DINO last blocks (if any)
# ======================================================
# collect trainable params
def get_trainable_params(blip2_model, dino_model):
    params = []
    # LoRA adapters in blip2 are trainable (peft manages it), include all model.parameters() and later filter
    for name, p in blip2_model.named_parameters():
        if p.requires_grad:
            params.append(p)
    # include DINO trainable params
    for name, p in dino_model.named_parameters():
        if p.requires_grad:
            params.append(p)
    return params

trainable_params = get_trainable_params(blip2, grounding_dino)
optimizer = optim.AdamW(trainable_params, lr=LR_LORA, weight_decay=WEIGHT_DECAY)

# set a smaller LR for DINO params
# (we will use param groups to set LR for all params, then for DINO params set LR_DINO)
dino_param_ids = {id(p) for n,p in grounding_dino.named_parameters() if p.requires_grad}
param_groups = [
    {"params": [p for p in trainable_params if id(p) not in dino_param_ids], "lr": LR_LORA},
    {"params": [p for p in trainable_params if id(p) in dino_param_ids], "lr": LR_DINO}
]
optimizer = optim.AdamW(param_groups, weight_decay=WEIGHT_DECAY)

total_steps = math.ceil(len(train_loader) * EPOCHS / GRAD_ACCUM)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=total_steps)

criterion = TriTaskCriterion()

# ======================================================
# Training loop
# ======================================================
scaler = torch.cuda.amp.GradScaler(enabled=True)
global_step = 0

print(f"Training on device={DEVICE} -- total_steps approx {total_steps}")

for epoch in range(EPOCHS):
    blip2.train()
    # set DINO train/eval: we kept only last N blocks trainable; put model.train() to update them
    grounding_dino.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    running_loss = 0.0
    optimizer.zero_grad()

    for batch_idx, batch in enumerate(pbar):
        pv = batch["pixel_values"].to(DEVICE)             # (B,C,H,W)
        captions = batch["caption"]
        questions = batch["question"]
        answers = batch["answer"]
        target_boxes = batch["boxes"].to(DEVICE)          # (B, G, 4) cxcywh; may be zero-length

        # choose task by sampling: caption 40%, grounding 40%, vqa 20%
        r = random.random()

        with torch.cuda.amp.autocast(enabled=True):
            if r < 0.4:
                # CAPTION
                inputs = processor(images=pv, text=[""]*pv.size(0), return_tensors="pt", padding=True).to(DEVICE)
                outputs = blip2(**inputs, labels=inputs.input_ids)
                loss = criterion.loss_caption(outputs.loss)
            elif r < 0.8:
                # GROUNDING: use DINO forward to get objectness logits and boxes
                # DINO forward depends on repo API; the following is a best-effort interface:
                dino_inputs = {"image": pv}  # if repo requires preprocessing, replace accordingly
                dino_out = grounding_dino.forward(pv)  # replace with correct forward signature if different
                # extract logits & boxes: preferences depend on repo; typical outputs: {"pred_logits":..., "pred_boxes":...}
                # here we attempt to access attributes; adapt if your groundingdino version differs
                try:
                    pred_logits = dino_out["pred_logits"]   # (B, Q, C) or (B, Q) for objectness
                    pred_boxes = dino_out["pred_boxes"]     # (B, Q, 4) in cxcywh normalized
                except Exception:
                    # fallback: many repo versions return dict with keys
                    pred_logits = dino_out.pred_logits
                    pred_boxes = dino_out.pred_boxes
                # ensure shapes
                if pred_logits.ndim == 3 and pred_logits.shape[-1] > 1:
                    # multi-class; reduce to objectness for simple case by taking max class logits
                    pred_logits_obj = torch.max(pred_logits, dim=-1)[0]
                else:
                    pred_logits_obj = pred_logits.squeeze(-1) if pred_logits.ndim==3 else pred_logits
                # prepare targets: list of tensors in cxcywh (we already have)
                targets_list = []
                for b in range(target_boxes.shape[0]):
                    targets_list.append(target_boxes[b])
                loss = criterion.loss_grounding(pred_logits_obj, pred_boxes, targets_list)
            else:
                # VQA
                # prepare inputs: use question prompt
                vqa_losses = []
                for i in range(pv.size(0)):
                    q = questions[i]
                    a = answers[i]
                    if (not q) or (not a):
                        # fallback to caption for this sample
                        inputs = processor(images=pv[i].unsqueeze(0), text="", return_tensors="pt").to(DEVICE)
                        out = blip2(**inputs, labels=inputs.input_ids)
                        vqa_losses.append(out.loss)
                    else:
                        inputs = processor(images=pv[i].unsqueeze(0), text=q, return_tensors="pt", padding=True).to(DEVICE)
                        labels = processor.tokenizer(a, return_tensors="pt", padding=True).input_ids.to(DEVICE)
                        out = blip2(**inputs, labels=labels)
                        vqa_losses.append(out.loss)
                loss = torch.stack(vqa_losses).mean()

        scaler.scale(loss).backward()
        if (batch_idx + 1) % GRAD_ACCUM == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_( [p for group in optimizer.param_groups for p in group['params'] if p.requires_grad], 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
            global_step += 1

        running_loss += loss.item()
        pbar.set_postfix({"loss": running_loss / (batch_idx + 1), "step": global_step})

    # Save checkpoint (model + peft adapters)
    ckpt = {
        "epoch": epoch+1,
        "global_step": global_step,
        "model_state_dict": blip2.state_dict(),
        "dino_state_dict": grounding_dino.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict()
    }
    torch.save(ckpt, os.path.join(OUTPUT_DIR, f"checkpoint_epoch_{epoch+1}.pth"))
    # Save PEFT adapters separately (if peft supports save_pretrained)
    try:
        blip2.save_pretrained(os.path.join(OUTPUT_DIR, f"blip2_peft_epoch_{epoch+1}"))
    except Exception:
        pass
    print(f"Epoch {epoch+1} checkpoint saved.")

print("Training complete.")


In [None]:
def infer_image(image_pil, task="caption", question=None, blip2_model=None, grounding_dino_model=None, peft_adapter_dir=None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if peft_adapter_dir:
        blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl")
        blip2_model = PeftModel.from_pretrained(blip2_model, peft_adapter_dir).to(device)
    blip2_model.to(device).eval()
    grounding_dino_model.to(device).eval()

    if task == "caption":
        inputs = processor(images=image_pil, return_tensors="pt").to(device)
        gen = blip2_model.generate(**inputs, max_new_tokens=80)
        caption = processor.decode(gen[0], skip_special_tokens=True)
        return {"caption": caption}

    if task == "qa":
        if question is None:
            return {"error": "Question required"}
        inputs = processor(images=image_pil, text=question, return_tensors="pt").to(device)
        gen = blip2_model.generate(**inputs, max_new_tokens=40)
        ans = processor.decode(gen[0], skip_special_tokens=True)
        return {"answer": ans}

    if task == "yesno":
        if question is None:
            return {"error": "Question required"}
        q = question + " Answer yes or no."
        inputs = processor(images=image_pil, text=q, return_tensors="pt").to(device)
        gen = blip2_model.generate(**inputs, max_new_tokens=8)
        ans = processor.decode(gen[0], skip_special_tokens=True).lower()
        return {"answer": "Yes" if "yes" in ans else "No"}

    if task == "grounding":
        # GroundingDINO inference utilities usually provide predict_with_caption()
        # Example: boxes, logits, phrases = grounding_dino.predict_with_caption(image_pil, caption="object", box_threshold=0.3)
        # adapt to your repo's function name
        boxes, logits, phrases = grounding_dino_model.predict_with_caption(image_pil, caption="object", box_threshold=0.35, text_threshold=0.25)
        objects = []
        for b, p in zip(boxes, phrases):
            x1,y1,x2,y2 = b.tolist()
            objects.append({"obj_cls": p, "obj_coord":[x1,y1,x2,y2]})
        return {"objects": objects}
