In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
from evo.dataset import FastaDataset

fasta_fpath = "/homefs/home/lux70/storage/data/pfam/Pfam-A.fasta"
ds = FastaDataset(fasta_fpath)

100%|██████████| 10.6G/10.6G [01:40<00:00, 114MB/s] 
100%|██████████| 10.6G/10.6G [01:15<00:00, 150MB/s]


In [63]:
from torch.utils.data import DataLoader

dataloader = DataLoader(ds, batch_size=1024, shuffle=True)
batch = next(iter(dataloader))

In [64]:
batch[0][:10]

('A0A699YNH5_HAELA/23-130 A0A699YNH5.1 PF08433.14;KTI12;',
 'A0A1T4SZU4_9HYPH/127-242 A0A1T4SZU4.1 PF00717.27;Peptidase_S24;',
 'A0A135LN36_PENPA/24-169 A0A135LN36.1 PF19327.3;Ap4A_phos_N;',
 'F6SKP2_MONDO/254-308 F6SKP2.2 PF07443.17;HARP;',
 'A0A7J6R6W1_PEROL/19-296 A0A7J6R6W1.1 PF07690.20;MFS_1;',
 'A0A392MEI1_9FABA/32-120 A0A392MEI1.1 PF13966.10;zf-RVT;',
 'A0A8H6CR98_9LECA/9-265 A0A8H6CR98.1 PF01267.21;F-actin_cap_A;',
 'A0A4R6UVI0_9ACTN/3-302 A0A4R6UVI0.1 PF01156.23;IU_nuc_hydro;',
 'A0A8B9BKV1_9AVES/20-97 A0A8B9BKV1.1 PF00238.23;Ribosomal_L14;',
 'A0A4Q3ZC85_9RHOB/16-109 A0A4Q3ZC85.1 PF01243.24;Putative_PNPOx;')

In [65]:
def header_to_accession(header):
    subid = header.split(" ")[-1]
    return subid.split(".")[0]

accessions_batch = list(map(lambda x: header_to_accession(x), batch[0][:10]))
accessions_batch

['PF08433',
 'PF00717',
 'PF19327',
 'PF07443',
 'PF07690',
 'PF13966',
 'PF01267',
 'PF01156',
 'PF00238',
 'PF01243']

In [66]:
fam_to_clan_df = pd.read_csv(
    "/homefs/home/lux70/storage/data/pfam/Pfam-A.clans.tsv",
    sep="\t",
    header=None
)
header = ["accession", "clan", "short_name", "gene_name", "description"]
fam_to_clan_df.columns = header

print(fam_to_clan_df.shape)
fam_to_clan_df.head(n=20)

(20795, 5)


Unnamed: 0,accession,clan,short_name,gene_name,description
0,PF00001,CL0192,GPCR_A,7tm_1,7 transmembrane receptor (rhodopsin family)
1,PF00002,CL0192,GPCR_A,7tm_2,7 transmembrane receptor (Secretin family)
2,PF00003,CL0192,GPCR_A,7tm_3,7 transmembrane sweet-taste receptor of 3 GCPR
3,PF00004,CL0023,P-loop_NTPase,AAA,ATPase family associated with various cellular...
4,PF00005,CL0023,P-loop_NTPase,ABC_tran,ABC transporter
5,PF00006,CL0023,P-loop_NTPase,ATP-synt_ab,"ATP synthase alpha/beta family, nucleotide-bin..."
6,PF00007,CL0079,Cystine-knot,Cys_knot,Cystine-knot domain
7,PF00008,CL0001,EGF,EGF,EGF-like domain
8,PF00009,CL0023,P-loop_NTPase,GTP_EFTU,Elongation factor Tu GTP binding domain
9,PF00010,,,HLH,Helix-loop-helix DNA-binding domain


In [67]:
print(fam_to_clan_df.clan.isnull().sum())
print(len(fam_to_clan_df.clan.unique()))

11868
660


In [68]:
print(fam_to_clan_df.description.isnull().sum())
print(len(fam_to_clan_df.description.unique()))

0
20086


In [69]:
print(len(fam_to_clan_df.accession.unique()))
assert len(fam_to_clan_df.accession.unique()) == fam_to_clan_df.shape[0]

20795


In [75]:
clans = fam_to_clan_df.clan.dropna().unique()
print(len(clans))
clans.sort()
print(clans[:10])
print(clans[-10:])

clans_to_idx = dict(zip(clans, np.arange(len(clans))))
clans_to_idx

659
['CL0001' 'CL0003' 'CL0004' 'CL0005' 'CL0007' 'CL0010' 'CL0012' 'CL0013'
 'CL0014' 'CL0015']
['CL0761' 'CL0763' 'CL0764' 'CL0765' 'CL0766' 'CL0767' 'CL0768' 'CL0769'
 'CL0770' 'CL0771']


