In [1]:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from torchvision import models, transforms
from PIL import Image
import cv2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MELDDataset(Dataset):
    def __init__(self, csv_file, video_folder, tokenizer, transform=None):
        self.data = pd.read_csv(csv_file)
        self.video_folder = video_folder
        self.tokenizer = tokenizer
        self.transform = transform
        self.num_classes = 7  # Adjust based on your specific task

        # Define the emotion mapping here
        self.emotion_map = {
            "neutral": 0,
            "joy": 1,
            "surprise": 2,
            "anger": 3,
            "sadness": 4,
            "disgust": 5,
            "fear": 6
        }
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Load text data
        text = row['Utterance']
        text_inputs = self.tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=128)

        # Load video data (single frame example, extend to multiple frames as needed)
        video_file = os.path.join(self.video_folder, f"{row['Dialogue_ID']}_{row['Utterance_ID']}.mp4")
        cap = cv2.VideoCapture(video_file)
        ret, frame = cap.read()
        cap.release()

        if ret:
            image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            if self.transform:
                image = self.transform(image)
        else:
            # If video is not available or failed to load
            image = torch.zeros((3, 224, 224))  # Dummy image

        # Load label
        label_str  = row['Emotion']
        print(f'{label_str} here is label')
        label = self.emotion_map[label_str]
        label = torch.tensor(label, dtype=torch.long)

        return text_inputs, image, label

In [3]:
class MultiModalModel(nn.Module):
    def __init__(self, num_classes):
        super(MultiModalModel, self).__init__()
        # Text model (BERT)
        self.text_model = BertModel.from_pretrained('bert-base-uncased')
        self.text_fc = nn.Linear(768, 512)

        # Video model (ResNet)
        self.video_model = models.resnet18(pretrained=True)
        self.video_model.fc = nn.Linear(self.video_model.fc.in_features, 512)

        # Final classification layer
        self.fc = nn.Linear(512 + 512, num_classes)

    def forward(self, text_input, video_input):
        # Process text
        text_output = self.text_model(**text_input).last_hidden_state[:, 0, :]
        text_output = self.text_fc(text_output)

        # Process video
        video_output = self.video_model(video_input)

        # Concatenate text and video features
        combined_output = torch.cat((text_output, video_output), dim=1)

        # Classification
        output = self.fc(combined_output)
        return output

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = MELDDataset(csv_file='../MELD.Raw/train_sent_emo.csv', video_folder='train_splits', tokenizer=tokenizer, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

dev_dataset = MELDDataset(csv_file='../MELD.Raw/dev_sent_emo.csv', video_folder='dev_splits_complete', tokenizer=tokenizer, transform=transform)
dev_loader = DataLoader(dev_dataset, batch_size=16, shuffle=False)



In [5]:
emotion_map = {
    "neutral": 0,
    "joy": 1,
    "surprise": 2,
    "anger": 3,
    "sadness": 4,
    "disgust": 5,
    "fear": 6
}

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiModalModel(num_classes=7).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(5):  # Adjust the number of epochs
    model.train()
    for text_inputs, images, labels in train_loader:
        text_inputs = {key: val.squeeze(1).to(device) for key, val in text_inputs.items()}
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(text_inputs, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    # Validation loop can be added here



joy here is label
joy here is label
neutral here is label
surprise here is label
anger here is label
surprise here is label
sadness here is label
neutral here is label
joy here is label
neutral here is label
neutral here is label
anger here is label
joy here is label
neutral here is label
neutral here is label
neutral here is label
neutral here is label
neutral here is label
neutral here is label
neutral here is label
fear here is label
joy here is label
anger here is label
surprise here is label
neutral here is label
neutral here is label
neutral here is label
neutral here is label
disgust here is label
neutral here is label
neutral here is label
joy here is label
surprise here is label
neutral here is label
anger here is label
neutral here is label
neutral here is label
surprise here is label
sadness here is label
sadness here is label
disgust here is label
joy here is label
anger here is label
anger here is label
neutral here is label
anger here is label
joy here is label
sadness he

In [7]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for text_inputs, images, labels in dev_loader:
        text_inputs = {key: val.squeeze(1).to(device) for key, val in text_inputs.items()}
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(text_inputs, images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Validation Accuracy: {100 * correct / total}%')

sadness here is label
surprise here is label
neutral here is label
joy here is label
sadness here is label
neutral here is label
neutral here is label
joy here is label
neutral here is label
surprise here is label
neutral here is label
neutral here is label
surprise here is label
anger here is label
neutral here is label
joy here is label
neutral here is label
neutral here is label
neutral here is label
surprise here is label
surprise here is label
surprise here is label
joy here is label
neutral here is label
neutral here is label
joy here is label
surprise here is label
neutral here is label
neutral here is label
surprise here is label
neutral here is label
neutral here is label
neutral here is label
anger here is label
neutral here is label
neutral here is label
surprise here is label
neutral here is label
surprise here is label
surprise here is label
surprise here is label
neutral here is label
anger here is label
surprise here is label
anger here is label
anger here is label
anger