In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import BertTokenizer, BertModel
from PIL import Image
import json
from torchvision.transforms.functional import to_tensor

In [None]:
# Tokenizer and BERT model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
bert_model.eval()

# Image transformation
image_transform = transforms.Compose([
    transforms.Resize((224, 224)), # pads or shrinks the image to 224*224
    transforms.ToTensor(),
])

In [2]:
# MMF dataset class
class MmfDataset(Dataset):
    def __init__(self, data, image_folder, image_transform, tokenizer):
        self.data = data
        self.image_folder = image_folder
        self.image_transform = image_transform
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]

        # Load and preprocess image
        image_path = self.image_folder + entry["image"]
        image = Image.open(image_path).convert("RGB")
        image = self.image_transform(image)

        # Tokenize and obtain text embeddings using BERT
        text = entry["text"]
        tokens = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True)
        with torch.no_grad():
            text_embedding = bert_model(**tokens).last_hidden_state.mean(dim=1)

        # Label encoding
        label = entry["labels"][0]
        if label == "not harmful":
            encoded_label = 0
        elif label == "somewhat harmful":
            encoded_label = 1
        elif label == "very harmful":
            encoded_label = 2

        # Convert encoded_label to a PyTorch tensor
        encoded_label_tensor = torch.tensor(encoded_label)

        return image, text_embedding, encoded_label_tensor

In [None]:
def collate_fn(batch):
    # Unpack the batch into separate lists for images, text_embeddings, and labels
    images, text_embeddings, labels = zip(*batch)

    # Stack images and text_embeddings into tensors
    images = torch.stack(images)
    text_embeddings = torch.stack(text_embeddings)

    # Stack labels into a tensor
    labels = torch.stack(labels)

    return images, text_embeddings, labels

In [None]:
dataset_path = "C:\\Users\\aysen\\Documents\\GitHub\\harmful_meme_models\\data\\datasets\\memes\\defaults\\annotations\\train.jsonl"
image_folder = "C:\\Users\\aysen\\Documents\\GitHub\\harmful_meme_models\\data\\datasets\\memes\\defaults\\images\\"

# Read the JSON string from the file
with open(dataset_path, "r", encoding='cp437') as file:
    dataset_str = file.read()
    file.close()

# Parse the JSON string
dataset = [json.loads(entry) for entry in dataset_str.strip().split('\n')]

# Split the dataset into training and validation sets
train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

# Create DataLoader instances for both training and validation sets
mmf_dataset_train = MmfDataset(data=train_dataset, image_folder=image_folder, image_transform=image_transform, tokenizer=tokenizer)
data_loader_train = DataLoader(mmf_dataset_train, batch_size=32, shuffle=True, collate_fn=collate_fn)

mmf_dataset_val = MmfDataset(data=val_dataset, image_folder=image_folder, image_transform=image_transform, tokenizer=tokenizer)
data_loader_val = DataLoader(mmf_dataset_val, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:
# Model definition
class MmfClassifier(nn.Module):
    
    def __init__(self, image_feature_size, text_feature_size, num_classes):
        
        super(MmfClassifier, self).__init__()
        self.shared_layer = nn.Linear(image_feature_size + text_feature_size, 256)
        self.relu = nn.LeakyReLU()
        self.output_layer = nn.Linear(256, num_classes)

    def forward(self, image_data, text_data):

        # Reshape to (batch_size, channels*height*width)
        flattened_image_data = image_data.view(image_data.size(0), -1)

        # Reshape to (batch_size, sequence_length*embedding_size)
        flattened_text_data = text_data.view(text_data.size(0), -1)

        # Combine visual and textual features 
        combined_features = torch.cat((flattened_image_data, flattened_text_data), dim=1)
        shared_output = self.relu(self.shared_layer(combined_features))
        output = self.output_layer(shared_output)

        return output

In [None]:
# Hyperparameters
image_feature_size = 3*224*224 # Image feature size
text_feature_size = 768  # Text feature size
num_classes = 3  # Number of classes

# Instantiate the model
model = MmfClassifier(image_feature_size, text_feature_size, num_classes)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
print("First entry in the dataset:")
print(json.dumps(dataset[0], indent=2))

In [None]:
# Training and Validation loops
num_epochs = 5

for epoch in range(num_epochs):
    
    # Training loop
    model.train()
    for batch in data_loader_train:
        images, text_embeddings, labels = batch
        outputs = model(images, text_embeddings)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validation loop
    model.eval()
    
    val_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    with torch.no_grad():
        for batch_val in data_loader_val:
            images_val, text_embeddings_val, labels_val = batch_val
            outputs_val = model(images_val, text_embeddings_val)
            loss_val = criterion(outputs_val, labels_val)
            val_loss += loss_val.item()

            _, predicted = torch.max(outputs_val, 1)
            correct_predictions += (predicted == labels_val).sum().item()
            total_samples += labels_val.size(0)

    avg_val_loss = val_loss / len(data_loader_val)
    accuracy = correct_predictions / total_samples

    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {loss.item()}, Validation Loss: {avg_val_loss}, Validation Accuracy: {accuracy}")