In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%%capture
!pip install pytorch-lightning
!pip install torchdata
!pip install boto3
!pip install transformers

In [None]:
import copy
import math
import os
import pytorch_lightning
import re
import shutil
import subprocess
import torch
import torchaudio
import transformers
import warnings
import multiprocessing as mp
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from imblearn.over_sampling import SMOTENC
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from torchmetrics.functional import mean_squared_error, mean_absolute_error
from torchmetrics.functional.classification import multiclass_accuracy, multiclass_f1_score
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import (
    AdamW,
    AutoConfig,
    AutoModel,
    UniSpeechModel,
    UniSpeechPreTrainedModel
)
warnings.filterwarnings('ignore')
tqdm.pandas()

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.__version__)
print(torchaudio.__version__)
print(DEVICE)

2.0.0+cu118
2.0.1+cu118
cuda


# Preprocess data


As a first step, the dataset will be subsampled and balanced to ensure an approximate proportional distribution of classes. Steps include:

- Load main data and apply some reformatting
- Proportionately subsampling data
- Convert audio of subsampled set (mp3 -> wav) and resample frequency to 16Khz
- Remove low count classes (sub 100)
- Create sepearate class based datasets in dataframe form (e.g. age, accent)


In [None]:
# metadata = pd.read_csv('/content/gdrive/MyDrive/EDA/metadata.csv')
# metadata['audio'] = metadata["audio"].apply(lambda x: os.path.split(x)[1])
# metadata['utt_len'] = metadata["utterance"].apply(lambda x: len(x))
# metadata = metadata[["gender", "age", "accent", "utterance", "utt_len", "audio", "duration"]]
# metadata.insert(0, 'id', metadata["audio"].apply(lambda x: x.split('_')[-1][:-4]))
# metadata['audio'] = metadata["audio"].apply(lambda x: f"/content/gdrive/MyDrive/model/audio_processed_original/{x[:-3]}wav")
# metadata

In [None]:
def get_counts_perc(df, column):
    value_counts = df[column].value_counts()
    value_percs = df[column].value_counts(normalize=True).mul(100).round(1).astype(str) + '%'
    return pd.concat([value_counts, value_percs], axis=1, keys=['count', 'percentage'])

In [None]:
# def balance_features(metadata):

#     female_n = 875
#     male_n = 375
#     age_n = 156
#     accent_n = 1250
#     feats = {}
#     f_list = []

#     for index, row in metadata.iterrows():
#         if row['gender'] == 'other':
#             continue

#         if row['accent'] in feats:
#             feats[row['accent']]['count'] += 1
#         else:
#             feats[row['accent']] = {'count': 0}

#         if feats[row['accent']]['count'] <= accent_n:
#             if row['age'] in feats[row['accent']]:
#                 feats[row['accent']][row['age']] += 1
#             else:
#                 feats[row['accent']][row['age']] = 1

#             if feats[row['accent']][row['age']] <= age_n: 
#                 if row['gender'] in feats[row['accent']]:
#                     feats[row['accent']][row['gender']] += 1
#                 else:
#                     feats[row['accent']][row['gender']] = 1

#                 if row['gender'] == 'male' and feats[row['accent']][row['gender']] <= male_n or \
#                  row['gender'] == 'female' and feats[row['accent']][row['gender']] <= female_n :
#                     f_list.append({
#                         "gender": row['gender'],
#                         "age": row['age'], 
#                         "accent": row['accent'],
#                         "utterance": row['utterance'],
#                         "utt_len": row['utt_len'],
#                         "audio": row['audio'],
#                         "duration": row['duration'],
#                     })
#                 else:
#                     feats[row['accent']]['count'] -= 1
#                     feats[row['accent']][row['age']] -= 1
#                     continue

#             else:
#                 feats[row['accent']]['count'] -= 1
#                 continue

#     return f_list

In [None]:
# md_shuffled = metadata.sample(frac=1).reset_index(drop=True)
# balanced_list = balance_features(md_shuffled)
# meta_balanced_pre = pd.DataFrame(balanced_list)
# meta_balanced_pre

In [None]:
# get_counts_perc(meta_balanced_pre, 'gender')

In [None]:
# get_counts_perc(meta_balanced_pre, 'age')

