<a href="https://colab.research.google.com/github/RArunn/Intent-Identification-Detection/blob/main/vcoco.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Complete Fast HOI Detection Code for Colab Pro A100
import os
import json
import torch
import gc
import zipfile
import urllib.request
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from tqdm import tqdm
import re
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings("ignore")

# A100 optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
# Cell 1: Dataset Download
def download_coco_val2014():
    val_dir = "val2014"
    if os.path.exists(val_dir):
        return val_dir

    os.system("wget -O val2014.zip http://images.cocodataset.org/zips/val2014.zip")

    with zipfile.ZipFile("val2014.zip", 'r') as zip_ref:
        zip_ref.extractall()

    os.remove("val2014.zip")
    return val_dir

# Download dataset
val_dir = download_coco_val2014()

In [None]:
# Cell 2: Fast Model Initialization
def init_model_fast():
    """Initialize model optimized for speed"""
    gc.collect()
    torch.cuda.empty_cache()

    model_path = "Qwen/Qwen2.5-VL-3B-Instruct"

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )

    # Compile for speed
    model = torch.compile(model, mode="reduce-overhead")

    processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
    return model, processor

In [None]:
# Cell 3: Prompt
fast_prompt = """Analyze this image for human-object interactions. For each interaction you find, provide:

1. What the person is doing
2. What object they're interacting with
3. Location of the person as [x, y, width, height]
4. Location of the object as [x, y, width, height]

Format your response as:
Person at [x,y,w,h] doing ACTION with OBJECT at [x,y,w,h]

Only include clear, visible interactions. Maximum 10 interactions per image."""

In [None]:
# Cell 4:
class FastCOCODataset(Dataset):
    def __init__(self, image_dir, max_images=None):
        self.image_dir = image_dir
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.jpg')])

        if max_images:
            self.image_files = self.image_files[:max_images]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_file = self.image_files[idx]
        image_id = int(img_file.replace('COCO_val2014_', '').replace('.jpg', ''))
        image_path = os.path.join(self.image_dir, img_file)

        try:
            image = Image.open(image_path).convert('RGB')
            original_size = image.size

            # Smaller size for speed
            max_size = 448
            if max(original_size) > max_size:
                image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)

            return {
                'image': image,
                'image_id': image_id,
                'original_size': original_size,
                'resized_size': image.size,
                'filename': img_file
            }
        except:
            return {
                'image': Image.new('RGB', (224, 224)),
                'image_id': -1,
                'original_size': (224, 224),
                'resized_size': (224, 224),
                'filename': img_file
            }

In [None]:
# Cell 5: Extract Real Coordinates from Qwen Response (ROBUST + DYNAMIC CONFIDENCE)
def extract_interactions_fast(text, image_data):
    """Extract real coordinates from Qwen's response - ROBUST with fallback confidence"""
    results = []

    # Primary pattern: "Person at [x,y,w,h] doing ACTION with OBJECT at [x,y,w,h] confidence 0.85"
    pattern_with_conf = r'person\s+at\s+\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\s+(?:doing\s+)?(\w+)\s+(?:with\s+)?(\w+)\s+at\s+\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\s+confidence\s+([0-9]*\.?[0-9]+)'

    # Fallback pattern: without confidence (fallback to 0.7)
    pattern_no_conf = r'person\s+at\s+\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\s+(?:doing\s+)?(\w+)\s+(?:with\s+)?(\w+)\s+at\s+\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]'

    # Try primary pattern first
    matches = list(re.finditer(pattern_with_conf, text.lower()))

    # If no matches with confidence, try fallback pattern
    if not matches:
        matches = [(match, 0.7) for match in re.finditer(pattern_no_conf, text.lower())]
    else:
        matches = [(match, None) for match in matches]

    for match_data in matches:
        match, fallback_conf = match_data
        try:
            # Extract coordinates
            px, py, pw, ph = map(int, match.groups()[:4])
            action = match.group(5).strip()
            obj = match.group(6).strip()

            # Smart object validation - keep quality, don't lose real interactions
            if len(obj) < 2 or obj.lower() in ['object', 'objects', 'thing', 'item', 'stuff', 'it']:
                continue
            ox, oy, ow, oh = map(int, match.groups()[6:10])

            # Get confidence
            if fallback_conf is not None:
                confidence = fallback_conf
            else:
                confidence = float(match.group(11))
                confidence = max(0.1, min(1.0, confidence))

            # Validate coordinates before processing
            if px < 0 or py < 0 or pw <= 0 or ph <= 0 or ox < 0 or oy < 0 or ow <= 0 or oh <= 0:
                continue

            # Scale coordinates to original image size
            scale_x = image_data['original_size'][0] / image_data['resized_size'][0]
            scale_y = image_data['original_size'][1] / image_data['resized_size'][1]

            # Convert [x,y,w,h] to [x1,y1,x2,y2] and scale
            person_box = [
                int(px * scale_x),
                int(py * scale_y),
                int((px + pw) * scale_x),
                int((py + ph) * scale_y)
            ]

            object_box = [
                int(ox * scale_x),
                int(oy * scale_y),
                int((ox + ow) * scale_x),
                int((oy + oh) * scale_y)
            ]

            # Validate boxes are within image bounds
            width, height = image_data['original_size']
            person_box = [
                max(0, min(width, person_box[0])),
                max(0, min(height, person_box[1])),
                max(0, min(width, person_box[2])),
                max(0, min(height, person_box[3]))
            ]

            object_box = [
                max(0, min(width, object_box[0])),
                max(0, min(height, object_box[1])),
                max(0, min(width, object_box[2])),
                max(0, min(height, object_box[3]))
            ]

            # Skip invalid boxes
            if (person_box[2] <= person_box[0] or person_box[3] <= person_box[1] or
                object_box[2] <= object_box[0] or object_box[3] <= object_box[1]):
                continue

            result = {
                'image_id': image_data['image_id'],
                'person_box': person_box,
                f'{action}_agent': confidence,
                f'{action}_{obj}': object_box + [confidence]
            }

            results.append(result)

        except (ValueError, IndexError, AttributeError):
            continue

    return results

