# Load Dataset from HF
Loading from Quesmed organization

In [11]:
from datasets import load_dataset
import os

os.environ['TOKENIZERS_PARALLELISM'] = 'true'

ds = load_dataset('quesmed/comment_sentiment', token=True)
ds

DatasetDict({
    test: Dataset({
        features: ['id', 'createdAt', 'userId', 'userCreatedAt', 'classYear', 'universityId', 'country', 'universityName', 'parentId', 'questionId', 'comment', 'review', 'negative', 'neutral', 'positive', 'tone', 'sadness', 'joy', 'love', 'anger', 'fear', 'surprise', 'emotion', 'educational', 'giving feedback', 'asking a question', 'insulting', 'supporting', 'humour', 'frustration', 'theme'],
        num_rows: 15
    })
    train: Dataset({
        features: ['id', 'createdAt', 'userId', 'userCreatedAt', 'classYear', 'universityId', 'country', 'universityName', 'parentId', 'questionId', 'comment', 'review', 'negative', 'neutral', 'positive', 'tone', 'sadness', 'joy', 'love', 'anger', 'fear', 'surprise', 'emotion', 'educational', 'giving feedback', 'asking a question', 'insulting', 'supporting', 'humour', 'frustration', 'theme'],
        num_rows: 120
    })
    validate: Dataset({
        features: ['id', 'createdAt', 'userId', 'userCreatedAt', 'classY

In [12]:
from datasets import Dataset
def isolate_dataset(ds: Dataset, feature: str):
    cols = ds.column_names['train']
    col_keep = {'comment', feature}
    
    ds_filter = ds.remove_columns(col_keep.symmetric_difference(cols))
    ds_filter = ds_filter.rename_column(feature, 'label')
    ds_filter = ds_filter.class_encode_column('label')

    return ds_filter

## Setup model and trainer

In [13]:
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoConfig
import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

def init_model(model_path: str):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    config = AutoConfig.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)

    return (tokenizer, config, model)

In [14]:
def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")

In [15]:
from transformers import Trainer, TrainingArguments, logging
from datasets import Dataset
import numpy as np
import evaluate

metric = evaluate.load('accuracy')

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

logging.set_verbosity_error()

def setup_trainer(name: str, dataset: Dataset, model, tokenizer, push_to_hub=False):
    model_name = f"../fine-tuning-chkp/{name}"

    training_args = TrainingArguments(
        output_dir=model_name,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_steps=0.2,
        num_train_epochs=5,
        learning_rate=2e-5,
        weight_decay=0.01,
        metric_for_best_model="accuracy",
        load_best_model_at_end=True,
        disable_tqdm=False,
        push_to_hub=push_to_hub,
        hub_model_id=f"quesmed/{name}",
        hub_strategy="end",
        use_mps_device=True
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['validate'],
        compute_metrics=compute_metrics,
        tokenizer=tokenizer
    )

    return (trainer, training_args)

# Fine-tuning Tone

In [16]:
tone_tokenizer, tone_config, tone_model = init_model("cardiffnlp/twitter-roberta-base-sentiment-latest")

In [17]:
ds_tone = isolate_dataset(ds, 'tone')

ds_tone = ds_tone.map(
  lambda row: tone_tokenizer(row['comment'], max_length=512, padding='max_length', truncation=True, return_tensors='pt'), 
  batched=True,
  remove_columns=['comment']
)

ds_tone['train'].features

Casting to class labels:   0%|          | 0/15 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/120 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/15 [00:00<?, ? examples/s]

Map:   0%|          | 0/15 [00:00<?, ? examples/s]

Map:   0%|          | 0/120 [00:00<?, ? examples/s]

Map:   0%|          | 0/15 [00:00<?, ? examples/s]

{'label': ClassLabel(names=['negative', 'neutral', 'positive'], id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}

# Hyperparameter search
Using [optuna](https://optuna.org/) to find the optimal training parameters.

In [16]:
tone_trainer, tone_args = setup_trainer('tone', dataset=ds_tone, model=tone_model, tokenizer=tone_tokenizer, push_to_hub=True)

/Users/stefan/Github/sentiment_analysis/fine-tuning-chkp/tone is already a clone of https://huggingface.co/quesmed/tone. Make sure you pull the latest changes with `repo.git_pull()`.


In [10]:
def model_init_tone():
    _, _, model = init_model("cardiffnlp/twitter-roberta-base-sentiment-latest")
    return model

trainer = Trainer(
    model_init=model_init_tone,
    args=tone_args,
    train_dataset=ds_tone["train"],
    eval_dataset=ds_tone["validate"],
    tokenizer=tone_tokenizer,
    compute_metrics=compute_metrics
)

In [11]:
best_run_tone = trainer.hyperparameter_search(n_trials=10, direction="maximize")

[I 2023-07-29 18:46:10,336] A new study created in memory with name: no-name-f81e0bf5-7cfe-4734-9113-27393ba08a5e


  0%|          | 0/90 [00:00<?, ?it/s]

{'loss': 0.6555, 'learning_rate': 3.852047770983165e-06, 'epoch': 0.6}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.543163001537323, 'eval_accuracy': 0.8, 'eval_runtime': 0.5164, 'eval_samples_per_second': 29.048, 'eval_steps_per_second': 3.873, 'epoch': 1.0}
{'loss': 0.4577, 'learning_rate': 2.889035828237373e-06, 'epoch': 1.2}
{'loss': 0.4159, 'learning_rate': 1.9260238854915823e-06, 'epoch': 1.8}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.5304719805717468, 'eval_accuracy': 0.7333333333333333, 'eval_runtime': 0.3133, 'eval_samples_per_second': 47.879, 'eval_steps_per_second': 6.384, 'epoch': 2.0}
{'loss': 0.2909, 'learning_rate': 9.630119427457912e-07, 'epoch': 2.4}
{'loss': 0.3656, 'learning_rate': 0.0, 'epoch': 3.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.5722706913948059, 'eval_accuracy': 0.7333333333333333, 'eval_runtime': 0.3123, 'eval_samples_per_second': 48.038, 'eval_steps_per_second': 6.405, 'epoch': 3.0}


[I 2023-07-29 18:46:58,619] Trial 0 finished with value: 0.7333333333333333 and parameters: {'learning_rate': 4.8150597137289554e-06, 'num_train_epochs': 3, 'seed': 11, 'per_device_train_batch_size': 4}. Best is trial 0 with value: 0.7333333333333333.


{'train_runtime': 46.8413, 'train_samples_per_second': 7.686, 'train_steps_per_second': 1.921, 'train_loss': 0.437093718846639, 'epoch': 3.0}




  0%|          | 0/45 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.5463020205497742, 'eval_accuracy': 0.6666666666666666, 'eval_runtime': 0.3257, 'eval_samples_per_second': 46.053, 'eval_steps_per_second': 6.14, 'epoch': 1.0}
{'loss': 0.5889, 'learning_rate': 4.563825160945706e-06, 'epoch': 1.2}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.5267026424407959, 'eval_accuracy': 0.6666666666666666, 'eval_runtime': 0.3374, 'eval_samples_per_second': 44.459, 'eval_steps_per_second': 5.928, 'epoch': 2.0}
{'loss': 0.384, 'learning_rate': 1.521275053648569e-06, 'epoch': 2.4}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.5240800976753235, 'eval_accuracy': 0.7333333333333333, 'eval_runtime': 0.3364, 'eval_samples_per_second': 44.584, 'eval_steps_per_second': 5.945, 'epoch': 3.0}


[I 2023-07-29 18:47:41,074] Trial 1 finished with value: 0.7333333333333333 and parameters: {'learning_rate': 7.6063752682428446e-06, 'num_train_epochs': 3, 'seed': 25, 'per_device_train_batch_size': 8}. Best is trial 0 with value: 0.7333333333333333.


{'train_runtime': 41.1103, 'train_samples_per_second': 8.757, 'train_steps_per_second': 1.095, 'train_loss': 0.44660082393222383, 'epoch': 3.0}




  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.5844404101371765, 'eval_accuracy': 0.6666666666666666, 'eval_runtime': 0.3354, 'eval_samples_per_second': 44.724, 'eval_steps_per_second': 5.963, 'epoch': 1.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.5470119118690491, 'eval_accuracy': 0.6, 'eval_runtime': 0.3413, 'eval_samples_per_second': 43.945, 'eval_steps_per_second': 5.859, 'epoch': 2.0}


