In [2]:
using Pkg 
# Pkg.add("Lux")
# Pkg.add("MLUtils")
# Pkg.add("Optimisers")
# Pkg.add("Zygote")
# Pkg.add("OneHotArrays")
# Pkg.add("Random") 
# Pkg.add("Statistics")
# Pkg.add("Printf")
# Pkg.add("Reactant")
# Pkg.add("MLDatasets")
# Pkg.add("SimpleChains")

using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains

Reactant.set_default_backend("cpu")

DataLoader(collect.((x_train,y_train)); batchsize, shuffle = true, partial = false) 

En este caso 
### DataLoader 
crea distintos mini_batches de forma eficiente 

### Shuffle = true 
mezcla los elementos antes de dividirlos en mini_batches, cada época vuelve a mezclar los datos, por lo que los nuevos mini_batches serán distintos 

### Partial = true 
Si el número total de datos, no son divisibles por el tamaño del mini_batch, genera un último mini_batch con el resto de datos 


In [3]:
Reactant.set_default_backend("cpu")

In [None]:
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split=:train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true, partial=false),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false, partial=false),
    )
end

In [45]:
parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing

In [None]:
 train_split = 0.9
 # Process images into (H, W, C, BS) batches
  x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
  y_data = onehotbatch(labels_raw, 0:9)
  (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

In [None]:
#x_train
#y_train
x_data # Me interesa que 
#y_data

# Qué recibe dense? 
# Hacer una función de my_leaky_relu 

In [None]:
my_batchsize = 128
A1 = DataLoader(collect.((x_train, y_train)); batchsize = my_batchsize, shuffle=true, partial=false)
A2 = DataLoader(collect.((x_test, y_test)); batchsize = my_batchsize, shuffle=false, partial=false)
A3 = DataLoader(collect.((x_train, y_train)); batchsize = my_batchsize, shuffle=true, partial=false)

In [None]:
A1 == A2

In [None]:
A1 == A3

In [None]:
dataset = MNIST(; split =:train)
imgs = dataset.features
labels_raw = dataset.targets

In [None]:
imgs

In [None]:
dataset.targets

In [None]:
onehotbatch(dataset.targets, 0:9)

In [None]:
lux_model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)

In [None]:
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
     

In [None]:
const lossfn = CrossEntropyLoss(; logits=Val(true)) # const se usa para decirle a Julia que el tipo de la variable global no va a cambiar: https://docs.julialang.org/en/v1/base/base/#const

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(Array(y))
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
    

### Función para entrenar

In [None]:
fieldnames(Adam)

In [None]:
function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...)
    train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
    ps, st = dev(Lux.setup(rng, model)) # se inicializan los parámetros del modelo de forma aleatoria y se cargan en el CPU (dev)

    vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote() # Usando Reactant permite compilar el modelo antes de entrenarlo: https://lux.csail.mit.edu/stable/manual/compiling_lux_models#reactant-compilation

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    if dev isa ReactantDevice
        x_ra = first(test_dataloader)[1]
        model_compiled = @compile model(x_ra, ps, Lux.testmode(st)) # Justo aquí es compilado el modelo
    else
        model_compiled = model
    end

    ### Lets train the model
    nepochs = 10 # Cuantas veces se pasa por todos los datos
    tr_acc, te_acc = 0.0, 0.0 # Se inicializan las variables de accuracy
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, train_state = Training.single_train_step!(
                vjp, lossfn, (x, y), train_state
            )
        end
        ttime = time() - stime

        tr_acc =
            accuracy(
                model_compiled, train_state.parameters, train_state.states, train_dataloader
            ) * 100
        te_acc =
            accuracy(
                model_compiled, train_state.parameters, train_state.states, test_dataloader
            ) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return train_state.parameters, train_state.states, tr_acc, te_acc # En el código del tutorial no están las primeras dos variables, lo modifiqué para que nos devuelva los parámetros entrenados.
end

@doc Training.single_train_step!


In [None]:
@doc Training.single_train_step!

In [None]:
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # Se necesita al correr la primera vez para bajar los datos de MNIST
tr_acc, te_acc = train(lux_model, reactant_device()); # entrenando el modelo en lux, este tarda más

In [None]:

ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # Se necesita al correr la primera vez para bajar los datos de MNIST
ps, st, tr_acc, te_acc = train(simple_chains_model); #entrenando el modelo de simple simple_chains_model

### Analices cómo están formateando los datos

In [None]:
dataset = MNIST(; split=:train)

In [None]:
dev=cpu_device();
train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))

In [None]:
Adam|>fieldnames

In [None]:
collect(train_dataloader)[1][1]

In [None]:
train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))

In [None]:
first(test_dataloader)[2][:,10]

In [None]:
dev = cpu_device()
train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))

### Hagamos predicciones con el modelo ya entrenado 

In [23]:
using Plots
using Plots: plot, heatmap

In [None]:
# Get the first batch from the test dataloader
x_batch = first(test_dataloader)[1]

# Extract the first sample and reshape it to 28×28×1×1
single_sample = x_batch[:, :, :, 25]

println("Shape of single sample: ", size(single_sample))

heatmap(single_sample[:, :, 1], c=:grays, title="MNIST Image", xlabel="Width", ylabel="Height")