# Cas-Detector

## 1. Pre-processing

In [None]:
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
import os
import re
import math
import random
import time
from datetime import datetime
import copy
import glob
import gc
import pickle
import h5py

In [None]:
# 预定义的变量
project_path = '/home/fengchr/staff/Cas-detector'

# 设定随机数种子
torch.manual_seed(201)
torch.cuda.manual_seed(88)

# 模型超参数
noncas_weight = 0.25 #nonCas对Cas样本的权重比
window_size = 64  # 片段长度
window_step = 25
lr = 0.015
momentum = 0.8
batch_size = 256
lr_fixed_decay = 0.9
lr_fixed_decay_every = 2500
inspect_loss_decay = 0.75  # 每当loss下降为原来的inspect_loss_decay，便更新一次lr
lr_adapt_decay = 0.75  # 更新lr为原来的lr_adapt_decay
inspect_num = 5  # 以inspect_num个batch为单位判定模型收敛程度


In [None]:
#tools
def read_fasta(input):
    with open(input,'r') as f:
        fasta = {}
        for line in f:
            line = line.strip()
            if line[0] == '>':
                header = line[1:]
            else:
                sequence = line
                fasta[header] = fasta.get(header,'') + sequence
        return fasta
def split_seq(seq,window_size,step=window_step):#具体step为多少需要外部判断
    if len(seq) < window_size:
        return([])
    out=[seq[i:i+window_size] for i in range(random.randint(0,step-1),len(seq)-window_size+1,step)]
    return(out)
def split_seq_rand(seq,window_size,num):#具体num为多少需要外部判断
    if len(seq) < window_size:
        return([])
    out=[]
    for i in range(num):
        tmp=random.randint(0,len(seq)-window_size)
        out.append(seq[tmp:tmp+window_size])
    return(out)
def one_hot_encode(lst):
    encode_dict = {}
    for i in range(len(lst)):
        tmp=np.zeros(len(lst))
        tmp[i]=1
        encode_dict[lst[i]]=tmp
    return encode_dict
encode_dict=one_hot_encode(["A","R","N","D","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","Y","V","-"])
print("encode_dict:")
print(encode_dict)
def encode_aa_lst(aa_lst):
    return([torch.tensor(np.array([encode_dict.get(j,encode_dict["-"]) for j in list(aa)]), dtype=torch.float32) for aa in aa_lst])

def save_model(run_header, epoch, model):
    state = {
        'state_dict': model.state_dict()
    }
    torch.save(state, project_path + "/"+run_header+".epoch_"+str(epoch)+".model.pkl")
def load_model(run_header, epoch):
    state = torch.load(project_path + "/"+run_header+".epoch_"+str(epoch)+".model.pkl")
    tmp = MyModel()
    tmp.to(device)
    tmp.load_state_dict(state['state_dict'])
    tmp.eval()
    return(tmp)

In [None]:
from transformers import T5Tokenizer, T5EncoderModel

device = torch.device('cuda:0')

# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained(project_path + '/embedding/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

# Load the model
model = T5EncoderModel.from_pretrained(project_path + '/embedding/prot_t5_xl_half_uniref50-enc').to(device)

# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
model.half()
model = model.eval()

In [None]:
with open(project_path + '/data/target_encode_dict.pkl', 'rb') as file:
    target_encode_dict = pickle.load(file)

all_type_num = len(target_encode_dict)
cas_types = list(target_encode_dict.keys())[-1]
weight = torch.tensor([1]*(all_type_num - 1)+[noncas_weight])
print(target_encode_dict)
print("All types num: " + str(all_type_num))

In [None]:
with h5py.File(project_path + '/data/data_cas_12_26.h5', 'r') as f:
    cas_emb = np.array(f['cas_emb'])
    cas_labels = np.array(f['cas_labels'])
with h5py.File(project_path + '/data/data_noncas.1.h5', 'r') as f:
    noncas_emb = np.array(f['noncas_emb'])
    noncas_labels = np.array(f['noncas_labels'])
cas_set1 = [(cas_emb[i],cas_labels[i]) for i in range(len(cas_labels))]
noncas_set1 = [(noncas_emb[i],noncas_labels[i]) for i in range(len(noncas_labels))]
random.shuffle(cas_set1)
random.shuffle(noncas_set1)
del cas_emb,cas_labels,noncas_emb,noncas_labels

print(str(len(cas_set1)))
print(str(len(noncas_set1)))

In [None]:
#定义数据加载类
class MyDataset(Dataset):
    def __init__(self, dataset, k, state='train'):#k:十折交叉验证index
        overall_size = len(dataset)
        if state == 'train':
            index2 = list(range(int((k % 10) * overall_size / 10))) + \
                     list(range(int((k % 10 + 1) * overall_size / 10), overall_size))
            self.inputset = [dataset[k][0] for k in index2]
            self.targetset = [dataset[k][1] for k in index2]
        if state == 'test':
            index2 = list(range(int((k % 10) * overall_size / 10), int((k % 10 + 1) * overall_size / 10)))
            self.inputset = [dataset[k][0] for k in index2]
            self.targetset = [dataset[k][1] for k in index2]
        if state == 'all':
            self.inputset = [k[0] for k in dataset]
            self.targetset = [k[1] for k in dataset]

    def __getitem__(self, idx):
        input = self.inputset[idx]
        label = self.targetset[idx]
        return input, label

    def __len__(self):
        return len(self.targetset)

