In [1]:
import pandas as pd
import genome_kit as gk
from genome_kit import Genome, Interval, VariantGenome
import numpy as np
import math
from functools import partial
import os
import json
import torch
import torch.nn as nn
from mamba_ssm.modules.mamba_simple import Mamba, Block
from huggingface_hub import PyTorchModelHubMixin

In [2]:
df = pd.read_csv("../data/tcga_dataset.csv") 
df['ref'] = df['HGVSc'].str.extract(r'[c|g]\.\d+([A-Z])>')
df['alt'] = df['HGVSc'].str.extract(r'>?([A-Z])$')
df['cds_pos'] = df['HGVSc'].str.extract('(\d+)')
df['var_id'] = df['Transcript_ID']+ ":" + df['HGVSc']
df

Unnamed: 0,Cancer_type,Cancer_type_count,NMF_cluster,build,chromosome,start,end,Hugo_Symbol,Transcript_ID,HGVSc,...,AF Group,LOEUF,LOEUF_bin,5UTR_length,3UTR_length,Transcript_length,ref,alt,cds_pos,var_id
0,ACC,12,1,GRCh38,chr12,98546362,98546362,TMPO,ENST00000556029,c.994G>T,...,[0],0.737,3.0,356,370,1722,G,T,994,ENST00000556029:c.994G>T
1,ACC,12,1,GRCh38,chr12,112911070,112911070,OAS1,ENST00000202917,c.489T>G,...,[0],0.818,3.0,263,715,1467,T,G,489,ENST00000202917:c.489T>G
2,ACC,12,1,GRCh38,chr14,21076448,21076448,ARHGEF40,ENST00000298694,c.1828C>T,...,[0],0.722,2.0,127,3958,5915,C,T,1828,ENST00000298694:c.1828C>T
3,ACC,12,1,GRCh38,chr17,81860411,81860411,P4HB,ENST00000331483,c.61G>T,...,[0],0.561,1.0,223,1465,1751,G,T,61,ENST00000331483:c.61G>T
4,ACC,12,2,GRCh38,chr1,19219308,19219308,EMC1,ENST00000477853,c.2977C>T,...,"[0.000001, 0.00001)",0.914,4.0,43,4,3026,C,T,2977,ENST00000477853:c.2977C>T
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4252,UCS,40,2,GRCh38,chr9,34088499,34088499,DCAF12,ENST00000361264,c.1213G>T,...,[0],0.699,2.0,342,148,1705,G,T,1213,ENST00000361264:c.1213G>T
4253,UCS,40,2,GRCh38,chr9,79576163,79576163,TLE4,ENST00000376552,c.238G>T,...,[0],0.166,0.0,1018,2083,3341,G,T,238,ENST00000376552:c.238G>T
4254,UVM,3,1,GRCh38,chr14,70342681,70342681,COX16,ENST00000389912,c.118C>T,...,"[0.00001, 0.0001)",1.374,7.0,144,202,466,C,T,118,ENST00000389912:c.118C>T
4255,UVM,3,1,GRCh38,chr17,81869230,81869230,ARHGDIA,ENST00000269321,c.358C>T,...,"[0.000001, 0.00001)",0.607,2.0,136,256,752,C,T,358,ENST00000269321:c.358C>T


In [3]:
ref_genome = Genome("gencode.v29")

def find_transcript(genome, transcript_id):
    """Find a transcript in a genome by transcript ID.
    
    Args:
        genome (object): The genome object containing a list of transcripts.
        transcript_id (str): The ID of the transcript to find.
        
    Returns:
        object: The transcript object, if found.
        
    Raises:
        ValueError: If no transcript with the given ID is found.
    
    Example:
        >>> # Create sample transcripts and a genome
        >>> transcript1 = 'ENST00000263946'
        >>> genome = Genome("gencode.v29")
        >>> result = find_transcript(genome, 'ENST00000335137')
        >>> print(result.id)
        <Transcript ENST00000263946.7 of PKP1>
        >>> # If transcript ID is not found
        >>> find_transcript(genome, 'ENST00000000000')
        ValueError: Transcript with ID ENST00000000000 not found.
    """
    transcripts = [x for x in genome.transcripts if x.id.split('.')[0] == transcript_id]
    if not transcripts:
        print(f"Transcript with ID {transcript_id} not found.")
        return ''
    
    return transcripts[0]

