<a href="https://colab.research.google.com/github/AtomZa/BadApple-EdgeDetection/blob/main/TopModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install git+https://github.com/ElArkk/jax-unirep.git --upgrade
!pip install optuna
!pip install biopython
!pip install seaborn
     

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/ElArkk/jax-unirep.git
  Cloning https://github.com/ElArkk/jax-unirep.git to /tmp/pip-req-build-d0bdmt4n
  Running command git clone --filter=blob:none --quiet https://github.com/ElArkk/jax-unirep.git /tmp/pip-req-build-d0bdmt4n
  Resolved https://github.com/ElArkk/jax-unirep.git to commit 7763bf69cc7864f8cf466151e452c52f3adc6476
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!pip install feather-format

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:

from jax_unirep import get_reps, fit
from jax_unirep.utils import load_params
from Bio import SeqIO
import pandas as pd
import glob
import os
import numpy as np

from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split, KFold
from sklearn.linear_model import RidgeCV, LinearRegression, HuberRegressor
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsRegressor

import warnings
warnings.filterwarnings('ignore') 

from sklearn.preprocessing import normalize, StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
import feather



In [None]:
# read FASTA file:
# input: file name
# output: names and sequences in the file as an array of dim-2 arrays [name, sequence].
def read_fasta(name):
    fasta_seqs = SeqIO.parse(open(gdrive_path + name + '.fasta.txt'),'fasta')
    data = []
    for fasta in fasta_seqs:
        data.append([fasta.id, str(fasta.seq).strip()])
    
    return data

# read sequence text file:
# input: file name
# output: names and sequences in the file as an array of dim-2 arrays [name, sequence].
def read_labeled_data(name):
    seqs = np.loadtxt("/content/" + name + '_seqs.txt', dtype='str')
    
    fitnesses = np.loadtxt("/content/" + name + '_fitness.txt')
    data = []
    for seq, fitness in zip(seqs, fitnesses):
        data.append([str(seq).strip(), fitness])
    
    return data

# save represented dataframe of features as feather
def save_reps(df, path):
  feather.write_dataframe(df, path + '.feather')
  print(path + '.feather', 'saved!')


# read represented dataframe of features as feather
def read_reps(path):
  return feather.read_dataframe(path + '.feather')


aa_to_int = {
  'M':1,
  'R':2,
  'H':3,
  'K':4,
  'D':5,
  'E':6,
  'S':7,
  'T':8,
  'N':9,
  'Q':10,
  'C':11,
  'U':12,
  'G':13,
  'P':14,
  'A':15,
  'V':16,
  'I':17,
  'F':18,
  'Y':19,
  'W':20,
  'L':21,
  'O':22, #Pyrrolysine
  'X':23, # Unknown
  'Z':23, # Glutamic acid or GLutamine
  'B':23, # Asparagine or aspartic acid
  'J':23, # Leucine or isoleucine
  'start':24,
  'stop':25,
}


def get_int_to_aa():
  return {value:key for key, value in aa_to_int.items()}


def _one_hot(x, k, dtype=np.float32):
  # return np.array(x[:, None] == np.arange(k), dtype)
  return np.array(x[:, None] == np.arange(k))


def aa_seq_to_int(s):
  """Return the int sequence as a list for a given string of amino acids."""
  # Make sure only valid aa's are passed
  if not set(s).issubset(set(aa_to_int.keys())):
    raise ValueError(
      f"Unsupported character(s) in sequence found:"
      f" {set(s).difference(set(aa_to_int.keys()))}"
    )

  return [aa_to_int[a] for a in s]


def aa_seq_to_onehot(seq):
  return 1*np.equal(np.array(aa_seq_to_int(seq))[:,None], np.arange(21)).flatten()
  

def multi_onehot(seqs):
  return np.stack([aa_seq_to_onehot(s) for s in seqs.tolist()])


def distance_matrix(N):
	distance_matrix = np.zeros((N,N))
	for i in range(N):
		for j in range(N):
			# distance_matrix[i,j]=1- ((abs(i-j)/N)**2)
			distance_matrix[i,j]= 1-(abs(i-j)/N)

	return distance_matrix


def confusion_matrix_loss(Y_test,Y_preds_test):

  N = len(Y_test)
  Y_rank_matrix = np.zeros((N,N))
  Y_preds_rank_matrix = np.zeros((N,N))
  for i in range(N):
    for j in range(N):

      if Y_test[i] > Y_test[j]:
        Y_rank_matrix[i,j] = 1
      elif Y_test[i] <= Y_test[j]:
        Y_rank_matrix[i,j] = 0
      if Y_preds_test[i] > Y_preds_test[j]:
        Y_preds_rank_matrix[i,j] = 1
      elif Y_preds_test[i] <= Y_preds_test[j]:
        Y_preds_rank_matrix[i,j] = 0
  confusion_matrix = ~(Y_preds_rank_matrix == Y_rank_matrix)
  # dist_mat = distance_matrix(N)
  # confusion_matrix = confusion_matrix*dist_mat
  loss = np.sum(confusion_matrix)/confusion_matrix.size

  return loss