In [None]:
#定义模型
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv1 = nn.Conv2d(1, 128, (11,31), stride=(1,3), padding=(5,5))
        self.conv2 = nn.Conv2d(128, 128, (11,31), stride=(1,3), padding=(5,5))
        self.conv3 = nn.Conv2d(128, 64, (7,13), stride=(1,1), padding=(3,6))
        self.conv4 = nn.Conv2d(64, 64, (5,7), stride=(1,1), padding=(2,3))
        self.fc1 = nn.Linear(64 * 16 * 24, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, all_type_num)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) #[batch_size,128,32,167]
        x = self.pool(F.relu(self.conv2(x))) #[batch_size,128,16,24]
        x = F.relu(self.conv3(x)) #[batch_size,64,16,24]
        x = F.relu(self.conv4(x)) #[batch_size,64,16,24]
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x
    

In [None]:
#定义训练方法
cuda_avail = torch.cuda.is_available()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(cuda_avail)

#writer
current_time = datetime.now()
formatted_time = current_time.strftime("%m-%d-%H-%M-%S")
run_header=formatted_time
writer = SummaryWriter(project_path + "/model/" + run_header)
j = 0
writer_title=formatted_time+" fold="+str(j)

#training data
my_cas_data = MyDataset(dataset=cas_set1, k=j, state='train')
my_cas_loader = torch.utils.data.DataLoader(my_cas_data, batch_size=int(batch_size/4), shuffle=True, num_workers=0, drop_last=True)
my_noncas_data = MyDataset(dataset=noncas_set1, k=j, state='train')
my_noncas_loader = torch.utils.data.DataLoader(my_noncas_data, batch_size=int(3*batch_size/4), shuffle=True, num_workers=0, drop_last=True)

#validation data
val_batch_num = 128*4
my_val_cas_data = MyDataset(dataset=cas_set1, k=j, state='test')
my_val_cas_loader = torch.utils.data.DataLoader(my_val_cas_data, batch_size=int(len(my_val_cas_data)/val_batch_num), shuffle=False, num_workers=0, drop_last=True)
# tmp=enumerate(my_val_cas_loader, 0)
# _, val_cas_batch=next(tmp)
# val_cas_inputs, val_cas_labels = val_cas_batch
# val_cas_inputs = val_cas_inputs.unsqueeze(1).to(device)
# val_cas_labels = val_cas_labels.to(device)

my_val_noncas_data = MyDataset(dataset=noncas_set1, k=j, state='test')
my_val_noncas_loader = torch.utils.data.DataLoader(my_val_noncas_data, batch_size=int(len(my_val_noncas_data)/val_batch_num), shuffle=False, num_workers=0, drop_last=True)
# tmp=enumerate(my_val_noncas_loader, 0)
# _, val_noncas_batch=next(tmp)
# val_noncas_inputs, val_noncas_labels = val_noncas_batch
# val_noncas_inputs = val_noncas_inputs.unsqueeze(1).to(device)
# val_noncas_labels = val_noncas_labels.to(device)

#model
my_model = MyModel()
my_model.to(device)

#criterion
weight = weight.to(device)
criterion = nn.CrossEntropyLoss(weight = weight)

#optimizer
optimizer = optim.SGD(my_model.parameters(), lr=lr, momentum=momentum)

epoch_num = 1
write_every = 1
test_every = 2500
running_loss = []
val_loss=[]
val_accuracy=[]
epoch = -1
inspect = False
initialize = True
converge_level = 0
inspect_loss = 0
jj = -1
noncas_enum = enumerate(my_noncas_loader, 0)
def generate_numbers(n):
    num = 2
    while True:
        yield num
        num = num % n + 1
noncas_generater=generate_numbers(15)


