In [None]:
pip install transformers datasets torch torchvision pillow




In [None]:
from datasets import load_dataset
from datasets import DatasetDict
from transformers import BlipProcessor, BlipForConditionalGeneration, TrainingArguments, Trainer
from torchvision import transforms
from PIL import Image
from torch.nn import CrossEntropyLoss
import torch

# Load the dataset
dataset = load_dataset("axiong/pmc_oa_demo")
print(dataset)
print(dataset['train'][0])  # Check the first example

# Load the BLIP processor and model
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# Define the image transformation
image_transform = transforms.Compose([
    transforms.Resize((384, 384)),  # Resize to match BLIP model input size
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard normalization
])

# Preprocessing function
def preprocess_function(examples):
    # Process the images
    images = [image_transform(image) for image in examples['image']]

    # Tokenize the captions using the BLIP processor
    captions = examples['caption']
    encoding = processor.tokenizer(captions, padding='max_length', truncation=True, max_length=77, return_tensors='pt')

    return {
        'pixel_values': torch.stack(images),  # Stack processed images into a tensor
        'input_ids': encoding['input_ids']  # Tokenized captions
    }

# Apply the preprocessing function to the dataset
processed_dataset = dataset.map(preprocess_function, batched=True)

# Split the dataset into train and validation sets
split_dataset = processed_dataset["train"].train_test_split(test_size=0.1)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./blip_model",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    logging_dir="./logs",
    logging_steps=10,
    save_steps=500,
    evaluation_strategy="steps",
    eval_steps=500,
    learning_rate=5e-5,
    warmup_steps=500,
    report_to="none"  # Disable W&B
)

# Define a custom trainer to handle the custom loss function
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # Forward pass
        outputs = model(pixel_values=inputs["pixel_values"], input_ids=inputs["input_ids"], labels=inputs["input_ids"])

        # Extract logits
        logits = outputs.logits

        # Shift logits and labels for causal language modeling
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = inputs["input_ids"][..., 1:].contiguous()

        # Ensure pad_token_id is set correctly
        pad_token_id = model.config.pad_token_id if model.config.pad_token_id is not None else -100

        # Compute the loss
        loss_fct = CrossEntropyLoss(ignore_index=pad_token_id)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        return (loss, outputs) if return_outputs else loss

# Initialize the custom trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=split_dataset["train"],
    eval_dataset=split_dataset["test"],
    tokenizer=processor.tokenizer
)

# Train the model
trainer.train()


DatasetDict({
    train: Dataset({
        features: ['image', 'caption'],
        num_rows: 100
    })
})
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=336x343 at 0x79586F39DC30>, 'caption': 'Electron micrographs of ultra-thin sections of potato root inoculated with Brevundimonas sp. TN37, 30\xa0days post inoculation. RS Rhizosphere, RC root cell and B Bacterium. Brevundimonas sp. TN37 colonized over the root surface of potato. Bacterial cells form sheets in the grooves, formed by the root cells (A,B) and also closely attached to the cell wall of plant cells (C,D) and in control uninoculated plants no bacterial cells were observed (E).'}


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

  trainer = CustomTrainer(
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Step,Training Loss,Validation Loss


TrainOutput(global_step=36, training_loss=8.957001474168566, metrics={'train_runtime': 3067.0304, 'train_samples_per_second': 0.088, 'train_steps_per_second': 0.012, 'total_flos': 1.6022438610075648e+17, 'train_loss': 8.957001474168566, 'epoch': 3.0})

In [None]:
from PIL import Image
import torch

def generate_caption(image_path, model, processor):
    """
    Generate a caption for an input image using the trained BLIP model.
    """
    # Open and preprocess the image
    raw_image = Image.open(image_path).convert("RGB")
    inputs = processor(images=raw_image, return_tensors="pt").to(model.device)

    # Generate caption
    model.eval()
    with torch.no_grad():
        generated_ids = model.generate(pixel_values=inputs['pixel_values'])

    # Decode the generated caption
    generated_caption = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_caption

# Example: Test the trained model with a new image
test_image_path = "/content/xray.jpg"  # Replace with your test image path
caption = generate_caption(test_image_path, model, processor)
print(f"Generated Caption: {caption}")


Generated Caption: the chest is a large, flat, flat, and flat area
