In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import BertTokenizer, BertModel
from PIL import Image
import json
from torchvision.transforms.functional import to_tensor

In [2]:
# MMF dataset class
class MmfDataset(Dataset):
    def __init__(self, data, image_folder, image_transform, tokenizer):
        self.data = data
        self.image_folder = image_folder
        self.image_transform = image_transform
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]

        # Load and preprocess image
        image_path = self.image_folder + entry["image"]
        image = Image.open(image_path).convert("RGB")
        image = self.image_transform(image)

        # Tokenize and obtain text embeddings using BERT
        text = entry["text"]
        tokens = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True)
        with torch.no_grad():
            text_embedding = bert_model(**tokens).last_hidden_state.mean(dim=1)

        # Label encoding
        label = entry["labels"][0]
        if label == "not harmful":
            encoded_label = 0
        elif label == "somewhat harmful":
            encoded_label = 1
        elif label == "very harmful":
            encoded_label = 2

        # Convert encoded_label to a PyTorch tensor
        encoded_label_tensor = torch.tensor(encoded_label)

        return image, text_embedding, encoded_label_tensor

In [3]:
def collate_fn(batch):
    # Unpack the batch into separate lists for images, text_embeddings, and labels
    images, text_embeddings, labels = zip(*batch)

    # Stack images and text_embeddings into tensors
    images = torch.stack(images)
    text_embeddings = torch.stack(text_embeddings)

    # Stack labels into a tensor
    labels = torch.stack(labels)

    return images, text_embeddings, labels

In [4]:
# Hyperparameters
image_feature_size = 3*224*224 # Image feature size
text_feature_size = 768  # Text feature size
num_classes = 3  # Number of classes

# Instantiate the model
class MmfClassifier(nn.Module):
    def __init__(self, image_feature_size, text_feature_size, num_classes):
        super(MmfClassifier, self).__init__()

        self.shared_layer = nn.Linear(image_feature_size + text_feature_size, 256)
        self.relu = nn.ReLU()
        self.output_layer = nn.Linear(256, num_classes)

    def forward(self, image_data, text_data):

        # Reshape to (batch_size, channels * height * width)
        flattened_image_data = image_data.view(image_data.size(0), -1)

        # Reshape to (batch_size, sequence_length * embedding_size)
        flattened_text_data = text_data.view(text_data.size(0), -1)
        
        combined_features = torch.cat((flattened_image_data, flattened_text_data), dim=1)
        shared_output = self.relu(self.shared_layer(combined_features))
        output = self.output_layer(shared_output)

        return output

In [5]:
model = MmfClassifier(image_feature_size, text_feature_size, num_classes)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Tokenizer and BERT model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
bert_model.eval()

# Image transformation
image_transform = transforms.Compose([
    transforms.Resize((224, 224)), # pads or shrinks the image to 224*224
    transforms.ToTensor(),
])

In [7]:
dataset_path = "C:\\Users\\aysen\\Documents\\GitHub\\harmful_meme_models\\data\\datasets\\memes\\defaults\\annotations\\train.jsonl"
image_folder = "C:\\Users\\aysen\\Documents\\GitHub\\harmful_meme_models\\data\\datasets\\memes\\defaults\\images\\"

# Read the JSON string from the file
with open(dataset_path, "r", encoding='cp437') as file:
    dataset_str = file.read()
    file.close()

# Parse the JSON string
dataset = [json.loads(entry) for entry in dataset_str.strip().split('\n')]

# Create DataLoader
mmf_dataset = MmfDataset(data=dataset, image_folder=image_folder, image_transform=image_transform, tokenizer=tokenizer)
data_loader = DataLoader(mmf_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [8]:
print("First entry in the dataset:")
print(json.dumps(dataset[0], indent=2))

First entry in the dataset:
{
  "id": "covid_memes_18",
  "image": "covid_memes_18.png",
  "labels": [
    "somewhat harmful",
    "individual"
  ],
  "text": "Bernie or Elizabeth?\nBe informed.Compare them on the issues that matter.\nIssue: Who makes the dankest memes?\n"
}


In [9]:
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    for batch in data_loader:

        images, text_embeddings, labels = batch

        # Forward pass
        outputs = model(images, text_embeddings)

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

Epoch 1/5, Loss: 0.6063514947891235
Epoch 2/5, Loss: 0.23800086975097656
Epoch 3/5, Loss: 1.2541553974151611
Epoch 4/5, Loss: 0.26917997002601624
Epoch 5/5, Loss: 0.5931230187416077
