In [44]:
# Importing necessary libraries
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd

In [45]:
# Loading data
def load_data(data_file):
    df = pd.read_csv(data_file)
    texts = df['text'].tolist()
    labels = df['author'].tolist()
    return texts, labels

data_file = "/kaggle/input/persian-authors-preprocessed/persian_authors_preprocessed.csv"  # Provide the path to your dataset file
texts, labels = load_data(data_file)

In [46]:
unique_labels = sorted(list(set(labels)))
print(unique_labels)

dict_label = {}
for l in range(len(unique_labels)) :
    dict_label[int(l)] = unique_labels[l]
    dict_label[unique_labels[l]] = int(l)
    print(unique_labels[l] ,str(l) )

['eraghi', 'ferdousi', 'jami', 'jooya', 'moulavi', 'nezami', 'rahi', 'saadi', 'saeb', 'shahriar']
eraghi 0
ferdousi 1
jami 2
jooya 3
moulavi 4
nezami 5
rahi 6
saadi 7
saeb 8
shahriar 9


In [47]:
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length , dict_label={} , overwrite_dict = True):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.unique_labels = sorted(list(set(labels)))
        self.unique_labels = list(self.unique_labels)
        self.dict_label = dict_label
        if overwrite_dict:
            self.dict_label = {}
            for l in range(len(self.unique_labels)) :
                self.dict_label[int(l)] = self.unique_labels[l]
                self.dict_label[self.unique_labels[l]] = int(l)
            # Check unique labels
                unique_labels = set(labels)
                num_classes = len(unique_labels)
                print("Unique labels:", unique_labels)
                print("Number of classes:", num_classes)
            label_counts = {label: labels.count(label) for label in unique_labels}
            print("Label counts:", label_counts)
            print(self.dict_label)
    
    def get_dict(self):
        return self.dict_label


    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
#         print("index : " , idx)
        label =self.dict_label[ self.labels[idx] ] # Convert label to integer
        encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label)}


In [48]:
# Defining the BERT classifier model
class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)
        logits = self.fc(x)
        return logits

In [49]:
def train(model, data_loader, optimizer, scheduler, device):
    model.train()
    for batch in data_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

In [50]:
# Evaluation function
def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    actual_labels = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            _, preds = torch.max(outputs, dim=1)
            predictions.extend(preds.cpu().tolist())
            actual_labels.extend(labels.cpu().tolist())
    return accuracy_score(actual_labels, predictions), classification_report(actual_labels, predictions)


In [51]:
# Function for predicting sentiment
def predict_sentiment(text, model, tokenizer, device, max_length=128,dict_labe={}):
    model.eval()
    encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs, dim=1)
    return dict_label[preds.item()]


In [52]:

# Set up parameters
bert_model_name = 'bert-base-uncased'
num_classes = len(set(labels))
print("num_classes",num_classes)
max_length = 512
batch_size = 16
num_epochs = 100
learning_rate = 2e-5


num_classes 10


In [53]:
# Splitting data into train and validation sets
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)


In [54]:

# Initializing tokenizer and datasets
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
dict_label = train_dataset.get_dict()
val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, max_length,dict_label , overwrite_dict=False)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)


