In [1]:
import os
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, Trainer, TrainingArguments
from datasets import Dataset, DatasetDict
from PIL import Image
import pandas as pd

In [2]:
class ROCODataLoader:
    def __init__(self, image_dir_train, image_dir_val, annotation_file_train, annotation_file_val):
        """
        Initializes the data loader with directories for training and validation images, and annotation files.
        :param image_dir_train: Directory where training images are stored.
        :param image_dir_val: Directory where validation images are stored.
        :param annotation_file_train: Path to the CSV file for training annotations.
        :param annotation_file_val: Path to the CSV file for validation annotations.
        """
        self.image_dir_train = image_dir_train
        self.image_dir_val = image_dir_val
        self.annotation_file_train = annotation_file_train
        self.annotation_file_val = annotation_file_val

    def load_data(self):
        """
        Loads and prepares both the training and validation datasets.
        :return: A DatasetDict with 'train' and 'validation' datasets.
        """
        train_df = pd.read_csv(self.annotation_file_train)
        val_df = pd.read_csv(self.annotation_file_val)

        # Add full image paths for training and validation sets
        train_df['image_path'] = train_df['ID'].apply(lambda x: os.path.join(self.image_dir_train, x + ".jpg"))
        val_df['image_path'] = val_df['ID'].apply(lambda x: os.path.join(self.image_dir_val, x + ".jpg"))

        # Convert to Hugging Face dataset format
        train_dataset = Dataset.from_pandas(train_df)
        val_dataset = Dataset.from_pandas(val_df)

        # Return a DatasetDict containing both training and validation sets
        return DatasetDict({'train': train_dataset, 'validation': val_dataset})

In [3]:
def preprocess_data(examples, processor):
    """
    Preprocess the dataset by loading images and processing them using the BLIP processor.
    Truncates long sequences to fit within the model's token limit.
    """
    images = []
    for path in examples['image_path']:
        try:
            img = Image.open(path).convert("RGB")
            images.append(img)
        except Exception as e:
            print(f"Error loading image {path}: {e}")
            continue
        
    # Truncate text that exceeds the token limit and process images
    text = [caption for caption in examples['Caption']]
    inputs = processor(images=images, text=text, return_tensors="pt", padding=True, truncation=True, max_length=512)

    # Return inputs compatible with the model's forward method
    return {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'pixel_values': inputs['pixel_values']
    }

