In [1]:
import os
import random
import numpy as np 
import pandas as pd
from tqdm import tqdm 
from PIL import Image
import matplotlib.pyplot as plt

import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms 
from transformers import ViTFeatureExtractor
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import cross_entropy
from transformers import BertTokenizer, BertModel, ViTModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
class ChestXrayDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, index):
        img_name = os.path.join(self.img_dir , self.data_frame['filename'].iloc[index])
        caption = self.data_frame['impression'].iloc[index]

        image = Image.open(img_name).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, caption

In [3]:
image_dir = "Dataset\Indiana University - Chest X-Rays\images\images"
image_caption_csv_path = "Dataset\Indiana University - Chest X-Rays\indiana_chest_xray_captions.csv"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = ChestXrayDataset(csv_file=image_caption_csv_path, img_dir=image_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)

In [16]:
class MaskedLanguageModeling(nn.Module):
    def __init__(self, bert_model, image_encoder, mask_token_id, pad_token_id, mask_prob=0.15):
        super(MaskedLanguageModeling, self).__init__()
        
        # Full BERT model to access embeddings directly
        self.bert_model = bert_model
        self.text_encoder_layers = nn.ModuleList(bert_model.encoder.layer[:6])  # First 6 layers for text encoding
        self.multimodal_encoder_layers = nn.ModuleList(bert_model.encoder.layer[6:])  # Last 6 layers for multimodal fusion
        
        # Shared image encoder (e.g., Vision Transformer)
        self.image_encoder = image_encoder

        # Prediction head for MLM
        hidden_size = bert_model.config.hidden_size
        self.prediction_head = nn.Linear(hidden_size, bert_model.config.vocab_size)
        
        # Masking probability and token IDs
        self.mask_prob = mask_prob
        self.mask_token_id = mask_token_id
        self.pad_token_id = pad_token_id

    def mask_tokens(self, input_ids):
        labels = input_ids.clone()
        mask = (torch.rand(input_ids.shape, device=input_ids.device) < self.mask_prob) & (input_ids != self.pad_token_id)
        input_ids[mask] = self.mask_token_id
        labels[~mask] = -100  # Ignore unmasked tokens in the loss
        return input_ids, labels

    def forward(self, images, tokenized_text):
        input_ids, attention_mask = tokenized_text['input_ids'], tokenized_text['attention_mask']
        
        # Image Encoding
        image_features = self.image_encoder(pixel_values=images).last_hidden_state  # Shape: (batch_size, num_patches, hidden_size)
        
        # Text Encoding with masking
        masked_input_ids, labels = self.mask_tokens(input_ids)
        
        # Use embeddings directly from BERT
        text_embeddings = self.bert_model.embeddings(masked_input_ids)
        
        # Pass through the first 6 layers of BERT for text encoding
        for layer in self.text_encoder_layers:
            text_embeddings = layer(text_embeddings)[0]
        
        # Multimodal Interaction
        combined_features = torch.cat((image_features, text_embeddings), dim=1)  # Concatenate along sequence dimension
        multimodal_features = combined_features
        for layer in self.multimodal_encoder_layers:
            multimodal_features = layer(multimodal_features)[0]
        
        # MLM Prediction
        text_predictions = self.prediction_head(multimodal_features[:, -text_embeddings.size(1):, :])  # Only predict on text portion
        
        return text_predictions, labels

In [17]:
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device)

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
mask_token_id = tokenizer.mask_token_id
pad_token_id = tokenizer.pad_token_id

# Initialize the MLM model
mlm_model = MaskedLanguageModeling(
    bert_model=bert_model,
    image_encoder=image_encoder,
    mask_token_id=mask_token_id,
    pad_token_id=pad_token_id
).to(device)

In [18]:
images, captions = next(iter(data_loader))
images = images.to(device)
tokenized_text = tokenizer(captions, return_tensors="pt", padding=True, truncation=True).to(device)

# Forward pass through the MLM model
with torch.no_grad():
    text_predictions, labels = mlm_model(images, tokenized_text)

# Output shapes for verification
print("Text Predictions Shape:", text_predictions.shape)  # Expected: (batch_size, sequence_length, vocab_size)
print("Labels Shape:", labels.shape)

Text Predictions Shape: torch.Size([8, 15, 30522])
Labels Shape: torch.Size([8, 15])


In [26]:
import torch
import torch.optim as optim
import os

