In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("japeralrashid/xr-bones-dataset-for-bone-fracture-detection")

print("Path to dataset files:", path)

Path to dataset files: /home/siya/.cache/kagglehub/datasets/japeralrashid/xr-bones-dataset-for-bone-fracture-detection/versions/5


In [2]:
import os
import pandas as pd
import glob
import json

# Path to your downloaded dataset
dataset_path = "/home/siya/.cache/kagglehub/datasets/japeralrashid/xr-bones-dataset-for-bone-fracture-detection/versions/5"
print(f"Dataset location: {dataset_path}")

# List all directories and files
print("\nDirectory structure:")
for root, dirs, files in os.walk(dataset_path, topdown=True):
    level = root.replace(dataset_path, '').count(os.sep)
    indent = ' ' * 4 * level
    print(f"{indent}{os.path.basename(root)}/")
    subindent = ' ' * 4 * (level + 1)
    for f in files[:3]:  # Show only first 3 files to avoid clutter
        print(f"{subindent}{f}")
    if len(files) > 3:
        print(f"{subindent}... and {len(files)-3} more files")

# Check CSV files
csv_files = glob.glob(os.path.join(dataset_path, "*_labels.csv"))
print(f"\nFound {len(csv_files)} label CSV files")

# Examine one CSV file to understand its structure
if csv_files:
    sample_csv = csv_files[0]
    df = pd.read_csv(sample_csv)
    print(f"\nStructure of {os.path.basename(sample_csv)}:")
    print(df.head())
    print(f"\nTotal entries: {len(df)}")

Dataset location: /home/siya/.cache/kagglehub/datasets/japeralrashid/xr-bones-dataset-for-bone-fracture-detection/versions/5

Directory structure:
5/
    YOLODataSet/
        xr_bones.yaml
        xr.yaml
        images/
            train/
                XR_SHOULDER_positive_3530.png
                XR_ELBOW_positive_366.png
                XR_ELBOW_positive_597.png
                ... and 23848 more files
            val/
                XR_FINGER_negative_1710.png
                XR_ELBOW_negative_761.png
                XR_ELBOW_negative_927.png
                ... and 997 more files
        labels/
            train/
                XR_ELBOW_positive_523.txt
                XR_SHOULDER_negative_1146.txt
                XR_FINGER_negative_1022.txt
                ... and 21498 more files
            val/
                XR_FINGER_negative_2228.txt
                XR_HAND_positive_413.txt
                XR_SHOULDER_negative_3653.txt
                ... and 997 more files

Found 0 l

In [3]:
import os
import yaml
import glob
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# Dataset paths
dataset_path = "/home/siya/.cache/kagglehub/datasets/japeralrashid/xr-bones-dataset-for-bone-fracture-detection/versions/5"
yolo_config_path = os.path.join(dataset_path, "YOLODataSet", "xr_bones.yaml")

# Load YOLO configuration
with open(yolo_config_path, 'r') as f:
    yolo_config = yaml.safe_load(f)

print("YOLO Configuration:")
print(yaml.dump(yolo_config, indent=2))

# Check dataset statistics
train_images = glob.glob(os.path.join(dataset_path, "YOLODataSet", "images", "train", "*.png"))
val_images = glob.glob(os.path.join(dataset_path, "YOLODataSet", "images", "val", "*.png"))
train_labels = glob.glob(os.path.join(dataset_path, "YOLODataSet", "labels", "train", "*.txt"))
val_labels = glob.glob(os.path.join(dataset_path, "YOLODataSet", "labels", "val", "*.txt"))

print(f"\nDataset Statistics:")
print(f"Training images: {len(train_images)}")
print(f"Validation images: {len(val_images)}")
print(f"Training labels: {len(train_labels)}")
print(f"Validation labels: {len(val_labels)}")

# Sample a few images to understand the data
print("\nSample images:")
for i, img_path in enumerate(train_images[:3]):
    print(f"Image {i+1}: {os.path.basename(img_path)}")