def create_cds_track(t):
    """Create a track of the coding sequence of a transcript.
    Use the exons of the transcript to create a track where the first position of the codon is one.
    
    Args:
        t (gk.Transcript): The transcript object.
    """
    cds_intervals = t.cdss
    utr3_intervals = t.utr3s
    utr5_intervals = t.utr5s
    
    len_utr3 = sum([len(x) for x in utr3_intervals])
    len_utr5 = sum([len(x) for x in utr5_intervals])
    len_cds = sum([len(x) for x in cds_intervals])
    
    # create a track where first position of the codon is one
    cds_track = np.zeros(len_cds, dtype=int)
    # set every third position to 1
    cds_track[0::3] = 1
    # concat with zeros of utr3 and utr5
    cds_track = np.concatenate([np.zeros(len_utr5, dtype=int), cds_track, np.zeros(len_utr3, dtype=int)])
    return cds_track

def create_splice_track(t):
    """Create a track of the splice sites of a transcript.
    The track is a 1D array where the positions of the splice sites are 1.

    Args:
        t (gk.Transcript): The transcript object.
    """
    len_utr3 = sum([len(x) for x in t.utr3s])
    len_utr5 = sum([len(x) for x in t.utr5s])
    len_cds = sum([len(x) for x in t.cdss])
    
    len_mrna = len_utr3 + len_utr5 + len_cds
    splicing_track = np.zeros(len_mrna, dtype=int)
    cumulative_len = 0
    for exon in t.exons:
        cumulative_len += len(exon)
        splicing_track[cumulative_len - 1:cumulative_len] = 1
        
    return splicing_track

# convert to one hot
def seq_to_oh(seq):
    oh = np.zeros((len(seq), 4), dtype=int)
    for i, base in enumerate(seq):
        if base == 'A':
            oh[i, 0] = 1
        elif base == 'C':
            oh[i, 1] = 1
        elif base == 'G':
            oh[i, 2] = 1
        elif base == 'T':
            oh[i, 3] = 1
    return oh

def create_one_hot_encoding(t, genome):
    """Create a track of the sequence of a transcript.
    The track is a 2D array where the rows are the positions
    and the columns are the one-hot encoding of the bases.

    Args
        t (gk.Transcript): The transcript object.
    """
    seq = "".join([genome.dna(exon) for exon in t.exons])
    oh = seq_to_oh(seq)
    return oh

def create_six_track_encoding(t, genome, channels_last=False):
    """Create a track of the sequence of a transcript.
    The track is a 2D array where the rows are the positions
    and the columns are the one-hot encoding of the bases.
    Concatenate the one-hot encoding with the cds track and the splice track.

    Args
        t (gk.Transcript): The transcript object.
    """
    oh = create_one_hot_encoding(t, genome)
    cds_track = create_cds_track(t)
    splice_track = create_splice_track(t)
    six_track = np.concatenate([oh, cds_track[:, None], splice_track[:, None]], axis=1)
    if not channels_last:
        six_track = six_track.T
    return six_track

In [4]:
def create_block(
    d_model,
    ssm_cfg=None,
    norm_epsilon=1e-5,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
):
    if ssm_cfg is None:
        ssm_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}
    mix_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
    norm_cls = partial(nn.LayerNorm, eps=norm_epsilon, **factory_kwargs)
    block = Block(
        d_model,
        mix_cls,
        norm_cls=norm_cls,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32,
    )
    block.layer_idx = layer_idx
    return block


