In [1]:
import os
import torch
from torch import nn

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import pandas as pd

from typing import List
import numpy as np

# 1. Data
# 2. Model
# 3. Prediction
# 4. Filter data 
# 5. Training loop
# 6. Metrics
# 7. Save the model

In [2]:
genre_mapping = {
    0: "Blues",
    1: "Classical",
    2: "Country",
    3: "Disco",
    4: "Hiphop",
    5: "Jazz",
    6: "Metal",
    7: "Pop",
    8: "Reggae",
    9: "Rock",
}

In [3]:
genre_mapping_inverse = {
    "Blues": 0,
    "Classical": 1,
    "Country": 2,
    "Disco": 3,
    "Hiphop": 4,
    "Jazz": 5,
    "Metal": 6,
    "Pop": 7,
    "Reggae": 8,
    "Rock": 9,
}

In [4]:
column_names = [
    "chroma_stft_mean",
    "chroma_stft_var",
    "rms_mean",
    "rms_var",
    "spectral_centroid_mean",
    "spectral_centroid_var",
    "spectral_bandwidth_mean",
    "spectral_bandwidth_var",
    "rolloff_mean",
    "rolloff_var",
    "zero_crossing_rate_mean",
    "zero_crossing_rate_var",
    "harmony_mean",
    "harmony_var",
    "tempo",
    "mfcc1_mean",
    "mfcc1_var",
    "mfcc2_mean",
    "mfcc2_var",
    "mfcc3_mean",
    "mfcc3_var",
    "mfcc4_mean",
    "mfcc4_var",
    "mfcc5_mean",
    "mfcc5_var",
    "mfcc6_mean",
    "mfcc6_var",
    "mfcc7_mean",
    "mfcc7_var",
    "mfcc8_mean",
    "mfcc8_var",
    "mfcc9_mean",
    "mfcc9_var",
    "mfcc10_mean",
    "mfcc10_var",
    "mfcc11_mean",
    "mfcc11_var",
    "mfcc12_mean",
    "mfcc12_var",
    "mfcc13_mean",
    "mfcc13_var",
    "mfcc14_mean",
    "mfcc14_var",
    "mfcc15_mean",
    "mfcc15_var",
    "mfcc16_mean",
    "mfcc16_var",
    "mfcc17_mean",
    "mfcc17_var",
    "mfcc18_mean",
    "mfcc18_var",
    "mfcc19_mean",
    "mfcc19_var",
    "mfcc20_mean",
    "mfcc20_var",
]

In [5]:
class MusicClassifier(nn.Module):
    def __init__(self, input_features, output_features):
        super().__init__()
        self.linear_layer_stack = nn.Sequential(
            nn.Linear(
                in_features=input_features, out_features=256, dtype=torch.float32
            ),
            nn.GELU(),
            nn.Dropout(p=0.2),
            nn.Linear(in_features=256, out_features=128, dtype=torch.float32),
            nn.GELU(),
            nn.Dropout(p=0.2),
            nn.Linear(
                in_features=128, out_features=output_features, dtype=torch.float32
            ),
        )

    def forward(self, x):
        return self.linear_layer_stack(x)

