In [1]:
using Random
using LinearAlgebra
"""
Construct a binary vector. By default 10000 elements long.
"""
hdv(N::Int=10000) = rand((-1,1), 1, N)


"""
Bundles bipolar hyperdimensional vectors.
"""
add(vectors::Vector{Int}...) = reduce(.+, vectors) .|> sign


"""
Binds binpolar hyperdimensional vectors.
"""
multiply(vectors::Vector{Int}...) = reduce(.*, vectors)


"""
Permutes a bipolar hyperdimensional vector by an adjustable circular shift.
"""
perm(vector::Vector, k::Int=1) = circshift(vector, (0, k))


"""
Calculates the cosine similarity between two bipolar vectors.
"""
cosine(x, y) = dot(x, y) / (norm(x) * norm(y))

cosine

In [8]:
"""
Construct a binary vector. By default 10000 elements long.
"""
bithdv(N::Int=10000) = bitrand(N)


"""
Bundles binary hyperdimensional vectors based on the element-wise majority rule.
"""
function bitadd(vectors::BitVector ...)
    v = reduce(.+, vectors)
    n = length(vectors) / 2
    x = [i > n ? 1 : i < n ? 0 : rand(0:1) for i in v]
    return convert(BitVector, x)
end


"""
Binds binary hyperdimensional vectors based on an element-wise XOR gate.
"""
bitbind(vectors::BitVector ...) =  reduce(.⊻, vectors)


"""
Permutes a binary hyperdimensional vector by an adjustable circular shift.
"""
bitperm(vector::BitVector, k::Int=1) = circshift(vector, k)


"""
Calculates the Hamming distance between two binary vectors.
"""
hamming(x::BitVector, y::BitVector) = sum(x .!= y)/length(x)

hamming

In [3]:
using DataFrames, CSV
data = CSV.read("ProtExdata/ACPs_Breast_cancer.csv", DataFrame)
unique(data.class)
class_num = [i == "very active" ? 1 : i == "mod. active" ? 2 : i == "inactive - exp" ? 3 : 4 for i in data.class]
data[!, :class_num] = class_num
data = data[data.class_num .!= 4, :]
first(data, 5)

Unnamed: 0_level_0,ID,sequence,class,class_num
Unnamed: 0_level_1,Int64,String,String31,Int64
1,1,AAWKWAWAKKWAKAKKWAKAA,mod. active,2
2,2,AIGKFLHSAKKFGKAFVGEIMNS,mod. active,2
3,3,AWKKWAKAWKWAKAKWWAKAA,mod. active,2
4,4,ESFSDWWKLLAE,mod. active,2
5,5,ETFADWWKLLAE,mod. active,2


In [4]:
using PyCall

py"""
import numpy as np
from transformers import EsmModel, EsmConfig, EsmTokenizer
import torch

embeddings = "facebook/esm2_t6_8M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(embeddings)
model = EsmModel.from_pretrained(embeddings)

seq = 'ARNDCQEGHILKMFPSTWYVOUBJZX'
inputs = tokenizer(seq, return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state[0].detach().numpy()
print(last_hidden_states)
"""

