# NOTE
Try a multimodal approach as follows:
1. Train a simple model to predict sgRNA sequences
2. Next, fine tuning: Using the Gene ID and its respective sequences, try to fine tune the model to predict the sgRNA sequences based on the gene ID

In [1]:
# Importing the required libraries
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import tiktoken

In [2]:
class CausalSelfAttention(nn.Module):
    """
    Implements a causal self-attention mechanism which is a fundamental component of transformer models
    designed for sequence processing tasks where the model should not have future insight. This module 
    ensures that the predictions for a particular position are dependent only on the known outputs at 
    previous positions.

    Attributes:
        c_attn (nn.Linear): Linear layer that projects input embeddings into queries, keys, and values.
        c_proj (nn.Linear): Linear layer that projects the output of the attention mechanism back to
                            the dimension of embeddings.
        bias (torch.Tensor): Buffer that applies a triangular mask to ensure attention is only applied
                             to preceding positions, preserving causality.
    """

    def __init__(self, config):
        """
        Initializes the CausalSelfAttention layer with specific configuration.

        Args:
            config: A configuration object containing attributes like `n_embd` (embedding size),
                    `n_head` (number of attention heads), and `block_size` (sequence length).
        """
        super().__init__()
        # Ensuring the embedding size is divisible by the number of heads for even split.
        assert config.n_embd % config.n_head == 0

        # Linear transformation that outputs triple the embedding dimension to split into
        # queries, keys, and values.
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)

        # Linear transformation for the output of the attention computation.
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)

        # Store the number of attention heads and the embedding dimension per head.
        self.n_head = config.n_head
        self.n_embd = config.n_embd

        # Register a buffer for the triangular mask that prevents attending to future positions.
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                         .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        """
        Defines the forward pass of the causal self-attention mechanism.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, sequence_length, embedding_dim).

        Returns:
            torch.Tensor: The output tensor after processing with causal self-attention.
        """
        # Unpack the dimensions of the input tensor.
        B, T, C = x.size()

        # Pass the input through the attention projection layer to get combined query, key, value tensors.
        qkv = self.c_attn(x).split(self.n_embd, dim=2)

        # Split and reshape the combined QKV tensor into individual Q, K, V tensors and transpose
        # for multi-head attention computation.
        q, k, v = [tensor.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) for tensor in qkv]

        # Compute the attention scores, apply scaling for stability, and use the mask to enforce causality.
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))

        # Apply softmax to convert scores to probabilities and compute the weighted sum of values.
        att = F.softmax(att, dim=-1)
        y = (att @ v).transpose(1, 2).contiguous().view(B, T, C)

        # Project the output back to the embedding dimension and return.
        return self.c_proj(y)

In [3]:
class MLP(nn.Module):
    """
    A multilayer perceptron (MLP) module used within transformer blocks as a position-wise
    feed-forward network. This module is a simple neural network for transforming the 
    representation at every position independently in the sequence.

    Attributes:
        c_fc (nn.Linear): The first linear layer that expands the input dimension.
        gelu (nn.GELU): Gaussian Error Linear Unit (GELU) activation function, which
                        allows the model to include non-linearity and helps in learning
                        more complex patterns. This version uses the 'tanh' approximation
                        for faster computation.
        c_proj (nn.Linear): The second linear layer that projects the output back to 
                            the original embedding dimension.
    """

    def __init__(self, config):
        """
        Initializes the MLP module with specified configurations.

        Args:
            config: A configuration object containing `n_embd`, the size of the input
                    and output embeddings.
        """
        super().__init__()
        # First linear layer that increases dimensionality 4x to allow more complex interactions.
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        
        # GELU activation function with 'tanh' approximation.
        self.gelu = nn.GELU(approximate='tanh')
        
        # Second linear layer that reduces dimensionality back to the original size.
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        """
        Defines the forward pass of the MLP module.

        Args:
            x (torch.Tensor): The input tensor to the MLP with shape (batch_size, sequence_length, n_embd).

        Returns:
            torch.Tensor: The output tensor after processing through two linear layers
                          and a GELU activation function, with the same shape as input.
        """
        # Pass the input through the first linear layer and then apply the GELU activation function.
        x = self.c_fc(x)
        x = self.gelu(x)
        
        # Finally, pass the activated output through the second linear layer to match the original embedding size.
        x = self.c_proj(x)
        return x

