In [1]:
using Statistics
using Profile

include("ConvolutionModule.jl")
include("PoolingModule.jl")
include("FlattenModule.jl")
include("DenseModule.jl")

include("MNISTDataLoader.jl")
include("LossAndAccuracy.jl")
include("NetworkHandlers.jl")

using .ConvolutionModule, .PoolingModule, .MNISTDataLoader, .FlattenModule, .DenseModule

# Load and preprocess the data
train_features, train_labels = MNISTDataLoader.load_data(:train)
train_x, train_y = MNISTDataLoader.preprocess_data(train_features, train_labels; one_hot=true)

# for i in eachindex(axes(train_x, 4))
#     input = train_x[:, :, :, i]
#     (height, width, channels) = size(input)
#     if height != 28 || width != 28 || channels != 1
#         println("Input $i: height=$height, width=$width, channels=$channels")
#     end
# end

# Load and preprocess test data
test_features, test_labels = MNISTDataLoader.load_data(:test)
test_x, test_y = MNISTDataLoader.preprocess_data(test_features, test_labels; one_hot=true)

# Create batches
batch_size = 100
train_data = MNISTDataLoader.batch_data((train_x, train_y), batch_size; shuffle=true)
# input_image = Float64.(input_image)

sample_input = train_x[:, :, :, 1]

# Initialize layers
conv_layer1 = ConvolutionModule.init_conv_layer(3, 3, 1, 6, 1, 0, 3697631579, 28, 28, 1)
pool_layer1 = PoolingModule.init_pool_layer(2, 2, 2, 26, 26, 6)
conv_layer2 = ConvolutionModule.init_conv_layer(3, 3, 6, 16, 1, 0, 3731614026, 13, 13, 6)
pool_layer2 = PoolingModule.init_pool_layer(2, 2, 2, 11, 11, 16)
flatten_layer = FlattenModule.FlattenLayer()
dense_layer1 = DenseModule.init_dense_layer(400, 84, DenseModule.relu, DenseModule.relu_grad, 4172219205)
dense_layer2 = DenseModule.init_dense_layer(84, 10, DenseModule.identity, DenseModule.identity_grad, 3762133366)

# Workaround because of namespaces...
function backward_pass_master(network, grad_loss)
    for layer in reverse(network)
        if isa(layer, ConvolutionModule.ConvLayer)
            grad_loss = ConvolutionModule.backward_pass(layer, grad_loss)

        elseif isa(layer, PoolingModule.MaxPoolLayer)
            grad_loss = PoolingModule.backward_pass(layer, grad_loss)

        elseif isa(layer, DenseModule.DenseLayer)
            grad_loss = DenseModule.backward_pass(layer, grad_loss)

        elseif isa(layer, FlattenModule.FlattenLayer)
            grad_loss = FlattenModule.backward_pass(layer, grad_loss)
        else
            println("No backward pass defined for layer type $(typeof(layer))")
        end
    end
    return grad_loss
end

function update_weights(network, learning_rate)
    for layer in reverse(network)
        if isa(layer, DenseModule.DenseLayer) || isa(layer, ConvolutionModule.ConvLayer)
            # println("Type of layer ", typeof(layer))

            # println("Previous weights grad ", layer.grad_weights[1:2, :])

            layer.grad_weights ./= batch_size
            layer.grad_biases ./= batch_size
            # println("Normalized weights grad ", layer.grad_weights[1:2, :])

            # println("Previous weights ", layer.weights[1:2, :])

            layer.weights .-= learning_rate * layer.grad_weights
            layer.biases .-= learning_rate * layer.grad_biases



            # println("Updated weights ", layer.weights[1:2, :])

            fill!(layer.grad_weights, 0)
            fill!(layer.grad_biases, 0)
        end
    end
end

# Function to evaluate the model on test data
function evaluate_model(network, test_x, test_y)
    total_loss = 0.0
    total_accuracy = 0.0
    num_samples = size(test_x, 4)

    for i in 1:num_samples
        input = test_x[:, :, :, i]
        target = test_y[:, i]

        # Forward pass
        output = NetworkHandlers.forward_pass_master(network, input)

        # Calculate loss and accuracy
        loss, accuracy, _ = LossAndAccuracy.loss_and_accuracy(output, target)
        total_loss += loss
        total_accuracy += accuracy
    end

    # Calculate average loss and accuracy
    avg_loss = total_loss / num_samples
    avg_accuracy = total_accuracy / num_samples
    return avg_loss, avg_accuracy
end

# Assemble the network
network = (conv_layer1, pool_layer1, conv_layer2, pool_layer2, flatten_layer, dense_layer1, dense_layer2)

using .NetworkHandlers, .LossAndAccuracy
epochs = 3
training_step = 0.5

println("Starting training...")



