In [None]:
import torch
import matplotlib.pyplot as plt

## Data

### Constants

In [None]:
N = 20
FACTOR = 1 / torch.sqrt(torch.tensor(N))
K = 11

In [None]:
NUM_SAMPLES = 2000

In [None]:
def generate_data_1() -> torch.Tensor:
    """
    Returns one sample of data from distribution D_A^(1)
    """
    return torch.sign((torch.rand(N) * 2. - 1.)) * FACTOR

In [None]:
def generate_data_2(imp_cols: torch.Tensor) -> torch.Tensor:
    """
    Returns one sample of data from distribution D_A^(2)

    Arg:
        imp_cols (torch.Tensor): Tensor of columns which are significant in the distribution
    """
    x = torch.sign((torch.rand(N) * 2. - 1.)) * FACTOR
    sign = torch.sign((torch.rand(1) * 2. - 1.))
    for col in imp_cols:
        x[col] = sign * FACTOR
    return x

In [None]:
from numpy.random import choice

A = choice(range(N), K, False)
A

In [None]:
def get_y_from_data(x: torch.Tensor, imp_cols: torch.Tensor) -> torch.Tensor:
    """
    Returns label y (0 or 1) given a single data point x

    Args:
        x (torch.Tensor): Data tensor
        imp_cols (torch.Tensor): Significant columns from the data
    """
    y = torch.tensor(1.)
    for col in imp_cols:
        y *= torch.sign(x[col])
    if y <= 0:
        y = torch.tensor(0.)
    return y

### Generating data

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
function_choices = torch.round(torch.rand(NUM_SAMPLES))

In [None]:
X = torch.cat([generate_data_1().reshape(1, -1) if function_choices[i] == 0 else generate_data_2(A).reshape(1, -1) for i in range(NUM_SAMPLES)])
X.shape

In [None]:
Y = torch.cat([get_y_from_data(x, A).reshape(1, -1) for x in X])
Y.shape

In [None]:
X_training, X_test, Y_training, Y_test = train_test_split(X, Y, test_size=0.2)
X_train, X_val, Y_train, Y_val = train_test_split(X_training, Y_training, test_size=0.25)
X_train.shape, Y_train.shape, X_val.shape, Y_val.shape, X_test.shape, Y_test.shape

In [None]:
from scripts.utils import make_dataloader

train_dataloader, val_dataloader = make_dataloader(X_train, Y_train, batch_size=32, shuffle=True), make_dataloader(X_val, Y_val, batch_size=32, shuffle=True)

## Models

### Neural network

In [None]:
device = 'cpu'

In [None]:
from scripts.models import SimpleNN
from scripts.metrics import BinaryAccuracy
from scripts.train import train_model
from scripts.utils import EarlyStopping

In [None]:
depths = [1, 2, 3, 4, 5]
widths = [16, 32, 64]
weight_decays = torch.logspace(-3, 3, 7)
etas = [1e-4, 1e-3, 1e-2]


In [None]:
best_score = 0.0
best_depth = None
best_width = None
best_weight_decay = None
best_eta = None

In [None]:
total_count = len(depths) * len(widths) * len(weight_decays) * len(etas)
count = 0
EPOCHS = 50

print(f'Cross-validating across {total_count} models.\n')

for depth in depths:
    for width in widths:
        for weight_decay in weight_decays:
            for eta in etas:
                count += 1
                model = SimpleNN(input_size=N, hidden_layers=depth, hidden_units=width).to(device)
                loss_fn = torch.nn.BCELoss()
                optimizer = torch.optim.Adam(params=model.parameters(), lr=eta, weight_decay=weight_decay)
                metric = BinaryAccuracy()

                history = train_model(
                    model=model,
                    train_dataloader=train_dataloader,
                    val_dataloader=val_dataloader,
                    loss_fn=loss_fn,
                    optimizer=optimizer,
                    metric=metric,
                    epochs=50,
                    verbose=0,
                    device=device
                )
                curr_score = history['val_score'][-1]

                print(f'[{count}/{total_count}] depth={depth}, width={width}, lambda={weight_decay:.5f}, eta={eta} ===> validation score={curr_score:.6f}')
                if curr_score > best_score:
                    best_score = curr_score
                    best_depth = depth
                    best_width = width
                    best_weight_decay = weight_decay
                    best_eta = eta

print(f'Validation complete. Best validation score after {EPOCHS} epochs = {best_score:.6f}')
print(f'Best configuration: depth={best_depth}, width={best_width}, lambda={best_weight_decay:.5f}, eta={best_eta}')

In [None]:
best_model_nn = SimpleNN(input_size=N, hidden_layers=best_depth, hidden_units=best_width).to(device)

In [None]:
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(params=best_model_nn.parameters(), lr=best_eta, weight_decay=best_weight_decay)
metric = BinaryAccuracy()
early_stopper = EarlyStopping(patience=20, min_delta=1e-4)

In [None]:
history = train_model(
    model=best_model_nn,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    metric=metric,
    epochs=500,
    early_stopping=early_stopper,
    device=device
)

In [None]:
from scripts.utils import plot_train_history

plot_train_history(history)

In [None]:
from scripts.test import predict

preds_train, preds_val = predict(best_model_nn, X_train, device), predict(best_model_nn, X_val, device)
score_train, score_val = metric(preds_train, Y_train), metric(preds_val, Y_val)
score_train, score_val