In [None]:
# 人控训练
for ss in range(epoch_num):
    epoch += 1
    my_model.train()
    for ii, batch_cas in enumerate(my_cas_loader, 0):
        jj+=1
        #加载noncas batch
        tmp, batch_noncas = next(noncas_enum , (None, None))
        if tmp is None:
            #加载下一个noncas数据集
            noncas_iter = next(noncas_generater)
            del noncas_set1, my_noncas_data, my_noncas_loader
            gc.collect()
            print("Loading noncas data "+str(noncas_iter)+"...")
            with h5py.File(project_path + '/data/data_noncas.'+str(noncas_iter)+'.h5', 'r') as f:
                noncas_emb = np.array(f['noncas_emb'])
                noncas_labels = np.array(f['noncas_labels'])
            noncas_set1 = [(noncas_emb[i],noncas_labels[i]) for i in range(len(noncas_labels))]
            random.shuffle(noncas_set1)
            del noncas_emb,noncas_labels
            my_noncas_data = MyDataset(dataset=noncas_set1, k=j, state='train')
            my_noncas_loader = torch.utils.data.DataLoader(my_noncas_data, batch_size=int(3*batch_size/4), shuffle=True, num_workers=0, drop_last=True)
            noncas_enum = enumerate(my_noncas_loader, 0)
            tmp, batch_noncas = next(noncas_enum , (None, None))
        inputs_cas, labels_cas = batch_cas
        inputs_noncas, labels_noncas = batch_noncas
        inputs = torch.cat((inputs_cas, inputs_noncas), dim=0)
        inputs = inputs.unsqueeze(1).to(device,dtype=torch.float32)
        labels = torch.cat((labels_cas, labels_noncas), dim=0)
        labels = labels.to(device)
        #训练步骤
        optimizer.zero_grad()
        outputs = my_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        #记录训练loss
        running_loss.append(loss.item())

        if initialize and epoch == 0 and ii == 3 * inspect_num + 10:
            inspect_loss = sum(running_loss[-inspect_num:]) / inspect_num
            inspect = True
            initialize = False
            print("inspect loss initialized as: " + str(inspect_loss))
            print("first inspect loss cutoff: " + str(inspect_loss_decay * inspect_loss))

        if jj % write_every == write_every - 1:
            writer.add_scalars(writer_title, {'loss': sum(running_loss[-write_every:]) / write_every,
                                               '100*lr': optimizer.param_groups[0]['lr'] * 100}, jj)
            print("[%s]-[%d]-[%d] lr=[%r] loss=[%.3f]" % (
            j, epoch + 1, ii + 1, optimizer.param_groups[0]['lr'], sum(running_loss[-write_every:]) / write_every))
        
        # loss每降低一半，相应lr也调整为原来的一半
        if inspect:
            if sum(running_loss[-inspect_num:]) / inspect_num < inspect_loss_decay * inspect_loss:
                lr_0 = lr
                lr *= lr_adapt_decay
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr
                inspect_loss *= inspect_loss_decay
                converge_level += 1
                print("lr adapted to " + str(lr_adapt_decay) + "*" + str(lr_0) + "=" + str(lr))
                print("new inspect loss cutoff: " + str(inspect_loss_decay * inspect_loss))
        if jj % test_every == test_every - 1:
            my_model.eval()
            with torch.no_grad():
                epoch_cas_accuracy=[]
                epoch_noncas_accuracy=[]
                epoch_val_cas_losses=[]
                epoch_val_noncas_losses=[]
                tmp1 = enumerate(my_val_cas_loader)
                tmp2 = enumerate(my_val_noncas_loader)
                for t in range(val_batch_num):
                    _, val_cas_batch = next(tmp1)
                    _, val_noncas_batch = next(tmp2)

                    val_cas_inputs, val_cas_labels = val_cas_batch
                    val_cas_inputs = val_cas_inputs.unsqueeze(1).to(device,dtype=torch.float32)
                    val_cas_labels = val_cas_labels.to(device)
                    val_cas_outputs = my_model(val_cas_inputs)
                    loss_cas = criterion(val_cas_outputs, val_cas_labels)
                    val_cas_max_indices = torch.argmax(val_cas_outputs, dim=1)
                    val_cas_acc = torch.sum(val_cas_max_indices == val_cas_labels).item()/val_cas_labels.size()[0]

                    val_noncas_inputs, val_noncas_labels = val_noncas_batch
                    val_noncas_inputs = val_noncas_inputs.unsqueeze(1).to(device,dtype=torch.float32)
                    val_noncas_labels = val_noncas_labels.to(device)
                    val_noncas_outputs = my_model(val_noncas_inputs)
                    loss_noncas = criterion(val_noncas_outputs, val_noncas_labels)
                    val_noncas_max_indices = torch.argmax(val_noncas_outputs, dim=1)
                    val_noncas_acc = torch.sum(val_noncas_max_indices == val_noncas_labels).item()/val_noncas_labels.size()[0]

                    epoch_val_cas_losses.append(loss_cas.item())
                    epoch_val_noncas_losses.append(loss_noncas.item())
                    epoch_cas_accuracy.append(val_cas_acc)
                    epoch_noncas_accuracy.append(val_noncas_acc)
                val_loss.append((sum(epoch_val_cas_losses)/val_batch_num,sum(epoch_val_noncas_losses)/val_batch_num))
                val_accuracy.append((sum(epoch_cas_accuracy)/val_batch_num,sum(epoch_noncas_accuracy)/val_batch_num))
            writer.add_scalars(writer_title,{'val_cas_loss': val_loss[-1][0],'val_noncas_loss': val_loss[-1][1],'val_average_loss': (val_loss[-1][0]+val_loss[-1][1])/2,'val_cas_accuracy': val_accuracy[-1][0],'val_noncas_accuracy': val_accuracy[-1][1],'val_average_accuracy': (val_accuracy[-1][0]+val_accuracy[-1][1])/2},jj)
            print(writer_title)
            print('val_cas_loss:',str(val_loss[-1][0]))
            print('val_noncas_loss:',str(val_loss[-1][1]))
            print('val_cas_accuracy:',str(val_accuracy[-1][0]))
            print('val_noncas_accuracy:',str(val_accuracy[-1][1]))
        if jj % lr_fixed_decay_every == lr_fixed_decay_every - 1:
            lr *= lr_fixed_decay
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr
    # 每个epoch保存一次model：
    save_model(run_header=run_header, epoch=epoch, model=my_model)

In [None]:
#加载模型
import importlib
from model_12_27_10_28_52 import MyModel

my_model = load_model('12-27-10-28-52',2)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.io as pio
from collections import Counter
def split_seq_fixed(seq,window_size,step=window_step):
    if len(seq) < window_size:
        return([])
    out=[seq[i:i+window_size] for i in range(0,len(seq)-window_size+1,step)]
    return(out)