class MixerModel(
    nn.Module,
    PyTorchModelHubMixin,
):

    def __init__(
        self,
        d_model: int,
        n_layer: int,
        input_dim: int,
        ssm_cfg=None,
        norm_epsilon: float = 1e-5,
        rms_norm: bool = False,
        initializer_cfg=None,
        fused_add_norm=False,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32

        self.embedding = nn.Linear(input_dim, d_model, **factory_kwargs)

        self.layers = nn.ModuleList(
            [
                create_block(
                    d_model,
                    ssm_cfg=ssm_cfg,
                    norm_epsilon=norm_epsilon,
                    residual_in_fp32=residual_in_fp32,
                    fused_add_norm=fused_add_norm,
                    layer_idx=i,
                    **factory_kwargs,
                )
                for i in range(n_layer)
            ]
        )

        self.norm_f = nn.LayerNorm(d_model, eps=norm_epsilon, **factory_kwargs)

        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )

    def forward(self, x, inference_params=None, channel_last=False):
        if not channel_last:
            x = x.transpose(1, 2)

        hidden_states = self.embedding(x)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(
                hidden_states, residual, inference_params=inference_params
            )

        residual = (hidden_states + residual) if residual is not None else hidden_states
        hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))

        hidden_states = hidden_states

        return hidden_states

    def representation(
        self,
        x: torch.Tensor,
        lengths: torch.Tensor,
        aggr: str = "mean",
        channel_last: bool = False,  
    ) -> torch.Tensor:
        """Get global representation of input data.

        Args:
            x: Data to embed. Has shape (B x C x L) if not channel_last.
            lengths: Unpadded length of each data input.
            aggr: mean, max, or no_aggr
            channel_last: Expects input of shape (B x L x C).
            
        Returns:
            Global representation vector of shape (B x H).
        """
        out = self.forward(x, channel_last=channel_last)

        if aggr == "mean":
            embed = mean_unpadded(out, lengths)
        if aggr == "max":
            embed = max_unpadded(out, lengths)
        if aggr == "no_aggr":
            embed = out
            
        return embed


