In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import PIL
import torchvision
import numpy
import pandas
import torch 
import torch.optim as optim
import gc
from torch.optim.lr_scheduler import StepLR
import cv2
import os
import json
import numpy as np
from transformers import BertModel, BertTokenizer
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
from transformers import T5EncoderModel
from transformers import GPT2Tokenizer, GPT2Model
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import requests
from tqdm import tqdm
import re 
import string 

In [None]:
PATH_DATASETS = "."
PATH_JSON_TRAIN = os.path.join(PATH_DATASETS, "annotations/data/subtask2a/train.json") 
PATH_JSON_VAL = os.path.join(PATH_DATASETS, "annotations/data/subtask2a/validation.json") 

PATH_JSON_DEV = os.path.join(PATH_DATASETS, "annotations/data/subtask2a/dev_subtask2a_en.json") 
PATH_JSON_TEST = "./test_data_arabic/test_data_arabic/ar_subtask2a_test_unlabeled.json"

#os.path.join(PATH_DATASETS, "./test_data_arabic/test_data_arabic/ar_subtask2a_test_unlabeled.json") 


PATH_IMG_TRAIN = "./train_images"
PATH_IMG_VAL = "./validation_images"
PATH_IMG_DEV = "./dev_images"
PATH_IMG_TEST = "./test_images_arabic/subtask2a"

PATH_SAVE_MODEL = "subtask2a_models"
PATH_SAVE_SUBMISSION = "subtask2a_submissions"

os.makedirs(PATH_SAVE_MODEL, exist_ok=True)
os.makedirs(PATH_SAVE_SUBMISSION, exist_ok=True)

BATCH_SIZE = 8

EPOCHS_FULL = 0
LR_FULL = 1e-5

EPOCHS_FC = 3
LR_FC = 3e-6

TRAIN_ALL = True

In [None]:
data = json.load(open(PATH_JSON_TRAIN,"r",encoding='utf-8'))

print(data[0])

In [None]:
def preprocess(text):
    return text

