In [1]:
!pip install -q datasets transformers

In [17]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, f1_score

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hf_data_id = 'alxxtexxr/emotion-no-love'
hf_model_id = 'google-bert/bert-base-uncased'
project_name = 'BERT-Base-emotion-no-love-v0.1'

In [9]:
datasets = load_dataset(hf_data_id)
datasets

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 10638
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1296
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1305
    })
})

In [11]:
tokenizer = AutoTokenizer.from_pretrained(hf_model_id)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



In [13]:
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

tokenized_datasetes = datasets.map(tokenize, batched=True, batch_size=None)

Map:   0%|          | 0/10638 [00:00<?, ? examples/s]

In [26]:
model = AutoModelForSequenceClassification.from_pretrained(hf_model_id, num_labels=6).to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [28]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc, 'f1': f1}

batch_size = 64
logging_steps = len(tokenized_datasetes['train']) // batch_size
train_args = TrainingArguments(
    output_dir=project_name,
    num_train_epochs=8,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    weight_decay=0.01,
    eval_strategy='epoch',
    save_strategy='epoch',
    disable_tqdm=False,
)

trainer = Trainer(
    model=model, 
    args=train_args,
    compute_metrics=compute_metrics,
    train_dataset=tokenized_datasetes['train'],
    eval_dataset=tokenized_datasetes['validation'],
)

In [29]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malimtegar[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,2.120402,0.422068,0.424652
2,No log,3.013069,0.419753,0.423959
3,0.393000,3.44292,0.417438,0.415903
4,0.393000,3.751264,0.42284,0.414037
5,0.393000,3.934924,0.429784,0.4251
6,0.060200,4.041559,0.421296,0.418872
7,0.060200,4.124743,0.422068,0.419453
8,0.060200,4.170523,0.425926,0.421894


TrainOutput(global_step=1336, training_loss=0.17772742731128624, metrics={'train_runtime': 1287.9505, 'train_samples_per_second': 66.077, 'train_steps_per_second': 1.037, 'total_flos': 3804993842891328.0, 'train_loss': 0.17772742731128624, 'epoch': 8.0})

In [30]:
trainer.push_to_hub()

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

events.out.tfevents.1730049748.3e4667c18132.2376.0:   0%|          | 0.00/8.96k [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.24k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/alxxtexxr/BERT-Base-emotion-no-love-v0.1/commit/ac9ad6b27ab6a323dbc602545a0190d2a484dd27', commit_message='End of training', commit_description='', oid='ac9ad6b27ab6a323dbc602545a0190d2a484dd27', pr_url=None, pr_revision=None, pr_num=None)