
# CELL 1: Install Dependencies


In [None]:

print('[1/6] Installing required packages...')
!pip install -q torch torchvision transformers datasets accelerate ftfy wandb
!pip install -q git+https://github.com/openai/CLIP.git
!pip install -q pillow tqdm scikit-learn
!pip install -q pycocotools supervision
!pip install -q timm
print('Packages installed.')


# CELL 2: Imports and Setup


In [None]:

print('\n [2/6] Setting up environment...')
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import os, glob, json, random, math, re
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
import xml.etree.ElementTree as ET
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from functools import partial
from transformers import (
    CLIPProcessor, CLIPModel,
    AutoProcessor, AutoModelForZeroShotObjectDetection,
    AutoTokenizer, get_linear_schedule_with_warmup
)
from torchvision.ops import nms, box_iou
from tqdm import tqdm
import numpy as np
import warnings
warnings.filterwarnings('ignore')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {DEVICE}')


DRIVE_ROOT = "/content/drive/MyDrive/Model_Training"
IMAGES_DIR = os.path.join(DRIVE_ROOT, "images")
ANNOTATIONS_DIR = os.path.join(DRIVE_ROOT, "annotations")
SENTENCES_DIR = os.path.join(DRIVE_ROOT, "sentences")
OUT_DIR = os.path.join(DRIVE_ROOT, "training_output")


from google.colab import drive

print('Attempting to unmount drive to ensure a clean connection...')
try:
    drive.flush_and_unmount()
    print('Drive unmounted successfully.')
except Exception as e:
    print(f"No drive to unmount or error during unmount: {e}")

print('Mounting Google Drive...')
drive.mount('/content/drive', force_remount=True)
print('Drive mounted.')

print(f"Ensuring output directory exists: {OUT_DIR}")
os.makedirs(OUT_DIR, exist_ok=True)


# CELL 3: Annotation Parsing Functions


In [None]:

print('\n [3/6] Setting up annotation parsers...')

def parse_pascal_voc_xml(xml_folder, images_dir, class_map=None, keep_scene_without_box=False):
    """Parse Pascal-VOC xmls. Handles <nobndbox> tags and returns list of items."""
    xml_files = glob.glob(os.path.join(xml_folder, '*.xml'))
    items = []
    for xf in tqdm(xml_files, desc="Parsing XMLs"):
        try:
            tree = ET.parse(xf)
            root = tree.getroot()
        except Exception as e:
            print(f'Failed parse {xf}: {e}')
            continue

        fname_tag = root.find('filename')
        filename = fname_tag.text.strip() if fname_tag is not None else os.path.splitext(os.path.basename(xf))[0] + '.jpg'
        image_path = os.path.join(images_dir, filename)

        for obj in root.findall('object'):
            name_tag = obj.find('name')
            name = name_tag.text.strip() if name_tag is not None else 'unknown'
            nobnd = obj.find('nobndbox')
            bnd = obj.find('bndbox')

            if nobnd is not None and nobnd.text.strip() in ('1','true','True'):
                if keep_scene_without_box:
                    phrase = class_map.get(name, f'class_{name}') if class_map else (name if not name.isdigit() else f'class_{name}')
                    items.append({'image_path': image_path, 'phrase': phrase, 'bbox': None})
                continue

            if bnd is None:
                continue

            try:
                xmin = float(bnd.find('xmin').text)
                ymin = float(bnd.find('ymin').text)
                xmax = float(bnd.find('xmax').text)
                ymax = float(bnd.find('ymax').text)
            except:
                continue

            x = int(round(xmin)); y = int(round(ymin))
            w = int(round(xmax - xmin)); h = int(round(ymax - ymin))
            if w <= 0 or h <= 0: continue

            phrase = class_map.get(name, f'class_{name}') if class_map else (name if not name.isdigit() else f'class_{name}')
            items.append({'image_path': image_path, 'phrase': phrase, 'bbox': [x,y,w,h]})

    print(f'Parsed {len(items)} objects from {len(xml_files)} xml files')
    return items