[I 2023-07-29 18:48:09,303] Trial 2 finished with value: 0.6 and parameters: {'learning_rate': 5.952961213419925e-06, 'num_train_epochs': 2, 'seed': 20, 'per_device_train_batch_size': 16}. Best is trial 0 with value: 0.7333333333333333.


{'train_runtime': 26.8874, 'train_samples_per_second': 8.926, 'train_steps_per_second': 0.595, 'train_loss': 0.519575297832489, 'epoch': 2.0}




  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.6356574892997742, 'eval_accuracy': 0.8, 'eval_runtime': 0.3922, 'eval_samples_per_second': 38.249, 'eval_steps_per_second': 5.1, 'epoch': 1.0}


[I 2023-07-29 18:49:14,149] Trial 3 finished with value: 0.8 and parameters: {'learning_rate': 1.0741683524249098e-06, 'num_train_epochs': 1, 'seed': 5, 'per_device_train_batch_size': 64}. Best is trial 3 with value: 0.8.


{'train_runtime': 63.2862, 'train_samples_per_second': 1.896, 'train_steps_per_second': 0.032, 'train_loss': 0.5893210172653198, 'epoch': 1.0}




  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.5288072228431702, 'eval_accuracy': 0.7333333333333333, 'eval_runtime': 0.3265, 'eval_samples_per_second': 45.941, 'eval_steps_per_second': 6.125, 'epoch': 1.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.557681143283844, 'eval_accuracy': 0.8666666666666667, 'eval_runtime': 0.3282, 'eval_samples_per_second': 45.706, 'eval_steps_per_second': 6.094, 'epoch': 2.0}


[I 2023-07-29 18:49:46,381] Trial 4 finished with value: 0.8666666666666667 and parameters: {'learning_rate': 6.426351850922471e-05, 'num_train_epochs': 2, 'seed': 31, 'per_device_train_batch_size': 16}. Best is trial 4 with value: 0.8666666666666667.


