In [1]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import PIL
import torchvision
import numpy
import pandas
import torch 
import torch.optim as optim
import gc
from torch.optim.lr_scheduler import StepLR
import cv2
import os
import json
import numpy as np
from transformers import BertModel, BertTokenizer
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
from transformers import T5EncoderModel
from transformers import GPT2Tokenizer, GPT2Model
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import requests
from tqdm import tqdm
import re 
import string 

In [None]:
PATH_DATASETS = "../datasets"
PATH_JSON_TRAIN = os.path.join(PATH_DATASETS, "data/subtask1/train.json") 
PATH_JSON_VAL = os.path.join(PATH_DATASETS, "data/subtask1/validation.json") 
PATH_JSON_DEV = os.path.join(PATH_DATASETS, "dev_gold_labels/dev_subtask1_en.json") 
PATH_JSON_TEST = os.path.join(PATH_DATASETS, "test_data/english/en_subtask1_test_unlabeled.json") 

PATH_SAVE_MODEL = "subtask1_models"
PATH_SAVE_SUBMISSION = "subtask1_submissions"

os.makedirs(PATH_SAVE_MODEL, exist_ok=True)
os.makedirs(PATH_SAVE_SUBMISSION, exist_ok=True)

BERT_MODEL = 'limjiayi/bert-hateful-memes-expanded' 
NUM_CLASSES = 20

BATCH_SIZE = 8

EPOCHS_FULL = 3
LR_FULL = 1e-5

EPOCHS_FC = 3
LR_FC = 3e-6

TRAIN_ALL = True

In [None]:
data = json.load(open(PATH_JSON_TRAIN,"r",encoding='utf-8'))

print(data[0])

{'id': '65635', 'text': 'THIS IS WHY YOU NEED\\n\\nA SHARPIE WITH YOU AT ALL TIMES', 'labels': ['Black-and-white Fallacy/Dictatorship'], 'link': 'https://www.facebook.com/photo/?fbid=4023552137722493&set=g.633131750534436'}


In [None]:
def preprocess(text):
    return text

In [None]:
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
text_model = AutoModel.from_pretrained(BERT_MODEL)

Some weights of BertModel were not initialized from the model checkpoint at limjiayi/bert-hateful-memes-expanded and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
class MyDataset(Dataset):
    
    def __init__(self, paths_json, bin_classes):
        self.texts = []
        self.ids = []
        self.labels = []
        
        for path_json in paths_json:
            data = json.load(open(path_json,"r",encoding='utf-8'))

            for x in tqdm(data):
                self.ids.append(x['id'])

                if 'labels' in x:
                    curr_labels = []
                    for bin_class in bin_classes:
                        if bin_class in x['labels']:
                            curr_labels.append(1)
                        else:
                            curr_labels.append(0)
                    self.labels.append(curr_labels)
                else:
                    self.labels.append([])

                text = preprocess(x['text'])
                if text is None:
                    text = ""
                self.texts.append(tokenizer(text,return_tensors='pt',padding='max_length',max_length=128,truncation=True))

    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self,idx):
        text_tensors = {}
        for key, value in self.texts[idx].items():
            text_tensors[key] = value.cuda() if isinstance(value, torch.Tensor) else value
        
        return (text_tensors,torch.tensor(self.labels[idx]))

In [None]:
#torchvision.models.efficientnet_b0(pretrained=True)
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # Define text and image encoders
        self.text_encoder = AutoModel.from_pretrained(BERT_MODEL)
        
        self.fc = nn.Linear(98304, NUM_CLASSES)  # Adjust num_classes accordingly
    def forward(self,text_input):
        # Process text input
        
        text_outputs = []

        for i in range(text_input['input_ids'].shape[0]):
            x = dict()
            x['input_ids'] = text_input['input_ids'][i]
            x['token_type_ids'] = text_input['token_type_ids'][i]
            x['attention_mask'] = text_input['attention_mask'][i]
            text_outputs.append(self.text_encoder(**x).last_hidden_state)
            
            
        text_outputs = torch.stack(text_outputs)
        # Flatten and concatenate the outputs
        text_outputs = text_outputs.view(text_outputs.size(0), -1)
        
        # Pass through fully connected layer
        output = nn.Sigmoid()(self.fc(nn.Tanh()(text_outputs)))
        return output

