In [None]:
# !pip install einops
# !pip install tf-models-official

In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import BertTokenizer, BertModel
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

In [None]:
# Ensure you have the GPU enabled if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
DATA_PATH = '/content/drive/MyDrive/data'

In [None]:
# Data Loading Class
class Flickr8kDataset(Dataset):
    def __init__(self, data_file, images_path, tokenizer, transform=None):
        self.data_frame = pd.read_csv(data_file, sep="\t", names=["image", "text", "label"])
        self.images_path = images_path
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        row = self.data_frame.iloc[idx]
        image_path = os.path.join(self.images_path, row['image'])
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        text = row['text']
        inputs = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
        label = torch.tensor(1 if row['label'] == 'match' else 0, dtype=torch.float)

        return image, inputs.input_ids.squeeze(0), inputs.attention_mask.squeeze(0), label

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
# Model Class
class ImageTextMatchingModel(nn.Module):
    def __init__(self):
        super(ImageTextMatchingModel, self).__init__()
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.vision_encoder = models.resnet50(pretrained=True)
        # Replace the classifier of ResNet50
        num_ftrs = self.vision_encoder.fc.in_features
        self.vision_encoder.fc = nn.Linear(num_ftrs, 512)

        # Classifier to combine vision and text features
        self.classifier = nn.Sequential(
            nn.Linear(512 + self.text_encoder.config.hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, images, input_ids, attention_mask):
        text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        vision_features = self.vision_encoder(images)
        combined_features = torch.cat((vision_features, text_features), dim=1)
        logits = self.classifier(combined_features)
        return logits

In [None]:
# Load data
IMAGES_PATH = DATA_PATH + "/images"
train_data_file = DATA_PATH + '/flickr8k.TrainImages.txt'
dev_data_file = DATA_PATH + '/flickr8k.DevImages.txt'
test_data_file = DATA_PATH + '/flickr8k.TestImages.txt'

# Assume you have defined your dataset paths: train_data_file, dev_data_file, test_data_file
train_dataset = Flickr8kDataset(train_data_file, IMAGES_PATH, tokenizer, transform)
val_dataset = Flickr8kDataset(dev_data_file, IMAGES_PATH, tokenizer, transform)
test_dataset = Flickr8kDataset(test_data_file, IMAGES_PATH, tokenizer, transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=16, num_workers=2)

# Initialize model, loss, and optimizer
model = ImageTextMatchingModel().to(device)
criterion = nn.BCELoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-5)

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}/{epochs}', leave=True)
        for i, (images, input_ids, attention_mask, labels) in progress_bar:
            images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images, input_ids, attention_mask).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            progress_bar.set_postfix({'loss': running_loss/(i+1)})
        
        evaluate_model(model, val_loader)

def evaluate_model(model, loader):
    model.eval()
    predictions, truths = [], []
    with torch.no_grad():
        progress_bar = tqdm(loader, desc='Evaluating', leave=False)
        for images, input_ids, attention_mask, labels in progress_bar:
            images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
            outputs = model(images, input_ids, attention_mask).squeeze()
            predicted = (outputs > 0.5).float()
            predictions.extend(predicted.cpu().numpy())
            truths.extend(labels.cpu().numpy())

    accuracy = accuracy_score(truths, predictions)
    print(f'Validation Accuracy: {accuracy}')


In [None]:
# Example function calls
train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10)
evaluate_model(model, test_loader)

In [None]:
def plot_training_history(history):
    acc = history.history['binary_accuracy']
    val_acc = history.history['val_binary_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(1, len(acc) + 1)

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, acc, 'b-', label='Training accuracy')
    plt.plot(epochs, val_acc, 'r-', label='Validation accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, loss, 'b-', label='Training loss')
    plt.plot(epochs, val_loss, 'r-', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.show()

In [None]:
plot_training_history(itm.history)