def train_mlm_model(mlm_model, bert_model, image_encoder, data_loader, tokenizer, device, num_epochs=3, learning_rate=1e-4, checkpoint_path="mlm_checkpoint.pth", encoders_checkpoint_path="encoders_checkpoint.pth"):
    """
    Train the Masked Language Modeling model with shared encoders and checkpointing.

    Parameters:
    - mlm_model: MaskedLanguageModeling instance
    - bert_model: Shared BERT model (used for text and multimodal encoding)
    - image_encoder: Shared image encoder (Vision Transformer)
    - data_loader: DataLoader instance with training data
    - tokenizer: BertTokenizer instance
    - device: torch.device, either 'cuda' or 'cpu'
    - num_epochs: int, number of training epochs
    - learning_rate: float, learning rate for optimizer
    - checkpoint_path: str, path to save/load model checkpoint
    - encoders_checkpoint_path: str, path to save/load encoder checkpoints
    """

    optimizer = optim.Adam(mlm_model.parameters(), lr=learning_rate)
    loss_fn = torch.nn.CrossEntropyLoss()

    # Load checkpoint if it exists
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        mlm_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch + 1}")

    for epoch in range(start_epoch, num_epochs):
        mlm_model.train()
        total_loss = 0
        batch_count = len(data_loader)
        
        for i, (images, captions) in enumerate(data_loader):
            # Move images to the correct device
            images = images.to(device)
            
            # Tokenize captions and move to device
            tokenized = tokenizer(captions, return_tensors="pt", padding=True, truncation=True)
            input_ids = tokenized['input_ids'].to(device)
            attention_mask = tokenized['attention_mask'].to(device)
            tokenized_text = {'input_ids': input_ids, 'attention_mask': attention_mask}
            
            # Forward pass
            optimizer.zero_grad()
            text_predictions, labels = mlm_model(images, tokenized_text)
            
            # Reshape predictions and labels for calculating loss
            vocab_size = bert_model.config.vocab_size
            text_predictions = text_predictions.view(-1, vocab_size)  # Shape: (batch_size * seq_len, vocab_size)
            labels = labels.view(-1)  # Shape: (batch_size * seq_len)
            
            # Calculate loss and backpropagate
            loss = loss_fn(text_predictions, labels)
            loss.backward()
            optimizer.step()
            
            # Accumulate loss for reporting
            total_loss += loss.item()

            # Print progress every 10 batches
            if (i + 1) % 10 == 0:
                print(f"Epoch [{epoch + 1}/{num_epochs}], Batch [{i + 1}/{batch_count}], Loss: {total_loss / (i + 1):.4f}")
        
        # Average loss for the epoch
        avg_loss = total_loss / len(data_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")
        
        # Save MLM model checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': mlm_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"MLM model checkpoint saved at epoch {epoch + 1}")

        # Save the shared encoders' states to enable reuse
        torch.save({
            'bert_model_state_dict': bert_model.state_dict(),
            'image_encoder_state_dict': image_encoder.state_dict(),
        }, encoders_checkpoint_path)
        print(f"Shared encoders' checkpoints saved at epoch {epoch + 1}")

    print("Training complete.")


In [27]:
train_mlm_model(
    mlm_model=mlm_model,
    bert_model=bert_model,
    image_encoder=image_encoder,
    data_loader=data_loader,
    tokenizer=tokenizer,
    device=device,
    num_epochs=3,
    learning_rate=1e-4,
    checkpoint_path="mlm_checkpoint.pth",
    encoders_checkpoint_path="encoders_checkpoint.pth"
)

Epoch [1/3], Batch [10/927], Loss: 9.8781
Epoch [1/3], Batch [20/927], Loss: 9.1913
Epoch [1/3], Batch [30/927], Loss: 8.8698
Epoch [1/3], Batch [40/927], Loss: 8.5506
Epoch [1/3], Batch [50/927], Loss: 8.2074
Epoch [1/3], Batch [60/927], Loss: 7.8312
Epoch [1/3], Batch [70/927], Loss: 7.5788
Epoch [1/3], Batch [80/927], Loss: 7.3073
Epoch [1/3], Batch [90/927], Loss: 6.9983
Epoch [1/3], Batch [100/927], Loss: 6.7616
Epoch [1/3], Batch [110/927], Loss: 6.5130
Epoch [1/3], Batch [120/927], Loss: 6.3386
Epoch [1/3], Batch [130/927], Loss: 6.1512
Epoch [1/3], Batch [140/927], Loss: 5.9870
Epoch [1/3], Batch [150/927], Loss: 5.8531
Epoch [1/3], Batch [160/927], Loss: 5.7231
Epoch [1/3], Batch [170/927], Loss: 5.6108
Epoch [1/3], Batch [180/927], Loss: 5.4668
Epoch [1/3], Batch [190/927], Loss: 5.3388
Epoch [1/3], Batch [200/927], Loss: 5.2250
Epoch [1/3], Batch [210/927], Loss: 5.1146
Epoch [1/3], Batch [220/927], Loss: 5.0460
Epoch [1/3], Batch [230/927], Loss: 4.9454
Epoch [1/3], Batch [

KeyboardInterrupt: 