{'train_runtime': 30.646, 'train_samples_per_second': 7.831, 'train_steps_per_second': 0.522, 'train_loss': 0.39152830839157104, 'epoch': 2.0}




  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.6156949996948242, 'eval_accuracy': 0.8, 'eval_runtime': 0.3909, 'eval_samples_per_second': 38.372, 'eval_steps_per_second': 5.116, 'epoch': 1.0}


[I 2023-07-29 18:50:32,385] Trial 5 finished with value: 0.8 and parameters: {'learning_rate': 2.012115221496673e-06, 'num_train_epochs': 1, 'seed': 9, 'per_device_train_batch_size': 32}. Best is trial 4 with value: 0.8666666666666667.


{'train_runtime': 44.5162, 'train_samples_per_second': 2.696, 'train_steps_per_second': 0.09, 'train_loss': 0.5687837600708008, 'epoch': 1.0}




  0%|          | 0/45 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

[I 2023-07-29 18:50:47,080] Trial 6 pruned. 


{'eval_loss': 0.6205324530601501, 'eval_accuracy': 0.6, 'eval_runtime': 0.3178, 'eval_samples_per_second': 47.199, 'eval_steps_per_second': 6.293, 'epoch': 1.0}




  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

[I 2023-07-29 18:51:50,802] Trial 7 pruned. 


{'eval_loss': 0.620151698589325, 'eval_accuracy': 0.7333333333333333, 'eval_runtime': 0.6296, 'eval_samples_per_second': 23.826, 'eval_steps_per_second': 3.177, 'epoch': 1.0}




  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

[I 2023-07-29 18:52:32,653] Trial 8 pruned. 


{'eval_loss': 0.642278790473938, 'eval_accuracy': 0.6666666666666666, 'eval_runtime': 0.8737, 'eval_samples_per_second': 17.169, 'eval_steps_per_second': 2.289, 'epoch': 1.0}




  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.52342289686203, 'eval_accuracy': 0.8, 'eval_runtime': 0.3292, 'eval_samples_per_second': 45.567, 'eval_steps_per_second': 6.076, 'epoch': 1.0}


[I 2023-07-29 18:52:50,661] Trial 9 finished with value: 0.8 and parameters: {'learning_rate': 2.7389763738537607e-05, 'num_train_epochs': 1, 'seed': 24, 'per_device_train_batch_size': 8}. Best is trial 4 with value: 0.8666666666666667.


{'train_runtime': 16.4052, 'train_samples_per_second': 7.315, 'train_steps_per_second': 0.914, 'train_loss': 0.6169548670450846, 'epoch': 1.0}


In [12]:
best_run_tone

BestRun(run_id='4', objective=0.8666666666666667, hyperparameters={'learning_rate': 6.426351850922471e-05, 'num_train_epochs': 2, 'seed': 31, 'per_device_train_batch_size': 16}, run_summary=None)

In [17]:
for n, v in best_run_tone.hyperparameters.items():
    setattr(tone_trainer.args, n, v)

result = tone_trainer.train()
print_summary(result)



  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.3096, 'learning_rate': 4.819763888191853e-05, 'epoch': 0.5}
{'loss': 0.9335, 'learning_rate': 3.2131759254612355e-05, 'epoch': 1.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.7770673036575317, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.3214, 'eval_samples_per_second': 46.664, 'eval_steps_per_second': 6.222, 'epoch': 1.0}
{'loss': 0.6181, 'learning_rate': 1.6065879627306178e-05, 'epoch': 1.5}
{'loss': 0.5801, 'learning_rate': 0.0, 'epoch': 2.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 0.6173900961875916, 'eval_accuracy': 0.6666666666666666, 'eval_runtime': 0.3122, 'eval_samples_per_second': 48.039, 'eval_steps_per_second': 6.405, 'epoch': 2.0}
{'train_runtime': 37.691, 'train_samples_per_second': 6.368, 'train_steps_per_second': 0.425, 'train_loss': 0.8602981865406036, 'epoch': 2.0}
Time: 37.69
Samples/second: 6.37


In [18]:
print(tone_trainer.state.best_model_checkpoint)
tone_trainer.save_model('../fine-tuning-final/tone')


fine-tuning-chkp/tone/checkpoint-16


Upload file pytorch_model.bin:   0%|          | 1.00/476M [00:00<?, ?B/s]

Upload file runs/Jul29_18-46-05_Stefans-MacBook-Pro.local/events.out.tfevents.1690670771.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_18-46-05_Stefans-MacBook-Pro.local/events.out.tfevents.1690670819.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-12-36_Stefans-MacBook-Pro.local/events.out.tfevents.1690672371.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_18-46-05_Stefans-MacBook-Pro.local/events.out.tfevents.1690670862.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_18-46-05_Stefans-MacBook-Pro.local/events.out.tfevents.1690671154.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_18-46-05_Stefans-MacBook-Pro.local/events.out.tfevents.1690670890.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_18-46-05_Stefans-MacBook-Pro.local/events.out.tfevents.1690670955.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_18-46-05_Stefans-MacBook-Pro.local/events.out.tfevents.1690670987.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_18-46-05_Stefans-MacBook-Pro.local/events.out.tfevents.1690671112.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_18-46-05_Stefans-MacBook-Pro.local/events.out.tfevents.1690671048.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_18-46-05_Stefans-MacBook-Pro.local/events.out.tfevents.1690671034.Stefans-MacBook-Pro.l…

