In [3]:
import numpy as np
import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import TrainingArguments
from transformers import Trainer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=9)
import pandas as pd
from nltk.tokenize import word_tokenize
from nltk.tokenize import sent_tokenize

from nltk.corpus import stopwords
import nltk
nltk.download('punkt_tab')
nltk.download('stopwords')
# remove stopwords

df = pd.read_csv('USAirlinesTweets.csv', encoding='latin1')
df=df[df['sentiment']=='negative'][df['negativereason']!="Can't Tell"] # filtering dataframe
label_mapping = {"Flight Attendant Complaints": 0, "Bad Flight": 1, "Customer Service Issue": 2, 'Lost Luggage': 3, 'Late Flight': 4, 'Damaged Luggage': 5, 'Cancelled Flight': 6, 'longlines': 7, 'Flight Booking Problems': 8}
news_quality = df['negativereason'].map(label_mapping)
news_text = df['tweet'].tolist()
x=[]
words=stopwords.words("english")
words+=['@','!','.',',','?','#']
for i in news_text:
  tweet=i
  tweet=word_tokenize(tweet.lower())
  tweet=[j for j in tweet if j not in words]
  y=' '.join(tweet)
  for i in ['@','!','.',',','?','#']:
    y=y.replace(i,'')
  x.append(y)
news_text=x


train_news_text, test_news_text, train_news_quality, test_news_quality = train_test_split(
    news_text, news_quality, test_size=0.3, random_state=42)

encodings = tokenizer(train_news_text, truncation=True, padding=True, return_tensors="pt")
encodings_test = tokenizer(test_news_text, truncation = True, padding=True, return_tensors = "pt")
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item
train_dataset = CustomDataset(encodings, torch.tensor(train_news_quality.values))
test_dataset = CustomDataset(encodings_test, torch.tensor(test_news_quality.values))
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1).numpy()
    predictions = np.argmax(probs, axis=1)  # pick class with highest probability

    report = classification_report(labels, predictions, output_dict=True)
    acc = accuracy_score(labels, predictions)

    metrics = {
        "accuracy": acc,
        "f1_weighted": report["weighted avg"]["f1-score"],
        "precision_weighted": report["weighted avg"]["precision"],
        "recall_weighted": report["weighted avg"]["recall"]
    }
    for class_label in ["0", "1", "2", "3", "4", "5", "6", "7", "8"]:
        if class_label in report: # Check if class label exists in the report
            metrics[f"precision_class_{class_label}"] = report[class_label]["precision"]
            metrics[f"recall_class_{class_label}"] = report[class_label]["recall"]
            metrics[f"f1_class_{class_label}"] = report[class_label]["f1-score"]
            metrics[f"support_class_{class_label}"] = report[class_label]["support"]
        else: # If class label doesn't exist, set metrics to 0
            metrics[f"precision_class_{class_label}"] = 0.0
            metrics[f"recall_class_{class_label}"] = 0.0
            metrics[f"f1_class_{class_label}"] = 0.0
            metrics[f"support_class_{class_label}"] = 0.0
    return metrics

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    num_train_epochs=5,
    learning_rate=1e-5
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics)
trainer.train()
trainer.evaluate()
results = trainer.evaluate()

print(results)
preds = trainer.predict(train_dataset)
preds = np.argmax(preds[:3][0],axis=1)
GT = train_news_quality
print(classification_report(GT,preds))
preds = trainer.predict(test_dataset)
preds = np.argmax(preds[:3][0],axis=1)
GT = test_news_quality
print(classification_report(GT,preds))

sentences=['The airport lost my bags, would not recommend.', \
           'The customer service was horrendous and didn\'t listen to me, I demand my money back!', \
           'I wanted to book some seats all in a row for my family, but the website wouldn\'t let me even though they all showed as available.', \
           "Why is there still only one terminal for this section? I have to wait forever to get to my flight.", \
           "I had instruments in my luggage, but they were all damaged when I got them back. I demand to be compensated for this!", \
           "The flight attendants were very rude to me and refused to serve me snacks.", \
           ]
print(sentences)
x=CustomDataset(tokenizer(sentences, truncation=True, padding=True, return_tensors="pt"), torch.tensor(np.array([1]*len(sentences))))
preds2 = trainer.predict(x)
print(preds2)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
  df=df[df['sentiment']=='negative'][df['negativereason']!="Can't Tell"] # filtering dataframe


