Imports

In [None]:
from utils.preprocessing import add_metadata, concat_data, parse_fasta_with_groups, prepare_contrasted_learning_data
from embeddings.protBERT import get_protbert_embeddings, contasted_learning
from utils.visualisation import draw_pca, draw_tsna
from Bio import SeqIO
import pandas as pd
import numpy as np
import os
from typing import List, Tuple
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from matplotlib.axes import Axes

Normal ProtBERT parameters

In [None]:
NORMAL_MAX_SEQ_LEN = 1024
NORMAL_BATCH_SIZE = 4

Contrasted learning protBERT parameters

In [None]:
CL_EPOCHS = 10
CL_TRAIN_LIMIT_NH = 2000
CL_MAX_LEN = 600
CL_BATCH_SIZE = 16
CL_ACCUMULATION_STEPS = 2
CL_LR = 2e-5

Paths

In [None]:
human_raw_path = "data/raw/human_98.fasta"
nonhuman_raw_path = "data/raw/nonhuman_98.fasta"
normal_protbert_path = "data/processed/protbert.pkl"
contrased_learning_probert_path = "data/processed/contrasted_learning_protbert.pkl"
contrasted_learning_adapter_path = "model/contrastive_learning_adapter"

Generate embeddings preprocessing

In [None]:
pre_emb_human, metadata_human = parse_fasta_with_groups(human_raw_path, 1)
pre_emb_nonhuman, metadata_nonhuman = parse_fasta_with_groups(nonhuman_raw_path, 0)

Generate normal protBERT embeddings


In [None]:
if os.path.exists(normal_protbert_path):
    print("Embedding were generated previously, skipping generating them.")
else:
    protbert_emb_human = get_protbert_embeddings(pre_emb_human, NORMAL_MAX_SEQ_LEN, NORMAL_BATCH_SIZE)
    protbert_emb_nonhuman = get_protbert_embeddings(
        pre_emb_nonhuman, NORMAL_MAX_SEQ_LEN, NORMAL_BATCH_SIZE
    )

    human_labeled = add_metadata(protbert_emb_human, metadata_human)
    nonhuman_labeled = add_metadata(protbert_emb_nonhuman, metadata_nonhuman)

    concat_data(human_labeled, nonhuman_labeled, out_path=normal_protbert_path)

Generate contrasted learning protBERT adapter

In [None]:
if os.path.exists(contrasted_learning_adapter_path):
    print("Adapter was pre-made, skipping generating it.")
else:
    all_seqs, all_labels = prepare_contrasted_learning_data(pre_emb_human, pre_emb_nonhuman, CL_TRAIN_LIMIT_NH)
    cl_adapter = contrasted_learning_adapter = contasted_learning(all_seqs,
                                                     all_labels,
                                                     CL_MAX_LEN,
                                                     CL_BATCH_SIZE,
                                                     CL_LR,
                                                     CL_EPOCHS,
                                                     CL_ACCUMULATION_STEPS,
                                                     contrasted_learning_adapter_path)

Generate contrasted learning protBERT embeddings

In [None]:
if os.path.exists(contrased_learning_probert_path):
    print("CL Embedding were generated previously, skipping generating them.")
else:
    protbert_emb_human = get_protbert_embeddings(pre_emb_human, NORMAL_MAX_SEQ_LEN, NORMAL_BATCH_SIZE, contrasted_learning_adapter_path)
    protbert_emb_nonhuman = get_protbert_embeddings(
        pre_emb_nonhuman, NORMAL_MAX_SEQ_LEN, NORMAL_BATCH_SIZE, contrasted_learning_adapter_path
    )

    human_labeled = add_metadata(protbert_emb_human, metadata_human)
    nonhuman_labeled = add_metadata(protbert_emb_nonhuman, metadata_nonhuman)

    concat_data(human_labeled, nonhuman_labeled, out_path=contrased_learning_probert_path)

Compare PCA (normal vs contrasted learning protBERT)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
draw_pca(normal_protbert_path, "Normal protbert PCA", ax1)
draw_pca(contrased_learning_probert_path, "Contrasted learning protbert PCA", ax2)

plt.tight_layout()
plt.show()

Compare TSNA (normal vs contrasted learning protBERT)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
draw_tsna(normal_protbert_path, "Normal protbert TSNA", ax1)
draw_tsna(contrased_learning_probert_path, "Contrasted learning protbert TSNA", ax2)

plt.tight_layout()
plt.show()