In [1]:
import torch
torch.cuda.empty_cache()

In [2]:
import pandas as pd

# Load the datasets
ds_train = pd.read_pickle('pickles/ds_train.pkl')
ds_test = pd.read_pickle('pickles/ds_test.pkl')


# Rename the columns
ds_train = ds_train.drop(columns=['text'])
ds_test = ds_test.drop(columns=['text'])
ds_train = ds_train.rename(columns={'label':'og_label', 'simple_topic':'label', 'no_stopword':'text'})
ds_test = ds_test.rename(columns={'label':'og_label', 'simple_topic':'label', 'no_stopword':'text'})

# Create new datasets
from datasets import Dataset, DatasetDict
new_train = Dataset.from_pandas(ds_train[['label','text']])
new_test = Dataset.from_pandas(ds_test[['label','text']])

# Create a DatasetDict
new_ds = DatasetDict({
    'train': new_train,
    'test': new_test
})

# Save the new datasets to disk
# new_ds.save_to_disk('data')


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
ds_train

Unnamed: 0,og_label,label_text,label,preprocess,text
0,7,rec.autos,3,i was wondering if anyone out there could enli...,wondering anyone could enlighten car saw day 2...
1,4,comp.sys.mac.hardware,1,a fair number of brave souls who upgraded thei...,fair number brave souls upgraded si clock osci...
2,4,comp.sys.mac.hardware,1,well folks my mac plus finally gave up the gho...,well folks mac plus finally gave ghost weekend...
3,1,comp.graphics,1,do you have weiteks addressphone number id l...,weiteks addressphone number id like get inform...
4,14,sci.space,4,from article by tom a baker my understandi...,article tom baker understanding expected error...
...,...,...,...,...,...
11309,13,sci.med,4,dn from david nye dn a neurology dn consultat...,dn david nye dn neurology dn consultation chea...
11310,4,comp.sys.mac.hardware,1,i have a very old mac 512k and a mac plus both...,old mac 512k mac plus problem screens blank so...
11311,3,comp.sys.ibm.pc.hardware,1,i just installed a dx266 cpu in a clone mother...,installed dx266 cpu clone motherboard tried mo...
11312,1,comp.graphics,1,wouldnt this require a hypersphere in 3space...,wouldnt require hypersphere 3space 4 points sp...


In [4]:
new_train[0]

{'label': 3,
 'text': 'wondering anyone could enlighten car saw day 2door sports car looked late 60s early 70s called bricklin doors really small addition front bumper separate rest body know anyone tellme model name engine specs years production car made history whatever info funky looking car please email'}

In [5]:

# # Load model and tokenizer

import torch
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
# model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased") 

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

tokenized_train = new_train.map(preprocess_function, batched=True)
tokenized_test = new_test.map(preprocess_function, batched=True)

Map: 100%|██████████| 11314/11314 [00:01<00:00, 7174.96 examples/s]
Map: 100%|██████████| 7532/7532 [00:00<00:00, 9086.42 examples/s]


In [6]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

In [7]:
new_test[0]

{'label': 3,
 'text': 'little confused models 8889 bonnevilles heard le se lse sse ssei could someone tell differences far features performance also curious know book value prefereably 89 model much less book value usually get words much demand time year heard midspring early summer best time buy'}

In [8]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
# label1 = 0, etc
id2label = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6}
label2id = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6}
model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased", num_labels=7, id2label=id2label, label2id=label2id)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [28]:
import numpy as np
import evaluate

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    return accuracy.compute(predictions=predictions, references=labels)

evaluate.load("accuracy")