def extract_sentence_phrases_from_text(text):
    """Extract phrases from sentence annotation files."""
    pattern = re.compile(r'\[/EN#(\d+)\/[^\s\]]+\s([^\]]+)\]')
    matches = pattern.findall(text)
    out = {}
    for idx, phrase in matches:
        out.setdefault(idx, []).append(phrase.strip())
    return out

def parse_and_load_annotations():
    """Parse all annotations and create cleaned dataset."""
    all_items = []


    if os.path.exists(ANNOTATIONS_DIR):
        xmls = glob.glob(os.path.join(ANNOTATIONS_DIR,'*.xml'))
        if xmls:
            parsed = parse_pascal_voc_xml(ANNOTATIONS_DIR, IMAGES_DIR)
            all_items.extend(parsed)


    sentence_map = {}
    if os.path.exists(SENTENCES_DIR):
        text_files = glob.glob(os.path.join(SENTENCES_DIR, '*'))
        for tf in tqdm(text_files, desc="Parsing Sentences"):
            try:
                with open(tf, 'r', encoding='utf-8') as f:
                    txt = f.read()
                sm = extract_sentence_phrases_from_text(txt)
                for k,v in sm.items():
                    sentence_map.setdefault(k,[]).extend(v)
            except Exception as e:
                print(f'Could not read {tf}: {e}')


    if sentence_map:
        id_to_phrase = {}
        for k,v in sentence_map.items():
            counts = {}
            for p in v:
                counts[p] = counts.get(p,0)+1
            best = sorted(counts.items(), key=lambda x:-x[1])[0][0]
            id_to_phrase[k] = best

        for it in all_items:
            ph = it['phrase']
            m = re.match(r'class_(\d+)', str(ph))
            if m:
                iid = m.group(1)
                if iid in id_to_phrase:
                    it['phrase'] = id_to_phrase[iid]


    cleaned = []
    image_paths_checked = set()
    for it in all_items:
        if it['image_path'] not in image_paths_checked:
            if not os.path.exists(it['image_path']):
                print(f"Warning: Image not found, skipping item: {it['image_path']}")
                continue
            image_paths_checked.add(it['image_path'])

        if it['bbox'] is None: continue
        x,y,w,h = it['bbox']
        it['bbox'] = [int(round(x)), int(round(y)), int(round(w)), int(round(h))]
        cleaned.append(it)

    return cleaned

print(' Annotation parsers ready.')

# CELL 4: Dataset Classes

In [None]:

print('\n [4/6] Setting up dataset classes...')