Upload file training_args.bin:   0%|          | 1.00/3.93k [00:00<?, ?B/s]

To https://huggingface.co/quesmed/tone
   555cd56..d2b6060  main -> main

To https://huggingface.co/quesmed/tone
   d2b6060..a0d3fe1  main -> main



# Emotion fine-tuning

In [19]:
emotion_tokenizer, emotion_config, emotion_model = init_model("bhadresh-savani/distilbert-base-uncased-emotion")


In [20]:
# emotion_labels = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

ds_emotion = isolate_dataset(ds, 'emotion')

ds_emotion = ds_emotion.map(
  lambda row: emotion_tokenizer(row['comment'], max_length=512, padding='max_length', truncation=True, return_tensors='pt'), 
  batched=True,
  remove_columns=['comment']
)

ds_emotion['train'].features

Casting to class labels:   0%|          | 0/15 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/120 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/15 [00:00<?, ? examples/s]

Map:   0%|          | 0/15 [00:00<?, ? examples/s]

Map:   0%|          | 0/120 [00:00<?, ? examples/s]

Map:   0%|          | 0/15 [00:00<?, ? examples/s]

{'label': ClassLabel(names=['?puzzled', 'anger', 'fear', 'joy', 'sadness', 'surprise'], id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}

In [21]:
emotion_trainer, emotion_args = setup_trainer('emotion', dataset=ds_emotion, model=emotion_model, tokenizer=emotion_tokenizer, push_to_hub=True)

/Users/stefan/Github/sentiment_analysis/fine-tuning-chkp/emotion is already a clone of https://huggingface.co/quesmed/emotion. Make sure you pull the latest changes with `repo.git_pull()`.


In [22]:
def model_init_emotion():
    _, _, model = init_model("bhadresh-savani/distilbert-base-uncased-emotion")
    return model

trainer = Trainer(
    model_init=model_init_emotion,
    args=emotion_args,
    train_dataset=ds_emotion["train"],
    eval_dataset=ds_emotion["validate"],
    tokenizer=emotion_tokenizer,
    compute_metrics=compute_metrics
)

/Users/stefan/Github/sentiment_analysis/fine-tuning-chkp/emotion is already a clone of https://huggingface.co/quesmed/emotion. Make sure you pull the latest changes with `repo.git_pull()`.


In [23]:
best_run_emotion = trainer.hyperparameter_search(n_trials=10, direction="maximize")

[I 2023-07-29 19:18:54,683] A new study created in memory with name: no-name-47aabdb8-5eb3-4f58-8554-17c9e87b636f


  0%|          | 0/45 [00:00<?, ?it/s]

{'loss': 3.3145, 'learning_rate': 3.3508029272253483e-06, 'epoch': 0.6}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.20538330078125, 'eval_accuracy': 0.26666666666666666, 'eval_runtime': 0.2376, 'eval_samples_per_second': 63.134, 'eval_steps_per_second': 8.418, 'epoch': 1.0}
{'loss': 2.7061, 'learning_rate': 2.5131021954190113e-06, 'epoch': 1.2}
{'loss': 2.3476, 'learning_rate': 1.6754014636126742e-06, 'epoch': 1.8}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.0054733753204346, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.1943, 'eval_samples_per_second': 77.206, 'eval_steps_per_second': 10.294, 'epoch': 2.0}
{'loss': 2.3294, 'learning_rate': 8.377007318063371e-07, 'epoch': 2.4}
{'loss': 2.18, 'learning_rate': 0.0, 'epoch': 3.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.9239314794540405, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.1952, 'eval_samples_per_second': 76.829, 'eval_steps_per_second': 10.244, 'epoch': 3.0}


[I 2023-07-29 19:19:21,109] Trial 0 finished with value: 0.3333333333333333 and parameters: {'learning_rate': 4.1885036590316854e-06, 'num_train_epochs': 3, 'seed': 4, 'per_device_train_batch_size': 8}. Best is trial 0 with value: 0.3333333333333333.


{'train_runtime': 25.3907, 'train_samples_per_second': 14.178, 'train_steps_per_second': 1.772, 'train_loss': 2.5755245632595485, 'epoch': 3.0}




  0%|          | 0/150 [00:00<?, ?it/s]

