In [None]:
# Final Datasets: tremblwithsequences.csv 
# swissprotwithsequences.csv was divided in two datasets as shown in DatasetGeneration notebook : only_in_swiss.csv(This becomes part of test 2) and 
# common_in_swiss_trembl.csv(This is used as Swissprot only dataset)

# For balancing the datasets we used median = 10 for Swissprot only dataset and median = 250 for Swissprot + Trembl datset

In [None]:
#Benchmark 3(Swissprot Only Unbalanced Dataset)

import pandas as pd
import math

class DatasetSplitter:
    def __init__(self, df):
        self.df = df
        self.valid = pd.DataFrame(columns=df.columns)
        self.test = pd.DataFrame(columns=df.columns)
        self.train = pd.DataFrame(columns=df.columns)
        self.uniref50extra = pd.DataFrame(columns=df.columns)
        self.uniref90extra = pd.DataFrame(columns=df.columns)
        self.extra = pd.DataFrame(columns=df.columns)
        self.orphans = pd.DataFrame(columns=df.columns)
        
    def process_groups(self, df, group_by_column, extra_df):
        for ec_number in df['EC'].unique():
            ec_df = df[df['EC'] == ec_number]
            grouped = ec_df.groupby(group_by_column)
            num_groups = grouped.ngroups
            counter = 'Test'
            if num_groups < 3:
                extra_df = pd.concat([extra_df, ec_df])
                continue

            if num_groups >= 10:
                num_test = math.ceil(0.1 * num_groups)
                num_valid = math.ceil(0.1 * num_groups)
                num_train = num_groups - num_test - num_valid
            elif num_groups >= 6 and num_groups <10 :
                num_test = 2
                num_valid = 2
                num_train = num_groups - num_test - num_valid
            else:
                num_test = 1
                num_valid = 1
                num_train = num_groups - num_test - num_valid

            for group_label, group_df in grouped:
                if num_test > 0 and counter == 'Test':
                    self.test = pd.concat([self.test, group_df])
                    num_test -= 1
                    if num_test == 0:
                        counter = 'Valid'
                elif num_valid > 0 and counter == 'Valid':
                    self.valid = pd.concat([self.valid, group_df])
                    num_valid -= 1
                    if num_valid == 0:
                        counter = 'Train'
                else:
                    self.train = pd.concat([self.train, group_df])

        return extra_df

    def process(self):
        self.uniref50extra = self.process_groups(self.df, 'UniRef50', self.uniref50extra)
        self.uniref90extra = self.process_groups(self.uniref50extra, 'UniRef90', self.uniref90extra)
        self.extra = self.process_groups(self.uniref90extra, 'UniRef100', self.extra)
        return self.valid, self.test, self.train, self.extra, self.orphans

# Read the dataset
df = pd.read_csv('common_in_swiss_trembl.csv')

# Create an instance of the DatasetSplitter and process the data
splitter = DatasetSplitter(df)
valid, test, train, extra, orphans = splitter.process()



In [None]:
#Split the extra EC numbers
def split_extra_by_ec(extra):
    # Group the dataframe by 'EC'
    df = extra.groupby('EC')
    # Count the occurrences of each 'EC' group
    counts_df = df.size().reset_index(name='count')
    
    # Iterate over each group
    for ec_number, group in df:
        count = counts_df[counts_df['EC'] == ec_number]['count'].values[0]
        if count >= 3:
            test = pd.concat([test, group.iloc[0:1]])
            valid = pd.concat([valid, group.iloc[1:2]])
            train = pd.concat([train, group.iloc[2:]])
        elif count == 2:
            train = pd.concat([train, group.iloc[:1]])
            test = pd.concat([test, group.iloc[1:]])
        else:
            orphans = pd.concat([orphans, group])
    
    # Save to CSV files
    valid.to_csv('UnbalancedSwissprot/valid.csv', index=False)
    test.to_csv('UnbalancedSwissprot/test.csv', index=False)
    train.to_csv('UnbalancedSwissprot/train.csv', index=False)
    orphans.to_csv('UnbalancedSwissprot/orphans.csv', index=False)
    
    return test, train, valid, orphans