In [6]:
def training_loop(MusicClassifier: MusicClassifier):
    # TODO: Externalise this
    # Init le model
    torch.manual_seed(42)
    model = MusicClassifier(input_features=55, output_features=10)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.011)

    def accuracy_fn(y_true, y_pred):
        correct = (
            torch.eq(input=y_true, other=y_pred).sum().item()
        )  # torch.eq() calculates where two tensors are equal
        acc = (correct / len(y_pred)) * 100  # Calcul simple de pourcentage
        return acc

    # Prepare data
    df = pd.read_csv("./csv/actual_dataset.csv")
    # df = pd.read_csv("/app/resources/original_dataset.csv")
    X = torch.from_numpy(df.drop(columns=["label"]).to_numpy()).type(torch.float32)
    y = torch.from_numpy(df["label"].to_numpy()).type(torch.long)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    # Training loop
    torch.manual_seed(42)
    epochs = 125
    for epoch in range(epochs + 1):
        """
        Train
        """
        model.train()

        # 1. Forward pass
        y_logits = model(X_train)
        y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)

        # 2. Metrics
        loss = loss_fn(y_logits, y_train)
        acc = accuracy_fn(y_true=y_train, y_pred=y_pred)

        # 2.1 Save metrics
        # loss_history.append(loss.cpu().detach().numpy())
        # acc_history.append(acc)

        # 3. Zero Grad
        optimizer.zero_grad()

        # 4. Backpropagation
        loss.backward()

        # 5. Optimmizer step
        optimizer.step()

        """
        Test
        """
        model.eval()

        with torch.inference_mode():
            # 1. Forward pass
            y_test_logits = model(X_test)
            y_test_pred = torch.softmax(y_test_logits, dim=1).argmax(dim=1)

            # 2. Metrics
            test_loss = loss_fn(y_test_logits, y_test)
            test_acc = accuracy_fn(y_pred=y_test_pred, y_true=y_test)

            # 2.1 Save metrics
            # test_loss_history.append(test_loss.cpu().detach().numpy())
            # test_acc_history.append(test_acc)

        # Print out what's happening
        if epoch % 25 == 0:
            print(
                f"Epoch: {epoch} | Loss: {loss:.5f}, Acc: {acc:.2f}% | Test Loss: {test_loss:.5f}, Test Acc: {test_acc:.2f}%"
            )

    # if epoch == 125:
    cm = confusion_matrix(y_test, y_test_pred.numpy())
    ConfusionMatrixDisplay(cm).plot()
    # Save the model
    torch.save(obj=model.state_dict(), f="./actual_model_fast.pth")

In [7]:
def predict(
    model: MusicClassifier, df: pd.DataFrame, genre_mapping: dict[int, str], real_class: str
):
    # TODO: Rewrite
    model.eval()

    class_predictions = []
    raw_results = []
    total_rows, _ = df.shape
    for i in range(total_rows):
        # print(df.iloc[i].to_numpy().reshape(55,1))
        y_logits = model(torch.from_numpy(df.iloc[i].to_numpy().reshape(55,1).transpose()).type(torch.float32))

        y_softmax = torch.softmax(y_logits, dim=1)
        y_pred = y_softmax.argmax(dim=1)

        # print(genre_mapping[y_pred.detach().numpy()[0]])
        # print(list(torch.round(y_softmax * 1000) / 1000))

        raw_results.append(y_softmax.detach().numpy())
        class_predictions.append(genre_mapping[y_pred.detach().numpy()[0]])

    unique_values = set(class_predictions)
    actual_best = 0
    for elt in unique_values:
        if class_predictions.count(elt) > actual_best:
            actual_best = class_predictions.count(elt)
            prediction = elt
        # print(elt, class_predictions.count(elt))
    # print("Real class =>", real_class)
    # print("Results =>", genre_mapping_inverse[prediction])
    return prediction, raw_results

In [8]:
def filter_data(kept_df: pd.DataFrame, seuil: float, raw_results: list, dataframe: pd.DataFrame, real_class: int):
    for i in range(len(raw_results)):
            if raw_results[i][0][int(real_class)] > seuil:
                kept_df = pd.concat([kept_df, pd.DataFrame(dataframe.iloc[i].to_numpy().reshape(55,1).transpose(), columns=column_names)], axis=0)
                kept_df['label'] = real_class
    return kept_df

In [9]:
# 1. Data
seuil = 0.1 # ! Boucler en changeant le seuil
base_df = pd.read_csv('./resources/original_dataset.csv')
for genre in os.listdir('./extracted_drive_dataset/'):
    for extracted_features in os.listdir(f'./extracted_drive_dataset/{genre}/'):
            kept_df = pd.DataFrame(columns=column_names)
            dataframe = pd.read_csv(f"./extracted_drive_dataset/{genre}/{extracted_features}")
            # 2. Model
            my_model = MusicClassifier(input_features=55, output_features=10)
            my_model.load_state_dict(
                torch.load(
                    f="./resources/actual_model_fast.pth", map_location=torch.device("cpu")
                )
            )
            dataframe.drop(columns=['label'], inplace=True)
            
            # 3. Prediction
            result, raw_results = predict(my_model, dataframe, genre_mapping, real_class=genre)

            # 4. Filter new data
            kept_df = filter_data(kept_df, seuil, raw_results, dataframe, real_class=int(genre))
            
            # 4.1 Concat with original dataset
            base_df = pd.concat([base_df, kept_df], axis=0)

