In [1]:
import os
import numpy as np 
import pandas as pd
from tqdm import tqdm 
from PIL import Image
import matplotlib.pyplot as plt

import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms 
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import cross_entropy
from transformers import BertTokenizer, BertModel, ViTModel

In [2]:
class ChestXrayDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, index):
        img_name = os.path.join(self.img_dir , self.data_frame['filename'].iloc[index])
        caption = self.data_frame['impression'].iloc[index]

        image = Image.open(img_name).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, caption

In [3]:
class MaskedLanguageModeling(nn.Module):
    def __init__(self, 
                 text_encoder_name='bert-base-uncased', 
                 vision_encoder_name='google/vit-base-patch16-224-in21k',
                 mask_token_id=None,  # Pass mask_token_id from tokenizer
                 pad_token_id=None,   # Pass pad_token_id from tokenizer
                 mask_prob=0.15):
        super(MaskedLanguageModeling, self).__init__()
        
        # Initialize BERT model and separate layers for text encoding and multimodal fusion
        self.bert_model = BertModel.from_pretrained(text_encoder_name)
        self.text_encoder_layers = nn.ModuleList(self.bert_model.encoder.layer[:6])  # First 6 layers for text encoding
        self.multimodal_encoder_layers = nn.ModuleList(self.bert_model.encoder.layer[6:])  # Last 6 layers for multimodal fusion
        
        # Vision transformer for image encoding
        self.vision_encoder = ViTModel.from_pretrained(vision_encoder_name)
        
        # Prediction head for MLM
        self.prediction_head = nn.Linear(self.bert_model.config.hidden_size, self.bert_model.config.vocab_size)
        
        # Masking probability and token IDs
        self.mask_prob = mask_prob
        self.mask_token_id = mask_token_id
        self.pad_token_id = pad_token_id

    def mask_tokens(self, input_ids):
        """
        Masks input tokens with a probability of `mask_prob` and returns the masked input_ids and labels.
        """
        labels = input_ids.clone()
        
        # Ensure pad_token_id is on the same device as input_ids
        pad_token_id = torch.tensor(self.pad_token_id, device=input_ids.device)
        
        # Create mask array for tokens to be masked
        mask = (torch.rand(input_ids.shape, device=input_ids.device) < self.mask_prob) & (input_ids != pad_token_id)
        
        # Replace tokens with [MASK] token where mask is True
        input_ids[mask] = self.mask_token_id
        labels[~mask] = -100  # Set label to -100 for unmasked tokens, to ignore them in the loss
        
        return input_ids, labels

    def forward(self, images, tokenized_text):
        """
        Forward pass that processes images and tokenized text and outputs MLM predictions.
        """
        input_ids, attention_mask = tokenized_text['input_ids'], tokenized_text['attention_mask']
        
        # 1. Image Encoding: Extract features from images using the ViT model
        image_features = self.vision_encoder(images).last_hidden_state  # Shape: (batch_size, num_patches, hidden_size)
        
        # 2. Text Encoding: Apply masking on input_ids
        masked_input_ids, labels = self.mask_tokens(input_ids)
        
        # Convert token IDs to embeddings using BERT's embedding layer
        text_features = self.bert_model.embeddings(masked_input_ids)
        
        # Pass masked tokens through the first 6 layers of BERT for text encoding
        for layer in self.text_encoder_layers:
            text_features = layer(text_features)[0]
        
        # 3. Multimodal Interaction: Concatenate image and text features and pass through multimodal encoder
        combined_features = torch.cat((image_features, text_features), dim=1)  # Concatenate along sequence dimension
        multimodal_features = combined_features
        for layer in self.multimodal_encoder_layers:
            multimodal_features = layer(multimodal_features)[0]
        
        # 4. MLM Prediction: Apply prediction head to the text portion of the multimodal output
        text_predictions = self.prediction_head(multimodal_features[:, -text_features.size(1):, :])  # Only text portion
        
        return text_predictions, labels

