In [1]:
import os
import random
import numpy as np 
import pandas as pd
from tqdm import tqdm 
from PIL import Image
import matplotlib.pyplot as plt

import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms 
from transformers import ViTFeatureExtractor
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import cross_entropy
from transformers import BertTokenizer, BertModel, ViTModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
class ChestXrayDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, index):
        img_name = os.path.join(self.img_dir , self.data_frame['filename'].iloc[index])
        caption = self.data_frame['impression'].iloc[index]

        image = Image.open(img_name).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, caption

In [3]:
image_dir = "Dataset\Indiana University - Chest X-Rays\images\images"
image_caption_csv_path = "Dataset\Indiana University - Chest X-Rays\indiana_chest_xray_captions.csv"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = ChestXrayDataset(csv_file=image_caption_csv_path, img_dir=image_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)

In [4]:
class ImageTextContrastive(nn.Module):
    def __init__(self, image_encoder, text_encoder, hidden_size, projection_dim=256, momentum=0.999, lookup_size=65535):
        super(ImageTextContrastive, self).__init__()
        
        # Shared image and text encoders
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder

        # Projections to 256-dimensional vectors
        self.image_projection = nn.Linear(hidden_size, projection_dim)
        self.text_projection = nn.Linear(hidden_size, projection_dim)

        # Momentum encoders (for MoCo-style momentum updates)
        self.momentum_image_encoder = image_encoder  # Clone of image encoder for momentum updates
        self.momentum_text_encoder = text_encoder    # Clone of text encoder for momentum updates
        self.momentum_image_projection = nn.Linear(hidden_size, projection_dim)
        self.momentum_text_projection = nn.Linear(hidden_size, projection_dim)
        
        # Initialize the momentum encoder weights to match the base encoder weights
        self._init_momentum_encoders()
        
        # Lookup table for recent image-text representations
        self.lookup_size = lookup_size
        self.register_buffer("image_lookup", torch.randn(lookup_size, projection_dim))
        self.register_buffer("text_lookup", torch.randn(lookup_size, projection_dim))
        
        # Momentum for updating momentum encoders
        self.momentum = momentum
        self.lookup_index = 0  # Pointer for updating the lookup table

    def _init_momentum_encoders(self):
        # Initialize momentum encoders to match the base encoders
        for param_q, param_k in zip(self.image_encoder.parameters(), self.momentum_image_encoder.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False  # Momentum encoder is not trainable
            
        for param_q, param_k in zip(self.text_encoder.parameters(), self.momentum_text_encoder.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False
        
        # Initialize projection layers similarly
        self.momentum_image_projection.load_state_dict(self.image_projection.state_dict())
        self.momentum_text_projection.load_state_dict(self.text_projection.state_dict())
        
    @torch.no_grad()
    def _momentum_update_encoders(self):
        # Apply momentum update to both image and text encoders
        for param_q, param_k in zip(self.image_encoder.parameters(), self.momentum_image_encoder.parameters()):
            param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum)
        
        for param_q, param_k in zip(self.text_encoder.parameters(), self.momentum_text_encoder.parameters()):
            param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum)

        # Update projection layers similarly
        for param_q, param_k in zip(self.image_projection.parameters(), self.momentum_image_projection.parameters()):
            param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum)

        for param_q, param_k in zip(self.text_projection.parameters(), self.momentum_text_projection.parameters()):
            param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum)

    def forward(self, images, input_ids, attention_mask):
        # Image encoding with CLS token extraction
        image_features = self.image_encoder(images).last_hidden_state[:, 0, :]  # CLS token
        image_features = F.normalize(self.image_projection(image_features), dim=-1)

        # Text encoding with CLS token extraction
        text_features = self.extract_text_features(input_ids, attention_mask)
        text_features = F.normalize(self.text_projection(text_features), dim=-1)

        # Momentum encoding for contrastive loss calculation
        with torch.no_grad():
            self._momentum_update_encoders()
            momentum_image_features = self.momentum_image_encoder(images).last_hidden_state[:, 0, :]
            momentum_image_features = F.normalize(self.momentum_image_projection(momentum_image_features), dim=-1)

            momentum_text_features = self.extract_text_features(input_ids, attention_mask, momentum=True)
            momentum_text_features = F.normalize(self.momentum_text_projection(momentum_text_features), dim=-1)

            # Update lookup table
            self._update_lookup_table(momentum_image_features, momentum_text_features)
        
        # Contrastive loss calculation
        contrastive_loss = self.compute_contrastive_loss(image_features, text_features)
        
        return contrastive_loss

    def extract_text_features(self, input_ids, attention_mask, momentum=False):
        encoder = self.momentum_text_encoder if momentum else self.text_encoder
        embeddings = encoder.embeddings(input_ids=input_ids)
        
        # Ensure attention_mask has the correct shape for broadcasting
        attention_mask = attention_mask[:, None, None, :]  # Shape: (batch_size, 1, 1, sequence_length)
        
        # Apply only the first 6 layers of the encoder
        text_features = embeddings
        for layer in encoder.encoder.layer[:6]:
            text_features = layer(text_features, attention_mask=attention_mask)[0]
        
        return text_features[:, 0, :]  # CLS token

    def _update_lookup_table(self, image_features, text_features):
        # Update the lookup table with new entries using a circular index
        batch_size = image_features.size(0)
        if batch_size > self.lookup_size:
            batch_size = self.lookup_size
        
        self.image_lookup[self.lookup_index:self.lookup_index + batch_size] = image_features[:batch_size]
        self.text_lookup[self.lookup_index:self.lookup_index + batch_size] = text_features[:batch_size]
        self.lookup_index = (self.lookup_index + batch_size) % self.lookup_size

    def compute_contrastive_loss(self, image_features, text_features):
        # Compute cosine similarities between image and text features
        sim_i2t = torch.mm(image_features, self.text_lookup.T)  # Image to text similarity
        sim_t2i = torch.mm(text_features, self.image_lookup.T)  # Text to image similarity

        # Apply contrastive learning loss based on ALBEF's approach
        labels = torch.arange(image_features.size(0), device=image_features.device)
        loss_i2t = F.cross_entropy(sim_i2t, labels)
        loss_t2i = F.cross_entropy(sim_t2i, labels)

        return (loss_i2t + loss_t2i) / 2

