In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# Setup

In [2]:
import os

import datasets
import pandas as pd

from src.dataset import add_audio_column, filter_df, prepare_ds, split_df
from src.train import end_training, get_model, get_trainer
from src.utils import get_csv_name, get_run_name

In [3]:
RES_DIR_PATH = "res"
NOTEBOOK_ENV = "jupyter"

AUDIOS_DIR_PATH = os.path.join(RES_DIR_PATH, "mp3_data")
MODELS_DIR_PATH = os.path.join(RES_DIR_PATH, "models")
DATASETS_DIR_PATH = os.path.join(RES_DIR_PATH, "datasets")

CSV_PATH = os.path.join(RES_DIR_PATH, "samples_clustered.csv")

TOP_N_GENRES = 6
TOP_N_FEATURES = 9


FEATURES_CONFIG_SUBSET = {"genre": {"top_n": 3, "samples": 1000}}
FEATURES_CONFIG_GEN = {"genre": {"top_n": TOP_N_GENRES, "samples": None}}
FEATURES_CONFIG_CAT = {"category": {"top_n": TOP_N_FEATURES, "samples": None}}
FEATURES_CONFIG_MULTI = {
    "genre": {"top_n": TOP_N_GENRES, "samples": None},
    "category": {"top_n": TOP_N_FEATURES, "samples": None},
}

VALID_SIZE = 0.1
TEST_SIZE = 0.1

## Backbones

The two considered backbones are [Wav2Vec2](https://arxiv.org/abs/2006.11477) and [Whisper](https://cdn.openai.com/papers/whisper.pdf).

Both models are used through the [Hugging Face Transformers](https://huggingface.co/docs/transformers) library.

The implementation of the **Wav2Vec2** classifier follows the one in the [Wav2Vec2ForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2ForSequenceClassification) class, adding the support for a custom classification head.

Regarding **Whisper**, I took the outputs from the [WhisperEncoder](https://huggingface.co/docs/transformers/model_doc/whisper) class and used them right away.

## Fine-tuning

For both of the backbones, when freezing them, the gradient computation of the entire encoder was disabled.

## Classifier

The classifier is implemented through an MLP, with variable layer size and hidden dimensions.
Each layer is followed by an optional Dropout layer and a ReLU activation.

## Multi-task

TODO

# Training 1

In [4]:
TRAINING_CONFIG = {
    "epochs": 20,
    "learning_rate": 5e-5,
    "warmup": 0.0,
    "train_batch_size": 8,
    "eval_batch_size": 16,
    "feature_encoder": None,
    "freeze_encoder": None,
    "classifier_layers": None, 
    "classifier_dropout": None,
}

In [5]:
def create_or_load_df(features_config):
    filtered_csv_path = get_csv_name(features_config, CSV_PATH)

    # If the subset is already in the filesystem, load it directly
    if os.path.exists(filtered_csv_path):
        print(f"Loading {filtered_csv_path}")
        df = pd.read_csv(filtered_csv_path)
    else:
        df = pd.read_csv(CSV_PATH)
        # Filter the dataset according to the given configuration and remove rows containing null values
        df = filter_df(
            df, 
            remove_nones=True,
            features_config=features_config, 
        )
        df.to_csv(filtered_csv_path, index=False)

    print(f"{len(df)} examples in DataFrame")
    # If the split column is not in the dataset, split the dataset into three partisions using 
    # `TEST_SIZE` and `VALID_SIZE` and save the result

    if "split" not in df.columns:
        df = split_df(df, validation_size=VALID_SIZE, test_size=TEST_SIZE)
        df.to_csv(filtered_csv_path, index=False)

    print(df.value_counts("split"))
    return df

In [6]:
# Create a function for loading the dataset for the requested model

def load_and_prepare_ds(training_config, feature_config, df, clustered=True):
    encoded_dataset_path = os.path.join(DATASETS_DIR_PATH, f"ds-{training_config['feature_encoder']}-full-encoded")
    ds = datasets.load_from_disk(encoded_dataset_path)
    ds = add_audio_column(ds, audios_dir_path=AUDIOS_DIR_PATH, training_config={"feature_encoder": training_config['feature_encoder']})
    return prepare_ds(ds, df, feature_config, clustered=clustered, fixed_mapping=None, save=False)

In [7]:
# Build the filename indicating the subset of the whole dataset with the specific configurations
df = create_or_load_df(FEATURES_CONFIG_SUBSET)

Loading res/samples_clustered_genre3s1000.csv
999 examples in DataFrame
split
train    799
test     100
valid    100
dtype: int64


## Wav2Vec2

In [8]:
TRAINING_CONFIG["feature_encoder"] = "wav2vec2"
TRAINING_CONFIG["freeze_encoder"] = True
TRAINING_CONFIG["classifier_layers"] = [256]
TRAINING_CONFIG["classifier_dropout"] = 0

In [9]:
prepared_ds = load_and_prepare_ds(TRAINING_CONFIG, FEATURES_CONFIG_SUBSET, df)

prepared_ds

Loading cached processed dataset at /home/alesssandros/dev/FCN_Newspaper/aii/res/datasets/ds-wav2vec2-full-encoded/cache-b5264b5b2644adf3.arrow


Removing extra columns from dataset
Mapping features clusters
Extracting train split
Extracting valid split
Extracting test split
Create `ClassLabels` for target classes
{'genre': ClassLabel(names=['Electronic', 'Rock/Blues', 'World/Ethnic'], id=None)}


Casting the dataset:   0%|          | 0/799 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['label', 'id', 'duration', 'input_values'],
        num_rows: 799
    })
    valid: Dataset({
        features: ['label', 'id', 'duration', 'input_values'],
        num_rows: 100
    })
    test: Dataset({
        features: ['label', 'id', 'duration', 'input_values'],
        num_rows: 100
    })
})