Unique labels: {'rahi', 'ferdousi', 'jooya', 'jami', 'moulavi', 'eraghi', 'saadi', 'nezami', 'shahriar', 'saeb'}
Number of classes: 10
Unique labels: {'rahi', 'ferdousi', 'jooya', 'jami', 'moulavi', 'eraghi', 'saadi', 'nezami', 'shahriar', 'saeb'}
Number of classes: 10
Unique labels: {'rahi', 'ferdousi', 'jooya', 'jami', 'moulavi', 'eraghi', 'saadi', 'nezami', 'shahriar', 'saeb'}
Number of classes: 10
Unique labels: {'rahi', 'ferdousi', 'jooya', 'jami', 'moulavi', 'eraghi', 'saadi', 'nezami', 'shahriar', 'saeb'}
Number of classes: 10
Unique labels: {'rahi', 'ferdousi', 'jooya', 'jami', 'moulavi', 'eraghi', 'saadi', 'nezami', 'shahriar', 'saeb'}
Number of classes: 10
Unique labels: {'rahi', 'ferdousi', 'jooya', 'jami', 'moulavi', 'eraghi', 'saadi', 'nezami', 'shahriar', 'saeb'}
Number of classes: 10
Unique labels: {'rahi', 'ferdousi', 'jooya', 'jami', 'moulavi', 'eraghi', 'saadi', 'nezami', 'shahriar', 'saeb'}
Number of classes: 10
Unique labels: {'rahi', 'ferdousi', 'jooya', 'jami', 'm

In [55]:
# Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTClassifier(bert_model_name, num_classes).to(device)


In [56]:
# Initializing optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=learning_rate,no_deprecation_warning=True)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)


In [57]:
# Training the model
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train(model, train_dataloader, optimizer, scheduler, device)
    accuracy, report = evaluate(model, val_dataloader, device)
    print(f"Validation Accuracy: {accuracy:.4f}")
    print(report)

Epoch 1/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.0484
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         7
           1       0.00      0.00      0.00         8
           2       0.00      0.00      0.00         9
           3       0.05      1.00      0.09         3
           4       0.00      0.00      0.00         5
           5       0.00      0.00      0.00         3
           6       0.00      0.00      0.00         4
           7       0.00      0.00      0.00         9
           8       0.00      0.00      0.00         7
           9       0.00      0.00      0.00         7

    accuracy                           0.05        62
   macro avg       0.00      0.10      0.01        62
weighted avg       0.00      0.05      0.00        62

Epoch 2/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.0484
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         7
           1       0.00      0.00      0.00         8
           2       0.00      0.00      0.00         9
           3       0.05      1.00      0.09         3
           4       0.00      0.00      0.00         5
           5       0.00      0.00      0.00         3
           6       0.00      0.00      0.00         4
           7       0.00      0.00      0.00         9
           8       0.00      0.00      0.00         7
           9       0.00      0.00      0.00         7

    accuracy                           0.05        62
   macro avg       0.00      0.10      0.01        62
weighted avg       0.00      0.05      0.00        62

Epoch 3/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.0484
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         7
           1       0.00      0.00      0.00         8
           2       0.00      0.00      0.00         9
           3       0.00      0.00      0.00         3
           4       0.00      0.00      0.00         5
           5       0.05      1.00      0.09         3
           6       0.00      0.00      0.00         4
           7       0.00      0.00      0.00         9
           8       0.00      0.00      0.00         7
           9       0.00      0.00      0.00         7

    accuracy                           0.05        62
   macro avg       0.00      0.10      0.01        62
weighted avg       0.00      0.05      0.00        62

Epoch 4/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.0323
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         7
           1       0.00      0.00      0.00         8
           2       0.00      0.00      0.00         9
           3       0.05      0.33      0.09         3
           4       0.00      0.00      0.00         5
           5       0.02      0.33      0.04         3
           6       0.00      0.00      0.00         4
           7       0.00      0.00      0.00         9
           8       0.00      0.00      0.00         7
           9       0.00      0.00      0.00         7

    accuracy                           0.03        62
   macro avg       0.01      0.07      0.01        62
weighted avg       0.00      0.03      0.01        62

Epoch 5/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.2258
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         7
           1       0.00      0.00      0.00         8
           2       0.00      0.00      0.00         9
           3       0.50      0.33      0.40         3
           4       0.67      0.40      0.50         5
           5       0.00      0.00      0.00         3
           6       0.19      1.00      0.32         4
           7       0.33      0.11      0.17         9
           8       0.00      0.00      0.00         7
           9       0.19      0.86      0.31         7

    accuracy                           0.23        62
   macro avg       0.19      0.27      0.17        62
weighted avg       0.16      0.23      0.14        62

Epoch 6/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.1290
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         7
           1       0.00      0.00      0.00         8
           2       0.00      0.00      0.00         9
           3       0.00      0.00      0.00         3
           4       0.00      0.00      0.00         5
           5       0.00      0.00      0.00         3
           6       0.13      1.00      0.24         4
           7       1.00      0.44      0.62         9
           8       0.00      0.00      0.00         7
           9       0.00      0.00      0.00         7

    accuracy                           0.13        62
   macro avg       0.11      0.14      0.09        62
weighted avg       0.15      0.13      0.10        62

Epoch 7/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.1935
              precision    recall  f1-score   support

           0       0.22      1.00      0.36         7
           1       0.00      0.00      0.00         8
           2       0.00      0.00      0.00         9
           3       0.00      0.00      0.00         3
           4       0.00      0.00      0.00         5
           5       0.20      0.67      0.31         3
           6       0.00      0.00      0.00         4
           7       1.00      0.22      0.36         9
           8       0.20      0.14      0.17         7
           9       0.00      0.00      0.00         7

    accuracy                           0.19        62
   macro avg       0.16      0.20      0.12        62
weighted avg       0.20      0.19      0.13        62

Epoch 8/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.3387
              precision    recall  f1-score   support

           0       0.50      0.86      0.63         7
           1       0.67      0.25      0.36         8
           2       0.00      0.00      0.00         9
           3       0.21      1.00      0.35         3
           4       0.00      0.00      0.00         5
           5       0.12      0.67      0.20         3
           6       0.00      0.00      0.00         4
           7       0.60      0.67      0.63         9
           8       0.25      0.14      0.18         7
           9       0.50      0.14      0.22         7

    accuracy                           0.34        62
   macro avg       0.28      0.37      0.26        62
weighted avg       0.33      0.34      0.28        62

Epoch 9/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.3548
              precision    recall  f1-score   support

           0       0.43      0.86      0.57         7
           1       0.53      1.00      0.70         8
           2       1.00      0.11      0.20         9
           3       0.14      0.67      0.24         3
           4       0.00      0.00      0.00         5
           5       0.00      0.00      0.00         3
           6       0.00      0.00      0.00         4
           7       0.67      0.44      0.53         9
           8       0.25      0.14      0.18         7
           9       0.00      0.00      0.00         7

    accuracy                           0.35        62
   macro avg       0.30      0.32      0.24        62
weighted avg       0.39      0.35      0.29        62

Epoch 10/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.4194
              precision    recall  f1-score   support

           0       0.60      0.86      0.71         7
           1       0.38      1.00      0.55         8
           2       0.00      0.00      0.00         9
           3       0.18      0.67      0.29         3
           4       0.00      0.00      0.00         5
           5       0.00      0.00      0.00         3
           6       0.00      0.00      0.00         4
           7       0.50      0.67      0.57         9
           8       0.57      0.57      0.57         7
           9       0.00      0.00      0.00         7

    accuracy                           0.42        62
   macro avg       0.22      0.38      0.27        62
weighted avg       0.26      0.42      0.31        62

Epoch 11/100
Validation Accuracy: 0.4839
              precision    recall  f1-score   support

           0       0.55      0.86      0.67         7
           1       1.00      0.62      0.77         8
         

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.5323
              precision    recall  f1-score   support

           0       0.60      0.43      0.50         7
           1       0.73      1.00      0.84         8
           2       0.50      0.33      0.40         9
           3       0.30      1.00      0.46         3
           4       1.00      0.40      0.57         5
           5       0.00      0.00      0.00         3
           6       0.50      1.00      0.67         4
           7       0.55      0.67      0.60         9
           8       0.50      0.43      0.46         7
           9       0.33      0.14      0.20         7

    accuracy                           0.53        62
   macro avg       0.50      0.54      0.47        62
weighted avg       0.53      0.53      0.50        62

Epoch 13/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.4355
              precision    recall  f1-score   support

           0       0.75      0.43      0.55         7
           1       1.00      0.62      0.77         8
           2       0.00      0.00      0.00         9
           3       0.18      1.00      0.30         3
           4       0.40      0.80      0.53         5
           5       0.00      0.00      0.00         3
           6       0.57      1.00      0.73         4
           7       0.64      0.78      0.70         9
           8       1.00      0.14      0.25         7
           9       0.00      0.00      0.00         7

    accuracy                           0.44        62
   macro avg       0.45      0.48      0.38        62
weighted avg       0.50      0.44      0.40        62

Epoch 14/100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.5161
              precision    recall  f1-score   support

           0       1.00      0.29      0.44         7
           1       0.80      1.00      0.89         8
           2       1.00      0.22      0.36         9
           3       0.30      1.00      0.46         3
           4       0.67      0.40      0.50         5
           5       0.00      0.00      0.00         3
           6       0.50      1.00      0.67         4
           7       0.50      0.67      0.57         9
           8       0.40      0.57      0.47         7
           9       0.20      0.14      0.17         7

    accuracy                           0.52        62
   macro avg       0.54      0.53      0.45        62
weighted avg       0.60      0.52      0.48        62

Epoch 15/100
Validation Accuracy: 0.4516
              precision    recall  f1-score   support

           0       0.62      0.71      0.67         7
           1       0.57      1.00      0.73         8
         

In [58]:
# Saving the trained model
torch.save(model.state_dict(), "bert_classifier.pth")


In [60]:
# Testing sentiment prediction
accuracy =0
for i in range(len(val_texts)):
    test_text = val_texts[i]
    sentiment = predict_sentiment(test_text, model, tokenizer, device,dict_labe=dict_label)
    if sentiment == val_labels[i]:accuracy +=1
    print(val_labels[i] )
    print(f"Predicted sentiment: {sentiment}")
print(f"accuracy: {accuracy/len(val_texts)} ")


eraghi
Predicted sentiment: eraghi
rahi
Predicted sentiment: eraghi
nezami
Predicted sentiment: ferdousi
jami
Predicted sentiment: jami
saadi
Predicted sentiment: saadi
shahriar
Predicted sentiment: nezami
ferdousi
Predicted sentiment: ferdousi
jami
Predicted sentiment: jami
eraghi
Predicted sentiment: eraghi
jami
Predicted sentiment: shahriar
eraghi
Predicted sentiment: rahi
jami
Predicted sentiment: moulavi
saeb
Predicted sentiment: jami
shahriar
Predicted sentiment: ferdousi
saadi
Predicted sentiment: jami
moulavi
Predicted sentiment: rahi
ferdousi
Predicted sentiment: ferdousi
moulavi
Predicted sentiment: moulavi
ferdousi
Predicted sentiment: ferdousi
saadi
Predicted sentiment: ferdousi
moulavi
Predicted sentiment: rahi
moulavi
Predicted sentiment: moulavi
saeb
Predicted sentiment: nezami
jooya
Predicted sentiment: jami
rahi
Predicted sentiment: saeb
saeb
Predicted sentiment: saeb
saadi
Predicted sentiment: jami
eraghi
Predicted sentiment: jami
jooya
Predicted sentiment: saeb
eragh