In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaModel, ViTModel, RobertaTokenizer
from PIL import Image
import torchvision.transforms as transforms
import pandas as pd

In [19]:
class CustomDataset(Dataset):
    def __init__(self, csv_file='rationale.csv', csv_file2='train.csv', transform=None):
        self.data = pd.read_csv(csv_file)
        self.data2 = pd.read_csv(csv_file2)
        self.data = self.data[:10]
        self.data2 = self.data2[:10]
        self.transform = transform
        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data['image_path'][idx]
        image = Image.open(img_name).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        ocr = self.data2['ocr'][idx]
        ocr_tokens = self.tokenizer(ocr, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
        
        llm_rationale = self.data['rationale'][idx]
        rationale_tokens = self.tokenizer(llm_rationale, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
        
        return image, ocr_tokens, rationale_tokens

In [20]:
class MultimodalModel(nn.Module):
    def __init__(self, text_hidden_size=768, image_hidden_size=768, fusion_output_size=512, vocab_size=50265):
        super(MultimodalModel, self).__init__()
        self.text_encoder = RobertaModel.from_pretrained('roberta-base')
        self.image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')
        
        self.fusion = nn.Linear(text_hidden_size + image_hidden_size, fusion_output_size)
        self.rationale_generator = nn.Linear(fusion_output_size, vocab_size)
        
    def forward(self, text_tokens, image):
        text_features = self.text_encoder(**text_tokens).last_hidden_state[:, 0, :]
        image_features = self.image_encoder(image).last_hidden_state[:, 0, :]
        
        fused_features = torch.cat((text_features, image_features), dim=1)
        fused_features = self.fusion(fused_features)
        return fused_features
    
    def generate_next_token(self, initial_features, current_output):
        text_features = self.text_encoder(input_ids=current_output).last_hidden_state[:, -1, :]
        fused_features = torch.cat((text_features, initial_features), dim=1)
        fused_features = self.fusion(fused_features)
        logits = self.rationale_generator(fused_features)
        return logits

def knowledge_distillation_loss(student_logits, teacher_logits, temperature=1.0):
    return nn.KLDivLoss(reduction='batchmean')(F.log_softmax(student_logits / temperature, dim=1),
                                               F.softmax(teacher_logits / temperature, dim=1))

In [4]:
# def knowledge_distillation_loss(student_logits, teacher_logits, temperature=1.0):
#     return nn.KLDivLoss()(F.log_softmax(student_logits / temperature, dim=1),
#                           F.softmax(teacher_logits / temperature, dim=1))

In [5]:
# batch_size = 32
# learning_rate = 1e-4
# num_epochs = 5

In [6]:
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

In [7]:
# dataset = CustomDataset('rationale.csv', 'train.csv', transform=transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [8]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = MultimodalModel().to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# criterion = nn.CrossEntropyLoss()  # For simplicity, using MSE loss for rationale regression

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
def train_model(model, train_loader, num_epochs, device):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    kd_loss_fn = knowledge_distillation_loss
    ce_loss_fn = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in train_loader:
            images, ocr_tokens, llm_rationale_tokens = batch
            images = images.to(device)
            ocr_tokens = {k: v.squeeze(1).to(device) for k, v in ocr_tokens.items()}
            llm_rationale_tokens = llm_rationale_tokens['input_ids'].squeeze(1).to(device)
            
            optimizer.zero_grad()
            
            initial_features = model(ocr_tokens, images)
            
            # Generate rationale token by token
            generated = torch.full((llm_rationale_tokens.size(0), 1), model.text_encoder.config.bos_token_id, dtype=torch.long).to(device)
            
            kd_loss = 0
            ce_loss = 0
            for i in range(llm_rationale_tokens.size(1) - 1):
                next_token_logits = model.generate_next_token(initial_features, generated)
                kd_loss += kd_loss_fn(next_token_logits, model.text_encoder(input_ids=llm_rationale_tokens[:, :i+1]).last_hidden_state[:, -1, :])
                ce_loss += ce_loss_fn(next_token_logits, llm_rationale_tokens[:, i+1])
                
                generated = torch.cat([generated, llm_rationale_tokens[:, i+1].unsqueeze(1)], dim=1)
            
            loss = kd_loss + ce_loss
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

In [None]:
# torch.save(model.state_dict(), 'model_temp1.pth')

In [15]:
def generate_rationale(model, image_path, ocr_text, device, tokenizer, max_length=100):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    ocr_tokens = tokenizer(ocr_text, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
    ocr_tokens = {k: v.to(device) for k, v in ocr_tokens.items()}
    
    model.eval()
    with torch.no_grad():
        initial_features = model(ocr_tokens, image)
        
        generated = torch.full((1, 1), model.text_encoder.config.bos_token_id, dtype=torch.long).to(device)
        
        for _ in range(max_length):
            next_token_logits = model.generate_next_token(initial_features, generated)
            next_token = torch.argmax(next_token_logits, dim=-1)
            
            generated = torch.cat([generated, next_token.unsqueeze(1)], dim=1)
            
            if next_token.item() == model.text_encoder.config.eos_token_id:
                break
    
    generated_rationale = tokenizer.decode(generated.squeeze(), skip_special_tokens=True)
    return generated_rationale

# Load your trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultimodalModel().to(device)
model.load_state_dict(torch.load('model_temp1.pth'))

# Initialize tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

df = pd.read_csv('val.csv',delimiter='\t')
# Example usage
image_path = './valImages/1.jpg'
ocr_text = df['ocr'][10]

generated_rationale = generate_rationale(model, image_path, ocr_text, device, tokenizer)
print("Generated Rationale:")
print(generated_rationale)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


FileNotFoundError: [Errno 2] No such file or directory: 'model_temp1.pth'

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
# Initialize model and dataset
model = MultimodalModel().to(device)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = CustomDataset(transform=transform)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Train the model
num_epochs = 10
train_model(model, train_loader, num_epochs, device)

# Save the trained model
torch.save(model.state_dict(), 'meme_rationale_model.pth')

# Generate a rationale for a single image
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
image_path = 'path_to_your_test_image.jpg'
ocr_text = 'Text extracted from the test image'

generated_rationale = generate_rationale(model, image_path, ocr_text, device, tokenizer)
print("Generated Rationale:")
print(generated_rationale)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x1280 and 1536x512)

In [None]:
# # Set the model to evaluation mode
# model.eval()

# # Load validation data
# dataset = CustomDataset('val_rationale.csv', 'val.csv', transform=transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# # Tokenizer for decoding
# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# # Evaluate the model on the validation set
# with torch.no_grad():
#     for batch in dataloader:
#         images, ocr_tokens, llm_rationale_tokens = batch
#         images = images.to(device)
#         ocr_tokens = {k: v.squeeze(1).to(device) for k, v in ocr_tokens.items()}
#         llm_rationale_tokens = {k: v.squeeze(1).to(device) for k, v in llm_rationale_tokens.items()}
        
#         # Forward pass
#         model_rationale_features = model(ocr_tokens, images)
        
#         # If model_rationale_features are logits, convert to token IDs
#         model_rationale_token_ids = torch.argmax(model_rationale_features, dim=-1)
        
#         # Decode token IDs to text
#         decoded_rationale = tokenizer.decode(model_rationale_token_ids.squeeze().cpu().numpy(), skip_special_tokens=False)
        
#         # Print the decoded rationale text
#         print(decoded_rationale)