def predict_seq(mymodel, seq, step=16, title='sequence'):
    mymodel.eval()
    with torch.no_grad():
        seq = seq.replace("\n","")
        split_lst = split_seq_fixed(seq,window_size,step=step)
        if len(split_lst) == 0:
            print("\""+title+"\" is too short!")
            return
        seqs_batch = [" ".join(list(re.sub(r"[UZOB]", "X", seq))) for seq in split_lst]
        ids = tokenizer(seqs_batch, add_special_tokens=True, padding="longest")
        input_ids = torch.tensor(ids['input_ids']).to(device)
        attention_mask = torch.tensor(ids['attention_mask']).to(device)
        embedding_repr = model(input_ids=input_ids, attention_mask=attention_mask)
        emb = embedding_repr.last_hidden_state[:,:window_size].detach().unsqueeze(1).to(device,dtype=torch.float32) # shape (batch x window_size x 1024)
        outputs = mymodel(emb)
        outputs = torch.nn.Softmax(dim=1)(outputs)
        
        mycolors = ["#91F876","#F64A4A","#81FBFB","#FFA07A","#008000","#C0C0C0","#0280CE","#800080","#F0F01F","#FF3EEE","#000000"]
        #饼图
        type_score = np.sum(np.array(outputs.cpu()), axis=0)
        all_types = list(target_encode_dict.keys())
        plotdata = pd.DataFrame({"type":all_types,"type_score":type_score})
        fig2 = go.Figure(data=[go.Pie(labels=all_types, values=type_score, marker=dict(colors=mycolors), pull=[0]*(len(all_types)-1)+[0.2])])
        fig2.update_layout(title=title, title_x=0.5)
        
        #折线图
        values = np.array(outputs.view(-1).cpu())
        types = all_types*len(outputs)
        indices = [i+(window_size+1)/2 for i in list(range(0,len(seq)-window_size+1,step)) for _ in range(len(all_types))]
        annotations = ["["+str(int(indices[i]-(window_size-1)/2))+", "+str(int(indices[i]+(window_size-1)/2))+"]: "+str(types[i])+"("+str(round(values[i],3))+")" for i in range(len(types))]
        plotdata = pd.DataFrame({"indices":indices,"types":types,"values":values,"annotations":annotations})
#         plt.figure(figsize=(15,8))
#         custom_palette = sns.color_palette(mycolors)
#         plot = sns.lineplot(data=plotdata,x="indices",y="values",hue="types",palette=custom_palette)
#         plot.set_title(title)
#         plot.set_xlabel('pos')
#         plot.set_ylabel('score')
#         plt.axhline(y=0.5, color='red', linestyle='--', linewidth=1, label='Threshold')
        grouped_data = plotdata.groupby('types')
        sorted_types = [all_types[i] for i in sorted(range(len(all_types)), key=lambda i: type_score[i], reverse=True)]
        fig1 = go.Figure()
        fig1.add_shape(
            type='line',
            x0=0,  # 设置水平线的起点 x 坐标
            y0=0.5,  # 设置水平线的 y 坐标
            x1=len(seq),  # 设置水平线的终点 x 坐标
            y1=0.5,  # 设置水平线的 y 坐标
            line=dict(color='#F4D03F', width=2, dash='dash')  # 设置线条的颜色、宽度和样式为虚线
        )
        fig1.add_shape(
            type='line',
            x0=0,  # 设置水平线的起点 x 坐标
            y0=0.8,  # 设置水平线的 y 坐标
            x1=len(seq),  # 设置水平线的终点 x 坐标
            y1=0.8,  # 设置水平线的 y 坐标
            line=dict(color='red', width=2, dash='dash')  # 设置线条的颜色、宽度和样式为虚线
        )
        for group in sorted_types:
            data = grouped_data.get_group(group)
            fig1.add_trace(go.Scatter(x=data['indices'], y=data['values'], mode='lines+markers',line=dict(color=mycolors[all_types.index(group)]), name=group, text=data['annotations'], textposition='top center', hoverinfo='text'))
        fig1.update_layout(title=title, title_x=0.5, xaxis_title='pos', yaxis_title='score', xaxis=dict(range=[0, len(seq)]), yaxis=dict(range=[0-0.05, 1+0.05]))
        
        #柱形图
        
        
        #预测结果统计
        max_indices = torch.argmax(outputs, dim=1)
        preds = [all_types[i] for i in np.array(max_indices.cpu())]
        counter = Counter(preds)
        for element, count in counter.most_common():
            print(f"{element}: {count}")
        
        fig1.show()
        fig2.show()
        return

