In [None]:
using Pkg
Pkg.activate("..")

include("../src/MyDiffMLP.jl")
using .MyDiffMLP
const AD = MyDiffMLP.MyAD
const NN = MyDiffMLP.MyNN



In [None]:
using JLD2, Statistics
X_train = load("../data/imdb_dataset_prepared.jld2", "X_train")
y_train = load("../data/imdb_dataset_prepared.jld2", "y_train")
X_test = load("../data/imdb_dataset_prepared.jld2", "X_test")
y_test = load("../data/imdb_dataset_prepared.jld2", "y_test")
nothing

In [None]:
model = NN.Chain(
    NN.Dense(size(X_train, 1), 32, AD.relu),
    NN.Dense(32, 1, AD.sigmoid)
)

loss(ŷ, y) = begin
    ϵ = 1e-7
    ŷ = clamp.(ŷ, ϵ, 1 - ϵ)
    -mean(y .* log.(ŷ) .+ (1 .- y) .* log.(1 .- ŷ))
end

accuracy(ŷ, y) = mean((ŷ .> 0.5) .== (y .> 0.5))

η = 0.1
epochs = 10
batch_size = 64

In [None]:
using Random, Printf
function train!(model, X_train, y_train, X_test, y_test)
    n = size(X_train, 2)
    for epoch in 1:epochs
        idx = randperm(n)
        total_loss = 0.0
        total_acc = 0.0
        num_batches = 0

        t = @elapsed begin
            for i in 1:batch_size:n
                last = min(i + batch_size - 1, n)
                x_batch = X_train[:, idx[i:last]]
                y_batch = y_train[:, idx[i:last]]

                x_var = AD.Variable(x_batch, zeros(size(x_batch)))
                output = model(x_var)
                nodes = AD.topological_sort(output)
                AD.forward!(nodes)

                ŷ = output.output
                total_loss += loss(ŷ, y_batch)
                total_acc += accuracy(ŷ, y_batch)
                num_batches += 1

                grad_seed = (ŷ .- y_batch) ./ size(y_batch, 2)
                AD.backward!(nodes, grad_seed)
                NN.update!(NN.parameters(model), η)
            end
        end

        test_var = AD.Variable(X_test, zeros(size(X_test)))
        test_output = model(test_var)
        AD.forward!(AD.topological_sort(test_output))
        y_test_pred = test_output.output

        test_loss = loss(y_test_pred, y_test)
        test_acc = accuracy(y_test_pred, y_test)

        println(@sprintf("Epoch: %d (%.2fs) \tTrain: (l: %.4f, a: %.4f) \tTest: (l: %.4f, a: %.4f)",
            epoch, t, total_loss / num_batches, total_acc / num_batches, test_loss, test_acc))
    end
end

train!(model, X_train, y_train, X_test, y_test)