In [None]:
import scipy
import torch
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.ndimage
try:
    # PyTorch 1.7.0 and newer versions
    import torch.fft

    def dct1_rfft_impl(x):
        return torch.view_as_real(torch.fft.rfft(x, dim=1))

    def dct_fft_impl(v):
        return torch.view_as_real(torch.fft.fft(v, dim=1))

    def idct_irfft_impl(V):
        return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
except ImportError:
    # PyTorch 1.6.0 and older versions
    def dct1_rfft_impl(x):
        return torch.rfft(x, 1)

    def dct_fft_impl(v):
        return torch.rfft(v, 1, onesided=False)

    def idct_irfft_impl(V):
        return torch.irfft(V, 1, onesided=False)
        
def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """
    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    Vc = dct_fft_impl(v)

    k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == 'ortho':
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)

    return V


def idct(X, norm=None):
    """
    The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
    Our definition of idct is that idct(dct(x)) == x
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param X: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the inverse DCT-II of the signal over the last dimension
    """

    x_shape = X.shape
    N = x_shape[-1]

    X_v = X.contiguous().view(-1, x_shape[-1]) / 2

    if norm == 'ortho':
        X_v[:, 0] *= np.sqrt(N) * 2
        X_v[:, 1:] *= np.sqrt(N / 2) * 2

    k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V_t_r = X_v
    V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)

    V_r = V_t_r * W_r - V_t_i * W_i
    V_i = V_t_r * W_i + V_t_i * W_r

    V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)

    v = idct_irfft_impl(V)
    x = v.new_zeros(v.shape)
    x[:, ::2] += v[:, :N - (N // 2)]
    x[:, 1::2] += v.flip([1])[:, :N // 2]

    return x.view(*x_shape)



# Smoothing amount for gradients before computing attribution prior loss;
# Smoothing window size is 1 + (2 * sigma); set to 0 for no smoothing
att_prior_grad_smooth_sigma = 3

# Maximum frequency integer to consider for a Fourier attribution prior
fourier_att_prior_freq_limit = 100

# Amount to soften the Fourier attribution prior loss limit; set to None
# to not soften; softness decays like 1 / (1 + x^c) after the limit
fourier_att_prior_freq_limit_softness = None
def place_tensor(tensor, input_tensor):
    """
    Places a tensor on GPU, if PyTorch sees CUDA; otherwise, the returned tensor
    remains on CPU.
    """
    return tensor.to(input_tensor.device)
        
def smooth_tensor_1d(input_tensor, smooth_sigma):
    """
    Smooths an input tensor along a dimension using a Gaussian filter.
    Arguments:
        `input_tensor`: a A x B tensor to smooth along the second dimension
        `smooth_sigma`: width of the Gaussian to use for smoothing; this is the
            standard deviation of the Gaussian to use, and the Gaussian will be
            truncated after 1 sigma (i.e. the smoothing window is
            1 + (2 * sigma); sigma of 0 means no smoothing
    Returns an array the same shape as the input tensor, with the dimension of
    `B` smoothed.
    """
    # Generate the kernel
    if smooth_sigma == 0:
        sigma, truncate = 1, 0
    else:
        sigma, truncate = smooth_sigma, 1
    base = np.zeros(1 + (2 * sigma))
    base[sigma] = 1  # Center of window is 1 everywhere else is 0
    kernel = scipy.ndimage.gaussian_filter(base, sigma=sigma, truncate=truncate)
    kernel = torch.tensor(kernel, dtype=torch.float32, device=input_tensor.device)

    # Expand the input and kernel to 3D, with channels of 1
    # Also make the kernel float-type, as the input is going to be of type float
    input_tensor = torch.unsqueeze(input_tensor, dim=1)
    kernel = torch.unsqueeze(torch.unsqueeze(kernel, dim=0), dim=1).float()
    padded_input = F.pad(input_tensor, (sigma,sigma),"replicate")
    smoothed = torch.nn.functional.conv1d(
        padded_input, kernel
    )
    return torch.squeeze(smoothed, dim=1)


def fourier_att_prior_loss_dct(
        input_grads, freq_limit, limit_softness,
        att_prior_grad_smooth_sigma
    ):
    """
    Computes an attribution prior loss for some given training examples,
    using a Fourier transform form.
    Arguments:
        `status`: a B-tensor, where B is the batch size; each entry is 1 if
            that example is to be treated as a positive example, and 0
            otherwise
        `input_grads`: a B x L x D tensor, where B is the batch size, L is
            the length of the input, and D is the dimensionality of each
            input base; this needs to be the gradients of the input with
            respect to the output (for multiple tasks, this gradient needs
            to be aggregated); this should be *gradient times input*
        `freq_limit`: the maximum integer frequency index, k, to consider for
            the loss; this corresponds to a frequency cut-off of pi * k / L;
            k should be less than L / 2
        `limit_softness`: amount to soften the limit by, using a hill
            function; None means no softness
        `att_prior_grad_smooth_sigma`: amount to smooth the gradient before
            computing the loss
    Returns a single scalar Tensor consisting of the attribution loss for
    the batch.
    """
    abs_grads = torch.abs(input_grads)

    # Smooth the gradients
    grads_smooth = abs_grads 
    # smooth_tensor_1d(
    #     abs_grads, att_prior_grad_smooth_sigma
    # )

    # Only do the positives
    pos_grads = grads_smooth

    # Loss for positives
    if pos_grads.nelement():
        #pos_fft = torch.rfft(pos_grads.float(), 1)
        pos_dct = dct(pos_grads.float(), "ortho")
        pos_mags = torch.abs(pos_dct)
        pos_mag_sum = torch.sum(pos_mags, dim=1, keepdim=True)
        pos_mag_sum[pos_mag_sum == 0] = 1  # Keep 0s when the sum is 0
        pos_mags = pos_mags / pos_mag_sum

        # Cut off DC
        pos_mags = pos_mags[:, 1:]

        # Construct weight vector
        weights = place_tensor(torch.ones_like(pos_mags), input_grads)
        if limit_softness is None:
            weights[:, freq_limit:] = 0
        else:
            x = place_tensor(
                torch.arange(1, pos_mags.size(1) - freq_limit + 1), input_grads
            ).float()
            weights[:, freq_limit:] = 1 / (1 + torch.pow(x, limit_softness))

        # Multiply frequency magnitudes by weights
        pos_weighted_mags = pos_mags * weights

        # Add up along frequency axis to get score
        pos_score = torch.sum(pos_weighted_mags, dim=1)
        pos_loss = 1 - pos_score
        return torch.mean(pos_loss)
    else:
        return place_tensor(torch.zeros(1), input_grads)


def fourier_att_prior_loss(
        input_grads, freq_limit, limit_softness,
        att_prior_grad_smooth_sigma
    ):
    """
    Computes an attribution prior loss for some given training examples,
    using a Fourier transform form.
    Arguments:
        `status`: a B-tensor, where B is the batch size; each entry is 1 if
            that example is to be treated as a positive example, and 0
            otherwise
        `input_grads`: a B x L x D tensor, where B is the batch size, L is
            the length of the input, and D is the dimensionality of each
            input base; this needs to be the gradients of the input with
            respect to the output (for multiple tasks, this gradient needs
            to be aggregated); this should be *gradient times input*
        `freq_limit`: the maximum integer frequency index, k, to consider for
            the loss; this corresponds to a frequency cut-off of pi * k / L;
            k should be less than L / 2
        `limit_softness`: amount to soften the limit by, using a hill
            function; None means no softness
        `att_prior_grad_smooth_sigma`: amount to smooth the gradient before
            computing the loss
    Returns a single scalar Tensor consisting of the attribution loss for
    the batch.
    """
    abs_grads = torch.abs(input_grads)

    # Smooth the gradients
    grads_smooth = abs_grads 
    # smooth_tensor_1d(
    #     abs_grads, att_prior_grad_smooth_sigma
    # )

    # Only do the positives
    pos_grads = grads_smooth

    # Loss for positives
    if pos_grads.nelement():
        #pos_fft = torch.rfft(pos_grads.float(), 1)
        pos_fft = torch.view_as_real(torch.fft.rfft(pos_grads.float()))
        pos_mags = torch.norm(pos_fft, dim=2)
        pos_mag_sum = torch.sum(pos_mags, dim=1, keepdim=True)
        pos_mag_sum[pos_mag_sum == 0] = 1  # Keep 0s when the sum is 0
        pos_mags = pos_mags / pos_mag_sum

        # Cut off DC
        pos_mags = pos_mags[:, 1:]

        # Construct weight vector
        weights = place_tensor(torch.ones_like(pos_mags))
        if limit_softness is None:
            weights[:, freq_limit:] = 0
        else:
            x = place_tensor(
                torch.arange(1, pos_mags.size(1) - freq_limit + 1)
            ).float()
            weights[:, freq_limit:] = 1 / (1 + torch.pow(x, limit_softness))

        # Multiply frequency magnitudes by weights
        pos_weighted_mags = pos_mags * weights

        # Add up along frequency axis to get score
        pos_score = torch.sum(pos_weighted_mags, dim=1)
        pos_loss = 1 - pos_score
        return torch.mean(pos_loss)
    else:
        return place_tensor(torch.zeros(1))

In [None]:

CATEGORIES = ["Membrane","Cytoplasm","Nucleus","Extracellular","Cell membrane","Mitochondrion","Plastid","Endoplasmic reticulum","Lysosome/Vacuole","Golgi apparatus","Peroxisome"]
SS_CATEGORIES = ["NULL", "SP", "TM", "MT", "CH", "TH", "NLS", "NES", "PTS", "GPI"] 

FAST = "Fast"
ACCURATE = "Accurate"

EMBEDDINGS = {
    FAST: {
        "embeds": "data_files/embeddings/esm1b_swissprot.h5",
        "config": "swissprot_esm1b.yaml",
        "source_fasta": "data_files/deeploc_swissprot_clipped1k.fasta"
    },
    ACCURATE: {
        "embeds": "data_files/embeddings/prott5_swissprot.h5",
        "config": "swissprot_prott5.yaml",
        "source_fasta": "data_files/deeploc_swissprot_clipped4k.fasta"
    }
}

SIGNAL_DATA = "data_files/multisub_ninesignals.pkl"
LOCALIZATION_DATA = "./data_files/multisub_5_partitions_unique.csv"

BATCH_SIZE = 128
SUP_LOSS_MULT = 0.1
REG_LOSS_MULT = 0.1


In [None]:
import pickle
import torch
from Bio import SeqIO
import re
import pandas as pd
import time 
import os
class FastaBatchedDatasetTorch(torch.utils.data.Dataset):
    def __init__(self, data_df):
        self.data_df = data_df

    def __len__(self):
        return len(self.data_df)
    
    def shuffle(self):
        self.data_df = self.data_df.sample(frac=1).reset_index(drop=True)

    def __getitem__(self, idx):
        return self.data_df["Sequence"][idx], self.data_df["ACC"][idx]

    def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
        sizes = [(len(s), i) for i, s in enumerate(self.data_df["Sequence"])]
        sizes.sort(reverse=True)
        batches = []
        buf = []
        max_len = 0

        def _flush_current_buf():
            nonlocal max_len, buf
            if len(buf) == 0:
                return
            batches.append(buf)
            buf = []
            max_len = 0
        start = 0
        #start = random.randint(0, len(sizes))
        for j in range(len(sizes)):
            i = (start + j) % len(sizes)
            sz = sizes[i][0]
            idx = sizes[i][1]    
            sz += extra_toks_per_seq
            if (max(sz, max_len) * (len(buf) + 1) > toks_per_batch):
                _flush_current_buf()
            max_len = max(max_len, sz)
            buf.append(idx)

        _flush_current_buf()
        return batches

class BatchConverterProtT5(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, alphabet):
        self.alphabet = alphabet

    def __call__(self, raw_batch):
        # RoBERTa uses an eos token, while ESM-1 does not.
        batch_size = len(raw_batch)
        #print(len(raw_batch[0]), raw_batch[1], raw_batch[2])
        max_len = max(len(seq_str) for seq_str, _ in raw_batch)
        labels = []
        lengths = []
        strs = []
        for i, (seq_str, label) in enumerate(raw_batch):
            #seq_str = seq_str[1:]
            labels.append(label)
            lengths.append(len(seq_str))
            strs.append(seq_str)
        
        proteins = [" ".join(list(item)) for item in strs]
        proteins = [re.sub(r"[UZOB]", "X", sequence) for sequence in proteins]
        ids = self.alphabet.batch_encode_plus(proteins, add_special_tokens=True, padding=True)
        non_pad_mask = torch.tensor(ids['input_ids']) > -100 # B, T

        return ids, torch.tensor(lengths), non_pad_mask, labels


class BatchConverter(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, alphabet):
        self.alphabet = alphabet

    def __call__(self, raw_batch):
        # RoBERTa uses an eos token, while ESM-1 does not.
        batch_size = len(raw_batch)
        #print(len(raw_batch[0]), raw_batch[1], raw_batch[2])
        max_len = max(len(seq_str) for seq_str, _ in raw_batch)
        tokens = torch.empty((batch_size, max_len + int(self.alphabet.prepend_bos) + \
            int(self.alphabet.append_eos)), dtype=torch.int64)
        tokens.fill_(self.alphabet.padding_idx)
        labels = []
        lengths = []
        strs = []
        for i, (seq_str, label) in enumerate(raw_batch):
            #seq_str = seq_str[1:]
            labels.append(label)
            lengths.append(len(seq_str))
            strs.append(seq_str)
            if self.alphabet.prepend_bos:
                tokens[i, 0] = self.alphabet.cls_idx
            seq = torch.tensor([self.alphabet.get_idx(s) for s in seq_str], dtype=torch.int64)
            tokens[i, int(self.alphabet.prepend_bos) : len(seq_str) + int(self.alphabet.prepend_bos)] = seq
            if self.alphabet.append_eos:
                tokens[i, len(seq_str) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
        
        non_pad_mask = ~tokens.eq(self.alphabet.padding_idx) &\
         ~tokens.eq(self.alphabet.cls_idx) &\
         ~tokens.eq(self.alphabet.eos_idx)# B, T

        return tokens, torch.tensor(lengths), non_pad_mask, labels

def read_fasta(fastafile):
    """Parse a file with sequences in FASTA format and store in a dict"""
    proteins = list(SeqIO.parse(fastafile, "fasta"))
    res = {}
    for prot in proteins:
        res[str(prot.id)] = str(prot.seq)
    return res

# with open("/tools/src/deeploc-2.0/models/ESM1b_alphabet.pkl", "rb") as f:
#     alphabet = pickle.load(f)

###################################
#######   TRAINING STUFF  #########
###################################

import h5py
import numpy as np
import pickle5
from sklearn.model_selection import ShuffleSplit
from src.constants import *

def get_swissprot_df(clip_len):  
    with open(SIGNAL_DATA, "rb") as f:
        annot_df = pickle5.load(f)
    nes_exclude_list = ['Q7TPV4','P47973','P38398','P38861','Q16665','O15392','Q9Y8G3','O14746','P13350','Q06142']
    swissprot_exclusion_list = ['Q04656-5','O43157','Q9UPN3-2']
    def clip_middle_np(x):
        if len(x)>clip_len:
            x = np.concatenate((x[:clip_len//2],x[-clip_len//2:]), axis=0)
        return x
    def clip_middle(x):
      if len(x)>clip_len:
          x = x[:clip_len//2] + x[-clip_len//2:]
      return x
 
    annot_df["TargetAnnot"] = annot_df["ANNOT"].apply(lambda x: clip_middle_np(x))
    data_df = pd.read_csv(LOCALIZATION_DATA)
    data_df["Sequence"] = data_df["Sequence"].apply(lambda x: clip_middle(x))
    data_df["Target"] = data_df[CATEGORIES].values.tolist()    

    annot_df = annot_df[~annot_df.ACC.isin(nes_exclude_list)].reset_index(drop=True)
    data_df = data_df[~data_df.ACC.isin(swissprot_exclusion_list)].reset_index(drop=True)
    data_df = data_df.merge(annot_df[["ACC", "ANNOT", "Types", "TargetAnnot"]], on="ACC", how="left")
    data_df['TargetAnnot'] = data_df['TargetAnnot'].fillna(0)

    # embedding_fasta = read_fasta(f"{embedding_path}/remapped_sequences_file.fasta")
    # embedding_df = pd.DataFrame(embedding_fasta.items(), columns=["details", "RawSeq"])
    # embedding_df["Hash"] = embedding_df.details.apply(lambda x: x.split()[0])
    # embedding_df["ACC"] = embedding_df.details.apply(lambda x: x.split()[1])
    # data_df = data_df.merge(embedding_df[["ACC", "Hash"]]).reset_index(drop=True)

    return data_df

def convert_to_binary(x):
    types_binary = np.zeros((len(SS_CATEGORIES)-1,))
    for c in x.split("_"):
      types_binary[SS_CATEGORIES.index(c)-1] = 1
    return types_binary

def get_swissprot_ss_Xy(save_path, fold, clip_len):
    with open(SIGNAL_DATA, "rb") as f:
        annot_df = pickle5.load(f)
    nes_exclude_list = ['Q7TPV4','P47973','P38398','P38861','Q16665','O15392','Q9Y8G3','O14746','P13350','Q06142']
    swissprot_exclusion_list = ['Q04656-5','O43157','Q9UPN3-2']
    def clip_middle_np(x):
        if len(x)>clip_len:
            x = np.concatenate((x[:clip_len//2],x[-clip_len//2:]), axis=0)
        return x
    def clip_middle(x):
      if len(x)>clip_len:
          x = x[:clip_len//2] + x[-clip_len//2:]
      return x
    
    train_annot_pred_df = pd.read_pickle(os.path.join(save_path, f"inner_{fold}_1Layer.pkl"))
    test_annot_pred_df = pd.read_pickle(os.path.join(save_path, f"{fold}_1Layer.pkl"))
    assert train_annot_pred_df.merge(test_annot_pred_df, on="ACC").empty == True

    
    filt_annot_df = annot_df[annot_df["Types"]!=""].reset_index(drop=True)
    seq_df = filt_annot_df.merge(train_annot_pred_df)
    seq_df["Sequence"] = seq_df["Sequence"].apply(lambda x: clip_middle(x))
    seq_df["Target"] = seq_df[CATEGORIES].values.tolist()
    seq_df["TargetSignal"] = seq_df["Types"].apply(lambda x: convert_to_binary(x))

    annot_true_df = seq_df
    X_true_train, y_true_train = np.concatenate((np.stack(annot_true_df["embeds"].to_numpy()), np.stack(annot_true_df["Target"].to_numpy())), axis=1) , np.stack(annot_true_df["TargetSignal"].to_numpy())
    annot_pred_df = seq_df
    X_pred_target = np.stack(annot_true_df["preds"].to_numpy())# > threshold_dict[f"{i}_multidct"]
    X_pred_train, y_pred_train = np.concatenate((np.stack(annot_pred_df["embeds"].to_numpy()), X_pred_target), axis=1), np.stack(annot_pred_df["TargetSignal"].to_numpy())

    seq_df = filt_annot_df.merge(test_annot_pred_df)
    seq_df["Sequence"] = seq_df["Sequence"].apply(lambda x: clip_middle(x))
    seq_df["Target"] = seq_df[CATEGORIES].values.tolist()
    seq_df["TargetSignal"] = seq_df["Types"].apply(lambda x: convert_to_binary(x))

    annot_test_df = seq_df
    X_test_target = np.stack(annot_test_df["preds"].to_numpy())# > threshold_dict[f"{i}_multidct"]
    X_test, y_test = np.concatenate((np.stack(annot_test_df["embeds"].to_numpy()), X_test_target), axis=1), np.stack(annot_test_df["TargetSignal"].to_numpy())
    
    X_train = np.concatenate((X_true_train, X_pred_train), axis=0)
    y_train = np.concatenate((y_true_train, y_pred_train), axis=0)
    #print(X_train.shape, X_test.shape)

    return X_train, y_train, X_test, y_test


class EmbeddingsLocalizationDataset(torch.utils.data.Dataset):
    """
    Dataset of protein embeddings and the corresponding subcellular localization label.
    """

    def __init__(self, embedding_file, data_df) -> None:
        super().__init__()
        self.data_df = data_df
        self.embeddings_file = embedding_file
    
    def __getitem__(self, index: int):
        embedding = np.array(self.embeddings_file[self.data_df["ACC"][index]]).copy()
        return self.data_df["Sequence"][index], embedding, self.data_df["Target"][index], self.data_df["TargetAnnot"][index], self.data_df["ACC"][index]
    
    def get_batch_indices(self, toks_per_batch, max_batch_size, extra_toks_per_seq=0):
        sizes = [(len(s), i) for i, s in enumerate(self.data_df["Sequence"])]
        sizes.sort(reverse=True)
        batches = []
        buf = []
        max_len = 0

        def _flush_current_buf():
            nonlocal max_len, buf
            if len(buf) == 0:
                return
            batches.append(buf)
            buf = []
            max_len = 0
        start = 0
        #start = random.randint(0, len(sizes))
        for j in range(len(sizes)):
            i = (start + j) % len(sizes)
            sz = sizes[i][0]
            idx = sizes[i][1]    
            sz += extra_toks_per_seq
            if (max(sz, max_len) * (len(buf) + 1) > toks_per_batch) or len(buf) >= max_batch_size:
                _flush_current_buf()
            max_len = max(max_len, sz)
            buf.append(idx)

        _flush_current_buf()
        return batches

    def __len__(self) -> int:
        return len(self.data_df)

class TrainBatchConverter(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, alphabet, embed_len):
        self.alphabet = alphabet
        self.embed_len = embed_len

    def __call__(self, raw_batch):
        batch_size = len(raw_batch)
        max_len = max(len(seq_str) for seq_str, _, _, _, _ in raw_batch)
        embedding_tensor = torch.zeros((batch_size, max_len, self.embed_len), dtype=torch.float32)
        np_mask = torch.zeros((batch_size, max_len))
        target_annots = torch.zeros((batch_size, max_len), dtype=torch.int64)
        labels = []
        lengths = []
        strs = []
        targets = torch.zeros((batch_size, 11), dtype=torch.float32)
        for i, (seq_str, embedding, target, target_annot, label) in enumerate(raw_batch):
            #seq_str = seq_str[1:]
            labels.append(label)
            lengths.append(len(seq_str))
            strs.append(seq_str)
            targets[i] = torch.tensor(target)
            embedding_tensor[i, :len(seq_str)] = torch.tensor(np.array(embedding))
            target_annots[i, :len(seq_str)] = torch.tensor(target_annot)
            np_mask[i, :len(seq_str)] = 1
        np_mask = np_mask == 1
        return embedding_tensor, torch.tensor(lengths), np_mask, targets, target_annots, labels
    
class SignalTypeDataset(torch.utils.data.Dataset):

    def __init__(self, X, y) -> None:
        super().__init__()
        self.X = X
        self.y = y
    
    def __getitem__(self, index: int):
        return torch.tensor(self.X[index]).float(), torch.tensor(self.y[index]).float()

    def __len__(self):
        return self.X.shape[0]


class DataloaderHandler:
    def __init__(self, clip_len, alphabet, embedding_file, embed_len) -> None:
        self.clip_len = clip_len
        self.alphabet = alphabet
        self.embedding_file = embedding_file
        self.embed_len = embed_len

    def get_train_val_dataloaders(self, outer_i):
        data_df = get_swissprot_df(self.clip_len)
        
        train_df = data_df[data_df.Partition != outer_i].reset_index(drop=True)

        X = np.stack(train_df["ACC"].to_numpy())
        sss_tt = ShuffleSplit(n_splits=1, test_size=2048, random_state=0)
        
        (split_train_idx, split_val_idx) = next(sss_tt.split(X))
        split_train_df =  train_df.iloc[split_train_idx].reset_index(drop=True)
        split_val_df = train_df.iloc[split_val_idx].reset_index(drop=True)

        # print(split_train_df[CATEGORIES].mean())
        # print(split_val_df[CATEGORIES].mean())
        embedding_file = h5py.File(self.embedding_file, "r")
        train_dataset = EmbeddingsLocalizationDataset(embedding_file, split_train_df)
        train_batches = train_dataset.get_batch_indices(4096*4, BATCH_SIZE, extra_toks_per_seq=0)
        train_dataloader = torch.utils.data.DataLoader(train_dataset, collate_fn=TrainBatchConverter(self.alphabet, self.embed_len), batch_sampler=train_batches)

        val_dataset = EmbeddingsLocalizationDataset(embedding_file, split_val_df)
        val_batches = val_dataset.get_batch_indices(4096*4, BATCH_SIZE, extra_toks_per_seq=0)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, collate_fn=TrainBatchConverter(self.alphabet, self.embed_len), batch_sampler=val_batches)
        return train_dataloader, val_dataloader

    def get_partition(self, outer_i):
        data_df = get_swissprot_df(self.clip_len )
        test_df = data_df[data_df.Partition == outer_i].reset_index(drop=True)
        return test_df

    def get_partition_dataloader(self, outer_i):
        data_df = get_swissprot_df(self.clip_len)
        test_df = data_df[data_df.Partition == outer_i].reset_index(drop=True)
        
        embedding_file = h5py.File(self.embedding_file, "r")
        test_dataset = EmbeddingsLocalizationDataset(embedding_file, test_df)
        test_batches = test_dataset.get_batch_indices(4096*4, BATCH_SIZE, extra_toks_per_seq=0)
        test_dataloader = torch.utils.data.DataLoader(test_dataset, collate_fn=TrainBatchConverter(self.alphabet, self.embed_len), batch_sampler=test_batches)
        return test_dataloader, test_df

    def get_partition_dataloader_inner(self, partition_i):
        data_df = get_swissprot_df(self.clip_len)
        test_df = data_df[data_df.Partition != partition_i].reset_index(drop=True)
        embedding_file = h5py.File(self.embedding_file, "r")
        test_dataset = EmbeddingsLocalizationDataset(embedding_file, test_df)
        test_batches = test_dataset.get_batch_indices(4096*4, BATCH_SIZE, extra_toks_per_seq=0)
        test_dataloader = torch.utils.data.DataLoader(test_dataset, collate_fn=TrainBatchConverter(self.alphabet, self.embed_len), batch_sampler=test_batches)

        return test_dataloader, test_df
    
    def get_ss_train_val_dataloader(self, save_path, outer_i):
        X, y, _, _ = get_swissprot_ss_Xy(save_path, outer_i, clip_len=self.clip_len)
        sss_tt = ShuffleSplit(n_splits=1, test_size=0.1, random_state=0)
        
        (split_train_idx, split_val_idx) = next(sss_tt.split(y))
        split_train_X, split_train_y =  X[split_train_idx], y[split_train_idx]
        split_val_X, split_val_y = X[split_val_idx], y[split_val_idx]

        print(split_train_X.shape, split_train_y.shape, split_val_X.shape, split_val_y.shape)
        
        train_dataset = SignalTypeDataset(split_train_X, split_train_y)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            shuffle=True,
            batch_size=BATCH_SIZE,
            drop_last=True)

        val_dataset = SignalTypeDataset(split_val_X, split_val_y)
        val_dataloader = torch.utils.data.DataLoader(
            val_dataset,
            shuffle=False,
            batch_size=BATCH_SIZE)
        
        return train_dataloader, val_dataloader

    def get_ss_test_dataloader(self, save_path, outer_i):
        _, _, X, y = get_swissprot_ss_Xy(save_path, outer_i, clip_len=self.clip_len)
        
        print(X.shape, y.shape)
        val_dataset = SignalTypeDataset(X, y)
        val_dataloader = torch.utils.data.DataLoader(
            val_dataset,
            shuffle=False,
            batch_size=X.shape[0])

        return val_dataloader
    
    def get_swissprot_ss_xy(self, save_path, outer_i):
        return get_swissprot_ss_Xy(save_path=save_path, fold=outer_i, clip_len=self.clip_len)












In [None]:
import h5py
import torch
from esm import Alphabet, FastaBatchedDataset, pretrained
from transformers import T5EncoderModel, T5Tokenizer
import tqdm
from src.data import *
from src.utils import *
import os
if torch.cuda.is_available():
    device = "cuda"
    dtype = torch.float16
elif torch.backends.mps.is_available():
    device = "cpu"
    dtype=torch.bfloat16
else:
    device = "cpu"
    dtype=torch.bfloat16

def embed_esm1b(embed_dataloader, out_file):
    model, _ = pretrained.load_model_and_alphabet("esm1b_t33_650M_UR50S")
    model.eval().to(device)
    embed_h5 = h5py.File(out_file, "w")
    try:
        with torch.autocast(device_type=device,dtype=dtype):
            with torch.no_grad():
                for i, (toks, lengths, np_mask, labels) in tqdm.tqdm(enumerate(embed_dataloader)):
                    embed = model(toks.to(device), repr_layers=[33])["representations"][33].float().cpu().numpy()
                    for j in range(len(labels)):
                        # removing start and end tokens
                        embed_h5[labels[j]] = embed[j, 1:1+lengths[j]].astype(np.float16)
        embed_h5.close()
    except:
        os.system(f"rm {out_file}")
        raise Exception("Failed to create embeddings")
    

def embed_prott5(embed_dataloader, out_file):
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
    model.eval().to(device)
    embed_h5 = h5py.File(out_file, "w")
    try:
        with torch.autocast(device_type=device,dtype=dtype):
            with torch.no_grad():
                for i, (toks, lengths, np_mask, labels) in tqdm.tqdm(enumerate(embed_dataloader)):
                    embed = model(input_ids=torch.tensor(toks['input_ids'], device=device),
                    attention_mask=torch.tensor(toks['attention_mask'], 
                        device=device)).last_hidden_state.float().cpu().numpy()
                    for j in range(len(labels)):
                        # removing end tokens
                        embed_h5[labels[j]] = embed[j, :lengths[j]].astype(np.float16)
        embed_h5.close()
    except:
        os.system(f"rm {out_file}")
        raise Exception("Failed to create embeddings")

def generate_embeddings(model_attrs: ModelAttributes):
    fasta_dict = read_fasta(EMBEDDINGS[model_attrs.model_type]["source_fasta"])
    test_df = pd.DataFrame(fasta_dict.items(), columns=['ACC', 'Sequence'])
    embed_dataset = FastaBatchedDatasetTorch(test_df)
    embed_batches = embed_dataset.get_batch_indices(8196, extra_toks_per_seq=1)
    if model_attrs.model_type == FAST:
        embed_dataloader = torch.utils.data.DataLoader(embed_dataset, collate_fn=BatchConverter(model_attrs.alphabet), batch_sampler=embed_batches)
        embed_esm1b(embed_dataloader, EMBEDDINGS[model_attrs.model_type]["embeds"])
    elif model_attrs.model_type == ACCURATE:
        embed_dataloader = torch.utils.data.DataLoader(embed_dataset, collate_fn=BatchConverterProtT5(model_attrs.alphabet), batch_sampler=embed_batches)
        embed_prott5(embed_dataloader, EMBEDDINGS[model_attrs.model_type]["embeds"])
    else:
        raise Exception("wrong model type provided expected Fast,Accurate got", model_attrs.model_type)
    

In [None]:
import os
import tqdm
import pandas as pd
import numpy as np
import pickle
import torch
from src.utils import ModelAttributes
from src.data import DataloaderHandler
from src.metrics import *

if torch.cuda.is_available():
    device = "cuda"
    dtype = torch.float16
elif torch.backends.mps.is_available():
    device = "cpu"
    dtype=torch.bfloat16
else:
    device = "cpu"
    dtype=torch.bfloat16

def predict_sl_values(dataloader, model):
    output_dict = {}
    annot_dict = {}
    pool_dict = {}
    with torch.no_grad():
      for i, (toks, lengths, np_mask, targets, targets_seq, labels) in tqdm.tqdm(enumerate(dataloader)):
        with torch.autocast(device_type=device,dtype=dtype):
            y_pred, y_pool, y_attn = model.predict(toks.to(device), lengths.to(device), np_mask.to(device))
        x = torch.sigmoid(y_pred).float().cpu().numpy()
        for j in range(len(labels)):
            if len(labels) == 1:
                output_dict[labels[j]] = x
                pool_dict[labels[j]] = y_pool.float().cpu().numpy()
                annot_dict[labels[j]] = y_attn[:lengths[j]].float().cpu().numpy()
            else:
                output_dict[labels[j]] = x[j]
                pool_dict[labels[j]] = y_pool[j].float().cpu().numpy()
                annot_dict[labels[j]] = y_attn[j,:lengths[j]].float().cpu().numpy()

    output_df = pd.DataFrame(output_dict.items(), columns=['ACC', 'preds'])
    annot_df = pd.DataFrame(annot_dict.items(), columns=['ACC', 'pred_annot'])
    pool_df = pd.DataFrame(pool_dict.items(), columns=['ACC', 'embeds'])
    return output_df.merge(annot_df).merge(pool_df)
    
def generate_sl_outputs(
        model_attrs: ModelAttributes, 
        datahandler: DataloaderHandler, 
        thresh_type="mcc", 
        inner_i="1Layer", 
        reuse=False):
    
    threshold_dict = {}
        
    for outer_i in range(5):
        print("Generating output for ensemble model", outer_i)
        dataloader, data_df = datahandler.get_partition_dataloader_inner(outer_i)
        if not os.path.exists(os.path.join(model_attrs.outputs_save_path, f"inner_{outer_i}_{inner_i}.pkl")):
            path = f"{model_attrs.save_path}/{outer_i}_{inner_i}.ckpt"
            model = model_attrs.class_type.load_from_checkpoint(path).to(device).eval()
            pred_df = predict_sl_values(dataloader, model)
            pred_df.to_pickle(os.path.join(model_attrs.outputs_save_path, f"inner_{outer_i}_{inner_i}.pkl"))
        else:
            pred_df = pd.read_pickle(os.path.join(model_attrs.outputs_save_path, f"inner_{outer_i}_{inner_i}.pkl"))

        if thresh_type == "roc":
            thresholds = get_optimal_threshold(pred_df, data_df)
        elif thresh_type == "pr":
            thresholds = get_optimal_threshold_pr(pred_df, data_df)
        else:
            thresholds = get_optimal_threshold_mcc(pred_df, data_df)
        threshold_dict[f"{outer_i}_{inner_i}"] = thresholds
        
        if not os.path.exists(os.path.join(model_attrs.outputs_save_path, f"{outer_i}_{inner_i}.pkl")):
            dataloader, data_df = datahandler.get_partition_dataloader(outer_i)
            output_df = predict_sl_values(dataloader, model)
            output_df.to_pickle(os.path.join(model_attrs.outputs_save_path, f"{outer_i}_{inner_i}.pkl"))

    with open(os.path.join(model_attrs.outputs_save_path, f"thresholds_sl_{thresh_type}.pkl"), "wb") as f:
        pickle.dump(threshold_dict, f)

def predict_ss_values(X, model):
    X_tensor = torch.tensor(X, device=device).float()
    y_preds = torch.sigmoid(model(X_tensor))
    return y_preds.detach().cpu().numpy()

def generate_ss_outputs(
        model_attrs: ModelAttributes, 
        datahandler: DataloaderHandler, 
        thresh_type="mcc", 
        inner_i="1Layer", 
        reuse=False):
    
    threshold_dict = {}
    if not os.path.exists(f"{model_attrs.outputs_save_path}"):
        os.makedirs(f"{model_attrs.outputs_save_path}")
    for outer_i in range(5):
        print("Generating output for ensemble model", outer_i)
        X_train, y_train, X_test, y_test = datahandler.get_swissprot_ss_xy(model_attrs.outputs_save_path, outer_i)
        path = f"{model_attrs.save_path}/signaltype/{outer_i}.ckpt"
        model = SignalTypeMLP.load_from_checkpoint(path).to(device).eval()
        
        y_train_preds = predict_ss_values(X_train, model)
        thresh = np.zeros((9,))
        threshold_dict = {}
        #print("thresholds")
        for type_i in range(9):
            thresh[type_i] = get_best_threshold_mcc(y_train[:, type_i], y_train_preds[:, type_i])
            threshold_dict[SS_CATEGORIES[type_i+1]] = thresh[type_i]
            #print(SS_CATEGORIES[type_i+1], thresh[type_i])
        y_test_preds = predict_ss_values(X_test, model)
        pickle.dump(y_test_preds, open(f"{model_attrs.outputs_save_path}/ss_{outer_i}.pkl", "wb"))

    with open(os.path.join(model_attrs.outputs_save_path, f"thresholds_ss_mcc.pkl"), "wb") as f:
        pickle.dump(threshold_dict, f)

In [None]:
from sklearn.metrics import f1_score
from sklearn.metrics import hamming_loss, matthews_corrcoef, confusion_matrix, roc_auc_score
from sklearn.metrics import jaccard_score
from sklearn.metrics import classification_report
from sklearn import metrics
import pickle
from src.constants import *
from src.utils import *
import numpy as np
import pandas as pd
import os
import json


# taken from https://www.kaggle.com/cpmpml/optimizing-probabilities-for-best-mcc
def mcc(tp, tn, fp, fn):
    sup = tp * tn - fp * fn
    inf = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
    if inf==0:
        return 0
    else:
        return sup / np.sqrt(inf)
        
def get_best_threshold_mcc(y_true, y_prob):
    idx = np.argsort(y_prob)
    y_true_sort = y_true[idx]
    n = y_true.shape[0]
    nump = 1.0 * np.sum(y_true) # number of positive
    numn = n - nump # number of negative
    tp = nump
    tn = 0.0
    fp = numn
    fn = 0.0
    best_mcc = 0.0
    best_id = -1
    prev_proba = -1
    best_proba = -1
    mccs = np.zeros(n)
    for i in range(n):
        # all items with idx < i are predicted negative while others are predicted positive
        # only evaluate mcc when probability changes
        proba = y_prob[idx[i]]
        if proba != prev_proba:
            prev_proba = proba
            new_mcc = mcc(tp, tn, fp, fn)
            if new_mcc >= best_mcc:
                best_mcc = new_mcc
                best_id = i
                best_proba = proba
        mccs[i] = new_mcc
        if y_true_sort[i] == 1:
            tp -= 1.0
            fn += 1.0
        else:
            fp -= 1.0
            tn += 1.0

    y_pred = (y_prob >= best_proba).astype(int)
    score = matthews_corrcoef(y_true, y_pred)
    # print(score, best_mcc)
    # plt.plot(mccs)
    return best_proba

def get_optimal_threshold(output_df, data_df):
    test_df = data_df.merge(output_df)
    
    predictions = np.stack(test_df["preds"].to_numpy())
    actuals = np.stack(test_df["Target"].to_numpy())
    
    optimal_thresholds = np.zeros((11,))
    for i in range(11):
        fpr, tpr, thresholds = metrics.roc_curve(actuals[:, i], predictions[:, i])
        optimal_idx = np.argmax(tpr - fpr)
        optimal_thresholds[i] = thresholds[optimal_idx]

    return optimal_thresholds

def get_optimal_threshold_pr(output_df, data_df):
    test_df = data_df.merge(output_df)
    
    predictions = np.stack(test_df["preds"].to_numpy())
    actuals = np.stack(test_df["Target"].to_numpy())
    
    optimal_thresholds = np.zeros((11,))
    for i in range(11):
        pr, re, thresholds = metrics.precision_recall_curve(actuals[:, i], predictions[:, i])
        fscores = (2 * pr * re) / (pr + re)
        optimal_idx = np.argmax(fscores)
        optimal_thresholds[i] = thresholds[optimal_idx]

    return optimal_thresholds

def get_optimal_threshold_mcc(output_df, data_df):
    test_df = data_df.merge(output_df)
    
    predictions = np.stack(test_df["preds"].to_numpy())
    actuals = np.stack(test_df["Target"].to_numpy())
    
    optimal_thresholds = np.zeros((11,))
    for i in range(11):
        optimal_thresholds[i] = get_best_threshold_mcc(actuals[:, i], predictions[:, i])

    return optimal_thresholds

def calculate_sl_metrics_fold(test_df, thresholds):
    print("Computing fold")
    predictions = np.stack(test_df["preds"].to_numpy())
    outputs = predictions>thresholds
    actuals = np.stack(test_df["Target"].to_numpy())

    ypred_membrane = outputs[:, 0]
    ypred_subloc = outputs[:,1:]
    y_membrane = actuals[:, 0]
    y_subloc = actuals[:,1:]

    metrics_dict = {}

    metrics_dict["NumLabels"] = y_subloc.sum(1).mean()
    metrics_dict["NumLabelsTest"] = ypred_subloc.sum(1).mean()
    metrics_dict["ACC_membrane"] = (ypred_membrane == y_membrane).mean()
    metrics_dict["MCC_membrane"] = matthews_corrcoef(y_membrane, ypred_membrane)
    metrics_dict["ACC_subloc"] = (np.all((ypred_subloc == y_subloc), axis=1)).mean()
    metrics_dict["HammLoss_subloc"] = 1-hamming_loss(y_subloc, ypred_subloc)
    metrics_dict["Jaccard_subloc"] = jaccard_score(y_subloc, ypred_subloc, average="samples")
    metrics_dict["MicroF1_subloc"] = f1_score(y_subloc, ypred_subloc, average="micro")
    metrics_dict["MacroF1_subloc"] = f1_score(y_subloc, ypred_subloc, average="macro")
    for i in range(10):
      metrics_dict[f"{CATEGORIES[1+i]}"] = matthews_corrcoef(y_subloc[:,i], ypred_subloc[:,i])

    # for i in range(10):
    #    metrics_dict[f"{categories[1+i]}"] = roc_auc_score(y_subloc[:,i], predictions[:,i+1])
    return metrics_dict

def calculate_sl_metrics(model_attrs: ModelAttributes, datahandler: DataloaderHandler, thresh_type="mcc", inner_i="1Layer"):
    with open(os.path.join(model_attrs.outputs_save_path, f"thresholds_sl_{thresh_type}.pkl"), "rb") as f:
        threshold_dict = pickle.load(f)
    print(np.array(list(threshold_dict.values())).mean(0))
    metrics_dict_list = {}
    full_data_df = []
    for outer_i in range(5):
        data_df = datahandler.get_partition(outer_i)
        output_df = pd.read_pickle(os.path.join(model_attrs.outputs_save_path, f"{outer_i}_{inner_i}.pkl"))
        data_df = data_df.merge(output_df)
        full_data_df.append(data_df)
        threshold = threshold_dict[f"{outer_i}_{inner_i}"]
        metrics_dict = calculate_sl_metrics_fold(data_df, threshold)
        for k in metrics_dict:
            metrics_dict_list.setdefault(k, []).append(metrics_dict[k])

    output_dict = {}
    for k in metrics_dict_list:
        output_dict[k] = [f"{round(np.array(metrics_dict_list[k]).mean(), 2):.2f} pm {round(np.array(metrics_dict_list[k]).std(), 2):.2f}"]

    print(pd.DataFrame(output_dict).to_latex())
    for k in metrics_dict_list:
        print("{0:21s} : {1}".format(k, f"{round(np.array(metrics_dict_list[k]).mean(), 2):.2f} + {round(np.array(metrics_dict_list[k]).std(), 2):.2f}"))
    for k in metrics_dict_list:
        print("{0}".format(f"{round(np.array(metrics_dict_list[k]).mean(), 2):.2f} + {round(np.array(metrics_dict_list[k]).std(), 2):.2f}"))


def calculate_ss_metrics_fold(y_test, y_test_preds, thresh):
    y_preds = y_test_preds > thresh

    metrics_dict = {}

    metrics_dict["microF1"] = f1_score(y_test, y_preds, average="micro")
    metrics_dict["macroF1"] = f1_score(y_test, y_preds, average="macro")
    metrics_dict["accuracy"] = (np.all((y_preds == y_test), axis=1)).mean()

    for j in range(len(SS_CATEGORIES)-1):
        metrics_dict[f"{SS_CATEGORIES[j+1]}"]  = matthews_corrcoef(y_preds[:, j],y_test[:, j])

    return metrics_dict

def calculate_ss_metrics(model_attrs: ModelAttributes, datahandler: DataloaderHandler, thresh_type="mcc"):
    with open(os.path.join(model_attrs.outputs_save_path, f"thresholds_ss_{thresh_type}.pkl"), "rb") as f:
        threshold_dict = pickle.load(f)
    # print(np.array(list(threshold_dict.values())).mean(0))
    metrics_dict_list = {}
    thresh = np.array([threshold_dict[k] for k in SS_CATEGORIES[1:]])
    
    for outer_i in range(5):
        _,_,_, y_test = datahandler.get_swissprot_ss_xy(model_attrs.outputs_save_path, outer_i)
        y_test_preds = pickle.load(open(f"{model_attrs.outputs_save_path}/ss_{outer_i}.pkl", "rb"))
        metrics_dict = calculate_ss_metrics_fold(y_test, y_test_preds, thresh)
        for k in metrics_dict:
            metrics_dict_list.setdefault(k, []).append(metrics_dict[k])

    output_dict = {}
    for k in metrics_dict_list:
        output_dict[k] = [f"{round(np.array(metrics_dict_list[k]).mean(), 2):.2f} pm {round(np.array(metrics_dict_list[k]).std(), 2):.2f}"]
    print(pd.DataFrame(output_dict).to_latex())

In [None]:
from typing import Any
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from .attr_prior import *
from src.constants import *


pos_weights_bce = torch.tensor([1,1,1,3,2.3,4,9.5,4.5,6.6,7.7,32])
def focal_loss(input, target, gamma=1):
    bceloss = F.binary_cross_entropy_with_logits(input, target, pos_weight=pos_weights_bce.to(input.device), reduction="none")
    logpt = -F.binary_cross_entropy_with_logits(input, target, reduction="none")
    pt = torch.exp(logpt)
    # compute the loss
    focal_loss = ( (1-pt) ** gamma ) * bceloss
    return focal_loss.mean()

class AttentionHead(nn.Module):
      def __init__(self, hidden_dim, n_heads):
          super(AttentionHead, self).__init__()
          self.n_heads = n_heads
          self.hidden_dim = hidden_dim
          self.preattn_ln = nn.LayerNorm(hidden_dim//n_heads)
          self.Q = nn.Linear(hidden_dim//n_heads, n_heads, bias=False)
          torch.nn.init.normal_(self.Q.weight, mean=0.0, std=1/(hidden_dim//n_heads))

      def forward(self, x, np_mask, lengths):
          # input (batch, seq_len, embed)
          n_heads = self.n_heads
          hidden_dim = self.hidden_dim
          x = x.view(x.size(0), x.size(1), n_heads, hidden_dim//n_heads)
          x = self.preattn_ln(x)
          mul = (x * \
                self.Q.weight.view(1, 1, n_heads, hidden_dim//n_heads)).sum(-1) \
                #* np.sqrt(5)
                #/ np.sqrt(hidden_dim//n_heads)
          mul_score_list = []
          for i in range(mul.size(0)):
              # (1, L) -> (1, 1, L) -> (1, L) -> (1, L, 1)
              mul_score_list.append(F.pad(smooth_tensor_1d(mul[i, :lengths[i], 0].unsqueeze(0), 2).unsqueeze(0),(0, mul.size(1)-lengths[i]),"constant").squeeze(0))
          
          mul = torch.cat(mul_score_list, dim=0).unsqueeze(-1)
          mul = mul.masked_fill(~np_mask.unsqueeze(-1), float("-inf"))
          
          attns = F.softmax(mul, dim=1) # (b, l, nh)
          x = (x * attns.unsqueeze(-1)).sum(1)
          x = x.view(x.size(0), -1)
          return x, attns.squeeze(2)

class BaseModel(pl.LightningModule):
    def __init__(self, embed_dim) -> None:
        super().__init__()
       
        self.initial_ln = nn.LayerNorm(embed_dim)
        self.lin = nn.Linear(embed_dim, 256)
        self.attn_head = AttentionHead(256, 1)
        self.clf_head = nn.Linear(256, 11)
        self.kld = nn.KLDivLoss(reduction="batchmean")
        self.lr = 1e-3

    def forward(self, embedding, lens, non_mask):
        x = self.initial_ln(embedding)
        x = self.lin(x)
        x_pool, x_attns = self.attn_head(x, non_mask, lens)
        x_pred = self.clf_head(x_pool)
        #print(x_pred, x_attns)
        return x_pred, x_attns

    def predict(self, embedding, lens, non_mask):
        x = self.initial_ln(embedding)
        x = self.lin(x)
        x_pool, x_attns = self.attn_head(x, non_mask, lens)
        x_pred = self.clf_head(x_pool)
        #print(x_pred, x_attns)
        return x_pred, x_pool, x_attns
    
    def attn_reg_loss(self, y_true, y_attn, y_tags, lengths, n):
        loss = 0
        count = 0
        reg_loss = 0
        for i in range(y_true.size(0)):
            reg_loss += fourier_att_prior_loss_dct(
                  F.pad(y_attn[i, :lengths[i]].unsqueeze(0).unsqueeze(0), (8,8),"replicate").squeeze(1),
                  lengths[i]//6,
                  0.2, 3)
        reg_loss = reg_loss / y_true.size(0)
        kld_loss = 0
        kld_count = 0
        for i in range(y_true.size(0)):
            if y_tags[i].sum() > 0:         
                for j in range(9):
                    if (j+1) in y_tags[i]:
                        pos_tar = (y_tags[i]==(j+1)).float()
                        kld_count += 1
                        kld_loss += pos_weights_annot[j] * self.kld(
                            torch.log(y_attn[i, :+lengths[i]].unsqueeze(0)), 
                            pos_tar[:lengths[i]].unsqueeze(0) / pos_tar[:lengths[i]].sum().unsqueeze(0))
        return reg_loss, kld_loss / torch.tensor(kld_count + 1e-5), kld_count

    def configure_optimizers(self):
        grouped_parameters = [
            {"params": [p for n, p in self.named_parameters()]}
        ]
        optimizer = torch.optim.AdamW(grouped_parameters, lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, mode="min", factor=0.1, patience=1,
                    min_lr=1e-5)
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "bce_loss"
        }

    def training_step(self, batch, batch_idx):
        #self.unfreeze()
        x, l, n, y, y_tags, _ = batch
        y_pred, y_attns =  self.forward(x, l, n)
        reg_loss, seq_loss, seq_count = self.attn_reg_loss(y, y_attns, y_tags, l, n)
        bce_loss = focal_loss(y_pred, y)
        loss = bce_loss + SUP_LOSS_MULT * seq_loss + REG_LOSS_MULT * reg_loss
        self.log('train_loss_batch', loss, on_epoch=True)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        #self.unfreeze()
        x, l, n, y, y_tags, _ = batch
        y_pred, y_attns =  self.forward(x, l, n)
        reg_loss, seq_loss, seq_count = self.attn_reg_loss(y, y_attns, y_tags, l, n)
        bce_loss = focal_loss(y_pred, y)
        loss = bce_loss + SUP_LOSS_MULT * seq_loss + REG_LOSS_MULT * reg_loss
        self.log('val_loss_batch', loss, on_epoch=True)
        self.log('bce_loss', bce_loss, on_epoch=True)
        return {'loss': loss, 
                'seq_loss': seq_loss,
                'reg_loss': reg_loss,
                'bce_loss': bce_loss,
                'seq_count': seq_count}
    

class ProtT5Frozen(BaseModel):
    def __init__(self):
        super().__init__(1024)

class ESM1bFrozen(BaseModel):
    def __init__(self):
        super().__init__(1280)
        

pos_weights_annot = torch.tensor([0.23, 0.92, 0.98, 2.63, 5.64, 1.60, 2.37, 1.87, 2.03])
class SignalTypeMLP(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.Linear(267, 32)
        self.ln2 = nn.Linear(32, 9)
        self.lr = 1e-3

    def forward(self, x):
        x = nn.Tanh()(self.ln1(x))
        x = self.ln2(x)
        return x

    def configure_optimizers(self):
        grouped_parameters = [
            {"params": [p for n, p in self.named_parameters()], 'lr': self.lr},
        ]
        optimizer = torch.optim.AdamW(grouped_parameters, lr=self.lr)
        return optimizer

    def training_step(self, batch, batch_idx):
        #self.unfreeze()
        x, y = batch
        y_pred = self.forward(x)
        loss = nn.BCEWithLogitsLoss(pos_weight=pos_weights_annot.to(y_pred.device))(y_pred, y)
        self.log('train_loss_batch', loss, on_epoch=True)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        #self.freeze()
        x, y = batch
        y_pred = self.forward(x)
        loss = nn.BCEWithLogitsLoss(pos_weight=pos_weights_annot.to(y_pred.device))(y_pred, y)
        self.log('val_loss', loss, on_epoch=True)
        return {'loss': loss}
  

In [None]:
from src.model import *
from src.data import DataloaderHandler
import pickle
from transformers import T5EncoderModel, T5Tokenizer, logging
import os
class ModelAttributes:
    def __init__(self, 
                 model_type: str,
                 class_type: pl.LightningModule, 
                 alphabet, 
                 embedding_file: str, 
                 save_path: str,
                 outputs_save_path: str,
                 clip_len: int,
                 embed_len: int) -> None:
        self.model_type = model_type
        self.class_type = class_type 
        self.alphabet = alphabet
        self.embedding_file = embedding_file
        self.save_path = save_path
        if not os.path.exists(f"{self.save_path}"):
            os.makedirs(f"{self.save_path}")
        self.ss_save_path = os.path.join(self.save_path, "signaltype")
        if not os.path.exists(f"{self.ss_save_path}"):
            os.makedirs(f"{self.ss_save_path}")

        self.outputs_save_path = outputs_save_path

        if not os.path.exists(f"{outputs_save_path}"):
            os.makedirs(f"{outputs_save_path}")
        self.clip_len = clip_len
        self.embed_len = embed_len
        

def get_train_model_attributes(model_type):
    if model_type == FAST:
        with open("models/ESM1b_alphabet.pkl", "rb") as f:
            alphabet = pickle.load(f)
        return ModelAttributes(
            model_type,
            ESM1bFrozen,
            alphabet,
            EMBEDDINGS[FAST]["embeds"],
            "models/models_esm1b",
            "outputs/esm1b/",
            1022,
            1280
        )
    elif model_type == ACCURATE:
        alphabet = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False )
        
        return ModelAttributes(
            model_type,
            ProtT5Frozen,
            alphabet,
            EMBEDDINGS[ACCURATE]["embeds"],            
            "models/models_prott5",
            "outputs/prott5/",
            4000,
            1024
        )
    else:
        raise Exception("wrong model type provided expected Fast,Accurate got", model_type)
    


In [None]:
from sklearn.model_selection import ShuffleSplit
import numpy as np
import pandas as pd
import argparse
from src.constants import * 

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.progress import ProgressBar
import pytorch_lightning as pl
from src.model import *
from src.data import *
from src.utils import *
from src.eval_utils import *


def train_model(model_attrs: ModelAttributes, datahandler:DataloaderHandler, outer_i: int):
    train_dataloader, val_dataloader = datahandler.get_ss_train_val_dataloader(model_attrs.outputs_save_path, outer_i)

    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=model_attrs.ss_save_path,
        filename= f"{outer_i}",
        save_top_k=1,
        every_n_epochs=1,
        save_last=False,
        save_weights_only=True
    )

    early_stopping_callback = EarlyStopping(
         monitor='val_loss',
         patience=5, 
         mode='min'
    )

    # Initialize trainer
    trainer = pl.Trainer(max_epochs=500, 
                        default_root_dir=model_attrs.ss_save_path + f"/{outer_i}",
                        check_val_every_n_epoch = 1,
                        callbacks=[checkpoint_callback, early_stopping_callback],
                        )
                        # precision=16,
                        # gpus=1)
                        #tpu_cores=8)
    clf = SignalTypeMLP()
    print(f"Training clf {outer_i}")
    trainer.fit(clf, train_dataloader, val_dataloader)
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-m","--model", 
        default="Fast",
        choices=['Accurate', 'Fast'],
        type=str,
        help="Model to use."
    )
    args = parser.parse_args()

    model_attrs = get_train_model_attributes(model_type=args.model)
    datahandler = DataloaderHandler(
        clip_len=model_attrs.clip_len, 
        alphabet=model_attrs.alphabet, 
        embedding_file=model_attrs.embedding_file,
        embed_len=model_attrs.embed_len
    )

    print("Training sorting signal type prediction models")
    for i in range(0, 5):
        print(f"Training model {i+1} / 5")
        if not os.path.exists(os.path.join(model_attrs.save_path, f"signaltype/{i}.ckpt")):
            train_model(model_attrs, datahandler, i)
    
    print("Finished training sorting signal type prediction models")

    print("Using trained models to generate outputs of signal prediction")
    generate_ss_outputs(model_attrs=model_attrs, datahandler=datahandler)
    print("Generated outputs!")

    print("Computing sorting signal type prediction performance on swissprot CV dataset")
    calculate_ss_metrics(model_attrs=model_attrs, datahandler=datahandler)


    