In [21]:
using Printf
using DataFrames
using CSV
using ProgressMeter
using JLD
using ThreadSafeDicts
using Flux
using MLJ
using Random
using LinearAlgebra
using MLJBase
using MLDataPattern
using Flux.Data: DataLoader
using Flux: onehotbatch, onecold, logitcrossentropy, throttle, @epochs, params
using Base.Iterators: repeated

bithdv(N::Int=10000) = bitrand(N)

function bitadd(vectors::BitVector ...)
    v = reduce(.+, vectors)
    n = length(vectors) / 2
    x = [i > n ? 1 : i < n ? 0 : rand(0:1) for i in v]
    return convert(BitVector, x)
end

bitbind(vectors::BitVector ...) =  reduce(.⊻, vectors)

bitperm(vector::BitVector, k::Int=1) = circshift(vector, k)

hamming(x::BitVector, y::BitVector) = sum(x .!= y)/length(x)

function scaler(row, lower, upper)
    minx = minimum(row)
    maxx = maximum(row)
    x = [lower + ((i - minx)*(upper-lower))/(maxx - minx) for i in row]
    return x
end

function mat_scaler(matrix, lower, upper, dim = 1)
    if dim == 2
        scaled = reduce(hcat, [scaler(matrix[:, i], lower, upper) for i in 1:size(matrix, 2)])
    elseif dim == 1
        scaled = permutedims(hcat([scaler(matrix[i, :], lower, upper) for i in 1:size(matrix, 1)]...))
    end
    return scaled
end

function nested_arrays2mat(arrays, pd = false)
    if pd == false
        return reduce(hcat,arrays)
    else
        return permutedims(reduce(hcat,arrays))
    end
end



nested_arrays2mat (generic function with 2 methods)

In [17]:
netsurf = CSV.read("../data/netsurf.csv", DataFrame) # read dataset
select!(netsurf, Not(" cb513_mask")) # remove column
for i in ["input", " dssp3", " dssp8"] # remove white spaces in strings
    netsurf[!, i] = [join(map(x -> isspace(a[x]) ? "" : a[x], 1:length(a))) for a in netsurf[!, i]]
end

seq_list = netsurf.input # all sequences
AA_set = union(hcat([Set(i) for i in seq_list])...) # get all possible amino acids

function loader()
    x = vcat(
    JLD.load("../data/k250.jld")["k25[ranges[i]:ranges[i + 1]]"], 
    [JLD.load(@sprintf("../data/k25%s.jld", i))["k25[ranges[i] + 1:ranges[i + 1]]"] for i in 1:3]...,)
    return x
end
println("start loading")
k25 = loader()
println("loaded")

n_seq = length(k25)


start loading


loaded


801

In [41]:
function getdata(args)
    x_data = [x for i in k25 for x in i] # set x and y arrays
    xl = length(x_data)
    y_data = first([x for i in netsurf[!, " dssp8"] for x in i], xl)
    
    (X_train,y_train), (X_test,y_test) = stratifiedobs((x_data,y_data), p = 0.7, shuffle = true)
    X_train = hcat(X_train...)
    X_test = hcat(X_test...)
    
    struct_set = union(hcat([Set(i) for i in netsurf[!, " dssp8"]])...)
    y_train, y_test = onehotbatch(y_train, struct_set), onehotbatch(y_test, struct_set)

    train_data = DataLoader((X_train, y_train), batchsize = args.batchsize, shuffle=true)
    test_data = DataLoader((X_test, y_test), batchsize = args.batchsize)
    
    return train_data, test_data
end


getdata (generic function with 1 method)

In [19]:
function loss_all(dataloader, model)
    l = 0f0
    for (x,y) in dataloader
        l += logitcrossentropy(model(x), y)
    end
    l/length(dataloader)
end

function accuracy(data_loader, model)
    acc = 0
    for (x,y) in data_loader
        acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))*1 / size(x,2)
    end
    acc/length(data_loader)
end
loss(x,y) = logitcrossentropy(m(x), y)

function build_model(; size=10000, nclasses=8)
    return Chain(
 	    Dense(prod(size), 32, relu),
            Dense(32, nclasses))
end

build_model (generic function with 1 method)

In [23]:
Base.@kwdef mutable struct Args
    rate::Float64 = 3e-4    # learning rate
    batchsize::Int = 1024   # batch size
    epochs::Int = 10      # number of epochs
end

Args

In [34]:
y_data