In [None]:
# Cell 6: Fast Batch Inference
def fast_batch_inference(batch_data, model, processor):
    """Optimized batch inference for speed"""
    valid_batch = [item for item in batch_data if item['image_id'] != -1]
    if not valid_batch:
        return []

    images = [item['image'] for item in valid_batch]

    # Create messages
    messages_list = []
    for image in images:
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": fast_prompt}
            ]
        }]
        messages_list.append(messages)

    try:
        texts = [processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
                 for msgs in messages_list]

        inputs = processor(
            text=texts,
            images=images,
            return_tensors="pt",
            padding=True,
            max_length=1024
        ).to(model.device)

        # Fast generation
        with torch.inference_mode():
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                output_ids = model.generate(
                    **inputs,
                    max_new_tokens=150,  # Increased for more detailed responses
                    do_sample=False,     # Greedy for speed
                    use_cache=True,
                    pad_token_id=processor.tokenizer.eos_token_id
                )

        generated_ids = output_ids[:, inputs.input_ids.shape[1]:]
        responses = processor.batch_decode(generated_ids, skip_special_tokens=True)

        # Extract interactions from all responses
        all_results = []
        for data, response in zip(valid_batch, responses):
            results = extract_interactions_fast(response, data)
            all_results.extend(results)

        # Cleanup
        del inputs, output_ids, generated_ids
        torch.cuda.empty_cache()

        return all_results

    except Exception as e:
        return []