In [5]:
import copy 

In [6]:
class ImageTextContrastive(nn.Module):
    def __init__(self, image_encoder, text_encoder, hidden_size, projection_dim=256, momentum=0.995):
        super(ImageTextContrastive, self).__init__()
        self.image_encoder = image_encoder  # Shared image encoder
        self.text_encoder = text_encoder  # Shared text encoder (e.g., BERT model)
        self.projection_dim = projection_dim
        self.momentum = momentum

        # Projection layers for image and text features
        self.image_projection = nn.Linear(hidden_size, projection_dim)
        self.text_projection = nn.Linear(hidden_size, projection_dim)

        # Momentum encoders (deep copy of the original encoders)
        self.momentum_image_encoder = copy.deepcopy(self.image_encoder)
        self.momentum_text_encoder = copy.deepcopy(self.text_encoder)
        self.momentum_image_projection = nn.Linear(hidden_size, projection_dim)
        self.momentum_text_projection = nn.Linear(hidden_size, projection_dim)

        # Ensure the momentum encoders are initialized with the same weights as the main encoders
        self._initialize_momentum_encoders()

    def _initialize_momentum_encoders(self):
        # Copy parameters from main encoders to momentum encoders
        for param, momentum_param in zip(self.image_encoder.parameters(), self.momentum_image_encoder.parameters()):
            momentum_param.data.copy_(param.data)
            momentum_param.requires_grad = False

        for param, momentum_param in zip(self.text_encoder.parameters(), self.momentum_text_encoder.parameters()):
            momentum_param.data.copy_(param.data)
            momentum_param.requires_grad = False

    @torch.no_grad()
    def _momentum_update(self):
        # Apply the momentum update to both image and text encoders
        for param, momentum_param in zip(self.image_encoder.parameters(), self.momentum_image_encoder.parameters()):
            momentum_param.data = self.momentum * momentum_param.data + (1 - self.momentum) * param.data

        for param, momentum_param in zip(self.text_encoder.parameters(), self.momentum_text_encoder.parameters()):
            momentum_param.data = self.momentum * momentum_param.data + (1 - self.momentum) * param.data

    def forward(self, images, input_ids, attention_mask):
        # Main encoding and projection
        image_features = self.image_encoder(pixel_values=images).last_hidden_state[:, 0, :]  # CLS token
        image_features = F.normalize(self.image_projection(image_features), dim=-1)

        text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        text_features = F.normalize(self.text_projection(text_features), dim=-1)

        # Momentum encoding and projection for contrastive learning
        with torch.no_grad():
            self._momentum_update()  # Update momentum encoders

            momentum_image_features = self.momentum_image_encoder(pixel_values=images).last_hidden_state[:, 0, :]
            momentum_image_features = F.normalize(self.momentum_image_projection(momentum_image_features), dim=-1)

            momentum_text_features = self.momentum_text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            momentum_text_features = F.normalize(self.momentum_text_projection(momentum_text_features), dim=-1)

        # Contrastive loss between main and momentum features
        contrastive_loss = self.calculate_contrastive_loss(image_features, text_features, momentum_image_features, momentum_text_features)
        return contrastive_loss

    def calculate_contrastive_loss(self, image_features, text_features, momentum_image_features, momentum_text_features):
        # Compute similarities and contrastive loss
        batch_size = image_features.size(0)
        temperature = 0.07

        # Similarity scores
        sim_image_text = torch.mm(image_features, text_features.t()) / temperature
        sim_image_momentum_text = torch.mm(image_features, momentum_text_features.t()) / temperature
        sim_momentum_image_text = torch.mm(momentum_image_features, text_features.t()) / temperature

        # Labels for contrastive loss
        labels = torch.arange(batch_size).to(image_features.device)

        # Contrastive loss across both directions
        loss = F.cross_entropy(sim_image_text, labels) + F.cross_entropy(sim_image_momentum_text, labels) + F.cross_entropy(sim_momentum_image_text, labels)
        return loss / 3  # Average the loss