YOLO Configuration:
names:
- XR_ELBOW_positive
- XR_FINGER_positive
- XR_FOREARM_positive
- XR_HAND_positive
- XR_SHOULDER_positive
- XR_ELBOW_negative
- XR_FINGER_negative
- XR_FOREARM_negative
- XR_HAND_negative
- XR_SHOULDER_negative
nc: 10
train: /YOLODataSet/images/train
val: /YOLODataSet/images/val


Dataset Statistics:
Training images: 23851
Validation images: 1000
Training labels: 21501
Validation labels: 1000

Sample images:
Image 1: XR_SHOULDER_positive_3530.png
Image 2: XR_ELBOW_positive_366.png
Image 3: XR_ELBOW_positive_597.png


In [4]:
import json
import shutil
from pathlib import Path
from tqdm import tqdm

def yolo_to_qwen_format(dataset_path, output_dir, split='train', max_samples=None):
    """
    Convert YOLO format dataset to Qwen3-VL format
    """
    # Create output directories
    os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
    
    # Get image and label paths
    image_dir = os.path.join(dataset_path, "YOLODataSet", "images", split)
    label_dir = os.path.join(dataset_path, "YOLODataSet", "labels", split)
    
    image_paths = glob.glob(os.path.join(image_dir, "*.png"))
    
    if max_samples:
        image_paths = image_paths[:max_samples]
    
    qwen_data = []
    
    for img_path in tqdm(image_paths, desc=f"Processing {split} images"):
        # Get corresponding label file
        img_name = os.path.basename(img_path)
        label_name = os.path.splitext(img_name)[0] + ".txt"
        label_path = os.path.join(label_dir, label_name)
        
        # Read image to get dimensions
        img = Image.open(img_path)
        img_width, img_height = img.size
        
        # Copy image to output directory
        output_img_path = os.path.join(output_dir, 'images', img_name)
        shutil.copy2(img_path, output_img_path)
        
        # Process labels
        bounding_boxes = []
        has_fracture = False
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                lines = f.readlines()
            
            for line in lines:
                parts = line.strip().split()
                if len(parts) >= 5:
                    class_id = int(float(parts[0]))
                    # YOLO format: center_x, center_y, width, height (normalized)
                    center_x = float(parts[1]) * img_width
                    center_y = float(parts[2]) * img_height
                    bbox_width = float(parts[3]) * img_width
                    bbox_height = float(parts[4]) * img_height
                    
                    # Convert to [x_min, y_min, x_max, y_max]
                    x_min = center_x - (bbox_width / 2)
                    y_min = center_y - (bbox_height / 2)
                    x_max = center_x + (bbox_width / 2)
                    y_max = center_y + (bbox_height / 2)
                    
                    # Clamp to image boundaries
                    x_min = max(0, min(img_width, x_min))
                    y_min = max(0, min(img_height, y_min))
                    x_max = max(0, min(img_width, x_max))
                    y_max = max(0, min(img_height, y_max))
                    
                    bounding_boxes.append([x_min, y_min, x_max, y_max])
                    has_fracture = True
        
        # Determine anatomical region from filename
        region = "unknown"
        if "ELBOW" in img_name:
            region = "elbow"
        elif "FINGER" in img_name:
            region = "finger"
        elif "FOREARM" in img_name:
            region = "forearm"
        elif "HAND" in img_name:
            region = "hand"
        elif "SHOULDER" in img_name:
            region = "shoulder"
        
        # Create Qwen3-VL formatted messages
        if has_fracture:
            prompt = f"Identify and box all fracture regions in this {region} X-ray image. Return results in JSON format."
            response = {
                "fracture_present": True,
                "anatomical_region": region,
                "bounding_boxes": [
                    {"bbox_2d": bbox, "label": "fracture", "region": region} 
                    for bbox in bounding_boxes
                ],
                "diagnosis": "Fracture detected in the highlighted region."
            }
        else:
            prompt = f"Analyze this {region} X-ray image for any signs of fractures. Return results in JSON format."
            response = {
                "fracture_present": False,
                "anatomical_region": region,
                "bounding_boxes": [],
                "diagnosis": "No fractures detected in this X-ray image."
            }
        
        # Create conversation format
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": prompt}
                ]
            },
            {
                "role": "assistant",
                "content": json.dumps(response)
            }
        ]
        
        qwen_data.append({
            "messages": messages,
            "image": output_img_path,
            "metadata": {
                "original_image": img_path,
                "has_fracture": has_fracture,
                "region": region,
                "bounding_boxes": bounding_boxes
            }
        })
    
    return qwen_data

