In [None]:
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks/Cas9/On target')

Mounted at /content/drive


In [None]:
import requests
import tensorflow as tf
import pandas as pd
import numpy as np
from operator import add
from functools import reduce
import random
import tabulate

from keras import Model
from keras import regularizers
from keras.optimizers import Adam
from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax
from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout
from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape
from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer
from keras.models import load_model
from keras.callbacks import EarlyStopping, ReduceLROnPlateau

!pip install cyvcf2
import cyvcf2
!pip install parasail
import parasail

import re

Collecting cyvcf2
  Downloading cyvcf2-0.30.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
Collecting coloredlogs (from cyvcf2)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1 (from coloredlogs->cyvcf2)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: humanfriendly, coloredlogs, cyvcf2
Successfully installed coloredlogs-15.0.1 cyvcf2-0.30.28 humanfriendly-10.0
Collecting parasail
  Downloading parasail-1.3.4-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

### Data Encoding

In [None]:
ntmap = {'A': (1, 0, 0, 0),
         'C': (0, 1, 0, 0),
         'G': (0, 0, 1, 0),
         'T': (0, 0, 0, 1)
         }

def get_seqcode(seq):
    return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))

### Attention model

In [None]:
class PositionalEncoding(Layer):
    def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):
        super(PositionalEncoding, self).__init__()
        self.sequence_len = sequence_len
        self.embedding_dim = embedding_dim

    def call(self, x):

        position_embedding = np.array([
            [pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]
            for pos in range(self.sequence_len)])

        position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2])  # dim 2i
        position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2])  # dim 2i+1
        position_embedding = tf.cast(position_embedding, dtype=tf.float32)

        return position_embedding+x

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'sequence_len' : self.sequence_len,
            'embedding_dim' : self.embedding_dim,
        })
        return config

def MultiHeadAttention_model(input_shape):
    input = Input(shape=input_shape)

    conv1 = Conv1D(256, 3, activation="relu")(input)
    pool1 = AveragePooling1D(2)(conv1)
    drop1 = Dropout(0.4)(pool1)

    conv2 = Conv1D(256, 3, activation="relu")(drop1)
    pool2 = AveragePooling1D(2)(conv2)
    drop2 = Dropout(0.4)(pool2)

    lstm = Bidirectional(LSTM(128,
                               dropout=0.5,
                               activation='tanh',
                               return_sequences=True,
                               kernel_regularizer=regularizers.l2(0.01)))(drop2)

    pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)
    atten = MultiHeadAttention(num_heads=2,
                               key_dim=64,
                               dropout=0.2,
                               kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)

    flat = Flatten()(atten)

    dense1 = Dense(512,
                   kernel_regularizer=regularizers.l2(1e-4),
                   bias_regularizer=regularizers.l2(1e-4),
                   activation="relu")(flat)
    drop3 = Dropout(0.1)(dense1)

    dense2 = Dense(128,
                   kernel_regularizer=regularizers.l2(1e-4),
                   bias_regularizer=regularizers.l2(1e-4),
                   activation="relu")(drop3)
    drop4 = Dropout(0.1)(dense2)

    dense3 = Dense(256,
                   kernel_regularizer=regularizers.l2(1e-4),
                   bias_regularizer=regularizers.l2(1e-4),
                   activation="relu")(drop4)
    drop5 = Dropout(0.1)(dense3)

    output = Dense(1, activation="linear")(drop5)

    model = Model(inputs=[input], outputs=[output])
    return model

### Predict gRNA in one specific gene

In [None]:
def fetch_ensembl_transcripts(gene_symbol):
    url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
    response = requests.get(url)
    if response.status_code == 200:
        gene_data = response.json()
        if 'Transcript' in gene_data:
            return gene_data['Transcript']
        else:
            print("No transcripts found for gene:", gene_symbol)
            return None
    else:
        print(f"Error fetching gene data from Ensembl: {response.text}")
        return None

def fetch_ensembl_sequence(transcript_id):
    url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
    response = requests.get(url)
    if response.status_code == 200:
        sequence_data = response.json()
        if 'seq' in sequence_data:
            return sequence_data['seq']
        else:
            print("No sequence found for transcript:", transcript_id)
            return None
    else:
        print(f"Error fetching sequence data from Ensembl: {response.text}")
        return None