function training_loop()
    plot_loss = Float64[]
    batch_loss = 0.0
    for epoch in 1:epochs
        accumulated_accuracy_epoch = 0.0
        accumulated_accuracy_batch = 0.0
        
        @time begin
            for i in eachindex(axes(train_x, 4))
                input = train_x[:, :, :, i]
                target = train_y[:, i]

                output = NetworkHandlers.forward_pass_master(network, input)
                
                loss, accuracy, grad_loss = LossAndAccuracy.loss_and_accuracy(output, target)
                accumulated_accuracy_epoch += accuracy
                accumulated_accuracy_batch += accuracy
                batch_loss += loss
                # if i % 100 == 0
                #     println("Loss: ", loss)
                #     println("Accuracy: ", round(accumulated_accuracy_batch / 100, digits=2))
                #     accumulated_accuracy_batch = 0.0
                # end

                # if i % 10000 == 0
                #     println("i ", i)
                # end

                backward_pass_master(network, grad_loss)
                
                if i % batch_size == 0
                    plot_loss = push!(plot_loss, batch_loss / batch_size)
                    batch_loss = 0.0
                    update_weights(network, training_step)
                end
            end
        end
        # println("Epoch $(epoch) done")
        # println("Accuracy: ", round(accumulated_accuracy_epoch / size(train_x, 4), digits=2))
        # accumulated_accuracy_epoch = 0.0

        test_loss, test_accuracy = evaluate_model(network, test_x, test_y)
        println("Epoch $(epoch) done. Training Accuracy: $(round(accumulated_accuracy_epoch / size(train_x, 4), digits=2)), Test Loss: $test_loss, Test Accuracy: $test_accuracy")

        # Update weights at the end of each epoch
        update_weights(network, training_step)
    end
end


# @code_warntype training_loop()
@profile training_loop()

Profile.print()
ProfileView.view()

# TODO: throw out all .= .* etc

Starting training...
278.993854 seconds (2.62 G allocations: 365.666 GiB, 8.42% gc time, 2.57% compilation time)
Epoch 1 done. Training Accuracy: 89.81, Test Loss: 0.1880244781484015, Test Accuracy: 93.89
263.106043 seconds (2.61 G allocations: 365.078 GiB, 8.68% gc time)
Epoch 2 done. Training Accuracy: 95.81, Test Loss: 0.14050660120819483, Test Accuracy: 95.29
258.087746 seconds (2.61 G allocations: 365.078 GiB, 8.89% gc time)
Epoch 3 done. Training Accuracy: 96.52, Test Loss: 0.12438330327535273, Test Accuracy: 96.05


│ before your program finished. To profile for longer runs, call
│ `Profile.init()` with a larger buffer and/or larger delay.
└ @ Profile C:\Users\mikul\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Profile\src\Profile.jl:609


Overhead ╎ [+additional indent] Count File:Line; Function
     ╎11     @Base\asyncevent.jl:306; (::Base.var"#726#727"{VSCodeServer.IJulia…
     ╎ 11     …liaCore\src\stdio.jl:243; #4
     ╎  11     …iaCore\src\stdio.jl:120; send_stdio(name::String, send_callback…
     ╎   11     …iaCore\src\stdio.jl:159; send_stream(name::String, send_callba…
   10╎    11     …c\serve_notebook.jl:72; io_send_callback(name::String, data::…
     ╎     1      …SONRPC\src\core.jl:241; send_notification(x::VSCodeServer.JS…
     ╎    ╎ 1      @JSON\src\Writer.jl:354; json
     ╎    ╎  1      …ase\strings\io.jl:107; sprint
     ╎    ╎   1      …ase\strings\io.jl:114; sprint(f::Function, args::Dict{Str…
     ╎    ╎    1      …SON\src\Writer.jl:349; print(io::IOBuffer, obj::Dict{Str…
     ╎    ╎     1      …ON\src\Writer.jl:323; show_json
     ╎    ╎    ╎ 1      …ON\src\Writer.jl:325; #show_json#3
     ╎    ╎    ╎  1      …N\src\Writer.jl:271; show_json(io::VSCodeServer.JSON.…
    1╎    ╎    ╎   1      …N\src\W

Excessive output truncated after 524348 bytes.


     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +28 1      …ssair\ir.jl:348; copy(ir:…
     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +29 1      …ssair\ir.jl:213; copy
    1╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +30 1      …se\array.jl:411; copy
     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +10 2      …retation.jl:103; abstract…
     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +11 2      …retation.jl:788; abstract…
     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +12 2      …retation.jl:818; abstract…
     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +13 2      …retation.jl:1207; const_p…
     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +14 2      …ypeinfer.jl:216; typeinf(…
     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +15 2      …ypeinfer.jl:247; _typeinf…
     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +16 2      …retation.jl:3186; typeinf…
     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +17 2      …retation.jl:3098; typeinf…
     ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ +18 2      …retation.jl:2913; abstrac…
     ╎    ╎    ╎    ╎    ╎    ╎    ╎ 

UndefVarError: UndefVarError: `ProfileView` not defined