In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import json
from transformers import AutoTokenizer

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data.append(json.loads(line))
    return pd.DataFrame(data)

class MemeDataset(Dataset):
    def __init__(self, annotations_file, tokenizer, max_length=512):
        self.img_labels = load_jsonl(annotations_file)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        text = self.img_labels.iloc[idx]['text']
        label = torch.tensor(self.img_labels.iloc[idx]['label'], dtype=torch.long)
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {key: val.squeeze(0) for key, val in encoding.items()}, label

checkpoint = "Hate-speech-CNERG/bert-base-uncased-hatexplain"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

train_dataset = MemeDataset(annotations_file='/kaggle/input/memedata/hateful_memes/train.jsonl', tokenizer=tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataset = MemeDataset(annotations_file='/kaggle/input/memedata/hateful_memes/dev_unseen.jsonl', tokenizer=tokenizer)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)


In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from tqdm import tqdm


In [None]:
from transformers import AutoModelForSequenceClassification
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2, ignore_mismatched_sizes=True)

model.classifier = nn.Sequential(
    nn.Dropout(model.config.hidden_dropout_prob),
    nn.Linear(model.config.hidden_size, 2)
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=5e-5)

from sklearn.exceptions import UndefinedMetricWarning
import warnings

warnings.filterwarnings("ignore", category=UndefinedMetricWarning) 

def train_and_evaluate_with_checkpoint(model, train_loader, val_loader, optimizer, device, num_epochs=20, save_path='best_model.pt'):
    train_losses = []
    val_losses = []
    val_accuracies = []
    val_f1_scores = []
    val_auc_scores = []

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        total_loss = 0
        model.train()
        for inputs, labels in tqdm(train_loader):
            inputs = {key: value.to(device) for key, value in inputs.items()}
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
#             break
        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        model.eval()
        val_labels = []
        val_predictions = []
        val_loss = 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader):
                inputs = {key: value.to(device) for key, value in inputs.items()}
                labels = labels.to(device)
                outputs = model(**inputs, labels=labels)
                loss = outputs.loss
                val_loss += loss.item()
                predictions = torch.argmax(outputs.logits, dim=1)
                val_labels.extend(labels.cpu().numpy())
                val_predictions.extend(predictions.cpu().numpy())
#                 break
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        if len(set(val_labels)) > 1: 
            val_f1 = f1_score(val_labels, val_predictions, average='binary')
            val_auc = roc_auc_score(val_labels, val_predictions)
        else:
            val_f1, val_auc = 0.0, 0.5  

        val_accuracies.append(accuracy_score(val_labels, val_predictions))
        val_f1_scores.append(val_f1)
        val_auc_scores.append(val_auc)

        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracies[-1]:.4f}, Val F1-Score: {val_f1:.4f}, Val AUC: {val_auc:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), save_path)
            print(f"Checkpoint saved to {save_path} at epoch {epoch+1}")

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Loss Curve')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(val_accuracies, label='Accuracy')
    plt.plot(val_f1_scores, label='F1 Score')
    plt.plot(val_auc_scores, label='AUC Score')
    plt.title('Validation Metrics')
    plt.xlabel('Epochs')
    plt.ylabel('Metrics')
    plt.legend()
    plt.show()
    
train_and_evaluate_with_checkpoint(model, train_loader, val_loader, optimizer, device)
