In [None]:
# Import libraries
import torch
from transformers import BertTokenizer, BertModel
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from torchvision.datasets import CocoCaptions
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms.functional import to_pil_image

In [None]:
# Data Loading and preprocessing
# Using MSCOCO Dataset here

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Define image and text preprocessing transforms
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Custom Dataset Class
class ImageTextDataset(torch.utils.data.Dataset):
    def __init__(self, root, annFile, transform, tokenizer):
        self.dataset = CocoCaptions(root=root, annFile=annFile, transform=transform)
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, captions = self.dataset[idx]
        # Choose the first caption
        text = captions[0]
        text_encoded = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=32)
        return image, text_encoded['input_ids'], text_encoded['attention_mask']

# Example usage
dataset = ImageTextDataset(root='C:/Users/ADMIN/Downloads/TextBasedImageQuery/dataset/images/val2017', annFile='C:/Users/ADMIN/Downloads/TextBasedImageQuery/dataset/annotations/captions_val2017.json', transform=image_transform, tokenizer=tokenizer)
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=16, 
    shuffle=True, 
    drop_last=True  # Drop incomplete batches
)

In [None]:
# Model Definition
class TextImageModel(torch.nn.Module):
    def __init__(self, text_encoder, image_encoder):
        super().__init__()
        self.text_encoder = text_encoder
        self.image_encoder = image_encoder
        self.text_projection = torch.nn.Linear(768, 512)
        self.image_projection = torch.nn.Linear(2048, 512)
    
    def forward(self, input_ids, attention_mask, images):
        text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        text_features = self.text_projection(text_features)
        
        image_features = self.image_encoder(images).squeeze()
        image_features = self.image_projection(image_features)
        
        return text_features, image_features

# Load pretrained encoders

text_encoder = BertModel.from_pretrained('bert-base-uncased')
# Load pretrained weights
image_encoder = resnet50(weights=ResNet50_Weights.DEFAULT)
image_encoder = torch.nn.Sequential(*(list(image_encoder.children())[:-1]))

# Initialize the model
model = TextImageModel(text_encoder, image_encoder)

In [None]:
# Training Loop
criterion = torch.nn.CosineEmbeddingLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(5):
    model.train()
    total_loss = 0
    for images, input_ids, attention_mask in tqdm(dataloader):
        input_ids, attention_mask, images = input_ids.squeeze(1), attention_mask.squeeze(1), images
        text_features, image_features = model(input_ids, attention_mask, images)
        
        # Target: 1 (similar embeddings)
        targets = torch.ones(text_features.size(0)).to(images.device)
        loss = criterion(text_features, image_features, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}")

In [None]:
# Evaluation

# Function to compute Recall@K
def recall_at_k(model, dataloader, k=5):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, input_ids, attention_mask in dataloader:
            input_ids, attention_mask, images = input_ids.squeeze(1), attention_mask.squeeze(1), images
            text_features, image_features = model(input_ids, attention_mask, images)
            similarities = torch.matmul(text_features, image_features.T)
            _, indices = similarities.topk(k, dim=1)
            correct += sum(i in indices[i] for i in range(images.size(0)))
            total += images.size(0)
    return correct / total

# Compute metrics
r5 = recall_at_k(model, dataloader, k=5)
print(f"Recall@5: {r5:.4f}")

In [None]:
# Save model
torch.save(model.state_dict(), 'text_image_model.pt')

# Load model
model.load_state_dict(torch.load('text_image_model.pt'))

In [None]:
# Inference

def batched_inference(text_query, model, dataloader, device, k=5):
    model.eval()
    similarities = []
    all_images = []  # To store all images in a flattened list

    with torch.no_grad():
        for images, _, _ in dataloader:
            images = images.to(device)
            image_features = model.image_encoder(images).squeeze()
            image_features = model.image_projection(image_features)

            # Handle singleton dimensions
            if image_features.dim() == 1:
                image_features = image_features.unsqueeze(0)

            # Compute text features once
            text_features = model.text_encoder(**text_query).pooler_output
            text_features = model.text_projection(text_features)

            # Compute similarity
            similarity = torch.matmul(text_features, image_features.T)
            similarities.append(similarity.cpu())

            # Flatten images into a list
            all_images.extend(images.cpu())  # Store images in a flattened list

    # Concatenate all similarities
    all_similarities = torch.cat(similarities)
    top_k_scores, top_k_indices = all_similarities.topk(k)

    # Retrieve top-k images and scores
    top_k_indices = top_k_indices.view(-1).tolist()
    top_k_images = [(all_images[idx.item()], top_k_scores[i].item()) for i, idx in enumerate(top_k_indices)]

    return top_k_images


# Initialize tokenizer and tokenize the query
query_text = "A dog playing with a ball"
text_query = tokenizer(query_text, return_tensors="pt", padding="max_length", truncation=True, max_length=32)
text_query = {k: v.to(device) for k, v in text_query.items()}

# Perform batched inference to get top-k results
top_k_images = batched_inference(text_query, model, dataloader, device, k=5)

# Print similarity scores and visualize images
for idx, (image_tensor, score) in enumerate(top_k_images):
    print(f"Rank {idx+1}, Similarity Score: {score:.4f}")

    # Visualize the image
    image = to_pil_image(image_tensor)
    image.show()