In [2]:
import torch
import torch.nn as nn
from transformers import RobertaModel, ViTModel, RobertaTokenizer
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
from PIL import Image
import torchvision.transforms as transforms

In [3]:
class CustomDataset(Dataset):
    def __init__(self, csv_file='rationale.csv', csv_file2='train.csv', transform=None):
        self.data = pd.read_csv(csv_file)
        self.data2 = pd.read_csv(csv_file2)
        self.data = self.data[:50]
        self.data2 = self.data2[:50]
        self.transform = transform
        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data['image_path'][idx]
        image = Image.open(img_name).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        ocr = self.data2['ocr'][idx]
        ocr_tokens = self.tokenizer(ocr, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
        
        llm_rationale = self.data['rationale'][idx]
        rationale_tokens = self.tokenizer(llm_rationale, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
        
        return image, ocr_tokens, rationale_tokens

In [6]:
class MultimodalModel(nn.Module):
    def __init__(self, text_hidden_size=768, image_hidden_size=768, fusion_output_size=512):
        super(MultimodalModel, self).__init__()
        self.text_encoder = RobertaModel.from_pretrained('roberta-base')
        self.image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')
        
        self.fusion = nn.Linear(text_hidden_size + image_hidden_size, fusion_output_size)
        self.rationale_generator = nn.Linear(fusion_output_size, text_hidden_size)
        
    def forward(self, text_tokens, image):
        text_features = self.text_encoder(**text_tokens).last_hidden_state[:, 0, :]  # Use [CLS] token
        image_features = self.image_encoder(image).last_hidden_state[:, 0, :]  # Use [CLS] token
        
        fused_features = torch.cat((text_features, image_features), dim=1)
        fused_features = self.fusion(fused_features)
        rationale_features = self.rationale_generator(fused_features)
        
        return rationale_features

def knowledge_distillation_loss(student_features, teacher_features, temperature=1.0):
    return nn.MSELoss()(student_features, teacher_features)

In [4]:
# def knowledge_distillation_loss(student_logits, teacher_logits, temperature=1.0):
#     return nn.KLDivLoss()(F.log_softmax(student_logits / temperature, dim=1),
#                           F.softmax(teacher_logits / temperature, dim=1))

In [7]:
batch_size = 32
learning_rate = 1e-4
num_epochs = 5

In [8]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [9]:
dataset = CustomDataset('rationale.csv', 'train.csv', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultimodalModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()  # For simplicity, using MSE loss for rationale regression

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        images, ocr_tokens, llm_rationale_tokens = batch
        images = images.to(device)
        ocr_tokens = {k: v.squeeze(1).to(device) for k, v in ocr_tokens.items()}
        llm_rationale_tokens = {k: v.squeeze(1).to(device) for k, v in llm_rationale_tokens.items()}
        
        # Forward pass
        model_rationale_features = model(ocr_tokens, images)
        llm_rationale_features = model.text_encoder(**llm_rationale_tokens).last_hidden_state[:, 0, :]
        
        # Calculate losses
        kd_loss = knowledge_distillation_loss(model_rationale_features, llm_rationale_features)
        
        total_loss = kd_loss
        
        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss.item():.4f}")

print("Training finished!")

In [11]:
torch.save(model.state_dict(), 'model_temp.pth')

In [16]:


def generate_rationale(model, image_path, ocr_text, device, tokenizer):
    # Image preprocessing
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    # OCR text preprocessing
    ocr_tokens = tokenizer(ocr_text, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
    ocr_tokens = {k: v.to(device) for k, v in ocr_tokens.items()}

    # Generate rationale
    model.eval()
    with torch.no_grad():
        rationale_features = model(ocr_tokens, image)
    
    # Convert features to text (this step depends on your model's output)
    # Assuming the model's output is logits for each token in the vocabulary
    predicted_token_ids = torch.argmax(rationale_features, dim=-1)
    generated_rationale=""
    for i in predicted_token_ids:
        generated_rationale += tokenizer.decode(predicted_token_ids, skip_special_tokens=True)
    
    return generated_rationale

# Load your trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultimodalModel().to(device)
model.load_state_dict(torch.load('model_temp.pth'))

# Initialize tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

df = pd.read_csv('val.csv',delimiter='\t')
# Example usage
image_path = './valImages/2.jpg'
ocr_text = df['ocr'][2]

generated_rationale = generate_rationale(model, image_path, ocr_text, device, tokenizer)
print("Generated Rationale:")
print(generated_rationale)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Generated Rationale:
 taken


In [None]:
# # Set the model to evaluation mode
# model.eval()

# # Load validation data
# dataset = CustomDataset('val_rationale.csv', 'val.csv', transform=transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# # Tokenizer for decoding
# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# # Evaluate the model on the validation set
# with torch.no_grad():
#     for batch in dataloader:
#         images, ocr_tokens, llm_rationale_tokens = batch
#         images = images.to(device)
#         ocr_tokens = {k: v.squeeze(1).to(device) for k, v in ocr_tokens.items()}
#         llm_rationale_tokens = {k: v.squeeze(1).to(device) for k, v in llm_rationale_tokens.items()}
        
#         # Forward pass
#         model_rationale_features = model(ocr_tokens, images)
        
#         # If model_rationale_features are logits, convert to token IDs
#         model_rationale_token_ids = torch.argmax(model_rationale_features, dim=-1)
        
#         # Decode token IDs to text
#         decoded_rationale = tokenizer.decode(model_rationale_token_ids.squeeze().cpu().numpy(), skip_special_tokens=False)
        
#         # Print the decoded rationale text
#         print(decoded_rationale)