In [None]:
def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam="NGG", target_length=20):
    targets = []
    len_sequence = len(sequence)
    #complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
    dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}

    for i in range(len_sequence - len(pam) + 1):
        if sequence[i + 1:i + 3] == pam[1:]:
            if i >= target_length:
                target_seq = sequence[i - target_length:i + 3]
                if strand == -1:
                    tar_start = end - (i + 2)
                    tar_end = end - (i - target_length)
                    #seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]
                else:
                    tar_start = start + i - target_length
                    tar_end = start + i + 3 - 1
                    #seq_in_ref = target_seq
                gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]])
                #targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])
                targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])

    return targets


In [None]:
# Function to predict on-target efficiency and format output
def format_prediction_output(targets, model_path):
    model = MultiHeadAttention_model(input_shape=(23, 4))
    model.load_weights(model_path)

    formatted_data = []

    for target in targets:
        # Encode the gRNA sequence
        encoded_seq = get_seqcode(target[0])

        # Predict on-target efficiency using the model
        prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
        if prediction > 100:
            prediction = 100

        # Format output
        gRNA = target[1]
        chr = target[2]
        start = target[3]
        end = target[4]
        strand = target[5]
        transcript_id = target[6]
        exon_id = target[7]
        #seq_in_ref = target[8]
        #formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction[0]])
        formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])

    return formatted_data

In [None]:
def gRNADesign(gene_symbol, model_path, write_to_csv=False):
    transcripts = fetch_ensembl_transcripts(gene_symbol)
    results = []
    if transcripts:
        for i in range(len(transcripts)):
            Exons = transcripts[i]['Exon']
            transcript_id = transcripts[i]['id']
            for j in range(len(Exons)):
                exon_id = Exons[j]['id']
                gene_sequence = fetch_ensembl_sequence(exon_id)
                if gene_sequence:
                    start = Exons[j]['start']
                    end = Exons[j]['end']
                    strand = Exons[j]['strand']
                    chr = Exons[j]['seq_region_name']
                    targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)
                    if targets:
                        formatted_data = format_prediction_output(targets, model_path)
                        results.append(formatted_data)

    #header = ['Chr','Start','End','Strand','Transcript','Exon','Target sequence (5\' to 3\')','gRNA','Sequence in reference genome','pred_Score']
    header = ['Chrom','Start','End','Strand','Transcript','Exon','Target sequence (5\' to 3\')','gRNA','pred_Score']
    output = []
    for result in results:
        for item in result:
            output.append(item)
    sort_output = sorted(output, key=lambda x: x[8], reverse=True)

    if write_to_csv==True:
        pd.DataFrame(data=sort_output, columns=header).to_csv(f'/content/drive/MyDrive/Colab Notebooks/Cas9/On target/design_results/Cas9_{gene_symbol}.csv')
    else:
        return sort_output

In [None]:
# design
genes = ['TROAP','SPC24','RAD54L','MCM2','COPB2','CKAP5']
model_path = '/content/drive/MyDrive/Colab Notebooks/Cas9/On target/saved_model/Cas9_MultiHeadAttention_weights.keras'

for gene in genes:
    gRNADesign(gene, model_path, write_to_csv=True)



### Combine with VCF information

##### Benchmarking with labelled MDA-MB-231 mutations

In [None]:
# read VCF file
vcf_reader = cyvcf2.VCF('/content/drive/MyDrive/Colab Notebooks/CRISPR_data/SRR25934512.filter.snps.indels.vcf.gz')

In [None]:
# read background mutations in MDAMB321 from Depmap
mdamb321_mut_bg = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/CRISPR_data/MDAMB231 mutations.csv')
mdamb321_mut_bg

