In [1]:
from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
from datasets import load_dataset
from scipy.special import softmax
import evaluate
import numpy as np
import os
from huggingface_hub import login

In [2]:
hf_key = os.getenv('HuggingFace_Key')

In [3]:
login(token=hf_key)

In [4]:
ds = load_dataset('imdb')

In [5]:
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

In [6]:
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert/distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")

In [8]:
def preprocess_text(ds):
    return tokenizer(ds["text"], truncation=True)

In [9]:
tokenized_ds = ds.map(preprocess_text, batched=True)

In [10]:
roc_auc = evaluate.load('roc_auc', 'binary')
accuracy = evaluate.load('accuracy')
f1 = evaluate.load('f1')
precision = evaluate.load('precision')

In [11]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    pred = np.argmax(logits, axis=1)
    probs = softmax(logits).max(axis=1)
    return {
        'accuracy': accuracy.compute(predictions=pred, references=labels)['accuracy'],
        'f1': f1.compute(predictions=pred, references=labels)['f1'],
        'precision': precision.compute(predictions=pred, references=labels)['precision'],
        'roc_auc': roc_auc.compute(prediction_scores=probs, references=labels)['roc_auc']
    }

In [12]:
repo_name = "LHL_LLM_Project"

training_args = TrainingArguments(
    learning_rate=7.578566728652807e-06,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=11,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    output_dir=repo_name,
    push_to_hub=True,
)

In [13]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds['train'],
    eval_dataset=tokenized_ds['test'],
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
)

In [14]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Roc Auc
1,No log,0.214715,0.91604,0.916037,0.916073,0.694555
2,0.295600,0.195826,0.925,0.924714,0.928255,0.667926
3,0.182400,0.202159,0.925,0.926236,0.911216,0.773378
4,0.146100,0.202633,0.92768,0.92802,0.92368,0.757221
5,0.146100,0.233143,0.9232,0.92514,0.902343,0.816395
6,0.109700,0.242889,0.926,0.927085,0.913689,0.813913
7,0.084800,0.264023,0.9234,0.925198,0.903977,0.849364
8,0.070700,0.262245,0.92764,0.928097,0.922269,0.823711
9,0.055800,0.286813,0.9256,0.926331,0.91732,0.84423
10,0.055800,0.292641,0.92628,0.92714,0.916452,0.851365


No files have been modified since last commit. Skipping to prevent empty commit.


TrainOutput(global_step=4301, training_loss=0.1188958251623186, metrics={'train_runtime': 16338.2016, 'train_samples_per_second': 16.832, 'train_steps_per_second': 0.263, 'total_flos': 3.64284828853224e+16, 'train_loss': 0.1188958251623186, 'epoch': 11.0})

In [17]:
trainer.evaluate()

{'eval_loss': 0.19582612812519073,
 'eval_accuracy': 0.925,
 'eval_f1': 0.9247139128689018,
 'eval_precision': 0.928254735993551,
 'eval_roc_auc': 0.6679258527999999,
 'eval_runtime': 94.1169,
 'eval_samples_per_second': 265.627,
 'eval_steps_per_second': 4.154,
 'epoch': 11.0}

In [16]:
trainer.push_to_hub()

CommitInfo(commit_url='https://huggingface.co/Gur212/LHL_LLM_Project/commit/d1813eee47bb8aad97d51fa6f4fb35fe4f94cd74', commit_message='End of training', commit_description='', oid='d1813eee47bb8aad97d51fa6f4fb35fe4f94cd74', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Gur212/LHL_LLM_Project', endpoint='https://huggingface.co', repo_type='model', repo_id='Gur212/LHL_LLM_Project'), pr_revision=None, pr_num=None)