In [11]:
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from torch.optim import Adam
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import os
from open_clip import create_model_and_transforms
from PIL import Image

In [2]:
class ImageCaptioningModel(nn.Module):
    def __init__(self, encoder_model_name='hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', 
                 gpt2_model_name='gpt2', projection_dim=512):
        super(ImageCaptioningModel, self).__init__()
        
        # Load Encoder model and its tokenizer
        self.encoder_model, self.preprocess_train, self.preprocess_val = create_model_and_transforms(
            encoder_model_name, pretrained_hf=True
        )
        
        # Get the encoder output size by passing a dummy input through the encoder
        dummy_input = torch.rand(1, 3, 224, 224)  # Assuming the model input size is 224x224
        with torch.no_grad():
            dummy_output = self.encoder_model.encode_image(dummy_input)
        self.encoder_output_dim = dummy_output.shape[1]  # Use the embedding dimension from the dummy output
        
        # GPT-2 Decoder model
        self.gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name)
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)

        # Projection layer to match embedding sizes
        self.image_embeddings_projected = nn.Linear(self.encoder_output_dim, projection_dim)
    
    def forward(self, images, captions):
        # Extract image embeddings
        image_embeddings = self.encoder_model.encode_image(images)
        image_embeddings_projected = self.image_embeddings_projected(image_embeddings)
        
        # Get GPT-2 inputs and generate logits
        gpt2_inputs = self.gpt2_tokenizer(captions, return_tensors="pt", padding=True, truncation=True)
        logits = self.gpt2_model(**gpt2_inputs).logits
        
        return logits

    def preprocess_image(self, image_path):
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.preprocess_val(image).unsqueeze(0)  # Preprocess image
        return image_tensor

    def tokenize_caption(self, caption):
        return self.gpt2_tokenizer.encode(caption, return_tensors='pt')

    def decode_caption(self, caption_tokens):
        return self.gpt2_tokenizer.decode(caption_tokens, skip_special_tokens=True)


In [3]:
class ROCODataset(Dataset):
    def __init__(self, caption_file, image_folder, model):
        self.captions = pd.read_csv(caption_file)
        self.image_folder = image_folder
        self.preprocess = model.preprocess_train
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

        # Assign a padding token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token  # or you can set a new token with add_special_tokens

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        caption = self.captions.iloc[idx]['Caption']
        image_id = self.captions.iloc[idx]['ID']
        image_path = os.path.join(self.image_folder, image_id + '.jpg')

        # Load and preprocess the image
        image = Image.open(image_path).convert("RGB")
        image_tensor = self.preprocess(image)

        # Tokenize the caption
        caption_tokens = self.tokenizer(caption, return_tensors='pt', padding=True, truncation=True)

        return image_tensor, caption_tokens['input_ids'].squeeze(0), caption_tokens['attention_mask'].squeeze(0)


In [4]:
def custom_collate_fn(batch):
    images, captions, attention_masks = zip(*batch)
    
    # Stack images (they are all the same size)
    images = torch.stack(images)
    
    # Pad captions and attention masks
    captions = torch.nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=0)
    attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)
    
    return images, captions, attention_masks


In [15]:
def train_model_mixed_precision(model, dataloader, epochs=5, lr=1e-4):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    optimizer = Adam(model.parameters(), lr=lr)
    scaler = torch.amp.GradScaler('cuda')

    for epoch in range(epochs):
        running_loss = 0.0

        for image_tensor, inputs, attention_mask in dataloader:
            image_tensor = image_tensor.to(device)
            inputs = inputs.to(device)
            attention_mask = attention_mask.to(device)

            optimizer.zero_grad()

            with torch.amp.autocast(device_type='cuda'):
                outputs = model(image_tensor, inputs)
                loss = outputs.loss

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}")


In [6]:
# File paths
train_caption_file = "../Datasets/ROCO2/train_captions.csv"
train_image_folder = "../Datasets/ROCO2/train_images/train/"
test_caption_file = "../Datasets/ROCO2/test_captions.csv"
test_image_folder = "../Datasets/ROCO2/test_images/test/"

# Initialize the model (use the BiomedCLIP model as encoder)
model = ImageCaptioningModel(encoder_model_name='hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

# Create datasets for train and test
train_dataset = ROCODataset(train_caption_file, train_image_folder, model)
test_dataset = ROCODataset(test_caption_file, test_image_folder, model)

# Create DataLoaders for train and test
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)

  checkpoint = torch.load(checkpoint_path, map_location=map_location)


In [16]:
# Train the model with the ROCO dataset
train_model_mixed_precision(model, train_loader, epochs=5)


TypeError: ImageCaptioningModel.forward() got an unexpected keyword argument 'input_ids'