In [None]:
predict_seq(my_model, "MMLPPRLQEHLATFLPAIRVGIDFGESAGGIAVVQGNQILHAETYIDFHNSDLGQRRQLRRGRRTRRAKKMRLARLRSWVLRQKLPDGTRLPDPYVVMRYPKFHVQPGVFKTKTPGRDSATVPSWIDLAKQGKVDASGFVRALTLIFQKRGFKWDAIELAKMTDEKMKDFLQTARVPSDNLASDIREEIQRRRQDSDSSGRGKKKVSPDELLALLEQARERQPQPRVAEHRSVKEADLRSAVEGFGNSINLTKATIQRWQRELSGLLNKVLRPARFENRLRTGCAWCGKPTPRKSKVREVAYEAAVRNLRVREGRSIRPLRPEEFAMFTQWWRLRGQAGESQQGGSEQKRSRKDRSQAIPKLKGIQSYLKKLGAQEQMARQIFDLLWNEKPQGRASLCQQHLSEAAQGRTMKDVVGEWHKVKVRKAPNPCREQHDTRVLHRLEQILFRPGKNGPDAWRYGPVQFITLEVPKPQTEQARRGEQKLRKPESFMERLRKETDGVCMYCDSSPPRPAEDKDHIFPQSRGGPDVWDNLVPVCRSCNMEKGNRTPFEWIGADGERWRRFTERVEGLAARGVRIEREDGKEETVRISERKRALLISQDTEYPDNPTSLAHVGARPRQFVVALRKLFEDRGVTAPSVNFESGVPFVQRIDGRTTFQLRKSWLKKADGSDNFPTKNEWDLLNHAQDAALIAACPPHTWRDLIFTHSAERPVWDGSWKQIPGLAIPALAPDWAEYLERRTWPLVKVLGRYPVSWKRKFADLTFSQNPDSVDDKRLVQYLLIADMPHSGKGPNDKRHPAETEIVNPALDKKFRAVATALGLKKKQTIPEKSLQEEFPGIRHVKVRKQRGGRLVRVKPKDGPPRKVEIKGASEAAVFWVANDDPLRDLRISVRWPVILGAMNVPRYEPSIPADARILATWSRYQLVRLGPEVTQKVGFYRVKEFDDSSVTVLPENAVPDAVAKRLNLKRQETEQETPEGSSPSQEIKLGKTLLMRYFQSLKGGKHHDPGTGS")

In [None]:
cas9d_fa1 = read_fasta(project_path + "/data/cas9d/"+"CN116096877A.FASTA")
cas9d_fa2 = read_fasta(project_path + "/data/cas9d/"+"WO2021202559A1.FASTA")
cas9d_fa3 = read_fasta(project_path + "/data/cas9d/"+"WO2023081855A1.FASTA")
cas9d_fa4 = read_fasta(project_path + "/data/cas9d/"+"WO2023097262A1.FASTA")
cas9d_fa5 = read_fasta(project_path + "/data/cas9d/"+"WO2023097282A1.FASTA")

In [None]:
for i in ['SEQ ID NO. 1294 in WO2023097282A1']:
    predict_seq(my_model, cas9d_fa5[i],title=i,step=4)
    print("--------------------------------------------------------------------------------------------------------------------------------")

In [None]:
for i in list(cas9d_fa2.keys()):
    predict_seq(my_model, cas9d_fa2[i],title=i)
    print("--------------------------------------------------------------------------------------------------------------------------------")

In [None]:
for i in list(cas9d_fa3.keys()):
    predict_seq(my_model, cas9d_fa3[i],title=i,step=4)
    print("--------------------------------------------------------------------------------------------------------------------------------")

In [None]:
for i in list(cas9d_fa4.keys()):
    predict_seq(my_model, cas9d_fa4[i],title=i,step=4)
    print("--------------------------------------------------------------------------------------------------------------------------------")

In [None]:
for i in ['SEQ ID NO. 978 in WO2023097282A1']:
    predict_seq(my_model, cas9d_fa5[i],title=i,step=4)
    print("--------------------------------------------------------------------------------------------------------------------------------")

In [None]:
test=read_fasta('/home/fengchr/staff/test_12_6/test.fa')
for i in list(test.keys()):
    predict_seq(my_model, test[i],title=i,step=4)
    print("--------------------------------------------------------------------------------------------------------------------------------")

In [None]:
predict_seq(my_model, "MSNDLNEHIDLPVISLQKTLEKLFGDEQFERTQHINFVVKLLSSQQADNFLGGLNLDYFNVDVEFQSNLPKPSLISFSKKVKISNLPITSYINSVAQLSASQTHAQHWNILVLKAAIYLIALPELKPELFKQAHTEHFNTVKRLFQRFRTANKNLDTEKKYQNTKEYQRLWSTYLQDPTQSLERFIQHLITLDTAELPEFDRNLLNDIRITFNYVLKNKAKIARASIDTQLEHQFLDEEQFIEESIEIKKGASPKALSIETLIDEPLHRQIVVNPTQVTPLAAHSETSQNYILPLVAKHIQRKEHLLTSSSFFPNPSSMNHLLKRLHVDYSEHQNKSALILMLAFLTGNSVNEWLYIQKSKRAKKLNNRQKLIYKNDQFFLNSHFNVFENRNFEYSDNLLNQTIYLDIPIPNLFIEDLRKMESVSFNEIQQYLRRLRQELFIPKLSVVKVSSLLHHTILSKTGNKQLADLMTGIDANQSSSVSYCHQNIPRLHAQYVDILKSLCADVASTYESCVPSLPDSIIHFGSRKAPKPQVITEIFAVLKFNIFSQAEDDLIAIYNHYNIWMWHILLLFTAARPVAEFPGFLKNFNLKRQILMVSDKEVGGRNGFGRLIPLCSFLVEEIKKFLKFLEYFSIQIAMNYPLLNDVIQQIETSKLPLLGIIQNNKWTPLRPSTVKDLYPELGLAHENWHRHTARAFLTHKFSEPEILALFGHELMQQEAAHPFSSLSLSQFSKIADVLEQMKTYFKITGVEAHVITQ",title="YTGE-72")