In [10]:
run_name = get_run_name(TRAINING_CONFIG)
model = get_model(TRAINING_CONFIG, prepared_ds["train"])

trainer = get_trainer(
    run_name=run_name,
    model=model,
    train_ds=prepared_ds["train"],
    eval_ds=prepared_ds["valid"],
    training_config=TRAINING_CONFIG,
    output_dir="out",
    debug=False,
    env=NOTEBOOK_ENV,
)

trainer.train()
end_training(run_name, trainer, MODELS_DIR_PATH)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForSequenceMultiClassification: ['quantizer.weight_proj.bias', 'project_hid.bias', 'project_hid.weight', 'project_q.weight', 'quantizer.weight_proj.weight', 'quantizer.codevectors', 'project_q.bias']
- This IS expected if you are initializing Wav2Vec2ForSequenceMultiClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForSequenceMultiClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForSequenceMultiClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.weight', 'head.lay

The following columns in the training set don't have a corresponding argument in `Wav2Vec2ForSequenceMultiClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `Wav2Vec2ForSequenceMultiClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 799
  Num Epochs = 20
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2000
  Number of trainable parameters = 394499
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Accuracy
1,1.0863,1.07113,0.46
2,1.0628,1.046994,0.5
3,1.0435,1.017666,0.5
4,1.0172,0.996653,0.49
5,0.9976,0.983945,0.52
6,0.9824,0.96791,0.53
7,0.9818,0.962945,0.51
8,0.9627,0.956101,0.53
9,0.9427,0.947107,0.55
10,0.942,0.946416,0.56


The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceMultiClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `Wav2Vec2ForSequenceMultiClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 16
Saving model checkpoint to out/checkpoint-100
Configuration saved in out/checkpoint-100/config.json
Model weights saved in out/checkpoint-100/pytorch_model.bin
Feature extractor saved in out/checkpoint-100/preprocessor_config.json
Deleting older checkpoint [out/checkpoint-200] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceMultiClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `Wav2Vec2ForSequenceMultiClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num ex

0,1
eval/accuracy,▁▃▃▃▄▅▄▅▆▇▆▅▆▆▆▇█▇▇▇
eval/loss,█▇▅▄▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁
eval/runtime,▁▁▁▁▁▁▁▁▁▁█▁▁▁▁███▁▁
eval/samples_per_second,▇█████████▁████▁▁▁██
eval/steps_per_second,▇█████████▁████▁▁▁██
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁
train/loss,███▇▇▇▆▆▅▅▅▄▄▄▄▄▄▃▃▃▃▂▃▂▃▂▂▂▂▂▂▂▁▁▂▁▂▂▁▂
train/total_flos,▁

0,1
eval/accuracy,0.56
eval/loss,0.9322
eval/runtime,2.4511
eval/samples_per_second,40.799
eval/steps_per_second,2.856
train/epoch,20.0
train/global_step,2000.0
train/learning_rate,0.0
train/loss,0.911
train/total_flos,1.245292480666616e+18