Unnamed: 0,Gene,Chromosome,Position,Variant Type,Variant Info,Ref Allele,Alt Allele,Allele Fraction,Ref Count,Alt Count,...,Vep Mane Select,Sift,Vep Ensp,Ensembl Gene Id,Provean Prediction,Nmd,Vep Somatic,Lof Number Of Transcripts In Gene,Vep Impact,Oncogene High Impact
0,CHD5,chr1,6131677,deletion,frameshift_variant,AC,A,0.400,21,13,...,NM_015557.3,,ENSP00000262450,ENSG00000116254,,,,1.0,HIGH,False
1,EIF4G3,chr1,20969478,SNV,missense_variant,G,A,0.678,9,21,...,NM_001391906.1,deleterious_low_confidence(0),ENSP00000473510,ENSG00000075151,Damaging,,,,MODERATE,False
2,ODF2L,chr1,86382979,SNV,missense_variant,T,G,0.656,10,21,...,NM_001366781.1,deleterious(0),ENSP00000433092,ENSG00000122417,Damaging,,,,MODERATE,False
3,GTF2B,chr1,88860139,SNV,splice_donor_variant,C,A,0.421,21,15,...,NM_001514.6,,ENSP00000359531,ENSG00000137947,,,,1.0,HIGH,False
4,LRIG2,chr1,113112563,SNV,missense_variant,A,G,0.714,13,35,...,NM_014813.3,deleterious_low_confidence(0.01),ENSP00000355396,ENSG00000198799,Damaging,,1,,MODERATE,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
271,ADGRG2,chrX,19028184,SNV,missense_variant&splice_region_variant,T,G,0.233,24,6,...,NM_001079858.3,tolerated_low_confidence(0.28),ENSP00000369198,ENSG00000173698,Neutral,,,,MODERATE,False
272,MAGEB3,chrX,30236624,SNV,missense_variant,G,A,0.315,12,5,...,NM_002365.5,tolerated(0.25),ENSP00000355198,ENSG00000198798,Neutral,,1,,MODERATE,False
273,NR0B1,chrX,30308885,SNV,missense_variant,G,T,0.767,6,23,...,NM_000475.5,tolerated_low_confidence(0.05),ENSP00000368253,ENSG00000169297,Neutral,,1&1,,MODERATE,False
274,CFAP47,chrX,36046943,SNV,missense_variant,G,T,0.414,16,12,...,NM_001304548.2,tolerated(0.15),ENSP00000367922,ENSG00000165164,,,,,MODERATE,False


In [None]:
# intersect
mutation_pos = []
for mutation in vcf_reader:
    mutation_pos.append(mutation.POS)

mdamb321_mut_remain_idx = []
for i in range(len(mdamb321_mut_bg)):
    if mdamb321_mut_bg['Position'][i] in mutation_pos:
        mdamb321_mut_remain_idx.append(i)

mdamb321_mut_remain = mdamb321_mut_bg.iloc[mdamb321_mut_remain_idx,:]
mdamb321_mut_remain

Unnamed: 0,Gene,Chromosome,Position,Variant Type,Variant Info,Ref Allele,Alt Allele,Allele Fraction,Ref Count,Alt Count,...,Vep Mane Select,Sift,Vep Ensp,Ensembl Gene Id,Provean Prediction,Nmd,Vep Somatic,Lof Number Of Transcripts In Gene,Vep Impact,Oncogene High Impact
2,ODF2L,chr1,86382979,SNV,missense_variant,T,G,0.656,10,21,...,NM_001366781.1,deleterious(0),ENSP00000433092,ENSG00000122417,Damaging,,,,MODERATE,False
4,LRIG2,chr1,113112563,SNV,missense_variant,A,G,0.714,13,35,...,NM_014813.3,deleterious_low_confidence(0.01),ENSP00000355396,ENSG00000198799,Damaging,,1,,MODERATE,False
9,RGS5,chr1,163147354,SNV,missense_variant,C,A,0.228,20,5,...,NM_003617.4,tolerated(1),ENSP00000319308,ENSG00000143248,Neutral,,,,MODERATE,False
10,TPR,chr1,186327474,SNV,missense_variant,C,A,0.263,14,4,...,NM_003292.3,deleterious(0.04),ENSP00000356448,ENSG00000047410,Neutral,,,,MODERATE,False
11,ASPM,chr1,197104987,SNV,missense_variant,A,G,0.427,26,19,...,NM_018136.5,deleterious(0),ENSP00000356379,ENSG00000066279,Damaging,,1,,MODERATE,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
265,ZBTB6,chr9,122911153,SNV,missense_variant,G,A,0.447,19,15,...,NM_006626.6,deleterious(0.02),ENSP00000362763,ENSG00000186130,Damaging,,1,,MODERATE,False
266,REXO4,chr9,133417769,SNV,missense_variant,G,C,0.382,25,15,...,NM_020385.4,tolerated(0.3),ENSP00000361010,ENSG00000148300,Neutral,,,,MODERATE,False
267,SEC16A,chr9,136476666,SNV,missense_variant,A,C,0.418,27,17,...,NM_014866.2,tolerated(0.08),ENSP00000508822,ENSG00000148396,Neutral,,,,MODERATE,False
268,MT-ND4,chrM,12084,SNV,missense_variant,C,T,1.000,2,4818,...,,tolerated_low_confidence(0.23),ENSP00000354961,ENSG00000198886,,,,,MODERATE,False


##### Predict cell type-specific gRNA

