In [None]:
%pip install -q datasets huggingface_hub pillow matplotlib requests
%pip install fsspec==2023.9.2  


Collecting fsspec==2023.9.2
  Downloading fsspec-2023.9.2-py3-none-any.whl.metadata (6.7 kB)
Downloading fsspec-2023.9.2-py3-none-any.whl (173 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/173.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m163.8/173.4 kB[0m [31m5.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m173.4/173.4 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2023.9.2 which is incompatible.
tor

In [None]:
from datasets import load_dataset
from itertools import islice
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO
import numpy as np
from typing import Dict, List, Tuple, Optional, Any


In [None]:
num_samples = 1000
load_num_dataset_streaming = 20

SAMPLE_DELAY = 0.2
RETRY_BASE_DELAY = 2

In [None]:

def load_dataset_streaming(dataset_name: str, num_samples: int = 5000, _attempt: int = 0, _fallback_idx: int = 0) -> List[Dict]:
    """Load dataset in streaming mode and collect specified number of samples."""
    import time
    from requests.exceptions import HTTPError

    fallback_datasets = [
        dataset_name,
        "conceptual_captions",
        "laion/laion400m",
        "ChristophSchuhmann/MS_COCO_2017_URL_TEXT"
    ]

    if _fallback_idx >= len(fallback_datasets):
        print("error")
        return []

    current_dataset = fallback_datasets[_fallback_idx] if _fallback_idx > 0 else dataset_name

    try:
        print(f"Loading {current_dataset} in streaming mode...")
        if _attempt > 0:
            delay = min(RETRY_BASE_DELAY ** _attempt, 60)
            print(f"  Waiting {delay} seconds before retry...")
            time.sleep(delay)

        dataset_stream = load_dataset(current_dataset, split="train", streaming=True)

        samples = []
        for idx, sample in enumerate(islice(dataset_stream, num_samples)):
            samples.append(sample)
            if (idx + 1) % 500 == 0:
                print(f"  Collected {idx + 1} samples...")
            # Small delay to avoid hitting rate limits
            if SAMPLE_DELAY > 0:
                time.sleep(SAMPLE_DELAY)

        print(f"✅ Successfully collected {len(samples)} samples from {current_dataset}")
        return samples

    except HTTPError as e:
        if "429" in str(e):
            print(f"⚠️ Rate limited (429 error) on {current_dataset}")
            if _attempt < 3:
                print(f"  Retrying with backoff (attempt {_attempt + 1}/3)...")
                return load_dataset_streaming(dataset_name, num_samples, _attempt + 1, _fallback_idx)
            else:
                print(f"  Max retries reached for {current_dataset}")
    except Exception as e:
        print(f"❌ Error loading {current_dataset}: {type(e).__name__}: {str(e)[:100]}")

    if _fallback_idx + 1 < len(fallback_datasets):
        print(f"Trying fallback dataset {_fallback_idx + 2}/{len(fallback_datasets)}...")
        return load_dataset_streaming(dataset_name, num_samples, 0, _fallback_idx + 1)

    return []


def download_image(url: str) -> Optional[Image.Image]:
    try:
        response = requests.get(url, timeout=10, headers={
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
        })
        if response.status_code == 200:
            return Image.open(BytesIO(response.content))
    except Exception as e:
        print(f"Error downloading image: {e}")
    return None


def get_sample(samples: List[Dict], index: int) -> Dict[str, Any]:
    if 0 <= index < len(samples):
        sample = samples[index]

        output = {
            'index': index,
            'prompt': sample.get('prompt', sample.get('caption', 'No prompt')),
            'caption': sample.get('target_caption', sample.get('caption', 'No caption')),
            'input_image': None,
            'output_image': None
        }

        if 'target_image' in sample:
            output['output_image'] = sample['target_image']
            if 'source_images' in sample and len(sample['source_images']) > 0:
                output['input_image'] = sample['source_images'][0]

        elif 'image_url' in sample or 'image' in sample:
            url = sample.get('image_url') or sample.get('image')
            output['output_image'] = download_image(url)

        return output
    else:
        raise IndexError(f"Index {index} out of range. Dataset has {len(samples)} samples.")


def get_batch(samples: List[Dict], indices: List[int]) -> Dict[str, List]:
    batch = {
        'prompts': [],
        'input_images': [],
        'output_images': [],
        'captions': []
    }

    for idx in indices:
        try:
            sample = get_sample(samples, idx)
            batch['prompts'].append(sample['prompt'])
            batch['input_images'].append(sample['input_image'])
            batch['output_images'].append(sample['output_image'])
            batch['captions'].append(sample['caption'])
        except IndexError as e:
            print(f"Skipping index {idx}: {e}")

    return batch


In [None]:
def visualize_sample(sample: Dict[str, Any]):
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    if sample['input_image'] is not None:
        axes[0].imshow(sample['input_image'])
        axes[0].set_title("Input Image")
    else:
        axes[0].text(0.5, 0.5, 'No Input Image', ha='center', va='center')
        axes[0].set_title("Input Image (Missing)")
    axes[0].axis('off')

    # Output image
    if sample['output_image'] is not None:
        axes[1].imshow(sample['output_image'])
        axes[1].set_title("Output Image")
    else:
        axes[1].text(0.5, 0.5, 'No Output Image', ha='center', va='center')
        axes[1].set_title("Output Image (Missing)")
    axes[1].axis('off')

    plt.suptitle(f"Prompt: {sample['prompt'][:80]}..." if len(sample['prompt']) > 80 else f"Prompt: {sample['prompt']}")
    plt.tight_layout()
    plt.show()

    print(f"Caption: {sample['caption']}")


def visualize_batch(batch: Dict[str, List], max_show: int = 4):
    """Visualize a batch of samples in a grid."""
    n_samples = min(len(batch['prompts']), max_show)

    if n_samples == 0:
        print("No samples to visualize")
        return

    fig, axes = plt.subplots(n_samples, 2, figsize=(10, 4*n_samples))
    if n_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(n_samples):
        # Input image
        if batch['input_images'][i] is not None:
            axes[i, 0].imshow(batch['input_images'][i])
        else:
            axes[i, 0].text(0.5, 0.5, 'No Input', ha='center', va='center')
        axes[i, 0].set_title(f"Input {i+1}")
        axes[i, 0].axis('off')

        # Output image
        if batch['output_images'][i] is not None:
            axes[i, 1].imshow(batch['output_images'][i])
        else:
            axes[i, 1].text(0.5, 0.5, 'No Output', ha='center', va='center')
        axes[i, 1].set_title(f"Output {i+1}")
        axes[i, 1].axis('off')

        # Add prompt as text
        prompt_text = batch['prompts'][i][:40] + "..." if len(batch['prompts'][i]) > 40 else batch['prompts'][i]
        axes[i, 1].text(0.5, -0.1, prompt_text, ha='center', transform=axes[i, 1].transAxes)

    plt.tight_layout()
    plt.show()


In [None]:
# Load the dataset
samples = load_dataset_streaming("xcpan/MetaQuery_Instruct_2.4M_512res", num_samples=load_num_dataset_streaming)


Loading xcpan/MetaQuery_Instruct_2.4M_512res in streaming mode...


Downloading readme:   0%|          | 0.00/566 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/4258 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
sample = get_sample(samples, 0)
visualize_sample(sample)
print(len(samples))

In [None]:
if not samples or len(samples) == 0:
    print("No samples loaded from dataset. Using a minimal dummy dataset for testing.")
    samples = [
        {"text": "A cat sitting on a mat", "image": None},
        {"text": "A dog playing in the park", "image": None},
        {"text": "Birds flying in the sky", "image": None},
        {"text": "Fish swimming in the ocean", "image": None},
        {"text": "A tree in the forest", "image": None}
    ]
    print(f"{len(samples)} dummy samples created.")
else:
    print(f"loaded {len(samples)} samples from dataset.")


In [None]:
batch = get_batch(samples, indices=[0, 1, 2, 3])
print(f"Batch contains {len(batch['prompts'])} samples")
visualize_batch(batch, max_show=4)


In [None]:
import random
random_indices = random.sample(range(len(samples)), 6)
random_batch = get_batch(samples, random_indices)
visualize_batch(random_batch, max_show=3)


In [None]:
def save_batch(batch: Dict[str, List], filename: str = "batch_data.json"):
    """Save batch text data to JSON (images not included)."""
    import json

    save_data = {
        'prompts': batch['prompts'],
        'captions': batch['captions'],
        'num_samples': len(batch['prompts'])
    }

    with open(filename, 'w') as f:
        json.dump(save_data, f, indent=2)

    print(f"✅ Saved batch data to {filename}")



In [None]:
%pip install fsspec==2023.9.2
%pip install -U datasets huggingface_hub
%pip install -q pillow matplotlib

from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import time
torch.set_float32_matmul_precision('high')

In [None]:

def filter_valid_samples(samples: List[Dict], max_samples: Optional[int] = None) -> List[Dict]:
    """
    Filter samples to keep only those with:
    
    """
    filtered_samples = []

    for idx, sample in enumerate(samples):
        if max_samples and len(filtered_samples) >= max_samples:
            break

        prompt = sample.get('prompt') or sample.get('caption') or sample.get('text')
        if not prompt or prompt.strip() == '':
            continue

        output = {
            'index': idx,
            'prompt': prompt,
            'caption': sample.get('target_caption', sample.get('caption', prompt)),
            'input_image': None,
            'output_image': None,
            'raw_sample': sample
        }

        if 'target_image' in sample and sample['target_image'] is not None:
            output['output_image'] = sample['target_image']
            if 'source_images' in sample and len(sample['source_images']) > 0:
                output['input_image'] = sample['source_images'][0]

        elif 'image_url' in sample or 'image' in sample:
            url = sample.get('image_url') or sample.get('image')
            if url:
                downloaded_img = download_image(url)
                if downloaded_img:
                    output['output_image'] = downloaded_img
                else:
                    continue

        if output['output_image'] is not None:
            filtered_samples.append(output)
            if len(filtered_samples) % 10 == 0:
                print(f"✅ Filtered {len(filtered_samples)} valid samples...")

    print(f"\n📊 Summary:")
    print(f"  - Total samples processed: {len(samples)}")
    print(f"  - Valid samples found: {len(filtered_samples)}")
    print(f"  - Rejection rate: {(1 - len(filtered_samples)/len(samples))*100:.1f}%")

    return filtered_samples

print("Filtering samples for valid images and prompts...")
filtered_data = filter_valid_samples(samples, max_samples=50)

print(f"Filtered data saved to 'filtered_data' variable")
print(f"📦 Contains {len(filtered_data)} valid samples with images and prompts")


In [None]:
def analyze_filtered_data(filtered_data: List[Dict]) -> Dict[str, Any]:
    analysis = {
        'total_samples': len(filtered_data),
        'has_input_image': 0,
        'has_output_image': len(filtered_data),
        'avg_prompt_length': 0,
        'prompt_samples': [],
        'image_sizes': []
    }

    prompt_lengths = []

    for sample in filtered_data:
        if sample['input_image'] is not None:
            analysis['has_input_image'] += 1

        prompt_len = len(sample['prompt'])
        prompt_lengths.append(prompt_len)

        if len(analysis['prompt_samples']) < 5:
            analysis['prompt_samples'].append(sample['prompt'][:100] + '...' if len(sample['prompt']) > 100 else sample['prompt'])

        if sample['output_image'] and hasattr(sample['output_image'], 'size'):
            analysis['image_sizes'].append(sample['output_image'].size)

    analysis['avg_prompt_length'] = np.mean(prompt_lengths) if prompt_lengths else 0
    analysis['min_prompt_length'] = min(prompt_lengths) if prompt_lengths else 0
    analysis['max_prompt_length'] = max(prompt_lengths) if prompt_lengths else 0

    return analysis

analysis = analyze_filtered_data(filtered_data)

print("\n Sample prompts:")
for i, prompt in enumerate(analysis['prompt_samples'], 1):
    print(f"  {i}. {prompt}")

if analysis['image_sizes']:
    print(f"\n🖼️ Image sizes: {set(analysis['image_sizes'])}")
