In [1]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2

In [2]:
from unsloth import FastVisionModel
from transformers import TextStreamer
from PIL import Image
import torch

# Load the fine-tuned model
def load_model(model_path="/content/drive/MyDrive/MED-MCP/TRAIN-VLM/MED_Llama3.2_11B_VL-lora_model"):
    """
    Load the fine-tuned medical vision model.

    Args:
        model_path: Path to the saved LoRA model

    Returns:
        model, tokenizer
    """
    print("Loading model...")
    model, tokenizer = FastVisionModel.from_pretrained(
        model_name=model_path,
        load_in_4bit=True,  # Set to False for 16bit
    )
    FastVisionModel.for_inference(model)  # Enable inference mode
    print("Model loaded successfully!")
    return model, tokenizer


def analyze_medical_image(image_path, model, tokenizer,
                         instruction="You are an expert radiographer. Describe accurately what you see in this image.",
                         max_tokens=256,
                         temperature=1.5,
                         min_p=0.1,
                         stream_output=True):
    """
    Analyze a medical image using the fine-tuned vision model.

    Args:
        image_path: Path to the image file (supports PNG, JPG, JPEG, etc.)
        model: The loaded vision model
        tokenizer: The model tokenizer
        instruction: Custom instruction for the model
        max_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature (higher = more creative)
        min_p: Minimum probability for sampling
        stream_output: Whether to stream output token by token

    Returns:
        Generated text description of the image
    """
    try:
        # Load the image
        print(f"Loading image from: {image_path}")
        image = Image.open(image_path)

        # Prepare the messages in the required format
        messages = [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": instruction}
            ]}
        ]

        # Apply chat template
        input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

        # Tokenize inputs
        inputs = tokenizer(
            image,
            input_text,
            add_special_tokens=False,
            return_tensors="pt",
        ).to("cuda")

        print("\nGenerating analysis...\n")

        if stream_output:
            # Stream output token by token
            text_streamer = TextStreamer(tokenizer, skip_prompt=True)
            output = model.generate(
                **inputs,
                streamer=text_streamer,
                max_new_tokens=max_tokens,
                use_cache=True,
                temperature=temperature,
                min_p=min_p
            )
        else:
            # Generate without streaming
            output = model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                use_cache=True,
                temperature=temperature,
                min_p=min_p
            )
            # Decode the output
            generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
            # Extract only the assistant's response
            if "assistant" in generated_text:
                generated_text = generated_text.split("assistant")[-1].strip()
            print(generated_text)
            return generated_text

    except FileNotFoundError:
        print(f"Error: Image file not found at {image_path}")
        return None
    except Exception as e:
        print(f"Error during inference: {str(e)}")
        return None


def batch_analyze_images(image_paths, model, tokenizer, instruction=None):
    """
    Analyze multiple medical images.

    Args:
        image_paths: List of paths to image files
        model: The loaded vision model
        tokenizer: The model tokenizer
        instruction: Custom instruction (optional)

    Returns:
        List of analysis results
    """
    results = []
    for i, image_path in enumerate(image_paths, 1):
        print(f"\n{'='*60}")
        print(f"Analyzing image {i}/{len(image_paths)}: {image_path}")
        print('='*60)

        result = analyze_medical_image(
            image_path,
            model,
            tokenizer,
            instruction=instruction if instruction else "You are an expert radiographer. Describe accurately what you see in this image.",
            stream_output=True
        )
        results.append({"image": image_path, "analysis": result})

    return results


# Example usage
if __name__ == "__main__":
    # Load the model once
    model, tokenizer = load_model("/content/drive/MyDrive/MED-MCP/TRAIN-VLM/MED_Llama3.2_11B_VL-lora_model")

    # Single image analysis
    print("\n" + "="*60)
    print("SINGLE IMAGE ANALYSIS")
    print("="*60)

    image_path = "/content/drive/MyDrive/MED-MCP/TRAIN-VLM/extracted_images/ROCOv2_2023_test_000056.png"
    analysis = analyze_medical_image(
        image_path=image_path,
        model=model,
        tokenizer=tokenizer,
        instruction="You are an expert radiographer. Describe accurately what you see in this image.",
        max_tokens=256,
        stream_output=True
    )

    # Batch analysis (multiple images)
    # print("\n" + "="*60)
    # print("BATCH IMAGE ANALYSIS")
    # print("="*60)
    #
    # image_paths = [
    #     "path/to/image1.png",
    #     "path/to/image2.png",
    #     "path/to/image3.png"
    # ]
    #
    # results = batch_analyze_images(image_paths, model, tokenizer)
    #
    # # Print summary
    # print("\n" + "="*60)
    # print("ANALYSIS SUMMARY")
    # print("="*60)
    # for i, result in enumerate(results, 1):
    #     print(f"\nImage {i}: {result['image']}")
    #     print(f"Analysis: {result['analysis'][:100]}...")  # First 100 chars

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
Loading model...
==((====))==  Unsloth 2025.12.8: Fast Mllama patching. Transformers: 4.56.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

Model loaded successfully!

SINGLE IMAGE ANALYSIS
Loading image from: /content/drive/MyDrive/MED-MCP/TRAIN-VLM/extracted_images/ROCOv2_2023_test_000056.png

Generating analysis...

CT chest. There is an irregular round mass in right lower lobe. There is consolidation in right upper lobe and lung parenchymal hemorrhage is present.<|eot_id|>