{'loss': 3.4504, 'learning_rate': 1.6569957920696866e-06, 'epoch': 0.3}
{'loss': 2.8955, 'learning_rate': 1.5512301032141747e-06, 'epoch': 0.6}
{'loss': 2.9273, 'learning_rate': 1.4454644143586628e-06, 'epoch': 0.9}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.3858933448791504, 'eval_accuracy': 0.26666666666666666, 'eval_runtime': 0.1973, 'eval_samples_per_second': 76.036, 'eval_steps_per_second': 10.138, 'epoch': 1.0}
{'loss': 2.9371, 'learning_rate': 1.339698725503151e-06, 'epoch': 1.2}
{'loss': 2.6661, 'learning_rate': 1.233933036647639e-06, 'epoch': 1.5}
{'loss': 2.535, 'learning_rate': 1.128167347792127e-06, 'epoch': 1.8}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.1227197647094727, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.2175, 'eval_samples_per_second': 68.966, 'eval_steps_per_second': 9.195, 'epoch': 2.0}
{'loss': 2.5997, 'learning_rate': 1.0224016589366152e-06, 'epoch': 2.1}
{'loss': 2.5998, 'learning_rate': 9.166359700811033e-07, 'epoch': 2.4}
{'loss': 2.552, 'learning_rate': 8.108702812255914e-07, 'epoch': 2.7}
{'loss': 1.742, 'learning_rate': 7.051045923700795e-07, 'epoch': 3.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.9646297693252563, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.2299, 'eval_samples_per_second': 65.255, 'eval_steps_per_second': 8.701, 'epoch': 3.0}
{'loss': 1.8445, 'learning_rate': 5.993389035145675e-07, 'epoch': 3.3}
{'loss': 2.0217, 'learning_rate': 4.935732146590556e-07, 'epoch': 3.6}
{'loss': 2.1447, 'learning_rate': 3.878075258035437e-07, 'epoch': 3.9}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.8889068365097046, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.2044, 'eval_samples_per_second': 73.388, 'eval_steps_per_second': 9.785, 'epoch': 4.0}
{'loss': 2.4815, 'learning_rate': 2.8204183694803177e-07, 'epoch': 4.2}
{'loss': 1.8718, 'learning_rate': 1.7627614809251986e-07, 'epoch': 4.5}
{'loss': 2.0282, 'learning_rate': 7.051045923700794e-08, 'epoch': 4.8}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.863094449043274, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.1968, 'eval_samples_per_second': 76.211, 'eval_steps_per_second': 10.161, 'epoch': 5.0}


[I 2023-07-29 19:20:10,678] Trial 1 finished with value: 0.3333333333333333 and parameters: {'learning_rate': 1.7627614809251985e-06, 'num_train_epochs': 5, 'seed': 21, 'per_device_train_batch_size': 4}. Best is trial 0 with value: 0.3333333333333333.


{'train_runtime': 48.7329, 'train_samples_per_second': 12.312, 'train_steps_per_second': 3.078, 'train_loss': 2.4454948043823244, 'epoch': 5.0}




  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.451944589614868, 'eval_accuracy': 0.26666666666666666, 'eval_runtime': 0.7849, 'eval_samples_per_second': 19.11, 'eval_steps_per_second': 2.548, 'epoch': 1.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.3482654094696045, 'eval_accuracy': 0.26666666666666666, 'eval_runtime': 0.6644, 'eval_samples_per_second': 22.577, 'eval_steps_per_second': 3.01, 'epoch': 2.0}


[I 2023-07-29 19:21:08,582] Trial 2 finished with value: 0.26666666666666666 and parameters: {'learning_rate': 9.2735783731369e-06, 'num_train_epochs': 2, 'seed': 25, 'per_device_train_batch_size': 64}. Best is trial 0 with value: 0.3333333333333333.


{'train_runtime': 57.0563, 'train_samples_per_second': 4.206, 'train_steps_per_second': 0.07, 'train_loss': 3.074446201324463, 'epoch': 2.0}




  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.5448811054229736, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.196, 'eval_samples_per_second': 76.537, 'eval_steps_per_second': 10.205, 'epoch': 1.0}
{'loss': 2.4963, 'learning_rate': 1.4196231173783549e-05, 'epoch': 1.12}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.4423075914382935, 'eval_accuracy': 0.4, 'eval_runtime': 0.1965, 'eval_samples_per_second': 76.323, 'eval_steps_per_second': 10.176, 'epoch': 2.0}
{'loss': 1.5055, 'learning_rate': 8.641184192737813e-06, 'epoch': 2.25}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.4126588106155396, 'eval_accuracy': 0.4, 'eval_runtime': 0.1955, 'eval_samples_per_second': 76.73, 'eval_steps_per_second': 10.231, 'epoch': 3.0}
{'loss': 1.4409, 'learning_rate': 3.086137211692076e-06, 'epoch': 3.38}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.4021178483963013, 'eval_accuracy': 0.4, 'eval_runtime': 0.194, 'eval_samples_per_second': 77.307, 'eval_steps_per_second': 10.308, 'epoch': 4.0}


[I 2023-07-29 19:21:42,881] Trial 3 finished with value: 0.4 and parameters: {'learning_rate': 1.9751278154829286e-05, 'num_train_epochs': 4, 'seed': 4, 'per_device_train_batch_size': 16}. Best is trial 3 with value: 0.4.


{'train_runtime': 33.3643, 'train_samples_per_second': 14.387, 'train_steps_per_second': 0.959, 'train_loss': 1.7218625098466873, 'epoch': 4.0}




  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.3613953590393066, 'eval_accuracy': 0.26666666666666666, 'eval_runtime': 0.7145, 'eval_samples_per_second': 20.993, 'eval_steps_per_second': 2.799, 'epoch': 1.0}


[I 2023-07-29 19:22:10,732] Trial 4 finished with value: 0.26666666666666666 and parameters: {'learning_rate': 1.582718593899664e-05, 'num_train_epochs': 1, 'seed': 9, 'per_device_train_batch_size': 64}. Best is trial 3 with value: 0.4.


{'train_runtime': 27.0788, 'train_samples_per_second': 4.432, 'train_steps_per_second': 0.074, 'train_loss': 3.227383613586426, 'epoch': 1.0}




  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

[I 2023-07-29 19:22:17,780] Trial 5 pruned. 


