In [6]:
import torch
from torchvision import transforms
from PIL import Image
import os
import json
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration, Trainer, TrainingArguments

# Define a function to preprocess a single image
def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')  # Ensure the image is in RGB format
    return image

# Custom Dataset class to handle image and annotation loading
class ImageAnnotationDataset(Dataset):
    def __init__(self, directory_path, annotation_file):
        self.directory_path = directory_path
        with open(annotation_file, 'r') as f:
            self.annotations = json.load(f)
        self.image_filenames = [f for f in os.listdir(directory_path) if f.endswith('.jpg') or f.endswith('.png')]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.directory_path, img_name)
        image = preprocess_image(img_path)
        annotation = self.annotations.get(img_name, "")
        return image, annotation

# Main function to demonstrate usage
if __name__ == "__main__":
    directory_path = "/Users/kristinakuznetsova/Downloads/frames2"  
    annotation_file = "/Users/kristinakuznetsova/Downloads/annotations.json"  

    dataset = ImageAnnotationDataset(directory_path, annotation_file)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


In [7]:
# Load the BLIP processor and model
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# Define a function to collate data for the DataLoader
def collate_fn(batch):
    images, annotations = zip(*batch)
    inputs = processor(text=annotations, images=list(images), return_tensors="pt", padding=True)
    inputs['labels'] = inputs.input_ids
    return input

# Define the training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    save_steps=10_000,
    save_total_limit=2,
    report_to="wandb"
)

# Create a Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=dataset,
    eval_dataset=dataset,
)

# Start training
train_output = trainer.train()

print(train_output)

# Save the model and processor
output_dir = "/Users/kristinakuznetsova/Downloads/fine_tuned_blip"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
print(f"Model and processor saved to {output_dir}")

# Load the model and processor for inference
model = BlipForConditionalGeneration.from_pretrained(output_dir)
processor = BlipProcessor.from_pretrained(output_dir)
print("Model and processor loaded from", output_dir)

# Inference example
image_path = "/Users/kristinakuznetsova/Downloads/frames1/abnormal_scene_1_scenario_3_360.jpg"
image = Image.open(image_path).convert('RGB')
inputs = processor(images=image, return_tensors="pt")
outputs = model.generate(**inputs)
caption = processor.decode(outputs[0], skip_special_tokens=True)
print("Generated Caption:", caption)


[34m[1mwandb[0m: Currently logged in as: [33mkristi54578[0m ([33mteam00000[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss
1,No log,4.637148
2,No log,3.860972
3,No log,3.603543


TrainOutput(global_step=66, training_loss=4.378513220584754, metrics={'train_runtime': 177.4942, 'train_samples_per_second': 2.958, 'train_steps_per_second': 0.372, 'total_flos': 3.115474174181376e+17, 'train_loss': 4.378513220584754, 'epoch': 3.0})
Model and processor saved to /Users/kristinakuznetsova/Downloads/fine_tuned_blip
Model and processor loaded from /Users/kristinakuznetsova/Downloads/fine_tuned_blip




Generated Caption: a bald man in casual clothes looks at a construction worker in a fluorescent waistcoat and a white
