In [2]:
using JLD2
using Printf
using Statistics
using LinearAlgebra

In [3]:
include("src/AD.jl")
include("src/NN.jl")

Main.NN

In [11]:
# Load the data directly using the available keys
file = load("data/imdb_dataset_prepared.jld2")
X_train = Float32.(Matrix(file["X_train"]))
y_train = vec(Float32.(file["y_train"]))
X_test  = Float32.(Matrix(file["X_test"]))
y_test  = vec(Float32.(file["y_test"]));

In [12]:
dataset = NN.DataLoader((X_train, y_train), batchsize=64, shuffle=true);

In [13]:
# Create model with dropout and weight decay
model = NN.Chain(
    NN.Dense(size(X_train, 1), 64, NN.relu, weight_decay=Float32(0.0001)),
    NN.Dropout(Float32(0.3)),
    NN.Dense(64, 32, NN.relu, weight_decay=Float32(0.001)),
    NN.Dropout(Float32(0.3)),
    NN.Dense(32, 1, NN.sigmoid, weight_decay=Float32(0.001))
);

In [14]:
accuracy(m, x, y) = mean((vec(m(x)) .> 0.5) .== (y .> 0.5))

accuracy (generic function with 1 method)

In [17]:
# Use a smaller learning rate
global opt = NN.Adam(Float32(0.0001))  # Reduced learning rate
epochs = 5  # Increased epochs for early stopping
patience = 5;  # Number of epochs to wait for improvement

In [18]:
# Run training
model = NN.train_model(model, dataset, X_test, y_test, opt, epochs, patience);

Epoch: 1  (1.49s)   Train: (l: 0.5416, a: 0.8972)   Test: (l: 0.5518, a: 0.8650)
Epoch: 2  (1.33s)   Train: (l: 0.4969, a: 0.9044)   Test: (l: 0.5149, a: 0.8685)
Epoch: 3  (1.35s)   Train: (l: 0.4488, a: 0.9144)   Test: (l: 0.4769, a: 0.8705)
Epoch: 4  (1.63s)   Train: (l: 0.3995, a: 0.9177)   Test: (l: 0.4439, a: 0.8735)
Epoch: 5  (1.71s)   Train: (l: 0.3568, a: 0.9254)   Test: (l: 0.4137, a: 0.8735)
