In [1]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import PIL
import torchvision
import numpy
import pandas
import torch 
import torch.optim as optim
import gc
from torch.optim.lr_scheduler import StepLR
import cv2
import os
import json
import numpy as np
from transformers import BertModel, BertTokenizer
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
from transformers import T5EncoderModel
from transformers import GPT2Tokenizer, GPT2Model
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import requests
from tqdm import tqdm
import re 
import string 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
PATH_DATASETS = "../datasets"
PATH_JSON_TRAIN = os.path.join(PATH_DATASETS, "data/subtask2a/train.json") 
PATH_JSON_VAL = os.path.join(PATH_DATASETS, "data/subtask2a/validation.json") 
PATH_JSON_DEV = os.path.join(PATH_DATASETS, "dev_gold_labels/dev_subtask2a_en.json") 
PATH_JSON_TEST = os.path.join(PATH_DATASETS, "test_data/english/en_subtask2a_test_unlabeled.json") 


PATH_IMG_TRAIN = os.path.join(PATH_DATASETS, "train_images") 
PATH_IMG_VAL = os.path.join(PATH_DATASETS, "validation_images") 
PATH_IMG_DEV = os.path.join(PATH_DATASETS, "dev_images") 
PATH_IMG_TEST = os.path.join(PATH_DATASETS, "test_images/subtask1_2a/english") 

PATH_SAVE_MODEL = "subtask2a_models"
PATH_SAVE_SUBMISSION = "subtask2a_submissions"

os.makedirs(PATH_SAVE_MODEL, exist_ok=True)
os.makedirs(PATH_SAVE_SUBMISSION, exist_ok=True)

BERT_MODEL = 'limjiayi/bert-hateful-memes-expanded' 
NUM_CLASSES = 22

BATCH_SIZE = 8

EPOCHS_FULL = 3
LR_FULL = 1e-5

EPOCHS_FC = 0
LR_FC = 3e-6

TRAIN_ALL = True

In [3]:
data = json.load(open(PATH_JSON_TRAIN,"r",encoding='utf-8'))

print(data[0])

{'id': '63292', 'text': "This is why we're free\\n\\nThis is why we're safe\\n", 'image': 'prop_meme_556.png', 'labels': ['Causal Oversimplification', 'Transfer', 'Flag-waving'], 'link': 'https://www.facebook.com/SilentmajorityDJT/photos/2119966118152814/'}


In [4]:
def preprocess(text):
    return text