In [4]:
image_dir = "Dataset\Indiana University - Chest X-Rays\images\images"
image_caption_csv_path = "Dataset\Indiana University - Chest X-Rays\indiana_chest_xray_captions.csv"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = ChestXrayDataset(csv_file=image_caption_csv_path, img_dir=image_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
image, caption = dataset.__getitem__(3)

image = image.permute(1, 2, 0).numpy()
plt.imshow(image)
plt.title(f"{caption}")
plt.show()

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
mask_token_id = tokenizer.mask_token_id
pad_token_id = tokenizer.pad_token_id

mlm_model = MaskedLanguageModeling(mask_token_id=mask_token_id, pad_token_id=pad_token_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mlm_model = mlm_model.to(device)

loss_fn = nn.CrossEntropyLoss()

# Get one batch from the data loader
images, captions = next(iter(data_loader))

# Tokenize captions outside the model and move to device
tokenized = tokenizer(captions, return_tensors="pt", padding=True, truncation=True)
input_ids = tokenized['input_ids'].to(device)
attention_mask = tokenized['attention_mask'].to(device)
tokenized_text = {'input_ids': input_ids, 'attention_mask': attention_mask}
images = images.to(device)

In [6]:
text_predictions, labels = mlm_model(images, tokenized_text)

In [7]:
print(text_predictions.shape, labels.shape)

torch.Size([8, 37, 30522]) torch.Size([8, 37])


In [8]:
text_predictions = text_predictions.view(-1, mlm_model.bert_model.config.vocab_size)
labels = labels.view(-1)

In [9]:
print(text_predictions.shape, labels.shape)

torch.Size([296, 30522]) torch.Size([296])


In [10]:
loss = loss_fn(text_predictions, labels)

In [11]:
print("Model Predictions Shape:", text_predictions.shape)
print("Labels Shape:", labels.shape)
print("Loss:", loss.item())

Model Predictions Shape: torch.Size([296, 30522])
Labels Shape: torch.Size([296])
Loss: 10.35284423828125


In [13]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
mask_token_id = tokenizer.mask_token_id
pad_token_id = tokenizer.pad_token_id

# Initialize the model, optimizer, and loss function
mlm_model = MaskedLanguageModeling(mask_token_id=mask_token_id, pad_token_id=pad_token_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mlm_model = mlm_model.to(device)

optimizer = optim.Adam(mlm_model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()

# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    mlm_model.train()
    total_loss = 0
    batch_count = len(data_loader)
    for i, (images, captions) in enumerate(data_loader):
        # Move images to the correct device
        images = images.to(device)
        
        # Tokenize captions outside the model and move to device
        tokenized = tokenizer(captions, return_tensors="pt", padding=True, truncation=True)
        input_ids = tokenized['input_ids'].to(device)
        attention_mask = tokenized['attention_mask'].to(device)
        tokenized_text = {'input_ids': input_ids, 'attention_mask': attention_mask}
        
        # Forward pass
        optimizer.zero_grad()
        text_predictions, labels = mlm_model(images, tokenized_text)
        
        # Reshape predictions and labels for calculating loss
        text_predictions = text_predictions.view(-1, mlm_model.bert_model.config.vocab_size)
        labels = labels.view(-1)
        
        # Calculate loss and perform backpropagation
        loss = loss_fn(text_predictions, labels)
        loss.backward()
        optimizer.step()
        
        # Accumulate loss for reporting
        total_loss += loss.item()

        # Print progress every 10 batches
        if (i + 1) % 10 == 0:
            print(f"Batch [{i + 1}/{batch_count}], Loss: {total_loss / (i + 1)}")
    
    avg_loss = total_loss / len(data_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

print("Training complete.")

Batch [10/927], Loss: 9.709396171569825
Batch [20/927], Loss: 9.182132697105407
Batch [30/927], Loss: 8.739301300048828
Batch [40/927], Loss: 8.268716651201249
Batch [50/927], Loss: 7.971540026664734
Batch [60/927], Loss: 7.609890321890513
Batch [70/927], Loss: 7.3763805832181655
Batch [80/927], Loss: 7.140872257947922
Batch [90/927], Loss: 6.8168478833304516
Batch [100/927], Loss: 6.57410493016243
Batch [110/927], Loss: 6.3626073501326825
Batch [120/927], Loss: 6.202050961057345
Batch [130/927], Loss: 6.013554648252634
Batch [140/927], Loss: 5.853852467026029
Batch [150/927], Loss: 5.683526790936788
Batch [160/927], Loss: 5.511126710660756
Batch [170/927], Loss: 5.3891183251843735
Batch [180/927], Loss: 5.28337623162402
Batch [190/927], Loss: 5.173957423316805
Batch [200/927], Loss: 5.10396466627717
Batch [210/927], Loss: 5.01624285195555
Batch [220/927], Loss: 4.921440846947107
Batch [230/927], Loss: 4.827430165461872
Batch [240/927], Loss: 4.756799542034666
Batch [250/927], Loss: 4.

In [21]:
def test_sample(model, tokenizer, image, caption, device):
    model.eval()  # Set the model to evaluation mode
    
    # Tokenize the caption
    tokenized = tokenizer(caption, return_tensors="pt", padding=True, truncation=True)
    input_ids = tokenized['input_ids'].to(device)
    attention_mask = tokenized['attention_mask'].to(device)
    tokenized_text = {'input_ids': input_ids, 'attention_mask': attention_mask}
    
    # Move image to the device
    image = image.to(device).unsqueeze(0)  # Add batch dimension
    
    # Mask tokens in the caption (simulating the MLM task)
    masked_input_ids, labels = model.mask_tokens(input_ids)
    
    # Pass the image and masked caption through the model
    with torch.no_grad():
        text_predictions, _ = model(image, {'input_ids': masked_input_ids, 'attention_mask': attention_mask})
    
    # Identify masked positions in the labels
    masked_positions = (labels != -100).squeeze()
    
    # Get the predicted token IDs at the masked positions
    predicted_token_ids = text_predictions.argmax(dim=-1).squeeze()[masked_positions]
    
    # Convert token IDs back to words
    predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_token_ids.tolist())
    
    # Print original and predicted tokens
    print("Original Caption:", caption)
    print("Masked Caption:", tokenizer.decode(masked_input_ids.squeeze()))
    print("Predicted Tokens at Masked Positions:", predicted_tokens)

# Example usage:
# Load a sample from the dataset (assuming `dataset` is an instance of ChestXrayDataset)
sample_image, sample_caption = dataset[3]

# Move model to device if not already done
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mlm_model = mlm_model.to(device)

# Test the model on a sample
test_sample(mlm_model, tokenizer, sample_image, sample_caption, device)


Original Caption: No acute pulmonary findings.
Masked Caption: [MASK] no acute pulmonary [MASK]. [SEP]
Predicted Tokens at Masked Positions: ['[CLS]', 'disease']


In [None]:
def train_mlm_model(model, data_loader, tokenizer, device, num_epochs=3, learning_rate=1e-4, checkpoint_path="mlm_checkpoint.pth"):
    """
    Train the Masked Language Modeling model with checkpointing.

    Parameters:
    - model: MaskedLanguageModeling instance
    - data_loader: DataLoader instance with training data
    - tokenizer: BertTokenizer instance
    - device: torch.device, either 'cuda' or 'cpu'
    - num_epochs: int, number of training epochs
    - learning_rate: float, learning rate for optimizer
    - checkpoint_path: str, path to save/load model checkpoint
    """

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = torch.nn.CrossEntropyLoss()

    # Load checkpoint if exists
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch + 1}")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_loss = 0
        batch_count = len(data_loader)
        
        for i, (images, captions) in enumerate(data_loader):
            # Move images to the correct device
            images = images.to(device)
            
            # Tokenize captions outside the model and move to device
            tokenized = tokenizer(captions, return_tensors="pt", padding=True, truncation=True)
            input_ids = tokenized['input_ids'].to(device)
            attention_mask = tokenized['attention_mask'].to(device)
            tokenized_text = {'input_ids': input_ids, 'attention_mask': attention_mask}
            
            # Forward pass
            optimizer.zero_grad()
            text_predictions, labels = model(images, tokenized_text)
            
            # Reshape predictions and labels for calculating loss
            text_predictions = text_predictions.view(-1, model.bert_model.config.vocab_size)
            labels = labels.view(-1)
            
            # Calculate loss and perform backpropagation
            loss = loss_fn(text_predictions, labels)
            loss.backward()
            optimizer.step()
            
            # Accumulate loss for reporting
            total_loss += loss.item()

            # Print progress every 10 batches
            if (i + 1) % 10 == 0:
                print(f"Epoch [{epoch + 1}/{num_epochs}], Batch [{i + 1}/{batch_count}], Loss: {total_loss / (i + 1):.4f}")
        
        # Average loss for the epoch
        avg_loss = total_loss / len(data_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
        
        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch + 1}")

    print("Training complete.")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize tokenizer, model, and load checkpoint if available
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
mask_token_id = tokenizer.mask_token_id
pad_token_id = tokenizer.pad_token_id

mlm_model = MaskedLanguageModeling(mask_token_id=mask_token_id, pad_token_id=pad_token_id)
mlm_model = mlm_model.to(device)

# Assuming data_loader is already initialized with ChestXrayDataset
train_mlm_model(mlm_model, data_loader, tokenizer, device, num_epochs=3, learning_rate=1e-4, checkpoint_path="models/mlm_checkpoint.pth")