In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
def subsample(X, y, subsample_size, num_sets, is_disjoint=True):
    total_sample = subsample_size * num_sets
    indices = np.random.choice(
            X.shape[0], 
            total_sample, 
            replace = not is_disjoint
            )
    random_X = X[indices]
    random_y = y[indices]

    X_sets = np.split(random_X, num_sets)
    y_sets = np.split(random_y, num_sets)

    return list(zip(X_sets, y_sets))

In [3]:
df = pd.read_csv("data/adult.csv")
updated_df = pd.get_dummies(df)
updated_df = updated_df.drop_duplicates()
dataset = updated_df.to_numpy()

# remove the columns for income_<=50K and income_>50K
# y values are income_>50K
Y = dataset[:, -1]
X = np.delete(dataset, -1, axis = 1)
X = np.delete(X, -1, axis = 1)
subsamples = subsample(X, Y, 2500, 2, is_disjoint=True) 
assert len(subsamples) == 2
(d1_X, d1_y), (d2_X, d2_y) = subsamples[0], subsamples[1]

assert d1_X.shape == (2500, X.shape[1])
assert d1_y.shape == (2500, )

total_samples = np.concatenate((d1_X, d2_X), axis=0)
unique_values = len(set(map(tuple, total_samples)))

assert unique_values == 2500 * 2

subsamples = subsample(X[:3000], Y[:3000], 2500, 2, is_disjoint=False) 
(d1_X, d1_y), (d2_X, d2_y) = subsamples[0], subsamples[1]

assert d1_X.shape == (2500, X.shape[1])
assert d1_y.shape == (2500, )

total_samples = np.concatenate((d1_X, d2_X), axis=0)
unique_values = len(set(map(tuple, total_samples)))

In [4]:
def plot_losses(losses):
    epochs = np.arange(len(losses))
    plt.plot(epochs, losses)
    plt.xlabel('Number of Epochs', fontsize=14)
    plt.ylabel('Cross Entropy Loss', fontsize=14)
    plt.title('Plot of Loss vs Epochs', fontsize=14)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    fig = plt.gcf()
    fig.set_size_inches(8.25, 4.5)
    plt.show()

def plot_accuracies(accuracies):
    epochs = np.arange(len(accuracies))
    plt.plot(epochs, accuracies)
    plt.xlabel('Number of Epochs', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.title('Plot of Accuracy vs Epochs', fontsize=14)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    fig = plt.gcf()
    fig.set_size_inches(8.25, 4.5)
    plt.show()

assert unique_values < 2500 * 2