# CRISPRGenie

This is the main notebook for the project *CRISPRGenie*. The current aim of the project is as follows:
1. The input dataset contains **sgRNA** sequences for a total of 19k different genes from the human genome. 
2. The model will be trained on all the sequences to predict different possible sgRNA sequences based on the input target gene.
3. The model will be autoregressive, i.e., it will predict new sequences based on its training data. Furthermore, the model will ideally contain a semi-supervised regression task which will allow the model to predict the metrics of the newly predicted sgRNA sequence/s for the given target.
4. The metrics include the log2-fold changes and the effect (ranging from -9 to 9)

## Data Preparation

The model training will take place in two steps:

#### Part 1: Generating sgRNA Sequences

**Model Design:**

* Input: Gene symbol (e.g., ENSG00000148584)
* Output: Set of sgRNA sequences

**Approach:**

* Data Preparation: For training, map each gene symbol to its corresponding sgRNA sequences. This could involve aggregating all sgRNA sequences that target a specific gene into a single training example.
* Model Type: Use a generative model like GPT, which is adept at producing sequences. Train the model to generate sgRNA sequences when provided with a gene symbol.

**Training:**

* Input: Gene symbol.
* Output: A sequence of sgRNAs or a concatenated string of multiple sgRNAs.
Train the model to maximize the likelihood of generating correct sgRNA sequences given a gene symbol.

#### Part 2: Predicting Effect Metrics

**Model Design:**

* Input: sgRNA sequence
* Output: Effect metrics (quantized effect as an integer from -9 to 9)

**Approach:**

* Data Preparation: Use sgRNA sequences and their corresponding effect metrics from your dataset.
* Model Type: A classification model (like BERT used for classification tasks) that can predict a class (effect metric) for each sgRNA sequence.

**Training:**

* Input: sgRNA sequence.
* Output: Effect class.
This model can be trained using a cross-entropy loss where each class corresponds to a different quantile of sgRNA efficacy.

In [1]:
# Importing the necessary libraries
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from transformers import GPT2Config, GPT2LMHeadModel, AdamW

In [2]:
df = pd.read_csv('GenomeCRISPR_full.csv')
df.head()

Unnamed: 0,start,end,chr,strand,pubmed,cellline,condition,sequence,symbol,ensg,log2fc,rc_initial,rc_final,effect,cas,screentype
0,50844073,50844096,10,+,26472758.0,Jiyoye,viability,GCAGCATCCCAACCAGGTGGAGG,A1CF,ENSG00000148584,0.315907,[260],[244],2.0,hSpCas9,negative selection
1,50814011,50814034,10,-,26472758.0,Jiyoye,viability,GCGGGAGTGAGAGGACTGGGCGG,A1CF,ENSG00000148584,2.144141,[17],[59],9.0,hSpCas9,negative selection
2,50836111,50836134,10,+,26472758.0,Jiyoye,viability,ATGACTCTCATACTCCACGAAGG,A1CF,ENSG00000148584,1.426034,[75],[153],8.0,hSpCas9,negative selection
3,50836095,50836118,10,-,26472758.0,Jiyoye,viability,GAGTCATCGAGCAGCTGCCATGG,A1CF,ENSG00000148584,1.550133,[47],[105],8.0,hSpCas9,negative selection
4,50816234,50816257,10,-,26472758.0,Jiyoye,viability,AGTCACCCTAGCAAAACCAGTGG,A1CF,ENSG00000148584,0.382513,[58],[57],3.0,hSpCas9,negative selection


In [3]:
# Extract only the necessary columns
data_relevant = df[['ensg', 'sequence']]

# Drop any rows with missing values in these columns to ensure data integrity
data_relevant = data_relevant.dropna()

# take the first 1k instances as a test
data_relevant = data_relevant[:1000]

data_relevant.head()


Unnamed: 0,ensg,sequence
0,ENSG00000148584,GCAGCATCCCAACCAGGTGGAGG
1,ENSG00000148584,GCGGGAGTGAGAGGACTGGGCGG
2,ENSG00000148584,ATGACTCTCATACTCCACGAAGG
3,ENSG00000148584,GAGTCATCGAGCAGCTGCCATGG
4,ENSG00000148584,AGTCACCCTAGCAAAACCAGTGG


### Tokenization 

Since the current dataset has a much smaller vocabulary, I will be going with a custom tokenizer which will be lightweight compared to the pretrained tokenizer of GPT2

