In [None]:
# Useful starting lines
%matplotlib inline
import numpy as np
import itertools
import matplotlib.pyplot as plt
import csv
import helpers as helpers
import implementations as impl

In [None]:
# Load data
train_data, train_labels, ids = helpers.load_data('train.csv')

In [None]:
# Clean and standardize data
new_data = helpers.standardize(helpers.clean_data(train_data))
train_labels[train_labels==-1]=0

In [None]:
# Shuffle data
train_labels, new_data = helpers.shuffle_data(train_labels, new_data)

In [None]:
# Slice into training and validation sets
y_validation, y_train, tx_validation, tx_train = helpers.slice_data(train_labels, new_data, 0.25)

In [None]:
# Add bias to data
tx_train = np.c_[np.ones((y_train.shape[0], 1)), tx_train]
tx_validation = np.c_[np.ones((y_validation.shape[0], 1)), tx_validation]

In [None]:
# Initialize the weights randomly according to a Gaussian distribution
initial_w = np.random.normal(0., 0.1, [tx_train.shape[1],])

# Train model
trained_weights, train_loss = impl.logistic_regression(y_train, tx_train, initial_w, max_iters=1000, gamma=0.01)

In [None]:
print(trained_weights)

In [None]:
# Cross validation
predict_validation = helpers.predict_logistic(tx_validation, trained_weights)
predict_train = helpers.predict_logistic(tx_train, trained_weights)

predict_validation[predict_validation == -1] = 0
predict_train[predict_train == -1] = 0

train_accuracy = helpers.accuracy(predict_train, y_train)
validation_accuracy = helpers.accuracy(predict_validation, y_validation)

print(f"train_accuracy = {train_accuracy}")
print(f"validation_accuracy = {validation_accuracy}")

In [None]:
print(predict_validation)
print(y_validation)

In [None]:
tx_test, y_test, ids_test = helpers.load_data('test.csv', train=False)
tx_test = helpers.standardize(helpers.clean_data(tx_test))

In [None]:
tx_test = np.c_[np.ones((y_test.shape[0], 1)), tx_test]
predict_test = helpers.predict_logistic(tx_test, trained_weights)
print(predict_test)

In [None]:
helpers.create_csv_submission(ids_test, predict_test, 'Predictions_Logistics.csv')