In [5]:
transform = torchvision.transforms.Compose([
                #torchvision.transforms.ToPILImage(),
                #torchvision.transforms.Resize((224,224),interpolation = PIL.Image.BICUBIC),
                #torchvision.transforms.ToTensor(),
                #torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

In [6]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k',do_resize = True,do_rescale = True,do_normalize = True,image_mean = [0.5,0.5,0.5],image_std = [0.5,0.5,0.5])

tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
text_model = AutoModel.from_pretrained(BERT_MODEL)

  return self.fget.__get__(instance, owner)()
Some weights of BertModel were not initialized from the model checkpoint at limjiayi/bert-hateful-memes-expanded and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
class MyDataset(Dataset):
    
    def __init__(self, paths_json_img, bin_classes):
        self.filenames = []
        self.texts = []
        self.images = []
        self.ids = []
        self.labels = []
        
        for path_json, path_img in paths_json_img:
            print(path_json)
            data = json.load(open(path_json,"r",encoding='utf-8'))

            for x in tqdm(data):
                currentPath = os.path.join(path_img,x['image'])

                self.ids.append(x['id'])

                if 'labels' in x:
                    curr_labels = []
                    for bin_class in bin_classes:
                        if bin_class in x['labels']:
                            curr_labels.append(1)
                        else:
                            curr_labels.append(0)
                    self.labels.append(curr_labels)
                else:
                    self.labels.append([])

                text = preprocess(x['text'])
                if text is None:
                    text = ""
                self.texts.append(tokenizer(text,return_tensors='pt',padding='max_length',max_length=128,truncation=True))
                self.filenames.append(x['image'])

                currentImage = cv2.imread(currentPath)
                currentImage = torch.tensor(transform(currentImage)).unsqueeze(0)
                features = processor(currentImage)
                self.images.append(features)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,idx):
        text_tensors = {}
        for key, value in self.texts[idx].items():
            text_tensors[key] = value.cuda() if isinstance(value, torch.Tensor) else value
        
        
        image_tensors = {}
        for key, value in self.images[idx].items():
            image_tensors[key] = value.cuda() if isinstance(value, torch.Tensor) else value
        
        return ((image_tensors,text_tensors),torch.tensor(self.labels[idx]))

In [8]:
#torchvision.models.efficientnet_b0(pretrained=True)
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # Define text and image encoders
        self.text_encoder = AutoModel.from_pretrained(BERT_MODEL)
        
        self.image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        
        self.fc = nn.Linear(249600, NUM_CLASSES)  # Adjust num_classes accordingly
    def forward(self,  images,text_input):
        # Process text input
        
        text_outputs = []

        for i in range(text_input['input_ids'].shape[0]):
            x = dict()
            x['input_ids'] = text_input['input_ids'][i]
            x['token_type_ids'] = text_input['token_type_ids'][i]
            x['attention_mask'] = text_input['attention_mask'][i]
            text_outputs.append(self.text_encoder(**x).last_hidden_state)
            
            
        text_outputs = torch.stack(text_outputs)
        image_outputs = []
        
        for i in range(images['pixel_values'][0].shape[0]):
            x = dict()
            x['pixel_values'] = images['pixel_values'][0][i].unsqueeze(0).cuda()
          
            image_outputs.append(self.image_encoder(**x).last_hidden_state)
        
        image_outputs = torch.stack(image_outputs)

        # Flatten and concatenate the outputs
        text_outputs = text_outputs.view(text_outputs.size(0), -1)
        
        image_outputs = image_outputs.view(image_outputs.size(0), -1)
        combined = torch.cat((text_outputs, image_outputs), dim=1)
        
        # Pass through fully connected layer
        output = nn.Sigmoid()(self.fc(nn.Tanh()(combined)))
        return output

In [9]:
data = json.load(open(PATH_JSON_TRAIN,"r",encoding='utf-8'))

bin_classes = []

for x in data:
    for label in x['labels']:
        if label not in bin_classes:
            bin_classes.append(label)

print(len(bin_classes), NUM_CLASSES)
print(bin_classes)

22 22
['Causal Oversimplification', 'Transfer', 'Flag-waving', 'Black-and-white Fallacy/Dictatorship', 'Smears', 'Loaded Language', 'Glittering generalities (Virtue)', 'Thought-terminating cliché', 'Whataboutism', 'Slogans', 'Doubt', 'Name calling/Labeling', 'Repetition', 'Appeal to authority', 'Appeal to (Strong) Emotions', 'Reductio ad hitlerum', 'Appeal to fear/prejudice', 'Exaggeration/Minimisation', "Misrepresentation of Someone's Position (Straw Man)", 'Obfuscation, Intentional vagueness, Confusion', 'Bandwagon', 'Presenting Irrelevant Data (Red Herring)']


In [10]:
if TRAIN_ALL:
    train_data = MyDataset([(PATH_JSON_TRAIN, PATH_IMG_TRAIN), (PATH_JSON_VAL, PATH_IMG_VAL), (PATH_JSON_DEV, PATH_IMG_DEV)], bin_classes)
    # train_data = MyDataset([(PATH_JSON_VAL, PATH_IMG_VAL)], bin_classes)
else:
    # train_data = MyDataset([(PATH_JSON_DEV, PATH_IMG_DEV)], bin_classes)
    train_data = MyDataset([(PATH_JSON_TRAIN, PATH_IMG_TRAIN), (PATH_JSON_DEV, PATH_IMG_DEV)], bin_classes)
valid_data = MyDataset([(PATH_JSON_VAL, PATH_IMG_VAL)], bin_classes)
test_data = MyDataset([(PATH_JSON_TEST, PATH_IMG_TEST)], bin_classes)

train_dataloader = DataLoader(dataset = train_data, batch_size = BATCH_SIZE, shuffle = True)
valid_dataloader = DataLoader(dataset = valid_data, batch_size = BATCH_SIZE, shuffle = False)
test_dataloader = DataLoader(dataset = test_data, batch_size = 1, shuffle = False)

../datasets\data/subtask2a/train.json


  0%|          | 0/7000 [00:00<?, ?it/s]

100%|██████████| 7000/7000 [01:46<00:00, 65.79it/s]


../datasets\data/subtask2a/validation.json


100%|██████████| 500/500 [00:07<00:00, 63.93it/s]


../datasets\dev_gold_labels/dev_subtask2a_en.json


100%|██████████| 1000/1000 [00:15<00:00, 64.43it/s]


../datasets\data/subtask2a/validation.json


100%|██████████| 500/500 [00:07<00:00, 63.58it/s]


../datasets\test_data/english/en_subtask2a_test_unlabeled.json


100%|██████████| 1500/1500 [00:22<00:00, 65.24it/s]


In [11]:
# m = nn.Sigmoid()
# loss = nn.BCELoss()
# input = torch.randn(3, requires_grad=True)
# target = torch.empty(3).random_(2)

# print(input)
# print(target)
# output = loss(m(input), target)

In [12]:
print(len(train_data.images[0]['pixel_values']))
print(len(train_data))
print(train_data.texts[0]['input_ids'].shape)

model = Model()
model.cuda()
model.train()

best_thresh_all = []

criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LR_FULL)