In [4]:
class Block(nn.Module):
    """
    Represents a single Transformer block, which is a fundamental component of the Transformer architecture.
    Each block sequentially applies layer normalization, a causal self-attention mechanism, another layer normalization,
    and a multilayer perceptron (MLP). The architecture follows a typical pattern used in JXT models,
    implementing a residual connection around each of the two main sub-layers (self-attention and MLP).

    Attributes:
        ln_1 (nn.LayerNorm): Layer normalization applied before the self-attention mechanism.
        attn (CausalSelfAttention): The causal self-attention module, ensuring that the predictions
                                    for a position are dependent only on the known outputs at previous positions.
        ln_2 (nn.LayerNorm): Layer normalization applied before the MLP.
        mlp (MLP): The multilayer perceptron module that processes the output of the attention mechanism.
    """

    def __init__(self, config):
        """
        Initializes the Transformer block with specified configurations.

        Args:
            config: A configuration object containing necessary parameters like `n_embd`, which is used
                    to set the dimensionality of the layer normalization and to configure the attention and MLP modules.
        """
        super().__init__()
        # Layer normalization that normalizes the embeddings before the self-attention layer.
        self.ln_1 = nn.LayerNorm(config.n_embd)
        
        # The self-attention mechanism defined in the CausalSelfAttention class.
        self.attn = CausalSelfAttention(config)
        
        # Layer normalization that normalizes the output of the attention mechanism before passing it to the MLP.
        self.ln_2 = nn.LayerNorm(config.n_embd)
        
        # The MLP that further processes the output from the attention mechanism.
        self.mlp = MLP(config)

    def forward(self, x):
        """
        Defines the forward pass through the Transformer block.

        Args:
            x (torch.Tensor): Input tensor to the block with shape (batch_size, sequence_length, n_embd).

        Returns:
            torch.Tensor: The output tensor from the block, which has the same shape as the input.
                          This output can be fed into subsequent blocks in a Transformer model.
        """
        # Apply layer normalization, then self-attention, and add the result to the input (residual connection).
        x = x + self.attn(self.ln_1(x))
        
        # Apply another layer normalization, then process through the MLP, and add the result to the output
        # of the previous self-attention layer (residual connection).
        x = x + self.mlp(self.ln_2(x))
        
        return x

