In [1]:
using JLD2, Random, Statistics, Printf
include("NETWORK.jl")
using .NETWORK
using BenchmarkTools

datapath = "../Reference/CNN/CNN/data/imdb_dataset_prepared.jld2"
X_train = load(datapath, "X_train") 
y_train = load(datapath, "y_train")
X_test = load(datapath, "X_test") 
y_test = load(datapath, "y_test") 
embeddings = load(datapath, "embeddings") 
vocab = load(datapath, "vocab")

embedding_dim = size(embeddings, 1) 

X_train = Float32.(X_train);
X_test = Float32.(X_test);

y_train = Float32.(y_train);
y_train = vec(y_train);

y_test = Float32.(y_test)
y_test = vec(y_test);

model = NETWORK.Chain(
    NETWORK.Embedding(length(vocab), embedding_dim),
    NETWORK.Conv1D(embedding_dim, 8; kernel=3, activation=NETWORK.relu),
    NETWORK.MaxPool1D(8),
    NETWORK.Flatten(), 
    NETWORK.Dense(128, 1, NETWORK.sigma)
)

embeddings = Float32.(embeddings)
model.layers[1].weight.value .= embeddings

epochs = 5
batch_size = 64
learning_rate = 0.001f0

opt = NETWORK.AdamOptimizer(model, lr=learning_rate) 
test_trainer = NETWORK.Trainer(model, X_test, batch_size)
train_trainer = NETWORK.Trainer(model, X_train, batch_size)
# loader = NETWORK.DataLoader((X_train, y_train); batchsize=64, shuffle=true, prefetch_epochs=5)
# loader = NETWORK.SimpleLoader(X_train,y_train; batchsize=64, shuffle=true)

for epoch in 1:epochs
    # GC.gc()
    # stats = @allocated begin

    total_loss = 0f0
    total_acc  = 0f0
    num_samples = 0

    t = @elapsed begin
        # for (x, y) in loader
        #     batch_loss, batch_acc = NETWORK.train!(train_trainer, model, x, y, opt, batch_size)
        #     total_loss += batch_loss
        #     total_acc  += batch_acc
        #     num_samples += 1
        # end
        train_loss, train_acc = NETWORK.train!(train_trainer, model, X_train, y_train, opt; batch_size = batch_size)
        test_loss, test_acc = NETWORK.loss(test_trainer, X_test, y_test, batch_size)
    end
    # end
    # train_loss = total_loss / num_samples
    # train_acc = total_acc / (num_samples * batch_size)
    @printf("Epoch %d (%.2fs)  Train: (l %.4f, a %.4f)  Test: (l %.4f, a %.4f)\n",
        epoch, t, train_loss, train_acc, test_loss, test_acc)
    # println("Allocations per epoch: $(round(stats / 1_048_576, digits=2)) MB")
end

Epoch 1 (43.04s)  Train: (l 0.5212, a 0.7291)  Test: (l 0.3793, a 0.8370)
Epoch 2 (13.40s)  Train: (l 0.3267, a 0.8632)  Test: (l 0.3223, a 0.8690)
Epoch 3 (13.34s)  Train: (l 0.2513, a 0.8988)  Test: (l 0.3117, a 0.8731)
Epoch 4 (13.37s)  Train: (l 0.1957, a 0.9260)  Test: (l 0.3190, a 0.8739)
Epoch 5 (13.35s)  Train: (l 0.1458, a 0.9497)  Test: (l 0.3359, a 0.8731)