# Load in sequences, get representations then save them


In [None]:
# load labeled training data
seqs_df = pd.DataFrame(read_labeled_data('TEM1'), columns = ['sequence', 'fitness'])

## Unirep

In [None]:
# define key params
DIR_PATH = None

PARAMS = [None]

# get representations of data for each params:
N_seqs = len(seqs_df)
print("N_seqs:", N_seqs)

# vary batches based on memory available (i.e. if you have less memory run more batches)
# this google colab can handle around 1000 seqs per batch for sure
N_BATCHES = 1

BATCH_LEN = int(np.ceil(N_seqs/N_BATCHES))

for param in PARAMS:
  # append path to param unless unirep (no param)
  if param == 'one_hot':
    print('getting reps for one hot')
    onehot = multi_onehot(seqs_df.sequence)
    feat_cols = [ 'feat' + str(j) for j in range(1, onehot.shape[1] + 1) ]
    this_df = pd.DataFrame(onehot, columns=feat_cols)
    this_df.insert(0, "sequence", seqs_df.sequence)
    this_df.insert(1, "fitness", seqs_df.fitness)

    save_reps(this_df, 'one_hot')

    continue

  elif param is None:
    name = 'unirep'

  else:
    name = param
    param = load_params(DIR_PATH + param)

  print('getting reps for', name)

  # get 1st sequence
  reps, _, _ = get_reps(seqs_df.sequence[0], params=param)
  feat_cols = [ 'feat' + str(j) for j in range(1, reps.shape[1] + 1) ]
  this_df = pd.DataFrame(reps, columns=feat_cols)
  this_df.insert(0, "sequence", seqs_df.sequence[0])
  this_df.insert(1, "fitness", seqs_df.fitness[0])

  # get the rest in batches
  for i in range(N_BATCHES):
    this_unirep, _, _ = get_reps(seqs_df.sequence[ (1 + i*BATCH_LEN) : min( 1 + (i+1)*BATCH_LEN, N_seqs ) ] , params=None)
    this_unirep_df = pd.DataFrame(this_unirep, columns=feat_cols)
    this_unirep_df.insert(0, "sequence", seqs_df.sequence[ (1 + i*BATCH_LEN) : min( 1 + (i+1)*BATCH_LEN, N_seqs ) ].reset_index(drop=True))
    this_unirep_df.insert(1, "fitness", seqs_df.fitness[ (1 + i*BATCH_LEN) : min( 1 + (i+1)*BATCH_LEN, N_seqs ) ].reset_index(drop=True))
    this_df = pd.concat([this_df.reset_index(drop=True), this_unirep_df.reset_index(drop=True)]).reset_index(drop=True)

  save_reps(this_df, '/content/sample_data')



N_seqs: model_weights.pkl
getting reps for unirep


KeyboardInterrupt: ignored

In [None]:
this_df

Unnamed: 0,sequence,fitness,feat1,feat2,feat3,feat4,feat5,feat6,feat7,feat8,...,feat1891,feat1892,feat1893,feat1894,feat1895,feat1896,feat1897,feat1898,feat1899,feat1900
0,ASIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,0.002018,0.006393,0.061515,0.082276,-0.069812,-0.008327,0.043463,-0.029137,-0.027969,...,0.213317,0.026809,-0.115965,0.0462,-0.109487,0.04542,0.145132,-0.085517,-0.133907,-0.024836
1,CSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,0.003023,0.006385,0.064631,0.074807,-0.067615,-0.028016,0.043017,-0.040086,-0.024445,...,0.21173,0.027225,-0.114463,0.046574,-0.107344,0.052469,0.135533,-0.068176,-0.126432,-0.0233
2,DSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,0.00222,0.006561,0.071035,0.076654,-0.064301,-0.06553,0.05606,-0.023186,-0.024507,...,0.24009,0.015659,-0.132623,0.047833,-0.103569,0.038674,0.132507,-0.072743,-0.127861,-0.022262
3,ESIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,0.002378,0.006579,0.069527,0.076349,-0.064894,-0.052339,0.051859,-0.025333,-0.027768,...,0.237357,0.019714,-0.129176,0.045762,-0.102923,0.038223,0.136312,-0.082839,-0.131641,-0.021245
4,FSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,0.005702,0.00656,0.068791,0.074779,-0.064711,-0.064218,0.049007,-0.03767,-0.028824,...,0.233461,0.01535,-0.121269,0.043791,-0.102301,0.04061,0.130741,-0.090163,-0.123363,-0.022385
5,GSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,0.001208,0.006485,0.066348,0.074544,-0.066141,-0.041925,0.051375,-0.03545,-0.028355,...,0.231518,0.018476,-0.130188,0.045746,-0.10422,0.040065,0.136845,-0.0908,-0.131665,-0.021119
6,HSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,0.001663,0.006666,0.069321,0.067952,-0.063681,-0.071188,0.052399,-0.041238,-0.028201,...,0.249168,0.012875,-0.128579,0.04505,-0.099268,0.034813,0.127197,-0.083517,-0.127203,-0.022171
7,ISIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,0.378032,0.006898,0.068589,0.071081,-0.062204,-0.099893,0.054344,-0.048769,-0.033369,...,0.261778,0.002846,-0.118769,0.042248,-0.098126,0.028515,0.119106,-0.090477,-0.120277,-0.021998
8,KSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,0.00264,0.007348,0.062467,0.071654,-0.060362,-0.104889,0.060924,-0.040631,-0.028453,...,0.2761,-0.004392,-0.128623,0.041715,-0.100384,0.026015,0.122306,-0.086363,-0.122576,-0.019852