class CLIPReRankDataset(Dataset):
    """Dataset for training CLIP re-ranker."""
    def __init__(self, items, processor, num_negatives=3, crop_size=224):
        self.items = items
        self.processor = processor
        self.num_neg = num_negatives
        self.crop_size = crop_size

    def __len__(self):
        return len(self.items)

    def _crop_from_bbox(self, image, bbox, pad=0.15):
        x,y,w,h = bbox
        W,H = image.size
        pad_w = int(w*pad); pad_h = int(h*pad)
        x1 = max(0, x-pad_w); y1 = max(0, y-pad_h)
        x2 = min(W, x+w+pad_w); y2 = min(H, y+h+pad_h)
        return image.crop((x1,y1,x2,y2)).resize((self.crop_size,self.crop_size))

    def _random_neg_crop(self, image):
        W,H = image.size
        for _ in range(10):
            w = random.randint(40, min(W//2,200))
            h = random.randint(40, min(H//2,200))
            x = random.randint(0, max(0, W-w))
            y = random.randint(0, max(0, H-h))
            return image.crop((x,y,x+w,y+h)).resize((self.crop_size,self.crop_size))
        return image.resize((self.crop_size,self.crop_size))

    def __getitem__(self, idx):
        item = self.items[idx]
        try:
            img = Image.open(item['image_path']).convert('RGB')
        except Exception as e:
            print(f"Could not open {item['image_path']}, returning None. Error: {e}")
            return None

        pos = self._crop_from_bbox(img, item['bbox'])
        negs = []
        for _ in range(self.num_neg):
            if random.random() < 0.5:
                negs.append(self._random_neg_crop(img))
            else:
                other = random.choice(self.items)
                other_img = Image.open(other['image_path']).convert('RGB')
                negs.append(self._random_neg_crop(other_img))

        proc = self.processor(
            text=item['phrase'],
            images=[pos]+negs,
            return_tensors='pt',
            padding="max_length",
            max_length=77,
            truncation=True
        )
        return {k: v.squeeze(0) if v.dim()>0 else v for k,v in proc.items()}

class GroundingDINODataset(Dataset):
    """Dataset for training Grounding DINO."""
    def __init__(self, items):
        self.items = items
        self.image_groups = {}
        for item in items:
            img_path = item['image_path']
            self.image_groups.setdefault(img_path, []).append(item)
        self.image_list = list(self.image_groups.keys())

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_path = self.image_list[idx]
        items_for_image = self.image_groups[img_path]

        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Could not open {img_path}, returning None. Error: {e}")
            return None

        W, H = image.size

        boxes = []
        phrases_for_valid_boxes = []
        for item in items_for_image:
            x, y, w, h = item['bbox']
            phrase = item['phrase']

            if w <= 0 or h <= 0:
                continue


            norm_w, norm_h = w / W, h / H
            center_x = (x / W) + (norm_w / 2)
            center_y = (y / H) + (norm_h / 2)


            if norm_w < 1e-4 or norm_h < 1e-4:
                continue


            center_x = max(0.0, min(center_x, 1.0))
            center_y = max(0.0, min(center_y, 1.0))
            norm_w = max(0.0, min(norm_w, 1.0))
            norm_h = max(0.0, min(norm_h, 1.0))


            boxes.append([center_x, center_y, norm_w, norm_h])
            phrases_for_valid_boxes.append(phrase)


        if not boxes:
            return None


        unique_phrases = sorted(list(set(phrases_for_valid_boxes)))
        text_prompt = ". ".join(unique_phrases) + "."

        return {
            "image": image,
            "text": text_prompt,
            "boxes": torch.tensor(boxes, dtype=torch.float),
            "phrases": phrases_for_valid_boxes,
            "unique_phrases": unique_phrases
        }



def collate_fn(batch):
    """General collate function that handles None items."""
    batch = [b for b in batch if b is not None]
    if not batch: return None
    return torch.utils.data.dataloader.default_collate(batch)

def grounding_dino_collate_fn(batch, processor):
    """Custom collate function to handle batch processing for Grounding DINO."""
    batch = [b for b in batch if b is not None]
    if not batch: return None

    images = [item['image'] for item in batch]
    texts = [item['text'] for item in batch]
    target_boxes = [item['boxes'] for item in batch]
    target_phrases = [item['phrases'] for item in batch]
    unique_phrases_list = [item['unique_phrases'] for item in batch] # ✅ FIX: Get the unique phrases

    inputs = processor(
        images=images,
        text=texts,
        return_tensors="pt",
        padding=True,
        truncation=True
    )

    inputs['target_boxes'] = target_boxes
    inputs['target_phrases'] = target_phrases
    inputs['text'] = texts
    inputs['unique_phrases'] = unique_phrases_list

    return inputs

print(' Dataset classes ready.')

# CELL 5: Model Classes and Training Functions

In [None]:

print('\n [5/6] Setting up models and training functions...')

class ReRanker(nn.Module):
    """CLIP-based re-ranker model."""
    def __init__(self, clip_model, proj_dim=256, train_image_encoder=False):
        super().__init__()
        self.clip = clip_model
        self.train_image_encoder = train_image_encoder
        self.img_proj = nn.Sequential(
            nn.Linear(self.clip.visual_projection.out_features, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, proj_dim)
        )
        self.txt_proj = nn.Sequential(
            nn.Linear(self.clip.text_projection.out_features, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, proj_dim)
        )
        if not train_image_encoder:
            for p in self.clip.parameters():
                p.requires_grad = False

    def forward(self, input_ids, attention_mask, pixel_values):
        text_feats = self.clip.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
        text_feats = self.txt_proj(text_feats)
        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)

        img_feats = self.clip.get_image_features(pixel_values=pixel_values)
        img_feats = self.img_proj(img_feats)
        img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)

        return text_feats, img_feats

def compute_contrastive_loss(text_feats, img_feats, K=3, temp=0.07):
    B = text_feats.shape[0]
    d = img_feats.shape[-1]
    img_feats = img_feats.view(B, 1+K, d)
    sims = torch.einsum('bd,bnd->bn', text_feats, img_feats)
    logits = sims / temp
    labels = torch.zeros(B, dtype=torch.long, device=logits.device)
    loss = nn.CrossEntropyLoss()(logits, labels)
    preds = logits.argmax(dim=1)
    acc = (preds == labels).float().mean().item()
    return loss, acc

def train_clip_reranker(train_items, val_items, epochs=10):
    print('\n🔥 Training CLIP Re-ranker...')
    processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
    clip = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').to(DEVICE)

    train_ds = CLIPReRankerDataset(train_items, processor, num_negatives=3)
    val_ds = CLIPReRankerDataset(val_items, processor, num_negatives=3)
    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2, collate_fn=collate_fn)
    val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2, collate_fn=collate_fn)

    reranker = ReRanker(clip, proj_dim=256, train_image_encoder=False).to(DEVICE)
    params = [p for p in reranker.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=5e-5, weight_decay=1e-4)
    best_val_loss = float('inf')
    K=3

    for epoch in range(epochs):
        reranker.train()
        tloss, tacc, tsteps = 0, 0, 0
        for batch in tqdm(train_loader, desc=f'CLIP Train {epoch+1}/{epochs}'):
            if batch is None: continue
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            pv = batch['pixel_values'].to(DEVICE)
            B, n_images, c, h, w = pv.shape
            pv_flat = pv.view(B*n_images, c, h, w)
            txt_feats, img_feats = reranker(input_ids, attention_mask, pv_flat)
            loss, acc = compute_contrastive_loss(txt_feats, img_feats, K=K)
            loss.backward()
            optimizer.step()
            tloss += loss.item(); tacc += acc; tsteps += 1

        reranker.eval()
        vloss, vacc, vsteps = 0, 0, 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f'CLIP Val {epoch+1}/{epochs}'):
                if batch is None: continue
                input_ids = batch['input_ids'].to(DEVICE)
                attention_mask = batch['attention_mask'].to(DEVICE)
                pv = batch['pixel_values'].to(DEVICE)
                B, n_images, c, h, w = pv.shape
                pv_flat = pv.view(B*n_images, c, h, w)
                txt_feats, img_feats = reranker(input_ids, attention_mask, pv_flat)
                loss, acc = compute_contrastive_loss(txt_feats, img_feats, K=K)
                vloss += loss.item(); vacc += acc; vsteps += 1

        tloss /= max(1, tsteps); tacc /= max(1, tsteps)
        vloss /= max(1, vsteps); vacc /= max(1, vsteps)
        print(f'Epoch {epoch+1}: train_loss={tloss:.4f} train_acc={tacc:.3f} | val_loss={vloss:.4f} val_acc={vacc:.3f}')

        if vloss < best_val_loss:
            best_val_loss = vloss
            torch.save(reranker.state_dict(), os.path.join(OUT_DIR, 'best_clip_reranker.pt'))
            import pickle
            with open(os.path.join(OUT_DIR, 'clip_processor.pkl'), 'wb') as f:
                pickle.dump(processor, f)
            print(f"   -> New best model saved with validation loss: {vloss:.4f}")

    print('✅ CLIP Re-ranker training completed!')
    return reranker, processor