In [None]:
data = json.load(open(PATH_JSON_TRAIN,"r",encoding='utf-8'))

bin_classes = []

for x in data:
    for label in x['labels']:
        if label not in bin_classes:
            bin_classes.append(label)

print(len(bin_classes))
print(bin_classes)

20
['Black-and-white Fallacy/Dictatorship', 'Loaded Language', 'Glittering generalities (Virtue)', 'Thought-terminating cliché', 'Whataboutism', 'Slogans', 'Causal Oversimplification', 'Smears', 'Name calling/Labeling', 'Appeal to authority', 'Exaggeration/Minimisation', 'Repetition', 'Flag-waving', 'Appeal to fear/prejudice', 'Reductio ad hitlerum', 'Doubt', "Misrepresentation of Someone's Position (Straw Man)", 'Obfuscation, Intentional vagueness, Confusion', 'Bandwagon', 'Presenting Irrelevant Data (Red Herring)']


In [None]:
if TRAIN_ALL:
    train_data = MyDataset([PATH_JSON_TRAIN, PATH_JSON_DEV, PATH_JSON_VAL], bin_classes)
else:
    train_data = MyDataset([PATH_JSON_TRAIN, PATH_JSON_DEV], bin_classes)
valid_data = MyDataset([PATH_JSON_VAL], bin_classes)
test_data = MyDataset([PATH_JSON_TEST], bin_classes)

train_dataloader = DataLoader(dataset = train_data, batch_size = BATCH_SIZE, shuffle = True)
valid_dataloader = DataLoader(dataset = valid_data, batch_size = BATCH_SIZE, shuffle = False)
test_dataloader = DataLoader(dataset = test_data, batch_size = 1, shuffle = False)

  0%|          | 0/7000 [00:00<?, ?it/s]

100%|██████████| 7000/7000 [00:00<00:00, 7935.52it/s]
100%|██████████| 1000/1000 [00:00<00:00, 8403.47it/s]
100%|██████████| 500/500 [00:00<00:00, 8770.14it/s]
100%|██████████| 500/500 [00:00<00:00, 8773.70it/s]
100%|██████████| 1500/1500 [00:00<00:00, 9147.08it/s]


In [None]:
# m = nn.Sigmoid()
# loss = nn.BCELoss()
# input = torch.randn(3, requires_grad=True)
# target = torch.empty(3).random_(2)

# print(input)
# print(target)
# output = loss(m(input), target)

In [None]:
best_thresh_all = []
    
print(len(train_data))
print(train_data.texts[0]['input_ids'].shape)

model = Model()
model.cuda()
model.train()

criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LR_FULL)

best_loss = 1e9

for epoch in range(EPOCHS_FULL):

    train_loss = 0.0    
    model.train()
    for useless_id, (texts_batch, labels_batch) in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()

        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')

        labels_predictions = model(texts_batch)

#         print(labels_predictions.shape)
#         print(labels_batch.shape)
        
#         print(labels_predictions.type())
#         print(labels_batch.type())
        
