In [None]:
try: import torch
except: raise ImportError('Install torch via `pip install torch`')
from packaging.version import Version as V
v = V(torch.__version__)
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!")
if   v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
elif v  < V('2.3.0'): x = 'cu{}{}-torch220'
elif v  < V('2.4.0'): x = 'cu{}{}-torch230'
elif v  < V('2.5.0'): x = 'cu{}{}-torch240'
elif v  < V('2.6.0'): x = 'cu{}{}-torch250'
else: raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')

In [None]:
pip install --upgrade pip && pip install "unsloth[cu124-torch250] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
import os
import pandas as pd
from PIL import Image
import torch
import matplotlib.pyplot as plt
from unsloth import FastVisionModel
from datasets import Dataset, Image as HFImage
from transformers import TextStreamer, BitsAndBytesConfig
from unsloth import is_bf16_supported
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig

In [None]:
TRAIN_CSV_PATH = "/kaggle/input/train-captions/train_captions.csv" # CSV with 'ID' and 'caption'
TRAIN_IMAGE_DIR = "/kaggle/input/image-clef2025/train/home/damm/clef/data/2025/splits/train/"

In [None]:
VAL_CSV_PATH = "/kaggle/input/image-clef2025/valid_captions.csv" # CSV with 'ID' and 'caption'
VAL_IMAGE_DIR = "/kaggle/input/image-clef2025/valid/home/damm/clef/data/2025/splits/valid/"

In [None]:
TEST_IMAGE_DIR = "/kaggle/input/image-clef2025/test/test/" # Folder with test images named by ID

In [None]:
OUTPUT_MODEL_DIR = "/kaggle/working/finetuned_radiology_model"
SUBMISSION_FILE = "/kaggle/working/submission.csv"
LOSS_PLOT_FILE = "/kaggle/working/loss_plot.png"

In [None]:
MODEL_NAME = "unsloth/Qwen2-VL-2B-Instruct-bnb-4bit"
MAX_SEQ_LENGTH = 2048
LORA_R = 16
LORA_ALPHA = 16
TRAIN_BATCH_SIZE = 2
GRAD_ACCUMULATION_STEPS = 4
MAX_STEPS = 60 
EVAL_STEPS = 10 
LEARNING_RATE = 2e-4
SEED = 3407

In [None]:
INSTRUCTION = "You are an expert radiographer. Describe accurately what you see in this image."

In [None]:
def create_hf_dataset(csv_path, image_folder_path, dataset_type="train"):
    if not os.path.exists(csv_path):
        print(f"Warning: {dataset_type} CSV {csv_path} not found. Skipping {dataset_type} dataset.")
        return None
    if not os.path.exists(image_folder_path):
        print(f"Warning: {dataset_type} image folder {image_folder_path} not found. Skipping {dataset_type} dataset.")
        return None
    
    try:
        df = pd.read_csv(csv_path)
    except Exception as e:
        print(f"Error reading {dataset_type} CSV {csv_path}: {e}. Skipping {dataset_type} dataset.")
        return None

    data = {"image_path_str": [], "Caption": [], "ID": []} # Store paths in 'image_path_str'
    print(f"Loading {dataset_type} dataset from {csv_path} and {image_folder_path}...")
    for _, row in df.iterrows():
        image_id = str(row["ID"])
        caption = str(row["Caption"])
        
        found_image = False
        extensions_to_check = ['.jpg', '.png', '.jpeg', '.bmp', '.tiff', '.dcm']
        for ext in extensions_to_check:
            potential_path = os.path.join(image_folder_path, image_id + ext)
            if os.path.exists(potential_path):
                data["image_path_str"].append(potential_path)
                data["Caption"].append(caption)
                data["ID"].append(image_id)
                found_image = True
                break
        
        if not found_image:
            print(f"Warning: Image for ID {image_id} not found in {image_folder_path} with extensions {extensions_to_check}")
            
    if not data["image_path_str"]:
        print(f"Warning: No images found for {dataset_type} dataset. Skipping.")
        return None

    # Create dataset with image paths in 'image_path_str'
    hf_dataset = Dataset.from_dict({"image_path_str": data["image_path_str"], "Caption": data["Caption"], "ID": data["ID"]})
    
    hf_dataset = hf_dataset.cast_column("image_path_str", HFImage(decode=True))
    
    hf_dataset = hf_dataset.rename_column("image_path_str", "image")

    print(f"Loaded and cast {len(hf_dataset)} samples for {dataset_type} dataset. Image column is now 'image'.")
    return hf_dataset

In [None]:
train_dataset_raw = create_hf_dataset(TRAIN_CSV_PATH, TRAIN_IMAGE_DIR, "train")
val_dataset_raw = create_hf_dataset(VAL_CSV_PATH, VAL_IMAGE_DIR, "validation")

