In [1]:
import pandas as pd
import numpy as np
import pickle
import random

from itertools import chain
from tqdm.auto import tqdm
from ast import literal_eval
from collections import defaultdict
from datasets import Dataset
tqdm.pandas()

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [6]:
def load_df(path):
    df = pd.read_csv(path, delimiter='\t').rename(columns={
        'EC number':'EC',
        'Gene Ontology (molecular function)':'MF',
        'Gene Ontology (biological process)':'BP',
        'Gene Ontology (cellular component)':'CC',
        'InterPro':'IP',
        'Gene3D':'3D',
        'Sequence':'seqs'
    }).astype('string')
    print(len(df))
    df['combined'] = df.progress_apply(lambda x: ' '.join(str(x[col]) for col in df.columns if col != 'seqs'), axis=1)
    df = df.sort_values(by='combined', key=lambda x: x.str.len(), ascending=False)
    df = df.drop_duplicates(subset='seqs', keep='first')
    df = df.drop('combined', axis=1)
    df = df.reset_index(drop=True)
    print(len(df))
    return df


def create_dictionary(input, start=1, name='ec'):
    id2label, label2id = {}, {}
    for index, entry in enumerate(input, start=start):
        entry = entry + '_' + name
        id2label[index] = entry
        label2id[entry] = index
    return id2label, label2id


def process_descriptors(input_list,
                        start=1,
                        name='ec',
                        id2label=None,
                        label2id=None,
                        filter_func=lambda d: d.strip()):
    col_list, new_col = [], []

    if id2label == None or label2id == None:
        for item in tqdm(input_list, desc=f'{name} make dicts'):
            descriptors = str(item).split(';')
            filtered_descriptors = [filter_func(d) for d in descriptors if filter_func(d) != 'None' and filter_func(d) != 'nan']
            col_list.extend(filtered_descriptors)
        col_list = sorted(list(set(col_list)))
        col_list.pop(0)
        len_col_list = len(col_list)
        id2label, label2id = create_dictionary(col_list, start=start, name=name)
    else:
        len_col_list = len(id2label.keys())

    for item in tqdm(input_list, desc=f'{name} make new column'):
        descriptors = str(item).split(';')
        filtered_descriptors = [filter_func(d) for d in descriptors if filter_func(d) != 'None' and filter_func(d) != 'nan']
        new_entry = [label2id.get(d+'_'+name, 0) for d in filtered_descriptors] or [0]
        new_col.append(new_entry)
    return new_col, id2label, label2id, len_col_list + start


def ec_processing(input_list, start=1, name='ec', id2label=None, label2id=None):
    return process_descriptors(input_list, start=start, name=name, id2label=id2label, label2id=label2id,
                               filter_func=lambda d: d.strip() if '-' not in d and 'n' not in d else '')

def go_processing(input_list, start=1, name='go', id2label=None, label2id=None):
    return process_descriptors(input_list, start=start, name=name, id2label=id2label, label2id=label2id,
                               filter_func=lambda d: d[d.find('[GO:')+1:d.find(']')].strip())

def cofactor_processing(input_list, start=1, name='co', id2label=None, label2id=None):
    return process_descriptors(input_list, start=start, name=name, id2label=id2label, label2id=label2id,
                               filter_func=lambda d: d[d.find('Name=')+5:].strip() if 'Name' in d else '')

def domain_processing(input_list, start=1, name='ip', id2label=None, label2id=None):
    return process_descriptors(input_list, start=start, name=name, id2label=id2label, label2id=label2id)


In [3]:
# Train data
df = load_df('swiss_prot_raw.tsv')
# EC
ecs = df['EC'].tolist()
new_ec_col, id2ec, ec2id, ec_len = ec_processing(ecs, name='ec', start=1)
# MF
mfs = df['MF'].tolist()
new_mf_col, id2mf, mf2id, mf_len = go_processing(mfs, name='mf', start=ec_len+1)
# BP
bps = df['BP'].tolist()
new_bp_col, id2bp, bp2id, bp_len = go_processing(bps, name='bp', start=mf_len+1)
# CC
ccs = df['CC'].tolist()
new_cc_col, id2cc, cc2id, cc_len = go_processing(ccs, name='cc', start=bp_len+1)
# IP
ips = df['IP'].tolist()
new_ip_col, id2ip, ip2id, ip_len = domain_processing(ips, name='ip', start=cc_len+1)
# 3D
threeds = df['3D'].tolist()
new_threed_col, id2threed, threed2id, threed_len = domain_processing(threeds, name='3d', start=ip_len+1)
# cofactor
cos = df['Cofactor'].tolist()
new_co_col, id2co, co2id, co_len = cofactor_processing(cos, name='co', start=threed_len+1)

461186


  0%|          | 0/461186 [00:00<?, ?it/s]

381212


ec make dicts:   0%|          | 0/381212 [00:00<?, ?it/s]

