In [2]:
import random
from collections import Counter
from tqdm import tqdm

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import esm

  from .autonotebook import tqdm as notebook_tqdm


In [31]:
import scipy
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.metrics import confusion_matrix
from scipy.spatial.distance import cosine, euclidean

```
python scripts/extract.py esm2_t33_650M_UR50D data/polymerases.fasta output/pol_reprs/ --repr_layers 0 32 33 --include mean per_tok
```

In [4]:
BP_FASTA_PATH = "../data/ncbi-bacteriophage-human-meta.fasta"
BP_EMB_PATH = "../output/ncbi-bacteriophage-human-meta_reprs" # Path to directory of embeddings
POL_FASTA_PATH = "../data/polymerases.fasta"
POL_EMB_PATH = "../output/pol_reprs"
EMB_LAYER = 33

In [5]:
header = "sp|P00573|RPOL_BPT7 T7 RNA polymerase OS=Escherichia phage T7 OX=10760 GN=1 PE=1 SV=2"
fn = f'{POL_EMB_PATH}/{header}.pt'
embs = torch.load(fn)
emb_full = embs['representations'][EMB_LAYER].numpy()
emb_full.shape

(883, 1280)

In [30]:
num_pca_components = 60
pca = PCA(num_pca_components)
emb_train_pca = pca.fit_transform(emb_full)
first_3_pcs = emb_train_pca[:,:3]
concat = np.concatenate((emb_train_pca[:,1], emb_train_pca[:,2], emb_train_pca[:,3]))
concat

array([ 0.34093913,  1.5745403 ,  0.84059936, ...,  1.5515023 ,
       -1.3945863 ,  0.98084205], dtype=float32)

In [18]:
column_names = [f'PC{i+1}' for i in range(emb_train_pca.shape[1])]
df_pca = pd.DataFrame(data=emb_train_pca, columns=column_names)
df_pca

Unnamed: 0,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10,...,PC51,PC52,PC53,PC54,PC55,PC56,PC57,PC58,PC59,PC60
0,-1.030864,0.340940,1.311982,-0.131565,-0.090340,0.945702,-0.540043,-0.052349,-0.277934,0.624522,...,0.385101,0.240569,-0.256445,-0.566644,0.266996,0.147936,0.073352,0.357315,-0.144039,-0.069774
1,-2.106039,1.574543,1.552321,0.096282,-0.483227,0.637675,-0.499804,0.370680,0.198695,0.054993,...,-0.666571,0.102991,-0.654481,-0.938381,0.062526,0.246246,-0.268729,0.433944,-0.428055,-0.218815
2,-2.521439,0.840602,1.856283,0.693052,-0.804020,0.870917,-0.337703,0.264316,0.369523,-0.010375,...,-0.088197,0.130634,-0.871493,-0.360099,-0.251361,1.083913,0.159590,0.405571,-0.058522,-0.445614
3,-2.346556,-0.184050,2.489058,0.450000,-0.980110,0.895588,-0.137936,0.117267,-0.044523,-0.041353,...,-0.113024,0.284230,-0.522657,-0.609020,-0.062415,0.387962,0.029007,0.123185,-0.470590,-0.377726
4,-2.821247,0.943085,1.620764,0.294684,-0.900602,0.973819,-0.132614,0.506490,0.040704,0.219099,...,-0.475061,0.166890,-0.443193,-0.825755,-0.162715,0.206582,-0.124140,0.693672,-0.361744,-0.287003
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
878,-0.554717,0.939257,-0.999231,-0.021875,-0.766808,-0.238518,-1.419000,-0.281441,0.861862,0.583458,...,-0.526120,0.513637,-0.380624,0.709030,-0.239462,0.285849,-0.419388,-0.343810,0.472021,0.042697
879,1.266692,-2.319155,0.919116,-1.600543,0.047708,1.371742,-0.189883,-0.085715,0.384066,0.524285,...,-0.653752,0.882788,0.259164,0.286257,-0.166981,0.591659,0.075129,0.069988,0.620107,0.078419
880,2.397354,-1.701640,-0.759282,1.551502,-1.078396,1.999728,0.341463,1.080937,0.712508,0.444391,...,-0.668354,0.624846,0.665404,0.853355,-0.238125,0.840517,0.170138,-0.492183,0.324770,-0.252375
881,2.236915,-2.755748,0.737028,-1.394591,0.232159,1.560751,0.262047,0.200804,0.680733,-0.251049,...,-0.575725,0.486384,0.463153,0.626003,-0.071460,0.506172,0.123500,0.421387,0.319613,0.194179
