(stroke-classification)=
# Stroke classification

Introduction here!

## Mridangam stroke classification

{cite}`anantapadmanabhan_mridangam_2013, mridangam_stroke`

In [None]:
#%pip install compiam

## Importing compiam to the project
import compiam

# Import extras and supress warnings to keep the tutorial clean
import random
from pprint import pprint
import warnings
warnings.filterwarnings('ignore')

In [None]:
from compiam.timbre.stroke_classification import MridangamStrokeClassification
mridangam_stroke_classification = MridangamStrokeClassification()


Let's start by loading the mridangam stroke dataset. Since ``MridangamStrokeClassification``is based on the Mridangam Stroke Dataset, `compiam` includes a specific function to load the dataset and integrate it to the pipeline.

In [None]:
mridangam_stroke_classification.load_mridangam_dataset(
    data_home="../workspace/", download=True)

```{note}
This function does not return a dataloader. Instead, the dataloader lives within the tool class. We will see how this works in the following steps of this walkthrough.
```

In [None]:
# Print list of available mirdangam strokes in the dataset
mridangam_stroke_classification.list_strokes()

Let's train and evaluate a very basic model to perform classification of mridangam strokes. We first use a util function in `compiam` to separate a part of the mridangam dataset, which we will use for evaluation.

In [None]:
from compiam.utils.datasets import split_mirdata_tracks

train_split, evaluation_split = split_mirdata_tracks(mridangam_stroke_classification.dataset, split=0.1)

# Let's print out a random track from the created evaluation split
random.choice(list(evaluation_split.items()))

Our class will assume that the entire dataset is used for the training process. We need to update the dataset in the class with the training split.

In [None]:
mridangam_stroke_classification.mridangam_tracks = train_split
mridangam_stroke_classification.mridangam_ids = list(train_split.keys())

**Let's now train the model!** We will train Support Vector Machine (SVM) model using `scikit learn`. The mridangam stroke classification tool in `compiam` uses the [MusicExtraction in Essentia](https://essentia.upf.edu/streaming_extractor_music.html) to compute low-level features from the stroke recordings and feed the model.

```{note}
You can also train a different model and compare the performance. We offer other options (see [the documentation of the tool](https://mtg.github.io/compIAM/source/timbre.html#mridangam-stroke-classification)), but feel free to open a Pull Request in `compiam` to add more models to the available options.
```


In [None]:
svm__accuracy = mridangam_stroke_classification.train_model()

**The model has been trained. That is good!** We have also got the accuracy returned in case we want to store it, re-train the model again using different settings, and compare. 

Now we can predict the stroke on a particular list of instances. First, we need to get the list of paths for the `mirdata` dataset split we generated a few steps earlier.

In [None]:
# Get paths from created evaluation split
eval_paths = [evaluation_split[x].audio_path for x in list(evaluation_split.keys())]

# Compute prediction from list of paths
prediction = mridangam_stroke_classification.predict(eval_paths)

In [None]:
# Visualise and evaluate some predictions from the model output
pprint(random.choice(list(prediction.items())))
pprint(random.choice(list(prediction.items())))
pprint(random.choice(list(prediction.items())))