In [91]:
@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension
    sos_token_id: int = 1
    eos_token_id: int = 2

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.n_embd),
            'wpe': nn.Embedding(config.block_size, config.n_embd),
            'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            'ln_f': nn.LayerNorm(config.n_embd),
        })
        self.ffn_head = nn.Sequential(
            nn.Linear(config.n_embd, config.n_embd),  # Adjust dimensions as needed
            nn.ReLU(),
            nn.Linear(config.n_embd, config.n_embd)
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        #print(f"Input idx shape: {idx.shape}")

        # Assuming 'sos_token_id' is defined in config
        sos_positions = (idx == self.config.sos_token_id).nonzero(as_tuple=True)
        gene_id_end_idx = sos_positions[1].min()  # Index of the first SOS token
        gene_id_idx = idx[:, :gene_id_end_idx]
        sequence_idx = idx[:, gene_id_end_idx+1:]  # Exclude SOS token from sequence

        # Process gene_id through FFN
        gene_id_emb = self.transformer.wte(gene_id_idx)
        gene_id_context = self.ffn_head(gene_id_emb.mean(dim=1))  # Using mean to collapse gene_id embeddings

        # Process sequence
        pos = torch.arange(0, sequence_idx.size(1), dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)  # Position embeddings
        tok_emb = self.transformer.wte(sequence_idx)  # Token embeddings
        x = tok_emb + pos_emb.unsqueeze(0)  # Broadcast position embeddings
        x += gene_id_context.unsqueeze(1)  # Add gene ID context to each token position in the sequence

        #print(f"Combined input shape after context addition: {x.shape}")

        # Forward the blocks of the transformer
        for block in self.transformer.h:
            x = block(x)

        # Final layer norm and classifier
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        logits = logits.view(-1, self.config.vocab_size)
        #print(f"Logits shape after view: {logits.shape}")

        loss = None
        if targets is not None:
            # Assuming the targets are sliced in the data preparation or dataloader to match the sequence indices
            targets = targets[:, gene_id_end_idx+1:].contiguous().view(-1)  # Adjust target to only include sequence part
            #print(f"Targets shape after view: {targets.shape}")
            if logits.size(0) != targets.size(0):
                raise ValueError(f"Logits and targets size mismatch: {logits.size(0)} vs {targets.size(0)}")
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    @classmethod
    def from_pretrained(cls, model_type):
        """Loads pretrained GPT-2 model weights from huggingface"""
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        # n_layer, n_head and n_embd are determined from model_type
        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
        # create a from-scratch initialized minGPT model
        config = GPTConfig(**config_args)
        model = GPT(config)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
        # this means that we have to transpose these weights when we import them
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                # special treatment for the Conv1D weights we need to transpose
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # vanilla copy over the other parameters
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model

In [92]:
class CustomTokenizer:
    def __init__(self):
        self.token_to_id = {
            '[PAD]': 0, '[SOS]': 1, '[EOS]': 2,
            'A': 3, 'T': 4, 'G': 5, 'C': 6,
            '[ID]': 7,  # Special token for gene ID start
            '0': 8, '1': 9, '2': 10, '3': 11, '4': 12,
            '5': 13, '6': 14, '7': 15, '8': 16, '9': 17,
            '[PAD]': 18,
        }
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}

    def encode(self, text):
        tokens = []
        i = 0
        while i < len(text):
            if text[i:i+4] == 'ENSG':  # Recognize the ENSG gene ID prefix
                tokens.append('[ID]')
                i += 4  # Skip the 'ENSG'
            elif text[i] == '[':  # Special tokens handling
                special_token_end = text.find(']', i)
                tokens.append(text[i:special_token_end+1])
                i = special_token_end + 1
            else:
                tokens.append(text[i])
                i += 1
        return [self.token_to_id.get(token, self.token_to_id['[PAD]']) for token in tokens]

    @property
    def vocab_size(self):
        return len(self.token_to_id)

    def decode(self, token_ids):
        """ Convert a list of token IDs back to a string. """
        return ''.join(self.id_to_token.get(token_id, '') for token_id in token_ids)
    
    def is_special_token(self, token_ids):
        """Return if a token is a special token"""
        return [token_ids > 3]

In [93]:
class DataLoaderLite:
    def __init__(self, gene_ids, sequences, B, T, tokenizer):
        self.B = B
        self.T = T
        self.tokenizer = tokenizer
        self.current_position = 0
        self.pad_token_id = tokenizer.token_to_id['[PAD]']
        
        # Prepare data
        self.batches = []
        for i in range(0, len(gene_ids), B):
            batch_gene_ids = gene_ids[i:i+B]
            batch_sequences = sequences[i:i+B]
            max_len = 0
            batch_encoded = []
            
            # Encode and find the max length in the current batch
            for gene_id, seq in zip(batch_gene_ids, batch_sequences):
                formatted_seq = f'{gene_id}[SOS]{seq}[EOS]'
                encoded = tokenizer.encode(formatted_seq)
                max_len = max(max_len, len(encoded))
                batch_encoded.append(encoded)
            
            # Pad sequences in the current batch to the max length
            padded_batch = []
            for encoded in batch_encoded:
                padded_length = max_len - len(encoded)
                padded_seq = encoded + [self.pad_token_id] * padded_length
                padded_batch.append(padded_seq)
            
            # Add the padded batch to batches
            self.batches.append(torch.tensor(padded_batch, dtype=torch.long))
    
    def next_batch(self):
        if self.current_position >= len(self.batches):
            self.current_position = 0  # Reset for next epoch
            return None, None  # No more data
        
        batch = self.batches[self.current_position]
        self.current_position += 1
        
        # Prepare targets by rolling tensors
        targets = batch.roll(-1, dims=1)
        return batch, targets

    def view_data(self):
        # Return a readable format of the data (decoding back from token IDs to text)
        readable_data = []
        for batch in self.batches:
            for sequence in batch:
                decoded_sequence = [self.tokenizer.decode([token]) for token in sequence.tolist()]
                readable_data.append(decoded_sequence)
        return readable_data

