# BERT

In [None]:
# BERT

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns

TEST_SIZE = 0.15
VALIDATE_SIZE = 0.1765
RANDOM_STATE_INT = 14988828
MAX_LEN = 256
BATCH_SIZE = 64
EPOCHS = 20
PATIENCE = 3

df = pd.read_csv('/content/drive/MyDrive/english_only.csv')

class EmotionDataset(Dataset):
    def __init__(self, excerpts, labels, tokenizer, max_len):
        self.excerpts = excerpts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.excerpts)

    def __getitem__(self, idx):
        text = str(self.excerpts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }



tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
label_dict = {label: idx for idx, label in enumerate(df['plutchik_emotion'].unique())}
df['emotion_labels'] = df['plutchik_emotion'].replace(label_dict)

train_val_texts, test_texts, train_val_labels, test_labels = train_test_split(
    df['excerpt_value_cleaned'],
    df['emotion_labels'],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE_INT,
    stratify=df['emotion_labels']
)

train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_val_texts,
    train_val_labels,
    test_size=VALIDATE_SIZE,
    random_state=RANDOM_STATE_INT,
    stratify=train_val_labels
)

train_texts = train_texts.reset_index(drop=True)
val_texts = val_texts.reset_index(drop=True)
test_texts = test_texts.reset_index(drop=True)
train_labels = train_labels.reset_index(drop=True)
val_labels = val_labels.reset_index(drop=True)
test_labels = test_labels.reset_index(drop=True)

train_dataset = EmotionDataset(train_texts, train_labels, tokenizer, MAX_LEN)
val_dataset = EmotionDataset(val_texts, val_labels, tokenizer, MAX_LEN)
test_dataset = EmotionDataset(test_texts, test_labels, tokenizer, MAX_LEN)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_dict))
optimizer = AdamW(model.parameters(), lr=2e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

early_stopping = EarlyStopping(patience=PATIENCE, min_delta=0.01)

print('Starting training')
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch_input_ids = batch['input_ids'].to(device)
        batch_attention_mask = batch['attention_mask'].to(device)
        batch_labels = batch['labels'].to(device)

        model.zero_grad()
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    model.eval()
    val_loss = 0
    all_predictions, all_true_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            batch_input_ids = batch['input_ids'].to(device)
            batch_attention_mask = batch['attention_mask'].to(device)
            batch_labels = batch['labels'].to(device)

            outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
            loss = outputs.loss
            val_loss += loss.item()

            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            all_predictions.extend(preds.cpu().numpy())
            all_true_labels.extend(batch_labels.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)

    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Time: {datetime.now().strftime("%H:%M:%S")}')

    early_stopping(avg_val_loss)
    if early_stopping.early_stop:
        print("Early stopping")
        break

def evaluate(loader, set_name):
    model.eval()
    all_predictions = []
    all_true_labels = []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            all_predictions.extend(preds.cpu().numpy())
            all_true_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_true_labels, all_predictions)
    plt.figure(figsize=(10,7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_dict.keys(), yticklabels=label_dict.keys())
    plt.title(f'{set_name} Set Confusion Matrix')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

    print(f"{set_name} Set Classification Report:")
    report = classification_report(all_true_labels, all_predictions, target_names=[label for label, index in sorted(label_dict.items(), key=lambda x: x[1])])
    print(report)
    with open(f'/content/drive/MyDrive/results/bert_{set_name}_classification_report.txt', 'w') as file:
        file.write(report)

evaluate(val_loader, "Validation")
evaluate(test_loader, "Test")


# MacBERTh

In [None]:
# MacBerth

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns

TEST_SIZE = 0.15
VALIDATE_SIZE = 0.1765
RANDOM_STATE_INT = 14988828
MAX_LEN = 256
BATCH_SIZE = 64
EPOCHS = 20
PATIENCE = 3

df = pd.read_csv('/content/drive/MyDrive/english_only.csv')

class EmotionDataset(Dataset):
    def __init__(self, excerpts, labels, tokenizer, max_len):
        self.excerpts = excerpts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.excerpts)

    def __getitem__(self, idx):
        text = str(self.excerpts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

tokenizer = BertTokenizer.from_pretrained('emanjavacas/MacBERTh')

label_dict = {label: idx for idx, label in enumerate(df['plutchik_emotion'].unique())}
df['emotion_labels'] = df['plutchik_emotion'].replace(label_dict)

train_val_texts, test_texts, train_val_labels, test_labels = train_test_split(
    df['excerpt_value_cleaned'],
    df['emotion_labels'],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE_INT,
    stratify=df['emotion_labels']
)

train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_val_texts,
    train_val_labels,
    test_size=VALIDATE_SIZE,
    random_state=RANDOM_STATE_INT,
    stratify=train_val_labels
)

train_texts = train_texts.reset_index(drop=True)
val_texts = val_texts.reset_index(drop=True)
test_texts = test_texts.reset_index(drop=True)
train_labels = train_labels.reset_index(drop=True)
val_labels = val_labels.reset_index(drop=True)
test_labels = test_labels.reset_index(drop=True)

