<a href="https://colab.research.google.com/github/RArunn/Intent-Identification-Detection/blob/main/swighoi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Cell 1: Imports and Setup
import os
import json
import zipfile
import torch
import gc
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import re
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
import urllib.request
import requests
import warnings
warnings.filterwarnings("ignore")

gc.collect()
torch.cuda.empty_cache()

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
# Cell 2: Dataset Download
def download_dataset():
    img_dir = "images_512"

    if os.path.exists(img_dir):
        return img_dir

    if not os.path.exists("images_512.zip"):
        result = os.system('aria2c -x 16 -s 16 -o images_512.zip "https://swig-data-weights.s3.us-east-2.amazonaws.com/images_512.zip" || wget -O images_512.zip "https://swig-data-weights.s3.us-east-2.amazonaws.com/images_512.zip"')
        if result != 0:
            return None

    result = os.system("unzip -q images_512.zip")

    if result == 0:
        os.remove("images_512.zip")
        return img_dir
    else:
        return None

In [None]:
# Cell 3: Subset Extraction
def extract_image_subset():
    cache_file = "subset_cache.json"

    if os.path.exists(cache_file):
        with open(cache_file, 'r') as f:
            subset_images = json.load(f)
        return subset_images

    if not os.path.exists("test.json"):
        return None

    with open("test.json", 'r') as f:
        test_data = json.load(f)

    all_test_images = list(test_data.keys())
    target_image = "skiing_169.jpg"

    try:
        end_index = all_test_images.index(target_image)
        subset_images = all_test_images[:end_index + 1]

        with open(cache_file, 'w') as f:
            json.dump(subset_images, f)

        return subset_images

    except ValueError:
        return None

In [None]:
# Cell 4: Model Initialization
def initialize_model():
    gc.collect()
    torch.cuda.empty_cache()

    model_path = "Qwen/Qwen2.5-VL-3B-Instruct"

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        low_cpu_mem_usage=True
    )

    model.eval()
    model = torch.compile(model, mode="max-autotune", fullgraph=True)

    processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

    dummy_image = Image.new('RGB', (224, 224))
    dummy_messages = [{"role": "user", "content": [{"type": "image", "image": dummy_image}, {"type": "text", "text": "test"}]}]
    dummy_text = processor.apply_chat_template(dummy_messages, tokenize=False, add_generation_prompt=True)
    dummy_inputs = processor(text=[dummy_text], images=[dummy_image], return_tensors="pt").to(model.device)

    with torch.inference_mode():
        _ = model.generate(**dummy_inputs, max_new_tokens=10, do_sample=False)

    del dummy_inputs
    torch.cuda.empty_cache()

    return model, processor

In [None]:
# Cell 5: Prompt Definition
hoi_detection_prompt = """Analyze this image for human-object interactions. For each interaction you find, provide:

1. What the person is doing
2. What object they're interacting with
3. Location of the person as [x, y, width, height]
4. Location of the object as [x, y, width, height]

Format your response as:
Person at [x,y,w,h] doing ACTION with OBJECT at [x,y,w,h]

Only include clear, visible interactions. Maximum 10 interactions per image."""

In [None]:
# Cell 6: Dataset Class
class SWiGDataset(Dataset):
    def __init__(self, test_image_files, max_images=None):
        self.img_dir = "images_512"
        self.test_image_files = test_image_files

        self.valid_image_files = []
        for img_file in test_image_files:
            image_path = os.path.join(self.img_dir, img_file)
            if os.path.exists(image_path):
                self.valid_image_files.append(img_file)

        if max_images:
            self.valid_image_files = self.valid_image_files[:max_images]

        self.image_metadata = {}
        for img_file in self.valid_image_files:
            image_path = os.path.join(self.img_dir, img_file)
            try:
                with Image.open(image_path) as img:
                    self.image_metadata[img_file] = img.size
            except:
                self.image_metadata[img_file] = (512, 512)

    def __len__(self):
        return len(self.valid_image_files)

    def __getitem__(self, idx):
        img_file = self.valid_image_files[idx]
        image_id = img_file.replace('.jpg', '')
        image_path = os.path.join(self.img_dir, img_file)

        try:
            image = Image.open(image_path).convert('RGB')
            original_size = self.image_metadata.get(img_file, image.size)

            max_size = 448
            if max(original_size) > max_size:
                image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)

            return {
                'image': image,
                'image_id': image_id,
                'original_size': original_size,
                'resized_size': image.size,
                'filename': img_file
            }
        except Exception as e:
            return {
                'image': Image.new('RGB', (224, 224)),
                'image_id': image_id,
                'original_size': (224, 224),
                'resized_size': (224, 224),
                'filename': img_file
            }

