In [1]:
import pandas as pd
import numpy as np
import pickle
import random

from itertools import chain
from tqdm.auto import tqdm
from ast import literal_eval
from collections import defaultdict
from datasets import Dataset
tqdm.pandas()

In [3]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
df = pd.read_csv('swiss_prot_vec.tsv', delimiter='\t').rename(columns={
    'EC number':'EC',
    'Gene Ontology (molecular function)':'MF',
    'Gene Ontology (biological process)':'BP',
    'Gene Ontology (cellular component)':'CC',
    'InterPro':'IP',
    'Gene3D':'3D',
    'Sequence':'seqs'
}).astype('string')

print(len(df))

# Combine the columns (excluding 'seqs') into a single string
df['combined'] = df.progress_apply(lambda x: ' '.join(str(x[col]) for col in df.columns if col != 'seqs'), axis=1)

df = df.sort_values(by='combined', key=lambda x: x.str.len(), ascending=False)

# Drop duplicates based on the 'seqs' column, keeping the first occurrence (longest combined string)
df = df.drop_duplicates(subset='seqs', keep='first')

# Drop the temporary 'combined' column
#df = df.drop('combined', axis=1)

# Reset the index if needed
df = df.reset_index(drop=True)

len(df)

570830


  0%|          | 0/570830 [00:00<?, ?it/s]

482684

In [4]:
def create_dictionary(input, start=1, name='ec'):
    id2label, label2id = {}, {}
    for index, entry in enumerate(input, start=start):
        entry = entry + '_' + name
        id2label[index] = entry
        label2id[entry] = index
    return id2label, label2id


def process_descriptors(input_list, start=1, name='ec', filter_func=lambda d: d.strip()):
    col_list, new_col = [], []
    for item in tqdm(input_list, desc=f'{name} make dict'):
        descriptors = str(item).split(';')
        filtered_descriptors = [filter_func(d) for d in descriptors if filter_func(d) != 'None' and filter_func(d) != 'nan']
        col_list.extend(filtered_descriptors)
    col_list = sorted(list(set(col_list)))
    col_list.pop(0)
    id2label, label2id = create_dictionary(col_list, start=start, name=name)
    
    for item in tqdm(input_list, desc=f'{name} make new column'):
        descriptors = str(item).split(';')
        filtered_descriptors = [filter_func(d) for d in descriptors if filter_func(d) != 'None' and filter_func(d) != 'nan']
        new_entry = [label2id.get(d+'_'+name, 0) for d in filtered_descriptors] or [0]
        new_col.append(new_entry)
    return new_col, id2label, label2id, len(col_list) + start


def ec_processing(input_list, start=1, name='ec'):
    return process_descriptors(input_list, start=start, name=name,
                               filter_func=lambda d: d.strip() if '-' not in d and 'n' not in d else '')

def go_processing(input_list, start=1, name='go'):
    return process_descriptors(input_list, start=start, name=name,
                               filter_func=lambda d: d[d.find('[GO:')+1:d.find(']')].strip())

def cofactor_processing(input_list, start=1, name='co'):
    return process_descriptors(input_list, start=start, name=name,
                               filter_func=lambda d: d[d.find('Name=')+5:].strip() if 'Name' in d else '')

def domain_processing(input_list, start=1, name='ip'):
    return process_descriptors(input_list, start=start, name=name)


In [5]:
# EC
ecs = df['EC'].tolist()
new_ec_col, id2ec, ec2id, ec_len = ec_processing(ecs, name='ec', start=1)
# MF
mfs = df['MF'].tolist()
new_mf_col, id2mf, mf2id, mf_len = go_processing(mfs, name='mf', start=ec_len+1)
# BP
bps = df['BP'].tolist()
new_bp_col, id2bp, bp2id, bp_len = go_processing(bps, name='bp', start=mf_len+1)
# CC
ccs = df['CC'].tolist()
new_cc_col, id2cc, cc2id, cc_len = go_processing(ccs, name='cc', start=bp_len+1)
# IP
ips = df['IP'].tolist()
new_ip_col, id2ip, ip2id, ip_len = domain_processing(ips, name='ip', start=cc_len+1)
# 3D
threeds = df['3D'].tolist()
new_threed_col, id2threed, threed2id, threed_len = domain_processing(threeds, name='3d', start=ip_len+1)
# cofactor
cos = df['Cofactor'].tolist()
new_co_col, id2co, co2id, co_len = cofactor_processing(cos, name='co', start=threed_len+1)

