In [1]:
import os
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, Trainer, TrainingArguments
from datasets import Dataset, DatasetDict, load_from_disk, concatenate_datasets
from PIL import Image
import pandas as pd

In [2]:
class ROCODataLoader:
    def __init__(self, image_dir_train, image_dir_val, annotation_file_train, annotation_file_val):
        """
        Initializes the data loader with directories for training and validation images, and annotation files.
        :param image_dir_train: Directory where training images are stored.
        :param image_dir_val: Directory where validation images are stored.
        :param annotation_file_train: Path to the CSV file for training annotations.
        :param annotation_file_val: Path to the CSV file for validation annotations.
        """
        self.image_dir_train = image_dir_train
        self.image_dir_val = image_dir_val
        self.annotation_file_train = annotation_file_train
        self.annotation_file_val = annotation_file_val

    def load_data(self):
        """
        Loads and prepares both the training and validation datasets.
        :return: A DatasetDict with 'train' and 'validation' datasets.
        """
        train_df = pd.read_csv(self.annotation_file_train)
        val_df = pd.read_csv(self.annotation_file_val)

        # Add full image paths for training and validation sets
        train_df['image_path'] = train_df['ID'].apply(lambda x: os.path.join(self.image_dir_train, x + ".jpg"))
        val_df['image_path'] = val_df['ID'].apply(lambda x: os.path.join(self.image_dir_val, x + ".jpg"))

        # Convert to Hugging Face dataset format
        train_dataset = Dataset.from_pandas(train_df)
        val_dataset = Dataset.from_pandas(val_df)

        # Return a DatasetDict containing both training and validation sets
        return DatasetDict({'train': train_dataset, 'validation': val_dataset})

In [3]:
def preprocess_data(examples, processor):
    images = []
    for path in examples['image_path']:
        try:
            img = Image.open(path).convert("RGB")
            images.append(img)
        except Exception as e:
            print(f"Error loading image {path}: {e}")
            continue
        
    # Truncate text that exceeds the token limit and process images
    text = [caption for caption in examples['Caption']]
    inputs = processor(images=images, text=text, return_tensors="pt", padding=True, truncation=True, max_length=512)

    return {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'pixel_values': inputs['pixel_values']
    }

