# Project Motivation

Reinforcement agents are notoriously difficult to train due the commonly encountered problem of reduced learning rates on data previously used to train the agent. Many approaches are currently being made to remedy the problem, one of which is episodic recall in the fashion of [Ritter et al. (2018)](https://arxiv.org/pdf/1805.09692.pdf). This method involves a stable external memory usually implemented as a differentiable neural dictionary (DND) of key value pairs which link previously encountered environment states with the network's cell states at the time the environment states were encountered. This essentially allows temporally stable recall of contexts that can be useful when environment states which are similar but not identical to those found within its DND memory are encountered. Long short-term memory neural network (LSTM) architecture is then augmented with an additional input gate from the DND and an output gate to store the current context within it.

While episodic recall itself is a promising route, there are still technical details involved in ensuring it actually decreases the amount of data needed for training and training time. One such issue is the way key-value pairs are embedded and recalled from the DND. Here I implement a DND based on the description in [Pritzel et al. (2017)](https://arxiv.org/pdf/1703.01988.pdf) including their kernel function. To recall cell states from the DND there are various methods of determining the relative similarity of environment states in the memory to that currently encountered. This can be specified by a kernel function which weights the corresponding cell states based on a metric such as
$$k(h,h_{i}) = \frac{1}{||h-h_{i}||_{2}^{2} + \delta}$$

where $h$ is the current environment state, $h_{i}$ are memorised environment states which are being included in the recall, and $\delta$ is a distance tuning parameter. 

Below I have an implementation of the modified LSTM architecture and its DND. At the moment I am exploring alternative methods to using p-nearest neighbours variants to extract the cell states from the DND as done in [Pritzel et al. (2017)](https://arxiv.org/pdf/1703.01988.pdf).

In [191]:
using LinearAlgebra, Flux, Statistics, Dates, TimeZones, PyCall, Plots
ast_stats = pyimport("astropy.stats");

# Data Importing and Cleaning

<font size="3.5">The data used for the training of the reinforcement learning agent seen below is minute gold futures data from COMEX GCZ2020 between September 29, 2020 at 16:43 to October 6, 2020 at 04:06.</font>

In [None]:
#
# The functions time_change, find_date, data_ret, and line_prediction are functions needed
# and solely use for the importation and cleaning of the financial data I am using.
#


function time_change(T,mins;date_time=false)
    if date_time === false
        date_time = DateTime(T,"dd-u-yyyy HH:MM") + Minute(mins)
        time = Dates.format(date_time, "dd-u-yyyy HH:MM")
    
        return date_time, time
    else
        date_time = T + Minute(mins)
        time = Dates.format(date_time, "dd-u-yyyy HH:MM")
        
        return date_time, time
    end
end



function find_date(elements, value; days=false)
    
    
    left = 1
    right = length(elements)

    first = 1
    last = length(elements)
        
    while left <= right
        middle = (left + right) ÷ 2
        middle_element = elements[middle][5]
        
        if middle != first && middle != last
            left_element = elements[middle-1][5]
            right_element = elements[middle+1][5]

            if middle_element == value
                return middle_element,middle
            end
            if left_element == value
                return left_element,middle-1
            end
            if right_element == value
                return right_element,middle+1
            end
            if middle - left <= 1 && right - middle <= 1
                times_under_consideration = Dict()
                datetime_triple = [(left_element,left_element,middle - 1),
                                   (middle_element,middle_element,middle),
                                   (right_element,right_element,middle + 1)]
                for c in 1:length(datetime_triple)
                    if value > datetime_triple[c][1]
                        times_under_consideration[value - datetime_triple[c][1]] = (datetime_triple[c][2],datetime_triple[c][3])
                    end
                end
                if length(times_under_consideration) > 0
                    to_be_returned = times_under_consideration[minimum(keys(times_under_consideration))] 
                    return to_be_returned
                end
            end
        elseif middle == first || middle == last
            if value <= middle_element
                return middle_element, middle
            end
            if middle_element <= value
                return middle_element, middle
            end
        end
                
        if middle_element < value
            left = middle + 1
        elseif middle_element > value
            right = middle - 1
        end
    end
end



function data_ret(T_date;data_location=data_array,A=0,price_grab=false)
    
    Keys = []
    Data = []
    
    @views begin
    if price_grab === true
        time, index = find_date(data_location[1],T_date)
        push!(Keys,(1,index))
        push!(Data,data_location[1][index])
        return Data[1][9]
    end
    
    for c in 1:length(data_location)
        if c !== 9
            time, index = find_date(data_location[c],T_date)
            push!(Keys,(c,index))
            push!(Data,data_location[c][index])
        else
            time, index = find_date(data_location[c],T_date;days=true)
            push!(Keys,(c,index))
            push!(Data,data_location[c][index])
        end
    end
    
    dP_dt = [Data[c][2] for c in 1:length(Data)]
    dW_dt = [Data[c][1] for c in 1:length(Data)]
    d2P_dt2 = []
    d2W_dt2 = []
    end
        
    if A !== 0
        A += 1
        y_list = [(data_location[1][Keys[1][2]+-c][7]+data_location[1][Keys[1][2]+-c][8])/2 for c in 1:A]
        return dP_dt, dW_dt, d2P_dt2, d2W_dt2, Keys, Data, y_list
    end

    return dP_dt, dW_dt, d2P_dt2, d2W_dt2, Keys, Data
end



function line_prediction(t,dP_dt,dW_dt;d2P_dt2=0,d2W_dt2=0,gamma=1,kappa=0,Price=Price)
    
    predictions_difference = []
    for x in eachindex(dP_dt)
        push!(predictions_difference,((dP_dt[x]*(0-(t/periods[x]))+((0-(t/periods[x]))*(dW_dt[x]-dP_dt[x])*gamma) + Price)))
    end
    
        
    total = []
    
    for c in eachindex(predictions_difference)
        if abs(t-periods[c]) !== 0
            push!(total,(predictions_difference[c]/abs(t-periods[c])))
        else
            push!(total,predictions_difference[c])
        end
    end
    
    average = mean(total)
    
    for c in eachindex(total)
        total[c] = (total[c]/average)*predictions_difference[c]
    end

    prediction = mean(total)
    
    return prediction
end


periods = [1,2,5,10,30,60,120,240,1440]


#
# The function load_files loads the data files I am using, cleans the data, and puts it in an appropriate
# format for use in training.
#

function load_files(s,v)

    data_list = [[], [], [], [], [], [], [], [], []]

    for aa in 1:9
        storage_file = s*"indicator_data_"*string(aa)*".csv"
        
        io = open(storage_file, "r")
        lines = readlines(io)
        split_lines = [split(lines[x],",") for x in 1:length(lines)]

        current_data_times = Set([split_lines[x][5] for x in 1:length(split_lines)])

        if aa != 9
            transform = Dict(1=>x->parse(Float64,x),2=>x->parse(Float64,x),3=>x->parse(Float64,x),4=>x->parse(Int8,x),5=>x->DateTime(x,"dd-u-yyyy HH:MM"),6=>x->parse(Float64,x),7=>x->parse(Float64,x),
                            8=>x->parse(Float64,x),9=>x->parse(Float64,x),10=>x->x,11=>x->x)
        else
            transform = Dict(1=>x->parse(Float64,x),2=>x->parse(Float64,x),3=>x->parse(Float64,x),4=>x->parse(Int8,x),5=>x->DateTime(x,"dd U yyyy"),6=>x->parse(Float64,x),7=>x->parse(Float64,x),
                            8=>x->parse(Float64,x),9=>x->parse(Float64,x),10=>x->x,11=>x->x)
        end

        @views begin
        current_data = [[transform[x](split_lines[c][x]) for x in 1:11] for c in 1:length(split_lines)]
        current_data = unique(current_data)
        end
        close(io)
        #end
        println(length(current_data))

        storage_file = s*"indicator_storage_aggregate_"*string(aa)*".csv"
        io = open(storage_file, "r")
        lines = readlines(io)
        split_lines = [split(lines[x],",") for x in 1:length(lines)]


        @views begin
        aggregate_data = [[transform[x](split_lines[c][x]) for x in 1:11] for c in 1:length(split_lines) if split_lines[c][5] ∉ current_data_times]
        end
        append!(current_data,aggregate_data)
        close(io)

        println(length(current_data))

        data_list[aa] = current_data
    end

    
    volume_file = v

    io = open(volume_file, "r")
    lines = readlines(io)
    split_lines = [split(lines[x],",") for x in 2:length(lines)]
    times = [split_lines[x][1] for x in 1:length(split_lines)]
    volumes = []
    
    try
        volumes = [parse(Int, split_lines[x][6]) for x in 1:length(split_lines)]
    catch ArgumentError
        volumes = [convert(Int,parse(Float64, split_lines[x][6])) for x in 1:length(split_lines)]
    end
    close(io)

    try
        DateTime(times[1],"dd-u-yyyy HH:MM")
    catch
        times = [Dates.format(astimezone(ZonedDateTime(DateTime(replace(times[x],"Z"=>""), "y-m-dTH:M:S"), tz"UTC"), tz"America/Toronto"),"dd-u-yyyy HH:MM") for x in 1:length(times)]
    end

    times = reverse(times)
    volumes = reverse(volumes)
    times_as_datetime = [DateTime(times[x],"dd-u-yyyy HH:MM") for x in 1:length(times)]

    data_list[1] = [data_list[1][c] for c in 1:length(data_list[1]) if data_list[1][c][5] ∈ times_as_datetime]

    @views begin
    for c in 1:length(data_list)
        if c != 9 && length(data_list[c]) > 1
            data_list[c] = sort(data_list[c],by= x -> x[5])
        elseif length(data_list[c]) > 1
            data_list[c] = sort(data_list[c],by= x -> x[5])
        end
    end
    end



    data_datetimes = [data_list[1][c][5] for c in 1:length(data_list[1])]

    vols = Dict()
    data_strings = [Dates.format(data_list[1][c][5], "dd-u-yyyy HH:MM") for c in 1:length(data_list[1])]
    data_set = Set(data_strings)

    @views begin
    for c in 1:length(volumes)
        if times[c] in data_set
            vols[times[c]] = volumes[c]
        end
    end
    end


    println("data_set",length(data_set))

    println(data_list[1][1][5],data_list[1][length(data_list[1])][5])

    vols_tuple = [(times[c],volumes[c]) for c in 1:length(volumes) if times[c] ∈ data_set]
    println("vols: ",length(vols))

    times_as_datetime = nothing
    data_set = nothing
    times = nothing
    volumes = nothing

    looking = 50
    d_d = 25

    All_Prices = Dict()
    All_Predictions = Dict()

    pred_deviation = Dict()
    Price_deviation = Dict()
    difference_deviation = Dict()

    gamma = -15

    println("vols: ",length(vols),"data_list[1]",length(data_list[1]))

    ordered_vols_values = similar(data_list[1],Int64)
    ordered_vols_keys = similar(data_list[1],String)

    for i in 1:length(data_list[1])
        ordered_vols_keys[i] = vols_tuple[i][1]
        ordered_vols_values[i] = vols_tuple[i][2]
    end

    ordered_price_values = similar(data_list[1],Float64)
    ordered_price_keys = similar(data_list[1],String)
    ordered_price_keys_datetime = similar(data_list[1],DateTime)

    ordered_prediction_values = similar(data_list[1],Float64)
    ordered_prediction_keys = similar(data_list[1],String)
    ordered_prediction_keys_datetime = similar(data_list[1],DateTime)

    for i in 1:length(data_list[1])
       @views begin
        T = data_strings[i]
        T_datetime = data_datetimes[i]
        data_T = data_ret(T_datetime;data_location=data_list)
        All_Predictions_time_change = time_change(T_datetime,d_d,date_time=true)[2]

        All_Prices[T] = data_T[6][1][9]
        dW_dt_d = data_T[1]
        dP_dt_d = data_T[2]

        All_Predictions[All_Predictions_time_change] = line_prediction(d_d,dP_dt_d,dW_dt_d;Price=All_Prices[T],gamma=gamma)

        ordered_price_values[i] = data_T[6][1][9]
        ordered_price_keys[i] = T
        ordered_price_keys_datetime[i] = T_datetime

        ordered_prediction_values[i] = All_Predictions[All_Predictions_time_change]
        ordered_prediction_keys[i] = All_Predictions_time_change
        ordered_prediction_keys_datetime[i] = time_change(T_datetime,d_d,date_time=true)[1]
        end
    end

    @views begin
    for c in 1:length(data_list)
        if c === 1 && length(data_list[c]) > 1
            for d in 1:length(data_list[c])
                data_list[c][d][5] = Dates.format(data_list[c][d][5], "dd-u-yyyy HH:MM")
            end
        elseif c !== 9 && length(data_list[c]) > 1
            for d in 1:length(data_list[c])
                data_list[c][d][5] = Dates.format(data_list[c][d][5], "dd-u-yyyy HH:MM")
            end
        elseif length(data_list[c]) > 1
            for d in 1:length(data_list[c])
                data_list[c][d][5] = Dates.format(data_list[c][d][5], "dd U yyyy")
            end
        end
    end
    end

    return ordered_vols_values, ordered_vols_keys, ordered_price_values, ordered_price_keys, ordered_prediction_values, ordered_prediction_keys, data_list, ordered_prediction_values
end



loaded_data = load_files("Price_data","Volume_data");



times = Float32[]
days_of_week = Float32[]
for c=1:length(loaded_data[2])
    time = Dates.Time((DateTime(loaded_data[2][c],"dd-u-yyyy HH:MM")))
    time_string = Dates.format(time, "HH:MM")
    week_day = Dates.dayofweek((DateTime(loaded_data[2][c],"dd-u-yyyy HH:MM")))
    push!(days_of_week,week_day)
    
    push!(times,parse(Int32,time_string[1:2])*60 + parse(Int32,time_string[4:5]))
end



#
# williams_ad is an implemtation of the williams accumulation distribution technical indicator
#

function williams_ad(data)
    WAD = zeros(Float32,length(data))
    WAD[1] = data[1][9]
    
    for c=1:length(data)
        if c > 1
            prev_value = WAD[c-1]
            prev_close = data[c-1][9]
            if data[c][9] > prev_close
                ad = data[c][9] - minimum([prev_close, data[c][8]])
            elseif data[c][9] < prev_close
                ad = data[c][9] - maximum([prev_close, data[c][7]])
            else
                ad = 0.
            end
            
            WAD[c] = ad + prev_value
        end
    end
    
    return WAD
end



#
# Compiling the data matrix together from the  data
#


n = length(loaded_data[1])
X = zeros(Float32,17,n)
X[1,:] = loaded_data[3][1:n]
X[2,:] = loaded_data[1][1:n]
X[3,:] = [loaded_data[7][1][1:n][c][1] for c =1:length(loaded_data[7][1][1:n])]
X[4,:] = [loaded_data[7][1][1:n][c][2] for c =1:length(loaded_data[7][1][1:n])]
X[5,:] = [loaded_data[7][1][1:n][c][6] for c =1:length(loaded_data[7][1][1:n])]
X[6,:] = [loaded_data[7][1][1:n][c][7] for c =1:length(loaded_data[7][1][1:n])]
X[7,:] = [loaded_data[7][1][1:n][c][8] for c =1:length(loaded_data[7][1][1:n])]
X[8,:] = williams_ad(loaded_data[7][1][1:n])
difference_WAD = diff(williams_ad(loaded_data[7][1][1:n]))
difference_WAD = vcat(Float32[0.f0], difference_WAD)
X[9,:] = difference_WAD
difference_open = diff([loaded_data[7][1][1:n][c][6] for c =1:length(loaded_data[7][1][1:n])])
difference_open = vcat(Float32[0.f0], difference_open)
X[10,:] = difference_open
difference_high = diff([loaded_data[7][1][1:n][c][7] for c =1:length(loaded_data[7][1][1:n])])
difference_high = vcat(Float32[0.f0], difference_high)
X[11,:] = difference_high
difference_low = diff([loaded_data[7][1][1:n][c][8] for c =1:length(loaded_data[7][1][1:n])])
difference_low = vcat(Float32[0.f0], difference_low)
X[12,:] = difference_low
difference_close = diff(loaded_data[3][1:n])
difference_close = vcat(Float32[0.f0], difference_close)
X[13,:] = difference_close
difference_volume = diff(loaded_data[1][1:n])
difference_volume = vcat(Float32[0.f0], difference_volume)
X[14,:] = difference_volume
X[15,:] = loaded_data[8][1:n];
X[16,:] = times[1:n]
X[17,:] = days_of_week[1:n];

difference_WAD = nothing
difference_volume = nothing
difference_open = nothing
difference_high = nothing
difference_low = nothing
difference_close = nothing


prices = loaded_data[3][1:n]
sde_data = X
scaling_range = 1:85000


#
# MV_standardize is an implementation of multivariate standardization which is done using the biweight
# location and biweight midcovariance matrix instead of the mean and covariance matrix respectively.
# This choice is done due to my experience that in financial data these statistics can offer more
# information when they are used to standardize instead of mean and covariance due to lower sensitivity.
#


function MV_standardize(sde_data,scaling_range)
    SL = transpose(sde_data)

    
    SL_T = transpose(SL)
    SL_S = similar(SL_T)
    μ = [convert(Float32,ast_stats.biweight_location(SL[scaling_range,c])) for c in 1:size(sde_data)[1]]
    Σ = ast_stats.biweight_midcovariance(SL_T[:,scaling_range])
    Σ = [convert(Float32,Σ[c,d]) for c=1:size(Σ)[1], d=1:size(Σ)[2]]

    Σ = ((Σ)^(-1))^(1/2)


    for c in 1:size(sde_data)[2]
        SL_S[:,c] = Σ*(SL_T[:,c] - μ)
    end
    
    return SL_S
end



sde_data = MV_standardize(sde_data,scaling_range)
SDE = deepcopy(sde_data)


_begin_ = 1
_end_ = 12000
delay = 30

test_length = 2000


y1 = [X[1,c] for c=_begin_+30:_end_+30] .- X[1,_begin_:_end_]
Mii = -20#minimum(test_std)
Mai = 20#maximum(test_std)

y1 = [c/(Mai - Mii) for c in y1]#[(c - Mii)/(Mai - Mii) for c in y1]
y1 = [Float32[y1[c]] for c=1:length(y1)];

y1 = reshape([c[1] for c in y1],1,length(y1))


seq = [[sde_data[d,c] for d=1:size(sde_data)[1]] for c=_begin_:_end_];



# Implementation of the episodic recall LSTM architecture

<font size="3.5">Here I create the struct and functions for the episodic recall LSTM to be compatible with Julia's Flux neural network library for ease of training. I also create the DND and functions related to tracking its status as well as adding, removing, and retrieving cell states from it. Special care is made to ensure the DND is in fact differentiable so it can be trained with the gradient descent algorithms which utilise Flux's exact gradients via automatic differention.</font>

In [None]:
# Defining a base LSTM cell struct which will hold base LSTM cell info
struct LSTMMVCell{A,V,S}
    Wi::A
    Wh::A
    b::V
    state0::S
end

# Defining a constructor function that intantiates a layer of the episoidc recall LSTM
function LSTMMVCell(in::Integer, out::Integer;init = Flux.glorot_uniform)
    cell = LSTMMVCell(init(out * 5, in), init(out * 5, out), zeros(Float32,out * 5), (zeros(Float32,out,1), zeros(Float32,out,1)))
    cell.b[Flux.gate(out, 2)] .= 1
    return cell
end

# Variables and dictionaries that are used for the DND
# c_dict is the DND while the rest are used to keep track of which layers of the network are storing cell
# states in the DND
num_cells = 1
c_dict = Dict(a => Dict{Array,Array}() for a=1:num_cells)
cell_memory_tracker = Dict{Int,Int}(a => 0 for a=1:num_cells)
cell_track = 0

# Kernel function used to measure similarity to other environment states
function kernel(h::Array{Float32,1},hᵢ::Dict{Int64,Array{Float32,1}};δ=0.001f0)::Tuple{Array{Float32,1},Dict{Float32,Array}}
    HH = [hᵢ[d] for d=1:length(hᵢ)]
    distances = 1.f0./(map(r::Array -> squaresum(h - r),HH).+ δ)
    Di = Dict{Float32,Array}(zip(distances,HH))
    return distances,Di
end
squaresum(x::Array)::Float32 = sum(x.^2)


# Precalculated nearest 50 environment states under the kernel function used in place of p-nearest neighbours
# at the moment, currently working on other methods of retrieval
c_history = Dict()
for c=1:7000
    distances,Di = kernel(seq[c],Dict(d => seq[d] for d=1:length(seq[1:c])))
    sort_ = length(distances) > 50 ? partialsort(distances,1:50) : sort(distances)
    c_history[c] = (sort_,[Di[sort_[d]] for d=1:length(sort_)])
end
c_history = [c_history[c] for c=1:length(c_history)];


# Function used to retrieve the cell states from the DND
function get_c(x,c_size)
    global c_dict, cell_track, c_history
    cell_track::Int64 += 1
    current_cell = (cell_track%num_cells + 1)::Int64
    if cell_track::Int64 > num_cells::Int64
        memory_retrieval = c_history[cell_track-1]
        
        nearest_cell_states = [c_dict[current_cell][memory_retrieval[2][i]] for i=1:length(memory_retrieval[2])]

        return sum((memory_retrieval[1].*nearest_cell_states)./sum(memory_retrieval[1]))
            
    else
        return zeros(Float32,c_size)
    end
end

# Function used to store the cell states in the DND
function store_c(c,x)
    global c_dict, c_keys, cell_memory_tracker
    current_cell = cell_track%num_cells + 1
    c_dict[current_cell][x] = c
    
    cell_memory_tracker[current_cell] += 1
end

# Function used on an episodic recall LSTM layer which provides the architecture
function (m::LSTMMVCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},Flux.OneHotArray}) where {A,V,T}
    b, o = m.b, size(h, 1)
    g = m.Wi*x .+ m.Wh*h .+ b

    input = σ.(Flux.gate(g, o, 1))
    forget = σ.(Flux.gate(g, o, 2))
    cell = tanh.(Flux.gate(g, o, 3))

    # In addition to regular LSTMs, the following reinstatement gate is used to introduce the recalled
    # cell states
    reinstatement = σ.(Flux.gate(g, o, 4))
    
    output = σ.(Flux.gate(g, o, 5))
    

    # Recalling the cell states
    c_in = get_c(x,size(c))
    

    c = forget .* c .+ input .* cell .+ reinstatement .* tanh.(c_in)
    # Storing the cell states
    store_c(c,x)
    
    h′ = output .* tanh.(c)
    sz = size(x)
    return (h′, c), reshape(h′, :, sz[2:end]...)
end


# Completing the setup of the episodic recall LSTM struct and functions to be compatible with Flux
@Flux.functor LSTMMVCell

LSTMMV(a...; ka...) = Recur(LSTMMVCell(a...; ka...))
Recur(m::LSTMMVCell) = Flux.Recur(m, m.state0)

# A modified cell reset function used to reset the standard LSTM memory
function MV_reset!(L;mem_reset=true)
    if mem_reset
        global c_dict, cell_track, c_keys, cell_memory_tracker
        c_dict = Dict(1 => Dict{Array,Array}())
        cell_memory_tracker = Dict{Int,Int}(1 => 0)
        cell_track = 0
    end
    Flux.reset!(L)
end

# Regular and Episodic LSTM comparison

<font size="3.5">Having implemented the episodic LSTM I made an example comparison between it and a regular LSTM. They were both trained from the same initially randomised weights (without the randomised reinstatement weights for the regular LSTM) and for the same number of iterations over the same small dataset of 850 minutes of minute gold futures data from COMEX GCZ2020. We compare the results of training the networks to predict the difference in close price of the instrument 30 minutes in the future.

The networks were trained on data from July 27, 2020 at 06:21 and July 27, 2020 at 21:35. They are then tested on 1000 minutes of data taken slightly later from July 28, 2020 at 00:05 to July 28, 2020 at 16:46. The figures are the results of this test.

It should be noted that both networks are not deep at 2 layers nor do they have many neurons, they however sufficiently manage to display the differences in the capability of regular and episodic LSTMs also noticed in larger networks which I have trained and tested on the same data.
</font>

In [None]:
# Small 2 layer LSTM with initial layer having episodic memory
episodic_memory_network = Chain(LSTMMV(17,170),LSTM(170,1))

# Regular 2 layer LSTM
standard_network = Chain(LSTM(17,170),LSTM(170,1));

![image](LSTM_gif.gif)

![image](LSTM_plot.png)

<font size="3.5">It can be seen that the episodic LSTM produced much tighter predictions with lower variation after being trained from the very small 850 minutes of data. </font>