In [1]:
import torch
print(torch.version.cuda)
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))

12.1
2.4.1+cu121
True
0
NVIDIA GeForce RTX 3070


In [2]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
from torchvision import transforms
import open_clip
from torch import autocast
from torch.cuda.amp import GradScaler
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import time

In [3]:
class MedicalImageDataset(Dataset):
    def __init__(self, csv_file, image_folder, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.data.iloc[idx, 0] + '.jpg')  # Pick first column and add .jpg
        image = Image.open(img_name).convert('RGB')
        caption = self.data.iloc[idx, 1]  # Assuming 'Caption' is the second column
        
        if self.transform:
            image = self.transform(image)
        
        # Convert caption to a tensor (e.g., a single element tensor of string type)
        return image, caption


def preprocess_images(images, preprocess_func):
    """Preprocess images using the given preprocessing function."""
    preprocessed_images = []
    for img in images:
        if isinstance(img, torch.Tensor):
            # Convert tensor to PIL image for preprocessing
            img = transforms.ToPILImage()(img)
        preprocessed_images.append(preprocess_func(img))
    return torch.stack(preprocessed_images)

def set_padding_token(tokenizer):
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer

def tokenize_captions(captions, tokenizer, context_length):
    """Tokenize captions using the provided tokenizer."""
    tokenizer = set_padding_token(tokenizer)
    
    # Tokenize captions
    tokenized = tokenizer(captions, padding='max_length', truncation=True, max_length=context_length, return_tensors='pt')
    return tokenized

def train_model(model_manager, dataloader, num_epochs):
    """Train the model."""
    biomedclip_model = model_manager.biomedclip_model
    gpt2_model = model_manager.gpt2_model
    optimizer = model_manager.optimizer
    criterion = model_manager.criterion

    biomedclip_model.train()
    gpt2_model.train()
    
    scaler = torch.amp.GradScaler()  # For mixed precision training

    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in dataloader:
            images, captions = batch

            # Convert images to tensors if necessary
            if isinstance(images, list):
                images = torch.stack([torch.tensor(img) for img in images])
            images = images.to(model_manager.device)

            # Tokenize captions
            tokenized_captions = tokenize_captions(captions, model_manager.gpt2_tokenizer, context_length=256)
            tokenized_captions_ids = tokenized_captions['input_ids'].to(model_manager.device)
            
            # Extract image embeddings
            image_embeddings = model_manager.extract_image_embeddings(images)
            
            # Generate captions
            generated_logits = model_manager.generate_captions(image_embeddings, ["Medical image description: "] * len(images))
            
            # Flatten logits and targets for CrossEntropyLoss
            logits_flattened = generated_logits.view(-1, generated_logits.size(-1))
            targets_flattened = tokenized_captions_ids.view(-1)
            
            # Compute the loss
            loss = criterion(logits_flattened, targets_flattened)
            
            # Backward pass with scaled loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            optimizer.zero_grad()
            running_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(dataloader)}")


In [4]:
class ModelManager:
    def __init__(self, biomedclip_model_name: str, gpt2_model_name: str):
        # Load BiomedCLIP model and preprocess functions
        self.biomedclip_model, self.preprocess_train, self.preprocess_val = open_clip.create_model_and_transforms(biomedclip_model_name)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.biomedclip_model.to(self.device)
        
        # Load GPT-2 model and tokenizer
        self.gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name).to(self.device)
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
        
        # Set the correct output dimension for image embeddings
        self.image_embeddings_dim = 512  # Based on observed feature shape
        
        # Initialize projection layer
        self.image_embeddings_projected = torch.nn.Linear(self.image_embeddings_dim, self.gpt2_model.config.n_embd).to(self.device)
        
        # Initialize optimizer and criterion
        self.optimizer = torch.optim.Adam(
            list(self.biomedclip_model.parameters()) + list(self.gpt2_model.parameters()),
            lr=1e-4
        )
        self.criterion = torch.nn.CrossEntropyLoss()

    def extract_image_embeddings(self, images):
        """Extract image embeddings from BiomedCLIP."""
        with torch.no_grad():
            preprocessed_images = preprocess_images(images, self.preprocess_val)  # Use preprocess_val for validation images
            image_embeddings = self.biomedclip_model.encode_image(preprocessed_images.to(self.device))
            return self.image_embeddings_projected(image_embeddings)
    
    def generate_captions(self, image_embeddings_gpt2, prompts, num_return_sequences=1):
        """Generate logits for captions using GPT-2."""
        self.gpt2_model.eval()
        generated_logits = []
        
        for i, embedding in enumerate(image_embeddings_gpt2):
            prompt = prompts[i]
            
            # Prepare inputs
            input_ids = self.gpt2_tokenizer(prompt, return_tensors='pt').input_ids.to(self.device)
            input_embeds = self.gpt2_model.transformer.wte(input_ids)
            
            # Expand embedding dimensions to match input_embeds
            embedding_expanded = embedding.unsqueeze(0).unsqueeze(1)
            
            combined_embeds = torch.cat((embedding_expanded.expand(input_embeds.size(0), -1, -1), input_embeds), dim=1)
            
            # Create attention mask
            attention_mask = torch.ones(combined_embeds.shape[:-1]).to(self.device)
            
            # Generate logits
            logits = self.gpt2_model(
                inputs_embeds=combined_embeds,
                attention_mask=attention_mask
            ).logits
            
            generated_logits.append(logits)
            
        return torch.stack(generated_logits)  # Shape: (batch_size, seq_length, vocab_size)

In [5]:
# Define file paths
train_caption_file = "../Datasets/ROCO2/train_captions.csv"
train_image_folder = "../Datasets/ROCO2/train_images/train/"
test_caption_file = "../Datasets/ROCO2/test_captions.csv"
test_image_folder = "../Datasets/ROCO2/test_images/test/"

# Define image transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Create dataset and dataloader
dataset = MedicalImageDataset(csv_file=train_caption_file, image_folder=train_image_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0, pin_memory=True)

# Initialize model manager
model_manager = ModelManager('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', 'gpt2')

  checkpoint = torch.load(checkpoint_path, map_location=map_location)


In [6]:
# Train model
train_model(model_manager, dataloader, num_epochs=5)

  x = F.scaled_dot_product_attention(


ValueError: Expected input batch_size (192) to match target batch_size (8192).