In [None]:
from utils.general import get_logging_dir
import torch
import utils
import models
from tqdm.notebook import tqdm

from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary

from utils.training import get_loss_fn, get_optimizer, train_one_epoch

from utils.general import get_logging_dir, make_values_scalar
import data.adult


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
torch.random.manual_seed(42)

BATCH_SIZE = 128
DEVICE = torch.device("cpu")

In [None]:
dataset = data.adult.get_dataset("train")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset =  data.adult.get_dataset("test")
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [5]:
hparams = {
    "optimizer": "adam",
    "learning_rate": 0.001,
    "loss_fn": "bce_with_logits",
    "batch_norm": False,
    "layer_norm": True,
    "dropout_rate": 0.3,
    "epochs": 300,
}
architecture: list[int] = [512, 512, 512, 256, 128]

model_config = models.DenseNetwork.Config(
    107,
    architecture,
    1,
    torch.nn.ReLU,
    use_batch_norm=hparams["batch_norm"],
    use_layer_norm=hparams["layer_norm"],
    dropout_rate=hparams["dropout_rate"],
)
model = models.DenseNetwork(model_config)

model.to(DEVICE)

writer = SummaryWriter(get_logging_dir("basicNN", "adult"))
writer.add_text("Model Summary", str(summary(model, input_size=(1, 107))))

dummy_input = torch.randn(1, 107)  # Example input
writer.add_graph(model, dummy_input)

optimizer = get_optimizer(
    hparams["optimizer"], hparams["learning_rate"], model.parameters()
)
loss_fn = get_loss_fn(hparams["loss_fn"])


for epoch in tqdm(range(hparams["epochs"])):
    epoch_loss = train_one_epoch(model, dataloader, optimizer, loss_fn, DEVICE)
    writer.add_scalar("Loss/train", epoch_loss, epoch)
    y_hats, ys = utils.evaluation.evaluate(
        model, test_dataloader, from_logits=True, return_outputs_only=True
    )
    writer.add_scalar(
        "Loss/test", loss_fn(y_hats.to(torch.float), ys.to(torch.float)).item(), epoch
    )


metrics = utils.evaluation.evaluate(model, test_dataloader, from_logits=True)

writer.add_hparams(
    {**make_values_scalar(dict(model_config)), **hparams}, dict(metrics), run_name="."
)
writer.flush()
writer.close()

  0%|          | 0/300 [00:00<?, ?it/s]