In [None]:
def fetch_ensembl_transcripts(gene_symbol):
    url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
    response = requests.get(url)
    if response.status_code == 200:
        gene_data = response.json()
        if 'Transcript' in gene_data:
            return gene_data['Transcript']
        else:
            print("No transcripts found for gene:", gene_symbol)
            return None
    else:
        print(f"Error fetching gene data from Ensembl: {response.text}")
        return None

def fetch_ensembl_sequence(transcript_id):
    url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
    response = requests.get(url)
    if response.status_code == 200:
        sequence_data = response.json()
        if 'seq' in sequence_data:
            return sequence_data['seq']
        else:
            print("No sequence found for transcript:", transcript_id)
            return None
    else:
        print(f"Error fetching sequence data from Ensembl: {response.text}")
        return None


In [None]:
def apply_mutation(ref_sequence, offset, ref, alt):
    """
    Apply a single mutation to the sequence.
    """
    if len(ref) == len(alt) and alt != "*":  # SNV
        mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]

    elif len(ref) < len(alt):  # Insertion
        mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]

    elif len(ref) == len(alt) and alt == "*":  # Deletion
        mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]

    elif len(ref) > len(alt) and alt != "*":  # Deletion
        mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]

    elif len(ref) > len(alt) and alt == "*":  # Deletion
        mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]

    return mutated_seq

In [None]:
def construct_combinations(sequence, mutations):
    """
    Construct all combinations of mutations.
    mutations is a list of tuples (position, ref, [alts])
    """
    if not mutations:
        return [sequence]

    # Take the first mutation and recursively construct combinations for the rest
    first_mutation = mutations[0]
    rest_mutations = mutations[1:]
    offset, ref, alts = first_mutation

    sequences = []
    for alt in alts:
        mutated_sequence = apply_mutation(sequence, offset, ref, alt)
        sequences.extend(construct_combinations(mutated_sequence, rest_mutations))

    return sequences


In [None]:
def needleman_wunsch_alignment(query_seq, ref_seq):
    """
    Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq
    Use this position to represent the position of target sequence with mutations
    """
    # Needleman-Wunsch alignment
    alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)

    # extract CIGAR object
    cigar = alignment.cigar
    cigar_string = cigar.decode.decode("utf-8")

    # record ref_pos
    ref_pos = 0

    matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_string)
    max_num_before_equal = 0
    max_equal_index = -1
    total_before_max_equal = 0

    for i, (num_str, op) in enumerate(matches):
        num = int(num_str)
        if op == '=':
            if num > max_num_before_equal:
                max_num_before_equal = num
                max_equal_index = i
    total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))

    ref_pos = total_before_max_equal

    return ref_pos


In [None]:
def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,
                            exon_id, gene_symbol, vcf_reader, pam="NGG", target_length=20):
    # initialization
    mutated_sequences = [ref_sequence]

    # find mutations within interested region
    mutations = vcf_reader(f"{exon_chr}:{start}-{end}")
    if mutations:
        # find mutations
        mutation_list = []
        for mutation in mutations:
            offset = mutation.POS - start
            ref = mutation.REF
            alts = mutation.ALT[:-1]
            mutation_list.append((offset, ref, alts))

        # replace reference sequence of mutation
        mutated_sequences = construct_combinations(ref_sequence, mutation_list)

    # find gRNA in ref_sequence or all mutated_sequences
    targets = []
    for seq in mutated_sequences:
        len_sequence = len(seq)
        dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}
        for i in range(len_sequence - len(pam) + 1):
            if seq[i + 1:i + 3] == pam[1:]:
                if i >= target_length:
                    target_seq = seq[i - target_length:i + 3]
                    pos = ref_sequence.find(target_seq)
                    if pos != -1:
                        is_mut = False
                        if strand == -1:
                            tar_start = end - pos - target_length - 2
                        else:
                            tar_start = start + pos
                    else:
                        is_mut = True
                        nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)
                        if strand == -1:
                            tar_start = str(end - nw_pos - target_length - 2) + '*'
                        else:
                            tar_start = str(start + nw_pos) + '*'
                    gRNA = ''.join([dnatorna[base] for base in seq[i - target_length:i]])
                    targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])

    # filter duplicated targets
    unique_targets_set = set(tuple(element) for element in targets)
    unique_targets = [list(element) for element in unique_targets_set]

    return unique_targets