def train_grounding_dino(train_items, val_items, epochs=5):
    """Fine-tune Grounding DINO model."""
    print('\n🔥 Training Grounding DINO...')
    model_id = "IDEA-Research/grounding-dino-base"
    processor = AutoProcessor.from_pretrained(model_id)
    model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)

    train_ds = GroundingDINODataset(train_items)
    val_ds = GroundingDINODataset(val_items)

    collate_with_processor = partial(grounding_dino_collate_fn, processor=processor)


    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0, collate_fn=collate_with_processor)
    val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_with_processor)

    for name, param in model.named_parameters():
        param.requires_grad = 'bbox_embed' in name or 'class_embed' in name

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable_params, lr=2e-5, weight_decay=1e-4)
    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        tloss, tsteps = 0, 0

        for i, batch in enumerate(tqdm(train_loader, desc=f'DINO Train {epoch+1}/{epochs}')):
            if batch is None: continue
            optimizer.zero_grad()

            labels = []
            for j in range(len(batch['pixel_values'])):
                unique_phrases = batch['unique_phrases'][j]
                phrase_map = {phrase: idx for idx, phrase in enumerate(unique_phrases)}
                class_labels = [phrase_map[phrase] for phrase in batch['target_phrases'][j]]
                labels.append({
                    "boxes": batch['target_boxes'][j].to(DEVICE),
                    "class_labels": torch.tensor(class_labels, device=DEVICE)
                })

            inputs = {
                'pixel_values': batch['pixel_values'].to(DEVICE),
                'input_ids': batch['input_ids'].to(DEVICE),
                'attention_mask': batch['attention_mask'].to(DEVICE),
                'labels': labels
            }

            try:
                print(f"\n--- Debugging Batch {i} ---")

                for j in range(len(labels)):
                    num_classes = len(batch['unique_phrases'][j])
                    box_vals = labels[j]['boxes']
                    label_vals = labels[j]['class_labels']

                    print(f"  Item {j}:")
                    print(f"    - Unique Phrases: {batch['unique_phrases'][j]}")
                    print(f"    - Num Classes: {num_classes}")
                    print(f"    - Box tensor shape: {box_vals.shape}")
                    print(f"    - Box values min/max: {box_vals.min():.4f} / {box_vals.max():.4f}")
                    print(f"    - Label tensor shape: {label_vals.shape}")
                    print(f"    - Label values min/max: {label_vals.min()} / {label_vals.max()}")

                    if label_vals.max() >= num_classes:
                        print("    - ERROR: Max label index is out of bounds!")
                    if box_vals.min() < 0.0 or box_vals.max() > 1.0:
                        print("    -  ERROR: Box coordinates are out of [0, 1] bounds!")

                print("--- Batch OK, sending to model... ---")
                outputs = model(**inputs)

            except RuntimeError as e:
                print(f"RUNTIME ERROR CAUGHT in Batch {i}. The printout above is for the failing batch.")
                raise e

            loss = outputs.loss
            if loss is not None:
                loss.backward()
                optimizer.step()
                tloss += loss.item(); tsteps += 1

        print(f'Epoch {epoch+1}: train_loss={(tloss / max(1, tsteps)):.4f}')

    print(' Grounding DINO training completed!')
    return model, processor

print('Training functions ready.')

# CELL 6: Data Preparation

In [None]:

print('\n Starting Data Preparation...')


print('\n Parsing annotations...')
all_items = parse_and_load_annotations()

if len(all_items) == 0:
    raise RuntimeError('No usable annotations found. Check your DRIVE_ROOT and folder structure!')

print(f'Found {len(all_items)} annotated object instances.')


with open(os.path.join(OUT_DIR, 'parsed_annotations.json'), 'w') as f:
    json.dump(all_items, f, indent=2)

image_to_items = {}
for it in all_items:
    image_to_items.setdefault(it['image_path'], []).append(it)

image_paths = list(image_to_items.keys())
train_imgs, val_imgs = train_test_split(image_paths, test_size=0.15, random_state=42)

train_items = [it for im in train_imgs for it in image_to_items[im]]
val_items = [it for im in val_imgs for it in image_to_items[im]]

print(f'Dataset split: {len(train_items)} training items, {len(val_items)} validation items.')
print('Data preparation complete.')




# CELL 7: Train CLIP Re-ranker

In [None]:

train_clip_reranker(train_items, val_items, epochs=8)

# CELL 8: Train Grounding DINO

In [None]:

train_grounding_dino(train_items, val_items, epochs=5)