In [5]:
class CustomTokenizer:
    def __init__(self):
        # Include all nucleotide bases, special tokens, and necessary characters for gene IDs
        self.token_to_id = {
            '[PAD]': 0, '[ID]': 1, '[SOS]': 2, '[EOS]': 3,
            'A': 4, 'T': 5, 'G': 6, 'C': 7, 
            'E': 8, 'N': 9, 'S': 10, 
            '0': 11, '1': 12, '2': 13, '3': 14, '4': 15,
            '5': 16, '6': 17, '7': 18, '8': 19, '9': 20
        }
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}

    @property
    def vocab_size(self):
        return len(self.token_to_id)

    def encode(self, text):
        """ Convert text to a list of token IDs, treating each character as a token unless enclosed in []. """
        tokens = []
        i = 0
        while i < len(text):
            if text[i] == '[':  # Start of a special token
                special_token_end = text.find(']', i)
                if special_token_end != -1:
                    tokens.append(text[i:special_token_end+1])
                    i = special_token_end + 1
                else:
                    tokens.append(text[i])  # Fallback if ']' is missing
                    i += 1
            else:
                tokens.append(text[i])
                i += 1
        return [self.token_to_id[token] for token in tokens if token in self.token_to_id]

    def decode(self, token_ids):
        """ Convert a list of token IDs back to a string. """
        return ''.join(self.id_to_token.get(token_id, '') for token_id in token_ids)

In [6]:
tokenizer = CustomTokenizer()
# Example encoding
encoded = tokenizer.encode("[ID]ENSG00000148584[SOS]GCAGCATCCCAACCAGGTGGAGG[EOS]")
decoded = tokenizer.decode(encoded)

print("Encoded:", encoded)
print("Decoded:", decoded)

Encoded: [1, 8, 9, 10, 6, 11, 11, 11, 11, 11, 12, 15, 19, 16, 19, 15, 2, 6, 7, 4, 6, 7, 4, 5, 7, 7, 7, 4, 4, 7, 7, 4, 6, 6, 5, 6, 6, 4, 6, 6, 3]
Decoded: [ID]ENSG00000148584[SOS]GCAGCATCCCAACCAGGTGGAGG[EOS]


### Dataloader

In [7]:
class GeneSequenceDataset(Dataset):
    def __init__(self, gene_ids, sequences, tokenizer):
        self.gene_ids = gene_ids
        self.sequences = sequences
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.gene_ids)

    def __getitem__(self, idx):
        # Form the full sequence with special tokens
        full_sequence = f"[ID]{self.gene_ids[idx]}[SOS]{self.sequences[idx]}[EOS]"
        tokenized_sequence = self.tokenizer.encode(full_sequence)
        return torch.tensor(tokenized_sequence, dtype=torch.long)

def collate_fn(batch):
    batch_padded = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0)
    return batch_padded

In [8]:
# Example usage:
gene_ids = ['ENSG00000148584', 'ENSG00000155657']
sequences = ['GCAGCATCCCAACCAGGTGGAGG', 'TTGCCGTCAGCTTGGGAGG']
tokenizer = CustomTokenizer()  # Make sure your tokenizer is properly defined

dataset = GeneSequenceDataset(gene_ids, sequences, tokenizer)
dataloader = DataLoader(dataset, batch_size=20, collate_fn=collate_fn)

# Quick test to see a batch from DataLoader
for batch in dataloader:
    print(batch)
    break

tensor([[ 1,  8,  9, 10,  6, 11, 11, 11, 11, 11, 12, 15, 19, 16, 19, 15,  2,  6,
          7,  4,  6,  7,  4,  5,  7,  7,  7,  4,  4,  7,  7,  4,  6,  6,  5,  6,
          6,  4,  6,  6,  3],
        [ 1,  8,  9, 10,  6, 11, 11, 11, 11, 11, 12, 16, 16, 17, 16, 18,  2,  5,
          5,  6,  7,  7,  6,  5,  7,  4,  6,  7,  5,  5,  6,  6,  6,  4,  6,  6,
          3,  0,  0,  0,  0]])


## Model Training

In [32]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cpu


In [18]:
# Loading the pretrained model
vocab_size = len(tokenizer.token_to_id)
config = GPT2Config.from_pretrained('gpt2', vocab_size=vocab_size)
model = GPT2LMHeadModel(config)
model.resize_token_embeddings(vocab_size)
# Moving the model to the current device
model.to(device)  



GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(21, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=21, bias=False)
)

In [19]:
# Setup the dataloader for the entire dataset
gene_ids = data_relevant['ensg']
sequences = data_relevant['sequence']
dataset = GeneSequenceDataset(gene_ids, sequences, tokenizer)
# Using the custom collate function
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [20]:
optimizer = AdamW(model.parameters(), lr=5e-5)



In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def custom_loss(model_output, labels, tokenizer):
    
    outputs = model_output.logits

    # Find indices of [SOS] and [EOS] tokens
    sos_id = tokenizer.token_to_id['[SOS]']
    eos_id = tokenizer.token_to_id['[EOS]']

    loss = 0
    batch_size = outputs.size(0)
    for i in range(batch_size):
        # Extract the sequence between [SOS] and [EOS]
        start = (labels[i] == sos_id).nonzero(as_tuple=True)[0]
        end = (labels[i] == eos_id).nonzero(as_tuple=True)[0]
        if start.nelement() == 0 or end.nelement() == 0:
            continue  # Skip if [SOS] or [EOS] not found
        if start.item() >= end.item():
            continue  # Ensure valid range
        # Calculate loss only within the [SOS] and [EOS] range
        relevant_outputs = outputs[i, start:end, :]
        relevant_labels = labels[i, start:end]
        loss += F.cross_entropy(relevant_outputs, relevant_labels, reduction='sum')
    return loss / batch_size

In [22]:
def train(model, dataloader, tokenizer, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in dataloader:
            inputs, labels = batch[:, :-1], batch[:, 1:]  # Shifted for predicting the next token
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = custom_loss(outputs, labels, tokenizer)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}')

In [23]:
train(model, dataloader, tokenizer, optimizer, epochs=2)

Epoch 1, Loss: 28.78814559173584
Epoch 2, Loss: 26.91771062850952


In [35]:
import torch.nn.functional as F

def generate_sgRNA_sequence(model, tokenizer, gene_id, max_length=20, num_sequences=5, device='cpu'):
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():  # No need to track gradients
        # Prepare the input with the gene ID and start token
        input_tokens = f"[ID]{gene_id}[SOS]"
        input_ids = tokenizer.encode(input_tokens)
        
        input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)

        # Beam search parameters
        beam_size = num_sequences
        sequences = [[input_ids[:], 0]]  # List of [sequence, score]

        for _ in range(max_length):  # Limit maximum generation length
            all_candidates = []
            for seq, score in sequences:
                input_tensor = torch.tensor([seq], dtype=torch.long).to(device)
                output = model(input_tensor)
                predictions = output.logits[0, -1, :]  # Get logits for the last token
                probs = F.softmax(predictions, dim=-1)
                top_k_probs, top_k_ids = probs.topk(beam_size)  # Get top k probabilities and token IDs
                
                for i in range(beam_size):
                    candidate = [seq + [top_k_ids[i].item()], score - torch.log(top_k_probs[i]).item()]
                    all_candidates.append(candidate)

            # Order all candidates by score
            ordered = sorted(all_candidates, key=lambda x: x[1])
            sequences = ordered[:beam_size]

            # Stop if all sequences end with [EOS]
            if all(tokenizer.token_to_id['[EOS]'] in seq for seq, score in sequences):
                break

        # Decode the top sequences
        top_sequences = []
        for seq, score in sequences:
            if tokenizer.token_to_id['[EOS]'] in seq:
                end_index = seq.index(tokenizer.token_to_id['[EOS]'])
            else:
                end_index = len(seq)
            trimmed_seq = seq[len(input_ids):end_index]
            generated_sequence = tokenizer.decode(trimmed_seq)
            top_sequences.append(generated_sequence)

        return top_sequences

In [36]:
# Example usage
gene_id = "ENSG00000148584"
generated_sgRNA = generate_sgRNA_sequence(model, tokenizer, gene_id)
print("Generated sgRNA Sequence:", generated_sgRNA)

Generated sgRNA Sequence: ['GGAGGAGAGAGAGAAAGCAG', 'GGAGGAGAGAGAGAAAGAGG', 'GAAGGAGAGAGAGAAAGCAG', 'GGAGGAGAGGAAGAAAGCAG', 'GGAGGAGAGAGAGAAAGGGG']