train_dataset = EmotionDataset(train_texts, train_labels, tokenizer, MAX_LEN)
val_dataset = EmotionDataset(val_texts, val_labels, tokenizer, MAX_LEN)
test_dataset = EmotionDataset(test_texts, test_labels, tokenizer, MAX_LEN)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

model = BertForSequenceClassification.from_pretrained('emanjavacas/MacBERTh', num_labels=len(label_dict))
optimizer = AdamW(model.parameters(), lr=2e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

best_val_loss = float('inf')
patience_counter = 0
best_model_state = None

print('Starting training')
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch_input_ids = batch['input_ids'].to(device)
        batch_attention_mask = batch['attention_mask'].to(device)
        batch_labels = batch['labels'].to(device)

        model.zero_grad()
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    model.eval()
    val_loss = 0
    all_predictions, all_true_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            batch_input_ids = batch['input_ids'].to(device)
            batch_attention_mask = batch['attention_mask'].to(device)
            batch_labels = batch['labels'].to(device)

            outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
            loss = outputs.loss
            val_loss += loss.item()

            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            all_predictions.extend(preds.cpu().numpy())
            all_true_labels.extend(batch_labels.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)

    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Time: {datetime.now().strftime("%H:%M:%S")}')

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        best_model_state = model.state_dict()
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Early stopping")
            model.load_state_dict(best_model_state)
            break

def evaluate(loader, set_name):
    model.eval()
    all_predictions = []
    all_true_labels = []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            all_predictions.extend(preds.cpu().numpy())
            all_true_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_true_labels, all_predictions)
    plt.figure(figsize=(10,7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_dict.keys(), yticklabels=label_dict.keys())
    plt.title(f'{set_name} Set Confusion Matrix')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

    print(f"{set_name} Set Classification Report:")
    report = classification_report(all_true_labels, all_predictions, target_names=[label for label, index in sorted(label_dict.items(), key=lambda x: x[1])])
    print(report)
    with open(f'/content/drive/MyDrive/results/bert_{set_name}_classification_report.txt', 'w') as file:
        file.write(report)

evaluate(val_loader, "Validation")
evaluate(test_loader, "Test")


# RoBERTa

In [None]:
# Roberta


import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, RobertaForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns

TEST_SIZE = 0.15
VALIDATE_SIZE = 0.1765
RANDOM_STATE_INT = 14988828
MAX_LEN = 256
BATCH_SIZE = 64
EPOCHS = 20
PATIENCE = 3 
LR_NUM = 1e-5

df = pd.read_csv('/content/drive/MyDrive/english_only.csv')

class EmotionDataset(Dataset):
    def __init__(self, excerpts, labels, tokenizer, max_len):
        self.excerpts = excerpts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.excerpts)

    def __getitem__(self, idx):
        text = str(self.excerpts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=len(df['plutchik_emotion'].unique()))

label_dict = {label: idx for idx, label in enumerate(df['plutchik_emotion'].unique())}
df['emotion_labels'] = df['plutchik_emotion'].replace(label_dict)

train_val_texts, test_texts, train_val_labels, test_labels = train_test_split(
    df['excerpt_value_cleaned'],
    df['emotion_labels'],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE_INT,
    stratify=df['emotion_labels']
)

train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_val_texts,
    train_val_labels,
    test_size=VALIDATE_SIZE,
    random_state=RANDOM_STATE_INT,
    stratify=train_val_labels
)

train_texts = train_texts.reset_index(drop=True)
val_texts = val_texts.reset_index(drop=True)
test_texts = test_texts.reset_index(drop=True)
train_labels = train_labels.reset_index(drop=True)
val_labels = val_labels.reset_index(drop=True)
test_labels = test_labels.reset_index(drop=True)

train_dataset = EmotionDataset(train_texts, train_labels, tokenizer, MAX_LEN)
val_dataset = EmotionDataset(val_texts, val_labels, tokenizer, MAX_LEN)
test_dataset = EmotionDataset(test_texts, test_labels, tokenizer, MAX_LEN)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)


optimizer = AdamW(model.parameters(), lr=LR_NUM)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

best_val_loss = float('inf')
patience_counter = 0
best_model_state = None