In [7]:
image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
text_encoder = BertModel.from_pretrained('bert-base-uncased')
itc_model = ImageTextContrastive(image_encoder=image_encoder, text_encoder=text_encoder, hidden_size=image_encoder.config.hidden_size).to(device)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Example of computing contrastive loss for one batch
for images, captions in data_loader:
    images = images.to(device)

    # Tokenize captions
    tokenized = tokenizer(captions, return_tensors="pt", padding=True, truncation=True).to(device)
    input_ids, attention_mask = tokenized['input_ids'], tokenized['attention_mask']

    # Forward pass
    contrastive_loss = itc_model(images, input_ids, attention_mask)
    print(f"Contrastive Loss: {contrastive_loss.item()}")
    break

Contrastive Loss: 2.125472068786621


In [8]:
def train_itc_model(itc_model, data_loader, tokenizer, device, num_epochs=3, learning_rate=1e-4, checkpoint_path="itc_checkpoint.pth"):
    """
    Train the Image-Text Contrastive model with checkpointing.

    Parameters:
    - itc_model: ImageTextContrastive instance
    - data_loader: DataLoader instance with training data
    - tokenizer: BertTokenizer instance
    - device: torch.device, either 'cuda' or 'cpu'
    - num_epochs: int, number of training epochs
    - learning_rate: float, learning rate for optimizer
    - checkpoint_path: str, path to save/load model checkpoint
    """
    
    optimizer = optim.Adam(itc_model.parameters(), lr=learning_rate)
    
    # Load checkpoint if it exists
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        itc_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch + 1}")

    for epoch in range(start_epoch, num_epochs):
        itc_model.train()
        total_loss = 0
        batch_count = len(data_loader)
        
        for i, (images, captions) in enumerate(data_loader):
            images = images.to(device)
            
            # Tokenize captions and move to device
            tokenized = tokenizer(captions, return_tensors="pt", padding=True, truncation=True)
            input_ids = tokenized['input_ids'].to(device)
            attention_mask = tokenized['attention_mask'].to(device)
            
            # Forward pass through the ITC model
            optimizer.zero_grad()
            contrastive_loss = itc_model(images, input_ids, attention_mask)
            
            # Backpropagation
            contrastive_loss.backward()
            optimizer.step()
            
            # Accumulate loss for reporting
            total_loss += contrastive_loss.item()

            # Print progress every 10 batches
            if (i + 1) % 10 == 0:
                print(f"Epoch [{epoch + 1}/{num_epochs}], Batch [{i + 1}/{batch_count}], Loss: {total_loss / (i + 1):.4f}")
        
        # Average loss for the epoch
        avg_loss = total_loss / len(data_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")
        
        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': itc_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch + 1}")

    print("Training complete.")

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained models for image and text encoding
image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device)
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)