In [None]:
if train_dataset_raw:
    print("\nDebug: Checking train_dataset_raw immediately after cast_column (first 5 samples):")
    for i in range(min(5, len(train_dataset_raw))):
        sample_check = train_dataset_raw[i]
        print(f"  ID: {sample_check.get('ID', 'N/A')}, Image Type: {type(sample_check.get('image'))}, Image is None: {sample_check.get('image') is None}")
        if not isinstance(sample_check.get('image'), Image.Image) and sample_check.get('image') is not None:
            print(f"  WARNING: Non-PIL Image object for ID {sample_check.get('ID', 'N/A')}: {sample_check.get('image')}")

In [None]:
if train_dataset_raw:
    num_train_samples = len(train_dataset_raw)
    num_samples_to_keep = int(0.5 * num_train_samples)
    train_dataset_raw = train_dataset_raw.shuffle(seed=SEED) 
    train_dataset_raw = train_dataset_raw.select(range(num_samples_to_keep))
    print(f"Subsampled training dataset to {len(train_dataset_raw)} samples (50%).")

In [None]:
def is_image_valid(example):
    image_data = example.get('image') 
    is_valid = image_data is not None and isinstance(image_data, Image.Image)
    if not is_valid:
        print(f"FILTERING OUT: ID: {example.get('ID', 'Unknown ID')}, Image Type: {type(image_data)}, Image is None: {image_data is None}")
    return is_valid

if train_dataset_raw:
    initial_count_train = len(train_dataset_raw)
    print(f"Train dataset before filtering: {initial_count_train} samples.")
    train_dataset_raw = train_dataset_raw.filter(is_image_valid, load_from_cache_file=False)
    filtered_count_train = len(train_dataset_raw)
    print(f"Train dataset after filtering: {filtered_count_train} samples. Filtered out: {initial_count_train - filtered_count_train}")
    if not train_dataset_raw or filtered_count_train == 0:
        print("Training dataset is empty after filtering invalid images. Exiting.")
        exit()

if train_dataset_raw:
    print("\nDebug: Checking train_dataset_raw after filter (first 5 samples):")
    for i in range(min(5, len(train_dataset_raw))):
        sample_check = train_dataset_raw[i]
        print(f"  ID: {sample_check.get('ID', 'N/A')}, Image Type: {type(sample_check.get('image'))}, Image is None: {sample_check.get('image') is None}")
        if sample_check.get('image') is None or not isinstance(sample_check.get('image'), Image.Image):
             print(f"  CRITICAL WARNING: Invalid image for ID {sample_check.get('ID', 'N/A')} RETAINED after filter.")

if val_dataset_raw:
    initial_count_val = len(val_dataset_raw)
    print(f"Validation dataset before filtering: {initial_count_val} samples.")
    val_dataset_raw = val_dataset_raw.filter(is_image_valid, load_from_cache_file=False)
    filtered_count_val = len(val_dataset_raw)
    print(f"Validation dataset after filtering: {filtered_count_val} samples. Filtered out: {initial_count_val - filtered_count_val}")
    if not val_dataset_raw or filtered_count_val == 0:
        print("Validation dataset is empty after filtering invalid images. Will proceed without validation.")
        val_dataset_raw = None

In [None]:
def convert_to_conversation(sample):
    if sample.get('image') is None or not isinstance(sample.get('image'), Image.Image):
        print(f"CRITICAL in convert_to_conversation: ID {sample.get('ID', 'Unknown ID')} has invalid image type: {type(sample.get('image'))}")
    return {
        "messages": [
            {"role": "user", "content": [
                {"type": "text", "text": INSTRUCTION},
                {"type": "image", "image": sample["image"]} # sample["image"] must be a PIL.Image object
            ]},
            {"role": "assistant", "content": [
                {"type": "text", "text": sample["Caption"]}
            ]},
        ]
    }

In [None]:
if train_dataset_raw:
    converted_train_dataset = train_dataset_raw.map(
        convert_to_conversation, 
        batched=False,
        remove_columns=train_dataset_raw.column_names, 
        load_from_cache_file=False 
    )
    print(f"Training dataset converted to conversation format. Num samples: {len(converted_train_dataset)}")
else:
    print("Training dataset could not be loaded or was empty after filtering. Exiting.")
    exit()


In [None]:
converted_val_dataset = None
if val_dataset_raw:
    converted_val_dataset = val_dataset_raw.map(
        convert_to_conversation,
        batched=False,
        remove_columns=val_dataset_raw.column_names,
        load_from_cache_file=False
    )
    print(f"Validation dataset converted to conversation format. Num samples: {len(converted_val_dataset)}")
else:
    print("No valid validation dataset available.")


In [None]:
model, tokenizer = FastVisionModel.from_pretrained(
    MODEL_NAME,
    load_in_4bit=True,
    use_gradient_checkpointing="unsloth",
    max_seq_length=MAX_SEQ_LENGTH,
)

In [None]:
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers=False,
    finetune_language_layers=True,
    finetune_attention_modules=True,
    finetune_mlp_modules=True,
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=0,
    bias="none",
    random_state=SEED,
    use_rslora=False,
    loftq_config=None,
)
print("Model and tokenizer loaded with PEFT.")