Step,Training Loss
500,1.5605
1000,1.0734
1500,0.9227
2000,0.8527
2500,0.7404
3000,0.71




Step,Training Loss
500,1.5605
1000,1.0734
1500,0.9227
2000,0.8527
2500,0.7404
3000,0.71




  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.032793641090393, 'eval_accuracy': 0.6896120150187734, 'eval_f1_weighted': 0.6754757305107031, 'eval_precision_weighted': 0.6646066836033765, 'eval_recall_weighted': 0.6896120150187734, 'eval_precision_class_0': 0.48717948717948717, 'eval_recall_class_0': 0.40425531914893614, 'eval_f1_class_0': 0.4418604651162791, 'eval_support_class_0': 141.0, 'eval_precision_class_1': 0.5094339622641509, 'eval_recall_class_1': 0.5192307692307693, 'eval_f1_class_1': 0.5142857142857142, 'eval_support_class_1': 156.0, 'eval_precision_class_2': 0.7330595482546202, 'eval_recall_class_2': 0.796875, 'eval_f1_class_2': 0.7636363636363637, 'eval_support_class_2': 896.0, 'eval_precision_class_3': 0.7149122807017544, 'eval_recall_class_3': 0.7309417040358744, 'eval_f1_class_3': 0.7228381374722838, 'eval_support_class_3': 223.0, 'eval_precision_class_4': 0.6962699822380106, 'eval_recall_class_4': 0.7701375245579568, 'eval_f1_class_4': 0.7313432835820896, 'eval_support_class_4': 509.0, 'eval_precis

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           0       0.71      0.75      0.73       340
           1       0.77      0.79      0.78       424
           2       0.88      0.91      0.89      2014
           3       0.84      0.88      0.86       501
           4       0.84      0.92      0.87      1156
           5       0.00      0.00      0.00        54
           6       0.86      0.82      0.84       592
           7       1.00      0.04      0.07       130
           8       0.76      0.71      0.73       380

    accuracy                           0.84      5591
   macro avg       0.74      0.65      0.64      5591
weighted avg       0.83      0.84      0.82      5591



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           0       0.49      0.40      0.44       141
           1       0.51      0.52      0.51       156
           2       0.73      0.80      0.76       896
           3       0.71      0.73      0.72       223
           4       0.70      0.77      0.73       509
           5       0.00      0.00      0.00        20
           6       0.78      0.74      0.76       255
           7       0.00      0.00      0.00        48
           8       0.50      0.39      0.44       149

    accuracy                           0.69      2397
   macro avg       0.49      0.48      0.49      2397
weighted avg       0.66      0.69      0.68      2397

['The airport lost my bags, would not recommend.', "The customer service was horrendous and didn't listen to me, I demand my money back!", "I wanted to book some seats all in a row for my family, but the website wouldn't let me even though they all showed as available.", 'Why is there still onl

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


PredictionOutput(predictions=array([[-0.39649984, -1.2392728 , -0.41867024,  3.775128  , -0.76235855,
         0.15429991, -0.30775326, -0.35913384, -0.87179023],
       [-0.31218076, -1.1314658 ,  4.722972  , -0.698938  , -0.57097685,
        -1.643983  , -0.60682726, -0.701203  , -0.11647189],
       [-0.48680884,  0.46946236,  1.3060435 , -0.8938089 , -0.77265763,
        -1.490014  , -0.18414132, -0.9677125 ,  1.9347365 ],
       [-0.519021  , -0.6651811 ,  0.60306764, -0.6721138 ,  1.7523743 ,
        -1.4529951 ,  0.8087218 , -0.03857224, -0.86727756],
       [-0.1449628 , -0.59634924,  0.09826794,  2.5973651 , -0.4798279 ,
        -0.07439104, -0.61499524, -0.52286494, -0.69575477],
       [ 2.6968806 ,  0.01893388,  0.84494245, -0.8049284 , -0.519119  ,
        -1.1598305 , -0.90577173, -0.16752988, -1.1670196 ]],
      dtype=float32), label_ids=array([1, 1, 1, 1, 1, 1]), metrics={'test_loss': 3.8463058471679688, 'test_accuracy': 0.0, 'test_f1_weighted': 0.0, 'test_precision_we