## Other

In [None]:

# define key params
gdrive_path = "/content/ensemble/"
PARAMS = [None, 'one_hot', 'TEM1_epoch1_weights']

# get representations of data for each params:
N_seqs = len(seqs_df)
print("N_seqs:", N_seqs)

# vary batches based on memory available (i.e. if you have less memory run more batches)
# this google colab can handle around 1000 seqs per batch for sure
N_BATCHES = 1

BATCH_LEN = int(np.ceil(N_seqs/N_BATCHES))

for param in PARAMS:
  # append path to param unless unirep (no param)
  if param == 'one_hot':
    print('getting reps for one hot')
    onehot = multi_onehot(seqs_df.sequence)
    feat_cols = [ 'feat' + str(j) for j in range(1, onehot.shape[1] + 1) ]
    this_df = pd.DataFrame(onehot, columns=feat_cols)
    this_df.insert(0, "sequence", seqs_df.sequence)
    this_df.insert(1, "fitness", seqs_df.fitness)

    save_reps(this_df, gdrive_path + 'one_hot')

    continue

  elif param is None:
    name = 'unirep'

  else:
    name = param
    param = load_params("/content/ensemble/TEM1_epoch1_weights", 1900)

  print('getting reps for', name)

  # get 1st sequence
  reps, _, _ = get_reps(seqs_df.sequence[0], params=param)
  feat_cols = [ 'feat' + str(j) for j in range(1, reps.shape[1] + 1) ]
  this_df = pd.DataFrame(reps, columns=feat_cols)
  this_df.insert(0, "sequence", seqs_df.sequence[0])
  this_df.insert(1, "fitness", seqs_df.fitness[0])

  # get the rest in batches
  for i in range(N_BATCHES):
    this_unirep, _, _ = get_reps(seqs_df.sequence[ (1 + i*BATCH_LEN) : min( 1 + (i+1)*BATCH_LEN, N_seqs ) ] , params=param)
    this_unirep_df = pd.DataFrame(this_unirep, columns=feat_cols)
    this_unirep_df.insert(0, "sequence", seqs_df.sequence[ (1 + i*BATCH_LEN) : min( 1 + (i+1)*BATCH_LEN, N_seqs ) ].reset_index(drop=True))
    this_unirep_df.insert(1, "fitness", seqs_df.fitness[ (1 + i*BATCH_LEN) : min( 1 + (i+1)*BATCH_LEN, N_seqs ) ].reset_index(drop=True))
    this_df = pd.concat([this_df.reset_index(drop=True), this_unirep_df.reset_index(drop=True)]).reset_index(drop=True)

  save_reps(this_df, gdrive_path + name)



N_seqs: 9
getting reps for unirep
/content/ensemble/unirep.feather saved!
getting reps for one hot
/content/ensemble/one_hot.feather saved!
getting reps for TEM1_epoch1_weights


AttributeError: ignored

In [None]:
param

[Array([[-4.32526559e-01,  7.53880665e-02, -2.11843640e-01,
          6.06281981e-02,  2.77478129e-01, -8.81439075e-02,
          1.04718730e-01,  1.55644700e-01,  5.23038745e-01,
         -3.18140000e-01],
        [ 4.48811613e-02, -3.00060481e-01, -1.45759070e-02,
          1.20491721e-01, -3.09778657e-02,  5.83414081e-03,
         -1.11657895e-01, -5.27199030e-01, -4.85000171e-04,
         -1.91193491e-01],
        [-3.17000806e-01,  1.15881167e-01, -8.25776905e-03,
          5.00198007e-01, -5.91420457e-02,  1.04788348e-01,
         -2.08030924e-01, -2.43377000e-01,  1.65767953e-01,
         -2.02437952e-01],
        [ 5.75120747e-02, -2.68256575e-01,  1.71969414e-01,
          2.34683350e-01,  2.23891228e-01,  8.98892432e-02,
          2.20527779e-02, -3.40744466e-01, -1.12043381e-01,
          1.92534237e-03],
        [ 2.48875588e-01,  4.87032175e-01,  3.21024805e-01,
          9.91747975e-02,  3.95841241e-01,  2.97862347e-02,
          8.90410990e-02, -5.97915724e-02,  3.610234