Saving model checkpoint to res/models/wav2vec2-frz-c256-d0-20230223-205423
Configuration saved in res/models/wav2vec2-frz-c256-d0-20230223-205423/config.json
Model weights saved in res/models/wav2vec2-frz-c256-d0-20230223-205423/pytorch_model.bin
Feature extractor saved in res/models/wav2vec2-frz-c256-d0-20230223-205423/preprocessor_config.json


## Fine Tuning

In [11]:
TRAINING_CONFIG["freeze_encoder"] = False 

In [12]:
run_name = get_run_name(TRAINING_CONFIG)
model = get_model(TRAINING_CONFIG, prepared_ds["train"])

trainer = get_trainer(
    run_name=run_name,
    model=model,
    train_ds=prepared_ds["train"],
    eval_ds=prepared_ds["valid"],
    training_config=TRAINING_CONFIG,
    output_dir="out",
    debug=False,
    env=NOTEBOOK_ENV,
)

trainer.train()
end_training(run_name, trainer, MODELS_DIR_PATH)

loading configuration file config.json from cache at /home/alesssandros/.cache/huggingface/hub/models--facebook--wav2vec2-base/snapshots/0b5b8e868dd84f03fd87d01f9c4ff0f080fecfe8/config.json
Model config Wav2Vec2Config {
  "_name_or_path": "facebook/wav2vec2-base",
  "activation_dropout": 0.0,
  "adapter_kernel_size": 3,
  "adapter_stride": 2,
  "add_adapter": false,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForPreTraining"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 256,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": false,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "sum",
  "ctc_zero_infinity": false,
  "diversity_loss_weight": 0.1,
  "do_stable_layer_norm": false,
  "eos_token_id": 2,
  "

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668569513907036, max=1.0…

PyTorch: setting up devices
The following columns in the training set don't have a corresponding argument in `Wav2Vec2ForSequenceMultiClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `Wav2Vec2ForSequenceMultiClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 799
  Num Epochs = 20
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2000
  Number of trainable parameters = 94766211
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Accuracy
1,1.0167,0.858208,0.66
2,0.8291,0.780925,0.7
3,0.7637,0.748313,0.67
4,0.6544,0.809231,0.63
5,0.5689,0.624179,0.77
6,0.5093,0.785466,0.74
7,0.4399,0.622021,0.82
8,0.2909,0.577539,0.86
9,0.2996,0.682577,0.83
10,0.2807,0.620278,0.83


The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceMultiClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `Wav2Vec2ForSequenceMultiClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 16
Saving model checkpoint to out/checkpoint-100
Configuration saved in out/checkpoint-100/config.json
Model weights saved in out/checkpoint-100/pytorch_model.bin
Feature extractor saved in out/checkpoint-100/preprocessor_config.json
Deleting older checkpoint [out/checkpoint-1700] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceMultiClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `Wav2Vec2ForSequenceMultiClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num e

VBox(children=(Label(value='0.002 MB of 0.030 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.079292…

0,1
eval/accuracy,▂▃▂▁▅▄▇█▇▇▅▆▆▇█▇▇██▇
eval/loss,▄▃▃▄▂▄▂▁▂▂█▄▅▂▂▅▄▅▄▄
eval/runtime,█▁▁█▁▁▁▁▁▂▁█████████
eval/samples_per_second,▁██▁█████▆█▁▁▁▁▁▁▁▁▁
eval/steps_per_second,▁██▁█████▆█▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁
train/loss,███▇▆▆▆▆▅▅▅▄▄▄▄▃▃▂▃▂▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁

0,1
eval/accuracy,0.84
eval/loss,0.84186
eval/runtime,5.9526
eval/samples_per_second,16.799
eval/steps_per_second,1.176
train/epoch,20.0
train/global_step,2000.0
train/learning_rate,0.0
train/loss,0.0164
train/total_flos,1.245292480666616e+18


Saving model checkpoint to res/models/wav2vec2-fnt-c256-d0-20230223-210654
Configuration saved in res/models/wav2vec2-fnt-c256-d0-20230223-210654/config.json
Model weights saved in res/models/wav2vec2-fnt-c256-d0-20230223-210654/pytorch_model.bin
Feature extractor saved in res/models/wav2vec2-fnt-c256-d0-20230223-210654/preprocessor_config.json


## Classification Head

In [13]:
TRAINING_CONFIG["freeze_encoder"] = False
TRAINING_CONFIG["classifier_layers"] = [256, 256]

In [14]:
run_name = get_run_name(TRAINING_CONFIG)
model = get_model(TRAINING_CONFIG, prepared_ds["train"])

trainer = get_trainer(
    run_name=run_name,
    model=model,
    train_ds=prepared_ds["train"],
    eval_ds=prepared_ds["valid"],
    training_config=TRAINING_CONFIG,
    output_dir="out",
    debug=False,
    env=NOTEBOOK_ENV,
)

trainer.train()
end_training(run_name, trainer, MODELS_DIR_PATH)

loading configuration file config.json from cache at /home/alesssandros/.cache/huggingface/hub/models--facebook--wav2vec2-base/snapshots/0b5b8e868dd84f03fd87d01f9c4ff0f080fecfe8/config.json
Model config Wav2Vec2Config {
  "_name_or_path": "facebook/wav2vec2-base",
  "activation_dropout": 0.0,
  "adapter_kernel_size": 3,
  "adapter_stride": 2,
  "add_adapter": false,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForPreTraining"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 256,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": false,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "sum",
  "ctc_zero_infinity": false,
  "diversity_loss_weight": 0.1,
  "do_stable_layer_norm": false,
  "eos_token_id": 2,
  "

PyTorch: setting up devices
The following columns in the training set don't have a corresponding argument in `Wav2Vec2ForSequenceMultiClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `Wav2Vec2ForSequenceMultiClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 799
  Num Epochs = 20
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2000
  Number of trainable parameters = 94832003
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Accuracy
1,1.0823,1.037547,0.53
2,1.062,1.112393,0.4
3,1.0758,1.069584,0.37
4,1.091,1.084402,0.32
5,1.0946,1.089886,0.31
6,1.0861,1.060857,0.36
7,1.0462,1.003772,0.53
8,1.0545,0.990443,0.55
9,1.0373,0.988915,0.55
10,1.0035,1.078376,0.41


The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceMultiClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `Wav2Vec2ForSequenceMultiClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 16
Saving model checkpoint to out/checkpoint-100
Configuration saved in out/checkpoint-100/config.json
Model weights saved in out/checkpoint-100/pytorch_model.bin
Feature extractor saved in out/checkpoint-100/preprocessor_config.json
Deleting older checkpoint [out/checkpoint-800] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceMultiClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `Wav2Vec2ForSequenceMultiClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num ex

VBox(children=(Label(value='0.002 MB of 0.028 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.085927…

0,1
eval/accuracy,▇▄▃▁▁▂▇██▄▄▆█▇▅▆▆▆▆▆
eval/loss,▅█▇▇▇▆▄▄▄▇▆▄▁▄▆▅▅▄▃▃
eval/runtime,▄▁▃█▁▁▁▁██▁▄▁▁▁▁████
eval/samples_per_second,▄█▅▁████▁▁█▄████▁▁▁▁
eval/steps_per_second,▄█▅▁████▁▁█▄████▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁
train/loss,█▇▇▅▆▆▇▆█▇██▇▆▅▅▅▄▄▃▂▇▆▆▄▁▃▅▄▄▆▇▄▅▇▅▇▅▇▆
train/total_flos,▁

0,1
eval/accuracy,0.5
eval/loss,0.97249
eval/runtime,5.8737
eval/samples_per_second,17.025
eval/steps_per_second,1.192
train/epoch,20.0
train/global_step,2000.0
train/learning_rate,0.0
train/loss,1.0683
train/total_flos,1.2461570323039926e+18


Saving model checkpoint to res/models/wav2vec2-fnt-c256_256-d0-20230223-212551
Configuration saved in res/models/wav2vec2-fnt-c256_256-d0-20230223-212551/config.json
Model weights saved in res/models/wav2vec2-fnt-c256_256-d0-20230223-212551/pytorch_model.bin
Feature extractor saved in res/models/wav2vec2-fnt-c256_256-d0-20230223-212551/preprocessor_config.json


## Whisper

In [15]:
TRAINING_CONFIG["feature_encoder"] = "whisper"
TRAINING_CONFIG["freeze_encoder"] = True
TRAINING_CONFIG["classifier_layers"] = [256]
TRAINING_CONFIG["classifier_dropout"] = 0

In [16]:
prepared_ds = load_and_prepare_ds(TRAINING_CONFIG, FEATURES_CONFIG_SUBSET, df)

prepared_ds

loading configuration file preprocessor_config.json from cache at /home/alesssandros/.cache/huggingface/hub/models--openai--whisper-tiny/snapshots/302560528ac75a251232980ebcc68bad9668f664/preprocessor_config.json
Feature extractor WhisperFeatureExtractor {
  "chunk_length": 30,
  "feature_extractor_type": "WhisperFeatureExtractor",
  "feature_size": 80,
  "hop_length": 160,
  "mel_filters": [
    [
      -0.0,
      0.02486259490251541,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0,
      0.0

Removing extra columns from dataset
Mapping features clusters


Map:   0%|          | 0/20636 [00:00<?, ? examples/s]

Extracting train split
Extracting valid split
Extracting test split
Create `ClassLabels` for target classes
{'genre': ClassLabel(names=['Electronic', 'Rock/Blues', 'World/Ethnic'], id=None)}


Casting the dataset:   0%|          | 0/799 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['label', 'id', 'duration', 'input_features'],
        num_rows: 799
    })
    valid: Dataset({
        features: ['label', 'id', 'duration', 'input_features'],
        num_rows: 100
    })
    test: Dataset({
        features: ['label', 'id', 'duration', 'input_features'],
        num_rows: 100
    })
})

## Frozen

In [17]:
run_name = get_run_name(TRAINING_CONFIG)
model = get_model(TRAINING_CONFIG, prepared_ds["train"])

trainer = get_trainer(
    run_name=run_name,
    model=model,
    train_ds=prepared_ds["train"],
    eval_ds=prepared_ds["valid"],
    training_config=TRAINING_CONFIG,
    output_dir="out",
    debug=False,
    env=NOTEBOOK_ENV,
)

trainer.train()
end_training(run_name, trainer, MODELS_DIR_PATH)

loading configuration file config.json from cache at /home/alesssandros/.cache/huggingface/hub/models--openai--whisper-tiny/snapshots/302560528ac75a251232980ebcc68bad9668f664/config.json
Model config WhisperConfig {
  "_name_or_path": "openai/whisper-tiny",
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "architectures": [
    "WhisperForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "begin_suppress_tokens": [
    220,
    50257
  ],
  "bos_token_id": 50257,
  "d_model": 384,
  "decoder_attention_heads": 6,
  "decoder_ffn_dim": 1536,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 4,
  "decoder_start_token_id": 50258,
  "dropout": 0.0,
  "encoder_attention_heads": 6,
  "encoder_ffn_dim": 1536,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 4,
  "eos_token_id": 50257,
  "forced_decoder_ids": [
    [
      1,
      50259
    ],
    [
      2,
      50359
    ],
    [
      3,
      50363
    ]
  ],
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "max_lengt

PyTorch: setting up devices
The following columns in the training set don't have a corresponding argument in `WhisperForSequenceClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `WhisperForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 799
  Num Epochs = 20
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2000
  Number of trainable parameters = 99331
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Accuracy
1,1.094,1.087617,0.55
2,1.0863,1.084868,0.41
3,1.0754,1.073304,0.55
4,1.0686,1.060855,0.53
5,1.056,1.053237,0.52
6,1.0463,1.038378,0.54
7,1.0411,1.024477,0.58
8,1.0303,1.016008,0.58
9,1.0205,1.0106,0.52
10,1.0211,0.995507,0.58


The following columns in the evaluation set don't have a corresponding argument in `WhisperForSequenceClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `WhisperForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 16
Saving model checkpoint to out/checkpoint-100
Configuration saved in out/checkpoint-100/config.json
Model weights saved in out/checkpoint-100/pytorch_model.bin
Feature extractor saved in out/checkpoint-100/preprocessor_config.json
Deleting older checkpoint [out/checkpoint-1300] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `WhisperForSequenceClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `WhisperForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch si

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▇▁▇▆▆▆██▆█▇▆▆▆▇▇▇▇██
eval/loss,██▇▆▆▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
eval/runtime,█████████▁██▁███████
eval/samples_per_second,▁▁▁▁▁▁▁▁▁█▁▁█▁▁▁▁▁▁▁
eval/steps_per_second,▁▁▁▁▁▁▁▁▁█▁▁█▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁
train/loss,███▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▃▄▃▃▂▃▃▂▁▃▂▂▂▁▁▂▁▂▂▁▂
train/total_flos,▁

0,1
eval/accuracy,0.58
eval/loss,0.96373
eval/runtime,4.8435
eval/samples_per_second,20.646
eval/steps_per_second,1.445
train/epoch,20.0
train/global_step,2000.0
train/learning_rate,0.0
train/loss,0.9868
train/total_flos,5.9305346736e+16


Saving model checkpoint to res/models/whisper-frz-c256-d0-20230223-214404
Configuration saved in res/models/whisper-frz-c256-d0-20230223-214404/config.json
Model weights saved in res/models/whisper-frz-c256-d0-20230223-214404/pytorch_model.bin
Feature extractor saved in res/models/whisper-frz-c256-d0-20230223-214404/preprocessor_config.json


## Fine-Tuning

In [18]:
TRAINING_CONFIG["freeze_encoder"] = False

In [19]:
run_name = get_run_name(TRAINING_CONFIG)
model = get_model(TRAINING_CONFIG, prepared_ds["train"])

trainer = get_trainer(
    run_name=run_name,
    model=model,
    train_ds=prepared_ds["train"],
    eval_ds=prepared_ds["valid"],
    training_config=TRAINING_CONFIG,
    output_dir="out",
    debug=False,
    env=NOTEBOOK_ENV,
)

trainer.train()
end_training(run_name, trainer, MODELS_DIR_PATH)

loading configuration file config.json from cache at /home/alesssandros/.cache/huggingface/hub/models--openai--whisper-tiny/snapshots/302560528ac75a251232980ebcc68bad9668f664/config.json
Model config WhisperConfig {
  "_name_or_path": "openai/whisper-tiny",
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "architectures": [
    "WhisperForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "begin_suppress_tokens": [
    220,
    50257
  ],
  "bos_token_id": 50257,
  "d_model": 384,
  "decoder_attention_heads": 6,
  "decoder_ffn_dim": 1536,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 4,
  "decoder_start_token_id": 50258,
  "dropout": 0.0,
  "encoder_attention_heads": 6,
  "encoder_ffn_dim": 1536,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 4,
  "eos_token_id": 50257,
  "forced_decoder_ids": [
    [
      1,
      50259
    ],
    [
      2,
      50359
    ],
    [
      3,
      50363
    ]
  ],
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "max_lengt

PyTorch: setting up devices
The following columns in the training set don't have a corresponding argument in `WhisperForSequenceClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `WhisperForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 799
  Num Epochs = 20
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2000
  Number of trainable parameters = 8307715
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Accuracy
1,0.9859,0.805263,0.64
2,0.7342,0.646112,0.73
3,0.5818,0.567446,0.74
4,0.5668,0.515999,0.76
5,0.3722,0.703742,0.78
6,0.3083,0.835029,0.79
7,0.2558,0.758668,0.83
8,0.109,0.839265,0.82
9,0.0892,0.939754,0.81
10,0.0441,1.021315,0.84


The following columns in the evaluation set don't have a corresponding argument in `WhisperForSequenceClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `WhisperForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 16
Saving model checkpoint to out/checkpoint-100
Configuration saved in out/checkpoint-100/config.json
Model weights saved in out/checkpoint-100/pytorch_model.bin
Feature extractor saved in out/checkpoint-100/preprocessor_config.json
Deleting older checkpoint [out/checkpoint-700] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `WhisperForSequenceClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `WhisperForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch siz

VBox(children=(Label(value='0.002 MB of 0.013 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.190458…

0,1
eval/accuracy,▁▄▅▅▆▆█▇▇████▇▇▇▇▇▇▇
eval/loss,▄▂▁▁▃▄▃▄▅▆▆▅▇███████
eval/runtime,██▁██▁██▁▁████▁███▁█
eval/samples_per_second,▁▁█▁▁█▁▁██▁▁▁▁█▁▁▁█▁
eval/steps_per_second,▁▁█▁▁█▁▁██▁▁▁▁█▁▁▁█▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁
train/loss,██▇▆▆▅▅▄▅▃▃▃▃▂▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁

0,1
eval/accuracy,0.81
eval/loss,1.24314
eval/runtime,4.8442
eval/samples_per_second,20.643
eval/steps_per_second,1.445
train/epoch,20.0
train/global_step,2000.0
train/learning_rate,0.0
train/loss,0.0003
train/total_flos,5.9305346736e+16


Saving model checkpoint to res/models/whisper-fnt-c256-d0-20230223-215738
Configuration saved in res/models/whisper-fnt-c256-d0-20230223-215738/config.json
Model weights saved in res/models/whisper-fnt-c256-d0-20230223-215738/pytorch_model.bin
Feature extractor saved in res/models/whisper-fnt-c256-d0-20230223-215738/preprocessor_config.json


## Classification Head

In [20]:
TRAINING_CONFIG["classifier_layers"] = [256, 256]

In [21]:
run_name = get_run_name(TRAINING_CONFIG)
model = get_model(TRAINING_CONFIG, prepared_ds["train"])

trainer = get_trainer(
    run_name=run_name,
    model=model,
    train_ds=prepared_ds["train"],
    eval_ds=prepared_ds["valid"],
    training_config=TRAINING_CONFIG,
    output_dir="out",
    debug=False,
    env=NOTEBOOK_ENV,
)

trainer.train()
end_training(run_name, trainer, MODELS_DIR_PATH)

loading configuration file config.json from cache at /home/alesssandros/.cache/huggingface/hub/models--openai--whisper-tiny/snapshots/302560528ac75a251232980ebcc68bad9668f664/config.json
Model config WhisperConfig {
  "_name_or_path": "openai/whisper-tiny",
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "architectures": [
    "WhisperForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "begin_suppress_tokens": [
    220,
    50257
  ],
  "bos_token_id": 50257,
  "d_model": 384,
  "decoder_attention_heads": 6,
  "decoder_ffn_dim": 1536,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 4,
  "decoder_start_token_id": 50258,
  "dropout": 0.0,
  "encoder_attention_heads": 6,
  "encoder_ffn_dim": 1536,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 4,
  "eos_token_id": 50257,
  "forced_decoder_ids": [
    [
      1,
      50259
    ],
    [
      2,
      50359
    ],
    [
      3,
      50363
    ]
  ],
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "max_lengt

PyTorch: setting up devices
The following columns in the training set don't have a corresponding argument in `WhisperForSequenceClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `WhisperForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 799
  Num Epochs = 20
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2000
  Number of trainable parameters = 8373507
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Accuracy
1,1.0872,1.038829,0.65
2,0.8018,0.695724,0.76
3,0.6335,0.565698,0.75
4,0.4261,0.483982,0.8
5,0.3406,0.614469,0.78
6,0.2092,0.98445,0.74
7,0.2133,0.642029,0.85
8,0.0962,0.906091,0.83
9,0.0685,0.940986,0.78
10,0.0214,0.60179,0.89


The following columns in the evaluation set don't have a corresponding argument in `WhisperForSequenceClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `WhisperForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 16
Saving model checkpoint to out/checkpoint-100
Configuration saved in out/checkpoint-100/config.json
Model weights saved in out/checkpoint-100/pytorch_model.bin
Feature extractor saved in out/checkpoint-100/preprocessor_config.json
Deleting older checkpoint [out/checkpoint-1000] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `WhisperForSequenceClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `WhisperForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch si

VBox(children=(Label(value='0.002 MB of 0.027 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.088952…

0,1
eval/accuracy,▁▄▄▅▅▄▇▆▅█▇▆▇▇▇▇▇▇▇▇
eval/loss,█▄▂▁▃▇▃▆▇▂▅▆▆▆▅▆▆▆▆▆
eval/runtime,██▇███▁▇█▇███▅▇▇▇▁▇█
eval/samples_per_second,▁▁▁▁▁▁█▁▁▁▁▁▁▂▁▁▁█▁▁
eval/steps_per_second,▁▁▁▁▁▁█▁▁▁▁▁▁▂▁▁▁█▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁
train/loss,███▇▆▅▅▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁

0,1
eval/accuracy,0.85
eval/loss,0.85691
eval/runtime,4.8441
eval/samples_per_second,20.644
eval/steps_per_second,1.445
train/epoch,20.0
train/global_step,2000.0
train/learning_rate,0.0
train/loss,0.0002
train/total_flos,5.98099976928e+16


Saving model checkpoint to res/models/whisper-fnt-c256_256-d0-20230223-221234
Configuration saved in res/models/whisper-fnt-c256_256-d0-20230223-221234/config.json
Model weights saved in res/models/whisper-fnt-c256_256-d0-20230223-221234/pytorch_model.bin
Feature extractor saved in res/models/whisper-fnt-c256_256-d0-20230223-221234/preprocessor_config.json


# Training 2

In [22]:
TRAINING_CONFIG = {
    "epochs": 20,
    "learning_rate": 5e-5,
    "warmup": 0.0,
    "train_batch_size": 8,
    "eval_batch_size": 16,
    "feature_encoder": "whisper",
    "freeze_encoder": False,
    "classifier_layers": [256], 
    "classifier_dropout": 0.0,
}

## Genre Classification

In [23]:
# Build the filename indicating the subset of the whole dataset with the specific configurations
df = create_or_load_df(FEATURES_CONFIG_GEN)

df.head()

Loading res/samples_clustered_genre6.csv
16932 examples in DataFrame
split
train    13545
test      1694
valid     1693
dtype: int64


Unnamed: 0,path,duration,id,genre,split
0,01 Hip Hop/Abandoned Brass Stabs.mp3,7.262041,01_Hip_Hop_Abandoned_Brass_Stabs,Hip Hop/RnB,test
1,01 Hip Hop/Against Time Keys.mp3,6.948571,01_Hip_Hop_Against_Time_Keys,Hip Hop/RnB,train
2,01 Hip Hop/Against Time Piano.mp3,6.948571,01_Hip_Hop_Against_Time_Piano,Hip Hop/RnB,train
3,01 Hip Hop/Against Time Sax Sample.mp3,6.948571,01_Hip_Hop_Against_Time_Sax_Sample,Hip Hop/RnB,valid
4,01 Hip Hop/Against Time Staccato Strings.mp3,6.948571,01_Hip_Hop_Against_Time_Staccato_Strings,Hip Hop/RnB,train


In [24]:
prepared_ds = load_and_prepare_ds(TRAINING_CONFIG, FEATURES_CONFIG_GEN, df)

prepared_ds

Loading cached processed dataset at /home/alesssandros/dev/FCN_Newspaper/aii/res/datasets/ds-whisper-full-encoded/cache-6571d4b477ed53ce.arrow


Removing extra columns from dataset
Mapping features clusters
Extracting test split
Extracting train split
Extracting valid split
Create `ClassLabels` for target classes
{'genre': ClassLabel(names=['Electronic', 'Hip Hop/RnB', 'House', 'Orchestral', 'Rock/Blues', 'World/Ethnic'], id=None)}


Casting the dataset:   0%|          | 0/1694 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/13545 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/1693 [00:00<?, ? examples/s]

DatasetDict({
    test: Dataset({
        features: ['label', 'id', 'duration', 'input_features'],
        num_rows: 1694
    })
    train: Dataset({
        features: ['label', 'id', 'duration', 'input_features'],
        num_rows: 13545
    })
    valid: Dataset({
        features: ['label', 'id', 'duration', 'input_features'],
        num_rows: 1693
    })
})

In [25]:
run_name = get_run_name(TRAINING_CONFIG)
model = get_model(TRAINING_CONFIG, prepared_ds["train"])

trainer = get_trainer(
    run_name=run_name,
    model=model,
    train_ds=prepared_ds["train"],
    eval_ds=prepared_ds["valid"],
    training_config=TRAINING_CONFIG,
    output_dir="out",
    debug=False,
    env=NOTEBOOK_ENV,
)

trainer.train()
end_training(run_name, trainer, MODELS_DIR_PATH)

loading configuration file config.json from cache at /home/alesssandros/.cache/huggingface/hub/models--openai--whisper-tiny/snapshots/302560528ac75a251232980ebcc68bad9668f664/config.json
Model config WhisperConfig {
  "_name_or_path": "openai/whisper-tiny",
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "architectures": [
    "WhisperForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "begin_suppress_tokens": [
    220,
    50257
  ],
  "bos_token_id": 50257,
  "d_model": 384,
  "decoder_attention_heads": 6,
  "decoder_ffn_dim": 1536,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 4,
  "decoder_start_token_id": 50258,
  "dropout": 0.0,
  "encoder_attention_heads": 6,
  "encoder_ffn_dim": 1536,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 4,
  "eos_token_id": 50257,
  "forced_decoder_ids": [
    [
      1,
      50259
    ],
    [
      2,
      50359
    ],
    [
      3,
      50363
    ]
  ],
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "max_lengt

PyTorch: setting up devices
The following columns in the training set don't have a corresponding argument in `WhisperForSequenceClassification.forward` and have been ignored: duration, id. If duration, id are not expected by `WhisperForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 13545
  Num Epochs = 20
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 33880
  Number of trainable parameters = 8308486
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

## Category Classification

In [None]:
# Build the filename indicating the subset of the whole dataset with the specific configurations
df = create_or_load_df(FEATURES_CONFIG_CAT)

df.head()

In [None]:
prepared_ds = load_and_prepare_ds(TRAINING_CONFIG, FEATURES_CONFIG_CAT, df)

prepared_ds

In [None]:
run_name = get_run_name(TRAINING_CONFIG)
model = get_model(TRAINING_CONFIG, prepared_ds["train"])

trainer = get_trainer(
    run_name=run_name,
    model=model,
    train_ds=prepared_ds["train"],
    eval_ds=prepared_ds["valid"],
    training_config=TRAINING_CONFIG,
    output_dir="out",
    debug=False,
    env=NOTEBOOK_ENV,
)

trainer.train()
end_training(run_name, trainer, MODELS_DIR_PATH)