In [94]:
num_return_sequences = 5
max_length = 70

In [95]:
import pandas as pd

df = pd.read_csv('../data/GenomeCRISPR.csv')
df.head()

Unnamed: 0,start,end,chr,strand,pubmed,cellline,condition,sequence,symbol,ensg,log2fc,rc_initial,rc_final,effect,cas,screentype
0,50844073,50844096,10,+,26472758,Jiyoye,viability,GCAGCATCCCAACCAGGTGGAGG,A1CF,ENSG00000148584,0.315907,{260},{244},2,hSpCas9,negative selection
1,50814011,50814034,10,-,26472758,Jiyoye,viability,GCGGGAGTGAGAGGACTGGGCGG,A1CF,ENSG00000148584,2.144141,{17},{59},9,hSpCas9,negative selection
2,50836111,50836134,10,+,26472758,Jiyoye,viability,ATGACTCTCATACTCCACGAAGG,A1CF,ENSG00000148584,1.426034,{75},{153},8,hSpCas9,negative selection
3,50836095,50836118,10,-,26472758,Jiyoye,viability,GAGTCATCGAGCAGCTGCCATGG,A1CF,ENSG00000148584,1.550133,{47},{105},8,hSpCas9,negative selection
4,50816234,50816257,10,-,26472758,Jiyoye,viability,AGTCACCCTAGCAAAACCAGTGG,A1CF,ENSG00000148584,0.382513,{58},{57},3,hSpCas9,negative selection


In [140]:
tokenizer = CustomTokenizer()
train_loader = DataLoaderLite(df['ensg'],df['sequence'], B=4, T=32, tokenizer=tokenizer)

In [141]:
model = GPT(GPTConfig()).to('cpu')

In [142]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-6)
for i in range(100):
    x, y = train_loader.next_batch()
    x, y = x.to('cpu'), y.to('cpu')
    # print("Input shape:", x.shape)  # Should match (B, T) for input
    # print("Target shape:", y.shape)  # Should be (B, T) initially, reshaped later
    optimizer.zero_grad()
    logits, loss = model(x, y)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()        
    print(f"step {i}, loss: {loss.item()}")

step 0, loss: 10.994247436523438
step 1, loss: 10.946769714355469
step 2, loss: 10.846275329589844
step 3, loss: 10.668484687805176
step 4, loss: 10.578746795654297
step 5, loss: 10.516585350036621
step 6, loss: 10.400824546813965
step 7, loss: 10.247157096862793
step 8, loss: 10.16037368774414
step 9, loss: 10.146445274353027
step 10, loss: 9.870000839233398
step 11, loss: 10.06450366973877
step 12, loss: 9.885464668273926
step 13, loss: 9.647636413574219
step 14, loss: 9.538131713867188
step 15, loss: 9.528976440429688
step 16, loss: 9.46561336517334
step 17, loss: 9.39891242980957
step 18, loss: 9.37364387512207
step 19, loss: 9.336405754089355
step 20, loss: 9.190179824829102
step 21, loss: 8.909993171691895
step 22, loss: 9.193188667297363
step 23, loss: 9.118147850036621
step 24, loss: 8.874943733215332
step 25, loss: 8.61152458190918
step 26, loss: 8.61601734161377
step 27, loss: 8.696346282958984
step 28, loss: 8.51191520690918
step 29, loss: 8.504636764526367
step 30, loss: 8.

In [150]:
import torch
from torch.nn.functional import softmax
from Levenshtein import distance as levenshtein_distance