In [None]:
FastVisionModel.for_training(model)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=UnslothVisionDataCollator(model, tokenizer), # Must use!
    train_dataset=converted_train_dataset,
    eval_dataset=converted_val_dataset,
    args=SFTConfig(
        per_device_train_batch_size=TRAIN_BATCH_SIZE,
        gradient_accumulation_steps=GRAD_ACCUMULATION_STEPS,
        warmup_steps=5,
        max_steps=MAX_STEPS,
        learning_rate=LEARNING_RATE,
        fp16=not is_bf16_supported(),
        bf16=is_bf16_supported(),
        logging_steps=1, 
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=SEED,
        output_dir=OUTPUT_MODEL_DIR, 
        report_to="none",
        dataset_num_proc=4, 
        max_seq_length=MAX_SEQ_LENGTH,
    ),
)

In [None]:
print("Starting training...")
try:
    trainer_stats = trainer.train()
    print("Training finished.")
except Exception as e:
    print(f"Error during training: {e}")
    import traceback
    traceback.print_exc()
    print("Inspect the debug prints above, especially any 'FILTERING OUT' or 'CRITICAL' messages.")

In [None]:
# 5. Save Model
print(f"Saving model to {OUTPUT_MODEL_DIR}...")
trainer.save_model(OUTPUT_MODEL_DIR)
tokenizer.save_pretrained(OUTPUT_MODEL_DIR)
print("Model and tokenizer saved.")


In [None]:
log_history = trainer.state.log_history
train_steps = []
train_losses = []
eval_steps_log = []
eval_losses = []

for log_entry in log_history:
    if 'loss' in log_entry and 'eval_loss' not in log_entry : # Training log
        train_steps.append(log_entry['step'])
        train_losses.append(log_entry['loss'])
    if 'eval_loss' in log_entry: # Evaluation log
        eval_steps_log.append(log_entry['step'])
        eval_losses.append(log_entry['eval_loss'])


In [None]:
plt.figure(figsize=(12, 6))
plt.plot(train_steps, train_losses, label='Training Loss', marker='o', linestyle='-')
if eval_steps_log and eval_losses:
    plt.plot(eval_steps_log, eval_losses, label='Validation Loss', marker='x', linestyle='--')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Steps')
plt.legend()
plt.grid(True)
plt.savefig(LOSS_PLOT_FILE)
print(f"Loss plot saved to {LOSS_PLOT_FILE}")
plt.show()


In [None]:
print("Starting inference on test set...")

In [None]:
del model
del trainer
torch.cuda.empty_cache()

In [None]:
model, tokenizer = FastVisionModel.from_pretrained(
    OUTPUT_MODEL_DIR,
    load_in_4bit=True,
    random_state=SEED,
    max_seq_length=MAX_SEQ_LENGTH,
)

In [None]:
FastVisionModel.for_inference(model) 
print("Fine-tuned model loaded for inference.")

In [None]:
if not os.path.exists(TEST_IMAGE_DIR):
    print(f"Test image directory {TEST_IMAGE_DIR} not found. Skipping inference.")
else:
    test_image_files = []
    for f_name in os.listdir(TEST_IMAGE_DIR):
        full_path = os.path.join(TEST_IMAGE_DIR, f_name)
        if os.path.isfile(full_path):
            # Check for common image extensions
            if f_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.dcm')):
                test_image_files.append(full_path)
    
    if not test_image_files:
        print(f"No image files found in {TEST_IMAGE_DIR}. Skipping inference.")
    else:
        print(f"Found {len(test_image_files)} images for inference.")
        results = []
        
        for image_path in test_image_files:
            image_id = os.path.splitext(os.path.basename(image_path))[0]
            try:
                pil_image = Image.open(image_path).convert("RGB") # Ensure RGB
            except Exception as e:
                print(f"Could not load test image {image_path}: {e}")
                results.append({"ID": image_id, "Caption": f"Error loading image: {e}"})
                continue

            messages_inference = [
                {"role": "user", "content": [
                    {"type": "image"}, # Placeholder for the image
                    {"type": "text", "text": INSTRUCTION}
                ]}
            ]
            
            # Prepare text prompt using chat template
            text_prompt = tokenizer.apply_chat_template(
                messages_inference,
                tokenize=False, # Get the string prompt
                add_generation_prompt=True
            )

            inputs = tokenizer(
                images=pil_image,
                text=text_prompt,
                return_tensors="pt",
                add_special_tokens=False, 
            ).to("cuda")

            text_streamer = TextStreamer(tokenizer, skip_prompt=True)
            print(f"Generating caption for {image_id}...")
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    streamer=text_streamer, 
                    max_new_tokens=512,   
                    use_cache=True,
                    temperature=0.7,       
                    min_p=0.1
                )
            
            input_ids_len = inputs.input_ids.shape[1]
            generated_ids = outputs[0][input_ids_len:]
            predicted_caption = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
            
            print(f"Predicted for {image_id}: {predicted_caption}")
            results.append({"ID": image_id, "Caption": predicted_caption})
            pil_image.close()

        submission_df = pd.DataFrame(results)
        submission_df.to_csv(SUBMISSION_FILE, index=False)
        print(f"Submission file created: {SUBMISSION_FILE}")

print("Script finished.")