In [2]:
from datasets import Dataset, DatasetDict
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, f1_score

In [3]:
train_df = pd.read_csv('train_data.csv')
test_df = pd.read_csv('test.csv')

In [4]:
class_encoding = {'CAG': 0, 'NAG': 1, 'OAG': 2}
train_df['label'] = train_df['class_label'].map(class_encoding)
test_df['label'] = test_df['class_label'].map(class_encoding)

In [5]:
train_df = train_df[["comment", "label"]]
test_df = test_df[["comment", "label"]]

In [6]:
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

In [7]:
dataset = DatasetDict({
    'train': train_dataset,
    'test': test_dataset
})

In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['comment', 'label'],
        num_rows: 75
    })
    test: Dataset({
        features: ['comment', 'label'],
        num_rows: 90
    })
})

In [9]:
model_ckpt = 'ai4bharat/indic-bert'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt, keep_accents=True)

In [10]:
def tokenize(batch):
    return tokenizer(batch['comment'], padding='max_length', truncation=True, max_length=512)

In [11]:
dataset_encoded = dataset.map(tokenize, batched=True, batch_size=None)

Map: 100%|██████████| 75/75 [00:00<00:00, 3655.70 examples/s]
Map: 100%|██████████| 90/90 [00:00<00:00, 6116.23 examples/s]


### Our classifier

In [12]:
import torch 
from transformers import AutoModel
model_ckpt = 'ai4bharat/indic-bert'
device = torch.device('mps' if torch.cuda.is_available() else 'cpu')
model = AutoModel.from_pretrained(model_ckpt).to(device)

In [13]:
print(model.config)

AlbertConfig {
  "_name_or_path": "ai4bharat/indic-bert",
  "attention_probs_dropout_prob": 0,
  "bos_token_id": 2,
  "classifier_dropout_prob": 0.1,
  "down_scale_factor": 1,
  "embedding_size": 128,
  "eos_token_id": 3,
  "gap_size": 0,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "inner_group_num": 1,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "albert",
  "net_structure_type": 0,
  "num_attention_heads": 12,
  "num_hidden_groups": 1,
  "num_hidden_layers": 12,
  "num_memory_blocks": 0,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.40.0",
  "type_vocab_size": 2,
  "vocab_size": 200000
}



In [14]:
def extract_hidden_states(batch):
    inputs = {k:v.to(device) for k,v in batch.items()
              if k in tokenizer.model_input_names}
    with torch.no_grad():
        last_hidden_state = model(**inputs).last_hidden_state
    return{'hidden_state':last_hidden_state[:,0].cpu().numpy()}
    
dataset_encoded.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
dataset_hidden_states = dataset_encoded.map(extract_hidden_states, batched=True)

Map: 100%|██████████| 75/75 [00:20<00:00,  3.67 examples/s]
Map: 100%|██████████| 90/90 [00:23<00:00,  3.84 examples/s]


In [15]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [16]:
import torch 
from transformers import AutoModel
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
# device = torch.device('cpu')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModel.from_pretrained(model_ckpt).to(device)

In [17]:
dataset_hidden_states

DatasetDict({
    train: Dataset({
        features: ['comment', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'hidden_state'],
        num_rows: 75
    })
    test: Dataset({
        features: ['comment', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'hidden_state'],
        num_rows: 90
    })
})

In [18]:
import numpy as np
# Where hidden states will be the inputs and labels are the targets
X_train = np.array(dataset_hidden_states['train']['hidden_state'])
X_valid = np.array(dataset_hidden_states['test']['hidden_state'])
y_train = np.array(dataset_hidden_states['train']['label'])
y_valid = np.array(dataset_hidden_states['test']['label'])
X_train.shape, X_valid.shape

((75, 768), (90, 768))

In [19]:
from sklearn.linear_model import LogisticRegression
lr_clf = LogisticRegression(max_iter=3000)
lr_clf.fit(X_train,y_train)
round(lr_clf.score(X_valid,y_valid),3)

0.4

In [20]:
from sklearn.svm import SVC

svm_clf = SVC()
svm_clf.fit(X_train, y_train)

svm_score = svm_clf.score(X_valid, y_valid)
print(f"SVM accuracy score: {round(svm_score, 3)}")

SVM accuracy score: 0.389


In [21]:
from sklearn.ensemble import RandomForestClassifier