def generate_sequence_and_evaluate(model, tokenizer, gene_id, actual_sequence, max_length=100, top_k=5):
    model.eval()
    with torch.no_grad():
        # Prepare the initial input tokens
        input_tokens = tokenizer.encode(f'{gene_id}[SOS]')
        input_tokens = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0)
        
        x = input_tokens
        # Process the input through the model's layers
        x = model.transformer.wte(x)  # Word token embeddings
        pos = torch.arange(0, x.size(1), dtype=torch.long, device=x.device)
        x = x + model.transformer.wpe(pos)  # Positional embeddings
        
        for block in model.transformer.h:
            x = block(x)
        
        x = model.transformer.ln_f(x)
        logits = model.lm_head(x)  # Logits from the final output layer

        generated_sequence = []
        top_5_predictions_each_step = []
        for idx in range(x.size(1)):
            logits_slice = logits[:, idx, :]
            probs = softmax(logits_slice, dim=-1)
            top_probs, top_indices = torch.topk(probs, top_k, dim=1)
            top_tokens = [tokenizer.decode([idx]) for idx in top_indices[0].tolist()]  # Decode each token id to its corresponding token

            if tokenizer.token_to_id['[EOS]'] in top_indices[0]:
                break  # Stop if EOS token is among top predictions
            
            generated_sequence.append(top_indices[0][0].item())  # Append the most probable token to the sequence
            top_5_predictions_each_step.append((top_probs[0].tolist(), top_tokens))  # Store probabilities and tokens

        # Decode the generated sequence to a string using the tokenizer
        predicted_sequence = tokenizer.decode(generated_sequence)

    # Compute Levenshtein distance and accuracy
    levenshtein_dist = levenshtein_distance(predicted_sequence, actual_sequence)
    accuracy = 1 - (levenshtein_dist / max(len(predicted_sequence), len(actual_sequence)))

    return {
        'gene_id': gene_id,
        'predicted_sequence': predicted_sequence,
        'actual_sequence': actual_sequence,
        'levenshtein_distance': levenshtein_dist,
        'accuracy': accuracy,
        'top_5_predictions_each_step': top_5_predictions_each_step  # Include top 5 predictions for each step in the output
    }

# Example usage, provide the actual_sequence that corresponds to the gene_id
results = generate_sequence_and_evaluate(model, tokenizer, "ENSG000101148584", "GGATACGATACATGGA", top_k=5)
print("Generated Sequence:", results['predicted_sequence'])
print("Actual Sequence:", results['actual_sequence'])
print("Levenshtein Distance:", results['levenshtein_distance'])
print("Accuracy:", results['accuracy'])
for step, (probs, tokens) in enumerate(results['top_5_predictions_each_step']):
    print(f"Step {step}:")
    for prob, token in zip(probs, tokens):
        print(f"  {token} (Prob: {prob:.4f})")

Generated Sequence: GGGGGGGGGAGGGG
Actual Sequence: GGATACGATACATGGA
Levenshtein Distance: 10
Accuracy: 0.375
Step 0:
  G (Prob: 0.0008)
  C (Prob: 0.0004)
   (Prob: 0.0002)
   (Prob: 0.0002)
   (Prob: 0.0002)
Step 1:
  G (Prob: 0.0021)
  A (Prob: 0.0013)
  C (Prob: 0.0003)
   (Prob: 0.0002)
   (Prob: 0.0002)
Step 2:
  G (Prob: 0.0020)
  A (Prob: 0.0008)
  C (Prob: 0.0006)
   (Prob: 0.0003)
   (Prob: 0.0002)
Step 3:
  G (Prob: 0.0005)
  A (Prob: 0.0005)
  C (Prob: 0.0005)
   (Prob: 0.0002)
   (Prob: 0.0002)
Step 4:
  G (Prob: 0.0019)
  A (Prob: 0.0004)
  C (Prob: 0.0003)
   (Prob: 0.0002)
   (Prob: 0.0002)
Step 5:
  G (Prob: 0.0012)
  A (Prob: 0.0008)
  C (Prob: 0.0005)
  T (Prob: 0.0002)
   (Prob: 0.0002)
Step 6:
  G (Prob: 0.0010)
  A (Prob: 0.0007)
  T (Prob: 0.0003)
  C (Prob: 0.0003)
   (Prob: 0.0002)
Step 7:
  G (Prob: 0.0015)
  C (Prob: 0.0008)
  A (Prob: 0.0004)
  T (Prob: 0.0003)
   (Prob: 0.0002)
Step 8:
  G (Prob: 0.0047)
  A (Prob: 0.0007)
  C (Prob: 0.0003)
  T (Prob: 0.00