In [None]:
predict_seq(my_model, "MSNDLNEHFDLPMISIKRTIEKLFGDEQFERAQHINFVIELLSSQQTDNCLDGLNLDCFDIDVEFQSNLPKPSVISFNKKVKISDLPITSYVNSISQLSESQTHAKNWNILVLKAAIYLIALPELKPDLFKQAHAEHFNTVKRLFQRFRTANKNLDTEKKYHNTEEYKRLWNVYLKDLTLSLEQFIQHLITLDTDELPEFDRNLLNDIRITFNYILKNKAKIARASIDTQLQHQFLDEEQFIEESIEIKKGAESKALNIETLIDEPINRQIVVNPTDVTPLAAHSETSQSYVLPLVAKHIQRKEHLLTSSSFFPNPSSVNHLLKRLHVDYSEHQNKSALILMLAFLTGNSVNEWLYIQSKRAKNLNNRQKLIHKNDQFFLNSKFNVFENRDFEYSTSLLNQTIYLDIPIPNLFIEDLRKMDSVSFDDIQQYLRKLRQELLIPKLSVIKVSSLLHHTVLAKTGNKQLADLITGIDANQSSSISYCHQNIPRLHAQYVDILKSLCADIASTYESCVPSPPDSIIHFGSRKAPKPQIITEIFAVLKFNIFSQAEDDLIAIYNHYNIWMWHILLLFTAARPVAEFPGFLKNFNLKRQILMVSDKEAGGRNGFGRLIPLCSFLVEEIKKFLKFLEYFSTQIMMSHPALSDAIQQIEASKLPFLGIIQNDEWKPLSHSTVKNFHPELGLAHENWHRHTARAFLTHKFSEPEILALFGHELMQQEAAHPFSSLSLSQFSKIADVLEQMKTYFKITGVEAHVITQ",title="SRR6231193|NODE_1928_length_10001_cov_10.189624_2")

In [None]:
predict_seq(my_model, "MSIPAQLSAAVQTFLPTLRLGLDLGERAVGIAVVRGNEVLHAETVIDFHEATLKERRRLRRGRRTRRAKKSRIARLRSWILRQIVNGKRLPDPYILMRQKRFQCQPGEYRQKVQLAKSALPSWVEAVKQGRETSDEAFVIALTHLFQKRGYRWGGSDVQAMDDNTLADELRKIRLTPAVAEQVRREVERRKNDPNAPKGFTGKINGIEQLIEQALNRRRQPRVAEHRSIVEDEVRAVVTSFGRHHGIAEDTMTRWRAELVCLLNKPVRAARFENRALTGCTWCGAHTPKKSRPEVQELAYWAAVANVRVAAGRQPRPLTQSERAQFVEWWNADAQRRPTQPAIKRYLTSIGAQEEMARQFADLLNRRNLNGRTNLCLAHLREQAEGAFFCPQHQGVCRSAPNGQHRAVESARSRESSASRVWNPARAWHDRRVVARIERMLFMRDGTPRYGGIPSLITIEVPKPDTAHRYECPHCHEALAVNLRVRYRITKLELKPTKVRQNEAAFTCPQCRKPFEINGKRKIGTPNGLKPINVKLGLTHAVVWWAGGGKKARHVADTNGQCIYCGTNVDVGSVKLDHIFPQSMAGPGIYMNMVAACERCNNEKYNRTPWQWKGHDQAWWQAFEARLDRLFLPMRKREMLLSREASYPENPTALARVGGRAREFMRELQVMFARHAIPSERIVTGYRHDADIVMQMIEGWMTDRLRRSWMSGVDGRENFLPKDRADLRNHAQDAVLVAACPPHTWRERIFCYGPDRDMALPDLAPNWRNYESGMRDRHPLVVPLGRYRIRWRKQFLDQTFWKQPLGRKPVVYRQLAELKKSDASSIKDERIRCAFLSVCQAYNIGSEKTLTEDAQADLQNRLVSLGMVTPVRRVQCFSQKGGLPISVRPHDGPVRITQVKPTSDGVVLWLPAGVRLETARTRDLKISVIRPKPVVGWPSPDVPGQAVSELDPPVPPEAQRIATWYRYQCVRFSPHDGWYRLKEFSEKKLTVMPAIRLPKKLRNDAGGHDGEETGNDEREFGKEALLAAVRANAMATFCDPFD",title="Cas9c2|MG33-1")

In [None]:
predict_seq(my_model, "HAETYTDYHATTLEERRKMRRGRRTRHAKKLRLARLRSWLLRQTLPDPYDLMRRKDFQHPPLQLPQHLPVHRQESRTTSPWCHAVIKGHLNDPSAFVLALTHIFQKRGYKYDARDLSSLSPSELEEFLNSCCLLEQAGSLKDSLKGLVERVESNKLRNAYEKALTREPEPRKALPRQLKEDELKQLVKAFGSAQGLSPKQIVTWEKQLVGLLNKTIREPRFENRVIAGCSWCGKNTPRKSRKGVRELEFKAAVRNIRKDRVLPLDEQEALSFLHLWNDNAFRNKGLKARKQQFDKLLNSLRAQMDMASQLAELCGNDKGRGRANLCVDHLSLAADGAFKCNRHPGWLCKLGGAHEGHTQVEFVGGTGLTERLARNPCRETHEERVLRRLEEVLFDKSGNPRWGIPSLISIEFPKPGTAQTYDCPSCKEKLAIDLKVNRKIRKLSLGNAKRSDPQTSFACPFCNTGLFIKGTVKKPIGDRWEDRPVYLTHDEVYLRKAKGGMKEKKRAEYLRETNSTCVYCGQKIEDRLTMEPDHIFPRARGGPDVDSNLVASCHDCNHPNTGKGDRTPYEWIQRDGASPGGNDWSAFQKRVKSLPLPQRKRDVLLSDKEYYPDNPTALARATARKRAFIARIAKMLTDHGVPVDQIALNYETDKQVVIQAVDGWTTSRLRLSWRFHENGKPNFPQKKDWDLRNHAQDAALIAASPPHTWREAIFVETRPVDDPNKPIPGLAPRSLAPDWKGYLETIKDQKPLVSVLGRYDASWKRGFLDSTFWSFRSRNGHPTQRKQIDKVTAKQGDRIVDPKIKDAFRTLCESYGLVDRAGGFKDKPLPPEALEELRKQFPGIRRVRIVTQPGGKPILIQPQVGPPRQTQAKPGSEGIVLWLRLDTTKKRPREQTSEQFGLSLIRPAPLESFSVPRFEPPIPDDATAQLHLYRHDFIYLGENGIHPEGWYRLKEFSDDSVIALPEEAIPAELRKRMGLEQPGRKGSPSAAPAAPQERRLGKEELKGLLRSFRQRKLRIVKGGQ",title="Cas9c2|MG33-9")

