In [1]:
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt

from itertools import chain
from tqdm.auto import tqdm
from ast import literal_eval
from collections import defaultdict, Counter
from datasets import Dataset, concatenate_datasets, load_dataset, DatasetDict
tqdm.pandas()

In [2]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# Functions
def load_df(path):
    df = pd.read_csv(path, delimiter='\t').rename(columns={
        'EC number':'EC',
        'Gene Ontology (molecular function)':'MF',
        'Gene Ontology (biological process)':'BP',
        'Gene Ontology (cellular component)':'CC',
        'Sequence':'seqs'
    }).astype('string')
    print(len(df))
    df['combined'] = df.progress_apply(lambda x: ' '.join(str(x[col]) for col in df.columns if col != 'seqs'), axis=1)
    df = df.sort_values(by='combined', key=lambda x: x.str.len(), ascending=False)
    df = df.drop_duplicates(subset='seqs', keep='first')
    df = df.drop('combined', axis=1)
    df = df.reset_index(drop=True)
    print(len(df))
    return df


def create_dictionary(input, start=1, name='ec'):
    id2label, label2id = {}, {}
    for index, entry in enumerate(input, start=start):
        entry = entry + '_' + name
        id2label[index] = entry
        label2id[entry] = index
    return id2label, label2id


def process_descriptors(input_list,
                        start=1,
                        name='ec',
                        id2label=None,
                        label2id=None,
                        filter_func=lambda d: d.strip()):
    col_list, new_col = [], []

    if id2label == None or label2id == None:
        for item in tqdm(input_list, desc=f'{name} make dicts'):
            descriptors = str(item).split(';')
            filtered_descriptors = [filter_func(d) for d in descriptors]
            filtered_descriptors = [d for d in filtered_descriptors if d and d.lower() != 'none' and d.lower() != 'nan']
            col_list.extend(filtered_descriptors)
        col_list = sorted(list(set(col_list)))
        if '' in col_list:
            col_list.remove('')
        len_col_list = len(col_list)
        id2label, label2id = create_dictionary(col_list, start=start, name=name)
    else:
        len_col_list = len(id2label.keys())

    for item in tqdm(input_list, desc=f'{name} make new column'):
        descriptors = str(item).split(';')
        filtered_descriptors = [filter_func(d) for d in descriptors]
        filtered_descriptors = [d for d in filtered_descriptors if d and d.lower() != 'none' and d.lower() != 'nan']
        new_entry = [label2id[d + '_' + name] for d in filtered_descriptors if d] or [0]
        new_col.append(new_entry)

    return new_col, id2label, label2id, len_col_list + start


def ec_processing(input_list, start=1, name='ec', id2label=None, label2id=None):
    return process_descriptors(input_list, start=start, name=name, id2label=id2label, label2id=label2id,
                               filter_func=lambda d: d.strip() if '-' not in d and 'n' not in d else '')


def go_processing(input_list, start=1, name='go', id2label=None, label2id=None):
    return process_descriptors(input_list, start=start, name=name, id2label=id2label, label2id=label2id,
                               filter_func=lambda d: d[d.find('[GO:')+1:d.find(']')].strip())


def cofactor_processing(input_list, start=1, name='co', id2label=None, label2id=None):
    return process_descriptors(input_list, start=start, name=name, id2label=id2label, label2id=label2id,
                               filter_func=lambda d: d[d.find('Name=')+5:].strip() if 'Name' in d else '')


def domain_processing(input_list, start=1, name='ip', id2label=None, label2id=None):
    return process_descriptors(input_list, start=start, name=name, id2label=id2label, label2id=label2id)


def replace_df(df, all_cols):
    new_ec_col, new_mf_col, new_bp_col, new_cc_col = all_cols
    combined_list = [sorted([item for item in list(chain.from_iterable(element)) if item != 0])
                     for element in zip(*all_cols) if element != [0]]

    df['EC'] = new_ec_col
    df['MF'] = new_mf_col
    df['BP'] = new_bp_col
    df['CC'] = new_cc_col
    df['combined'] = combined_list
    df['string_combined'] = df['combined'].astype('string')

    unique_dict = {}

    # Iterate over the DataFrame and update the dictionary
    for _, row in tqdm(df.iterrows(), total=len(df)):
        combined_value = row['string_combined']
        if combined_value not in unique_dict:
            unique_dict[combined_value] = row
        if combined_value == '[0]':
            print(combined_value)

    # Create the final DataFrame from the dictionary values
    df_final = pd.DataFrame(unique_dict.values()).drop(columns=['string_combined'])
    return df_final

In [None]:
# Train data
df = load_df('doubles.tsv')
# EC
ecs = df['EC'].tolist()
new_ec_col, id2ec, ec2id, ec_len = ec_processing(ecs, name='ec', start=1)
# MF
mfs = df['MF'].tolist()
new_mf_col, id2mf, mf2id, mf_len = go_processing(mfs, name='mf', start=ec_len+1)
# BP
bps = df['BP'].tolist()
new_bp_col, id2bp, bp2id, bp_len = go_processing(bps, name='bp', start=mf_len+1)
# CC
ccs = df['CC'].tolist()
new_cc_col, id2cc, cc2id, cc_len = go_processing(ccs, name='cc', start=bp_len+1)