{'eval_loss': 2.295771598815918, 'eval_accuracy': 0.26666666666666666, 'eval_runtime': 0.1952, 'eval_samples_per_second': 76.853, 'eval_steps_per_second': 10.247, 'epoch': 1.0}




  0%|          | 0/90 [00:00<?, ?it/s]

{'loss': 3.3843, 'learning_rate': 4.810239399919614e-06, 'epoch': 0.3}
{'loss': 2.4298, 'learning_rate': 4.275768355484101e-06, 'epoch': 0.6}
{'loss': 2.5708, 'learning_rate': 3.741297311048588e-06, 'epoch': 0.9}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.8383318185806274, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.1934, 'eval_samples_per_second': 77.563, 'eval_steps_per_second': 10.342, 'epoch': 1.0}
{'loss': 1.619, 'learning_rate': 3.206826266613076e-06, 'epoch': 1.2}
{'loss': 2.2071, 'learning_rate': 2.672355222177563e-06, 'epoch': 1.5}
{'loss': 2.0178, 'learning_rate': 2.1378841777420504e-06, 'epoch': 1.8}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.5806950330734253, 'eval_accuracy': 0.4, 'eval_runtime': 0.194, 'eval_samples_per_second': 77.3, 'eval_steps_per_second': 10.307, 'epoch': 2.0}
{'loss': 1.5003, 'learning_rate': 1.603413133306538e-06, 'epoch': 2.1}
{'loss': 1.5562, 'learning_rate': 1.0689420888710252e-06, 'epoch': 2.4}
{'loss': 1.7619, 'learning_rate': 5.344710444355126e-07, 'epoch': 2.7}
{'loss': 1.4427, 'learning_rate': 0.0, 'epoch': 3.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.5423920154571533, 'eval_accuracy': 0.4, 'eval_runtime': 0.1957, 'eval_samples_per_second': 76.658, 'eval_steps_per_second': 10.221, 'epoch': 3.0}


[I 2023-07-29 19:22:47,412] Trial 6 finished with value: 0.4 and parameters: {'learning_rate': 5.344710444355126e-06, 'num_train_epochs': 3, 'seed': 26, 'per_device_train_batch_size': 4}. Best is trial 3 with value: 0.4.


{'train_runtime': 28.9469, 'train_samples_per_second': 12.437, 'train_steps_per_second': 3.109, 'train_loss': 2.0489992671542696, 'epoch': 3.0}




  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

[I 2023-07-29 19:22:53,875] Trial 7 pruned. 


{'eval_loss': 2.4755399227142334, 'eval_accuracy': 0.26666666666666666, 'eval_runtime': 0.1981, 'eval_samples_per_second': 75.738, 'eval_steps_per_second': 10.098, 'epoch': 1.0}




  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

[I 2023-07-29 19:23:00,367] Trial 8 pruned. 


{'eval_loss': 2.6507599353790283, 'eval_accuracy': 0.2, 'eval_runtime': 0.1986, 'eval_samples_per_second': 75.51, 'eval_steps_per_second': 10.068, 'epoch': 1.0}




  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.2886674404144287, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.2016, 'eval_samples_per_second': 74.389, 'eval_steps_per_second': 9.919, 'epoch': 1.0}
{'loss': 2.9965, 'learning_rate': 3.2370061192709505e-06, 'epoch': 1.12}


  0%|          | 0/2 [00:00<?, ?it/s]

[I 2023-07-29 19:23:14,933] Trial 9 pruned. 


{'eval_loss': 2.0955123901367188, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.2015, 'eval_samples_per_second': 74.426, 'eval_steps_per_second': 9.923, 'epoch': 2.0}


In [24]:
best_run_emotion

BestRun(run_id='3', objective=0.4, hyperparameters={'learning_rate': 1.9751278154829286e-05, 'num_train_epochs': 4, 'seed': 4, 'per_device_train_batch_size': 16}, run_summary=None)

In [25]:
for n, v in best_run_emotion.hyperparameters.items():
    setattr(emotion_trainer.args, n, v)

result = emotion_trainer.train()
print_summary(result)



  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.56971275806427, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.2145, 'eval_samples_per_second': 69.926, 'eval_steps_per_second': 9.323, 'epoch': 1.0}
{'loss': 2.4154, 'learning_rate': 1.4196231173783549e-05, 'epoch': 1.12}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.4366040229797363, 'eval_accuracy': 0.4, 'eval_runtime': 0.1996, 'eval_samples_per_second': 75.159, 'eval_steps_per_second': 10.021, 'epoch': 2.0}
{'loss': 1.5146, 'learning_rate': 8.641184192737813e-06, 'epoch': 2.25}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.4025484323501587, 'eval_accuracy': 0.4, 'eval_runtime': 0.1954, 'eval_samples_per_second': 76.776, 'eval_steps_per_second': 10.237, 'epoch': 3.0}
{'loss': 1.3946, 'learning_rate': 3.086137211692076e-06, 'epoch': 3.38}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.3901256322860718, 'eval_accuracy': 0.4, 'eval_runtime': 0.2006, 'eval_samples_per_second': 74.772, 'eval_steps_per_second': 9.97, 'epoch': 4.0}
{'train_runtime': 32.7718, 'train_samples_per_second': 14.647, 'train_steps_per_second': 0.976, 'train_loss': 1.7019539773464203, 'epoch': 4.0}
Time: 32.77
Samples/second: 14.65