# Usage:
 test, train, valid, orphans = split_data_by_ec(extra)

In [None]:
#The test generated above is test set 1 or in distribution test set

#Out of distribution test set (test set 2) for all benchmarks is shown below:

# Read the first CSV file
df1 = pd.read_csv('SwissprotUnbalanced/orphans.csv')

# Read the second CSV file
df2 = pd.read_csv('only_in_swiss.csv')

# Vertically concatenate the DataFrames
combined_df = pd.concat([df1, df2], axis=0)

# Save the combined DataFrame to a new CSV file
combined_df.to_csv('SwissprotUnbalanced/test2.csv', index=False)

In [None]:
#Benchmark 4(Swissprot Only Balanced Dataset)

import pandas as pd
import random

def balance_data_by_ec(input_file, output_file):
    # Read CSV file
    df = pd.read_csv(input_file)

    # Group by EC number
    grouped = df.groupby('EC')

    # Initialize empty DataFrame for sampled data
    sampled_data = pd.DataFrame(columns=df.columns)

    # Iterate over groups
    for name, group in grouped:
        num_records = len(group)
        
        if num_records < 10:
            # Make copies of available records until there are 10 records
            copies_needed = 10 - num_records
            copies = group.sample(copies_needed, replace=True)
            sampled_data = pd.concat([sampled_data, group, copies])
        elif num_records == 10:
            sampled_data = pd.concat([sampled_data, group])
        else:
            # Randomly sample 10 records ensuring they belong to different uniref50 clusters
            sampled_group = pd.DataFrame(columns=df.columns)
            unique_clusters = group['UniRef50'].unique()
            if len(unique_clusters) < 10:
                # If there are less than 10 clusters, sample different records from available clusters
                for cluster in unique_clusters:
                    cluster_records = group[group['UniRef50'] == cluster]
                    sampled_record = cluster_records.sample(n=1)
                    sampled_group = pd.concat([sampled_group, sampled_record])
                # Sample remaining records to make up 10
                remaining_clusters = 10 - len(unique_clusters)
                for _ in range(remaining_clusters):
                    random_cluster = random.choice(unique_clusters)
                    cluster_records = group[group['UniRef50'] == random_cluster]
                    sampled_record = cluster_records.sample(n=1)
                    sampled_group = pd.concat([sampled_group, sampled_record])
            else:
                sampled_clusters = random.sample(list(unique_clusters), 10)
                for cluster in sampled_clusters:
                    cluster_records = group[group['UniRef50'] == cluster]
                    sampled_record = cluster_records.sample(n=1)
                    sampled_group = pd.concat([sampled_group, sampled_record])
            sampled_data = pd.concat([sampled_data, sampled_group])        
        
    # Write sampled data to CSV
    sampled_data.to_csv(output_file, index=False)

# Usage
balance_data_by_ec('train.csv', 'BalancedSwissprot/train.csv')
balance_data_by_ec('valid.csv', 'BalancedSwissprot/valid.csv')


In [None]:
#Benchmark 1 (Swissprot + Trembl combined, Unbalanced)

In [None]:
#Mixing both datasets
import pandas as pd

# Read the UniProt and SwissProt DataFrames from CSV files
trembl = pd.read_csv("./Trembl.csv")
swiss = pd.read_csv("./Swissprot.csv")

# Add 'DBtype' column with respective values
trembl['DBtype'] = 'Uniprot'
swiss['DBtype'] = 'Swissprot'

# Concatenate the DataFrames
swiss_and_uni = pd.concat([trembl, swiss])


In [None]:
import pandas as pd
import math