{'CL0001': 0,
 'CL0003': 1,
 'CL0004': 2,
 'CL0005': 3,
 'CL0007': 4,
 'CL0010': 5,
 'CL0012': 6,
 'CL0013': 7,
 'CL0014': 8,
 'CL0015': 9,
 'CL0016': 10,
 'CL0018': 11,
 'CL0020': 12,
 'CL0021': 13,
 'CL0022': 14,
 'CL0023': 15,
 'CL0025': 16,
 'CL0026': 17,
 'CL0027': 18,
 'CL0028': 19,
 'CL0029': 20,
 'CL0030': 21,
 'CL0031': 22,
 'CL0032': 23,
 'CL0033': 24,
 'CL0036': 25,
 'CL0037': 26,
 'CL0039': 27,
 'CL0040': 28,
 'CL0041': 29,
 'CL0042': 30,
 'CL0043': 31,
 'CL0044': 32,
 'CL0045': 33,
 'CL0046': 34,
 'CL0047': 35,
 'CL0048': 36,
 'CL0050': 37,
 'CL0051': 38,
 'CL0052': 39,
 'CL0053': 40,
 'CL0054': 41,
 'CL0055': 42,
 'CL0056': 43,
 'CL0057': 44,
 'CL0059': 45,
 'CL0060': 46,
 'CL0061': 47,
 'CL0062': 48,
 'CL0063': 49,
 'CL0064': 50,
 'CL0065': 51,
 'CL0066': 52,
 'CL0067': 53,
 'CL0068': 54,
 'CL0069': 55,
 'CL0070': 56,
 'CL0071': 57,
 'CL0072': 58,
 'CL0073': 59,
 'CL0074': 60,
 'CL0075': 61,
 'CL0077': 62,
 'CL0079': 63,
 'CL0080': 64,
 'CL0081': 65,
 'CL0082': 66,
 'CL0

In [76]:
accession_to_clan = fam_to_clan_df.groupby("accession").first().filter(['accession','clan'], axis=1)
accession_to_clan = accession_to_clan.to_dict()['clan']

accession_to_clan

{'PF00001': 'CL0192',
 'PF00002': 'CL0192',
 'PF00003': 'CL0192',
 'PF00004': 'CL0023',
 'PF00005': 'CL0023',
 'PF00006': 'CL0023',
 'PF00007': 'CL0079',
 'PF00008': 'CL0001',
 'PF00009': 'CL0023',
 'PF00010': None,
 'PF00011': 'CL0190',
 'PF00012': 'CL0108',
 'PF00013': 'CL0007',
 'PF00014': None,
 'PF00015': None,
 'PF00016': 'CL0036',
 'PF00017': 'CL0541',
 'PF00018': 'CL0010',
 'PF00019': 'CL0079',
 'PF00020': 'CL0607',
 'PF00021': 'CL0117',
 'PF00022': 'CL0108',
 'PF00023': 'CL0465',
 'PF00024': 'CL0168',
 'PF00025': 'CL0023',
 'PF00026': 'CL0129',
 'PF00027': 'CL0029',
 'PF00028': 'CL0159',
 'PF00029': 'CL0375',
 'PF00030': 'CL0333',
 'PF00031': 'CL0320',
 'PF00032': None,
 'PF00033': 'CL0328',
 'PF00034': 'CL0318',
 'PF00035': 'CL0196',
 'PF00036': 'CL0220',
 'PF00037': 'CL0344',
 'PF00038': None,
 'PF00039': 'CL0451',
 'PF00040': 'CL0602',
 'PF00041': 'CL0159',
 'PF00042': 'CL0090',
 'PF00043': 'CL0497',
 'PF00044': 'CL0063',
 'PF00045': None,
 'PF00046': 'CL0123',
 'PF00047': 

In [80]:
def header_to_clan_idx(header):
    subid = header.split(" ")[-1]
    accession = subid.split(".")[0]
    clan_id = accession_to_clan[accession]
    if clan_id is None:
        return len(clans)  # dummy idx for unknown clan
    else:
        return clans_to_idx[clan_id]

In [81]:
batch = next(iter(dataloader))

clans_batch = list(map(lambda x: header_to_clan_idx(x), batch[0]))
clans_batch

[12,
 281,
 151,
 659,
 659,
 34,
 25,
 152,
 12,
 12,
 329,
 659,
 151,
 100,
 160,
 25,
 659,
 659,
 348,
 70,
 659,
 56,
 659,
 659,
 93,
 397,
 218,
 215,
 49,
 90,
 217,
 174,
 49,
 379,
 198,
 659,
 13,
 326,
 7,
 12,
 310,
 659,
 659,
 15,
 166,
 93,
 149,
 5,
 659,
 659,
 25,
 154,
 49,
 51,
 659,
 659,
 100,
 25,
 659,
 648,
 15,
 659,
 528,
 659,
 407,
 34,
 373,
 113,
 194,
 153,
 8,
 9,
 348,
 154,
 25,
 13,
 49,
 659,
 103,
 46,
 49,
 166,
 98,
 310,
 25,
 137,
 113,
 330,
 90,
 49,
 49,
 659,
 191,
 25,
 306,
 574,
 16,
 659,
 132,
 322,
 659,
 142,
 25,
 425,
 182,
 86,
 659,
 659,
 15,
 76,
 87,
 133,
 659,
 15,
 659,
 159,
 659,
 100,
 412,
 119,
 100,
 259,
 49,
 153,
 169,
 311,
 147,
 359,
 659,
 85,
 659,
 131,
 19,
 132,
 15,
 659,
 299,
 100,
 221,
 318,
 15,
 659,
 131,
 49,
 90,
 49,
 659,
 154,
 5,
 25,
 152,
 53,
 135,
 337,
 244,
 360,
 47,
 659,
 659,
 455,
 305,
 85,
 81,
 37,
 159,
 15,
 659,
 18,
 48,
 659,
 25,
 250,
 659,
 659,
 156,
 25,
 100,
 14,
 1

In [83]:
max(clans_batch)

659

# Check single layer decoder

In [84]:
batch[1]

('YLAGHAMYVLGNVRIREGFLDEAFDLHLRTLKCFEPTYGVSHHKTGNARLKVAWHLARVGKYEEALVHLQTALKIYQS',
 'GRRKTSAARVFLTAGSGNISINDRSIDVYFGREVARMIVRQPFEIVDMVNKFDMKITVNGGGNFGQAGAIRHGIARALLQYDENLRSPLRKAGFLTRDARQVERKKVGLHKARKRPQFSKR',
 'GAFFGDKMSPLSDTTNLASTVVNVDLFDHIRNMSWTTVPAFIISFIIFVFLSPTEAIAHFEKITLLKEGLLELNIVHWYSLIPFVILALLAILRVSAIITLSTGIISATIIGLIFNQSVTFKEALSILYFGYESTSGIEEIDALLSRGGMESMMFSISLVLLALSMGGLLFKLTILPVLLSSMKRLLINAPSLMCSAAASAIGINFLIGEQYLSILLTGNAFREHFKKAGLEAKNLSRILEDSGTVVNPLVPWSVCGIFITNILGVATVDYLPFAFFCLLSPILTLVFSITGL',
 'LLRAHKAELERHHGHRITDDMRIAMHAMLRCQTEKLGQSQWYCAHCHFDDRRPLSCGHRHCPQCLHRTTSDWLKRQKQKLLPVHYFMVTFTL',
 'LQGAGDQGLMFGYACRETDELMPLPITLAHRLAARLADVRKQGIIPYLRPDGKTQVTIEYEGDRPVRVDTVVVSSQHSADIDLKTLLMPDVKEHVVEPELRALDIDVDDYRLLVNPTGRF',
 'RAVIRGIATNSDGRTPGIASPSAEAQAAAIRSAYANAGLTDFNETRYLECHGTGTQAGDPQEVQGVASVFSQSRSPDQPLVIGSIKSNIGHSEPAAGVSGMLKAVLIVENGSIPGNPTF',
 'QLGFGVWRLEEDDAPKVVGAAIDAGYRHIDTAQGYDNEAGVGRAIAEASVGREELFITSKLRNGHQGYDSALRSLDESLARLQLDYLDLFLIHWPAPQHDRYADTWRAFVEAQKAGKVRSIGVSNFLPQHIQRIIDETGVTPAVN

In [85]:
from plaid.esmfold import batch_encode_sequences, esmfold_v1

esmfold = esmfold_v1()
esmfold.cuda()

Creating ESMFold...
ESMFold model loaded in 0.63 minutes.


In [92]:
outs = batch_encode_sequences(batch[1])

In [97]:
print(outs[0].shape)

x = outs[0][:16, :512].cuda()
mask = outs[1][:16, :512].cuda()

torch.Size([1024, 801])


In [104]:
s = esmfold.embed_for_folding_trunk(x, mask)[0]

In [107]:
from plaid.decoder import FullyConnectedNetwork

net = FullyConnectedNetwork.from_pretrained(ckpt_path="/homefs/home/lux70/storage/plaid/checkpoints/sequence_decoder/112385/last.ckpt")
net = net.cuda()
logits = net(s)

In [111]:
pred_toks = logits.argmax(-1)

In [113]:
(x == pred_toks).sum() / np.prod(x.shape)

tensor(0.9907, device='cuda:0')