ec make dict:   0%|          | 0/482684 [00:00<?, ?it/s]

ec make new column:   0%|          | 0/482684 [00:00<?, ?it/s]

mf make dict:   0%|          | 0/482684 [00:00<?, ?it/s]

mf make new column:   0%|          | 0/482684 [00:00<?, ?it/s]

bp make dict:   0%|          | 0/482684 [00:00<?, ?it/s]

bp make new column:   0%|          | 0/482684 [00:00<?, ?it/s]

cc make dict:   0%|          | 0/482684 [00:00<?, ?it/s]

cc make new column:   0%|          | 0/482684 [00:00<?, ?it/s]

ip make dict:   0%|          | 0/482684 [00:00<?, ?it/s]

ip make new column:   0%|          | 0/482684 [00:00<?, ?it/s]

3d make dict:   0%|          | 0/482684 [00:00<?, ?it/s]

3d make new column:   0%|          | 0/482684 [00:00<?, ?it/s]

co make dict:   0%|          | 0/482684 [00:00<?, ?it/s]

co make new column:   0%|          | 0/482684 [00:00<?, ?it/s]

In [6]:
# make full dicts and check for no duplicates
all_id = [id2ec, id2mf, id2bp, id2cc, id2ip, id2threed, id2co]
all_label = [ec2id, mf2id, bp2id, cc2id, ip2id, threed2id, co2id]

id2label, label2id = {}, {}

key_counts = 0
for d in all_id:
    key_counts += len(d.keys())
    id2label.update(d)
print(key_counts)

key_counts = 0
for d in all_label:
    key_counts += len(d.keys())
    label2id.update(d)
print(key_counts)

for k, v in id2label.items():
    if k != label2id[v]:
        print(v)

for k, v in label2id.items():
    if k != id2label[v]:
        print(v)

len(id2label.keys()), len(label2id.keys())

73461
73461


(73461, 73461)

In [7]:
all_cols = [new_ec_col, new_mf_col, new_bp_col, new_cc_col, new_ip_col, new_threed_col, new_co_col]
combined_list = [sorted([item for item in list(chain.from_iterable(element)) if item != 0]) for element in zip(*all_cols)]

In [8]:
df['EC'] = new_ec_col
df['Cofactor'] = new_co_col
df['MF'] = new_mf_col
df['BP'] = new_bp_col
df['CC'] = new_cc_col
df['IP'] = new_ip_col
df['3D'] = new_threed_col
df['combined'] = combined_list
df['string_combined'] = df['combined'].astype('string')
len(df)