In [None]:
def format_prediction_output_with_mutation(targets, model_path):
    model = MultiHeadAttention_model(input_shape=(23, 4))
    model.load_weights(model_path)

    formatted_data = []

    for target in targets:
        # Encode the gRNA sequence
        encoded_seq = get_seqcode(target[0])


        # Predict on-target efficiency using the model
        prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])
        if prediction > 100:
            prediction = 100

        # Format output
        gRNA = target[1]
        exon_chr = target[2]
        strand = target[3]
        tar_start = target[4]
        transcript_id = target[5]
        exon_id = target[6]
        gene_symbol = target[7]
        is_mut = target[8]
        formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id,
                               exon_id, target[0], gRNA, prediction, is_mut])

    return formatted_data

In [None]:
def gRNADesign_mutation(gene_symbol, vcf_reader, model_path, write_to_csv=False):
    results = []

    transcripts = fetch_ensembl_transcripts(gene_symbol)
    if transcripts:
        for transcript in transcripts:
            Exons = transcript['Exon']
            transcript_id = transcript['id']

            for Exon in Exons:
                exon_id = Exon['id']
                exon_chr = Exon['seq_region_name']
                start = Exon['start']
                end = Exon['end']
                strand = Exon['strand']
                gene_sequence = fetch_ensembl_sequence(exon_id) # reference exon sequence

                if gene_sequence:
                    targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand,
                                                      transcript_id, exon_id, gene_symbol, vcf_reader)
                    if targets:
                        # Predict on-target efficiency for each gRNA site
                        formatted_data = format_prediction_output_with_mutation(targets, model_path)
                        results.append(formatted_data)
    header = ['Gene','Chrom','Strand','Start','Transcript','Exon','Target sequence (5\' to 3\')','gRNA','pred_Score','Is_mutation']
    output = []
    for result in results:
        for item in result:
            output.append(item)
    sort_output = sorted(output, key=lambda x: x[8], reverse=True)

    if write_to_csv==True:
        pd.DataFrame(data=sort_output, columns=header).to_csv(f'/content/drive/MyDrive/Colab Notebooks/Cas9/On target/design_results/Cas9_{gene_symbol}_mut.csv')
    else:
        return sort_output

In [None]:
# design
genes = ['TROAP','SPC24','RAD54L','MCM2','COPB2','CKAP5']
model_path = '/content/drive/MyDrive/Colab Notebooks/Cas9/On target/saved_model/Cas9_MultiHeadAttention_weights.keras'

for gene in genes:
    gRNADesign_mutation(gene, vcf_reader, model_path, write_to_csv=True)

##### Find difference

In [None]:
# example
model_path = '/content/drive/MyDrive/Colab Notebooks/Cas9/On target/saved_model/Transformer_withoutEpi_weights.keras'
pred_result_mut = gRNADesign_mutation('TROAP', vcf_reader, model_path, write_to_csv=False)

header = ['Gene','Chrom','Strand','Start','Transcript','Exon','Target sequence (5\' to 3\')','gRNA','pred_Score','Is_mutation']
table = tabulate.tabulate(pred_result_mut, header, tablefmt='pipe')
print(table)

| Gene   |   Chrom |   Strand | Start     | Transcript      | Exon            | Target sequence (5' to 3')   | gRNA                 |   pred_Score | Is_mutation   |
|:-------|--------:|---------:|:----------|:----------------|:----------------|:-----------------------------|:---------------------|-------------:|:--------------|
| TROAP  |      12 |        1 | 49323611  | ENST00000549275 | ENSE00000919611 | GACCACCCGGCAAGCCACGAAGG      | GACCACCCGGCAAGCCACGA |  1           | False         |
| TROAP  |      12 |        1 | 49323730  | ENST00000549275 | ENSE00000919611 | ACCAGGAGAACCAAGATCCAAGG      | ACCAGGAGAACCAAGAUCCA |  1           | False         |
| TROAP  |      12 |        1 | 49323626  | ENST00000549275 | ENSE00000919611 | CACGAAGGATCCCCTCCTCCGGG      | CACGAAGGAUCCCCUCCUCC |  1           | False         |
| TROAP  |      12 |        1 | 49323626  | ENST00000551245 | ENSE00000919611 | CACGAAGGATCCCCTCCTCCGGG      | CACGAAGGAUCCCCUCCUCC |  1           | False         |
| TROAP  |

In [None]:
# compare with no mutation
pred_result = gRNADesign('TROAP', model_path, write_to_csv=False)

output_exclude = []
for tar in pred_result:
    if tar[6] not in [item[6] for item in pred_result_mut]:
        output_exclude.append(tar)