801-element Vector{Char}:
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 ⋮
 'S': ASCII/Unicode U+0053 (category Lu: Letter, uppercase)
 'T': ASCII/Unicode U+0054 (category Lu: Letter, uppercase)
 'T': ASCII/Unicode U+0054 (category Lu: Letter, uppercase)
 'G': ASCII/Unicode U+0047 (category Lu: Letter, uppercase)
 'G': ASCII/Unicode U+0047 (category Lu: Letter, uppercase)
 'G': ASCII/Unicode U+0047 (category Lu: Letter, uppercase)
 'S': ASCII

In [40]:
x_data = [x for i in k25 for x in i] # set x and y arrays
xl = length(x_data)
y_data = first([x for i in netsurf[!, " dssp8"] for x in i], xl)

196781-element Vector{Char}:
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 'H': ASCII/Unicode U+0048 (category Lu: Letter, uppercase)
 ⋮
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': ASCII/Unicode U+0043 (category Lu: Letter, uppercase)
 'C': AS

In [39]:
x_data

196781-element Vector{BitVector}:
 [1, 1, 0, 0, 0, 0, 0, 0, 1, 0  …  1, 1, 0, 1, 1, 0, 0, 0, 0, 0]
 [0, 0, 1, 0, 0, 0, 0, 1, 0, 1  …  1, 0, 0, 0, 0, 0, 1, 1, 0, 0]
 [0, 1, 0, 1, 1, 1, 0, 0, 0, 0  …  1, 0, 1, 0, 0, 0, 0, 1, 0, 1]
 [1, 1, 0, 0, 1, 0, 0, 0, 0, 0  …  0, 0, 0, 1, 0, 0, 0, 0, 0, 1]
 [1, 0, 0, 1, 1, 0, 1, 0, 1, 0  …  0, 0, 1, 0, 1, 0, 0, 0, 1, 1]
 [0, 0, 1, 0, 1, 0, 1, 1, 1, 0  …  1, 1, 1, 1, 0, 0, 1, 0, 1, 0]
 [1, 0, 1, 0, 1, 0, 0, 0, 1, 0  …  1, 0, 0, 1, 1, 1, 1, 0, 0, 1]
 [0, 0, 0, 1, 1, 1, 0, 1, 0, 1  …  1, 0, 1, 1, 1, 0, 1, 0, 0, 0]
 [0, 0, 1, 1, 0, 0, 0, 1, 1, 0  …  0, 0, 1, 1, 0, 0, 1, 1, 1, 0]
 [0, 1, 0, 0, 1, 1, 0, 1, 0, 1  …  1, 0, 0, 1, 0, 0, 1, 0, 0, 1]
 ⋮
 [1, 0, 0, 1, 0, 0, 0, 0, 0, 0  …  0, 1, 1, 1, 0, 1, 1, 0, 1, 0]
 [0, 0, 1, 0, 1, 0, 0, 0, 1, 0  …  0, 0, 0, 0, 1, 0, 1, 1, 0, 0]
 [1, 0, 1, 1, 0, 0, 1, 1, 0, 0  …  1, 0, 1, 0, 0, 0, 0, 1, 0, 1]
 [0, 0, 1, 0, 0, 1, 0, 0, 1, 1  …  1, 0, 1, 1, 0, 0, 0, 1, 1, 0]
 [1, 1, 0, 0, 1, 1, 1, 1, 1, 0  …  1, 0, 0, 1, 1, 0, 

In [42]:
function train(; kws...)
    # Initializing Model parameters 
    args = Args(; kws...)
    println("dataloader loading")
    train_data,test_data = getdata(args)
    println("dataloader loaded")
    m = build_model()
    loss(x,y) = logitcrossentropy(m(x), y)
    
    ## Training
    println("training started")
    evalcb = () -> @show(loss_all(train_data, m))
    opt = Adam(args.rate)
		
    @epochs args.epochs Flux.train!(loss, params(m), train_data, opt, cb = evalcb)

    @show accuracy(train_data, m)

    @show accuracy(test_data, m)

    @save "../data/model25_test.jld" m
end

train()

dataloader loading


dataloader loaded


training started


┌ Info: Epoch 1
└ @ Main /home/mfat/.julia/packages/Flux/FWgS0/src/optimise/train.jl:185


loss_all(train_data, m) = 1.9785403f0
loss_all(train_data, m) = 

1.811339f0


loss_all(train_data, m) = 1.7997732f0
loss_all(train_data, m) = 

1.7791259f0


loss_all(train_data, m) = 1.7759331f0


InterruptException: InterruptException: