In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from transformers import BertTokenizer, BertModel, ViTModel, ViTFeatureExtractor
from PIL import Image
import torchvision.transforms as transforms

# Custom Dataset definition
class CrisisDataset(Dataset):
    def __init__(self, dataframe, tokenizer, feature_extractor, transform=None):
        self.df = dataframe.copy()
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.transform = transform
        # Map string labels to binary values (informative: 1, not_informative: 0)
        self.df['label_image_bin'] = self.df['label_image'].map({'informative': 1, 'not_informative': 0})
        self.df['label_bin'] = self.df['label'].map({'informative': 1, 'not_informative': 0})
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        tweet_text = row['tweet_text']
        image_path = row['image']
        
        # Load image and convert to RGB
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        else:
            image = self.feature_extractor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)
            
        # Tokenize tweet text
        encoding = self.tokenizer(tweet_text, padding="max_length", truncation=True, 
                                  max_length=128, return_tensors="pt")
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # Convert labels to tensors (as floats for BCE loss)
        label_image = torch.tensor(row['label_image_bin'], dtype=torch.float)
        label = torch.tensor(row['label_bin'], dtype=torch.float)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'image': image,
            'label_image': label_image,
            'label': label
        }

# Text Classifier using BERT
class TextClassifier(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased'):
        super(TextClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output  # [batch_size, hidden_size]
        logits = self.classifier(pooled_output)  # [batch_size, 1]
        return logits

# Image Classifier using ViT
class ImageClassifier(nn.Module):
    def __init__(self, vit_model_name='google/vit-base-patch16-224-in21k'):
        super(ImageClassifier, self).__init__()
        self.vit = ViTModel.from_pretrained(vit_model_name)
        self.classifier = nn.Linear(self.vit.config.hidden_size, 1)
        
    def forward(self, images):
        outputs = self.vit(pixel_values=images)
        cls_output = outputs.last_hidden_state[:, 0]  # [batch_size, hidden_size]
        logits = self.classifier(cls_output)          # [batch_size, 1]
        return logits

# Multimodal Classifier combining BERT and ViT
class MultiModalClassifier(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased', vit_model_name='google/vit-base-patch16-224-in21k'):
        super(MultiModalClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.vit = ViTModel.from_pretrained(vit_model_name)
        combined_size = self.bert.config.hidden_size + self.vit.config.hidden_size
        self.classifier = nn.Linear(combined_size, 1)
        
    def forward(self, input_ids, attention_mask, images):
        text_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_text = text_outputs.pooler_output  # [batch_size, hidden_size]
        image_outputs = self.vit(pixel_values=images)
        cls_image = image_outputs.last_hidden_state[:, 0]  # [batch_size, hidden_size]
        combined = torch.cat((pooled_text, cls_image), dim=1)  # [batch_size, combined_size]
        logits = self.classifier(combined)  # [batch_size, 1]
        return logits

def calculate_accuracy(logits, labels):
    # Apply sigmoid to logits and threshold at 0.5
    preds = (torch.sigmoid(logits) >= 0.5).float()
    correct = (preds == labels).sum().item()
    return correct / labels.size(0)

if __name__ == '__main__':
    # File paths for your datasets
    train_file_path = 'crisismmd_datasplit_all/crisismmd_datasplit_all/task_informative_text_img_train.tsv'
    dev_file_path = 'crisismmd_datasplit_all/crisismmd_datasplit_all/task_informative_text_img_dev.tsv'
    test_file_path = 'crisismmd_datasplit_all/crisismmd_datasplit_all/task_informative_text_img_test.tsv'
    
    # Load datasets
    train_data = pd.read_csv(train_file_path, sep='\t')
    dev_data = pd.read_csv(dev_file_path, sep='\t')
    test_data = pd.read_csv(test_file_path, sep='\t')
    
    # Initialize tokenizer and feature extractor
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
    
    # Define image transform for ViT requirements
    image_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    ])
    
    # Create dataset and dataloader
    train_dataset = CrisisDataset(train_data, tokenizer, feature_extractor, transform=image_transform)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    
    # Instantiate models
    text_model = TextClassifier()
    image_model = ImageClassifier()
    multimodal_model = MultiModalClassifier()
    
    # Define loss function
    criterion = nn.BCEWithLogitsLoss()
    
    # Define optimizers
    text_optimizer = optim.Adam(text_model.parameters(), lr=2e-5)
    image_optimizer = optim.Adam(image_model.parameters(), lr=2e-5)
    multimodal_optimizer = optim.Adam(multimodal_model.parameters(), lr=2e-5)
    
    # Set models to training mode
    text_model.train()
    image_model.train()
    multimodal_model.train()
    
    # Sample training loop for one epoch (here, we loop over one batch for demonstration)
    for batch in train_loader:
        input_ids = batch['input_ids']         # [batch_size, seq_len]
        attention_mask = batch['attention_mask']
        images = batch['image']                  # [batch_size, 3, 224, 224]
        label_image = batch['label_image'].unsqueeze(1)  # [batch_size, 1]
        label = batch['label'].unsqueeze(1)              # [batch_size, 1]
        
        # Zero gradients
        text_optimizer.zero_grad()
        image_optimizer.zero_grad()
        multimodal_optimizer.zero_grad()
        
        # Forward pass for text model (predicts label_image)
        text_logits = text_model(input_ids, attention_mask)
        text_loss = criterion(text_logits, label_image)
        text_acc = calculate_accuracy(text_logits, label_image)
        
        # Forward pass for image model (predicts label_image)
        image_logits = image_model(images)
        image_loss = criterion(image_logits, label_image)
        image_acc = calculate_accuracy(image_logits, label_image)
        
        # Forward pass for multimodal model (predicts label)
        multimodal_logits = multimodal_model(input_ids, attention_mask, images)
        multimodal_loss = criterion(multimodal_logits, label)
        multimodal_acc = calculate_accuracy(multimodal_logits, label)
        
        # Backward passes and optimizer steps
        text_loss.backward()
        text_optimizer.step()
        
        image_loss.backward()
        image_optimizer.step()
        
        multimodal_loss.backward()
        multimodal_optimizer.step()
        
        # Print loss and accuracy for each model
        print("Batch Metrics:")
        print("  Text Model    -> Loss: {:.4f} | Accuracy: {:.4f}".format(text_loss.item(), text_acc))
        print("  Image Model   -> Loss: {:.4f} | Accuracy: {:.4f}".format(image_loss.item(), image_acc))
        print("  Multimodal    -> Loss: {:.4f} | Accuracy: {:.4f}".format(multimodal_loss.item(), multimodal_acc))
        
        # For demonstration, we process only one batch. Remove the break to run through all batches.
        break




RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
