In [1]:
include("../MyReverseDiff.jl")
include("../MyEmbedding.jl")
include("../MyMlp.jl")

using .MyReverseDiff
using .MyMlp
using JLD2
using Printf
using Random

## Przygotowanie danych IMDB

In [2]:
X_train = load("../../dataset/imdb_dataset_prepared.jld2", "X_train");
y_train = load("../../dataset/imdb_dataset_prepared.jld2", "y_train");
X_test = load("../../dataset/imdb_dataset_prepared.jld2", "X_test");
y_test = load("../../dataset/imdb_dataset_prepared.jld2", "y_test");
embeddings = load("../../dataset/imdb_dataset_prepared.jld2", "embeddings")
# vocab = load("../../dataset/imdb_dataset_prepared.jld2", "vocab");

input_size = size(X_train, 1) # Liczba cech
embedding_dim = size(embeddings, 1)

50

##  Trening modelu

In [3]:
batch_size = 64

model = Chain(
    Embedding(embeddings, name="embedding"),
    TransposeBlock(),
    ConvolutionBlock(3,50,8, name="layer1"),
    PoolingBlock(8),
    FlattenBlock(name="flatten"),
    Dense(input_size-2, 1, σ, name="softnet")
)

#   Utworzenie początkowych węzłów Constant dla danych wejściowych i etykiet
x_input_node = Constant(zeros(Float32, input_size, batch_size))
y_label_node = Constant(zeros(Float32, 1, batch_size))

#   Budowanie grafu treningowego
loss_node, model_output_node, order = build_graph!(model, binarycrossentropy, x_input_node, y_label_node; loss_name="loss")

optimizer_state = setup_optimizer(Adam(a=0.001f0), model)

epochs = 5

5

In [4]:
for epoch in 1:epochs
    permutation = randperm(size(X_train, 2))
    X_train_shuffled_epoch = X_train[:, permutation]
    y_train_shuffled_epoch = y_train[:, permutation]
    num_batches = ceil(Int, size(X_train, 2) / batch_size)

    loss_value = 0.0

    println("\nEpoch: $epoch")
    total_batch_time = 0.0
    total_batch_alloc = 0
    total_batch_gc_time = 0.0

    for i in 1:num_batches
        start_idx = (i - 1) * batch_size + 1
        end_idx = min(i * batch_size, size(X_train, 2))
        x_batch_view = view(X_train_shuffled_epoch, :, start_idx:end_idx)
        y_batch_view = view(y_train_shuffled_epoch, :, start_idx:end_idx)

        current_batch_size = size(x_batch_view, 2)
        view(x_input_node.output, :, 1:current_batch_size) .= x_batch_view
        view(y_label_node.output, :, 1:current_batch_size) .= y_batch_view

        stats = @timed begin
            forward!(order)
            backward!(order)
            step!(optimizer_state)
        end
        loss_value += loss_node.output

        total_batch_time += stats.time
        total_batch_alloc += stats.bytes
        total_batch_gc_time += stats.gctime
    end

    avg_loss_epoch = loss_value / num_batches

    println(@sprintf("Epoch: %d \tTrain: (l: %.4f) \tTotal Epoch Time: %.4fs \tTotal Alloc: %s \tGC Time: %.4fs",
        epoch, avg_loss_epoch, total_batch_time, Base.format_bytes(total_batch_alloc), total_batch_gc_time))
end


Epoch: 1
Epoch: 1 	Train: (l: 0.6037) 	Total Epoch Time: 49.7755s 	Total Alloc: 45.664 GiB 	GC Time: 9.0684s

Epoch: 2
Epoch: 2 	Train: (l: 0.3822) 	Total Epoch Time: 36.5714s 	Total Alloc: 44.490 GiB 	GC Time: 8.1546s

Epoch: 3
Epoch: 3 	Train: (l: 0.2812) 	Total Epoch Time: 36.9344s 	Total Alloc: 44.490 GiB 	GC Time: 8.2825s

Epoch: 4
Epoch: 4 	Train: (l: 0.2098) 	Total Epoch Time: 36.4040s 	Total Alloc: 44.490 GiB 	GC Time: 8.2458s

Epoch: 5
Epoch: 5 	Train: (l: 0.1509) 	Total Epoch Time: 37.7498s 	Total Alloc: 44.490 GiB 	GC Time: 8.4131s


In [5]:
batch_size = 64
num_test_samples = size(X_test, 2)
num_batches = ceil(Int, num_test_samples / batch_size)
total_test_loss_sum = 0.0
total_correct_predictions = 0.0

t_test = @elapsed begin
    for i in 1:num_batches

        start_idx = (i - 1) * batch_size + 1
        end_idx = min(i * batch_size, num_test_samples)
        x_batch_test = X_test[:, start_idx:end_idx]
        y_batch_test = y_test[:, start_idx:end_idx]

        current_test_batch_size = size(x_batch_test, 2)

        view(x_input_node.output, :, 1:current_test_batch_size) .= x_batch_test
        view(y_label_node.output, :, 1:current_test_batch_size) .= y_batch_test

        forward!(order)

        predictions = view(model_output_node.output, :, 1:current_test_batch_size)


        batch_loss = loss_node.output

        total_test_loss_sum += batch_loss * current_test_batch_size

        batch_accuracy = sum((predictions .> 0.5f0) .== y_batch_test) / current_test_batch_size
        total_correct_predictions += batch_accuracy * current_test_batch_size
    end
end

avg_test_loss = total_test_loss_sum / num_test_samples
avg_test_accuracy = total_correct_predictions / num_test_samples * 100.0

println(@sprintf("Test Loss (czas: %.2fs): %.4f", t_test, avg_test_loss))
println("Test Accuracy: $avg_test_accuracy %")

Test Loss (czas: 2.34s): 0.3619
Test Accuracy: 86.11999999999999 %
