### Tokenising Labelled Datasets

In [None]:
import os
from sklearn.utils import shuffle
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from transformers import BertModel, BertTokenizer#, DistilProtBert
from transformers import T5Model, T5Tokenizer
from sklearn.metrics import f1_score
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from transformers import TrainingArguments, Trainer
import pickle
import mgzip
import bz2
import gc
from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback
from transformers import default_data_collator
from functools import reduce
import bisect


# Specify file paths and max sequence length
heat_resistance_proteins_filepath = "/content/drive/My Drive/Colab Notebooks/Dissertation Code/Data/all_fungi_heat_resistant_prots.csv"
non_resistant_proteins_filepath = "/content/drive/My Drive/Colab Notebooks/Dissertation Code/Data/all_fungi_non_heat_resistant_excludesnodata.csv"

tokenised_files_savepath = "/content"

max_len = 256   # Majority class sequences over this length will be removed, minority class sequences will be truncated


########## PREPARE DATASET ##########
# Importing data
print("\nLoading data")
heat_resistant_proteins = pd.read_csv(heat_resistance_proteins_filepath)
print("\nHeat resistance sequences imported")
non_resistant_proteins = pd.read_csv(non_resistant_proteins_filepath)
print("\nNon-heat resistance sequences imported")


# Add labels to data to indicate association with heat resistance and merge datasets
heat_resistant_proteins['resist_heat'] = 1      # 1 = associated with heat resistance
non_resistant_proteins['resist_heat'] = 0       # 0 = no association
proteins_data = pd.concat([heat_resistant_proteins, non_resistant_proteins])


# Check for and remove null values (missing sequences)
proteins_data['Predicted Protein Sequence'].isna().sum()
proteins_data = proteins_data.dropna(subset=['Predicted Protein Sequence'], axis=0)
proteins_data = proteins_data.drop_duplicates()


# Drop unneeded columns
proteins_data = proteins_data.drop(columns=['Gene ID','source_id','Product Description', 'Gene Name or Symbol', 'gene_source_id','Organism', 'Computed GO Component IDs', 'Computed GO Function IDs', 'Computed GO Process IDs', 'Curated GO Component IDs','Curated GO Function IDs', 'Curated GO Process IDs'])


# Remove sequences longer than 5000 amino acids if they're also not associated with heat resistance
proteins_data = proteins_data[~((proteins_data['resist_heat'] == 0) & (proteins_data['Predicted Protein Sequence'].str.len() > max_len))]

# Find longest sequence length to act as max length for inputs
longest_seq = proteins_data['Predicted Protein Sequence'].str.len().max()


# Convert dataframes to lists and find length of dataset
x = proteins_data['Predicted Protein Sequence'].tolist()
y = proteins_data['resist_heat'].tolist()
datapoints = len(x)


# Add spaces between characters in sequences
x = [" ".join(sequence) for sequence in x]

print("\nInitial dataset preparation complete")
print(f"\nLongest sequence: {longest_seq}")
print(f"Number of datapoints: {datapoints}")
print("First sequence:")
print(x[0])


# Split data into train, test and validation datasets
print("\nSplitting data")
train_ratio = 0.70
val_ratio = 0.15
test_ratio = 0.15

x_train, x_test, y_train, y_test = train_test_split(x, y, shuffle=True, stratify=y, test_size=(1-train_ratio), random_state=1543)                           # test is now 30%, training is 70%
x_val, x_test, y_val, y_test =  train_test_split(x_train, y_train, stratify=y_train, test_size=(test_ratio/(test_ratio + val_ratio)), random_state=1435)    # validation is now 0.5 of test set (15% of total dataset)

print(np.unique(y_train))

# Oversample and undersample the training data (in batches to save memory)
print("\nOversampling and undersampling")
x_train = np.array(x_train).reshape(-1, 1)                                        # reshape x because it has a single feature and over/undersampler expects 2D array
oversampler = RandomOverSampler(sampling_strategy=0.1, random_state=209)
undersampler = RandomUnderSampler(sampling_strategy=1, random_state=212)
batch_size = 100000

def resample_data(x, y, batch_size, oversampler, undersampler):                    # Define function for oversampling and undersampling in batches
    x_resampled, y_resampled = [], []
    n_batches = len(x) // batch_size
    
    for i in range(n_batches):
        print(f"\nResampling batch {i}")
        x_batch = x[i * batch_size:(i + 1) * batch_size]
        y_batch = y[i * batch_size:(i + 1) * batch_size]
        
        if len(np.unique(y_batch)) > 1:
            x_batch, y_batch = oversampler.fit_resample(x_batch, y_batch)
            x_batch, y_batch = undersampler.fit_resample(x_batch, y_batch)
        
        x_resampled.extend(x_batch)                                                # Recombine batches into single dataset
        y_resampled.extend(y_batch)
    
    x_batch = x[n_batches * batch_size:]                                           # Process any data that didn't fit in a batch
    y_batch = y[n_batches * batch_size:]
    if len(np.unique(y_batch)) > 1:
        x_batch, y_batch = oversampler.fit_resample(x_batch, y_batch)
        x_batch, y_batch = undersampler.fit_resample(x_batch, y_batch)
    
    
    x_resampled.extend(x_batch)
    y_resampled.extend(y_batch)
    
    return shuffle(x_resampled, y_resampled, random_state=42)


x_train, y_train = resample_data(x_train, y_train, batch_size, oversampler, undersampler)         # Run batch resampling function
x_train = [x[0] for x in x_train]                                                                 # Convert x_train back to a list of strings


# Check ratio of labels
unique_values, counts = np.unique(y_train, return_counts=True)
print("Unique Values:", unique_values)
print("Counts:", counts)


# Pickle and save y sets
print("\nSaving y labels")
with mgzip.open(f"{tokenised_files_savepath}/BEAR_encoded_y_train.gz", "wb") as f:
      pickle.dump(y_train, f)
with mgzip.open(f"{tokenised_files_savepath}/BEAR_encoded_y_test.gz", "wb") as f:
      pickle.dump(y_test, f)
with mgzip.open(f"{tokenised_files_savepath}/BEAR_encoded_y_val.gz", "wb") as f:
      pickle.dump(y_val, f)


print(f"Size of x_train: {len(x_train)}")
print(f"Size of x_val: {len(x_val)}")
print(f"Size of x_test: {len(x_test)}")

print(f"Size of y_train: {len(y_train)}")
print(f"Size of y_val: {len(y_val)}")
print(f"Size of y_test: {len(y_test)}")


########## TOKENISE AND SPLIT DATA ##########
# Load tokeniser
print("\nLoading tokeniser")
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )


# Tokenise datasets
x_datasets = {"train":x_train, "test":x_test, "val":x_val}

for dataset_name in x_datasets:
    dataset = x_datasets[dataset_name]
    dataset_size = len(dataset)
    print(f"\nTokenising {dataset_name} dataset")

    tokenised_data = tokenizer(dataset, padding='max_length', truncation=True, max_length=max_len, add_special_tokens=True, return_attention_mask=True)
    with mgzip.open(f"{tokenised_files_savepath}/BEAR_encoded_{dataset_name}.gz", "wb") as f:
          pickle.dump(tokenised_data, f)

    
print("Tokenisation complete")