In [None]:
# Cell 7: Fast Processing Function (Fixed to group by image)
def process_coco_fast(model, processor, output_file="coco_hoi_fast.json",
                      batch_size=16, max_images=None, save_interval=1000):
    """Fast processing optimized for A100 - grouped by image"""

    dataset = FastCOCODataset('val2014', max_images)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,      # More workers for speed
        pin_memory=True,
        collate_fn=lambda x: x,
        persistent_workers=True
    )

    # Group detections by image
    detections_by_image = {}  # {image_id: [detections]}
    processed_images = []
    processed_count = 0

    for batch_idx, batch_data in enumerate(tqdm(dataloader, desc="Fast processing")):
        try:
            # Fast batch inference
            batch_results = fast_batch_inference(batch_data, model, processor)

            # Group results by image_id
            for result in batch_results:
                image_id = result['image_id']
                if image_id not in detections_by_image:
                    detections_by_image[image_id] = []
                detections_by_image[image_id].append(result)

            # Track processed images
            for data in batch_data:
                if data['image_id'] != -1:
                    processed_images.append({
                        'image_id': data['image_id'],
                        'filename': data['filename']
                    })

            processed_count += len(batch_data)

            # Clean up
            for data in batch_data:
                if hasattr(data['image'], 'close'):
                    data['image'].close()

            # Save checkpoint
            if processed_count % save_interval == 0:
                # Convert to final format for checkpoint
                checkpoint_results = []
                for img_id in sorted(detections_by_image.keys()):
                    checkpoint_results.extend(detections_by_image[img_id])

                checkpoint_data = {
                    'results': checkpoint_results,
                    'processed_count': processed_count,
                    'detection_count': len(checkpoint_results)
                }

                # Save checkpoint with compact arrays
                def save_compact_checkpoint(data, filename):
                    import json, re
                    json_str = json.dumps(data, indent=2)
                    array_pattern = r'\[\s*(\d+(?:\.\d+)?(?:\s*,\s*\d+(?:\.\d+)?)*)\s*\]'
                    def compact_array(match):
                        numbers = re.sub(r'\s*,\s*', ', ', match.group(1))
                        return f'[{numbers}]'
                    json_str = re.sub(array_pattern, compact_array, json_str)
                    with open(filename, 'w') as f:
                        f.write(json_str)

                checkpoint_file = f"{output_file}.checkpoint_{processed_count}.json"
                save_compact_checkpoint(checkpoint_data, checkpoint_file)

            # Memory cleanup
            if batch_idx % 50 == 0:
                gc.collect()
                torch.cuda.empty_cache()

        except Exception as e:
            continue

    # Create final output - all images in order (with and without detections)
    final_results = []

    # Get all processed images sorted by image_id
    all_processed_images = sorted(processed_images, key=lambda x: x['image_id'])

    for img_info in all_processed_images:
        img_id = img_info['image_id']

        if img_id in detections_by_image:
            # Add all detections for this image
            final_results.extend(detections_by_image[img_id])
        else:
            # Add entry for image with no detections
            final_results.append({
                'image_id': img_id,
                'detections': []  # Empty list indicates no detections
            })

    final_output = {
        'results': final_results,  # All images in order by image_id
        'metadata': {
            'total_images_processed': processed_count,
            'total_detections': sum(1 for r in final_results if 'person_box' in r),
            'images_with_detections': len(detections_by_image),
            'images_without_detections': processed_count - len(detections_by_image),
            'detection_rate': f"{len(detections_by_image)/processed_count*100:.1f}%" if processed_count > 0 else "0%"
        }
    }

    # Save final results with arrays on one line
    def save_compact_json(data, filename):
        """Save JSON with arrays on one line"""
        import json

        # Convert to JSON string with normal formatting
        json_str = json.dumps(data, indent=2)

        # Fix array formatting to be on one line
        import re

        # Pattern to match arrays with numbers
        array_pattern = r'\[\s*(\d+(?:\.\d+)?(?:\s*,\s*\d+(?:\.\d+)?)*)\s*\]'

        def compact_array(match):
            # Extract numbers and put them on one line
            numbers = match.group(1)
            # Clean up spacing
            numbers = re.sub(r'\s*,\s*', ', ', numbers)
            return f'[{numbers}]'

        # Apply the pattern
        json_str = re.sub(array_pattern, compact_array, json_str)

        # Write to file
        with open(filename, 'w') as f:
            f.write(json_str)

    save_compact_json(final_output, output_file)

    return final_results

In [None]:
# Cell 8: Quick Test
def run_fast_test():
    """Quick test on small batch"""
    model, processor = init_model_fast()

    results = process_coco_fast(
        model,
        processor,
        output_file="fast_test_results.json",
        batch_size=16,
        max_images=50,
        save_interval=25
    )

    if results:
        for i, result in enumerate(results[:3]):
            json.dumps(result, indent=2)

    return results

In [None]:
# Cell 9: Full Fast Processing
def run_full_fast():
    """Full dataset processing optimized for speed"""
    model, processor = init_model_fast()

    # A100 optimized settings
    batch_size = 28  # Large batch for A100

    results = process_coco_fast(
        model,
        processor,
        output_file="coco_hoi_full_fast.json",
        batch_size=batch_size,
        max_images=None,  # All images
        save_interval=2000
    )

    # Quick analysis
    if results:
        actions = {}
        objects = {}
        for result in results:
            for key in result.keys():
                if key.endswith('_agent'):
                    action = key.replace('_agent', '')
                    actions[action] = actions.get(action, 0) + 1
                elif '_' in key and key not in ['image_id', 'person_box']:
                    parts = key.split('_', 1)
                    if len(parts) == 2:
                        obj = parts[1]
                        objects[obj] = objects.get(obj, 0) + 1

    return results

# Uncomment to run full processing
full_results = run_full_fast()