In [1]:
using Pkg
Pkg.activate(".")
using DataFrames, JSON3, SurrealDB, Arrow, CSV, Dates, FreqTables, ProgressMeter, Flux, CUDA

[32m[1m  Activating[22m[39m project at `/blue/raquel.dias/nicks/protein_prediction/structuralFeatures/ISMBExample`


In [2]:
@assert CUDA.functional() 

## Helper Function Definitions

In [14]:
function fetch_features(featQuery::String,features::Vector{String};connection::SurrealConnection=conn,batch_size=5000)
    out = []
    chunks = Iterators.partition(features, batch_size) |> collect
    for i in 1:length(chunks)
        print("\rFetching batch $i of $(length(chunks))")
        cmd = replace(featQuery,"{{features}}" => "[$(join(chunks[i],","))]")
        resp = query(conn,cmd)
        res = resp.result[1].result
        println(" Batch $i complete in $(resp.result[1].time)")
        push!(out,res...)
    end
    return out
end

function fetchDatasets()
    query(conn,"SELECT name,id from datasets").result[1].result
end

function getDatasetFeatureIDs(datasetID::String)
    println("Retrieving all feature records for $datasetID")
    feature_query = """
select
    value features
from
    protein
where
    datasets contains $datasetID"""
    return query(conn,feature_query).result[1].result
end
function getDatasetPrecomputed(datasetID::String)
    query(conn, "select value precomputed_components.*.* from $(datasetID)").result[1].result[1] |> DataFrame
end

function getPrecompURI(component_id::String)
    q = """select value uri from $component_id"""
    res =  query(conn,q).result[1].result
    return length(res) == 0 ? nothing : first(res) 
end

function featureResultToDF(results::Vector, feature_col_name::String,label_col_name::String)
    d = DataFrame()
    for i in results
        d[!,i[label_col_name]] = i[feature_col_name]
    end
    d
end

function getData(run_declarations::Vector{Dict{String, String}})
    fetchedDatasetIDs = []
    fetchedDatasets = Dict()
    map(run_declarations) do x
    @time begin
        @assert issubset(["runName","datasetID","featureName","sampleNameCol","featureQuery","modelSaveLoc"],keys(x)) "Missing declarations"
        println("processing $(x["runName"])...")
        datasetID = x["datasetID"]
        if datasetID in fetchedDatasetIDs
            println("Feature records for $datasetID already retrieved, using cached values.")
            featureIDs = fetchedDatasets[datasetID]
        else
            featureIDs = getDatasetFeatureIDs(datasetID)
            println("Caching feature records for future batches.")
            fetchedDatasets[datasetID] = featureIDs
        end
        push!(fetchedDatasetIDs,datasetID)
        data = fetch_features(x["featureQuery"],featureIDs[1:end],batch_size = 15000)
        df = featureResultToDF(data,x["featureName"],x["sampleNameCol"]) |> y->convert.(Float32,y)
        additions = Dict(
            "pipeline_featureIDs" => featureIDs,
            "pipeline_data" =>  data,
            "pipeline_df" => df,
            "pipeline_featureDims" => size(df),
            )
        merge(x,additions)
    end
    end
end

function accuracy(results, truth)
    count(results .== truth) / length(truth)
end

function trainBatches(run_declarations::Vector{Dict{String, String}};debug=false)
    map(run_declarations) do batch
         println("Starting DataRetrieval for run $(batch["runName"])")
        x = getData([batch])[1]
        train_protein_ids = names(x["pipeline_df"])
        labels_df = DataFrame([i=>Float32[] for i in top_terms])
        println("Creating labels for run $(x["runName"])")
        @showprogress for id in train_protein_ids
            df = gdf[(id,)]
            row = [i in df.term ? 1.0 : 0.0 for i in top_terms]
            push!(labels_df,row)
        end 
        INPUT_SHAPE = x["pipeline_featureDims"][1]
        BATCH_SIZE = 5120
        gpu_model = Chain(
            BatchNorm(INPUT_SHAPE;eps = 0.001f0, momentum=0.99f0),
            Dense(INPUT_SHAPE=>512,relu),
            Dense(512=>512,relu),
            Dense(512=>512,relu),
            Dense(512=>num_labels,sigmoid)
        ) |> gpu
        train_matrix = Matrix(x["pipeline_df"])
        label_matrix = permutedims(Matrix(labels_df))
        gpu_loader = Flux.DataLoader((train_matrix, label_matrix) ,batchsize = BATCH_SIZE) |> gpu
        gpu_optim = Flux.setup(Flux.Adam(0.001,(0.9,0.999),1.0e-7), gpu_model) 
        gpu_losses = []
        println("Starting training for run $(x["runName"])")
        if debug
            println(["trainMatrix"=>size(train_matrix), "label_matrix"=>size(label_matrix),"model"=>gpu_model,"gpuLoader" => gpu_loader])
        end
        @showprogress for epoch in 1:5
            for (x, y) in gpu_loader
                loss, grads = Flux.withgradient(gpu_model) do m
                    # Evaluate model and loss inside gradient context:
                    y_hat = m(x)
                    Flux.binarycrossentropy(y_hat,y)
                end
                Flux.update!(gpu_optim, gpu_model, grads[1])
                push!(gpu_losses, loss)  # logging, outside gradient context
            end
        end
       println("Training complete for run $(x["runName"])")
        cpu_model = cpu(gpu_model)
        testmode!(cpu_model)
        res = cpu_model(train_matrix)
        binres = [i >= 0.5 ? 1.0 : 0 for i in res]
        acc = accuracy(binres,label_matrix)
        resultDict = Dict(
            "result_loss" => gpu_losses,
            "result_model" => gpu_model,
            "result_modelState" => Flux.state(gpu_model),
            "result_binaryAccuracy" => acc,
            "result_testingModel" => cpu_model,
            )
        
        merge(x,resultDict)
    end
end

trainBatches (generic function with 1 method)

In [1]:
io = open(`bash`,"w",stdout)
write(io,"ml surrealdb/1.5.0 && surreal start --auth --allow-all file:../blue.raquel.dias/nicks/protein_prediction/structuralFeatures/ismb.db")
sleep(7)
close(io)


 .d8888b.                                             888 8888888b.  888888b.
d88P  Y88b                                            888 888  'Y88b 888  '88b
Y88b.                                                 888 888    888 888  .88P
 'Y888b.   888  888 888d888 888d888  .d88b.   8888b.  888 888    888 8888888K.
    'Y88b. 888  888 888P'   888P'   d8P  Y8b     '88b 888 888    888 888  'Y88b
      '888 888  888 888     888     88888888 .d888888 888 888    888 888    888
Y88b  d88P Y88b 888 888     888     Y8b.     888  888 888 888  .d88P 888   d88P
 'Y8888P'   'Y88888 888     888      'Y8888  'Y888888 888 8888888P'  8888888P'




[2m2024-07-14T22:42:46.019823Z[0m [32m INFO[0m [2msurreal::env[0m[2m:[0m Running 1.5.0 for linux on x86_64
[2m2024-07-14T22:42:46.019858Z[0m [32m INFO[0m [2msurreal::dbs[0m[2m:[0m ✅🔒 Authentication is enabled 🔒✅
[2m2024-07-14T22:42:46.020679Z[0m [32m INFO[0m [2msurrealdb_core::kvs::ds[0m[2m:[0m Starting kvs store at file://../blue.raquel.dias/nicks/protein_prediction/structuralFeatures/ismb.db
[2m2024-07-14T22:42:46.419665Z[0m [32m INFO[0m [2msurrealdb_core::kvs::ds[0m[2m:[0m Started kvs store at file://../blue.raquel.dias/nicks/protein_prediction/structuralFeatures/ismb.db
[2m2024-07-14T22:42:46.444290Z[0m [32m INFO[0m [2msurrealdb::net[0m[2m:[0m Started web server on 0.0.0.0:8000


In [5]:
begin
    conn = SurrealConnection("ws://localhost:8000/rpc","diaslab","examplePW")
    conn.ns = "diaslab"
    conn.db = "ismb"
    signin(conn)
    use(conn)
end
query(conn,"info for db").result;


 .d8888b.                                             888 8888888b.  888888b.
d88P  Y88b                                            888 888  'Y88b 888  '88b
Y88b.                                                 888 888    888 888  .88P
 'Y888b.   888  888 888d888 888d888  .d88b.   8888b.  888 888    888 8888888K.
    'Y88b. 888  888 888P'   888P'   d8P  Y8b     '88b 888 888    888 888  'Y88b
      '888 888  888 888     888     88888888 .d888888 888 888    888 888    888
Y88b  d88P Y88b 888 888     888     Y8b.     888  888 888 888  .d88P 888   d88P
 'Y8888P'   'Y88888 888     888      'Y8888  'Y888888 888 8888888P'  8888888P'




[2m2024-06-29T23:18:13.287535Z[0m [32m INFO[0m [2msurreal::env[0m[2m:[0m Running 1.5.0 for linux on x86_64
[2m2024-06-29T23:18:13.287552Z[0m [32m INFO[0m [2msurreal::dbs[0m[2m:[0m ✅🔒 Authentication is enabled 🔒✅
[2m2024-06-29T23:18:13.287565Z[0m [32m INFO[0m [2msurrealdb_core::kvs::ds[0m[2m:[0m Starting kvs store at file://../blue.raquel.dias/nicks/protein_prediction/structuralFeatures/ismb.db
[2m2024-06-29T23:18:13.485613Z[0m [32m INFO[0m [2msurrealdb_core::kvs::ds[0m[2m:[0m Started kvs store at file://../blue.raquel.dias/nicks/protein_prediction/structuralFeatures/ismb.db
[2m2024-06-29T23:18:13.510143Z[0m [32m INFO[0m [2msurrealdb::net[0m[2m:[0m Started web server on 0.0.0.0:8000


## Model Training Workflow
* Declare the pipeline parameters for each run
    1. Declare the dataset of interest
    2. Define the feature query
    3. Declare the feature name column and sample name column
    4. Declare the save location
* Pipeline will run in series
    1. Fetch specified data from database
    2. Train model
    3. Return and/or save the results to the save location

### Find Available Datasets

In [6]:
fetchDatasets()

1-element JSON3.Array{JSON3.Object, Base.CodeUnits{UInt8, String}, SubArray{UInt64, 1, Vector{UInt64}, Tuple{UnitRange{Int64}}, true}}:
 {
     "id": "datasets:61u6akz831fm28kc2o7v",
   "name": "CAFA5 Train"
}

## Declare the run parameters

In [7]:
run_declarations = [
    Dict(
        "runName" => "CAFAT5Embeddings_NN",
        "datasetID" => "datasets:61u6akz831fm28kc2o7v",
        "featureName" => "t5",
        "sampleNameCol" => "uniprot_id",
        "featureQuery" =>  """select feature_t5_embed as t5 ,uniprot_id from {{features}}""",
        "modelSaveLoc" => "./models"
    ),
    Dict(
        "runName" => "CAFAESMEmbeddings_NN",
        "datasetID" => "datasets:61u6akz831fm28kc2o7v",
        "featureName" => "esm",
        "sampleNameCol" => "uniprot_id",
        "featureQuery" =>  """select feature_esm_embed as esm ,uniprot_id from {{features}}""",
        "modelSaveLoc" => "./models"
    ),
    Dict(
        "runName" => "CAFACombinedT5ESMEmbeddings_NN",
        "datasetID" => "datasets:61u6akz831fm28kc2o7v",
        "featureName" => "combined",
        "sampleNameCol" => "uniprot_id",
        "featureQuery" =>  """select array::flatten([feature_t5_embed, feature_esm_embed]) as combined ,uniprot_id from {{features}}""",
        "modelSaveLoc" => "./models"
        )
    ]

3-element Vector{Dict{String, String}}:
 Dict("datasetID" => "datasets:61u6akz831fm28kc2o7v", "modelSaveLoc" => "./models", "runName" => "CAFAT5Embeddings_NN", "sampleNameCol" => "uniprot_id", "featureQuery" => "select feature_t5_embed as t5 ,uniprot_id from {{features}}", "featureName" => "t5")
 Dict("datasetID" => "datasets:61u6akz831fm28kc2o7v", "modelSaveLoc" => "./models", "runName" => "CAFAESMEmbeddings_NN", "sampleNameCol" => "uniprot_id", "featureQuery" => "select feature_esm_embed as esm ,uniprot_id from {{features}}", "featureName" => "esm")
 Dict("datasetID" => "datasets:61u6akz831fm28kc2o7v", "modelSaveLoc" => "./models", "runName" => "CAFACombinedT5ESMEmbeddings_NN", "sampleNameCol" => "uniprot_id", "featureQuery" => "select array::flatten([feature_t5_embed, feature_esm_embed]) as combined ,uniprot_id from {{features}}", "featureName" => "combined")

## Example of using Precomputed Resource
Start by querying what precomputed resources are available

In [8]:
getDatasetPrecomputed("datasets:61u6akz831fm28kc2o7v")

Row,component_description,component_name,id,modified,uri
Unnamed: 0_level_1,String,String,String,String,String
1,UniProt Accession IDs for CAFA5 Dataset elements with GO terms in the uniprot_meta,Protein Accessions with GO Terms,precomputed:58spmd3q53lthnopvkow,2024-06-27T22:09:49.538797390Z,/blue/raquel.dias/nicks/protein_prediction/structuralFeatures/ISMBExample/assets/DBProteinsWithTerms.arrow
2,Official train_terms.tsv file from CAFA5 competition saved in arrow format,train_terms,precomputed:xjzqzx9zird2frkwoh7a,2024-06-27T22:58:12.257422551Z,/blue/raquel.dias/nicks/protein_prediction/structuralFeatures/ISMBExample/assets/cafa_train_terms.arrow


In [9]:
# fetch the uri of the resource using its ID
resource_uri = getPrecompURI("precomputed:xjzqzx9zird2frkwoh7a")
# load the resource from the given uri
train_terms = DataFrame(Arrow.Table(resource_uri))

Row,EntryID,term,aspect
Unnamed: 0_level_1,String15,String15,String3
1,A0A009IHW8,GO:0008152,BPO
2,A0A009IHW8,GO:0034655,BPO
3,A0A009IHW8,GO:0072523,BPO
4,A0A009IHW8,GO:0044270,BPO
5,A0A009IHW8,GO:0006753,BPO
6,A0A009IHW8,GO:1901292,BPO
7,A0A009IHW8,GO:0044237,BPO
8,A0A009IHW8,GO:1901360,BPO
9,A0A009IHW8,GO:0008150,BPO
10,A0A009IHW8,GO:1901564,BPO


## Create items that will be reused during all training runs

In [10]:
term_ft = freqtable(train_terms,:term) |> sort |> reverse
num_labels = 1500
top_terms = term_ft[1:num_labels] |> names |> y->y[1]
train_terms_updated = filter(row->row.term in top_terms,train_terms)
gdf = groupby(train_terms_updated,:EntryID)

Row,EntryID,term,aspect
Unnamed: 0_level_1,String15,String15,String3
1,A0A009IHW8,GO:0008152,BPO
2,A0A009IHW8,GO:0034655,BPO
3,A0A009IHW8,GO:0044270,BPO
4,A0A009IHW8,GO:0006753,BPO
5,A0A009IHW8,GO:0044237,BPO
6,A0A009IHW8,GO:1901360,BPO
7,A0A009IHW8,GO:0008150,BPO
8,A0A009IHW8,GO:1901564,BPO
9,A0A009IHW8,GO:1901565,BPO
10,A0A009IHW8,GO:0009117,BPO

Row,EntryID,term,aspect
Unnamed: 0_level_1,String15,String15,String3
1,X5HMX4,GO:0005515,MFO
2,X5HMX4,GO:0005488,MFO
3,X5HMX4,GO:0003674,MFO


## Training Iterations

In [15]:
@time res = trainBatches(run_declarations);

Starting DataRetrieval for run CAFAT5Embeddings_NN
processing CAFAT5Embeddings_NN...
Retrieving all feature records for datasets:61u6akz831fm28kc2o7v
Caching feature records for future batches.
Fetching batch 1 of 10 Batch 1 complete in 6.545966594s
Fetching batch 2 of 10 Batch 2 complete in 5.746395408s
Fetching batch 3 of 10 Batch 3 complete in 5.601333508s
Fetching batch 4 of 10 Batch 4 complete in 5.563351412s
Fetching batch 5 of 10 Batch 5 complete in 5.514220862s
Fetching batch 6 of 10 Batch 6 complete in 5.26713768s
Fetching batch 7 of 10 Batch 7 complete in 5.478087843s
Fetching batch 8 of 10 Batch 8 complete in 5.240745615s
Fetching batch 9 of 10 Batch 9 complete in 5.342034728s
Fetching batch 10 of 10 Batch 10 complete in 2.497972392s
 85.599866 seconds (3.89 M allocations: 7.056 GiB, 0.20% gc time)
Creating labels for run CAFAT5Embeddings_NN


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:53[39m


Starting training for run CAFAT5Embeddings_NN


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:49[39m


Training complete for run CAFAT5Embeddings_NN
Starting DataRetrieval for run CAFAESMEmbeddings_NN
processing CAFAESMEmbeddings_NN...
Retrieving all feature records for datasets:61u6akz831fm28kc2o7v
Caching feature records for future batches.
Fetching batch 1 of 10 Batch 1 complete in 6.841645775s
Fetching batch 2 of 10 Batch 2 complete in 6.811802079s
Fetching batch 3 of 10 Batch 3 complete in 6.506057629s
Fetching batch 4 of 10 Batch 4 complete in 6.463445093s
Fetching batch 5 of 10 Batch 5 complete in 6.480391223s
Fetching batch 6 of 10 Batch 6 complete in 6.135304518s
Fetching batch 7 of 10 Batch 7 complete in 6.1036452s
Fetching batch 8 of 10 Batch 8 complete in 6.066491464s
Fetching batch 9 of 10 Batch 9 complete in 6.260540593s
Fetching batch 10 of 10 Batch 10 complete in 2.888960014s
 98.869827 seconds (3.84 M allocations: 8.235 GiB, 0.49% gc time)
Creating labels for run CAFAESMEmbeddings_NN


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:52[39m


Starting training for run CAFAESMEmbeddings_NN


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:04[39m


Training complete for run CAFAESMEmbeddings_NN
Starting DataRetrieval for run CAFACombinedT5ESMEmbeddings_NN
processing CAFACombinedT5ESMEmbeddings_NN...
Retrieving all feature records for datasets:61u6akz831fm28kc2o7v
Caching feature records for future batches.
Fetching batch 1 of 10 Batch 1 complete in 10.6906098s
Fetching batch 2 of 10 Batch 2 complete in 9.909013349s
Fetching batch 3 of 10 Batch 3 complete in 9.665646859s
Fetching batch 4 of 10 Batch 4 complete in 9.698247981s
Fetching batch 5 of 10 Batch 5 complete in 9.540985515s
Fetching batch 6 of 10 Batch 6 complete in 9.279997473s
Fetching batch 7 of 10 Batch 7 complete in 9.353314598s
Fetching batch 8 of 10 Batch 8 complete in 9.366351574s
Fetching batch 9 of 10 Batch 9 complete in 9.541592526s
Fetching batch 10 of 10 Batch 10 complete in 4.278301745s
159.306861 seconds (4.02 M allocations: 14.450 GiB, 0.46% gc time)
Creating labels for run CAFACombinedT5ESMEmbeddings_NN


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:50[39m


Starting training for run CAFACombinedT5ESMEmbeddings_NN


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:05[39m


Training complete for run CAFACombinedT5ESMEmbeddings_NN
626.820181 seconds (3.94 G allocations: 187.388 GiB, 2.06% gc time, 0.01% compilation time)


In [19]:
foreach(res) do result
    println("Binary Accuracy for $(result["runName"]): $(round(result["result_binaryAccuracy"],digits = 4))")
end

Binary Accuracy for CAFAT5Embeddings_NN: 0.9803
Binary Accuracy for CAFAESMEmbeddings_NN: 0.9798
Binary Accuracy for CAFACombinedT5ESMEmbeddings_NN: 0.9804


In [21]:
res[1]

Dict{String, Any} with 15 entries:
  "datasetID"             => "datasets:61u6akz831fm28kc2o7v"
  "modelSaveLoc"          => "./models"
  "runName"               => "CAFAT5Embeddings_NN"
  "sampleNameCol"         => "uniprot_id"
  "result_modelState"     => (layers = ((λ = (), β = Float32[-0.0129818, -0.006…
  "result_model"          => Chain(BatchNorm(1024), Dense(1024 => 512, relu), D…
  "pipeline_data"         => Any[{…
  "result_loss"           => Any[0.700808, 0.652192, 0.584157, 0.469541, 0.3273…
  "result_testingModel"   => Chain(BatchNorm(1024, active=false), Dense(1024 =>…
  "pipeline_featureIDs"   => ["features:A0A009IHW8", "features:A0A021WW32", "fe…
  "pipeline_df"           => [1m1024×142246 DataFrame[0m[0m…
  "pipeline_featureDims"  => (1024, 142246)
  "result_binaryAccuracy" => 0.980346
  "featureQuery"          => "select feature_t5_embed as t5 ,uniprot_id from {{…
  "featureName"           => "t5"