In [None]:
# get_counts_perc(meta_balanced_pre, 'accent')

In [None]:
# filepath = '/content/gdrive/MyDrive/model/meta_balanced_v0.csv'
# meta_balanced_pre.to_csv(filepath, index=False)

In [None]:
def load_dataframe(version=''):
    df = pd.read_csv(f'/content/gdrive/MyDrive/model/meta_balanced{version}.csv', low_memory=False)
    df.insert(0, 'id', df["audio"].apply(lambda x: x.split('_')[-1][:-4]))
    df['audio'] = df["audio"].apply(lambda x: f"/content/gdrive/MyDrive/model/audio{version}/{x[:-3]}wav")
    return df

In [None]:
df = load_dataframe('_v3')
df

Unnamed: 0,id,gender,age,accent,utterance,utt_len,audio,duration
0,23809244,male,twenties,United States English,"Married to French Stefani, with whom he later ...",63,/content/gdrive/MyDrive/model/audio_v3/common_...,5.664
1,21091923,male,fourties,"German English,Non native speaker",There was a catapult on either side of the aft...,56,/content/gdrive/MyDrive/model/audio_v3/common_...,6.888
2,18750828,female,twenties,"India and South Asia (India, Pakistan, Sri Lanka)",It remains hot until the monsoon breaks toward...,67,/content/gdrive/MyDrive/model/audio_v3/common_...,5.640
3,19057438,male,twenties,Canadian English,"Returning home, Michael does not find Grant an...",77,/content/gdrive/MyDrive/model/audio_v3/common_...,6.840
4,20969631,male,fourties,"German English,Non native speaker",Just to the northeast is the crater Demonax.,44,/content/gdrive/MyDrive/model/audio_v3/common_...,6.216
...,...,...,...,...,...,...,...,...
15145,32443222,male,teens,"United States English,Southwestern United Stat...",She was the first female in that role at Oneonta.,49,/content/gdrive/MyDrive/model/audio_v3/common_...,5.040
15146,31324740,male,teens,Ukrainian,Historically this type of boat was used by Gow...,70,/content/gdrive/MyDrive/model/audio_v3/common_...,5.616
15147,32639044,female,thirties,"Northern Irish,Norwegian,yorkshire",The episode then shows the adventures Susan im...,53,/content/gdrive/MyDrive/model/audio_v3/common_...,5.580
15148,30615149,male,teens,"Australian English,England English",He was assassinated by his political opponents...,57,/content/gdrive/MyDrive/model/audio_v3/common_...,4.680


In [None]:
# def convert_audio(infile):
#     outpath = '/content/gdrive/MyDrive/model/audio_v3'
#     outfile = f"{outpath}/{infile[:-4]}.wav"
#     inpath = '/content/audio_original/audio2'
#     infile = f"{inpath}/{infile}"
#     if not os.path.isfile(outfile):
#         subprocess.run(['sox', infile, '-r', '16000', outfile])

In [None]:
# !mkdir /content/gdrive/MyDrive/model/audio_v3

In [None]:
# with mp.Pool(mp.cpu_count()) as pool:
#     results = tqdm(
#         pool.imap_unordered(convert_audio, mb_df_v3["audio"], chunksize=5),
#         total=len(mb_df_v3["audio"]),
#     )
#     for result in results:
#         if result:
#             print(result)

In [None]:
def get_class_data(df):
    accent_data = df[df.groupby('accent').accent.transform('count')>100]
    age_data = df[df.groupby('age').age.transform('count')>100]
    print(len(age_data))
    return age_data, accent_data

In [None]:
AGE_DATA, ACCENT_DATA = get_class_data(mb_df_v3)

15150


In [None]:
# get_counts_perc(AGE_DATA, 'age')

In [None]:
# len(ACCENT_DATA.accent.unique())

In [None]:
# get_counts_perc(ACCENT_DATA, 'accent')

In [None]:
(AGE_DATA["duration"].sum()/60)/60

23.63253399305556

## Prepare

Prepare data objects to load during training. Objects include:

- Processing object to:
    - encode labels to from text to integer
    - load audio to tensor and optionally resample to correct frequency
    - Optionally normalize mean and standard deviation of audio
- Dataset object to load and apply processing to elements in a batch.
- Custom collate object to:
    - Derive attention mask for inputs
    - apply batchwise padding to input audio and mask
- Datamodule object to:
    - prepare train/validation/test data
    - sort train and validation by length so batches contain similar length iputs
    - Load dataset object
    - Load custom collate function


In [None]:
SAMPLE_RATE = 16000

AGE_LABELS = sorted(AGE_DATA.age.unique())
AGE_IDX = [i for i, label in enumerate(AGE_LABELS)]

ACCENT_LABELS = sorted(ACCENT_DATA.accent.unique())
ACCENT_IDX = [i for i, label in enumerate(ACCENT_LABELS)]

In [None]:
%env LC_ALL=C.UTF-8
%env LANG=C.UTF-8
%env TRANSFORMERS_CACHE=/content/cache

env: LC_ALL=C.UTF-8
env: LANG=C.UTF-8
env: TRANSFORMERS_CACHE=/content/cache


In [None]:
from torch.utils.data import DataLoader, Dataset

class ClassProcessor:
    """Class for preprocessing text and audio for training and inference."""

    def __init__(self, labels=[], sample_rate=int()):
        """Args:
            labels (bool, optional): Labes. Outputs will be preprocessed labels.
            audio (bool, optional): Raw audio paths. Outputs will be preprcessed audio.
            inference (bool, optional): If True, the output of the class will be
            retricted to specific preprocesses.
            labels (:obj: list, optional): List of output labels (vocabulary).
            sample_rate (int, optional): Integer value sample rate of audio model was trained with.
        """

        self.labels = labels
        self.target_sample_rate = sample_rate

    def _encode_label(self, label):
        """Convert labels to numeric values. Used as target during training."""
        idx = self.labels.index(label) if label in self.labels else -1
        return idx

    def _resample_waveform(self, waveform):
        """Converts sample rate to that of audio trained in pretrained model."""

        speech_array, sample_rate = torchaudio.load(waveform)
        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
            speech_array = resampler(speech_array)
        return speech_array

    def _normalize_audio(self, waveform):
        """Mean and standard deviation normalization.
        Ensures model isn't biased too much to speaker specific features."""

        waveform_mean = torch.mean(waveform)
        waveform_std = torch.std(waveform)
        return (waveform - waveform_mean) / (waveform_std + 1e-10)

    def __call__(self, data, column):
        """To apply class preprocessing methods functionally to input text and audio."""

        waveform = self._resample_waveform(data['audio'])
        # waveform = self._normalize_audio(waveform)
        label = self._encode_label(data[column])
        return waveform.squeeze().numpy(), label


class ClassDataset(Dataset):
    """Dataset object to load and process training and eval data."""

    def __init__(self, data, labels, sample_rate, column):
        """Args:
            input_data (:obj: `list`): Array of audio, transcript arrays.
            labels (:obj: list, optional): List of output labels (vocabulary).
            sample_rate (int, optional): Integer value sample rate of audio model was trained with.
        """

        self.processor = ClassProcessor(labels, sample_rate)
        self.column = column
        self.data = data
  
    def __len__(self):
        """Return len of input_data. """

        return len(self.data)

    def __getitem__(self, idx):
        """
        Preprocesses batch data at index in array of audio, transcript pairs.
        Returns set of tensors containing numeric representations of training data.
        """
        row = self.data.iloc[idx]
        try:
            waveform, label= self.processor(row, self.column)
            return {'input_values': waveform, 'labels': label}
        except:
            return None
            print('no file')

In [None]:
class CollateFn:
    """A custom collate class to process variable length input data."""

    def get_pad(self, waveforms, mask=False):

        waveform_size = len(max(waveforms,key=len))
        return torch.zeros(len(waveforms), waveform_size)

    def pad_data(self, waveforms):
        """Pads input and masked data to length of longest input."""

        mask_padded = self.get_pad(waveforms, mask=True)
        waveforms_padded = self.get_pad(waveforms)
        for i in range(len(waveforms)):
            waveforms_padded[i][0:len(waveforms[i])] = torch.tensor(waveforms[i])
            mask_padded[i][0:len(waveforms[i])] = torch.ones(len(waveforms[i]))
        return waveforms_padded, mask_padded

    def __call__(self, batch):
        """Applies collate processes functionally to batch data"""

        input_features = [feature["input_values"] for feature in batch if feature]
        label_features = [feature["labels"] for feature in batch if feature]
    
        try:
            waveforms_padded, mask_padded = self.pad_data(input_features)
            batch_out = {}
            batch_out["input_values"] = waveforms_padded
            batch_out["attention_mask"] = mask_padded
            batch_out["labels"] = torch.tensor(label_features)
            return batch_out
        except:
            return None


class ClassDataModule(LightningDataModule):
    """Custom Lightning Data Module class. Prepares and loads training, validation, and test data."""

    def __init__(
        self,
        data,
        labels,
        sample_rate,
        out_column,
        batch_size
    ):
        """Args:
            data (:obj: `dataframe`): Dataframe of raw data.
            labels (:obj: list, optional): List of output labels.
            sample_rate (int, optional): Integer value sample rate of audio model was trained with.
            batch_size(int): Batch size value.
        """

        super().__init__()
        self.data = data
        self.batch_size = batch_size
        self.labels = labels
        self.column = out_column
        self.sample_rate = sample_rate

    def setup(self, stage=None):
        """Randomized split of train, validation, and test data. Returns sorted splits."""

        train_data, dev_data = train_test_split(self.data,
                                                test_size=0.2,
                                                random_state=123,
                                                stratify=self.data[self.column])
    
        val_data, test_data = train_test_split(dev_data,
                                               test_size=0.5,
                                               random_state=123,
                                               stratify=dev_data[self.column])

        train_data = train_data.sort_values("duration")
        val_data = val_data.sort_values("duration")

        self.train_data = train_data[[self.column, "audio"]]
        self.val_data = val_data[[self.column, "audio"]]
        self.test_data = test_data[[self.column, "audio"]]
            
    def train_dataloader(self):
        """Loads traning data."""

        return DataLoader(dataset=ClassDataset(self.train_data, self.labels, self.sample_rate, self.column),
                          batch_size=self.batch_size,
                          drop_last=False,
                          collate_fn=CollateFn(),
                          num_workers=2,
                          shuffle=False)

    def val_dataloader(self):
        """Loads validation data."""

        return DataLoader(dataset=ClassDataset(self.val_data, self.labels, self.sample_rate, self.column),
                          batch_size=self.batch_size,
                          drop_last=False,
                          collate_fn=CollateFn(),
                          num_workers=2)

    def test_dataloader(self):
        """Loads test data."""

        return DataLoader(dataset=ClassDataset(self.test_data, self.labels, self.sample_rate, self.column),
                          batch_size=self.batch_size,
                          drop_last=False,
                          collate_fn=CollateFn(),
                          num_workers=2)

## Compile Model

Compile and train a classification model for speech feature classes. Includes:

- A model trainer class for loading pretrained model training a classification head.
- A classification head for finetuning a self supervised learning pretrained model (Unispeech)
- Two stage learning rate scheduler object. 
- Metrics compute function, which includes accuracy and F1 score.

In [None]:
def compute_metrics(preds, targets, num_labels):
    """Takes prediction logits and target labels and computes accuracy and f score."""

    preds = preds.argmax(-1)
    acc = multiclass_accuracy(preds, targets, num_classes=num_labels)
    f1 = multiclass_f1_score(preds, targets, num_classes=num_labels, average='macro')
    # matrix = confusion_matrix(y_true=targets, y_pred=preds)
    # report = classification_report(y_true=targets, y_pred=preds, labels=AGE_IDX, target_names=AGE_LABELS)
    return {"accuracy": acc, "f1_score": f1}