#         print(labels_predictions)
#         print(labels_batch)
        
        loss = criterion(labels_predictions, labels_batch)
        loss.backward()

        optimizer.step()

        train_loss = train_loss + loss.item()

    # Validation loop
    validation_loss = 0.0
    model.eval()
    correct = 0
    total = 0

    all_val_pred = [[] for _ in range(NUM_CLASSES)] 
    all_val_gt = [[] for _ in range(NUM_CLASSES)] 

    for useless_id, (texts_batch, labels_batch) in tqdm(enumerate(valid_dataloader)):
        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')
        labels_predictions = model(texts_batch)

        loss = criterion(labels_predictions, labels_batch)


        validation_loss = validation_loss + loss.item()


        predicted = (labels_predictions > 0.5)
        
        total += labels_batch.size(0)
        correct += (predicted == labels_batch).sum().item()

        cpu_labels_predictions = labels_predictions.to('cpu').tolist()
        cpu_labels_batch = labels_batch.to('cpu').tolist()

        for bat in range(len(cpu_labels_predictions)):
            for i in range(NUM_CLASSES):
                all_val_pred[i].append(cpu_labels_predictions[bat][i])
                all_val_gt[i].append(cpu_labels_batch[bat][i])

    best_thresh_all = []
    print("BEST THRESHOLDS")
    for i in range(NUM_CLASSES):
        zipped_pred_gt = list(zip(all_val_pred[i], all_val_gt[i]))
        zipped_pred_gt.sort()

        best_thresh = 0
        best_f1 = 0
        tp = sum(all_val_gt[i])
        fp = len(all_val_gt[i]) - tp
        fn = 0
        for x in zipped_pred_gt:
            if x[1] == 1:
                tp -= 1
                fn += 1
            else:
                fp -= 1

            if tp > 0:
                curr_f1 = 2*tp / (2*tp + fp + fn)
                if curr_f1 > best_f1:
                    best_f1 = curr_f1
                    best_thresh = x[0] 
        best_thresh_all.append(best_thresh)

        print(f"{bin_classes[i]} : best_thresh={best_thresh} , best_f1={best_f1}")
    print()

    train_loss /= len(train_dataloader.dataset)
    validation_loss /= len(train_dataloader.dataset)
    accuracy = (correct / total) / len(bin_classes)
    print(f'Epoch: {epoch} Train Loss: {train_loss} Validation Loss: {validation_loss} Validation Accuracy: {accuracy * 100:.2f}%')

    # Save checkpoint if needed
    # checkpoint = {'checkpoint': model.state_dict()}
    # torch.save(checkpoint, os.path.join(PATH_SAVE_MODEL, f'checkpoint_{epoch}.pt'))
    print(f'Checkpoint reached! Validation loss modified from {best_loss} to {validation_loss}')
    best_loss = validation_loss
    torch.cuda.empty_cache()
                    


8500
torch.Size([1, 128])


Some weights of BertModel were not initialized from the model checkpoint at limjiayi/bert-hateful-memes-expanded and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
17it [00:03,  4.89it/s]


KeyboardInterrupt: 

In [None]:
for param in model.text_encoder.parameters():
    param.requires_grad = False

optimizer = torch.optim.Adam(model.parameters(), lr = LR_FC)
best_loss = 1e9


for epoch in range(EPOCHS_FC):

    train_loss = 0.0    
    model.train()
    for useless_id, (texts_batch, labels_batch) in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()

        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')

        labels_predictions = model(texts_batch)

        loss = criterion(labels_predictions, labels_batch)
        loss.backward()

        optimizer.step()

        train_loss = train_loss + loss.item()

    # Validation loop
    validation_loss = 0.0
    model.eval()
    correct = 0
    total = 0

    all_val_pred = [[] for _ in range(NUM_CLASSES)] 
    all_val_gt = [[] for _ in range(NUM_CLASSES)] 

    for useless_id, (texts_batch, labels_batch) in tqdm(enumerate(valid_dataloader)):
        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')
        labels_predictions = model(texts_batch)

        loss = criterion(labels_predictions, labels_batch)


        validation_loss = validation_loss + loss.item()


        predicted = (labels_predictions > 0.5)
        
        total += labels_batch.size(0)
        correct += (predicted == labels_batch).sum().item()

        cpu_labels_predictions = labels_predictions.to('cpu').tolist()
        cpu_labels_batch = labels_batch.to('cpu').tolist()

        for bat in range(len(cpu_labels_predictions)):
            for i in range(NUM_CLASSES):
                all_val_pred[i].append(cpu_labels_predictions[bat][i])
                all_val_gt[i].append(cpu_labels_batch[bat][i])

    best_thresh_all = []
    print("BEST THRESHOLDS")
    for i in range(NUM_CLASSES):
        zipped_pred_gt = list(zip(all_val_pred[i], all_val_gt[i]))
        zipped_pred_gt.sort()

        best_thresh = 0
        best_f1 = 0
        tp = sum(all_val_gt[i])
        fp = len(all_val_gt[i]) - tp
        fn = 0
        for x in zipped_pred_gt:
            if x[1] == 1:
                tp -= 1
                fn += 1
            else:
                fp -= 1

            if tp > 0:
                curr_f1 = 2*tp / (2*tp + fp + fn)
                if curr_f1 > best_f1:
                    best_f1 = curr_f1
                    best_thresh = x[0] 
        best_thresh_all.append(best_thresh)

        print(f"{bin_classes[i]} : best_thresh={best_thresh} , best_f1={best_f1}")
    print()

    train_loss /= len(train_dataloader.dataset)
    validation_loss /= len(train_dataloader.dataset)
    accuracy = (correct / total) / len(bin_classes)
    print(f'Epoch: {epoch} Train Loss: {train_loss} Validation Loss: {validation_loss} Validation Accuracy: {accuracy * 100:.2f}%')

    # Save checkpoint if needed
    # checkpoint = {'checkpoint': model.state_dict()}
    # torch.save(checkpoint, os.path.join(PATH_SAVE_MODEL, f'checkpoint_{epoch}.pt'))
    print(f'Checkpoint reached! Validation loss modified from {best_loss} to {validation_loss}')
    best_loss = validation_loss
    torch.cuda.empty_cache()