# make full dicts and check for no duplicates
all_id = [id2ec, id2mf, id2bp, id2cc]
all_label = [ec2id, mf2id, bp2id, cc2id]

id2label, label2id = {}, {}

key_counts = 0
for d in all_id:
    key_counts += len(d.keys())
    id2label.update(d)
print(key_counts)

key_counts = 0
for d in all_label:
    key_counts += len(d.keys())
    label2id.update(d)
print(key_counts)

for k, v in id2label.items():
    if k != label2id[v]:
        print(v)

for k, v in label2id.items():
    if k != id2label[v]:
        print(v)

print(len(id2label.keys()), len(label2id.keys()))

all_cols = [new_ec_col, new_mf_col, new_bp_col, new_cc_col]

df_final = replace_df(df, all_cols)

print(len(df_final))

df_final.to_csv('processed_doubles.csv', index=False)

In [None]:
# Test data
df = load_df('trembl_all_aspects_raw.tsv')
# EC
ecs = df['EC'].tolist()
new_ec_col, id2ec, ec2id, ec_len = ec_processing(ecs, name='ec', start=1)
# MF
mfs = df['MF'].tolist()
new_mf_col, id2mf, mf2id, mf_len = go_processing(mfs, name='mf', start=ec_len+1)
# BP
bps = df['BP'].tolist()
new_bp_col, id2bp, bp2id, bp_len = go_processing(bps, name='bp', start=mf_len+1)
# CC
ccs = df['CC'].tolist()
new_cc_col, id2cc, cc2id, cc_len = go_processing(ccs, name='cc', start=bp_len+1)
# IP
ips = df['IP'].tolist()
new_ip_col, id2ip, ip2id, ip_len = domain_processing(ips, name='ip', start=cc_len+1)
# 3D
threeds = df['3D'].tolist()
new_threed_col, id2threed, threed2id, threed_len = domain_processing(threeds, name='3d', start=ip_len+1)
# cofactor
cos = df['Cofactor'].tolist()
new_co_col, id2co, co2id, co_len = cofactor_processing(cos, name='co', start=threed_len+1)


all_cols = [new_ec_col, new_mf_col, new_bp_col, new_cc_col, new_ip_col, new_threed_col, new_co_col]
df = replace_df(df, all_cols)

# remove dups
for col in df.columns:
    mask = df[col].apply(lambda x: x != [0]
                        and (not isinstance(x, list) or len(x) > 0)
                        and (not isinstance(x, str) or x.strip() != '[]'))
    df = df.loc[mask]

df_filtered = df[~df['seqs'].isin(df_final['seqs'])]

print(len(df_filtered))
df_filtered.to_csv('processed_trembl.csv', index=False)

In [None]:
# Loading previously processed
df = pd.read_csv('processed_doubles.csv', converters={
    'EC': literal_eval,
    'MF': literal_eval,
    'BP': literal_eval,
    'CC': literal_eval,
    'combined': literal_eval
})

df['combined'] = df['combined'].apply(set)

ec_aspect = df['EC'].apply(set).tolist()
mf_aspect = df['MF'].apply(set).tolist()
bp_aspect = df['BP'].apply(set).tolist()
cc_aspect = df['CC'].apply(set).tolist()

single_aspect_dict = {
    'EC': ec_aspect,
    'MF': mf_aspect,
    'BP': bp_aspect,
    'CC': cc_aspect,
}

print(df.head(1))

sets = df['combined'].tolist()

In [None]:
def retrieve_pairs(seqs, indices):
    A, B = [], []
    for i, j in indices:
        A.append(seqs[i])
        B.append(seqs[j])
    return A, B


def flatten_and_count(tuples_list):
    flattened_list = [item for tuple_item in tuples_list for item in tuple_item]
    element_counts = Counter(flattened_list)
    most_repeated_element, count = element_counts.most_common(1)[0]
    return most_repeated_element, count


def make_pairs_nonredundant(pairs, max_count=10):
    index_count = defaultdict(int)
    seen_pairs = set()
    nonredundant_pairs = []
    sorted_pairs = sorted(pairs, key=lambda x: index_count[x[0]] + index_count[x[1]])
    for pair in tqdm(sorted_pairs, desc='Making non redundant'):
        i, j = pair
        if (i, j) not in seen_pairs and (j, i) not in seen_pairs:
            if index_count[i] < max_count and index_count[j] < max_count:
                nonredundant_pairs.append(pair)
                seen_pairs.add((i, j))
                index_count[i] += 1
                index_count[j] += 1
    return nonredundant_pairs