In [None]:
class BiStageLRScheduler(torch.optim.lr_scheduler._LRScheduler):
    """Custom shceduler class for two stage learning rate scheduling."""

    def __init__(self, optimizer, warmup_updates, decay_updates):        
        """Args:
            optimizer (:obj:): Numpy array of audio, transcript arrays.
            warmup_updates (int): number of steps that warmup updates applied.
            decay_updates (int): number of steps that decay updates applied.
        """

        self.warmup_updates = warmup_updates
        self.decay_updates = decay_updates
        # scale multipliers
        self.init_lr_scale = 0.01
        self.final_lr_scale = 0.05
        super().__init__(optimizer, last_epoch=-1)

    def get_lr(self):
        """reduces learing rate by calculated factors. Returns adjust learning rate."""

        base_lrs_out = []
        for base_lr in self.base_lrs:
            if self._step_count <= self.warmup_updates:
                base_lr = base_lr * (self.init_lr_scale + self._step_count / self.warmup_updates * (1 - self.init_lr_scale))
            elif self._step_count <= (self.warmup_updates + self.decay_updates):
                base_lr = base_lr * math.exp(math.log(self.final_lr_scale) * (self._step_count - self.warmup_updates) / self.decay_updates)
            else:
                base_lr = base_lr * self.final_lr_scale
            base_lrs_out.append(base_lr)
        return base_lrs_out

