# Multi-Layer Perceptron Binary Classifier using PyTorch

In [None]:
from math import isclose
from pathlib import Path
from warnings import filterwarnings

import matplotlib.pyplot as plt
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.model_summary import summarize
from pytorch_lightning.utilities.warnings import PossibleUserWarning


from shipsnet.data import ShipsDataModule
from shipsnet.models import MLPClassifier
from shipsnet.viz import array_to_rgb_image

%load_ext jupyter_black

filterwarnings("ignore", category=PossibleUserWarning)

# Sanity checks

## Check the datamodule loads the images correctly

In [None]:
datamodule = ShipsDataModule()

datamodule.prepare_data()
datamodule.setup()

inputs, labels = next(iter(datamodule.train_dataloader()))

fig, axes = plt.subplots(3, 4)

for tensor, ax in zip(inputs + 0.5, axes.flatten()):
    ax.imshow(array_to_rgb_image(tensor))
    ax.set_axis_off()

fig.tight_layout()
plt.show()

## Check reproducibility

In [None]:
def train_and_eval():
    """Quickly trains model and returns validation metrics."""
    datamodule = ShipsDataModule()
    model = MLPClassifier([10], "relu")
    trainer = Trainer(
        max_epochs=5,
        logger=False,
        enable_checkpointing=False,
        enable_model_summary=False,
        enable_progress_bar=False,
    )
    trainer.fit(model, datamodule)
    (metrics,) = trainer.validate(model, datamodule, verbose=False)
    return metrics


seed = seed_everything()

metrics_1 = train_and_eval()

# Reset the seed and retrain - should get same results
seed_everything(seed)
metrics_2 = train_and_eval()
assert all([isclose(metrics_1[k], metrics_2[k]) for k in metrics_1])

# Don't reset the seed - should get different results
metrics_3 = train_and_eval()
assert not all([isclose(metrics_1[k], metrics_3[k]) for k in metrics_1])

print("Reproducibility check passed!")

# Train an ensemble of classifiers

Use the two cells below to train an ensemble of MLP classifiers with different hidden shapes and activation functions.

In [None]:
datamodule = ShipsDataModule(
    batch_size=32,
    train_frac=0.75,
    random_split_seed=12345,
)

seed = seed_everything()

model = MLPClassifier(
    hidden_shape=[20],
    activation="relu",
)

model.save_hyperparameters({"seed": seed, "class": model.__class__.__name__})
summarize(model, max_depth=2)

In [None]:
early_stopping = EarlyStopping(monitor="val/loss", patience=5, verbose=True)
checkpoints = ModelCheckpoint(monitor="val/loss", filename="{epoch:d}")

trainer = Trainer(
    logger=TensorBoardLogger(".", default_hp_metric=False),
    callbacks=[early_stopping, checkpoints],
    enable_model_summary=False,
)
trainer.fit(model, datamodule)

# So we can easily see where to load the checkpoint from later
model.logger.experiment.add_text(
    "checkpoint_path", str(Path(checkpoints.best_model_path).resolve())
)

# Evaluate the best model

Use Tensorboard to compare different models. When you've found the best model, load it up and run it on the test set.

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs

In [None]:
# change this to the path to the checkpoint of the best model
best_model_path = None 

# Loads most recent if best_model_path is not specified
best_model = MLPClassifier.load_from_checkpoint(
    best_model_path or checkpoints.best_model_path
)

# Create a dummy trainer just to evaluate the model
(test_metrics,) = Trainer(logger=False).test(best_model, datamodule)