# Initialize tokenizer for text processing
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

hidden_size = bert_model.config.hidden_size

# Initialize the ITC model with shared encoders and hidden size
itc_model = ImageTextContrastive(
    image_encoder=image_encoder,
    text_encoder=bert_model,
    hidden_size=hidden_size,
    projection_dim=256
).to(device)

# Prepare data loader
# Assuming `data_loader` is already defined, containing pairs of images and captions

# Define training parameters
num_epochs = 3
learning_rate = 1e-4
checkpoint_path = "itc_checkpoint.pth"

# Train the model
train_itc_model(
    itc_model=itc_model,
    data_loader=data_loader,
    tokenizer=tokenizer,
    device=device,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    checkpoint_path=checkpoint_path
)

Epoch [1/3], Batch [10/927], Loss: 2.1313
Epoch [1/3], Batch [20/927], Loss: 2.1242
Epoch [1/3], Batch [30/927], Loss: 2.1147
Epoch [1/3], Batch [40/927], Loss: 2.1103
Epoch [1/3], Batch [50/927], Loss: 2.1058
Epoch [1/3], Batch [60/927], Loss: 2.1028
Epoch [1/3], Batch [70/927], Loss: 2.0999
Epoch [1/3], Batch [80/927], Loss: 2.0999
Epoch [1/3], Batch [90/927], Loss: 2.0990
Epoch [1/3], Batch [100/927], Loss: 2.0978
Epoch [1/3], Batch [110/927], Loss: 2.0971
Epoch [1/3], Batch [120/927], Loss: 2.0963
Epoch [1/3], Batch [130/927], Loss: 2.0957
Epoch [1/3], Batch [140/927], Loss: 2.0955
Epoch [1/3], Batch [150/927], Loss: 2.0951
Epoch [1/3], Batch [160/927], Loss: 2.0945
Epoch [1/3], Batch [170/927], Loss: 2.0940
Epoch [1/3], Batch [180/927], Loss: 2.0935
Epoch [1/3], Batch [190/927], Loss: 2.0930
Epoch [1/3], Batch [200/927], Loss: 2.0925
Epoch [1/3], Batch [210/927], Loss: 2.0921
Epoch [1/3], Batch [220/927], Loss: 2.0916
Epoch [1/3], Batch [230/927], Loss: 2.0913
Epoch [1/3], Batch [

KeyboardInterrupt: 

In [10]:
import torch
import torch.nn.functional as F

def test_itc_model(itc_model, data_loader, tokenizer, device):
    """
    Test the trained Image-Text Contrastive model on a batch.
    
    Parameters:
    - itc_model: ImageTextContrastive instance (trained model)
    - data_loader: DataLoader instance containing test data
    - tokenizer: BertTokenizer instance for text processing
    - device: torch.device, either 'cuda' or 'cpu'
    """
    itc_model.eval()
    
    with torch.no_grad():
        # Get a single batch of images and captions
        images, captions = next(iter(data_loader))
        
        # Move images to the appropriate device
        images = images.to(device)
        
        # Tokenize captions and move to device
        tokenized = tokenizer(captions, return_tensors="pt", padding=True, truncation=True).to(device)
        input_ids, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
        
        # Compute contrastive loss on the batch
        contrastive_loss = itc_model(images, input_ids, attention_mask)
        print(f"Contrastive Loss on Test Batch: {contrastive_loss.item()}")
        
        # Extract image and text features for analysis
        image_features = itc_model.image_projection(
            F.normalize(itc_model.image_encoder(images).last_hidden_state[:, 0, :], dim=-1)
        )
        text_features = itc_model.text_projection(
            F.normalize(itc_model.extract_text_features(input_ids, attention_mask), dim=-1)
        )
        
        # Print shapes and example features for verification
        print(f"Image Features Shape: {image_features.shape}")   # Expected shape: (batch_size, projection_dim)
        print(f"Text Features Shape: {text_features.shape}")     # Expected shape: (batch_size, projection_dim)
        
        # Display the first image and text feature vectors to understand alignment
        print("First Image Feature Vector:", image_features[0].cpu().numpy())
        print("First Text Feature Vector:", text_features[0].cpu().numpy())

# Call the test function
test_itc_model(itc_model, data_loader, tokenizer, device)


Contrastive Loss on Test Batch: 2.079648733139038


AttributeError: 'ImageTextContrastive' object has no attribute 'extract_text_features'