def mean_unpadded(x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
    """Take mean of tensor across second dimension without padding.

    Args:
        x: Tensor to take unpadded mean. Has shape (B x L x H).
        lengths: Tensor of unpadded lengths. Has shape (B)

    Returns:
        Mean tensor of shape (B x H).
    """
    mask = torch.arange(x.size(1), device=x.device)[None, :] < lengths[:, None]
    masked_tensor = x * mask.unsqueeze(-1)
    sum_tensor = masked_tensor.sum(dim=1)
    mean_tensor = sum_tensor / lengths.unsqueeze(-1).float()

    return mean_tensor


def max_unpadded(x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
    """Take max of tensor across second dimension without padding.

    Args:
        x: Tensor to take unpadded max. Has shape (B x L x H).
        lengths: Tensor of unpadded lengths. Has shape (B)

    Returns:
        Max tensor of shape (B x H).
    """
    mask = torch.arange(x.size(1), device=x.device)[None, :] < lengths[:, None]
    masked_tensor = x * mask.unsqueeze(-1)
    # Replace masked out values with a very low value before taking max
    #masked_tensor[~mask.unsqueeze(-1)] = float('-inf')
    max_tensor, _ = masked_tensor.max(dim=1)
    return max_tensor
    

def _init_weights(
    module,
    n_layer,
    initializer_range=0.02,  # Now only used for embedding layer.
    rescale_prenorm_residual=True,
    n_residuals_per_layer=1,  # Change to 2 if we have MLP
):
    if isinstance(module, nn.Linear):
        if module.bias is not None:
            if not getattr(module.bias, "_no_reinit", False):
                nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                nn.init.kaiming_uniform_(p, a=math.sqrt(5))
                with torch.no_grad():
                    p /= math.sqrt(n_residuals_per_layer * n_layer)

def load_model(run_path: str, checkpoint_name: str, device='cpu') -> nn.Module:
    """Load trained model located at specified path.

    Args:
        run_path: Path where run data is located.
        checkpoint_name: Name of model checkpoint to load.

    Returns:
        Model with loaded weights.
    """
    model_config_path = os.path.join(run_path, "model_config.json")
    data_config_path = os.path.join(run_path, "data_config.json")

    with open(model_config_path, "r") as f:
        model_params = json.load(f)

    # TODO: Temp backwards compatibility
    if "n_tracks" not in model_params:
        with open(data_config_path, "r") as f:
            data_params = json.load(f)
        n_tracks = data_params["n_tracks"]
    else:
        n_tracks = model_params["n_tracks"]

    model_path = os.path.join(run_path, checkpoint_name)

    model = MixerModel(
        d_model=model_params["ssm_model_dim"],
        n_layer=model_params["ssm_n_layers"],
        input_dim=n_tracks
    )
    checkpoint = torch.load(model_path, map_location=torch.device(device))

    state_dict = {}
    for k, v in checkpoint["state_dict"].items():
        if k.startswith("model"):
            state_dict[k.lstrip("model")[1:]] = v

    model.load_state_dict(state_dict)
    return model

In [5]:
checkpoint="epoch_22_step_20000_new.ckpt" 
model_repository="/work/gr-fe/saadat/tools/orthrus/Orthrus/HF_model/orthrus_large_6_track/" 
model = load_model(f"{model_repository}", checkpoint_name=checkpoint, device='cuda')
model = model.to(torch.device('cuda'))
model

MixerModel(
  (embedding): Linear(in_features=6, out_features=512, bias=True)
  (layers): ModuleList(
    (0-5): 6 x Block(
      (mixer): Mamba(
        (in_proj): Linear(in_features=512, out_features=2048, bias=False)
        (conv1d): Conv1d(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024)
        (act): SiLU()
        (x_proj): Linear(in_features=1024, out_features=64, bias=False)
        (dt_proj): Linear(in_features=32, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=512, bias=False)
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
  )
  (norm_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)

In [6]:
def get_embeds(transcript_id, genome, aggr):
    transc = find_transcript(genome, transcript_id)
    sixt = create_six_track_encoding(transc, genome)
    sixt = torch.tensor(sixt, dtype=torch.float32)
    sixt = sixt.unsqueeze(0)
    sixt = sixt.to(device='cuda')
    lengths = torch.tensor([sixt.shape[2]]).to(device='cuda')
    embedding = model.representation(sixt, lengths, aggr)
    
    return embedding


def get_var_genome(row, ref_genome):
    nuc_dict = {'A':'T',
           'T':'A',
           'C':'G',
           'G':'C'}
    
    temp_transc = find_transcript(ref_genome, row['Transcript_ID'])

    if temp_transc.strand == '-':
        row['ref'] = nuc_dict[row['ref']]
        row['alt'] = nuc_dict[row['alt']]
        
    var_genome = VariantGenome(ref_genome, ref_genome.variant(f"{row['chromosome']}:{row['start']}:{row['ref']}:{row['alt']}")) 

    return var_genome


In [20]:
df['var_token_idx'] = None

ref_embeds = {}
alt_embeds = {}

# get embeddings

for index,row in df.iterrows():
    
    if index % 100 == 0:
        print(index)
        
    # insert variant
    temp_var_genome = get_var_genome(row, ref_genome)

    # find index of variant
    transc_ref = find_transcript(ref_genome, row['Transcript_ID'])
    sixt_ref = create_six_track_encoding(transc_ref, ref_genome)
    
    transc_alt = find_transcript(temp_var_genome, row['Transcript_ID'])
    sixt_alt = create_six_track_encoding(transc_alt, temp_var_genome)

    var_token_idx = np.where(~np.all(sixt_alt == sixt_ref, axis=0))[0]

    df.loc[index, 'var_token_idx'] = var_token_idx[0]
    
    # get embeddings
    ref_embed = get_embeds(row['Transcript_ID'], ref_genome, 'no_aggr')
    alt_embed = get_embeds(row['Transcript_ID'], temp_var_genome, 'no_aggr')

    # assign
    ref_embeds[row['var_id']] = ref_embed.squeeze().to(torch.bfloat16).detach().cpu()
    alt_embeds[row['var_id']] = alt_embed.squeeze().to(torch.bfloat16).detach().cpu()

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200


In [None]:
import pickle
import json

# save df
df.to_csv('../data/tcga_processed.tsv.gz', sep='\t', index=False) 

# Save as Pickle
with open('../data/tcga_ref_embeds.pkl', 'wb') as pickle_file:
    pickle.dump(ref_embeds, pickle_file)

with open('../data/tcga_alt_embeds.pkl', 'wb') as pickle_file:
    pickle.dump(alt_embeds, pickle_file)
