In [1]:
import torch
import torch.nn as nn
from transformers import DeiTModel, BertModel, AutoTokenizer, AutoImageProcessor, get_linear_schedule_with_warmup

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import regex
import os
from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision.transforms import transforms
from tqdm import tqdm
from torch.optim import AdamW

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def comp_type_map(comp_type):
    type_to_label = {'Direct': 0, 'Metaphorical': 1, 'Semantic list': 2, 'Reduplication': 3, 'Single': 4}
    return type_to_label[comp_type]

def label_to_comp_type(label):
    label_to_type = {0: 'Direct', 1: 'Metaphorical', 2: 'Semantic list', 3: 'Reduplication', 4: 'Single'}
    return label_to_type[label]

In [3]:
class EmojiMultimodalClassifier(nn.Module):
    def __init__(self, image_model_name="facebook/deit-tiny-patch16-224",
                 text_model_name="bert-base-uncased", num_classes=5):
        super().__init__()
        
        self.image_encoder = DeiTModel.from_pretrained(image_model_name)
        self.image_proj = nn.Linear(self.image_encoder.config.hidden_size, 256)


        self.text_encoder = BertModel.from_pretrained(text_model_name)
        self.text_proj = nn.Linear(self.text_encoder.config.hidden_size, 256)

        self.fusion_proj = nn.Linear(512, 256)
        self.classifier = nn.Linear(256, num_classes)

        for param in self.image_encoder.parameters():
            param.requires_grad = False

        for param in self.text_encoder.parameters():
            param.requires_grad = False
            
    def forward(self, emoji_images, text_input_ids, text_attention_mask):
        """
        emoji_images: list of emoji image tensors (B, N, 3, H, W)
        text_input_ids: (B, T)
        text_attention_mask: (B, T)
        """
        batch_size, num_emojis, C, H, W = emoji_images.shape
        emoji_images = emoji_images.view(-1, C, H, W)  # Flatten to (B*N, C, H, W)

        image_outputs = self.image_encoder(pixel_values=emoji_images)
        image_embeddings = image_outputs.last_hidden_state[:, 0]  # (B*N, D)
        image_embeddings = self.image_proj(image_embeddings)
        image_embeddings = image_embeddings.view(batch_size, num_emojis, -1)
        image_repr = image_embeddings.mean(dim=1)

        text_outputs = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask)
        text_repr = self.text_proj(text_outputs.pooler_output)  # (B, D)

        fused = torch.cat([image_repr, text_repr], dim=1)  # (B, 2D)
        fused = torch.relu(self.fusion_proj(fused))        # (B, D)

        logits = self.classifier(fused)  # (B, num_classes)
        return logits


In [4]:
class EMENDataset(Dataset):
    def __init__(self, df, tokenizer, path='../../images'):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = 4
        self.num_emojis = 3
        self.path = path
        self.transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
    
    def preprocess_emoji_description(self, text):
        text = text.replace('\'\'', '').lower()
        split_text = regex.findall(r'\':?(.*?):?\'', text)
        return split_text
    
    def preprocess_en(self, text):
        return text.lower().strip()
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        emoji_descriptions = self.preprocess_emoji_description(row['Description'])
        comp_type = row['Composition strategy']
        comp_type = comp_type_map(comp_type)
        
        images = [self.transform(Image.open(os.path.join(self.path, f"{description}.png"))) for description in emoji_descriptions]
        images = torch.stack(images)
        if images.shape[0] < self.num_emojis:
            images = torch.cat([images, torch.zeros((self.num_emojis - images.shape[0], 3, 224, 224))])
        elif images.shape[0] > self.num_emojis:
            images = images[:self.num_emojis]
        
        en_text = self.preprocess_en(row['EN'])
        text_tokens = self.tokenizer(
            en_text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt'
        )
        return {
            'emoji_images': images,
            'text_input_ids': text_tokens['input_ids'].squeeze(0),
            'text_attention_mask': text_tokens['attention_mask'].squeeze(0),
            'comp_type': torch.tensor(comp_type, dtype=torch.long)
        }
        

In [5]:
elco_df = pd.read_csv('../../data/ELCo.csv')
train_df, validate_df = train_test_split(elco_df, test_size=0.1, random_state=42)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
train_dataset = EMENDataset(train_df, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validate_dataset = EMENDataset(validate_df, tokenizer)
validate_loader = DataLoader(validate_dataset, batch_size=64, shuffle=False)

In [8]:
num_epochs = 15
learning_rate = 2e-5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EmojiMultimodalClassifier(num_classes=5)
model.to(device)

optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_loader) * num_epochs
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
criterion = torch.nn.CrossEntropyLoss()

You are using a model of type vit to instantiate a model of type deit. This is not supported for all configurations of models and can yield errors.
Some weights of DeiTModel were not initialized from the model checkpoint at facebook/deit-tiny-patch16-224 and are newly initialized: ['embeddings.cls_token', 'embeddings.distillation_token', 'embeddings.patch_embeddings.projection.bias', 'embeddings.patch_embeddings.projection.weight', 'embeddings.position_embeddings', 'encoder.layer.0.attention.attention.key.bias', 'encoder.layer.0.attention.attention.key.weight', 'encoder.layer.0.attention.attention.query.bias', 'encoder.layer.0.attention.attention.query.weight', 'encoder.layer.0.attention.attention.value.bias', 'encoder.layer.0.attention.attention.value.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.layernorm_after.bias', 'enc

In [9]:
for epoch in tqdm(range(num_epochs), desc="Training Epochs"):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    model.train()
    total_loss = 0

    for i, batch in enumerate(train_loader):
        emoji_images = batch["emoji_images"].to(device)       # (B, N, 3, 224, 224)
        input_ids = batch["text_input_ids"].to(device)             # (B, T)
        attention_mask = batch["text_attention_mask"].to(device)   # (B, T)
        labels = batch["comp_type"].to(device)                    # (B,)

        optimizer.zero_grad()
        outputs = model(emoji_images, input_ids, attention_mask)  # (B, num_classes)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # scheduler.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Train Loss: {avg_loss:.4f}")

    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for batch in validate_loader:
            emoji_images = batch["emoji_images"].to(device)
            input_ids = batch["text_input_ids"].to(device)
            attention_mask = batch["text_attention_mask"].to(device)
            labels = batch["comp_type"].to(device)

            outputs = model(emoji_images, input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total
    print(f"Validation Accuracy: {acc:.4f}")

Training Epochs:   0%|          | 0/15 [00:00<?, ?it/s]


Epoch 1/15
Train Loss: 1.5106


Training Epochs:   7%|▋         | 1/15 [01:05<15:15, 65.38s/it]

Validation Accuracy: 0.4398

Epoch 2/15
Train Loss: 1.3353


Training Epochs:  13%|█▎        | 2/15 [02:12<14:23, 66.42s/it]

Validation Accuracy: 0.4398

Epoch 3/15
Train Loss: 1.2318


Training Epochs:  20%|██        | 3/15 [03:19<13:20, 66.68s/it]

Validation Accuracy: 0.4398

Epoch 4/15


Training Epochs:  20%|██        | 3/15 [03:41<14:44, 73.68s/it]


KeyboardInterrupt: 