In [None]:
class ClassifierModel(UniSpeechPreTrainedModel):
    """Classification head module. Trains pretrain model towards classification task."""

    def __init__(self, config, class_weights):
        super().__init__(config)

        self.unispeech = AutoModel.from_config(config)

        num_layers = config.num_hidden_layers + 1  # transformer layers + input embeddings
        self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)

        self.feature_extractor = nn.Linear(config.hidden_size, config.hidden_size//2)
        self.classifier = nn.Linear(config.hidden_size//2, config.num_labels)
        self.loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
        # Initialize weights and apply final processing
        self.post_init()

    def freeze_feature_extractor(self):
        """
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        """
 
        self.unispeech.feature_extractor._freeze_parameters()

    def _get_feat_extract_output_lengths(self, input_lengths):
        """
        Computes the output length of the convolutional layers
        """

        def _conv_out_length(input_length, kernel_size, stride):

            return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1

        for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
            input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

        return input_lengths

    def _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask):
        """"""

        # Effectively attention_mask.sum(-1), but not inplace to be able to run
        # on inference mode.
        non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
        output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
        batch_size = attention_mask.shape[0]

        attention_mask = torch.zeros(
            (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
        )
        # these two operations makes sure that all values before the output lengths idxs are attended to
        attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
        attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
        return attention_mask

    def forward(
        self,
        input_values,
        attention_mask = None,
        output_hidden_states = True,
        labels = None,
    ):

        """
        Generates representations from hidden layers of pretrain model. 
        Trains meanpooled representations on labels.
        """

        outputs = self.unispeech(
            input_values,
            attention_mask=attention_mask,
            output_hidden_states=output_hidden_states
        )

        hidden_states = outputs[2]
        hidden_states = torch.stack(hidden_states, dim=1)
        norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
        hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)

        padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
        hidden_states[~padding_mask] = 0.0
        pooling = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

        hidden_states = self.feature_extractor(pooling)
        hidden_states = torch.relu(hidden_states)
        logits = self.classifier(hidden_states)
        loss = None
        if labels is not None:
            loss = self.loss_fn(logits.view(-1, self.config.num_labels), labels.view(-1))
        return loss, logits

In [None]:
class ClassTransformer(LightningModule):
    """Trainer class to load pretrain model and finetune classificaiton head."""

    def __init__(
        self,
        model_name_or_path,
        num_labels,
        class_weights,
        learning_rate = 1e-4,
        warmup_steps = 0,
        weight_decay = 0.0005
    ):
        super().__init__()
        self.save_hyperparameters()
        self.num_labels = num_labels
        config = AutoConfig.from_pretrained(
            model_name_or_path,
            num_labels=num_labels,
            attention_dropout=0.01,
            hidden_dropout=0.05,
            layerdrop=0.01,
            gradient_checkpointing=True,
            use_weighted_layer_sum=True
        )
        self.model = ClassifierModel.from_pretrained(
            model_name_or_path,
            config=config,
            class_weights=class_weights
        )
        self.model.freeze_feature_extractor()

        self.train_cum_acc = []
        self.train_cum_loss = []

        # disable automatic lightning optimization (necessary for customizing train step)
        self.automatic_optimization = False
        # used for mixed precision training in order to reduce gradient underflow
        self.scaler = torch.cuda.amp.GradScaler()

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        self.optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

    def forward(self, **inputs):
        """Forward pass training with classification head."""

        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        """Manual optimizatoin training step. Includes metric computation."""

        if batch:
            opt = self.optimizers()

            # resets gradients before parameter updates
            opt.zero_grad()

            # apply mixed precision training
            with torch.cuda.amp.autocast(enabled=True):
                outputs = self(**batch)    
            
            loss, logits = outputs

            self.train_cum_loss.append(loss.item())
            cum_loss = sum(self.train_cum_loss)/len(self.train_cum_loss)

            self.log(f"train_loss", cum_loss, on_step=True, on_epoch=False, prog_bar=True)

            # scales loss to prevent vanishing gradients
            loss = self.scaler.scale(loss)
            self.manual_backward(loss)

            # Unscales the gradients of optimizer's assigned params in-place
            # Calling before clipping enables you to clip unscaled gradients as usual
            self.scaler.unscale_(opt)

            # gradient clipping manipulates a set of gradients such that their global norm is <= threshold value
            self.clip_gradients(opt, gradient_clip_val=1.0, gradient_clip_algorithm="norm")

            # Unscales the gradients of the optimizer's assigned params.
            self.scaler.step(opt)

            # stepwise scheduler call
            sch = self.lr_schedulers()
            sch.step()

            # Updates the scale for next iteration.
            self.scaler.update()

            targets = batch['labels'].to('cpu')
            logits = logits.to('cpu')

            metrics = compute_metrics(logits, targets, self.num_labels)

            self.train_cum_acc.append(metrics["accuracy"])
            cum_acc = sum(self.train_cum_acc)/len(self.train_cum_acc)

            self.log(f"train_acc", cum_acc, on_step=True, on_epoch=False, prog_bar=True)

            return loss

    def validation_step(self, batch, batch_idx):
        """Computes metrics using logits derived at validation step."""

        if batch:
            outputs = self(**batch)
            loss, logits = outputs
            targets = batch['labels'].to('cpu')
            logits = logits.to('cpu')
            metrics = compute_metrics(logits, targets, self.num_labels)
            
            self.log(f"val_loss", loss, on_step=False, on_epoch=True, prog_bar=True) 
            self.log(f"val_acc", metrics["accuracy"], on_step=False, on_epoch=True, prog_bar=True)
            return loss

    def configure_optimizers(self):
        """Prepare optimizer and schedule Bi-stage learning rate scheduler."""

        scheduler = BiStageLRScheduler(
            self.optimizer,
            warmup_updates=self.hparams.warmup_steps,
            decay_updates=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [self.optimizer], [scheduler]

## Train Model

Train step function that contains module calls and customized training aspects.

In [None]:
def run_train_class(data, labels, column, class_weights):
    """Confgure and run training"""

    seed_everything(123)
    checkpoint_dir = "/content/checkpoints"
    checkpoint = ModelCheckpoint(
        dirpath=checkpoint_dir,
        monitor="val_acc",
        mode="max",
        save_top_k=1,
        save_weights_only=True,
        verbose=True,
    )
    early_stopping = EarlyStopping(
        monitor="val_acc",
        min_delta=0.00,
        patience=6,
        verbose=False,
        mode="max"
    )
    trainer = Trainer(
        max_steps=10000,
        callbacks=[checkpoint, early_stopping],
        accelerator="auto",
        val_check_interval=500,
        log_every_n_steps=500,
        check_val_every_n_epoch=None
    )
    data_module = ClassDataModule(
        data=data,
        labels=labels,
        sample_rate=SAMPLE_RATE,
        out_column=column,
        batch_size=4
    )
    checkpoint="microsoft/unispeech-large-1500h-cv"
    num_labels = len(labels)
    model = ClassTransformer(
        model_name_or_path=checkpoint,
        num_labels=num_labels,
        class_weights=class_weights
    )
    trainer.fit(model, data_module)

In [None]:
def get_class_weights(df, column):
    """Weights classes so that imbalanced data doesn't affect performance."""

    df_sort = df.sort_values(by=[column])
    weights = compute_class_weight(class_weight = "balanced", classes= np.unique(df_sort[column].tolist()), y=df_sort[column].tolist())
    return torch.FloatTensor(weights).cuda()

In [None]:
data=ACCENT_DATA
labels=ACCENT_LABELS
column = 'accent'
class_weights = get_class_weights(data, column)
run_train_class(data=data, labels=labels, column=column, class_weights=class_weights)