In [1]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
import torch
from torchvision import transforms
import open_clip

In [2]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))

True
0
NVIDIA GeForce RTX 3070


In [3]:
class MedicalImageDataset(Dataset):
    def __init__(self, csv_file, image_folder, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.data.iloc[idx, 0] + '.jpg')  # Pick first column and add .jpg
        image = Image.open(img_name).convert('RGB')
        caption = self.data.iloc[idx, 1]  # Assuming 'Caption' is the second column
        
        if self.transform:
            image = self.transform(image)
        
        return image, caption

def preprocess_images(images, preprocess_func):
    """Preprocess images using the given preprocessing function."""
    preprocessed_images = []
    for img in images:
        if isinstance(img, torch.Tensor):
            # Convert tensor to PIL image for preprocessing
            img = transforms.ToPILImage()(img)
        preprocessed_images.append(preprocess_func(img))
    return torch.stack(preprocessed_images)

def tokenize_captions(captions, tokenizer, context_length):
    """Tokenize captions."""
    # Add padding token if not already present
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    return tokenizer(captions, padding=True, truncation=True, max_length=context_length, return_tensors='pt')

def train_model(model_manager, dataloader, num_epochs):
    """Train the model with progress tracking and time estimation."""
    from tqdm import tqdm
    from torch.utils.tensorboard import SummaryWriter
    import time
    
    writer = SummaryWriter()
    
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        for batch_idx, (images, captions) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}")):
            # Extract image embeddings
            image_embeddings_gpt2 = model_manager.extract_image_embeddings(images)
            
            # Tokenize captions
            tokenized_captions = tokenize_captions(captions, model_manager.gpt2_tokenizer, context_length=256)
            
            # Generate captions
            prompts = ["Medical image description: "] * len(images)
            generated_captions = model_manager.generate_captions(image_embeddings_gpt2, prompts)
            
            # Log generated captions to TensorBoard
            for i, caption in enumerate(generated_captions):
                writer.add_text(f'Generated Caption/{batch_idx * len(images) + i}', caption)
        
        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time
        
        print(f"Epoch {epoch+1} took {epoch_duration:.2f} seconds")
        
        # Log epoch duration to TensorBoard
        writer.add_scalar('Epoch Duration', epoch_duration, epoch)
    
    writer.close()


In [4]:
class ModelManager:
    def __init__(self, biomedclip_model_name: str, gpt2_model_name: str):
        # Load BiomedCLIP model and preprocess functions
        self.biomedclip_model, self.preprocess_train, self.preprocess_val = open_clip.create_model_and_transforms(biomedclip_model_name)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.biomedclip_model.to(self.device)
        
        # Load GPT-2 model and tokenizer
        self.gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name).to(self.device)
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
        
        # Set the correct output dimension for image embeddings
        self.image_embeddings_dim = 512  # Based on observed feature shape
        
        # Initialize projection layer
        self.image_embeddings_projected = torch.nn.Linear(self.image_embeddings_dim, self.gpt2_model.config.n_embd).to(self.device)

    def extract_image_embeddings(self, images):
        """Extract image embeddings from BiomedCLIP."""
        with torch.no_grad():
            preprocessed_images = preprocess_images(images, self.preprocess_val)  # Use preprocess_val for validation images
            image_embeddings = self.biomedclip_model.encode_image(preprocessed_images.to(self.device))
            return self.image_embeddings_projected(image_embeddings)
    
    def generate_captions(self, image_embeddings_gpt2, prompts, num_return_sequences=1):
        """Generate captions using GPT-2."""
        self.gpt2_model.eval()
        generated_captions = []
        
        for i, embedding in enumerate(image_embeddings_gpt2):
            prompt = prompts[i]  # Use the corresponding prompt for each image embedding
            
            # Prepare inputs
            input_ids = self.gpt2_tokenizer(prompt, return_tensors='pt').input_ids.to(self.device)
            input_embeds = self.gpt2_model.transformer.wte(input_ids)
            
            # Expand embedding dimensions to match input_embeds
            embedding_expanded = embedding.unsqueeze(0).unsqueeze(1)  # Shape: (1, 1, embedding_dim)
            
            # Ensure input_embeds and embedding_expanded have compatible dimensions
            combined_embeds = torch.cat((embedding_expanded.expand(input_embeds.size(0), -1, -1), input_embeds), dim=1)
            
            # Create attention mask
            attention_mask = torch.ones(combined_embeds.shape[:-1]).to(self.device)
            
            # Generate captions
            generated_ids = self.gpt2_model.generate(
                inputs_embeds=combined_embeds,
                attention_mask=attention_mask,
                max_length=50,
                pad_token_id=self.gpt2_tokenizer.pad_token_id,
                num_return_sequences=num_return_sequences
            )
            
            # Decode generated captions
            generated_caption = self.gpt2_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            generated_captions.append(generated_caption)
        
        return generated_captions


In [5]:
# Define file paths
train_caption_file = "../Datasets/ROCO2/train_captions.csv"
train_image_folder = "../Datasets/ROCO2/train_images/train/"
test_caption_file = "../Datasets/ROCO2/test_captions.csv"
test_image_folder = "../Datasets/ROCO2/test_images/test/"

# Define image transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Create dataset and dataloader
dataset = MedicalImageDataset(csv_file=train_caption_file, image_folder=train_image_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)

# Initialize model manager
model_manager = ModelManager('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', 'gpt2')



RuntimeError: Error(s) in loading state_dict for CustomTextCLIP:
	Missing key(s) in state_dict: "text.transformer.embeddings.position_ids". 

In [None]:
# Train model
train_model(model_manager, dataloader, num_epochs=5)