In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import LxmertTokenizer, LxmertForPreTraining, AdamW
from torchvision import transforms
from PIL import Image
import json
import os
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, f1_score

# NLVR2 Dataset
class NLVR2Dataset(Dataset):
    def __init__(self, json_file, img_dir, tokenizer, transform=None):
        with open(json_file, 'r') as f:
            self.data = json.load(f)
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['sentence']
        label = int(item['label'])
        img_left = Image.open(os.path.join(self.img_dir, item['identifier'] + '-img0.png')).convert('RGB')
        img_right = Image.open(os.path.join(self.img_dir, item['identifier'] + '-img1.png')).convert('RGB')

        if self.transform:
            img_left = self.transform(img_left)
            img_right = self.transform(img_right)

        inputs = self.tokenizer(text, padding='max_length', truncation=True, return_tensors='pt')
        
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'img_left': img_left,
            'img_right': img_right,
            'label': torch.tensor(label, dtype=torch.long)
        }

# Model definition
class LXMERTFORNLVR2(torch.nn.Module):
    def __init__(self, num_labels=2):
        super().__init__()
        self.lxmert = LxmertForPreTraining.from_pretrained("unc-nlp/lxmert-base-uncased")
        self.classifier = torch.nn.Linear(self.lxmert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask, visual_feats, visual_pos):
        outputs = self.lxmert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            visual_feats=visual_feats,
            visual_pos=visual_pos,
        )
        pooled_output = outputs.pooled_output
        logits = self.classifier(pooled_output)
        return logits

# Training function
def train(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        img_left = batch['img_left'].to(device)
        img_right = batch['img_right'].to(device)
        labels = batch['label'].to(device)

        visual_feats = torch.cat([img_left, img_right], dim=1)
        visual_pos = torch.arange(2 * 196).repeat(input_ids.size(0), 1).to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask, visual_feats, visual_pos)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Evaluation function
def evaluate(model, val_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            img_left = batch['img_left'].to(device)
            img_right = batch['img_right'].to(device)
            labels = batch['label'].to(device)

            visual_feats = torch.cat([img_left, img_right], dim=1)
            visual_pos = torch.arange(2 * 196).repeat(input_ids.size(0), 1).to(device)

            outputs = model(input_ids, attention_mask, visual_feats, visual_pos)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='binary')
    return accuracy, f1

# Main training loop
def main():
    # Hyperparameters
    batch_size = 32
    num_epochs = 5
    learning_rate = 2e-5

    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Tokenizer and model
    tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")
    model = LXMERTFORNLVR2().to(device)

    # Data preparation
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dataset = NLVR2Dataset('path/to/train.json', 'path/to/train/images', tokenizer, transform)
    val_dataset = NLVR2Dataset('path/to/val.json', 'path/to/val/images', tokenizer, transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Optimizer
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    # Training loop
    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, optimizer, device)
        val_accuracy, val_f1 = evaluate(model, val_loader, device)
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Validation Accuracy: {val_accuracy:.4f}")
        print(f"Validation F1 Score: {val_f1:.4f}")

    # Save the model
    torch.save(model.state_dict(), 'lxmert_nlvr2_model.pth')

if __name__ == "__main__":
    main()