In [5]:
from mysklearn.mypytable import MyPyTable
from mysklearn.myclassifiers import MyDecisionTreeClassifier
from mysklearn.myutils import cross_val_predict, print_confusion_matrix
from collections import Counter
import random

# Load dataset
table = MyPyTable().load_from_file("cleaned_tracks.csv")

# Categorize popularity into Low/Medium/High
def categorize_popularity(p):
    p = int(p)
    if p <= 33:
        return "Low"
    elif p <= 66:
        return "Medium"
    else:
        return "High"

pop_idx = table.column_names.index("popularity")
y = [categorize_popularity(row[pop_idx]) for row in table.data]

# Feature columns
explicit_idx = table.column_names.index("explicit")
duration_idx = table.column_names.index("duration_ms")
dance_idx = table.column_names.index("danceability")
energy_idx = table.column_names.index("energy")
tempo_idx = table.column_names.index("tempo")
loudness_idx = table.column_names.index("loudness")

# Build X matrix
X = []
for row in table.data:
    X.append([
        row[explicit_idx],
        row[duration_idx],
        row[dance_idx],
        row[energy_idx],
        row[tempo_idx],
        row[loudness_idx]
    ])

In [6]:
# Shuffle X and y together
combined = list(zip(X, y))
random.shuffle(combined)
X, y = zip(*combined)
X = list(X)
y = list(y)

# --- 10-FOLD CV using Decision Tree ---
accuracy, error_rate, true_labels, predictions = cross_val_predict(
    X, y, MyDecisionTreeClassifier, k=10
)

print(f"10-fold CV Accuracy: {accuracy:.2f}")
print(f"10-fold CV Error Rate: {error_rate:.2f}")

10-fold CV Accuracy: 0.39
10-fold CV Error Rate: 0.61


In [7]:
# --- Confusion Matrix ---
labels = ["Low", "Medium", "High"]
matrix = [[0 for _ in labels] for _ in labels]
label_to_idx = {label: i for i, label in enumerate(labels)}

for t, p in zip(true_labels, predictions):
    i = label_to_idx[t]
    j = label_to_idx[p]
    matrix[i][j] += 1

print_confusion_matrix(labels, matrix, "Decision Tree Confusion Matrix")

Decision Tree Confusion Matrix
                                   ['Low', 'Medium', 'High']
------  ----  ----  ----  -----  ---------------------------
Low     9144  1076  1147  11367                           80
Medium  8172  1400  1795  11367                           12
High    6995  1588  2784  11367                           24