In [None]:
# Cell 7: HOI Prediction Extraction
def extract_hoi_predictions(text, image_data, max_detections=10):
    results = []

    pattern_with_conf = re.compile(r'person\s+at\s+\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\s+(?:doing\s+)?(\w+)\s+(?:with\s+)?(\w+)\s+at\s+\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\s+confidence\s+([0-9]*\.?[0-9]+)', re.IGNORECASE)
    pattern_no_conf = re.compile(r'person\s+at\s+\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\s+(?:doing\s+)?(\w+)\s+(?:with\s+)?(\w+)\s+at\s+\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]', re.IGNORECASE)

    invalid_objects = {'object', 'objects', 'thing', 'item', 'stuff', 'it'}

    matches = pattern_with_conf.findall(text)
    if not matches:
        matches = [(m + (0.7,)) for m in pattern_no_conf.findall(text)]

    scale_x = image_data['original_size'][0] / image_data['resized_size'][0]
    scale_y = image_data['original_size'][1] / image_data['resized_size'][1]
    width, height = image_data['original_size']

    for match in matches[:max_detections]:
        try:
            if len(match) == 11:
                px, py, pw, ph, verb, obj, ox, oy, ow, oh, score = match
                score = float(score)
            else:
                px, py, pw, ph, verb, obj, ox, oy, ow, oh = match
                score = 0.7

            px, py, pw, ph = int(px), int(py), int(pw), int(ph)
            ox, oy, ow, oh = int(ox), int(oy), int(ow), int(oh)

            if (len(obj) < 2 or obj.lower() in invalid_objects or
                px < 0 or py < 0 or pw <= 0 or ph <= 0 or
                ox < 0 or oy < 0 or ow <= 0 or oh <= 0):
                continue

            subject_box = [
                max(0, min(width, int(px * scale_x))),
                max(0, min(height, int(py * scale_y))),
                max(0, min(width, int((px + pw) * scale_x))),
                max(0, min(height, int((py + ph) * scale_y)))
            ]

            object_box = [
                max(0, min(width, int(ox * scale_x))),
                max(0, min(height, int(oy * scale_y))),
                max(0, min(width, int((ox + ow) * scale_x))),
                max(0, min(height, int((oy + oh) * scale_y)))
            ]

            if (subject_box[2] <= subject_box[0] or subject_box[3] <= subject_box[1] or
                object_box[2] <= object_box[0] or object_box[3] <= object_box[1]):
                continue

            results.append({
                "subject_box": subject_box,
                "object_box": object_box,
                "subject_category": "person",
                "object_category": obj.lower(),
                "verb": verb.lower(),
                "score": round(max(0.1, min(1.0, score)), 4)
            })

        except (ValueError, IndexError):
            continue

    return results

In [None]:
# Cell 8: Batch Inference
def batch_inference(batch_data, model, processor):
    valid_batch = [item for item in batch_data if item['image_id'] != 'invalid']
    if not valid_batch:
        return []

    images = [item['image'] for item in valid_batch]

    messages_list = []
    for image in images:
        messages_list.append([{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": hoi_detection_prompt}
            ]
        }])

    try:
        processor.tokenizer.padding_side = 'left'

        texts = [processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
                 for msgs in messages_list]

        inputs = processor(
            text=texts,
            images=images,
            return_tensors="pt",
            padding=True,
            max_length=2048,
            truncation=True
        ).to(model.device, non_blocking=True)

        with torch.inference_mode():
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                output_ids = model.generate(
                    input_ids=inputs['input_ids'],
                    pixel_values=inputs.get('pixel_values', None),
                    image_grid_thw=inputs.get('image_grid_thw', None),
                    max_new_tokens=150,
                    do_sample=False,
                    use_cache=True,
                    num_beams=1,
                    pad_token_id=processor.tokenizer.pad_token_id,
                    eos_token_id=processor.tokenizer.eos_token_id
                )

        generated_ids = output_ids[:, inputs['input_ids'].shape[1]:]
        responses = processor.batch_decode(generated_ids, skip_special_tokens=True)

        batch_results = []
        for data, response in zip(valid_batch, responses):
            hoi_predictions = extract_hoi_predictions(response, data, max_detections=10)
            batch_results.append({
                'image_id': data['image_id'],
                'hoi_prediction': hoi_predictions
            })

        del inputs, output_ids, generated_ids
        torch.cuda.empty_cache()
        return batch_results

    except Exception as e:
        return []