In [26]:
result = emotion_trainer.train()
print_summary(result)

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.3971366882324219, 'eval_accuracy': 0.4, 'eval_runtime': 0.1951, 'eval_samples_per_second': 76.894, 'eval_steps_per_second': 10.253, 'epoch': 1.0}
{'loss': 1.3064, 'learning_rate': 1.4196231173783549e-05, 'epoch': 1.12}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.3285859823226929, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.1944, 'eval_samples_per_second': 77.147, 'eval_steps_per_second': 10.286, 'epoch': 2.0}
{'loss': 1.153, 'learning_rate': 8.641184192737813e-06, 'epoch': 2.25}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.324013352394104, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.1955, 'eval_samples_per_second': 76.737, 'eval_steps_per_second': 10.232, 'epoch': 3.0}
{'loss': 0.9528, 'learning_rate': 3.086137211692076e-06, 'epoch': 3.38}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 1.336905598640442, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.21, 'eval_samples_per_second': 71.443, 'eval_steps_per_second': 9.526, 'epoch': 4.0}
{'train_runtime': 29.6485, 'train_samples_per_second': 16.19, 'train_steps_per_second': 1.079, 'train_loss': 1.095855861902237, 'epoch': 4.0}
Time: 29.65
Samples/second: 16.19


In [27]:
print(emotion_trainer.state.best_model_checkpoint)
emotion_trainer.save_model('../fine-tuning-final/emotion')

fine-tuning-chkp/emotion/checkpoint-16


Upload file pytorch_model.bin:   0%|          | 1.00/255M [00:00<?, ?B/s]

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672761.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672938.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672811.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690673027.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672869.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672735.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672995.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672981.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672903.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672974.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672968.Stefans-MacBook-Pro.l…

Upload file runs/Jul29_19-16-56_Stefans-MacBook-Pro.local/events.out.tfevents.1690672931.Stefans-MacBook-Pro.l…

Upload file training_args.bin:   0%|          | 1.00/3.93k [00:00<?, ?B/s]

To https://huggingface.co/quesmed/emotion
   238063f..18da624  main -> main

To https://huggingface.co/quesmed/emotion
   18da624..1ecb28a  main -> main



# Theme fine-tuning

In [6]:
theme_tokenizer, theme_config, theme_model = init_model("facebook/bart-large-mnli")

In [10]:
print(ds['train'][0])

{'id': 287, 'createdAt': '2021-04-05 07:59:57.863000+00:00', 'userId': 4816, 'userCreatedAt': '2021-02-24 14:39:33.720000+00:00', 'classYear': 'Year 4', 'universityId': 2620, 'country': 'United Kingdom', 'universityName': 'University College London (UCL)', 'parentId': None, 'questionId': 4444, 'comment': 'why keep him on oxygen? Surely his sats are fine now if not too high?', 'review': 0.0, 'negative': 0, 'neutral': 1, 'positive': 0, 'tone': 'neutral', 'sadness': 0, 'joy': 0, 'love': 0, 'anger': 0, 'fear': 1, 'surprise': 1, 'emotion': '?puzzled', 'educational': 0, 'giving feedback': 0, 'asking a question': 1, 'insulting': 0, 'supporting': 0, 'humour': 0, 'frustration': 0, 'theme': 'asking a question'}


In [None]:
theme_labels = ['educational', 'giving feedback', 'asking a question', 'insulting', 'supporting', 'humour', 'frustration']

template="This example is {}."

for theme in theme_labels:
    def create_theme_ds(row):
      hypothesis = template.format(theme)
      label = 2 if row[theme] == 1 else 0 # convert to 2: entailment for TRUE
      
      return {
         'text': row['text'],
         'hypothesis': hypothesis,
         'label': label
      }

In [53]:
feature = 'theme'
cols = ds.column_names['train']
col_keep = {'text', feature}

ds_theme = ds.remove_columns(col_keep.symmetric_difference(cols))
ds_theme = ds_theme.rename_column(feature, 'label')

In [60]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
import random


theme_tokenizer, theme_config, theme_model = init_model("facebook/bart-large-mnli")
# Linear(in_features=1024, out_features=3, bias=True)
# {0: 'contradiction', 1: 'neutral', 2: 'entailment'}

theme_labels = ['clinical update', 'community', 'question', 'education', 'advocating', 'dissuading', 'other']
num_labels = len(theme_labels)
template="This example is {}."

def create_input_sequence(sample):
    text = sample['text']
    label = sample['label'][0]
    contradiction_labels = theme_labels[:]
    label_idx = contradiction_labels.index(label)
    contradiction_labels.pop(label_idx)

    encoded_sequence = theme_tokenizer(
        text,
        [template.format(label)],
        # max_length=512,
        # padding='max_length', 
        truncation=True, 
        return_tensors='pt'
    )
    encoded_sequence['labels'] = [2]
    encoded_sequence['input_sentence'] = theme_tokenizer.batch_decode(encoded_sequence.input_ids)
    return encoded_sequence

ds_theme_encoded = ds_theme.map(
    create_input_sequence, 
    batched=True, 
    batch_size=1,
    remove_columns=["label", "text"]
)