class BLIPFineTuner:
    def __init__(self, model_name="Salesforce/blip-image-captioning-large", output_dir="./results"):
        self.processor = BlipProcessor.from_pretrained(model_name)
        self.output_dir = output_dir
        self.model_name = model_name
        self.model = None

    def process_in_chunks(self, dataset, chunk_size, dataset_type):
        processed_chunk_dir = f"./processed_{dataset_type}_chunks/"
        if os.path.exists(processed_chunk_dir):
            print(f"Loading preprocessed {dataset_type} chunks from disk...")

            # Gather all chunk folders and load each of them
            chunk_folders = [os.path.join(processed_chunk_dir, folder) for folder in os.listdir(processed_chunk_dir)]
            chunk_datasets = []

            for folder in chunk_folders:
                try:
                    chunk_datasets.append(load_from_disk(folder))
                except Exception as e:
                    print(f"Error loading chunk from {folder}: {e}")
            
            # Concatenate all chunk datasets into a single dataset
            print(f"Concatenating {len(chunk_datasets)} chunks into a single dataset...")
            if len(chunk_datasets) > 0:
                full_dataset = concatenate_datasets(chunk_datasets)
                print(f"Successfully loaded and concatenated {dataset_type} dataset.")
                return full_dataset
            else:
                raise FileNotFoundError(f"No valid chunks found in {processed_chunk_dir}")
        else:
            print(f"No preprocessed {dataset_type} data found. Processing now...")

            # Create a directory for processed chunks if it doesn't exist
            os.makedirs(processed_chunk_dir, exist_ok=True)

            # Process and save each chunk directly to disk to avoid memory exhaustion
            for start_idx in range(0, len(dataset), chunk_size):
                end_idx = min(start_idx + chunk_size, len(dataset))
                chunk = dataset.select(range(start_idx, end_idx))
                print(f"Processing {dataset_type} chunk from {start_idx} to {end_idx}...")
                processed_chunk = chunk.map(lambda examples: preprocess_data(examples, self.processor), 
                                            batched=True, batch_size=1, load_from_cache_file=False)

                # Save the chunk immediately to disk using save_to_disk()
                chunk_path = f"{processed_chunk_dir}/chunk_{start_idx}_{end_idx}"
                print(f"Saving {dataset_type} processed chunk from {start_idx} to {end_idx} at {chunk_path}...")
                processed_chunk.save_to_disk(chunk_path)

            # Reload the chunks after processing
            return self.process_in_chunks(dataset, chunk_size, dataset_type)

    def fine_tune(self, dataset, num_train_epochs=3, batch_size=4, chunk_size=1000):
        if self.model is None:
            print(f"Loading model {self.model_name}...")
            self.model = BlipForConditionalGeneration.from_pretrained(self.model_name)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")
        self.model.to(device)

        # Process or load the dataset in chunks to prevent memory overflow
        print("Processing or loading training set in chunks...")
        processed_train_dataset = self.process_in_chunks(dataset["train"], chunk_size, dataset_type="train")
        print("Processing or loading validation set in chunks...")
        processed_validation_dataset = self.process_in_chunks(dataset["validation"], chunk_size, dataset_type="validation")

        # Define training arguments with caching
        training_args = TrainingArguments(
            output_dir=self.output_dir,
            evaluation_strategy="steps",
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=num_train_epochs,
            save_steps=1000,
            save_total_limit=2,
            logging_dir="./logs",
            learning_rate=5e-5,
            logging_steps=100,
            eval_steps=500,
            remove_unused_columns=False,
            load_best_model_at_end=True
        )

        # Define the Trainer
        print("Initializing Trainer...")
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=processed_train_dataset,
            eval_dataset=processed_validation_dataset,
        )

        # Start training
        print("Starting training...")
        trainer.train()

        # Save the fine-tuned model
        self.save_model()
        print(f"Model saved at {self.output_dir}")

    def save_model(self):
        if self.model:
            print("Saving model...")
            self.model.save_pretrained(self.output_dir)
            self.processor.save_pretrained(self.output_dir)
        else:
            print("Model is not loaded, cannot save.")

    def load_model(self):
        if os.path.exists(self.output_dir):
            print(f"Loading model from {self.output_dir}...")
            self.model = BlipForConditionalGeneration.from_pretrained(self.output_dir)
            self.processor = BlipProcessor.from_pretrained(self.output_dir)
        else:
            raise FileNotFoundError(f"No model found at {self.output_dir}")

    def generate_caption(self, image_path):
        if self.model is None:
            raise ValueError("The model is not loaded. Call `load_model()` or train the model first.")

        print(f"Generating caption for image: {image_path}")
        image = Image.open(image_path).convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
        out = self.model.generate(**inputs)
        generated_caption = self.processor.decode(out[0], skip_special_tokens=True)
        return generated_caption

In [None]:
def main(load_existing_model=False):
    # Paths to your data
    image_dir_train = "../Datasets/ROCO2/train_images/train"   
    image_dir_val = "../Datasets/ROCO2/valid_images/valid"
    annotation_file_train = "../Datasets/ROCO2/train_captions.csv"  
    annotation_file_val =  "../Datasets/ROCO2/valid_captions.csv"  

    # Load the dataset
    data_loader = ROCODataLoader(image_dir_train, image_dir_val, annotation_file_train, annotation_file_val)
    dataset = data_loader.load_data()

    # Initialize the BLIP fine-tuner
    fine_tuner = BLIPFineTuner()

    if load_existing_model:
        # Load the saved model
        print("Loading the existing model...")
        fine_tuner.load_model()
    else:
        # Fine-tune the model
        print("Starting fine-tuning...")
        fine_tuner.fine_tune(dataset, num_train_epochs=3, batch_size=6, chunk_size=3000)

    # Perform inference on a new image
    test_image_path = "sample1.jpg"  # Replace with the path to a new medical image
    caption = fine_tuner.generate_caption(test_image_path)
    print(f"Generated Caption: {caption}")


if __name__ == "__main__":
    main(load_existing_model=False)



Starting fine-tuning...
Loading model Salesforce/blip-image-captioning-large...
Using device: cuda
Processing or loading training set in chunks...
Loading preprocessed train chunks from disk...
Concatenating 21 chunks into a single dataset...
Successfully loaded and concatenated train dataset.
Processing or loading validation set in chunks...
No preprocessed validation data found. Processing now...
Processing validation chunk from 0 to 3000...


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]