class DatasetSplitter:
    def __init__(self, df):
        self.df = df
        self.valid = pd.DataFrame(columns=df.columns)
        self.train = pd.DataFrame(columns=df.columns)
        self.uniref50extra = pd.DataFrame(columns=df.columns)
        self.uniref90extra = pd.DataFrame(columns=df.columns)
        self.extra = pd.DataFrame(columns=df.columns)

    def process_groups(self, df, group_by_column, extra_df):
        for ec_number in df['EC'].unique():
            ec_df = df[df['EC'] == ec_number]
            grouped = ec_df.groupby(group_by_column)
            num_groups = grouped.ngroups
            counter = 'Valid'
            
            if num_groups < 2:
                extra_df = pd.concat([extra_df, ec_df])
                continue
            
            if num_groups >= 10:
                num_valid = math.ceil(0.1 * num_groups)
                num_train = num_groups - num_valid
            elif num_groups >= 6 and num_groups < 10:
                num_valid = 2
                num_train = num_groups - num_valid
            else:
                num_valid = 1
                num_train = num_groups - num_valid

            for group_label, group_df in grouped:
                if num_valid > 0 and counter == 'Valid':
                    self.valid = pd.concat([self.valid, group_df])
                    num_valid -= 1
                    if num_valid == 0:
                        counter = 'Train'
                else:
                    self.train = pd.concat([self.train, group_df])

        return extra_df

    def process(self):
        self.uniref50extra = self.process_groups(self.df, 'UniRef50', self.uniref50extra)
        self.uniref90extra = self.process_groups(self.uniref50extra, 'UniRef90', self.uniref90extra)
        self.extra = self.process_groups(self.uniref90extra, 'UniRef100', self.extra)
        return self.valid, self.train, self.extra

# Read the dataset
input_file = 'path_to_your_input_file.csv'
df = pd.read_csv(input_file)

# Create an instance of the DatasetSplitter and process the data
splitter = DatasetSplitter(df)
valid, train, extra = splitter.process()

train = pd.concat([train,extra])

# Save the results
valid.to_csv('UnbalancedUniprot/valid.csv', index=False)
train.to_csv('UnbalancedUniprot/train.csv', index=False)


In [None]:
#Benchmark 2 (Swissprot + Trembl combined, Balanced)

import pandas as pd
import random

def balance_data_by_ec(input_file, output_file):
    # Read CSV file
    df = pd.read_csv(input_file)

    # Group by EC number
    grouped = df.groupby('EC')

    # Initialize empty DataFrame for sampled data
    sampled_data = pd.DataFrame(columns=df.columns)

    # Iterate over groups
    for name, group in grouped:
        num_records = len(group)
        
        if num_records < 250:
            # Make copies of available records until there are 10 records
            copies_needed = 250 - num_records
            copies = group.sample(copies_needed, replace=True)
            sampled_data = pd.concat([sampled_data, group, copies])
        elif num_records == 250:
            sampled_data = pd.concat([sampled_data, group])
        else:
            # Randomly sample 250 records ensuring they belong to different uniref50 clusters
            sampled_group = pd.DataFrame(columns=df.columns)
            unique_clusters = group['UniRef50'].unique()
            if len(unique_clusters) < 250:
                # If there are less than 10 clusters, sample different records from available clusters
                for cluster in unique_clusters:
                    cluster_records = group[group['UniRef50'] == cluster]
                    sampled_record = cluster_records.sample(n=1)
                    sampled_group = pd.concat([sampled_group, sampled_record])
                # Sample remaining records to make up 250
                remaining_clusters = 250 - len(unique_clusters)
                for _ in range(remaining_clusters):
                    random_cluster = random.choice(unique_clusters)
                    cluster_records = group[group['UniRef50'] == random_cluster]
                    sampled_record = cluster_records.sample(n=1)
                    sampled_group = pd.concat([sampled_group, sampled_record])
            else:
                sampled_clusters = random.sample(list(unique_clusters), 250)
                for cluster in sampled_clusters:
                    cluster_records = group[group['UniRef50'] == cluster]
                    sampled_record = cluster_records.sample(n=1)
                    sampled_group = pd.concat([sampled_group, sampled_record])
            sampled_data = pd.concat([sampled_data, sampled_group])        
        
    # Write sampled data to CSV
    sampled_data.to_csv(output_file, index=False)

# Usage
balance_data_by_ec('train.csv', 'BalancedUniprot/train.csv')
balance_data_by_ec('valid.csv', 'BalancedUniprot/valid.csv')