[[ 0.1828955   0.5040589   0.39757144 ...  0.9225599   0.05342245
  -0.45949394]
 [ 0.3699865   0.28748336 -0.0116293  ...  0.12235592  0.17939556
  -0.2283884 ]
 [ 0.08383    -0.2629382   0.2560887  ... -0.10995694  0.3136453
   0.14911757]
 ...
 [ 0.2191191  -0.28030506  0.46562314 ... -0.02370527  0.45409355
  -0.10943315]
 [ 0.25577268 -0.34152842  0.41236705 ...  0.0049744   0.4067464
  -0.07702404]
 [ 0.11428512 -0.08154328  0.43711823 ...  0.24650113 -0.06088951
  -0.1885616 ]]


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmModel: ['esm.contact_head.regression.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'esm.contact_head.regression.bias', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a

In [39]:
using StatsBase

AA_list = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V', 'O', 'U', 'B', 'J', 'Z', 'X']
embeddings_esm = PyArray(py"last_hidden_states"o)[2:27, :]

dt = fit(UnitRangeTransform, embeddings_esm, dims = 1)
embeddings_esm = StatsBase.transform(dt, embeddings_esm)

random_hdv = permutedims(hcat([bithdv() for i in 1:320]...))

embeddings_esm_hdv = embeddings_esm*random_hdv
dt = fit(UnitRangeTransform, embeddings_esm_hdv, dims = 1)
embeddings_esm_hdv = round.(StatsBase.transform(dt, embeddings_esm_hdv))

hdvs = [convert(BitVector, embeddings_esm_hdv[i, :]) for i in 1:26]
AA_dict_esm = Dict(zip(AA_list, hdvs))

Dict{Char, BitVector} with 26 entries:
  'E' => [0, 1, 0, 1, 1, 0, 1, 1, 0, 1  …  1, 1, 0, 1, 0, 1, 1, 1, 0, 1]
  'Z' => [0, 1, 0, 1, 1, 0, 0, 1, 0, 0  …  1, 0, 0, 0, 1, 0, 1, 0, 1, 1]
  'X' => [0, 1, 1, 1, 1, 0, 0, 1, 0, 0  …  1, 0, 0, 0, 1, 0, 1, 0, 1, 1]
  'C' => [1, 1, 1, 0, 1, 1, 1, 1, 0, 1  …  0, 0, 1, 1, 0, 1, 0, 0, 0, 0]
  'B' => [0, 0, 1, 1, 1, 1, 0, 1, 0, 0  …  1, 0, 0, 0, 1, 0, 1, 0, 1, 1]
  'D' => [0, 0, 0, 0, 0, 1, 0, 0, 0, 0  …  1, 1, 1, 1, 1, 0, 1, 0, 0, 1]
  'A' => [0, 1, 0, 0, 1, 1, 1, 0, 1, 1  …  1, 1, 1, 0, 0, 1, 0, 1, 1, 0]
  'R' => [0, 0, 0, 1, 1, 1, 0, 1, 1, 0  …  1, 0, 1, 0, 0, 0, 0, 1, 0, 1]
  'G' => [0, 1, 0, 0, 1, 1, 0, 1, 1, 1  …  1, 1, 0, 0, 0, 0, 1, 0, 0, 1]
  'N' => [1, 0, 0, 1, 0, 0, 1, 0, 0, 1  …  1, 1, 1, 0, 0, 1, 1, 0, 0, 1]
  'Q' => [1, 0, 0, 0, 0, 1, 0, 0, 1, 0  …  0, 0, 1, 1, 0, 1, 1, 1, 0, 1]
  'M' => [1, 0, 0, 0, 1, 0, 1, 1, 0, 0  …  1, 1, 1, 1, 1, 0, 1, 0, 0, 0]
  'K' => [0, 0, 0, 1, 0, 0, 0, 1, 0, 0  …  0, 0, 1, 0, 1, 1, 0, 0, 1, 1]
  'F' => [0,

In [63]:
println(AA_list[7], AA_list[25])
println(cosine(embeddings_esm[7,:], embeddings_esm[25,:]))
println(hamming(AA_dict_esm['E'], AA_dict_esm['Z']))

println(AA_list[1], AA_list[8])
println(cosine(embeddings_esm[1,:], embeddings_esm[8,:]))
println(hamming(AA_dict_esm['A'], AA_dict_esm['G']))

println(AA_list[1], AA_list[15])
println(cosine(embeddings_esm[1,:], embeddings_esm[15,:]))
println(hamming(AA_dict_esm['A'], AA_dict_esm['P']))

println(AA_list[9], AA_list[12])
println(cosine(embeddings_esm[9,:], embeddings_esm[12,:]))
println(hamming(AA_dict_esm['H'], AA_dict_esm['K']))

EZ
0.9116949
0.4626
AG
0.9950109
0.5103
AP
0.8230575
0.5444
HK
0.985411
0.4482


In [47]:
using PyCall
using StatsBase
py"""
import pickle
infile = open("/home/mfat/Downloads/data(2).pkl",'rb')
embeddings_h = pickle.load(infile)
infile.close()
"""

embeddings_trans = PyArray(py"embeddings_h "o)
dt = fit(UnitRangeTransform, embeddings_trans, dims = 1)
embeddings_esm = StatsBase.transform(dt, embeddings_trans)

random_hdv = permutedims(hcat([bithdv() for i in 1:1024]...))

embeddings_trans_hdv = embeddings_trans*random_hdv
dt = fit(UnitRangeTransform, embeddings_trans_hdv, dims = 1)
embeddings_trans_hdv = round.(StatsBase.transform(dt, embeddings_trans_hdv))

hdvs = [convert(BitVector, embeddings_trans_hdv[i, :]) for i in 1:26]
AA_dict_trans = Dict(zip(AA_list, hdvs))

Dict{Char, BitVector} with 26 entries:
  'E' => [0, 1, 1, 0, 0, 0, 1, 1, 1, 0  …  0, 0, 1, 1, 1, 0, 0, 0, 0, 1]
  'Z' => [1, 1, 1, 0, 0, 1, 1, 1, 0, 1  …  0, 1, 1, 0, 1, 1, 1, 1, 0, 1]
  'X' => [1, 1, 1, 0, 0, 1, 1, 1, 0, 1  …  0, 1, 1, 0, 1, 1, 1, 1, 0, 1]
  'C' => [0, 0, 1, 1, 0, 0, 0, 0, 1, 0  …  0, 0, 1, 1, 0, 1, 0, 0, 1, 1]
  'B' => [1, 1, 1, 0, 0, 1, 1, 1, 0, 1  …  0, 1, 1, 0, 1, 1, 1, 1, 0, 1]
  'D' => [0, 1, 1, 0, 0, 1, 1, 1, 1, 1  …  0, 1, 1, 0, 1, 0, 1, 1, 0, 1]
  'A' => [0, 1, 1, 0, 0, 1, 1, 1, 1, 1  …  0, 1, 1, 0, 1, 0, 1, 1, 0, 0]
  'R' => [0, 1, 1, 0, 0, 0, 1, 1, 1, 1  …  0, 1, 1, 1, 1, 1, 1, 1, 0, 1]
  'G' => [0, 1, 1, 0, 0, 1, 1, 1, 1, 1  …  0, 1, 1, 0, 1, 0, 1, 1, 0, 1]
  'N' => [0, 1, 1, 0, 0, 1, 1, 1, 1, 1  …  0, 1, 1, 0, 1, 0, 1, 1, 0, 1]
  'Q' => [1, 1, 1, 0, 0, 1, 1, 1, 1, 1  …  0, 1, 1, 0, 1, 0, 1, 1, 0, 1]
  'M' => [1, 1, 1, 0, 0, 1, 1, 1, 1, 1  …  0, 1, 1, 0, 1, 1, 1, 1, 0, 1]
  'K' => [1, 1, 1, 0, 0, 1, 1, 1, 1, 1  …  0, 1, 1, 0, 1, 0, 1, 1, 0, 1]
  'F' => [0,

In [64]:
println(AA_list[7], AA_list[25])
println(cosine(embeddings_trans[7,:], embeddings_trans[25,:]))
println(hamming(AA_dict_trans['E'], AA_dict_trans['Z']))

println(AA_list[1], AA_list[8])
println(cosine(embeddings_trans[1,:], embeddings_trans[8,:]))
println(hamming(AA_dict_trans['A'], AA_dict_trans['G']))

println(AA_list[1], AA_list[15])
println(cosine(embeddings_trans[1,:], embeddings_trans[15,:]))
println(hamming(AA_dict_trans['A'], AA_dict_trans['P']))

println(AA_list[9], AA_list[12])
println(cosine(embeddings_trans[9,:], embeddings_trans[12,:]))
println(hamming(AA_dict_trans['H'], AA_dict_trans['K']))

EZ
0.99999964
0.3136
AG
0.9999999
0.0238
AP
0.99999905
0.3048
HK
1.0
0.0579


In [None]:
function convolved_embedding(sequence, tokens, k=3)
    """
    Simple 2-layer convolved embedding in hyperdimensional space
    """
    # layer 1
    kmer_hdvs = []
    for i in 1:length(sequence)-k+1
        kmer = sequence[i:i+k-1]
        aa_hdvs = [circshift(tokens[aa], k-l) for (l, aa) in enumerate(kmer)]
        push!(kmer_hdvs, bitbind(aa_hdvs))
    end
    
    # layer 2
    conv_kmer_hdvs = []
    for i in 1:length(kmer_hdvs)-k+1
        convolved_kmers = kmer_hdvs[i:i+k-1]
        conv_hdvs = [circshift(convolved_kmers[l], k-l) for (l, km) in enumerate(convolved_kmers)]
        push!(conv_kmer_hdvs, bitbind(conv_hdvs))
    end
    
    return bitadd(conv_kmer_hdvs)
end