In [None]:
transform = torchvision.transforms.Compose([
                #torchvision.transforms.ToPILImage(),
                #torchvision.transforms.Resize((224,224),interpolation = PIL.Image.BICUBIC),
                #torchvision.transforms.ToTensor(),
                #torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

In [None]:
#torchvision.models.efficientnet_b0(pretrained=True)
class BorrowedModel(nn.Module):
    def __init__(self):
        super(BorrowedModel, self).__init__()
        # Define text and image encoders
        self.text_encoder = AutoModel.from_pretrained('limjiayi/bert-hateful-memes-expanded')
        
        self.image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        
        self.fc = nn.Linear(249600, 22)  # Adjust num_classes accordingly
        self.fc2 = nn.Linear(128,2)
    def forward(self,  images,text_input):
        # Process text input
        
        text_outputs = []

        for i in range(text_input['input_ids'].shape[0]):
            x = dict()
            x['input_ids'] = text_input['input_ids'][i]
            x['token_type_ids'] = text_input['token_type_ids'][i]
            x['attention_mask'] = text_input['attention_mask'][i]
            text_outputs.append(self.text_encoder(**x).last_hidden_state)
            
            
        text_outputs = torch.stack(text_outputs)
        image_outputs = []
        
        for i in range(images['pixel_values'][0].shape[0]):
            x = dict()
            x['pixel_values'] = images['pixel_values'][0][i].unsqueeze(0).cuda()
          
            image_outputs.append(self.image_encoder(**x).last_hidden_state)
        
        image_outputs = torch.stack(image_outputs)

        # Flatten and concatenate the outputs
        text_outputs = text_outputs.view(text_outputs.size(0), -1)
        
        image_outputs = image_outputs.view(image_outputs.size(0), -1)
        combined = torch.cat((text_outputs, image_outputs), dim=1)
        
        # Pass through fully connected layer
        output = nn.Sigmoid()(self.fc(nn.Tanh()(combined)))
        return output

In [None]:
prevModel = BorrowedModel()

In [None]:
checkpoint = torch.load("./subtask2a_models/checkpoint_base.pt")

In [None]:
prevModel.load_state_dict(checkpoint['checkpoint'])

In [None]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k',do_resize = True,do_rescale = True,do_normalize = True,image_mean = [0.5,0.5,0.5],image_std = [0.5,0.5,0.5])

model_name = 'google-bert/bert-base-multilingual-uncased' 
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_model = AutoModel.from_pretrained(model_name)

In [None]:
class MyDataset(Dataset):
    
    def __init__(self, paths_json_img, bin_classes):
        self.filenames = []
        self.texts = []
        self.images = []
        self.ids = []
        self.labels = []
        
        for path_json, path_img in paths_json_img:
            data = json.load(open(path_json,"r",encoding='utf-8'))

            for x in tqdm(data):
                currentPath = os.path.join(path_img,x['image'])

                self.ids.append(x['id'])

                if 'labels' in x:
                    curr_labels = []
                    for bin_class in bin_classes:
                        if bin_class in x['labels']:
                            curr_labels.append(1)
                        else:
                            curr_labels.append(0)
                    self.labels.append(curr_labels)
                else:
                    self.labels.append([])

                text = preprocess(x['text'])
                if text is None:
                    text = ""
                self.texts.append(tokenizer(text,return_tensors='pt',padding='max_length',max_length=128,truncation=True))
                self.filenames.append(x['image'])

                currentImage = cv2.imread(currentPath)
                currentImage = torch.tensor(transform(currentImage)).unsqueeze(0)
                features = processor(currentImage)
                self.images.append(features)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,idx):
        text_tensors = {}
        for key, value in self.texts[idx].items():
            text_tensors[key] = value.cuda() if isinstance(value, torch.Tensor) else value
        
        
        image_tensors = {}
        for key, value in self.images[idx].items():
            image_tensors[key] = value.cuda() if isinstance(value, torch.Tensor) else value
        
        return ((image_tensors,text_tensors),torch.tensor(self.labels[idx]))

In [None]:
#torchvision.models.efficientnet_b0(pretrained=True)
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # Define text and image encoders
        self.text_encoder = AutoModel.from_pretrained('google-bert/bert-base-multilingual-uncased')
        
        self.image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        
        self.image_encoder.load_state_dict(prevModel.image_encoder.state_dict())
        
        self.fc = nn.Linear(249600, 22)  # Adjust num_classes accordingly
        self.fc.load_state_dict(prevModel.fc.state_dict())
        self.fc2 = nn.Linear(128,2)
    def forward(self,  images,text_input):
        # Process text input
        
        text_outputs = []

        for i in range(text_input['input_ids'].shape[0]):
            x = dict()
            x['input_ids'] = text_input['input_ids'][i]
            x['token_type_ids'] = text_input['token_type_ids'][i]
            x['attention_mask'] = text_input['attention_mask'][i]
            text_outputs.append(self.text_encoder(**x).last_hidden_state)
            
            
        text_outputs = torch.stack(text_outputs)
        image_outputs = []
        
        for i in range(images['pixel_values'][0].shape[0]):
            x = dict()
            x['pixel_values'] = images['pixel_values'][0][i].unsqueeze(0).cuda()
          
            image_outputs.append(self.image_encoder(**x).last_hidden_state)
        
        image_outputs = torch.stack(image_outputs)

        # Flatten and concatenate the outputs
        text_outputs = text_outputs.view(text_outputs.size(0), -1)
        
        image_outputs = image_outputs.view(image_outputs.size(0), -1)
        combined = torch.cat((text_outputs, image_outputs), dim=1)
        
        # Pass through fully connected layer
        output = nn.Sigmoid()(self.fc(nn.Tanh()(combined)))
        return output

In [None]:
data = json.load(open(PATH_JSON_TRAIN,"r",encoding='utf-8'))

bin_classes = []

for x in data:
    for label in x['labels']:
        if label not in bin_classes:
            bin_classes.append(label)

print(len(bin_classes))
print(bin_classes)

In [None]:
if TRAIN_ALL:
    train_data = MyDataset([(PATH_JSON_TRAIN, PATH_IMG_TRAIN), (PATH_JSON_VAL, PATH_IMG_VAL), (PATH_JSON_DEV, PATH_IMG_DEV)], bin_classes)
else:
    train_data = MyDataset([(PATH_JSON_TRAIN, PATH_IMG_TRAIN)], bin_classes)
valid_data = MyDataset([(PATH_JSON_VAL, PATH_IMG_VAL)], bin_classes)
test_data = MyDataset([(PATH_JSON_TEST, PATH_IMG_TEST)], bin_classes)

train_dataloader = DataLoader(dataset = train_data, batch_size = BATCH_SIZE, shuffle = True)
valid_dataloader = DataLoader(dataset = valid_data, batch_size = BATCH_SIZE, shuffle = False)
test_dataloader = DataLoader(dataset = test_data, batch_size = 1, shuffle = False)

In [None]:
# m = nn.Sigmoid()
# loss = nn.BCELoss()
# input = torch.randn(3, requires_grad=True)
# target = torch.empty(3).random_(2)

# print(input)
# print(target)
# output = loss(m(input), target)

In [None]:
predictions = {}
    
print(len(train_data.images[0]['pixel_values']))
print(len(train_data))
print(train_data.texts[0]['input_ids'].shape)

model = Model()
model.cuda()
model.train()

criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LR_FULL)

best_loss = 1e9

for epoch in range(EPOCHS_FULL):

    train_loss = 0.0    
    model.train()
    for useless_id, ((images_batch, texts_batch), labels_batch) in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()

        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')

        labels_predictions = model(images_batch, texts_batch)

#         print(labels_predictions.shape)
#         print(labels_batch.shape)
        
#         print(labels_predictions.type())
#         print(labels_batch.type())
        
#         print(labels_predictions)
#         print(labels_batch)
        
        loss = criterion(labels_predictions, labels_batch)
        loss.backward()

        optimizer.step()

        train_loss = train_loss + loss.item()

    # Validation loop
    validation_loss = 0.0
    model.eval()
    correct = 0
    total = 0

    for useless_id, ((images_batch, texts_batch), labels_batch) in tqdm(enumerate(valid_dataloader)):
        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')
        labels_predictions = model(images_batch, texts_batch)

        loss = criterion(labels_predictions, labels_batch)


        validation_loss = validation_loss + loss.item()


        predicted = (labels_predictions > 0.5)
        
        total += labels_batch.size(0)
        correct += (predicted == labels_batch).sum().item()


    train_loss /= len(train_dataloader.dataset)
    validation_loss /= len(train_dataloader.dataset)
    accuracy = (correct / total) / len(bin_classes)
    print(f'Epoch: {epoch} Train Loss: {train_loss} Validation Loss: {validation_loss} Validation Accuracy: {accuracy * 100:.2f}%')

    # Save checkpoint if needed
    # checkpoint = {'checkpoint': model.state_dict()}
    # torch.save(checkpoint, os.path.join(PATH_SAVE_MODEL, f'checkpoint_{epoch}.pt'))
    print(f'Checkpoint reached! Validation loss modified from {best_loss} to {validation_loss}')
    best_loss = validation_loss
    torch.cuda.empty_cache()

    
    
    
for param in model.text_encoder.parameters():
    param.requires_grad = False

for param in model.image_encoder.parameters():
    param.requires_grad = False 

optimizer = torch.optim.Adam(model.parameters(), lr = LR_FC)
best_loss = 1e9


for epoch in range(EPOCHS_FC):

    train_loss = 0.0    
    model.train()
    for useless_id, ((images_batch, texts_batch), labels_batch) in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()

        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')

        labels_predictions = model(images_batch, texts_batch)

        loss = criterion(labels_predictions, labels_batch)
        loss.backward()

        optimizer.step()

        train_loss = train_loss + loss.item()

    # Validation loop
    validation_loss = 0.0
    model.eval()
    correct = 0
    total = 0

    for useless_id, ((images_batch, texts_batch), labels_batch) in tqdm(enumerate(valid_dataloader)):
        labels_batch = labels_batch.to(torch.float32)
        labels_batch = labels_batch.to('cuda')

        labels_predictions = model(images_batch, texts_batch)

        loss = criterion(labels_predictions, labels_batch)


        validation_loss = validation_loss + loss.item()


        predicted = (labels_predictions > 0.5)
        total += labels_batch.size(0)
        correct += (predicted == labels_batch).sum().item()


    train_loss /= len(train_dataloader.dataset)
    validation_loss /= len(train_dataloader.dataset)
    accuracy = (correct / total) / len(bin_classes)
    print(f'Epoch: {epoch} Train Loss: {train_loss} Validation Loss: {validation_loss} Validation Accuracy: {accuracy * 100:.2f}%')





# Save checkpoint if needed
# checkpoint = {'checkpoint': model.state_dict()}
# torch.save(checkpoint, os.path.join(PATH_SAVE_MODEL, f'fc_checkpoint_{epoch}.pt'))
print(f'Checkpoint reached! Validation loss modified from {best_loss} to {validation_loss}')
best_loss = validation_loss
torch.cuda.empty_cache()

checkpoint = {'checkpoint': model.state_dict()}
torch.save(checkpoint, os.path.join(PATH_SAVE_MODEL, f'checkpoint.pt'))

#import torch
# model.train()
# checkpoint = torch.load(os.path.join(PATH_SAVE_MODEL, f'fc_checkpoint_{4}.pt'))

# # Apply the state dictionary to the model
# model.load_state_dict(checkpoint['checkpoint'])
                    


In [None]:
predictions = {}
ids = []

for useless_id, ((images_batch, texts_batch), labels_batch) in tqdm(enumerate(test_dataloader)):
    model.eval()

    labels_predictions = model(images_batch, texts_batch)


    predicted = (labels_predictions > 0.15)[0] #change to 0.25 for official submission
    
    curr_id = test_data.ids[useless_id]
    if curr_id not in predictions:
        predictions[curr_id] = []
        
    idx_bin_class = 0
    for bin_class in bin_classes:
        if predicted[idx_bin_class]:
            predictions[curr_id].append(bin_class)
        idx_bin_class += 1

In [None]:
output_json = []
for k,v in predictions.items():
    output_json.append({"id" : k, "labels" : v})

with open(os.path.join(PATH_SAVE_SUBMISSION, "submission_arabic_2a.txt"),"w") as fout:
    json.dump(output_json, fout)