header = ['Chrom','Start','End','Strand','Transcript','Exon','Target sequence (5\' to 3\')','gRNA','pred_Score']
print(tabulate.tabulate(output_exclude, header, tablefmt='pipe'))

|   Chrom |    Start |      End |   Strand | Transcript      | Exon            | Target sequence (5' to 3')   | gRNA                 |   pred_Score |
|--------:|---------:|---------:|---------:|:----------------|:----------------|:-----------------------------|:---------------------|-------------:|
|      12 | 49328997 | 49329019 |        1 | ENST00000551245 | ENSE00003688930 | CATGTCCATCACCCTTTGGACGG      | CAUGUCCAUCACCCUUUGGA |  1           |
|      12 | 49330176 | 49330198 |        1 | ENST00000551245 | ENSE00002377270 | AGGAAGTAGAGGGGCTGGTAGGG      | AGGAAGUAGAGGGGCUGGUA |  1           |
|      12 | 49328997 | 49329019 |        1 | ENST00000549891 | ENSE00003653778 | CATGTCCATCACCCTTTGGACGG      | CAUGUCCAUCACCCUUUGGA |  1           |
|      12 | 49328997 | 49329019 |        1 | ENST00000257909 | ENSE00003688930 | CATGTCCATCACCCTTTGGACGG      | CAUGUCCAUCACCCUUUGGA |  1           |
|      12 | 49330176 | 49330198 |        1 | ENST00000257909 | ENSE00000919622 | AGGAAGTAGAGGGGCTGGT

### Evaluate performance by comparing to Depmap

In [None]:
# read Depmap data
sgRNA_breast_cancer_DepMap = pd.read_excel('/content/drive/MyDrive/Colab Notebooks/CRISPR_data/sgRNA_breast_cancer_DepMap.xlsx')
sgRNA_breast_cancer_DepMap

Unnamed: 0,sgRNA_data_ID,Construct Barcode,sgRNA_annotation_breast_sgRNA,genome_alignment,guide_gene_map_gene,n_alignments,guide_efficacy_sgRNA,efficacy,MDAMB231-311Cas9_RepA_p5_batch2,MDAMB231-311Cas9_RepB_p5_batch2,...,MDAMB436-311cas9 Rep B p5_batch2,MDAMB436-311cas9 Rep C p5_batch2,MDAMB453-311Cas9_RepA_p5_batch2,MDAMB453-311Cas9_RepB_p5_batch2,MDAMB453-311Cas9_RepC_p5_batch2,MDAMB468-311cas9_RepB_p6_batch2,SKBR3-311Cas9_RepA_p6_batch3,SKBR3-311Cas9_RepB_p6_batch3,ZR-75-1-311Cas9_RepA_p5_batch2,ZR-75-1-311Cas9_RepB_p5_batch2
0,60396,TATTGGATACAAAGCAAAAG,TATTGGATACAAAGCAAAAG,chr10_135118908_-,,69.0,TATTGGATACAAAGCAAAAG,0.686422,-3.225385,-3.071423,...,-1.594764,-3.158211,-2.187555,-1.904303,-2.228527,-3.255427,-2.096329,-1.728541,-3.003000,-2.834542
1,60396,TATTGGATACAAAGCAAAAG,TATTGGATACAAAGCAAAAG,chr10_134957906_-,,69.0,TATTGGATACAAAGCAAAAG,0.686422,-3.225385,-3.071423,...,-1.594764,-3.158211,-2.187555,-1.904303,-2.228527,-3.255427,-2.096329,-1.728541,-3.003000,-2.834542
2,60396,TATTGGATACAAAGCAAAAG,TATTGGATACAAAGCAAAAG,chr3_195749617_-,,69.0,TATTGGATACAAAGCAAAAG,0.686422,-3.225385,-3.071423,...,-1.594764,-3.158211,-2.187555,-1.904303,-2.228527,-3.255427,-2.096329,-1.728541,-3.003000,-2.834542
3,60396,TATTGGATACAAAGCAAAAG,TATTGGATACAAAGCAAAAG,chr2_153612203_-,,69.0,TATTGGATACAAAGCAAAAG,0.686422,-3.225385,-3.071423,...,-1.594764,-3.158211,-2.187555,-1.904303,-2.228527,-3.255427,-2.096329,-1.728541,-3.003000,-2.834542
4,60396,TATTGGATACAAAGCAAAAG,TATTGGATACAAAGCAAAAG,chr19_42304731_-,,69.0,TATTGGATACAAAGCAAAAG,0.686422,-3.225385,-3.071423,...,-1.594764,-3.158211,-2.187555,-1.904303,-2.228527,-3.255427,-2.096329,-1.728541,-3.003000,-2.834542
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
81428,73821,TTTGTGTATTACCTCAGGAA,,,,,,,-0.069900,0.437790,...,0.572026,0.995102,-0.003004,-0.335238,0.006941,0.613603,-0.134402,0.126074,-0.172093,0.557233
81429,73839,TTTTACCTTGTTCACATGGA,,,,,TTTTACCTTGTTCACATGGA,,1.115169,1.066007,...,0.944632,0.983323,0.481378,0.640064,0.978568,0.845937,0.306001,0.489620,0.972101,1.319258
81430,73840,TTTTGACTCTAATCACCGGT,,,,,TTTTGACTCTAATCACCGGT,,0.757864,0.658191,...,0.351637,0.611574,0.331574,0.615665,0.917588,0.438363,0.237336,0.501745,0.469061,0.826546
81431,73841,TTTTTAATACAAGGTAATCT,,,,,TTTTTAATACAAGGTAATCT,,1.460283,1.095231,...,1.205029,0.976688,0.598391,0.797775,1.323785,0.799345,0.330981,0.095441,1.104013,0.911635


