In [1]:
import torch
import torch.nn as nn
from transformers import ViTModel, BertModel

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim_q, dim_kv, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim_q, kdim=dim_kv, vdim=dim_kv, num_heads=num_heads, batch_first=True)
    
    def forward(self, query, key_value):
        out, _ = self.attn(query, key_value, key_value)
        return out

class MemeCrossAttentionClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        self.cross_attn_img_to_txt = CrossAttentionBlock(dim_q=768, dim_kv=768)
        self.cross_attn_txt_to_img = CrossAttentionBlock(dim_q=768, dim_kv=768)

        self.classifier_input_dim = 768 * 2 
        self.sentiment_classifier = nn.Sequential(
            nn.Linear(self.classifier_input_dim, 256),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 3)  # Output: 3 sentiment classes
        )
        
        self.humor_classifier = nn.Sequential(
            nn.Linear(self.classifier_input_dim, 256),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 4)  # Output: 4 humor classes
        )
        
        self.sarcasm_classifier = nn.Sequential(
            nn.Linear(self.classifier_input_dim, 256),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 4)  # Output: 4 sarcasm classes
        )
        
        self.offense_classifier = nn.Sequential(
            nn.Linear(self.classifier_input_dim, 256),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 4)  # Output: 4 offense classes
        )

    def forward(self, image, text):
        vit_out = self.vit(pixel_values=image).last_hidden_state      
        bert_out = self.bert(**text).last_hidden_state                

        img_cls = vit_out[:, 0:1, :]      
        text_cls = bert_out[:, 0:1, :]    

        img_attn = self.cross_attn_img_to_txt(img_cls, bert_out)  
        text_attn = self.cross_attn_txt_to_img(text_cls, vit_out)  

        combined = torch.cat([img_attn.squeeze(1), text_attn.squeeze(1)], dim=-1)  

        return {
            'sentiment_logits': self.sentiment_classifier(combined),
            'humor_logits': self.humor_classifier(combined),
            'sarcasm_logits': self.sarcasm_classifier(combined),
            'offense_logits': self.offense_classifier(combined),
        }


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [3]:
model = MemeCrossAttentionClassifier()
model.load_state_dict(torch.load("/Users/sahilpandey/Projects/Sentiment_Analysis/models/meme_classifier_final.pth", map_location=device))
model.to(device)
model.eval()

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


MemeCrossAttentionClassifier(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (inter

In [4]:
import easyocr
from transformers import BertTokenizer, ViTFeatureExtractor
from PIL import Image
import torchvision.transforms as transforms
import torch

In [5]:
reader = easyocr.Reader(['en'])  # OCR reader
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = MemeCrossAttentionClassifier()
model.load_state_dict(torch.load("path_to_model.pth", map_location=device))
model.to(device)
model.eval()

Downloading detection model, please wait. This may take several minutes depending upon your network connection.


Progress: |██████████████████████████████████████████████████| 100.0% Complete

Downloading recognition model, please wait. This may take several minutes depending upon your network connection.


Progress: |██████████████████████████████████████████████████| 100.0% Complete

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


FileNotFoundError: [Errno 2] No such file or directory: 'path_to_model.pth'