best_loss = 1e9

for epoch in range(EPOCHS_FULL):

    train_loss = 0.0    
    model.train()
    for useless_id, ((images_batch, texts_batch), labels_batch) in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()

        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')

        labels_predictions = model(images_batch, texts_batch)

#         print(labels_predictions.shape)
#         print(labels_batch.shape)
        
#         print(labels_predictions.type())
#         print(labels_batch.type())
        
#         print(labels_predictions)
#         print(labels_batch)
        
        loss = criterion(labels_predictions, labels_batch)
        loss.backward()

        optimizer.step()

        train_loss = train_loss + loss.item()

    # Validation loop
    validation_loss = 0.0
    model.eval()
    correct = 0
    total = 0

    all_val_pred = [[] for _ in range(NUM_CLASSES)] 
    all_val_gt = [[] for _ in range(NUM_CLASSES)] 

    for useless_id, ((images_batch, texts_batch), labels_batch) in tqdm(enumerate(valid_dataloader)):
        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')
        labels_predictions = model(images_batch, texts_batch)

        loss = criterion(labels_predictions, labels_batch)


        validation_loss = validation_loss + loss.item()

        predicted = (labels_predictions > 0.5)
        
        total += labels_batch.size(0)
        correct += (predicted == labels_batch).sum().item()

        cpu_labels_predictions = labels_predictions.to('cpu').tolist()
        cpu_labels_batch = labels_batch.to('cpu').tolist()

        for bat in range(len(cpu_labels_predictions)):
            for i in range(NUM_CLASSES):
                all_val_pred[i].append(cpu_labels_predictions[bat][i])
                all_val_gt[i].append(cpu_labels_batch[bat][i])


    best_thresh_all = []
    print("BEST THRESHOLDS")
    for i in range(NUM_CLASSES):
        zipped_pred_gt = list(zip(all_val_pred[i], all_val_gt[i]))
        zipped_pred_gt.sort()

        best_thresh = 0
        best_f1 = 0
        tp = sum(all_val_gt[i])
        fp = len(all_val_gt[i]) - tp
        fn = 0
        for x in zipped_pred_gt:
            if x[1] == 1:
                tp -= 1
                fn += 1
            else:
                fp -= 1

            if tp > 0:
                curr_f1 = 2*tp / (2*tp + fp + fn)
                if curr_f1 > best_f1:
                    best_f1 = curr_f1
                    best_thresh = x[0] 
        best_thresh_all.append(best_thresh)

        print(f"{bin_classes[i]} : best_thresh={best_thresh} , best_f1={best_f1}")
    print()

    train_loss /= len(train_dataloader.dataset)
    validation_loss /= len(train_dataloader.dataset)
    accuracy = (correct / total) / len(bin_classes)
    print(f'Epoch: {epoch} Train Loss: {train_loss} Validation Loss: {validation_loss} Validation Accuracy thresh=0.5: {accuracy * 100:.2f}%')