EvaluationModule(name: "accuracy", module_type: "metric", features: {'predictions': Value(dtype='int32', id=None), 'references': Value(dtype='int32', id=None)}, usage: """
Args:
    predictions (`list` of `int`): Predicted labels.
    references (`list` of `int`): Ground truth labels.
    normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True.
    sample_weight (`list` of `float`): Sample weights Defaults to None.

Returns:
    accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy.

Examples:

    Example 1-A simple example
        >>> accuracy_metric = evaluate.load("accuracy")
        >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
        >>> print(results)
    

In [10]:
training_args = TrainingArguments(
    output_dir="my_awesome_model",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

  4%|▍         | 500/11314 [00:47<18:23,  9.80it/s]

{'loss': 1.2045, 'grad_norm': 5.472842693328857, 'learning_rate': 1.911613929644688e-05, 'epoch': 0.09}


  9%|▉         | 1002/11314 [01:33<15:45, 10.90it/s]

{'loss': 0.9352, 'grad_norm': 0.47719475626945496, 'learning_rate': 1.823227859289376e-05, 'epoch': 0.18}


 13%|█▎        | 1502/11314 [02:18<15:06, 10.82it/s]

{'loss': 0.9292, 'grad_norm': 0.13571925461292267, 'learning_rate': 1.7348417889340642e-05, 'epoch': 0.27}


 18%|█▊        | 2002/11314 [03:03<14:07, 10.98it/s]

{'loss': 0.8531, 'grad_norm': 21.148666381835938, 'learning_rate': 1.6464557185787523e-05, 'epoch': 0.35}


 22%|██▏       | 2502/11314 [03:49<12:57, 11.33it/s]

{'loss': 0.7801, 'grad_norm': 0.0958327129483223, 'learning_rate': 1.55806964822344e-05, 'epoch': 0.44}


 27%|██▋       | 3001/11314 [04:36<12:06, 11.44it/s]

{'loss': 0.8329, 'grad_norm': 9.007207870483398, 'learning_rate': 1.4696835778681282e-05, 'epoch': 0.53}


 31%|███       | 3501/11314 [05:23<10:28, 12.44it/s]

{'loss': 0.7411, 'grad_norm': 22.806806564331055, 'learning_rate': 1.3812975075128161e-05, 'epoch': 0.62}


 35%|███▌      | 4002/11314 [06:08<12:11, 10.00it/s]

{'loss': 0.8646, 'grad_norm': 22.182212829589844, 'learning_rate': 1.292911437157504e-05, 'epoch': 0.71}


 40%|███▉      | 4501/11314 [06:54<10:57, 10.37it/s]

{'loss': 0.7167, 'grad_norm': 22.938961029052734, 'learning_rate': 1.2045253668021922e-05, 'epoch': 0.8}


 44%|████▍     | 5001/11314 [07:41<12:26,  8.46it/s]

{'loss': 0.7798, 'grad_norm': 9.006744384765625, 'learning_rate': 1.11613929644688e-05, 'epoch': 0.88}


 49%|████▊     | 5500/11314 [08:26<08:14, 11.75it/s]

{'loss': 0.7583, 'grad_norm': 52.29185104370117, 'learning_rate': 1.027753226091568e-05, 'epoch': 0.97}


                                                    
 50%|█████     | 5657/11314 [09:53<09:24, 10.03it/s]

{'eval_loss': 0.8091230392456055, 'eval_accuracy': 0.7971322357939459, 'eval_runtime': 72.8392, 'eval_samples_per_second': 103.406, 'eval_steps_per_second': 51.703, 'epoch': 1.0}


 53%|█████▎    | 6000/11314 [10:25<07:27, 11.88it/s]   

{'loss': 0.5593, 'grad_norm': 0.042891595512628555, 'learning_rate': 9.39367155736256e-06, 'epoch': 1.06}


 57%|█████▋    | 6503/11314 [11:11<06:42, 11.95it/s]

{'loss': 0.4982, 'grad_norm': 12.954984664916992, 'learning_rate': 8.50981085380944e-06, 'epoch': 1.15}


 62%|██████▏   | 7002/11314 [11:57<06:56, 10.34it/s]

{'loss': 0.5896, 'grad_norm': 123.51172637939453, 'learning_rate': 7.625950150256321e-06, 'epoch': 1.24}


 66%|██████▋   | 7501/11314 [12:46<08:11,  7.76it/s]

{'loss': 0.5254, 'grad_norm': 94.30730438232422, 'learning_rate': 6.7420894467032e-06, 'epoch': 1.33}


 71%|███████   | 8002/11314 [13:33<06:01,  9.16it/s]

{'loss': 0.5198, 'grad_norm': 89.36083221435547, 'learning_rate': 5.85822874315008e-06, 'epoch': 1.41}


 75%|███████▌  | 8502/11314 [14:19<03:35, 13.07it/s]

{'loss': 0.4941, 'grad_norm': 0.025787746533751488, 'learning_rate': 4.97436803959696e-06, 'epoch': 1.5}


 80%|███████▉  | 9001/11314 [15:04<03:17, 11.73it/s]

{'loss': 0.5378, 'grad_norm': 1.4332557916641235, 'learning_rate': 4.0905073360438394e-06, 'epoch': 1.59}


 84%|████████▍ | 9503/11314 [15:51<02:19, 12.99it/s]

{'loss': 0.4753, 'grad_norm': 0.04786647856235504, 'learning_rate': 3.2066466324907197e-06, 'epoch': 1.68}


 88%|████████▊ | 10001/11314 [16:37<01:49, 11.99it/s]

{'loss': 0.5677, 'grad_norm': 0.027879901230335236, 'learning_rate': 2.3227859289375996e-06, 'epoch': 1.77}


 93%|█████████▎| 10502/11314 [17:24<01:04, 12.61it/s]

{'loss': 0.487, 'grad_norm': 0.01717495732009411, 'learning_rate': 1.4389252253844795e-06, 'epoch': 1.86}


 97%|█████████▋| 11002/11314 [18:12<00:32,  9.52it/s]

{'loss': 0.5201, 'grad_norm': 0.023554306477308273, 'learning_rate': 5.550645218313594e-07, 'epoch': 1.94}


                                                     
100%|██████████| 11314/11314 [19:58<00:00, 11.93it/s]

{'eval_loss': 0.9071014523506165, 'eval_accuracy': 0.8088157195963888, 'eval_runtime': 74.2881, 'eval_samples_per_second': 101.389, 'eval_steps_per_second': 50.695, 'epoch': 2.0}


100%|██████████| 11314/11314 [20:01<00:00,  9.42it/s]

{'train_runtime': 1201.0206, 'train_samples_per_second': 18.841, 'train_steps_per_second': 9.42, 'train_loss': 0.6891144524409574, 'epoch': 2.0}





TrainOutput(global_step=11314, training_loss=0.6891144524409574, metrics={'train_runtime': 1201.0206, 'train_samples_per_second': 18.841, 'train_steps_per_second': 9.42, 'total_flos': 876784157343564.0, 'train_loss': 0.6891144524409574, 'epoch': 2.0})

In [30]:
from sklearn.metrics import classification_report
true_labels = ds_test['label'].values

predictions = trainer.predict(tokenized_test)
predictions_tensor = torch.tensor(predictions.predictions)


predictions = torch.argmax(predictions_tensor, dim=1).cpu().numpy()

report = classification_report(true_labels, predictions)

100%|██████████| 3766/3766 [01:12<00:00, 51.63it/s]


In [32]:
print(report)

              precision    recall  f1-score   support

           0       0.48      0.44      0.46       319
           1       0.89      0.86      0.88      1955
           2       0.86      0.73      0.79       390
           3       0.78      0.90      0.84      1590
           4       0.77      0.80      0.78      1579
           5       0.69      0.75      0.71       398
           6       0.81      0.69      0.75      1301

    accuracy                           0.80      7532
   macro avg       0.75      0.74      0.74      7532
weighted avg       0.80      0.80      0.80      7532