rf_clf = RandomForestClassifier()
rf_clf.fit(X_train, y_train)

rf_score = rf_clf.score(X_valid, y_valid)
print(f"Random Forest accuracy score: {round(rf_score, 3)}")

Random Forest accuracy score: 0.478


In [22]:
from transformers import AlbertForSequenceClassification
num_labels = 3
model = (AlbertForSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to(device))

Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at ai4bharat/indic-bert and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
from sklearn.metrics import accuracy_score, f1_score
def perf_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc, 'f1 score': f1}

In [24]:
from transformers import Trainer, TrainingArguments
batch_size = 4 
logging_steps = len(dataset_encoded['train']) // batch_size
model_name = 'Custom Text Classifier'
training_args = TrainingArguments(output_dir=model_name,
                                  num_train_epochs=10, 
                                  learning_rate=5e-5,
                                  per_device_train_batch_size=batch_size,
                                  per_device_eval_batch_size=batch_size,
                                  weight_decay=0.01,
                                  evaluation_strategy='epoch',
                                  disable_tqdm=False,
                                  logging_steps=logging_steps,
                                  log_level='error')

In [25]:
trainer = Trainer(model=model,
                  args=training_args,
                  compute_metrics=perf_metrics,
                  train_dataset=dataset_encoded['train'],
                  eval_dataset=dataset_encoded['test'],
                  tokenizer=tokenizer,
                  )

In [26]:
trainer.train()

  9%|▉         | 18/190 [00:22<03:21,  1.17s/it]

{'loss': 1.1015, 'grad_norm': 0.6538131237030029, 'learning_rate': 4.5263157894736846e-05, 'epoch': 0.95}


 10%|█         | 19/190 [00:23<03:17,  1.15s/it]
 10%|█         | 19/190 [00:33<03:17,  1.15s/it]

{'eval_loss': 1.0977762937545776, 'eval_accuracy': 0.37777777777777777, 'eval_f1 score': 0.24951267056530213, 'eval_runtime': 9.4132, 'eval_samples_per_second': 9.561, 'eval_steps_per_second': 2.443, 'epoch': 1.0}


 19%|█▉        | 36/190 [00:53<03:04,  1.19s/it]

{'loss': 1.0963, 'grad_norm': 0.683247983455658, 'learning_rate': 4.0526315789473684e-05, 'epoch': 1.89}


                                                
 20%|██        | 38/190 [01:05<02:50,  1.12s/it]

{'eval_loss': 1.09470796585083, 'eval_accuracy': 0.35555555555555557, 'eval_f1 score': 0.2799452009978326, 'eval_runtime': 9.6232, 'eval_samples_per_second': 9.352, 'eval_steps_per_second': 2.39, 'epoch': 2.0}


 28%|██▊       | 54/190 [01:25<02:53,  1.28s/it]

{'loss': 1.0831, 'grad_norm': 3.0179920196533203, 'learning_rate': 3.578947368421053e-05, 'epoch': 2.84}


                                                
 30%|███       | 57/190 [01:38<02:37,  1.18s/it]

{'eval_loss': 1.0703274011611938, 'eval_accuracy': 0.37777777777777777, 'eval_f1 score': 0.2697248381787684, 'eval_runtime': 9.7578, 'eval_samples_per_second': 9.223, 'eval_steps_per_second': 2.357, 'epoch': 3.0}


 38%|███▊      | 72/190 [01:57<02:29,  1.27s/it]

{'loss': 0.9609, 'grad_norm': 2.96803879737854, 'learning_rate': 3.105263157894737e-05, 'epoch': 3.79}


                                                
 40%|████      | 76/190 [02:12<02:13,  1.17s/it]

{'eval_loss': 1.0513197183609009, 'eval_accuracy': 0.4666666666666667, 'eval_f1 score': 0.41517315187527953, 'eval_runtime': 10.2654, 'eval_samples_per_second': 8.767, 'eval_steps_per_second': 2.241, 'epoch': 4.0}


 47%|████▋     | 90/190 [02:33<02:36,  1.56s/it]

{'loss': 0.9128, 'grad_norm': 5.225546836853027, 'learning_rate': 2.6315789473684212e-05, 'epoch': 4.74}


                                                
 50%|█████     | 95/190 [02:57<03:35,  2.27s/it]

