In [1]:
import os
import numpy as np
import pandas as pd
import torch
import random
from streaming.base import MDSWriter

In [2]:
data = np.load('pairwise_representation.npz')
text_embeddings = data['description_repr_array']
#convert to torch tensor
text_embeddings = torch.tensor(text_embeddings)

#load text strings from text_sequence.txt
with open('text_sequence.txt', 'r') as f:
    text_strings = f.readlines()
    #strip newline characters
    text_strings = [x.strip() for x in text_strings]

#load text strings from text_sequence.txt
with open('protein_sequence.txt', 'r') as f:
    protein_strings = f.readlines()
    #strip newline characters
    protein_strings = [x.strip() for x in protein_strings]
    #remove spaces
    protein_strings = [x.replace(' ', '') for x in protein_strings]

In [3]:
#create a dictionary to map text strings to their embeddings
text_dict = dict(zip(text_strings, text_embeddings))
#seems to already automatically drop duplicate keys
#text_dict = {k: v for k, v in text_dict.items() if k}

In [4]:
len(text_dict)

102234

In [14]:
torch.save(text_dict, '../text2encoding.pt')

### Selected common ECs

In [5]:
df = pd.DataFrame({'Sequence': protein_strings, 'Text': text_strings})
df

Unnamed: 0,Sequence,Text
0,MVRLFYNPIKYLFYRRSCKKRLRKALKKLNFYHPPKECCQIYRLLE...,"Plays a role in virus cell tropism, and may be..."
1,MVRLFHNPIKCLFYRGSRKTREKKLRKSLKKLNFYHPPGDCCQIYR...,"Plays a role in virus cell tropism, and may be..."
2,MVRLFRNPIKCIFYRRSRKIQEKKLRKSLKKLNFYHPPEDCCQIYR...,"Plays a role in virus cell tropism, and may be..."
3,MVRLFRNPIKCIFYRRSRKIQEKKLRKSLKKLNFYHPPEDCCQIYR...,"Plays a role in virus cell tropism, and may be..."
4,MGNKESKYLEMCSEEAWLNIPNIFKCIFIRKLFYNKWLKYQEKKLK...,"Plays a role in virus cell tropism, and may be..."
...,...,...
441008,MTDPYSNFFTDWFKSNPFHHYPNSSTNPSPHPLPPVTPPSSFFFFP...,Transcriptional regulator required for normal ...
441009,MNSYETKGLSFESPSFIEWLKPQSSTTSSKSVLYRGKTRDAISRSN...,Probable transcriptional regulator. Belongs to...
441010,MLFSTVLSHRTLYILTCPNTLIHSYTHPHIHAYLAFTGFLTQLHHL...,Probable transcriptional regulator. Belongs to...
441011,MSNPACSNLFNNGCDHNSFNYSTSLSYIYNSHGSYYYSNTTNPNYI...,Probable transcriptional regulator. Belongs to...


In [6]:
df['Text'].value_counts()

Text
Component of the ubiquinol-cytochrome c reductase complex (complex III or cytochrome b-c1 complex) that is part of the mitochondrial respiratory chain. The b-c1 complex mediates electron transfer from ubiquinol to cytochrome c. Contributes to the generation of a proton gradient across the mitochondrial membrane that is then used for ATP synthesis. Binds 2 heme b groups non-covalently. The cytochrome bc1 complex contains 11 subunits: 3 respiratory subunits (MT-CYB, CYC1 and UQCRFS1), 2 core proteins (UQCRC1 and UQCRC2) and 6 low-molecular weight proteins (UQCRH/QCR6, UQCRB/QCR7, UQCRQ/QCR8, UQCR10/QCR9, UQCR11/QCR10 and a cleavage product of UQCRFS1). This cytochrome bc1 complex then forms a dimer. Heme 1 (or BL or b562) is low-potential and absorbs at about 562 nm, and heme 2 (or BH or b566) is high-potential and absorbs at about 566 nm. Belongs to the cytochrome b family. The full-length protein contains only eight transmembrane helices, not nine as predicted by bioinformatics to

In [7]:
selected = df["Text"].value_counts()[df["Text"].value_counts() > 500].sample(24, random_state=42)
selected

Text
carbamoyl phosphate + L-aspartate = H(+) + N-carbamoyl-L-aspartate + phosphate Pyrimidine metabolism; UMP biosynthesis via de novo pathway; (S)-dihydroorotate from bicarbonate: step 2/3. Belongs to the aspartate/ornithine carbamoyltransferase superfamily. ATCase family.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    615
ATP + L-cysteine + tRNA(Cys) = AMP + diphosph

In [8]:
#take one of each entry in selected from df
df_selected = pd.DataFrame()
for index in selected.index:
    df_selected = pd.concat([df_selected, df[df["Text"] == index].sample(1, random_state=42)])

#TODO: manually modify the sentences


### Prompts for inference in ProteinDT

In [None]:
#write it to a protein.txt and a text.txt file in the original format
os.makedirs("modified_24", exist_ok=True)

with open("modified_24/text_sequence.txt", "w") as file:
    for index, row in df_selected.iterrows():
        file.write(index + "\n")
        file.write(row["Text"] + "\n")

with open("modified_24/protein_sequence.txt", "w") as file:
    for index, row in df_selected.iterrows():
        file.write(index + "\n")
        file.write(row["Sequence"] + "\n")


In [None]:
#make a file of inference prompts, where each row is a text prompt.
#This is for running generation with ProteinDT
#../../examples/downstream_Text2Protein/
with open("step_01_text_retrieval.txt", "w") as file:
    for text in selected.index:
        for i in range(200):
            file.write(text + "\n")

### Save sharded dataset

In [15]:
#shuffle the df
shuffled_df = df.sample(frac=1, random_state=42)

In [20]:
name = "swissprot-text"

#ensure the directory exists
os.makedirs('../sharded_datasets/{}'.format(name), exist_ok=True)

for split in ['train']:
    shuffled_df.to_csv('../sharded_datasets/{}/{}.csv'.format(name, split))
    #split_df = pd.read_csv('data/{}/{}.csv'.format(name, split))
                    
    sequences = shuffled_df['Sequence'].values.tolist()
    texts = shuffled_df['Text'].values.tolist()

    output_dir = '../sharded_datasets/{}/{}'.format(name, split)
    #ensure the directory exists
    os.makedirs(output_dir, exist_ok=True)

    columns = {'sequence': 'str', 'text': 'str'} 
    with MDSWriter(out=output_dir, columns=columns) as out:
        for seq, text in zip(sequences, texts):
            out.write({'sequence': seq, 'text': text})