In [None]:
# Cell 9: Processing Pipeline
def process_images(model, processor, test_image_files, output_file="results.json",
                   batch_size=20, max_images=None, save_interval=1000):
    dataset = SWiGDataset(test_image_files, max_images)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=12,
        pin_memory=True,
        collate_fn=lambda x: x,
        persistent_workers=True,
        prefetch_factor=6,
        drop_last=False
    )

    all_predictions = []
    processed_count = 0
    total_hoi_count = 0

    pbar = tqdm(dataloader, desc="Processing", unit="batch")

    for batch_idx, batch_data in enumerate(pbar):
        try:
            batch_results = batch_inference(batch_data, model, processor)
            all_predictions.extend(batch_results)
            processed_count += len(batch_data)

            batch_hoi_count = sum(len(pred['hoi_prediction']) for pred in batch_results)
            total_hoi_count += batch_hoi_count

            pbar.set_postfix({
                'Images': processed_count,
                'HOI': total_hoi_count,
                'Avg_HOI': f"{total_hoi_count/processed_count:.1f}"
            })

            if batch_idx % 50 == 0:
                gc.collect()
                torch.cuda.empty_cache()

            if processed_count % save_interval == 0:
                checkpoint_file = f"{output_file}.checkpoint_{processed_count}.json"
                with open(checkpoint_file, 'w') as f:
                    json.dump(all_predictions, f, indent=2)

        except Exception as e:
            continue

    with open(output_file, 'w') as f:
        json.dump(all_predictions, f, indent=2)

    return all_predictions

In [None]:
# Cell 10: JSON Output Generation
def generate_json_output(predictions, output_file="swighoi.json"):
    if not predictions:
        return

    final_output = []

    for prediction in predictions:
        image_id_str = prediction['image_id']

        if image_id_str.endswith('.jpg'):
            image_id = image_id_str[:-4]
        else:
            image_id = image_id_str

        formatted_prediction = {
            "image_id": image_id,
            "hoi_prediction": prediction.get('hoi_prediction', [])[:10]
        }

        final_output.append(formatted_prediction)

    def format_json(data):
        json_str = json.dumps(data, indent=2, separators=(',', ': '))
        array_pattern = re.compile(r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]')
        json_str = array_pattern.sub(r'[\1, \2, \3, \4]', json_str)
        return json_str

    with open(output_file, 'w') as f:
        f.write(format_json(final_output))

    return final_output

In [None]:
# Cell 11: Execution Functions
def run_test_subset():
    img_dir = download_dataset()
    if not img_dir:
        return None

    subset_images = extract_image_subset()
    if not subset_images:
        return None

    model, processor = initialize_model()

    predictions = process_images(
        model, processor, subset_images,
        output_file="test_results.json",
        batch_size=20,
        max_images=50,
        save_interval=25
    )

    return predictions

def run_full_subset():
    img_dir = download_dataset()
    if not img_dir:
        return None

    subset_images = extract_image_subset()
    if not subset_images:
        return None

    model, processor = initialize_model()

    predictions = process_images(
        model, processor, subset_images,
        output_file="subset_results.json",
        batch_size=20,
        max_images=None,
        save_interval=500
    )

    return predictions

def run_full_dataset():
    img_dir = download_dataset()
    if not img_dir:
        return None

    with open("test.json", 'r') as f:
        test_data = json.load(f)
    all_test_images = list(test_data.keys())

    model, processor = initialize_model()

    predictions = process_images(
        model, processor, all_test_images,
        output_file="full_results.json",
        batch_size=20,
        max_images=None,
        save_interval=2000
    )

    return predictions

In [None]:
# Cell 12: Main Execution
if __name__ == "__main__":
    mode = "all"

    if mode == "small":
        results = run_test_subset()
    elif mode == "subset":
        results = run_full_subset()
    elif mode == "all":
        results = run_full_dataset()

    if results:
        final_json = generate_json_output(results)