# Create output directory
output_dir = "./xr_bones_qwen_format"
os.makedirs(output_dir, exist_ok=True)

# Convert training and validation data
print("Converting training data...")
train_data = yolo_to_qwen_format(dataset_path, output_dir, split='train')
print(f"Converted {len(train_data)} training samples")

print("\nConverting validation data...")
val_data = yolo_to_qwen_format(dataset_path, output_dir, split='val')
print(f"Converted {len(val_data)} validation samples")

# Save dataset metadata
dataset_info = {
    "total_train_samples": len(train_data),
    "total_val_samples": len(val_data),
    "anatomical_regions": ["elbow", "finger", "forearm", "hand", "shoulder"],
    "created_at": "2025-11-13"
}

with open(os.path.join(output_dir, "dataset_info.json"), 'w') as f:
    json.dump(dataset_info, f, indent=2)

Converting training data...


Processing train images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 23851/23851 [00:25<00:00, 942.96it/s] 


Converted 23851 training samples

Converting validation data...


Processing val images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [00:01<00:00, 930.83it/s]

Converted 1000 validation samples





In [5]:
from datasets import Dataset, DatasetDict, Features, Value
import json
from PIL import Image
import os
from tqdm import tqdm
import numpy as np

def create_hf_dataset_fixed(train_data, val_data):
    """
    Create Hugging Face dataset with a simple, robust schema that avoids all nested structure issues
    """
    # Define a super simple schema that will definitely work
    features = Features({
        "text_input": Value("string"),           # User prompt text
        "image_path": Value("string"),           # Path to image file
        "target_output": Value("string"),        # Assistant response (JSON string)
        "has_fracture": Value("bool"),           # Whether fracture is present
        "region": Value("string"),               # Anatomical region
        "original_image_path": Value("string"), # Original image path
    })
    
    # Process data to match the simple schema
    def process_data(data_list):
        processed = []
        for item in tqdm(data_list, desc="Processing data"):
            try:
                # Get user text from messages
                user_text = ""
                for content_item in item["messages"][0]["content"]:
                    if content_item["type"] == "text":
                        user_text = content_item["text"]
                
                # Get assistant response
                assistant_content = item["messages"][1]["content"]
                
                # Extract metadata
                has_fracture = item["metadata"]["has_fracture"]
                region = item["metadata"]["region"]
                original_image_path = item["metadata"]["original_image"]
                
                processed.append({
                    "text_input": user_text,
                    "image_path": item["image"],
                    "target_output": assistant_content,
                    "has_fracture": has_fracture,
                    "region": region,
                    "original_image_path": original_image_path,
                })
            
            except Exception as e:
                print(f"Error processing item: {e}")
                continue
        
        return processed
    
    # Process training and validation data
    print("Processing training data for dataset creation...")
    train_processed = process_data(train_data)
    
    print("Processing validation data for dataset creation...")
    val_processed = process_data(val_data)
    
    # Create datasets with explicit features
    print("Creating training dataset...")
    train_dataset = Dataset.from_list(train_processed, features=features)
    
    print("Creating validation dataset...")
    val_dataset = Dataset.from_list(val_processed, features=features)
    
    # Create DatasetDict
    dataset = DatasetDict({
        "train": train_dataset,
        "validation": val_dataset
    })
    
    return dataset

