In [1]:
from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
from datasets import load_dataset
from scipy.special import softmax
import evaluate
import numpy as np
import os
from huggingface_hub import login

In [2]:
hf_key = os.getenv('HuggingFace_Key')

In [3]:
login(token=hf_key)

In [4]:
ds = load_dataset('imdb')

In [5]:
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

In [6]:
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert/distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")

In [8]:
def preprocess_text(ds):
    return tokenizer(ds["text"], truncation=True)

In [9]:
tokenized_ds = ds.map(preprocess_text, batched=True)

In [10]:
roc_auc = evaluate.load('roc_auc', 'binary')
accuracy = evaluate.load('accuracy')
f1 = evaluate.load('f1')
precision = evaluate.load('precision')

In [11]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    pred = np.argmax(logits, axis=1)
    probs = softmax(pred)
    return {
        'accuracy': accuracy.compute(predictions=pred, references=labels)['accuracy'],
        'f1': f1.compute(predictions=pred, references=labels)['f1'],
        'precision': precision.compute(predictions=pred, references=labels)['precision'],
        'roc_auc': roc_auc.compute(prediction_scores=probs, references=labels)['roc_auc']
    }

In [12]:
repo_name = "LHL_LLM_Project"

training_args = TrainingArguments(
    learning_rate=7.578566728652807e-06,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=11,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    output_dir=repo_name,
    push_to_hub=True,
)

In [13]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds['train'],
    eval_dataset=tokenized_ds['test'],
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
)

In [14]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Roc Auc
1,No log,0.211971,0.91828,0.918479,0.916249,0.91828
2,0.291700,0.196347,0.92572,0.925842,0.924328,0.92572
3,0.179300,0.203592,0.9256,0.926847,0.911574,0.9256
4,0.143800,0.199853,0.92832,0.928537,0.925732,0.92832
5,0.143800,0.228076,0.92484,0.926061,0.911252,0.92484
6,0.107300,0.239606,0.92588,0.927067,0.912451,0.92588
7,0.084300,0.257814,0.9254,0.926526,0.912753,0.9254
8,0.068500,0.269769,0.928,0.928167,0.926023,0.928
9,0.055500,0.284446,0.92632,0.926783,0.920999,0.92632
10,0.055500,0.295047,0.92568,0.926387,0.917661,0.92568


No files have been modified since last commit. Skipping to prevent empty commit.


TrainOutput(global_step=4301, training_loss=0.11706939959797685, metrics={'train_runtime': 6504.6231, 'train_samples_per_second': 42.278, 'train_steps_per_second': 0.661, 'total_flos': 3.64284828853224e+16, 'train_loss': 0.11706939959797685, 'epoch': 11.0})

In [15]:
trainer.evaluate()

{'eval_loss': 0.1963466852903366,
 'eval_accuracy': 0.92572,
 'eval_f1': 0.9258416197436204,
 'eval_precision': 0.9243282034925444,
 'eval_roc_auc': 0.92572,
 'eval_runtime': 277.1905,
 'eval_samples_per_second': 90.191,
 'eval_steps_per_second': 1.411,
 'epoch': 11.0}

In [16]:
trainer.push_to_hub()

CommitInfo(commit_url='https://huggingface.co/Gur212/LHL_LLM_Project/commit/045085741b49b9ab321b66a05ed1ed9f2e55d632', commit_message='End of training', commit_description='', oid='045085741b49b9ab321b66a05ed1ed9f2e55d632', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Gur212/LHL_LLM_Project', endpoint='https://huggingface.co', repo_type='model', repo_id='Gur212/LHL_LLM_Project'), pr_revision=None, pr_num=None)