In [None]:
predict_seq(my_model, "MERELVLGIDYGGKYTGLAVVDRRHNQVLYANRLKMRDDVAGILKDRRKQRGIRRTAQTKKKRLRELKNYLKSIGYNESTATFETVYSLAHKRGYDYADMPEEKTSEEIEAMDVEERKQWEKEKQEWEETKRNSRHRKEVVKDVHKAMIEGRATEEQIKRVERIFNKQYRPKRFNNRILTKCKVEDCGVNTPLRKNVRDLLIENIVRFFPIEQSEKDNLKDAVLDKNRREEVKSFFRKHKTDEHIRKQVYDIADNKLSGRTVFCKEHILEHTEHSKEERKVFRLAPSLKTKIENVLAVIKDEILPKFTVNKVVMESNNFDIAAKTQGKKRLAKEEYGKGPREGKETRKEALLRETDGRCIYCGKSIDISNAHDDHIFPRKAGGLNIFANLVACCAVCNENKKGRTPLESGISPKPEIIAFMKNDLKKKILEDARNINTVDFNKYMSHASIGWRYMRDRLRESAGNKKLPIERQSGIYTAYFRRWWGFKKERGNTLHHALDAVILASRKGYSDDGLVDMTLKPKYNKGGEFDPEKHLPEPIEFKMDKGSRGSALHDRNPLSYKKGIITRRFMVTEIECGKEDDVISETYREKLKEAFKRFDTKKGKCLTDKEAKEAGFCIKKNELVMSLKCSIKGTGPGQMIRINNNVFKTNVHNVGVDVYLDEKGKKKAYERKNPRLSKHFIEPPPQPNGRVSFTLKRRDMVTVEGEDAIYRIKKLGTSPTIEAVVGSDGKTRTVSATKLTKANSAE",title="Cas9d|MG34-1")

In [None]:
predict_seq(my_model, "MDMVYVLNKDGKPLMATTRGGRVRYLLKEKKARVVSSTPFTIQLNYDTPDITQDLILGIDPGRTNIGVAVVKEDGQCVFSAHLETRNKEVPLLMKKRAAFRRQHRTQDRRRKRQRRAIAAGTTVESNTIERLLPGYEKPIVCHHIRNKEARFNNRSRPAGWLTPTANHLLQTHINLIAKIAKVLPITKVVVELNRFAFMAMDNPNIRRWEYQQGSLYGLGSVEDAVYAQQDGHCLFCKKPIDHYHHVVPRHKGGSETLANRCGLCEKHHALVHKDKAWAEKLVTRKGGMNKKYHALSVLNQIIPFLMEYLGEETPYDVYATDGKSTKGFRIAKNVPKEHYTDAYCIACSILDADTKVSAPAEPFKLKQFRRHDRQSCIRQMVDRKYLLDGKVVATNRHKAIEQKSDSLEEFREAYGDAAVSQLTVRPHLPQYKDMTRIMPGAVMAFNGSVGVMQSSIVGSFYKSTKGHKATPRRCVLLAQNAGMVFNPA",title="HEARO|MG35-287")

In [None]:
import requests
from PIL import Image
import io
from IPython.display import display
import base64
seq='MERELVLGIDYGGKYTGLAVVDRRHNQVLYANRLKMRDDVAGILKDRRKQRGIRRTAQTKKKRLRELKNYLKSIGYNESTATFETVYSLAHKRGYDYADMPEEKTSEEIEAMDVEERKQWEKEKQEWEETKRNSRHRKEVVKDVHKAMIEGRATEEQIKRVERIFNKQYRPKRFNNRILTKCKVEDCGVNTPLRKNVRDLLIENIVRFFPIEQSEKDNLKDAVLDKNRREEVKSFFRKHKTDEHIRKQVYDIADNKLSGRTVFCKEHILEHTEHSKEERKVFRLAPSLKTKIENVLAVIKDEILPKFTVNKVVMESNNFDIAAKTQGKKRLAKEEYGKGPREGKETRKEALLRETDGRCIYCGKSIDISNAHDDHIFPRKAGGLNIFANLVACCAVCNENKKGRTPLESGISPKPEIIAFMKNDLKKKILEDARNINTVDFNKYMSHASIGWRYMRDRLRESAGNKKLPIERQSGIYTAYFRRWWGFKKERGNTLHHALDAVILASRKGYSDDGLVDMTLKPKYNKGGEFDPEKHLPEPIEFKMDKGSRGSALHDRNPLSYKKGIITRRFMVTEIECGKEDDVISETYREKLKEAFKRFDTKKGKCLTDKEAKEAGFCIKKNELVMSLKCSIKGTGPGQMIRINNNVFKTNVHNVGVDVYLDEKGKKKAYERKNPRLSKHFIEPPPQPNGRVSFTLKRRDMVTVEGEDAIYRIKKLGTSPTIEAVVGSDGKTRTVSATKLTKANSAE'
title='Cas9d|MG34-1'
step=10
response = requests.get('http://127.0.0.1:8040/predict/?seq='+seq+'&title='+title+'&step='+str(step))
print(response.status_code)
print(type(response))
response=response.json()
plot = response["img"]
stats = response["statistics"]