class BLIPFineTuner:
    def __init__(self, model_name="Salesforce/blip-image-captioning-large", output_dir="./results"):
        """
        Initializes the BLIP model and processor.
        :param model_name: The pre-trained model to fine-tune.
        :param output_dir: Directory to save the fine-tuned model.
        """
        self.processor = BlipProcessor.from_pretrained(model_name)
        self.output_dir = output_dir
        self.model_name = model_name
        self.model = None

    def process_in_chunks(self, dataset, chunk_size, dataset_type):
        """
        Process the dataset in chunks to avoid memory exhaustion.
        :param dataset: The full dataset to process.
        :param chunk_size: Number of samples to process at once.
        :param dataset_type: Type of dataset being processed ("train" or "validation").
        """
        # Check if processed chunks already exist
        processed_chunk_dir = f"./processed_{dataset_type}_chunks/"
        if os.path.exists(processed_chunk_dir):
            print(f"Loading preprocessed {dataset_type} chunks from disk...")
            processed_dataset = load_from_disk(processed_chunk_dir)
            return processed_dataset

        print(f"No preprocessed {dataset_type} data found. Processing now...")
        # Process the dataset in chunks if no cached version is found
        processed_chunks = []
        for start_idx in range(0, len(dataset), chunk_size):
            end_idx = min(start_idx + chunk_size, len(dataset))
            chunk = dataset.select(range(start_idx, end_idx))
            print(f"Processing {dataset_type} chunk from {start_idx} to {end_idx}...")
            processed_chunk = chunk.map(lambda examples: preprocess_data(examples, self.processor), 
                                        batched=True, batch_size=1, load_from_cache_file=False)
            
            # Save processed chunk to memory and disk
            print(f"Saving {dataset_type} processed chunk from {start_idx} to {end_idx}...")
            processed_chunks.append(processed_chunk)

        # Concatenate all processed chunks into one dataset and save
        processed_dataset = Dataset.concatenate_datasets(processed_chunks)
        processed_dataset.save_to_disk(processed_chunk_dir)
        return processed_dataset

    def fine_tune(self, dataset, num_train_epochs=3, batch_size=4, chunk_size=1000):
        """
        Fine-tunes the BLIP model on the provided dataset.
        :param dataset: The dataset containing training and validation data.
        :param num_train_epochs: Number of epochs for fine-tuning.
        :param batch_size: Batch size for training and evaluation.
        :param chunk_size: Size of dataset chunks to process.
        """
        # Load the model if it's not already loaded
        if self.model is None:
            print(f"Loading model {self.model_name}...")
            self.model = BlipForConditionalGeneration.from_pretrained(self.model_name)

        # Move model to GPU if available
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")
        self.model.to(device)

        # Process or load the dataset in chunks to prevent memory overflow
        print("Processing or loading training set in chunks...")
        processed_train_dataset = self.process_in_chunks(dataset["train"], chunk_size, dataset_type="train")
        print("Processing or loading validation set in chunks...")
        processed_validation_dataset = self.process_in_chunks(dataset["validation"], chunk_size, dataset_type="validation")

        # Define training arguments with caching
        training_args = TrainingArguments(
            output_dir=self.output_dir,
            evaluation_strategy="steps",
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=num_train_epochs,
            save_steps=1000,
            save_total_limit=2,
            logging_dir="./logs",
            learning_rate=5e-5,
            logging_steps=100,  # Log and evaluate every 100 steps
            eval_steps=500,  # Evaluate the model every 500 steps
            remove_unused_columns=False,  # Keep extra columns like ID, image_path, etc.
            load_best_model_at_end=True  # Load the best model checkpoint at the end of training
        )

        # Define the Trainer
        print("Initializing Trainer...")
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=processed_train_dataset,
            eval_dataset=processed_validation_dataset,
        )

        # Start training
        print("Starting training...")
        trainer.train()

        # Save the fine-tuned model
        self.save_model()
        print(f"Model saved at {self.output_dir}")

    def save_model(self):
        """
        Saves the fine-tuned model.
        """
        if self.model:
            print("Saving model...")
            self.model.save_pretrained(self.output_dir)
            self.processor.save_pretrained(self.output_dir)
        else:
            print("Model is not loaded, cannot save.")

    def load_model(self):
        """
        Loads the saved fine-tuned model from disk.
        """
        if os.path.exists(self.output_dir):
            print(f"Loading model from {self.output_dir}...")
            self.model = BlipForConditionalGeneration.from_pretrained(self.output_dir)
            self.processor = BlipProcessor.from_pretrained(self.output_dir)
        else:
            raise FileNotFoundError(f"No model found at {self.output_dir}")

    def generate_caption(self, image_path):
        """
        Generates a caption for a given image using the fine-tuned model.
        :param image_path: Path to the medical image.
        :return: Generated caption.
        """
        if self.model is None:
            raise ValueError("The model is not loaded. Call `load_model()` or train the model first.")

        print(f"Generating caption for image: {image_path}")
        image = Image.open(image_path).convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
        out = self.model.generate(**inputs)
        generated_caption = self.processor.decode(out[0], skip_special_tokens=True)
        return generated_caption


In [4]:
def main(load_existing_model=False):
    # Paths to your data
    image_dir_train = "../Datasets/ROCO2/train_images/train"   
    image_dir_val = "../Datasets/ROCO2/valid_images/valid"
    annotation_file_train = "../Datasets/ROCO2/train_captions.csv"  
    annotation_file_val =  "../Datasets/ROCO2/valid_captions.csv"  

    # Load the dataset
    data_loader = ROCODataLoader(image_dir_train, image_dir_val, annotation_file_train, annotation_file_val)
    dataset = data_loader.load_data()

    # Initialize the BLIP fine-tuner
    fine_tuner = BLIPFineTuner()

    if load_existing_model:
        # Load the saved model
        print("Loading the existing model...")
        fine_tuner.load_model()
    else:
        # Fine-tune the model
        print("Starting fine-tuning...")
        fine_tuner.fine_tune(dataset, num_train_epochs=3, batch_size=4, chunk_size=1000)

    # Perform inference on a new image
    test_image_path = "/path_to_new_image.jpg"  # Replace with the path to a new medical image
    caption = fine_tuner.generate_caption(test_image_path)
    print(f"Generated Caption: {caption}")


if __name__ == "__main__":
    main(load_existing_model=False)



Starting fine-tuning...
Loading model Salesforce/blip-image-captioning-large...
Using device: cuda
Processing or loading training set in chunks...
No preprocessed train data found. Processing now...
Processing train chunk from 0 to 3000...


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving train processed chunk from 0 to 3000...
Processing train chunk from 3000 to 6000...


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving train processed chunk from 3000 to 6000...
Processing train chunk from 6000 to 9000...


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving train processed chunk from 6000 to 9000...
Processing train chunk from 9000 to 12000...


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving train processed chunk from 9000 to 12000...
Processing train chunk from 12000 to 15000...


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving train processed chunk from 12000 to 15000...
Processing train chunk from 15000 to 18000...




Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

ArrowMemoryError: realloc of size 4294967296 failed