# Create Hugging Face dataset with fixed approach
print("\nCreating Hugging Face dataset with proper schema...")
dataset = create_hf_dataset_fixed(train_data, val_data)

print(f"Dataset created successfully:")
print(f"  Training samples: {len(dataset['train'])}")
print(f"  Validation samples: {len(dataset['validation'])}")

# Verify dataset structure
print("\nDataset structure verification:")
print(f"Train features: {dataset['train'].features}")
print(f"Sample training example: {dataset['train'][0]}")

# Save dataset for future use
output_dir = "./xr_bones_qwen_format"
os.makedirs(os.path.join(output_dir, "hf_dataset_fixed"), exist_ok=True)
dataset.save_to_disk(os.path.join(output_dir, "hf_dataset_fixed"))
print(f"Dataset saved to {os.path.join(output_dir, 'hf_dataset_fixed')}")


Creating Hugging Face dataset with proper schema...
Processing training data for dataset creation...


Processing data: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 23851/23851 [00:00<00:00, 978580.67it/s]


Processing validation data for dataset creation...


Processing data: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [00:00<00:00, 675628.87it/s]

Creating training dataset...





Creating validation dataset...
Dataset created successfully:
  Training samples: 23851
  Validation samples: 1000

Dataset structure verification:
Train features: {'text_input': Value('string'), 'image_path': Value('string'), 'target_output': Value('string'), 'has_fracture': Value('bool'), 'region': Value('string'), 'original_image_path': Value('string')}
Sample training example: {'text_input': 'Analyze this shoulder X-ray image for any signs of fractures. Return results in JSON format.', 'image_path': './xr_bones_qwen_format/images/XR_SHOULDER_positive_3530.png', 'target_output': '{"fracture_present": false, "anatomical_region": "shoulder", "bounding_boxes": [], "diagnosis": "No fractures detected in this X-ray image."}', 'has_fracture': False, 'region': 'shoulder', 'original_image_path': '/home/siya/.cache/kagglehub/datasets/japeralrashid/xr-bones-dataset-for-bone-fracture-detection/versions/5/YOLODataSet/images/train/XR_SHOULDER_positive_3530.png'}


Saving the dataset (0/1 shards):   0%|          | 0/23851 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset saved to ./xr_bones_qwen_format/hf_dataset_fixed


In [6]:
import torch
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import bitsandbytes as bnb
import gc
import os