{'eval_loss': 0.9911127686500549, 'eval_accuracy': 0.5111111111111111, 'eval_f1 score': 0.46614318053197507, 'eval_runtime': 14.1106, 'eval_samples_per_second': 6.378, 'eval_steps_per_second': 1.63, 'epoch': 5.0}


 57%|█████▋    | 108/190 [03:22<02:42,  1.98s/it]

{'loss': 0.8147, 'grad_norm': 5.446675777435303, 'learning_rate': 2.1578947368421053e-05, 'epoch': 5.68}


                                                 
 60%|██████    | 114/190 [03:49<02:25,  1.91s/it]

{'eval_loss': 1.1856497526168823, 'eval_accuracy': 0.5111111111111111, 'eval_f1 score': 0.4958559188669348, 'eval_runtime': 14.6417, 'eval_samples_per_second': 6.147, 'eval_steps_per_second': 1.571, 'epoch': 6.0}


 66%|██████▋   | 126/190 [04:13<02:10,  2.04s/it]

{'loss': 0.7838, 'grad_norm': 13.084046363830566, 'learning_rate': 1.6842105263157896e-05, 'epoch': 6.63}


                                                 
 70%|███████   | 133/190 [04:42<01:55,  2.03s/it]

{'eval_loss': 1.1436775922775269, 'eval_accuracy': 0.5666666666666667, 'eval_f1 score': 0.5748077215961012, 'eval_runtime': 14.2867, 'eval_samples_per_second': 6.3, 'eval_steps_per_second': 1.61, 'epoch': 7.0}


 76%|███████▌  | 144/190 [05:05<01:37,  2.12s/it]

{'loss': 0.7678, 'grad_norm': 2.1491448879241943, 'learning_rate': 1.2105263157894737e-05, 'epoch': 7.58}


                                                 
 80%|████████  | 152/190 [05:36<01:09,  1.83s/it]

{'eval_loss': 1.2131448984146118, 'eval_accuracy': 0.5111111111111111, 'eval_f1 score': 0.5112956260497243, 'eval_runtime': 15.3702, 'eval_samples_per_second': 5.856, 'eval_steps_per_second': 1.496, 'epoch': 8.0}


 85%|████████▌ | 162/190 [05:56<01:00,  2.17s/it]

{'loss': 0.5842, 'grad_norm': 3.872788667678833, 'learning_rate': 7.3684210526315784e-06, 'epoch': 8.53}


                                                 
 90%|█████████ | 171/190 [06:28<00:36,  1.89s/it]

{'eval_loss': 1.1523464918136597, 'eval_accuracy': 0.5111111111111111, 'eval_f1 score': 0.5112956260497243, 'eval_runtime': 15.1302, 'eval_samples_per_second': 5.948, 'eval_steps_per_second': 1.52, 'epoch': 9.0}


 95%|█████████▍| 180/190 [06:47<00:23,  2.33s/it]

{'loss': 0.5264, 'grad_norm': 46.87462615966797, 'learning_rate': 2.631578947368421e-06, 'epoch': 9.47}


                                                 
100%|██████████| 190/190 [07:24<00:00,  2.34s/it]

{'eval_loss': 1.1298226118087769, 'eval_accuracy': 0.5333333333333333, 'eval_f1 score': 0.5293763799787896, 'eval_runtime': 15.9157, 'eval_samples_per_second': 5.655, 'eval_steps_per_second': 1.445, 'epoch': 10.0}
{'train_runtime': 444.0019, 'train_samples_per_second': 1.689, 'train_steps_per_second': 0.428, 'train_loss': 0.8464475179973402, 'epoch': 10.0}





TrainOutput(global_step=190, training_loss=0.8464475179973402, metrics={'train_runtime': 444.0019, 'train_samples_per_second': 1.689, 'train_steps_per_second': 0.428, 'total_flos': 17925348096000.0, 'train_loss': 0.8464475179973402, 'epoch': 10.0})

In [27]:
preds_output = trainer.predict(dataset_encoded['test'])
preds_output.metrics

100%|██████████| 23/23 [00:15<00:00,  1.50it/s]


{'test_loss': 1.1298226118087769,
 'test_accuracy': 0.5333333333333333,
 'test_f1 score': 0.5293763799787896,
 'test_runtime': 16.0848,
 'test_samples_per_second': 5.595,
 'test_steps_per_second': 1.43}