print(stats)

img_io = io.BytesIO(base64.b64decode(plot))
image = Image.open(img_io)
display(image)

In [None]:
import requests
from PIL import Image
import io
from IPython.display import display_html
seq='MERELVLGIDYGGKYTGLAVVDRRHNQVLYANRLKMRDDVAGILKDRRKQRGIRRTAQTKKKRLRELKNYLKSIGYNESTATFETVYSLAHKRGYDYADMPEEKTSEEIEAMDVEERKQWEKEKQEWEETKRNSRHRKEVVKDVHKAMIEGRATEEQIKRVERIFNKQYRPKRFNNRILTKCKVEDCGVNTPLRKNVRDLLIENIVRFFPIEQSEKDNLKDAVLDKNRREEVKSFFRKHKTDEHIRKQVYDIADNKLSGRTVFCKEHILEHTEHSKEERKVFRLAPSLKTKIENVLAVIKDEILPKFTVNKVVMESNNFDIAAKTQGKKRLAKEEYGKGPREGKETRKEALLRETDGRCIYCGKSIDISNAHDDHIFPRKAGGLNIFANLVACCAVCNENKKGRTPLESGISPKPEIIAFMKNDLKKKILEDARNINTVDFNKYMSHASIGWRYMRDRLRESAGNKKLPIERQSGIYTAYFRRWWGFKKERGNTLHHALDAVILASRKGYSDDGLVDMTLKPKYNKGGEFDPEKHLPEPIEFKMDKGSRGSALHDRNPLSYKKGIITRRFMVTEIECGKEDDVISETYREKLKEAFKRFDTKKGKCLTDKEAKEAGFCIKKNELVMSLKCSIKGTGPGQMIRINNNVFKTNVHNVGVDVYLDEKGKKKAYERKNPRLSKHFIEPPPQPNGRVSFTLKRRDMVTVEGEDAIYRIKKLGTSPTIEAVVGSDGKTRTVSATKLTKANSAE'
title='Cas9d|MG34-1'
step=8
response = requests.get('http://127.0.0.1:8040/predict/?seq='+seq+'&title='+title+'&step='+str(step))
print(response.status_code)
print(type(response))
response=response.json()
fig1_html = response["fig1_html"]
fig2_html = response["fig2_html"]
stats = response["statistics"]

print(stats)
display_html(fig1_html, raw=True)
display_html(fig2_html, raw=True)

In [None]:
import requests
from PIL import Image
import io
from IPython.display import display_html
seq='MERELVLGIDYGGKYTGLAVVDRRHNQVLYANRLKMRDDVAGILKDRRKQRGIRRTAQTKKKRLRELKNYLKSIGYNESTATFETVYSLAHKRGYDYADMPEEKTSEEIEAMDVEERKQWEKEKQEWEETKRNSRHRKEVVKDVHKAMIEGRATEEQIKRVERIFNKQYRPKRFNNRILTKCKVEDCGVNTPLRKNVRDLLIENIVRFFPIEQSEKDNLKDAVLDKNRREEVKSFFRKHKTDEHIRKQVYDIADNKLSGRTVFCKEHILEHTEHSKEERKVFRLAPSLKTKIENVLAVIKDEILPKFTVNKVVMESNNFDIAAKTQGKKRLAKEEYGKGPREGKETRKEALLRETDGRCIYCGKSIDISNAHDDHIFPRKAGGLNIFANLVACCAVCNENKKGRTPLESGISPKPEIIAFMKNDLKKKILEDARNINTVDFNKYMSHASIGWRYMRDRLRESAGNKKLPIERQSGIYTAYFRRWWGFKKERGNTLHHALDAVILASRKGYSDDGLVDMTLKPKYNKGGEFDPEKHLPEPIEFKMDKGSRGSALHDRNPLSYKKGIITRRFMVTEIECGKEDDVISETYREKLKEAFKRFDTKKGKCLTDKEAKEAGFCIKKNELVMSLKCSIKGTGPGQMIRINNNVFKTNVHNVGVDVYLDEKGKKKAYERKNPRLSKHFIEPPPQPNGRVSFTLKRRDMVTVEGEDAIYRIKKLGTSPTIEAVVGSDGKTRTVSATKLTKANSAE'
title='Cas9d|MG34-1'
step=8
response = requests.get('http://127.0.0.1:8040/predict/?seq='+seq+'&title='+title+'&step='+str(step)+'&return_type=raw')
print(response.status_code)
print(type(response))
response=response.json()
print(type(response))
print(response)