In [None]:
sgRNA_breast_cancer = sgRNA_breast_cancer_DepMap[sgRNA_breast_cancer_DepMap['efficacy'].notnull()].loc[:,['guide_efficacy_sgRNA',
                                                                                                                            'guide_gene_map_gene',
                                                                                                                            'efficacy',
                                                                                                                            'MDAMB231-311Cas9_RepA_p5_batch2',
                                                                                                                            'MDAMB231-311Cas9_RepB_p5_batch2']]
sgRNA_breast_cancer = sgRNA_breast_cancer.drop_duplicates(['guide_efficacy_sgRNA'])
sgRNA_breast_cancer

Unnamed: 0,guide_efficacy_sgRNA,guide_gene_map_gene,efficacy,MDAMB231-311Cas9_RepA_p5_batch2,MDAMB231-311Cas9_RepB_p5_batch2
0,TATTGGATACAAAGCAAAAG,,0.686422,-3.225385,-3.071423
69,GCTTTCACAGAATTATTCCA,,0.996387,-3.721890,-2.981994
113,GATCCTCTGAGAGTCCCAGG,,0.988974,-2.572748,-2.918988
151,GGCCATAGAATTCTCTCTGG,ZNF506 (440515),0.797474,-3.100816,-2.803211
187,GTTTCTTTACTCAGCCCCTG,SPDYE3 (441272),0.880902,-2.627522,-3.037610
...,...,...,...,...,...
77492,TTTGTTGGAGAGATGTACGA,LIPH (200879),0.999487,0.017084,-0.582310
77493,TTTGTTGGCACAAATACGGG,RPL10L (140801),0.999599,0.147163,0.192003
77494,TTTGTTGGCCACATCTACGG,C1orf137 (388667),0.999890,-0.130266,-0.052341
77495,TTTGTTTCCTCTTCTCGAGG,CRISPLD1 (83690),0.882474,0.725285,0.723811


In [None]:
sgRNA_breast_cancer['guide_efficacy_sgRNA'] = [item + random.choice(['A', 'C', 'G', 'T']) + 'GG' for item in sgRNA_breast_cancer['guide_efficacy_sgRNA']]
sgRNA_breast_cancer

Unnamed: 0,guide_efficacy_sgRNA,guide_gene_map_gene,efficacy,MDAMB231-311Cas9_RepA_p5_batch2,MDAMB231-311Cas9_RepB_p5_batch2
0,TATTGGATACAAAGCAAAAGAGG,,0.686422,-3.225385,-3.071423
69,GCTTTCACAGAATTATTCCAAGG,,0.996387,-3.721890,-2.981994
113,GATCCTCTGAGAGTCCCAGGCGG,,0.988974,-2.572748,-2.918988
151,GGCCATAGAATTCTCTCTGGTGG,ZNF506 (440515),0.797474,-3.100816,-2.803211
187,GTTTCTTTACTCAGCCCCTGAGG,SPDYE3 (441272),0.880902,-2.627522,-3.037610
...,...,...,...,...,...
77492,TTTGTTGGAGAGATGTACGACGG,LIPH (200879),0.999487,0.017084,-0.582310
77493,TTTGTTGGCACAAATACGGGAGG,RPL10L (140801),0.999599,0.147163,0.192003
77494,TTTGTTGGCCACATCTACGGTGG,C1orf137 (388667),0.999890,-0.130266,-0.052341
77495,TTTGTTTCCTCTTCTCGAGGTGG,CRISPLD1 (83690),0.882474,0.725285,0.723811


