In [7]:
#= 📦 Importy i załadowanie danych =#
using JLD2, Printf, Statistics, LinearAlgebra
include("src/AD.jl")
include("src/NN.jl")

file = load("data/imdb_dataset_prepared.jld2")
X_train = Int.(file["X_train"])
y_train = vec(Float32.(file["y_train"]))
X_test  = Int.(file["X_test"])
y_test  = vec(Float32.(file["y_test"]))
embeddings = file["embeddings"]
vocab = file["vocab"]
embedding_dim = size(embeddings, 1)



50

In [8]:
#= 🔄 Przygotowanie danych =#
if size(X_train, 1) < size(X_train, 2)
    X_train = X_train
else
    X_train = X_train'
end
if size(X_test, 1) < size(X_test, 2)
    X_test = X_test
else
    X_test = X_test'
end
X_train_f32 = Float32.(X_train)
X_test_f32 = Float32.(X_test)
dataset = NN.DataLoader((X_train_f32, y_train), batchsize=32, shuffle=true);

In [9]:
model = NN.Chain(
    x -> Int.(x),
    NN.Embedding(length(vocab), embedding_dim, embeddings),
    x -> permutedims(x, (2, 1, 3)),
    NN.Conv1D(3, embedding_dim, 16, NN.relu),
    NN.MaxPool1D(8),
    NN.flatten,
    NN.Dropout(0.3),
    NN.Dense(256, 64, NN.relu),
    NN.Dropout(0.3),
    NN.Dense(64, 1, NN.sigmoid)
);


In [10]:
#= 🎯 Metryka dokładności =#
accuracy(m, x, y) = mean((vec(m(x)) .> 0.5) .== (y .> 0.5))

accuracy (generic function with 1 method)

In [12]:
#= 🚀 Trening modelu =#
opt = NN.Adam(Float32(0.0005))
epochs = 5
NN.train_model(model, dataset, X_test_f32, y_test, opt, epochs)

Epoch: 1 (77.24s) 	Train: (l: 0.5251, a: 0.7497) 	Test: (l: 0.5311, a: 0.7463)
Epoch: 2 (76.77s) 	Train: (l: 0.5269, a: 0.7481) 	Test: (l: 0.5368, a: 0.7413)
Epoch: 3 (76.72s) 	Train: (l: 0.5267, a: 0.7458) 	Test: (l: 0.5418, a: 0.7353)
Epoch: 4 (77.01s) 	Train: (l: 0.5253, a: 0.7482) 	Test: (l: 0.5415, a: 0.7400)
Epoch: 5 (80.20s) 	Train: (l: 0.5244, a: 0.7497) 	Test: (l: 0.5407, a: 0.7440)