def calculate_similar_sets(sets):
    pairs_10, pairs_30, pairs_50, pairs_70 = [], [], [], []
    count_10, count_30, count_50, count_70 = 0, 0, 0, 0
    len_sets = len(sets)
    set_range = set(list(range(len_sets)))
    for i in tqdm(set_range, desc='Measuring set similarity'):
        set_i = sets[i]
        new_range = set_range - {i}
        for j in new_range:
            set_j = sets[j]
            intersection = set_i & set_j
            union = set_i.union(set_j)
            similarity = len(intersection) / len(union)
            if similarity >= 0.1:
                count_10 += 1
                pairs_10.append((i, j))
            if similarity >= 0.3:
                count_30 += 1
                pairs_30.append((i, j))
            if similarity >= 0.5:
                count_50 += 1
                pairs_50.append((i, j))
            if similarity >= 0.7:
                count_70 += 1
                pairs_70.append((i, j))
    pairs = (pairs_10, pairs_30, pairs_50, pairs_70)
    counts = (count_10, count_30, count_50, count_70)
    return pairs, counts

In [None]:
pairs, counts = calculate_similar_sets(sets)

In [None]:
print('Percents')
for count in counts:
    print(f'{count * 100 / (len(sets) ** 2)}%')
print('Redundant pair length')
for pair in pairs:
    print(len(pair) / 1e6)

In [None]:
non_10 = make_pairs_nonredundant(pairs[0])
non_30 = make_pairs_nonredundant(pairs[1])
non_50 = make_pairs_nonredundant(pairs[2])
non_70 = make_pairs_nonredundant(pairs[3])
len(non_10), len(non_30), len(non_50), len(non_70)

In [None]:
seqs = df['seqs'].tolist()
seqs[0]

In [None]:
a10, b10 = retrieve_pairs(seqs, non_10)
a30, b30 = retrieve_pairs(seqs, non_30)
a50, b50 = retrieve_pairs(seqs, non_50)
a70, b70 = retrieve_pairs(seqs, non_70)

data = DatasetDict({
    '10': Dataset.from_dict({
        'a': a10,
        'b': b10
    }),
    '30': Dataset.from_dict({
        'a': a30,
        'b': b30
    }),
    '50': Dataset.from_dict({
        'a': a50,
        'b': b50
    }),
    '70': Dataset.from_dict({
        'a': a70,
        'b': b70
    })
})
data

In [None]:
data = Dataset.from_dict({
    'a':a50,
    'b':b50
})
data

In [None]:
data.push_to_hub('lhallee/ProteinDoublesAll', private=True)

In [None]:
split = data.train_test_split(test_size=5000, seed=42)
test_valid = split['test'].train_test_split(test_size=2500)
valid_set = test_valid['train']
test_set = test_valid['test']
print(f"Validation set size: {valid_set.num_rows}")
print(f"Test set size: {test_set.num_rows}")

In [None]:
double_dataset = DatasetDict({
    'train':split['train'],
    'valid':valid_set,
    'test':test_set
})
double_dataset

In [None]:
double_dataset.push_to_hub('lhallee/ProteinDouble50', private=True)

In [3]:
full_data = load_dataset('lhallee/ProteinDoublesAll')
full_data

Downloading readme:   0%|          | 0.00/605 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/96.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/118M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/85.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/39.4M [00:00<?, ?B/s]

Generating 10 split:   0%|          | 0/275452 [00:00<?, ? examples/s]

Generating 30 split:   0%|          | 0/234053 [00:00<?, ? examples/s]

Generating 50 split:   0%|          | 0/158003 [00:00<?, ? examples/s]

Generating 70 split:   0%|          | 0/61766 [00:00<?, ? examples/s]

DatasetDict({
    10: Dataset({
        features: ['a', 'b'],
        num_rows: 275452
    })
    30: Dataset({
        features: ['a', 'b'],
        num_rows: 234053
    })
    50: Dataset({
        features: ['a', 'b'],
        num_rows: 158003
    })
    70: Dataset({
        features: ['a', 'b'],
        num_rows: 61766
    })
})

In [5]:
data = load_dataset('lhallee/ProteinDouble50')

NameError: name 'dat' is not defined

In [7]:
train = full_data["30"]
train

Dataset({
    features: ['a', 'b'],
    num_rows: 234053
})

In [8]:
valid = data['valid']
test = data['test']
valid, test

(Dataset({
     features: ['a', 'b'],
     num_rows: 2500
 }),
 Dataset({
     features: ['a', 'b'],
     num_rows: 2500
 }))

In [9]:
new_data = DatasetDict({
    'train':train,
    'valid':valid,
    'test':test
})

In [10]:
new_data

DatasetDict({
    train: Dataset({
        features: ['a', 'b'],
        num_rows: 234053
    })
    valid: Dataset({
        features: ['a', 'b'],
        num_rows: 2500
    })
    test: Dataset({
        features: ['a', 'b'],
        num_rows: 2500
    })
})

In [11]:
new_data.push_to_hub('lhallee/ProteinDouble30')

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/235 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/lhallee/ProteinDouble30/commit/f517188b2da808461aa4f75fe1d0f85381a2a2ab', commit_message='Upload dataset', commit_description='', oid='f517188b2da808461aa4f75fe1d0f85381a2a2ab', pr_url=None, pr_revision=None, pr_num=None)