# Save checkpoint if needed
# checkpoint = {'checkpoint': model.state_dict()}
# torch.save(checkpoint, os.path.join(PATH_SAVE_MODEL, f'fc_checkpoint_{epoch}.pt'))
print(f'Checkpoint reached! Validation loss modified from {best_loss} to {validation_loss}')
best_loss = validation_loss
torch.cuda.empty_cache()

checkpoint = {'checkpoint': model.state_dict()}
torch.save(checkpoint, os.path.join(PATH_SAVE_MODEL, f'checkpoint.pt'))

#import torch
# model.train()
# checkpoint = torch.load(os.path.join(PATH_SAVE_MODEL, f'fc_checkpoint_{4}.pt'))

# # Apply the state dictionary to the model
# model.load_state_dict(checkpoint['checkpoint'])

1000it [00:43, 23.17it/s]
63it [00:02, 24.25it/s]


BEST THRESHOLDS
Black-and-white Fallacy/Dictatorship : best_thresh=0.1474543809890747 , best_f1=0.45569620253164556
Loaded Language : best_thresh=0.28121861815452576 , best_f1=0.56875
Glittering generalities (Virtue) : best_thresh=0.4378505051136017 , best_f1=0.4918032786885246
Thought-terminating cliché : best_thresh=0.3190999925136566 , best_f1=0.39473684210526316
Whataboutism : best_thresh=0.1247732862830162 , best_f1=0.27450980392156865
Slogans : best_thresh=0.31313225626945496 , best_f1=0.47191011235955055
Causal Oversimplification : best_thresh=0.18258559703826904 , best_f1=0.22857142857142856
Smears : best_thresh=0.16772451996803284 , best_f1=0.5794871794871795
Name calling/Labeling : best_thresh=0.34608033299446106 , best_f1=0.64
Appeal to authority : best_thresh=0.1932286024093628 , best_f1=0.7692307692307693
Exaggeration/Minimisation : best_thresh=0.2341528683900833 , best_f1=0.5777777777777777
Repetition : best_thresh=0.1865023374557495 , best_f1=0.625
Flag-waving : best_thr

1000it [00:43, 23.14it/s]
63it [00:02, 24.24it/s]


