In [9]:
import numpy as np
import pandas as pd
import os
from toolkit_dslr.logistic_regression import LogisticRegressionScratch
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [10]:
def LogisticRegression(file: str):
    df = pd.read_csv(file)
    imputer = SimpleImputer(strategy="mean")

    houses = ["Gryffindor", "Slytherin", "Ravenclaw", "Hufflepuff"]

    X = df.drop(['Index', 'Hogwarts House', 'First Name', 'Last Name',
                  'Birthday', 'Best Hand', 'Arithmancy',
                  'Care of Magical Creatures'], axis=1)
    X = imputer.fit_transform(X)

    X_train, X_test, y_train_global, y_test_global = train_test_split(
        X, df["Hogwarts House"].values, test_size=0.2, random_state=42
    )

    models = {}
    for house in houses:
        y_train = np.array([1 if i == house else 0 for i in y_train_global])

        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        model = LogisticRegressionScratch(learning_rate=0.1, iterations=1000)
        model.fit(X_train, y_train)

        plt.plot(model.cost_history)
        plt.title("Cost Function Convergence")
        plt.xlabel("Iterations")
        plt.ylabel("Cost")
        plt.grid(True)
        plt.show()
        models[house] = (model, scaler)


    house_preds = []
    for idx in range(X_test.shape[0]):
        probas = []
        for house in houses:
            model, scaler = models[house]
            x = X_test[idx].reshape(1, -1)
            proba = model.predict(x)[0]
            probas.append(proba)
        best_house_idx = np.argmax(probas)
        house_preds.append(houses[best_house_idx])

    accuracy = np.mean(house_preds == y_test_global)
    print(f"Model Accuracy: {accuracy:.2f}")

In [None]:
def main(dataset_file: str):
    try:
        assert os.path.exists(dataset_file), "The file does not exists"
        LogisticRegression(dataset_file)
    except AssertionError as error:
        print(AssertionError.__name__ + ":", error)


main("datasets/dataset_train.csv")