(482684,
     Entry      EC Cofactor                                                 MF  \
 0  Q02248  [5487]      [0]  [10222, 10224, 5770, 11962, 12390, 13122, 8651...   
 1  Q9WU82  [5487]      [0]  [10222, 10224, 5770, 11962, 12390, 13122, 8651...   
 2  Q63673     [0]      [0]        [6887, 13312, 6902, 9954, 6690, 7005, 7026]   
 3  Q62226     [0]      [0]        [6887, 13312, 6902, 9954, 6690, 7005, 7026]   
 4  P04202  [5487]      [0]  [5809, 6699, 9662, 8651, 6935, 9878, 9986, 101...   
 
                                                   BP  \
 0  [26917, 19952, 19951, 23152, 16547, 27100, 208...   
 1  [26917, 19952, 19951, 23152, 16547, 27100, 208...   
 2  [22764, 23379, 23243, 16023, 13897, 23242, 165...   
 3  [22764, 23379, 23243, 16023, 13897, 23242, 165...   
 4  [14373, 16512, 18904, 14687, 15390, 24738, 245...   
 
                                                   CC  \
 0  [31929, 33069, 33229, 32175, 32176, 32174, 323...   
 1  [31929, 33069, 33229, 32175, 32176,

In [21]:
unique_values = df['string_combined'].unique()

# Create a dictionary to store the rows with the longest 'combined' string for each unique value
unique_dict = {}

# Iterate over the DataFrame and update the dictionary
for _, row in tqdm(df.iterrows(), total=len(df)):
    combined_value = row['string_combined']
    if combined_value not in unique_dict:
        unique_dict[combined_value] = row
    if combined_value == '[0]':
        print(combined_value)

# Create the final DataFrame from the dictionary values
df_final = pd.DataFrame(unique_dict.values()).drop(columns=['string_combined'])

print("Final DataFrame length:", len(df_final))

  0%|          | 0/482684 [00:00<?, ?it/s]

Final DataFrame length: 141387
Final DataFrame head:
    Entry      EC Cofactor                                                 MF  \
0  Q02248  [5487]      [0]  [10222, 10224, 5770, 11962, 12390, 13122, 8651...   
1  Q9WU82  [5487]      [0]  [10222, 10224, 5770, 11962, 12390, 13122, 8651...   
2  Q63673     [0]      [0]        [6887, 13312, 6902, 9954, 6690, 7005, 7026]   
3  Q62226     [0]      [0]        [6887, 13312, 6902, 9954, 6690, 7005, 7026]   
4  P04202  [5487]      [0]  [5809, 6699, 9662, 8651, 6935, 9878, 9986, 101...   

                                                  BP  \
0  [26917, 19952, 19951, 23152, 16547, 27100, 208...   
1  [26917, 19952, 19951, 23152, 16547, 27100, 208...   
2  [22764, 23379, 23243, 16023, 13897, 23242, 165...   
3  [22764, 23379, 23243, 16023, 13897, 23242, 165...   
4  [14373, 16512, 18904, 14687, 15390, 24738, 245...   

                                                  CC  \
0  [31929, 33069, 33229, 32175, 32176, 32174, 323...   
1  [31929, 

In [23]:
df_final.to_csv('processed_swiss_prot.csv', index=False)

In [2]:
df = pd.read_csv('processed_swiss_prot.csv', converters={
    'EC': literal_eval,
    'Cofactor': literal_eval,
    'MF': literal_eval,
    'BP': literal_eval,
    'CC': literal_eval,
    'IP': literal_eval,
    '3D': literal_eval,
    'combined': literal_eval
})

df['combined'] = df['combined'].apply(set)
#df['EC'] = df['EC'].apply(set)
#df['Cofactor'] = df['Cofactor'].apply(set)
#df['MF'] = df['MF'].apply(set)
#df['BP'] = df['BP'].apply(set)
#df['CC'] = df['CC'].apply(set)
#df['IP'] = df['IP'].apply(set)
#df['3D'] = df['3D'].apply(set)

In [3]:
sets = df['combined'].tolist()

In [4]:
n = len(sets)
all_indices = set(range(n))
aspect_dict = defaultdict(set)

for i, s in tqdm(enumerate(sets), total=n):
    for item in s:
        aspect_dict[item].add(i)

  0%|          | 0/141387 [00:00<?, ?it/s]

In [7]:
"""overlapping = [set()] * n

for i in tqdm(range(n)):
    seen = {j for item in sets[i] for j in aspect_dict[item]}
    seen.remove(i)
    overlapping[i] = seen

non_overlapping = [all_indices - overlapping[i] - {i} for i in range(n)]"""

  0%|          | 0/141387 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [5]:
triplets = {
    'EC': [],
    'Cofactor': [],
    'MF': [],
    'BP': [],
    'CC': [],
    'IP': [],
    '3D': []
}

for i, row in tqdm(df.iterrows(), total=len(df)):
    p = row['seqs']
    
    for aspect in triplets.keys():
        item = random.choice(row[aspect])
        if item != 0:
            item_idxs = aspect_dict[item] - {i}
            if len(item_idxs) > 0:
                a_idx = random.choice(tuple(item_idxs))
                a_item = df.loc[a_idx, 'seqs']
                n_idx = random.choice(tuple(all_indices - item_idxs))
                n_item = df.loc[n_idx, 'seqs']
                triplets[aspect].append((p, a_item, n_item))
            else:
                continue
        else:
            continue

  0%|          | 0/141387 [00:00<?, ?it/s]

In [None]:
for i, (k, v) in enumerate(triplets.items()):
    print(len(v))

127204
5161
140410
140952
141335
105121
94537


In [None]:
ps, ans, ns, acs = [], [], [], []
for i, (k, v) in enumerate(triplets.items()):
    for trip in v:
        p, a, n = trip
        ps.append(p)
        ans.append(a)
        ns.append(n)
        acs.append(i)

In [None]:
data = Dataset.from_dict({
    'positives':ps,
    'anchors':ans,
    'negatives':ns,
    'aspects':acs
})

In [None]:
data.push_to_hub('lhallee/triplets', split='epoch3')

Uploading the dataset shards:   0%|          | 0/3 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/252 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/252 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/252 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/402 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


CommitInfo(commit_url='https://huggingface.co/datasets/lhallee/triplets/commit/0d436e766f77a59708b0363ad41b27287c49658d', commit_message='Upload dataset', commit_description='', oid='0d436e766f77a59708b0363ad41b27287c49658d', pr_url=None, pr_revision=None, pr_num=None)

In [19]:
triplets['EC'][0]

('MATQADLMELDMAMEPDRKAAVSHWQQQSYLDSGIHSGATTTAPSLSGKGNPEEEDVDTSQVLYEWEQGFSQSFTQEQVADIDGQYAMTRAQRVRAAMFPETLDEGMQIPSTQFDAAHPTNVQRLAEPSQMLKHAVVNLINYQDDAELATRAIPELTKLLNDEDQVVVNKAAVMVHQLSKKEASRHAIMRSPQMVSAIVRTMQNTNDVETARCTAGTLHNLSHHREGLLAIFKSGGIPALVKMLGSPVDSVLFYAITTLHNLLLHQEGAKMAVRLAGGLQKMVALLNKTNVKFLAITTDCLQILAYGNQESKLIILASGGPQALVNIMRTYTYEKLLWTTSRVLKVLSVCSSNKPAIVEAGGMQALGLHLTDPSQRLVQNCLWTLRNLSDAATKQEGMEGLLGTLVQLLGSDDINVVTCAAGILSNLTCNNYKNKMMVCQVGGIEALVRTVLRAGDREDITEPAICALRHLTSRHQEAEMAQNAVRLHYGLPVVVKLLHPPSHWPLIKATVGLIRNLALCPANHAPLREQGAIPRLVQLLVRAHQDTQRRTSMGGTQQQFVEGVRMEEIVEGCTGALHILARDVHNRIVIRGLNTIPLFVQLLYSPIENIQRVAAGVLCELAQDKEAAEAIEAEGATAPLTELLHSRNEGVATYAAAVLFRMSEDKPQDYKKRLSVELTSSLFRTEPMAWNETADLGLDIGAQGEALGYRQDDPSYRSFHSGGYGQDALGMDPMMEHEMGGHHPGADYPVDGLPDLGHAQDLMDGLPPGDSNQLAWFDTDL',
 'MATALPRTLGELQLYRILQKANLLSYFDAFIQQGGDDVQQLCEAGEEEFLEIMALVGMASKPLHVRRLQKALRDWVTNPGLFNQPLTSLPVSSIPIYKLPEGSPTWLGISCNSYERSSSAREPHLKVPKCAATTCVQSLGQGKSEVGSLALQSVSESRLWQGHHTTESEHSLSPADLGSPASPKESSEALDAAAALSVAECVERMAPTLPKS

In [23]:
import pickle
with open('triplets_epoch_1.pkl', 'wb') as f:
    pickle.dump(triplets, f)

In [94]:
# For looking / documenting swiss prot duplicates
seqs = df['seqs'].tolist()
entry_ids = df['Entry'].tolist()
print(len(seqs), len(list(set(seqs))))

counts = {}
duplicates = {}

for i, seq in enumerate(seqs):
    if seq in counts:
        counts[seq] += 1
        duplicates[seq].append(entry_ids[i])
    else:
        counts[seq] = 1
        duplicates[seq] = [entry_ids[i]]

# Write duplicates to a text file
with open("swiss_prot_duplicates.txt", "w") as file:
    for seq, ids in duplicates.items():
        if len(ids) > 1:
            count = len(ids)
            file.write(f"{count}\t{' '.join(ids)}\n")
            file.write(f"{seq}\n")