ds_theme_encoded


Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'input_sentence'],
        num_rows: 20
    })
    validate: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'input_sentence'],
        num_rows: 3
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'input_sentence'],
        num_rows: 3
    })
})

In [61]:
ds_theme_encoded['train'][0]

{'input_ids': [0,
  500,
  1949,
  213,
  449,
  2013,
  116,
  50118,
  100,
  56,
  10,
  21431,
  24904,
  626,
  11,
  1824,
  4,
  1308,
  284,
  8,
  38,
  439,
  7,
  5,
  213,
  449,
  2013,
  1349,
  147,
  51,
  56,
  41,
  33638,
  8,
  28445,
  9668,
  4,
  2041,
  137,
  94,
  38,
  4024,
  10,
  213,
  449,
  2013,
  8,
  38,
  2145,
  38,
  1705,
  17,
  27,
  90,
  269,
  2842,
  5,
  1123,
  26965,
  8,
  20789,
  142,
  9,
  141,
  1359,
  127,
  124,
  16,
  6,
  53,
  961,
  1493,
  198,
  162,
  115,
  2842,
  24,
  95,
  2051,
  19,
  49,
  15145,
  18822,
  4,
  85,
  938,
  17,
  27,
  90,
  14,
  38,
  21,
  765,
  6,
  24,
  21,
  14,
  38,
  1705,
  17,
  27,
  90,
  20789,
  4,
  6233,
  1268,
  1493,
  655,
  2984,
  42,
  116,
  7698,
  47,
  3068,
  213,
  449,
  7870,
  114,
  47,
  17,
  27,
  548,
  56,
  42,
  1907,
  9,
  3012,
  116,
  2,
  2,
  713,
  1246,
  16,
  864,
  4,
  2],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,


In [62]:
premise = ds_theme['train'][0]['text']
template= "This example is {}."
hypothesis = template.format(ds_theme['train'][0]['label'])

# run through model pre-trained on MNLI
x = theme_tokenizer(premise, hypothesis, 
                           truncation_strategy='only_first',
        return_tensors='pt')
x
# logits = theme_model(x.to(device))[0]

# # we throw away "neutral" (dim 1) and take the probability of
# # "entailment" (2) as the probability of the label being true 
# entail_contradiction_logits = logits[:,[0,2]]
# probs = entail_contradiction_logits.softmax(dim=1)
# prob_label_is_true = probs[:,1]
# prob_label_is_true



{'input_ids': tensor([[    0,   500,  1949,   213,   449,  2013,   116, 50118,   100,    56,
            10, 21431, 24904,   626,    11,  1824,     4,  1308,   284,     8,
            38,   439,     7,     5,   213,   449,  2013,  1349,   147,    51,
            56,    41, 33638,     8, 28445,  9668,     4,  2041,   137,    94,
            38,  4024,    10,   213,   449,  2013,     8,    38,  2145,    38,
          1705,    17,    27,    90,   269,  2842,     5,  1123, 26965,     8,
         20789,   142,     9,   141,  1359,   127,   124,    16,     6,    53,
           961,  1493,   198,   162,   115,  2842,    24,    95,  2051,    19,
            49, 15145, 18822,     4,    85,   938,    17,    27,    90,    14,
            38,    21,   765,     6,    24,    21,    14,    38,  1705,    17,
            27,    90, 20789,     4,  6233,  1268,  1493,   655,  2984,    42,
           116,  7698,    47,  3068,   213,   449,  7870,   114,    47,    17,
            27,   548,    56,    42,  

In [63]:
ds_theme_encoded['train'][0]

{'input_ids': [0,
  500,
  1949,
  213,
  449,
  2013,
  116,
  50118,
  100,
  56,
  10,
  21431,
  24904,
  626,
  11,
  1824,
  4,
  1308,
  284,
  8,
  38,
  439,
  7,
  5,
  213,
  449,
  2013,
  1349,
  147,
  51,
  56,
  41,
  33638,
  8,
  28445,
  9668,
  4,
  2041,
  137,
  94,
  38,
  4024,
  10,
  213,
  449,
  2013,
  8,
  38,
  2145,
  38,
  1705,
  17,
  27,
  90,
  269,
  2842,
  5,
  1123,
  26965,
  8,
  20789,
  142,
  9,
  141,
  1359,
  127,
  124,
  16,
  6,
  53,
  961,
  1493,
  198,
  162,
  115,
  2842,
  24,
  95,
  2051,
  19,
  49,
  15145,
  18822,
  4,
  85,
  938,
  17,
  27,
  90,
  14,
  38,
  21,
  765,
  6,
  24,
  21,
  14,
  38,
  1705,
  17,
  27,
  90,
  20789,
  4,
  6233,
  1268,
  1493,
  655,
  2984,
  42,
  116,
  7698,
  47,
  3068,
  213,
  449,
  7870,
  114,
  47,
  17,
  27,
  548,
  56,
  42,
  1907,
  9,
  3012,
  116,
  2,
  2,
  713,
  1246,
  16,
  864,
  4,
  2],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,


In [64]:
theme_trainer, theme_args = setup_trainer('theme', dataset=ds_theme_encoded, model=theme_model, tokenizer=theme_tokenizer)

In [65]:
result = theme_trainer.train()
print_summary(result)

theme_trainer.save_model('../fine-tuning-final/theme')



  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (2, 3) + inhomogeneous part.