In this project, we will use two datasets, ISOT and LIAR to investigate the problem of detection of fake news in social media platforms. We will use two recent powerful networks such as ELECTRA and XLNET for this problem. 

dataset from https://www.kaggle.com/datasets/csmalarkodi/isot-fake-news-dataset

**ISOT DATASET**

https://ai.googleblog.com/2020/03/more-efficient-nlp-model-pre-training.html

https://github.com/google-research/electra

In [None]:
# Packages imported
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data.dataloader as dataloader
import torch.optim as optim
import re
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize

from torch.utils.data import TensorDataset
from torchvision import transforms
import torchvision

import matplotlib.pyplot as plt
import time
from IPython.display import clear_output

!pip install transformers sentencepiece
!pip install pytorch-transformers

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


[0mCollecting pytorch-transformers
  Downloading pytorch_transformers-1.2.0-py3-none-any.whl (176 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m176.4/176.4 kB[0m [31m24.4 MB/s[0m eta [36m0:00:00[0m
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m880.6/880.6 kB[0m [31m73.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25ldone
[?25h  Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=895242 sha256=931f057a3bc4f5b4297606a2a9f7e3718d0815e29d12f84b6cde53c27242ea3b
  Stored in directory: /root/.cache/pip/wheels/42/79/78/5ad3b042cb2d97c294535162cdbaf9b167e3b186eae55ab72d
Successfully built sacremoses
Installing collected packages: sacremoses, pytorch-transformers
Successfully installed pytorch-transformers-1.2.0 sacrem

tr -dc '[:print:]\n' < file > newfile

In [None]:
# Load datasets
fake_df = pd.read_csv("/notebooks/Fake.csv")


real_df = pd.read_csv("/notebooks/True.csv")




In [None]:
# Select the title and text columns from each DataFrame
fake_df = fake_df[['title', 'text']]
real_df = real_df[['title', 'text']]

In [None]:
# Print the first few rows of the fake news DataFrame
print(fake_df.head(1))

                                               title  \
0   Donald Trump Sends Out Embarrassing New Year’...   

                                                text  
0  Donald Trump just couldn t wish all Americans ...  


In [None]:
# Print the first few rows of the real news DataFrame
print(real_df.head())

                                               title  \
0  As U.S. budget fight looms, Republicans flip t...   
1  U.S. military to accept transgender recruits o...   
2  Senior U.S. Republican senator: 'Let Mr. Muell...   
3  FBI Russia probe helped by Australian diplomat...   
4  Trump wants Postal Service to charge 'much mor...   

                                                text  
0  WASHINGTON (Reuters) - The head of a conservat...  
1  WASHINGTON (Reuters) - Transgender people will...  
2  WASHINGTON (Reuters) - The special counsel inv...  
3  WASHINGTON (Reuters) - Trump campaign adviser ...  
4  SEATTLE/WASHINGTON (Reuters) - President Donal...  


In [None]:
fake_df['class'] = 0
real_df['class'] = 1

In [None]:
df = pd.concat([fake_df, real_df], ignore_index=True, sort=False)

In [None]:
# Shuffle the dataset
df = df.sample(frac=1, random_state=42).reset_index(drop=True)

In [None]:
df.head(10)

Unnamed: 0,title,text,class
0,Ben Stein Calls Out 9th Circuit Court: Committ...,"21st Century Wire says Ben Stein, reputable pr...",0
1,Trump drops Steve Bannon from National Securit...,WASHINGTON (Reuters) - U.S. President Donald T...,1
2,Puerto Rico expects U.S. to lift Jones Act shi...,(Reuters) - Puerto Rico Governor Ricardo Rosse...,1
3,OOPS: Trump Just Accidentally Confirmed He Le...,"On Monday, Donald Trump once again embarrassed...",0
4,Donald Trump heads for Scotland to reopen a go...,"GLASGOW, Scotland (Reuters) - Most U.S. presid...",1
5,Paul Ryan Responds To Dem’s Sit-In On Gun Con...,"On Wednesday, Democrats took a powerful stance...",0
6,AWESOME! DIAMOND AND SILK Rip Into The Press: ...,President Trump s rally in FL on Saturday was ...,0
7,STAND UP AND CHEER! UKIP Party Leader SLAMS Ge...,He s been Europe s version of the outspoken Te...,0
8,North Korea shows no sign it is serious about ...,WASHINGTON (Reuters) - The State Department sa...,1
9,Trump signals willingness to raise U.S. minimu...,(This version of the story corrects the figur...,1


In [None]:
print(df['class'].value_counts())

0    23481
1    21417
Name: class, dtype: int64


In [None]:
# Clean the text, remove special characters, URLs, numbers, and extra spaces.

import re

def clean_text(text):
  """Cleans a text string by removing special characters, URLs, numbers, and extra spaces.

  Args:
    text: The text string to clean.

  Returns:
    The cleaned text string.
  """

  # Remove special characters.
  text = re.sub('[^\w\s]', '', text)

  # Remove URLs.
  text = re.sub('https?://\S+', '', text)

  # Remove numbers.
  text = re.sub('\d+', '', text)

  # Remove extra spaces.
  text = re.sub(' +', ' ', text)

  # Convert the text to lower case.
  text = text.lower()

  return text

# Convert the text to lower case and tokenize it.

from nltk.tokenize import word_tokenize

def tokenize_text(text):
  """Tokenizes a text string.

  Args:
    text: The text string to tokenize.

  Returns:
    The tokens in the text string.
  """

  return word_tokenize(text)



In [None]:
# Clean the text in the `text` column.

df['text'] = df['text'].apply(clean_text)

# Tokenize the text in the `text` column.

df['tokens'] = df['text'].apply(tokenize_text)


In [None]:
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize


nltk.download('stopwords')
nltk.download('punkt')

def remove_stopwords(tokens):
  stop_words = set(stopwords.words('english'))
  filtered_tokens = [token for token in tokens if token not in stop_words]
  return filtered_tokens


df['tokens'] = df['text'].apply(word_tokenize)


df['tokens_no_stopwords'] = df['tokens'].apply(remove_stopwords)




[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Perfrorm stemming/lemmitization

In [None]:
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.corpus import wordnet

nltk.download('punkt')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
nltk.download('omw-1.4')



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


True

In [None]:
# Function for stemming
def stemming(text):
    stemmer = PorterStemmer()
    stemmed_text = ' '.join([stemmer.stem(word) for word in text.split()])
    return stemmed_text

# Function for lemmatization
def lemmatization(text):
    lemmatizer = WordNetLemmatizer()
    lemmatized_text = ' '.join([lemmatizer.lemmatize(word, get_wordnet_pos(tag)) for word, tag in nltk.pos_tag(nltk.word_tokenize(text))])
    return lemmatized_text

# Function for mapping NLTK POS tags to WordNet POS tags
def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return wordnet.NOUN


In [None]:
df['text'] = df['text'].apply(lemmatization)


Split the dataset

In [None]:
from sklearn.model_selection import train_test_split

# Split the dataset into training, validation, and test sets (80-10-10)
train_data, temp_data = train_test_split(df, test_size=0.2, random_state=42)
valid_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)

# Reset the indices
train_data = train_data.reset_index(drop=True)
valid_data = valid_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

Text Encoding

XLNet needs [SEP] [CLS] tags at the end of each sentence
We add them by using following code

In [None]:
from transformers import ElectraTokenizer, XLNetTokenizer

# Load the pretrained tokenizers
electra_tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')
xlnet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

# Function to encode text data
def encode_text_data(tokenizer, text_data, max_length, model_name=None):
    if model_name == 'xlnet':
        text_data = [text + " [SEP] [CLS]" for text in text_data]

    encoded_data = tokenizer.batch_encode_plus(
        text_data,
        add_special_tokens=True,
        truncation=True,
        padding='max_length',
        max_length=max_length,
        return_attention_mask=True,
        return_tensors='pt'
    )
    return encoded_data

# Encode text data for ELECTRA
max_length_electra = 256
encoded_train_data_electra = encode_text_data(electra_tokenizer, train_data['text'].tolist(), max_length_electra)
encoded_valid_data_electra = encode_text_data(electra_tokenizer, valid_data['text'].tolist(), max_length_electra)
encoded_test_data_electra = encode_text_data(electra_tokenizer, test_data['text'].tolist(), max_length_electra)

# Encode text data for XLNet
max_length_xlnet = 256
encoded_train_data_xlnet = encode_text_data(xlnet_tokenizer, train_data['text'].tolist(), max_length_xlnet, model_name='xlnet')
encoded_valid_data_xlnet = encode_text_data(xlnet_tokenizer, valid_data['text'].tolist(), max_length_xlnet, model_name='xlnet')
encoded_test_data_xlnet = encode_text_data(xlnet_tokenizer, test_data['text'].tolist(), max_length_xlnet, model_name='xlnet')




Downloading vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/27.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/760 [00:00<?, ?B/s]

MODEL PREPARATION

Select the pre-trained ELECTRA and XLNet models (e.g., ElectraForSequenceClassification and XLNetForSequenceClassification from the Hugging Face Transformers library).
    Configure the models: Set the number of labels to 2 (fake and real news) and specify any other hyperparameters.
    Define the training parameters: Choose the optimizer (e.g., AdamW), learning rate, batch size, and number of epochs.

In [None]:
from transformers import ElectraForSequenceClassification, XLNetForSequenceClassification, AdamW

# Load the pretrained ELECTRA model for sequence classification
electra_model = ElectraForSequenceClassification.from_pretrained(
    'google/electra-base-discriminator',
    num_labels=2,
    output_attentions=False,
    output_hidden_states=False
)

# Load the pretrained XLNet model for sequence classification
xlnet_model = XLNetForSequenceClassification.from_pretrained(
    'xlnet-base-cased',
    num_labels=2,
    output_attentions=False,
    output_hidden_states=False
)

# Define the training parameters
optimizer_electra = AdamW(electra_model.parameters(), lr=2e-5, eps=1e-8)
optimizer_xlnet = AdamW(xlnet_model.parameters(), lr=2e-5, eps=1e-8)

learning_rate = 2e-5
batch_size = 16
num_epochs = 3


Downloading pytorch_model.bin:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/electra-base-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.o

Downloading pytorch_model.bin:   0%|          | 0.00/445M [00:00<?, ?B/s]

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.weight', 'sequence_summary.summary.weight', 'logits_proj.bias', 'sequence_summary.summary.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

This code snippet imports the necessary classes from the Hugging Face Transformers library, loads the pretrained ELECTRA and XLNet models for sequence classification, and sets the number of labels to 2 for fake and real news classification. The models are configured not to output attentions or hidden states.

The training parameters are defined using the AdamW optimizer with a specified learning rate and epsilon value. You can adjust the learning rate, batch size, and number of epochs according to your requirements and hardware constraints.

Model Training

Model Training:

    Train the ELECTRA and XLNet models separately on the training dataset.
    Validate the models during training using the validation dataset to monitor their performance and avoid overfitting.

In [None]:
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import get_linear_schedule_with_warmup

# Function to create a TensorDataset from encoded data
def create_tensor_dataset(encoded_data, labels):
    return TensorDataset(
        encoded_data['input_ids'],
        encoded_data['attention_mask'],
        torch.tensor(labels, dtype=torch.long)
    )

# Create TensorDatasets for ELECTRA and XLNet
train_dataset_electra = create_tensor_dataset(encoded_train_data_electra, train_data['class'].tolist())
valid_dataset_electra = create_tensor_dataset(encoded_valid_data_electra, valid_data['class'].tolist())

train_dataset_xlnet = create_tensor_dataset(encoded_train_data_xlnet, train_data['class'].tolist())
valid_dataset_xlnet = create_tensor_dataset(encoded_valid_data_xlnet, valid_data['class'].tolist())

# Create DataLoaders for ELECTRA and XLNet
train_dataloader_electra = DataLoader(train_dataset_electra, sampler=RandomSampler(train_dataset_electra), batch_size=batch_size)
valid_dataloader_electra = DataLoader(valid_dataset_electra, sampler=SequentialSampler(valid_dataset_electra), batch_size=batch_size)

train_dataloader_xlnet = DataLoader(train_dataset_xlnet, sampler=RandomSampler(train_dataset_xlnet), batch_size=batch_size)
valid_dataloader_xlnet = DataLoader(valid_dataset_xlnet, sampler=SequentialSampler(valid_dataset_xlnet), batch_size=batch_size)

# Function to train a model
def train_model(model, optimizer, train_dataloader, valid_dataloader, num_epochs, device):
    model.to(device)

    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * num_epochs)

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0

        for step, batch in enumerate(train_dataloader):
            input_ids, attention_mask, labels = tuple(t.to(device) for t in batch)
            model.zero_grad()

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs[0]
            total_train_loss += loss.item()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

        avg_train_loss = total_train_loss / len(train_dataloader)

        # Validation phase
        model.eval()
        total_eval_loss = 0

        for batch in valid_dataloader:
            input_ids, attention_mask, labels = tuple(t.to(device) for t in batch)
            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs[0]
            total_eval_loss += loss.item()

        avg_eval_loss = total_eval_loss / len(valid_dataloader)

        print(f"Epoch: {epoch + 1}, Train Loss: {avg_train_loss}, Validation Loss: {avg_eval_loss}")

# Train the ELECTRA model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_model(electra_model, optimizer_electra, train_dataloader_electra, valid_dataloader_electra, num_epochs, device)

# Train the XLNet model
train_model(xlnet_model, optimizer_xlnet, train_dataloader_xlnet, valid_dataloader_xlnet, num_epochs, device)


Epoch: 1, Train Loss: 0.019703529338646792, Validation Loss: 0.004950502579110994
Epoch: 2, Train Loss: 0.0026918698187586502, Validation Loss: 0.0064993213591507315
Epoch: 3, Train Loss: 0.0016503566949672241, Validation Loss: 0.0016429046165570224
Epoch: 1, Train Loss: 0.023935745771841804, Validation Loss: 0.015502227676232359
Epoch: 2, Train Loss: 0.003405549313485693, Validation Loss: 0.004790249627001757
Epoch: 3, Train Loss: 0.0011191337376334242, Validation Loss: 0.0010165631525775726


This code snippet creates TensorDatasets and DataLoaders for the ELECTRA and XLNET

Model Evaluation

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Function to evaluate a model on test data
def evaluate_model(model, test_dataloader, device):
    model.to(device)
    model.eval()

    predictions, true_labels = [], []

    for batch in test_dataloader:
        input_ids, attention_mask, labels = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
        
        logits = outputs[0].detach().cpu().numpy()
        label_ids = labels.to('cpu').numpy()

        predictions.extend(np.argmax(logits, axis=1).flatten())
        true_labels.extend(label_ids.flatten())

    return predictions, true_labels

# Create a DataLoader for the test dataset
test_dataloader_electra = DataLoader(valid_dataset_electra, sampler=SequentialSampler(valid_dataset_electra), batch_size=batch_size)
test_dataloader_xlnet = DataLoader(valid_dataset_xlnet, sampler=SequentialSampler(valid_dataset_xlnet), batch_size=batch_size)

# Evaluate the models
predictions_electra, true_labels_electra = evaluate_model(electra_model, test_dataloader_electra, device)
predictions_xlnet, true_labels_xlnet = evaluate_model(xlnet_model, test_dataloader_xlnet, device)

# Calculate evaluation metrics
accuracy_electra = accuracy_score(true_labels_electra, predictions_electra)
precision_electra = precision_score(true_labels_electra, predictions_electra)
recall_electra = recall_score(true_labels_electra, predictions_electra)
f1_electra = f1_score(true_labels_electra, predictions_electra)

accuracy_xlnet = accuracy_score(true_labels_xlnet, predictions_xlnet)
precision_xlnet = precision_score(true_labels_xlnet, predictions_xlnet)
recall_xlnet = recall_score(true_labels_xlnet, predictions_xlnet)
f1_xlnet = f1_score(true_labels_xlnet, predictions_xlnet)

# Print the evaluation metrics
print("ELECTRA Model:")
print(f"Accuracy: {accuracy_electra}, Precision: {precision_electra}, Recall: {recall_electra}, F1-score: {f1_electra}")

print("\nXLNet Model:")
print(f"Accuracy: {accuracy_xlnet}, Precision: {precision_xlnet}, Recall: {recall_xlnet}, F1-score: {f1_xlnet}")


ELECTRA Model:
Accuracy: 0.999554565701559, Precision: 0.9995337995337995, Recall: 0.9995337995337995, F1-score: 0.9995337995337995

XLNet Model:
Accuracy: 0.9997772828507795, Precision: 1.0, Recall: 0.9995337995337995, F1-score: 0.9997668454185124


Hyperparameter Tuning

In [None]:
pip install optuna

Collecting optuna
  Downloading optuna-3.1.1-py3-none-any.whl (365 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m365.7/365.7 kB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
Collecting cmaes>=0.9.1
  Downloading cmaes-0.9.1-py3-none-any.whl (21 kB)
Collecting alembic>=1.5.0
  Downloading alembic-1.10.4-py3-none-any.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.9/212.9 kB[0m [31m48.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting colorlog
  Downloading colorlog-6.7.0-py2.py3-none-any.whl (11 kB)
Collecting Mako
  Downloading Mako-1.2.4-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.7/78.7 kB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: Mako, colorlog, cmaes, alembic, optuna
Successfully installed Mako-1.2.4 alembic-1.10.4 cmaes-0.9.1 colorlog-6.7.0 optuna-3.1.1
[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from functools import partial

# Define the objective function for hyperparameter tuning
def objective(trial, model, train_dataloader, valid_dataloader, device):
    # Define hyperparameters to be tuned
    lr = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
    num_epochs = trial.suggest_int("num_epochs", 1, 5)

    # Configure the optimizer with the suggested learning rate
    optimizer = AdamW(model.parameters(), lr=lr, eps=1e-8)

    # Train the model with the current hyperparameters
    train_model(model, optimizer, train_dataloader, valid_dataloader, num_epochs, device, trial)

    # Evaluate the model on the validation set
    model.eval()
    total_eval_loss = 0
    for batch in valid_dataloader:
        input_ids, attention_mask, labels = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        total_eval_loss += loss.item()

    avg_eval_loss = total_eval_loss / len(valid_dataloader)

    return avg_eval_loss

def train_model(model, optimizer, train_dataloader, valid_dataloader, num_epochs, device, trial=None):
    model.to(device)

    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * num_epochs)

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0

        for step, batch in enumerate(train_dataloader):
            input_ids, attention_mask, labels = tuple(t.to(device) for t in batch)
            model.zero_grad()

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs[0]
            total_train_loss += loss.item()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

        avg_train_loss = total_train_loss / len(train_dataloader)

        # Validation phase
        model.eval()
        total_eval_loss = 0

        for batch_idx, batch in enumerate(valid_dataloader):
            input_ids, attention_mask, labels = tuple(t.to(device) for t in batch)
            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs[0]
            total_eval_loss += loss.item()

            # Report intermediate objective value
            if trial is not None:
                trial.report(total_eval_loss / (batch_idx + 1), epoch)

            # Handle pruning based on the intermediate value
            if trial is not None and trial.should_prune():
                raise optuna.exceptions.TrialPruned()

        avg_eval_loss = total_eval_loss / len(valid_dataloader)

        if trial is None:
            print(f"Epoch: {epoch + 1}, Train Loss: {avg_train_loss}, Validation Loss: {avg_eval_loss}")


# Optuna study for hyperparameter tuning
study = optuna.create_study(direction="minimize")
objective_with_data = partial(objective, model=electra_model, train_dataloader=train_dataloader_electra, valid_dataloader=valid_dataloader_electra, device=device)
study.optimize(objective_with_data, n_trials=10)

# Print the best hyperparameters
print("Best hyperparameters: ", study.best_params)