print('Starting training')
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch_input_ids = batch['input_ids'].to(device)
        batch_attention_mask = batch['attention_mask'].to(device)
        batch_labels = batch['labels'].to(device)
        model.zero_grad()
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    
    model.eval()
    val_loss = 0
    all_predictions, all_true_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            batch_input_ids = batch['input_ids'].to(device)
            batch_attention_mask = batch['attention_mask'].to(device)
            batch_labels = batch['labels'].to(device)

            outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
            loss = outputs.loss
            val_loss += loss.item()

            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            all_predictions.extend(preds.cpu().numpy())
            all_true_labels.extend(batch_labels.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Time: {datetime.now().strftime("%H:%M:%S")}')

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        best_model_state = model.state_dict()
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Early stopping")
            model.load_state_dict(best_model_state)
            break

def evaluate(loader, set_name):
    model.eval()
    all_predictions = []
    all_true_labels = []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            all_predictions.extend(preds.cpu().numpy())
            all_true_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_true_labels, all_predictions)
    plt.figure(figsize=(10,7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_dict.keys(), yticklabels=label_dict.keys())
    plt.title(f'{set_name} Set Confusion Matrix')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

    labels_expected = list(range(len(label_dict)))
    target_names = [label for label, index in sorted(label_dict.items(), key=lambda x: x[1])]
    report = classification_report(all_true_labels, all_predictions, target_names=target_names, zero_division=1)
    print(f"{set_name} Set Classification Report:")
    print(report)
    with open(f'/content/drive/MyDrive/results/roberta_{set_name}_classification_report.txt', 'w') as file:
        file.write(report)

evaluate(val_loader, "Validation")
evaluate(test_loader, "Test")


# Weight Pooled Roberta


In [1]:
# Weight Pooled Roberta

import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, RobertaModel, AdamW, RobertaPreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime


TEST_SIZE = 0.15
VALIDATE_SIZE = 0.1765
RANDOM_STATE_INT = 14988828
MAX_LEN = 256
BATCH_SIZE = 32
EPOCHS = 20 
PATIENCE = 2

df = pd.read_csv('/content/drive/MyDrive/english_only.csv')

class EmotionDataset(Dataset):
    def __init__(self, excerpts, labels, tokenizer, max_len):
        self.excerpts = excerpts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.excerpts)

    def __getitem__(self, idx):
        excerpt = str(self.excerpts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            excerpt,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

class WeightedPoolingRoBERTa(RobertaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.num_labels)
        self.pooling_weights = nn.Parameter(torch.ones(4))

        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )

        hidden_states = outputs.hidden_states[-4:]
        weighted_sum = torch.stack(hidden_states).permute(1, 2, 3, 0) * self.pooling_weights
        weighted_sum = weighted_sum.sum(dim=-1)

        pooled_output = self.dropout(weighted_sum[:, 0, :])
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits
        )

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = WeightedPoolingRoBERTa.from_pretrained('roberta-base', num_labels=len(df['plutchik_emotion'].unique()))

label_dict = {label: idx for idx, label in enumerate(df['plutchik_emotion'].unique())}
df['emotion_labels'] = df['plutchik_emotion'].replace(label_dict)

train_val_texts, test_texts, train_val_labels, test_labels = train_test_split(
    df['excerpt_value_cleaned'],
    df['emotion_labels'],
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE_INT,
    stratify=df['emotion_labels']
)
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_val_texts,
    train_val_labels,
    test_size=VALIDATE_SIZE,
    random_state=RANDOM_STATE_INT,
    stratify=train_val_labels
)

train_texts = train_texts.reset_index(drop=True)
val_texts = val_texts.reset_index(drop=True)
test_texts = test_texts.reset_index(drop=True)
train_labels = train_labels.reset_index(drop=True)
val_labels = val_labels.reset_index(drop=True)
test_labels = test_labels.reset_index(drop=True)

train_dataset = EmotionDataset(train_texts, train_labels, tokenizer, MAX_LEN)
val_dataset = EmotionDataset(val_texts, val_labels, tokenizer, MAX_LEN)
test_dataset = EmotionDataset(test_texts, test_labels, tokenizer, MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

optimizer = AdamW(model.parameters(), lr=2e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

best_val_loss = float('inf')
patience_counter = 0
best_model_state = None

print('Starting training')
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch_input_ids = batch['input_ids'].to(device)
        batch_attention_mask = batch['attention_mask'].to(device)
        batch_labels = batch['labels'].to(device)
        model.zero_grad()
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    model.eval()
    val_loss = 0
    all_predictions, all_true_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            batch_input_ids = batch['input_ids'].to(device)
            batch_attention_mask = batch['attention_mask'].to(device)
            batch_labels = batch['labels'].to(device)

            outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
            loss = outputs.loss
            val_loss += loss.item()

            logits = outputs.logits
            preds = torch.argmax(logits, dim=1)
            all_predictions.extend(preds.cpu().numpy())
            all_true_labels.extend(batch['labels'].cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Time: {datetime.now().strftime("%H:%M:%S")}')

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        best_model_state = model.state_dict()  # Save the best model state
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Early stopping")
            model.load_state_dict(best_model_state)  # Restore the best model state
            break

def evaluate(loader, set_name):
    model.eval()
    predictions = []
    true_labels = []
    for batch in loader:
        batch_input_ids = batch['input_ids'].to(device)
        batch_attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        with torch.no_grad():
            outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1)
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

    report = classification_report(true_labels, predictions, target_names=[label for label, _ in sorted(label_dict.items(), key=lambda x: x[1])], zero_division=1)
    print(f"{set_name} Set Classification Report:")
    print(report)

    with open(f'/content/drive/MyDrive/results/weight_pooled_roberta_{set_name}_classification_report.txt', 'w') as file:
        file.write(report)

    cm = confusion_matrix(true_labels, predictions)
    plt.figure(figsize=(10,7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_dict.keys(), yticklabels=label_dict.keys())
    plt.title(f'{set_name} Set Confusion Matrix')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

evaluate(val_loader, "Validation")
evaluate(test_loader, "Test")
