In [None]:
# Add dependencies
using Pkg
Pkg.add(["CSV", "Optimisers", "Lux", "LuxCUDA", "DataFrames", "Zygote", "ComponentArrays", "Plots"])

In [1]:
# Load dependencies
using CSV, Lux, LuxCUDA, Optimisers, Random, DataFrames, Zygote, ComponentArrays

In [2]:
# Define enums
@enum PoisonClass Poisonous Edible

@enum CapShape begin
    Bell
    Conical
    Convex
    Flat
    Knobbed
    Sunken
end

@enum CapSurface begin
    Fibrous
    Grooves
    Scaly
    Smooth
end

@enum Bruised begin
    Bruises
    NoBruises
end

@enum Odor begin
    Almond
    Anise
    Creosote
    Fishy
    Foul
    Musty
    None
    Pungent
    Spicy
end

@enum GillAttachment begin
    Attached
    Descending
    Free
    Notched
end

@enum GillSpacing begin
    Close
    Crowded
    Distant
end

@enum GillSize begin
    Broad
    Narrow
end

@enum MColors begin
    Black
    Brown
    Buff
    Chocolate
    Cinnamon
    Gray
    Green
    Orange
    Pink
    Purple
    Red
    White
    Yellow
end

@enum StalkShape begin
    Enlarging
    Tapering
end

@enum StalkRoot begin
    Bulbous
    Club
    Cup
    Equal
    Rooted
end

In [None]:
# Define data
data = DataFrame(CSV.File("mushrooms.csv"))

macro Enumify(data, column, args...)
    # Generate a list of if expressions
    expr = :x
    for i in 1:2:length(args)
        symbol = args[i]
        value = args[i+1]
        expr = Expr(:if, :(x == $value), :($symbol), expr)
    end

    quote
        $(esc(data))[!, $(esc(column))] = map(x -> $expr, $(esc(data))[!, $(esc(column))])
    end
end

# Label data
@Enumify data :class Poisonous "p" Edible "e"
@Enumify data "cap-shape" Bell "b" Conical "c" Convex "x" Flat "f" Knobbed "k" Sunken "s"
@Enumify data "cap-surface" Fibrous "f" Grooves "g" Scaly "y" Smooth "s"
@Enumify data "cap-color" Brown "n" Buff "b" Cinnamon "c" Gray "g" Green "r" Pink "p" Purple "u" Red "e" White "w" Yellow "y"
@Enumify data :bruises Bruises "t" NoBruises "f"
@Enumify data :odor Almond "a" Anise "l" Creosote "c" Fishy "y" Foul "f" Musty "m" None "n" Pungent "p" Spicy "s"
@Enumify data "gill-attachment" Attached "a" Descending "d" Free "f" Notched "n"
@Enumify data "gill-spacing" Close "c" Crowded "w" Distant "d"

first(data, 5)

In [None]:
function split_df_percent(df, pct1, pct2; shuffled=true)
    if shuffled
        processed_df = df[shuffle(axes(df, 1)), :]
    else
        processed_df = df
    end

    n = nrow(processed_df)
    idx1 = floor(Int, pct1 * n)
    idx2 = floor(Int, pct2 * n)

    df1 = processed_df[1:idx1, :]
    df2 = processed_df[idx1+1:idx1+idx2,:]
    df3 = processed_df[idx1+idx2+1:end, :]

    return df1, df2, df3
end

training, testing, validation = split_df_percent(data, 0.6, 0.2)

first(training, 5)

In [None]:
# Define model
num_columns = ncol(training)

# Single layer with logistic regression
model = Dense(num_columns-1 => 1, sigmoid)

In [None]:
# Setup model and RNG
dev = gpu_device()
rng = Random.default_rng()
opt = Adam(0.03f0)

vjp_rule = AutoZygote()
lossfn = MSELoss()

ps, st = Lux.setup(rng, model) |> dev

In [None]:
# X and Y samples
x_samples = data[:, 2:end]
y_samples = data[:, 1]

first(x_samples, 5), y_samples

In [None]:
# Training time
function train!(model, ps, st, opt, epochs)
    tstate = Training.TrainState(model, ps, st, opt)
    for epoch in 1:epochs
        grads, loss, _, tstate = Training.single_train_step!(vjp_rule, lossfn, (x_samples, y_samples), tstate)
    end
    return tstate.model, tstate.parameters, tstate.states
end

model, ps, st = train!(model, ps, st, opt, 100)