In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score, average_precision_score
from scipy.stats import spearmanr

model = MultiHeadAttention_model(input_shape=(23, 4))
model.load_weights('/content/drive/MyDrive/Colab Notebooks/Cas9/On target/saved_model/Cas9_MultiHeadAttention_weights.keras')

x_test = np.concatenate(list(map(get_seqcode, sgRNA_breast_cancer['guide_efficacy_sgRNA'].to_list())))
pred_score = model.predict(x_test)



In [None]:
# for efficacy

y_test = np.array(sgRNA_breast_cancer['efficacy']).reshape(-1,1,1,1)

# 60%
true_type = [1 if item > np.percentile(y_test, 60) else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 60) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 60% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 60% is {score}\n')

# 70%
true_type = [1 if item > np.percentile(y_test, 70) else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 70) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 70% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 70% is {score}\n')

# 80%
true_type = [1 if item > np.percentile(y_test, 80) else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 80) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 80% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 80% is {score}\n')

# 90%
true_type = [1 if item > np.percentile(y_test, 90) else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 90) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 90% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 90% is {score}\n')

Accuracy of active gRNA ranked top 60% is 0.5424
ROC AUC of active gRNA ranked top 60% is 0.551

Accuracy of active gRNA ranked top 70% is 0.5922
ROC AUC of active gRNA ranked top 70% is 0.55

Accuracy of active gRNA ranked top 80% is 0.689
ROC AUC of active gRNA ranked top 80% is 0.5557

Accuracy of active gRNA ranked top 90% is 0.8224
ROC AUC of active gRNA ranked top 90% is 0.5508



In [None]:
# for log2FC MDA-MB-231 repA
y_test = np.array(abs(sgRNA_breast_cancer['MDAMB231-311Cas9_RepA_p5_batch2'])).reshape(-1,1,1,1)

# 60%
true_type = [1 if item >= 1 else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 60) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 60% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 60% is {score}\n')

# 70%
true_type = [1 if item >= 1 else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 70) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 70% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 70% is {score}\n')

# 80%
true_type = [1 if item >= 1 else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 80) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 80% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 80% is {score}\n')

# 90%
true_type = [1 if item >= 1 else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 90) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 90% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 90% is {score}\n')

Accuracy of active gRNA ranked top 60% is 0.5758
ROC AUC of active gRNA ranked top 60% is 0.5151

Accuracy of active gRNA ranked top 70% is 0.6429
ROC AUC of active gRNA ranked top 70% is 0.5151

Accuracy of active gRNA ranked top 80% is 0.7091
ROC AUC of active gRNA ranked top 80% is 0.5151

Accuracy of active gRNA ranked top 90% is 0.7744
ROC AUC of active gRNA ranked top 90% is 0.5151



In [None]:
# for log2FC MDA-MB-231 repB
y_test = np.array(abs(sgRNA_breast_cancer['MDAMB231-311Cas9_RepB_p5_batch2'])).reshape(-1,1,1,1)

# 60%
true_type = [1 if item >= 1 else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 60) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 60% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 60% is {score}\n')

# 70%
true_type = [1 if item >= 1 else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 70) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 70% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 70% is {score}\n')

# 80%
true_type = [1 if item >= 1 else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 80) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 80% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 80% is {score}\n')

# 90%
true_type = [1 if item >= 1 else 0 for item in y_test]
pre_type = [1 if item > np.percentile(pred_score, 90) else 0 for item in pred_score]

score = np.round(accuracy_score(true_type, pre_type), 4)
print(f'Accuracy of active gRNA ranked top 90% is {score}')
score = np.round(roc_auc_score(true_type, pred_score.flatten()), 4)
print(f'ROC AUC of active gRNA ranked top 90% is {score}\n')

Accuracy of active gRNA ranked top 60% is 0.5723
ROC AUC of active gRNA ranked top 60% is 0.5108

Accuracy of active gRNA ranked top 70% is 0.6394
ROC AUC of active gRNA ranked top 70% is 0.5108

Accuracy of active gRNA ranked top 80% is 0.706
ROC AUC of active gRNA ranked top 80% is 0.5108

Accuracy of active gRNA ranked top 90% is 0.77
ROC AUC of active gRNA ranked top 90% is 0.5108