# Hardware verification
print("\n" + "="*60)
print("HARDWARE SETUP")
print("="*60)
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    total_vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"Total VRAM: {total_vram:.1f} GB")
    print(f"Current memory usage: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

# Memory optimization
def optimize_memory():
    """Optimize GPU memory usage"""
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.set_per_process_memory_fraction(0.85)
    print(f"Memory optimized. Current usage: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

optimize_memory()

# Model loading
print("\n" + "="*60)
print("MODEL LOADING")
print("="*60)
print("Loading Qwen3-VL-2B-Instruct with 4-bit quantization...")

model_name = "Qwen/Qwen3-VL-2B-Instruct"

# Load model with 4-bit quantization
model = Qwen3VLForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

processor = AutoProcessor.from_pretrained(model_name)

print("‚úÖ Model loaded successfully!")
print(f"Current memory usage: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

# Enable gradient checkpointing and prepare for training
print("\nEnabling gradient checkpointing...")
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

print("Preparing model for k-bit training...")
model = prepare_model_for_kbit_training(model)

# Configure LoRA
print("\nConfiguring LoRA for medical imaging fine-tuning...")
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Apply LoRA
print("Applying LoRA configuration...")
model = get_peft_model(model, lora_config)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n‚úÖ Model configuration complete!")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Total memory usage: {torch.cuda.memory_allocated()/1024**3:.2f} GB")


HARDWARE SETUP
CUDA available: True
GPU: NVIDIA GeForce RTX 4080 Laptop GPU
Total VRAM: 12.0 GB
Current memory usage: 0.00 GB


`torch_dtype` is deprecated! Use `dtype` instead!


Memory optimized. Current usage: 0.00 GB

MODEL LOADING
Loading Qwen3-VL-2B-Instruct with 4-bit quantization...


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


‚úÖ Model loaded successfully!
Current memory usage: 1.47 GB

Enabling gradient checkpointing...
Preparing model for k-bit training...

Configuring LoRA for medical imaging fine-tuning...
Applying LoRA configuration...

‚úÖ Model configuration complete!
Trainable parameters: 34,865,152
Total memory usage: 2.19 GB


In [7]:
from transformers import DataCollatorWithPadding, TrainingArguments
import torch
from typing import Dict, List, Any
from PIL import Image
import json
import numpy as np

class MedicalImagingDataCollator:
    def __init__(self, processor):
        self.processor = processor
    
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Fixed collator with proper image_grid_thw parameter for Qwen3-VL
        """
        # Extract images and text inputs
        images = []
        texts = []
        
        for feature in features:
            # Load image from path
            image = Image.open(feature["image_path"]).convert("RGB")
            images.append(image)
            
            # Create the conversation format
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": feature["text_input"]}
                    ]
                },
                {
                    "role": "assistant",
                    "content": feature["target_output"]
                }
            ]
            
            # Apply chat template
            text = self.processor.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            texts.append(text)
        
        # Process all inputs together
        inputs = self.processor(
            images=images,
            text=texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        )
        
        # Create labels for training (mask input tokens)
        labels = inputs["input_ids"].clone()
        
        # FIX: Add image_grid_thw parameter - CRITICAL FOR QWEN3-VL
        batch_size = len(images)
        # For standard Qwen3-VL images, use this grid size
        image_grid_thw = torch.tensor([[1, 24, 24]] * batch_size, dtype=torch.long)
        
        # Return batch with ALL required parameters
        return {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "pixel_values": inputs["pixel_values"],
            "image_grid_thw": image_grid_thw,  # This is the crucial fix
            "labels": labels
        }

# Create data collator
data_collator = MedicalImagingDataCollator(processor)
print("‚úÖ Data collator created successfully!")

# Training configuration
from transformers import TrainingArguments

# Calculate training parameters based on your hardware
per_device_batch_size = 1  # Conservative for 12GB VRAM
gradient_accumulation_steps = 8  # Simulate larger batch size
num_train_epochs = 2  # Start with 2 epochs
learning_rate = 2e-5  # Conservative learning rate for medical tasks

# Training arguments - FIXED PARAMETER NAMES
training_args = TrainingArguments(
    output_dir="./qwen3vl-bone-fracture",
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_batch_size,
    per_device_eval_batch_size=per_device_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,
    weight_decay=0.01,
    warmup_ratio=0.1,
    logging_steps=10,
    eval_strategy="steps",  # FIXED: evaluation_strategy -> eval_strategy
    eval_steps=50,  # Evaluate every 50 steps
    save_strategy="steps",
    save_steps=100,  # Save checkpoint every 100 steps
    save_total_limit=3,
    fp16=True,  # Use mixed precision
    bf16=False,
    dataloader_num_workers=4,
    report_to="none",
    logging_first_step=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    optim="paged_adamw_8bit",  # Memory-efficient optimizer
    lr_scheduler_type="cosine",
    seed=42,
    remove_unused_columns=False,  # Important for custom collator
)

print("‚úÖ Training configuration created successfully!")
print(f"Training parameters:")
print(f"  Batch size per device: {per_device_batch_size}")
print(f"  Gradient accumulation steps: {gradient_accumulation_steps}")
print(f"  Effective batch size: {per_device_batch_size * gradient_accumulation_steps}")
print(f"  Learning rate: {learning_rate}")
print(f"  Epochs: {num_train_epochs}")

‚úÖ Data collator created successfully!
‚úÖ Training configuration created successfully!
Training parameters:
  Batch size per device: 1
  Gradient accumulation steps: 8
  Effective batch size: 8
  Learning rate: 2e-05
  Epochs: 2


In [None]:
# More robust dataset formatting with explicit feature definitions
from datasets import Features, Sequence, Value, Image as ImageFeature

def format_for_unsloth_safe(sample):
    """Safely format samples with error handling for inconsistent data types"""
    try:
        # Safely get text input
        text_input = str(sample.get("text_input", "")).strip()
        
        # Safely get target output - ensure it's a valid JSON string
        target_output = sample.get("target_output", "")
        if not isinstance(target_output, str):
            target_output = json.dumps({})
        
        # Create messages format
        messages = [
            {
                "role": "user", 
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": text_input}
                ]
            },
            {
                "role": "assistant",
                "content": target_output
            }
        ]
        
        # Safely load image
        try:
            image_path = sample.get("image_path", "")
            if not image_path or not os.path.exists(image_path):
                # Use a fallback blank image if path is invalid
                from PIL import Image
                image = Image.new('RGB', (448, 448), color='white')
            else:
                image = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"‚ö†Ô∏è Error loading image {image_path}: {e}")
            from PIL import Image
            image = Image.new('RGB', (448, 448), color='white')
        
        return {
            "messages": messages,
            "image": image
        }
    except Exception as e:
        print(f"‚ùå Error formatting sample: {e}")
        # Return a minimal valid sample as fallback
        from PIL import Image
        default_image = Image.new('RGB', (448, 448), color='white')
        return {
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": "Analyze this X-ray image for fractures."}
                    ]
                },
                {
                    "role": "assistant",
                    "content": "{}"
                }
            ],
            "image": default_image
        }

# Define explicit features to avoid Arrow serialization issues
features = Features({
    "messages": Sequence({
        "role": Value("string"),
        "content": Value("string"),
    }),
    "image": ImageFeature(),
})

# Process dataset in smaller batches with progress tracking
print("\nüîÑ Reformatting dataset with robust error handling...")
batch_size = 1000
total_train = len(dataset["train"])
total_eval = len(dataset["validation"])
formatted_train_samples = []
formatted_eval_samples = []

# Process training data in batches
for i in range(0, total_train, batch_size):
    batch = dataset["train"][i:min(i+batch_size, total_train)]
    for j in range(len(batch["text_input"])):
        sample = {key: batch[key][j] for key in batch.keys()}
        formatted = format_for_unsloth_safe(sample)
        formatted_train_samples.append(formatted)
    print(f"  Processed training samples {i}/{total_train}")

# Process validation data in batches
for i in range(0, total_eval, batch_size):
    batch = dataset["validation"][i:min(i+batch_size, total_eval)]
    for j in range(len(batch["text_input"])):
        sample = {key: batch[key][j] for key in batch.keys()}
        formatted = format_for_unsloth_safe(sample)
        formatted_eval_samples.append(formatted)
    print(f"  Processed validation samples {i}/{total_eval}")

# Create datasets with explicit features
print("\nüîß Creating dataset objects with explicit features...")
from datasets import Dataset

formatted_train = Dataset.from_list(formatted_train_samples, features=features)
formatted_eval = Dataset.from_list(formatted_eval_samples, features=features)

print(f"‚úÖ Training samples reformatted: {len(formatted_train)}")
print(f"‚úÖ Validation samples reformatted: {len(formatted_eval)}")

# Test with a small subset first before full training
print("\nüîç Testing data collator with a single sample...")
test_batch = [formatted_train[0]]
try:
    test_inputs = data_collator(test_batch)
    print("‚úÖ Data collator test passed!")
    print(f"  Input IDs shape: {test_inputs['input_ids'].shape}")
    print(f"  Pixel values shape: {test_inputs['pixel_values'].shape}")
except Exception as e:
    print(f"‚ùå Data collator test failed: {e}")
    # Fallback values if test fails
    print("‚ö†Ô∏è Using fallback configuration...")