# Save checkpoint if needed
# checkpoint = {'checkpoint': model.state_dict()}
# torch.save(checkpoint, os.path.join(PATH_SAVE_MODEL, f'checkpoint_{epoch}.pt'))
print(f'Checkpoint reached! Validation loss modified from {best_loss} to {validation_loss}')
best_loss = validation_loss
torch.cuda.empty_cache()

    
    
    
for param in model.text_encoder.parameters():
    param.requires_grad = False

for param in model.image_encoder.parameters():
    param.requires_grad = False 

optimizer = torch.optim.Adam(model.parameters(), lr = LR_FC)
best_loss = 1e9


for epoch in range(EPOCHS_FC):

    train_loss = 0.0    
    model.train()
    for useless_id, ((images_batch, texts_batch), labels_batch) in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()

        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')

        labels_predictions = model(images_batch, texts_batch)

        loss = criterion(labels_predictions, labels_batch)
        loss.backward()

        optimizer.step()

        train_loss = train_loss + loss.item()

    # Validation loop
    validation_loss = 0.0
    model.eval()
    correct = 0
    total = 0

    for useless_id, ((images_batch, texts_batch), labels_batch) in tqdm(enumerate(valid_dataloader)):
        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')

        labels_predictions = model(images_batch, texts_batch)
        loss = criterion(labels_predictions, labels_batch)
        validation_loss = validation_loss + loss.item()

        predicted = (labels_predictions > 0.5)
        total += labels_batch.size(0)
        correct += (predicted == labels_batch).sum().item()


    train_loss /= len(train_dataloader.dataset)
    validation_loss /= len(train_dataloader.dataset)
    accuracy = (correct / total) / len(bin_classes)

    print(f'Epoch: {epoch} Train Loss: {train_loss} Validation Loss: {validation_loss} Validation Accuracy thresh=0.5: {accuracy * 100:.2f}%')





# Save checkpoint if needed
# checkpoint = {'checkpoint': model.state_dict()}
# torch.save(checkpoint, os.path.join(PATH_SAVE_MODEL, f'fc_checkpoint_{epoch}.pt'))
print(f'Checkpoint reached! Validation loss modified from {best_loss} to {validation_loss}')
best_loss = validation_loss
torch.cuda.empty_cache()

checkpoint = {'checkpoint': model.state_dict()}
torch.save(checkpoint, os.path.join(PATH_SAVE_MODEL, f'checkpoint.pt'))

#import torch
# model.train()
# checkpoint = torch.load(os.path.join(PATH_SAVE_MODEL, f'fc_checkpoint_{4}.pt'))

# # Apply the state dictionary to the model
# model.load_state_dict(checkpoint['checkpoint'])
                    


1
8500
torch.Size([1, 128])


Some weights of BertModel were not initialized from the model checkpoint at limjiayi/bert-hateful-memes-expanded and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1063it [07:25,  2.39it/s]
63it [00:07,  8.72it/s]


BEST THRESHOLDS
Causal Oversimplification : best_thresh=0.13298802077770233 , best_f1=0.43636363636363634
Transfer : best_thresh=0.3453764319419861 , best_f1=0.6885245901639344
Flag-waving : best_thresh=0.2096594274044037 , best_f1=0.7394957983193278
Black-and-white Fallacy/Dictatorship : best_thresh=0.5069349408149719 , best_f1=0.6382978723404256
Smears : best_thresh=0.5679747462272644 , best_f1=0.8200371057513914
Loaded Language : best_thresh=0.3873363435268402 , best_f1=0.7035830618892508
Glittering generalities (Virtue) : best_thresh=0.11461390554904938 , best_f1=0.5693430656934306
Thought-terminating cliché : best_thresh=0.31492480635643005 , best_f1=0.5396825396825397
Whataboutism : best_thresh=0.09735649079084396 , best_f1=0.5555555555555556
Slogans : best_thresh=0.13925980031490326 , best_f1=0.5714285714285714
Doubt : best_thresh=0.2307792752981186 , best_f1=0.6037735849056604
Name calling/Labeling : best_thresh=0.3842535614967346 , best_f1=0.6842105263157895
Repetition : best_

1063it [07:17,  2.43it/s]
63it [00:07,  8.93it/s]