base_df


  kept_df = pd.concat([kept_df, pd.DataFrame(dataframe.iloc[i].to_numpy().reshape(55,1).transpose(), columns=column_names)], axis=0)
  kept_df = pd.concat([kept_df, pd.DataFrame(dataframe.iloc[i].to_numpy().reshape(55,1).transpose(), columns=column_names)], axis=0)
  kept_df = pd.concat([kept_df, pd.DataFrame(dataframe.iloc[i].to_numpy().reshape(55,1).transpose(), columns=column_names)], axis=0)
  kept_df = pd.concat([kept_df, pd.DataFrame(dataframe.iloc[i].to_numpy().reshape(55,1).transpose(), columns=column_names)], axis=0)
  kept_df = pd.concat([kept_df, pd.DataFrame(dataframe.iloc[i].to_numpy().reshape(55,1).transpose(), columns=column_names)], axis=0)
  kept_df = pd.concat([kept_df, pd.DataFrame(dataframe.iloc[i].to_numpy().reshape(55,1).transpose(), columns=column_names)], axis=0)
  kept_df = pd.concat([kept_df, pd.DataFrame(dataframe.iloc[i].to_numpy().reshape(55,1).transpose(), columns=column_names)], axis=0)
  kept_df = pd.concat([kept_df, pd.DataFrame(dataframe.iloc[i].to_num

Unnamed: 0,chroma_stft_mean,chroma_stft_var,rms_mean,rms_var,spectral_centroid_mean,spectral_centroid_var,spectral_bandwidth_mean,spectral_bandwidth_var,rolloff_mean,rolloff_var,...,mfcc16_var,mfcc17_mean,mfcc17_var,mfcc18_mean,mfcc18_var,mfcc19_mean,mfcc19_var,mfcc20_mean,mfcc20_var,label
0,-0.487808,0.640520,-0.006624,0.235568,-0.566828,-0.572791,-0.493983,-0.009229,-0.518590,-0.367952,...,-0.299108,0.168647,-0.425137,-0.003423,-0.376938,-0.499464,-0.513562,0.128414,-0.291781,0.0
1,-0.403142,0.131835,-0.264944,-0.342134,-0.508798,-0.749862,-0.425382,-0.519010,-0.424118,-0.642268,...,0.428544,-0.327031,-0.310040,-0.112125,-0.032083,-0.066593,1.011384,1.275780,0.056425,0.0
2,-0.361694,0.764491,0.016695,0.542195,-0.546245,-0.701852,-0.288365,-0.425734,-0.346190,-0.562723,...,0.503695,0.428053,-0.648762,0.316311,-0.177372,0.109337,-0.046244,0.653907,-0.521458,0.0
3,-0.175714,0.205477,0.024885,-0.063820,-0.723482,-0.700599,-0.517344,-0.348881,-0.607665,-0.474804,...,-0.065309,0.062981,-0.649076,0.092384,-0.464121,-0.211882,-0.099501,0.865880,-0.544744,0.0
4,-0.485895,0.337521,0.181345,-0.272072,-0.756246,-0.774827,-0.538557,-0.572962,-0.667537,-0.683920,...,-0.570609,0.856651,-0.167089,0.183265,-0.029476,-0.175130,-0.678995,0.276899,-0.606692,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,2.689548,-0.884458,-1.335109,-0.738694,-1.349777,0.351067,-1.229154,1.142247,-1.055350,0.646024,...,-1.054894,2.327190,-0.992745,-0.850448,-1.309762,0.833703,-1.063541,0.327168,-1.056012,4.0
0,3.724011,-6.172935,-1.381147,-0.753012,-1.209250,1.199594,-1.279223,1.802030,-1.066251,1.296038,...,-0.741933,2.175700,-1.093312,-0.633214,-1.106743,1.359090,-1.175902,0.266042,-1.088137,4.0
0,3.667702,-4.875680,-1.153692,-0.719737,-1.716599,0.134080,-1.654220,1.739555,-1.499429,0.780060,...,-0.980495,1.785358,0.018480,-0.546433,-0.457361,1.366457,-1.219533,0.743225,-0.867149,4.0
0,3.136811,-3.672270,-1.281875,-0.768764,-1.309017,0.228687,-1.195240,0.978636,-1.031411,0.562945,...,-0.781240,1.440615,0.745079,-0.779382,-0.898405,2.014491,-1.046684,1.079490,-0.598707,4.0