ec make new column:   0%|          | 0/381212 [00:00<?, ?it/s]

mf make dicts:   0%|          | 0/381212 [00:00<?, ?it/s]

mf make new column:   0%|          | 0/381212 [00:00<?, ?it/s]

bp make dicts:   0%|          | 0/381212 [00:00<?, ?it/s]

bp make new column:   0%|          | 0/381212 [00:00<?, ?it/s]

cc make dicts:   0%|          | 0/381212 [00:00<?, ?it/s]

cc make new column:   0%|          | 0/381212 [00:00<?, ?it/s]

ip make dicts:   0%|          | 0/381212 [00:00<?, ?it/s]

ip make new column:   0%|          | 0/381212 [00:00<?, ?it/s]

3d make dicts:   0%|          | 0/381212 [00:00<?, ?it/s]

3d make new column:   0%|          | 0/381212 [00:00<?, ?it/s]

co make dicts:   0%|          | 0/381212 [00:00<?, ?it/s]

co make new column:   0%|          | 0/381212 [00:00<?, ?it/s]

In [4]:
# make full dicts and check for no duplicates
all_id = [id2ec, id2mf, id2bp, id2cc, id2ip, id2threed, id2co]
all_label = [ec2id, mf2id, bp2id, cc2id, ip2id, threed2id, co2id]

id2label, label2id = {}, {}

key_counts = 0
for d in all_id:
    key_counts += len(d.keys())
    id2label.update(d)
print(key_counts)

key_counts = 0
for d in all_label:
    key_counts += len(d.keys())
    label2id.update(d)
print(key_counts)

for k, v in id2label.items():
    if k != label2id[v]:
        print(v)

for k, v in label2id.items():
    if k != id2label[v]:
        print(v)

print(len(id2label.keys()), len(label2id.keys()))

58945
58945
58945 58945


In [5]:
all_cols = [new_ec_col, new_mf_col, new_bp_col, new_cc_col, new_ip_col, new_threed_col, new_co_col]
combined_list = [sorted([item for item in list(chain.from_iterable(element)) if item != 0]) for element in zip(*all_cols)]

df['EC'] = new_ec_col
df['Cofactor'] = new_co_col
df['MF'] = new_mf_col
df['BP'] = new_bp_col
df['CC'] = new_cc_col
df['IP'] = new_ip_col
df['3D'] = new_threed_col
df['combined'] = combined_list
df['string_combined'] = df['combined'].astype('string')

unique_values = df['string_combined'].unique()

# Create a dictionary to store the rows with the longest 'combined' string for each unique value
unique_dict = {}

# Iterate over the DataFrame and update the dictionary
for _, row in tqdm(df.iterrows(), total=len(df)):
    combined_value = row['string_combined']
    if combined_value not in unique_dict:
        unique_dict[combined_value] = row
    if combined_value == '[0]':
        print(combined_value)

# Create the final DataFrame from the dictionary values
df_final = pd.DataFrame(unique_dict.values()).drop(columns=['string_combined'])

print("Final DataFrame length: ", len(df_final))

df_final.to_csv('processed_swiss_prot.csv', index=False)

  0%|          | 0/381212 [00:00<?, ?it/s]

Final DataFrame length:  94675


In [8]:
# Test data
df = load_df('trembl_all_aspects_raw.tsv')
# EC
ecs = df['EC'].tolist()
new_ec_col, id2ec, ec2id, ec_len = ec_processing(ecs, name='ec', id2label=id2label, label2id=label2id)
# MF
mfs = df['MF'].tolist()
new_mf_col, id2mf, mf2id, mf_len = go_processing(mfs, name='mf', id2label=id2label, label2id=label2id)
# BP
bps = df['BP'].tolist()
new_bp_col, id2bp, bp2id, bp_len = go_processing(bps, name='bp', id2label=id2label, label2id=label2id)
# CC
ccs = df['CC'].tolist()
new_cc_col, id2cc, cc2id, cc_len = go_processing(ccs, name='cc', id2label=id2label, label2id=label2id)
# IP
ips = df['IP'].tolist()
new_ip_col, id2ip, ip2id, ip_len = domain_processing(ips, name='ip', id2label=id2label, label2id=label2id)
# 3D
threeds = df['3D'].tolist()
new_threed_col, id2threed, threed2id, threed_len = domain_processing(threeds, name='3d', id2label=id2label, label2id=label2id)
# cofactor
cos = df['Cofactor'].tolist()
new_co_col, id2co, co2id, co_len = cofactor_processing(cos, name='co', id2label=id2label, label2id=label2id)

all_cols = [new_ec_col, new_mf_col, new_bp_col, new_cc_col, new_ip_col, new_threed_col, new_co_col]
combined_list = [sorted([item for item in list(chain.from_iterable(element)) if item != 0]) for element in zip(*all_cols)]