BEST THRESHOLDS
Black-and-white Fallacy/Dictatorship : best_thresh=0.1661970317363739 , best_f1=0.4594594594594595
Loaded Language : best_thresh=0.23600409924983978 , best_f1=0.573208722741433
Glittering generalities (Virtue) : best_thresh=0.4518183767795563 , best_f1=0.5079365079365079
Thought-terminating cliché : best_thresh=0.3465636968612671 , best_f1=0.38961038961038963
Whataboutism : best_thresh=0.13569805026054382 , best_f1=0.2641509433962264
Slogans : best_thresh=0.06287406384944916 , best_f1=0.4782608695652174
Causal Oversimplification : best_thresh=0.2181524634361267 , best_f1=0.25
Smears : best_thresh=0.12356071174144745 , best_f1=0.5728643216080402
Name calling/Labeling : best_thresh=0.3214830756187439 , best_f1=0.6462882096069869
Appeal to authority : best_thresh=0.2365138828754425 , best_f1=0.7769784172661871
Exaggeration/Minimisation : best_thresh=0.2879684269428253 , best_f1=0.5581395348837209
Repetition : best_thresh=0.20679157972335815 , best_f1=0.6086956521739131
Fla

1000it [00:43, 23.15it/s]
63it [00:02, 24.15it/s]


BEST THRESHOLDS
Black-and-white Fallacy/Dictatorship : best_thresh=0.1665106564760208 , best_f1=0.4657534246575342
Loaded Language : best_thresh=0.19219794869422913 , best_f1=0.5680473372781065
Glittering generalities (Virtue) : best_thresh=0.4044259786605835 , best_f1=0.5079365079365079
Thought-terminating cliché : best_thresh=0.35019129514694214 , best_f1=0.4166666666666667
Whataboutism : best_thresh=0.11264052242040634 , best_f1=0.2692307692307692
Slogans : best_thresh=0.04442925006151199 , best_f1=0.48
Causal Oversimplification : best_thresh=0.1684967428445816 , best_f1=0.22857142857142856
Smears : best_thresh=0.1516617238521576 , best_f1=0.5737704918032787
Name calling/Labeling : best_thresh=0.3283158242702484 , best_f1=0.6375545851528385
Appeal to authority : best_thresh=0.23459598422050476 , best_f1=0.7769784172661871
Exaggeration/Minimisation : best_thresh=0.24717316031455994 , best_f1=0.5652173913043478
Repetition : best_thresh=0.20218044519424438 , best_f1=0.6086956521739131


In [None]:
print(best_thresh_all)

best_thresh_all = [0.1665106564760208, 0.19219794869422913, 0.4044259786605835, 0.35019129514694214, 0.11264052242040634, 0.04442925006151199, 0.1684967428445816, 0.1516617238521576, 0.3283158242702484, 0.23459598422050476, 0.24717316031455994, 0.20218044519424438, 0.31305378675460815, 0.251982182264328, 0.0648965910077095, 0.12918420135974884, 0.035431552678346634, 0.004953750409185886, 0.18214985728263855, 0.023360857740044594]

best_thresh_all = [max(0.2, x) * 1 for x in best_thresh_all]

[0.1665106564760208, 0.19219794869422913, 0.4044259786605835, 0.35019129514694214, 0.11264052242040634, 0.04442925006151199, 0.1684967428445816, 0.1516617238521576, 0.3283158242702484, 0.23459598422050476, 0.24717316031455994, 0.20218044519424438, 0.31305378675460815, 0.251982182264328, 0.0648965910077095, 0.12918420135974884, 0.035431552678346634, 0.004953750409185886, 0.18214985728263855, 0.023360857740044594]


In [None]:
predictions = {}
ids = []

for useless_id, (texts_batch, labels_batch) in tqdm(enumerate(test_dataloader)):
    model.eval()

    labels_predictions = model(texts_batch)


    predicted = labels_predictions[0]
    
    curr_id = test_data.ids[useless_id]
    if curr_id not in predictions:
        predictions[curr_id] = []
        
    idx_bin_class = 0
    for bin_class in bin_classes:
        if predicted[idx_bin_class] > best_thresh_all[idx_bin_class]:
            predictions[curr_id].append(bin_class)
        idx_bin_class += 1

1500it [00:10, 146.34it/s]


In [None]:
output_json = []
for k,v in predictions.items():
    output_json.append({"id" : k, "labels" : v})

with open(os.path.join(PATH_SAVE_SUBMISSION, "submission.txt"),"w") as fout:
    json.dump(output_json, fout)