# Dependencies

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as transforms
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import json
from tqdm import tqdm

# Check GPU Availibility

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load COCO

In [None]:
image_dir = "coco-dataset/train2017"
annotations_file = "coco-dataset/annotations/captions_train2017.json"

# Load Annotations

In [None]:
with open(annotations_file, "r", encoding="utf-8") as f:
    coco_data = json.load(f)
print(f"Total annotations loaded: {len(coco_data['annotations'])}")

# Initialize BLIP Processor

In [None]:
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)


# Define COCO Dataset

In [None]:
class COCODataset(Dataset):
    def __init__(self, coco_data, image_dir, processor):
        self.coco_data = coco_data["annotations"]
        self.image_dir = image_dir
        self.processor = processor
        self.image_id_to_filename = {img["id"]: img["file_name"] for img in coco_data["images"]}
        
    def __len__(self):
        return len(self.coco_data)
    
    def __getitem__(self, idx):
        annotation = self.coco_data[idx]
        image_path = os.path.join(self.image_dir, self.image_id_to_filename[annotation["image_id"]])
        image = Image.open(image_path).convert("RGB")
        encoding = self.processor(images=image, text=annotation["caption"], return_tensors="pt", padding="max_length", max_length=64, truncation=True)
        
        return {
            "pixel_values": encoding["pixel_values"].squeeze(0),
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
        }


# Create Dataset

In [None]:
dataset = COCODataset(coco_data, image_dir, processor)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Dataloaders

In [None]:
batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

# Training Setup

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.7)
criterion = nn.CrossEntropyLoss()


# Training Loop

In [None]:
epochs = 10
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    scheduler.step()
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_dataloader)}")
    torch.save(model.state_dict(), "coco_checkpoint.pth")


# Evaluation

In [None]:
total_loss = 0.0
model.eval()
with torch.no_grad():
    for batch in tqdm(val_dataloader, desc="Evaluating"):
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        total_loss += loss.item()

print(f"Validation Loss: {total_loss / len(val_dataloader):.4f}")