df['EC'] = new_ec_col
df['Cofactor'] = new_co_col
df['MF'] = new_mf_col
df['BP'] = new_bp_col
df['CC'] = new_cc_col
df['IP'] = new_ip_col
df['3D'] = new_threed_col
df['combined'] = combined_list
df['string_combined'] = df['combined'].astype('string')

unique_values = df['string_combined'].unique()

# Create a dictionary to store the rows with the longest 'combined' string for each unique value
unique_dict = {}

# Iterate over the DataFrame and update the dictionary
for _, row in tqdm(df.iterrows(), total=len(df)):
    combined_value = row['string_combined']
    if combined_value not in unique_dict:
        unique_dict[combined_value] = row
    if combined_value == '[0]':
        print(combined_value)

# Create the final DataFrame from the dictionary values
df_final = pd.DataFrame(unique_dict.values()).drop(columns=['string_combined'])

print("Final DataFrame length: ", len(df_final))

df_final.to_csv('processed_trembl.csv', index=False)

34747


  0%|          | 0/34747 [00:00<?, ?it/s]

34560


ec make new column:   0%|          | 0/34560 [00:00<?, ?it/s]

mf make new column:   0%|          | 0/34560 [00:00<?, ?it/s]

bp make new column:   0%|          | 0/34560 [00:00<?, ?it/s]

cc make new column:   0%|          | 0/34560 [00:00<?, ?it/s]

ip make new column:   0%|          | 0/34560 [00:00<?, ?it/s]

3d make new column:   0%|          | 0/34560 [00:00<?, ?it/s]

co make new column:   0%|          | 0/34560 [00:00<?, ?it/s]

  0%|          | 0/34560 [00:00<?, ?it/s]

Final DataFrame length:  5878


In [None]:
df = pd.read_csv('processed_swiss_prot.csv', converters={
    'EC': literal_eval,
    'Cofactor': literal_eval,
    'MF': literal_eval,
    'BP': literal_eval,
    'CC': literal_eval,
    'IP': literal_eval,
    '3D': literal_eval,
    'combined': literal_eval
})

df['combined'] = df['combined'].apply(set)
#df['EC'] = df['EC'].apply(set)
#df['Cofactor'] = df['Cofactor'].apply(set)
#df['MF'] = df['MF'].apply(set)
#df['BP'] = df['BP'].apply(set)
#df['CC'] = df['CC'].apply(set)
#df['IP'] = df['IP'].apply(set)
#df['3D'] = df['3D'].apply(set)
df.head(1)

In [None]:
sets = df['combined'].tolist()

n = len(sets)
all_indices = set(range(n))
aspect_dict = defaultdict(set)

for i, s in tqdm(enumerate(sets), total=n):
    for item in s:
        aspect_dict[item].add(i)

In [None]:
triplets = {
    'EC': [],
    'Cofactor': [],
    'MF': [],
    'BP': [],
    'CC': [],
    'IP': [],
    '3D': []
}

for i, row in tqdm(df.iterrows(), total=len(df)):
    p = row['seqs']

    for aspect in triplets.keys():
        item = random.choice(row[aspect])
        if item != 0:
            item_idxs = aspect_dict[item] - {i}
            if len(item_idxs) > 0:
                a_idx = random.choice(tuple(item_idxs))
                a_item = df.loc[a_idx, 'seqs']

                n_idx = random.choice(tuple(all_indices - item_idxs))
                n_item = df.loc[n_idx, 'seqs']

                triplets[aspect].append((p, a_item, n_item))
            else:
                continue
        else:
            continue

In [None]:
for i, (k, v) in enumerate(triplets.items()):
    print(len(v))

ps, ans, ns, acs = [], [], [], []
for i, (k, v) in enumerate(triplets.items()):
    for trip in v:
        p, a, n = trip
        ps.append(p)
        ans.append(a)
        ns.append(n)
        acs.append(i)

data = Dataset.from_dict({
    'positives':ps,
    'anchors':ans,
    'negatives':ns,
    'aspects':acs
})

In [None]:
data.push_to_hub('lhallee/triplets', split='epoch3')

In [None]:
import pickle
with open('triplets_epoch_1.pkl', 'wb') as f:
    pickle.dump(triplets, f)

In [None]:
# For looking / documenting swiss prot duplicates
seqs = df['seqs'].tolist()
entry_ids = df['Entry'].tolist()
print(len(seqs), len(list(set(seqs))))

counts = {}
duplicates = {}

for i, seq in enumerate(seqs):
    if seq in counts:
        counts[seq] += 1
        duplicates[seq].append(entry_ids[i])
    else:
        counts[seq] = 1
        duplicates[seq] = [entry_ids[i]]

# Write duplicates to a text file
with open("swiss_prot_duplicates.txt", "w") as file:
    for seq, ids in duplicates.items():
        if len(ids) > 1:
            count = len(ids)
            file.write(f"{count}\t{' '.join(ids)}\n")
            file.write(f"{seq}\n")