BEST THRESHOLDS
Causal Oversimplification : best_thresh=0.16787345707416534 , best_f1=0.8
Transfer : best_thresh=0.3084378242492676 , best_f1=0.865546218487395
Flag-waving : best_thresh=0.3643975555896759 , best_f1=0.8623853211009175
Black-and-white Fallacy/Dictatorship : best_thresh=0.37753722071647644 , best_f1=0.8785046728971962
Smears : best_thresh=0.4502951502799988 , best_f1=0.90625
Loaded Language : best_thresh=0.6506150364875793 , best_f1=0.8148148148148148
Glittering generalities (Virtue) : best_thresh=0.45759671926498413 , best_f1=0.8163265306122449
Thought-terminating cliché : best_thresh=0.22561617195606232 , best_f1=0.7945205479452054
Whataboutism : best_thresh=0.2015017420053482 , best_f1=0.88
Slogans : best_thresh=0.26299503445625305 , best_f1=0.8076923076923077
Doubt : best_thresh=0.23064018785953522 , best_f1=0.7407407407407407
Name calling/Labeling : best_thresh=0.3651345372200012 , best_f1=0.9024390243902439
Repetition : best_thresh=0.4935813248157501 , best_f1=0.863

1063it [07:15,  2.44it/s]
63it [00:07,  8.95it/s]


BEST THRESHOLDS
Causal Oversimplification : best_thresh=0.41665977239608765 , best_f1=0.9047619047619048
Transfer : best_thresh=0.4611648619174957 , best_f1=0.9626556016597511
Flag-waving : best_thresh=0.4242860972881317 , best_f1=0.957983193277311
Black-and-white Fallacy/Dictatorship : best_thresh=0.4716000258922577 , best_f1=0.9464285714285714
Smears : best_thresh=0.3237995207309723 , best_f1=0.9659090909090909
Loaded Language : best_thresh=0.5875219702720642 , best_f1=0.9191176470588235
Glittering generalities (Virtue) : best_thresh=0.35387077927589417 , best_f1=0.9696969696969697
Thought-terminating cliché : best_thresh=0.4726349115371704 , best_f1=0.9459459459459459
Whataboutism : best_thresh=0.28403353691101074 , best_f1=0.9795918367346939
Slogans : best_thresh=0.30799734592437744 , best_f1=0.9357798165137615
Doubt : best_thresh=0.4400121569633484 , best_f1=0.9333333333333333
Name calling/Labeling : best_thresh=0.19074812531471252 , best_f1=0.9590163934426229
Repetition : best_th

In [31]:
print(best_thresh_all)

best_thresh_all = [0.07052270323038101, 0.2133808732032776, 0.4325084984302521, 0.25856152176856995, 0.07286142557859421, 0.35122835636138916, 0.20894578099250793, 0.36188608407974243, 0.17205330729484558, 0.2361125946044922, 0.36091744899749756, 0.28508374094963074, 0.2983148694038391, 0.462566077709198, 0.07985731214284897, 0.5132802724838257, 0.07505106925964355, 0.18899203836917877, 0.12206489592790604, 0.033173780888319016, 0.10967455059289932, 0.020485831424593925]

best_thresh_all = [max(0.2, x) * 1 for x in best_thresh_all]

[0.25, 0.25, 0.4325084984302521, 0.25856152176856995, 0.25, 0.35122835636138916, 0.25, 0.36188608407974243, 0.25, 0.25, 0.36091744899749756, 0.28508374094963074, 0.2983148694038391, 0.462566077709198, 0.25, 0.5132802724838257, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]


In [32]:
predictions = {}
ids = []

for useless_id, ((images_batch, texts_batch), labels_batch) in tqdm(enumerate(test_dataloader)):
    model.eval()

    labels_predictions = model(images_batch, texts_batch)


    predicted = labels_predictions[0]
    
    curr_id = test_data.ids[useless_id]
    if curr_id not in predictions:
        predictions[curr_id] = []
        
    idx_bin_class = 0
    for bin_class in bin_classes:
        if predicted[idx_bin_class] > best_thresh_all[idx_bin_class]:
            predictions[curr_id].append(bin_class)
        idx_bin_class += 1

1500it [00:24, 60.78it/s]


In [33]:
output_json = []
for k,v in predictions.items():
    output_json.append({"id" : k, "labels" : v})

with open(os.path.join(PATH_SAVE_SUBMISSION, "submission.txt"),"w") as fout:
    json.dump(output_json, fout)