# How to use gReLU with external pytorch models

In [1]:
import numpy as np
import pandas as pd
from torch import nn
import grelu.resources
import os
from grelu.model.models import BaseModel
from grelu.sequence.format import convert_input_type
from grelu.sequence.utils import generate_random_sequences

  from .autonotebook import tqdm as notebook_tqdm


## Load a pyTorch model

In [2]:
import kipoi
kipoi_model = kipoi.get_model('Basset', with_dataloader=False)

Already up to date.
Using downloaded and verified file: /root/.kipoi/models/Basset/downloaded/model_files/weights/4878981d84499eb575abd0f3b45570d3


  self.model.load_state_dict(torch.load(weights))


In [3]:
kipoi_model = kipoi_model.model.to('cpu')

In [4]:
tasks = pd.read_table('https://raw.github.com/davek44/Basset/refs/heads/master/data/models/targets.txt', header=None, 
                     names=['name', 'source'])
tasks.head()

Unnamed: 0,name,source
0,8988T,encode/wgEncodeAwgDnaseDuke8988tUniPk.narrowPe...
1,AoSMC,encode/wgEncodeAwgDnaseDukeAosmcUniPk.narrowPe...
2,Chorion,encode/wgEncodeAwgDnaseDukeChorionUniPk.narrow...
3,CLL,encode/wgEncodeAwgDnaseDukeCllUniPk.narrowPeak.gz
4,Fibrobl,encode/wgEncodeAwgDnaseDukeFibroblUniPk.narrow...


In [5]:
len(tasks)

164

## Make it gReLU compatible

In [6]:
class AddFinalAxis(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x.unsqueeze(-1)

In [7]:
embedding = nn.Sequential(AddFinalAxis(), *[l for l in kipoi_model[:21]])

head = nn.Sequential(*[x for x in kipoi_model[21]], AddFinalAxis())
head.n_tasks = len(tasks)

## Wrap it in a LightningModel

In [9]:
lm = grelu.lightning.LightningModel(model_params={'model_type':'BaseModel', 'embedding':embedding, 'head':head})

In [10]:
lm.activation = nn.Sigmoid()

In [11]:
lm.data_params['tasks'] = tasks.to_dict(orient="list")

## Test the model

In [12]:
test_input = generate_random_sequences(n=5, output_format='one_hot', seq_len=600, seed=0)
test_input.shape

torch.Size([5, 4, 600])

In [13]:
test_output = lm(test_input)

In [14]:
test_output.shape

torch.Size([5, 164, 1])

In [15]:
test_output.min(), test_output.max()

(tensor(0.1328, grad_fn=<MinBackward1>),
 tensor(0.8120, grad_fn=<MaxBackward1>))

## Load some variant data

In [16]:
variant_dir = grelu.resources.get_artifact(
    project='alzheimers-variant-tutorial',
    name='dataset'
).download()

variant_file = os.path.join(variant_dir, "variants.txt")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33manony-mouse-180959755991866352[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [36]:
test_variants = pd.read_table(variant_file, nrows=400)

## Predict variant effects

In [39]:
import grelu.variant

odds = grelu.variant.predict_variant_effects(
    variants=test_variants,
    model=lm, 
    seq_len=600,
    devices=0, # Run on GPU 0
    num_workers=0,
    batch_size=128,
    genome="hg38",
    compare_func="subtract", # Return the difference between alt and ref predicted probabilities
    return_ad=True, # Return an anndata object.
    rc = True, # Reverse complement the ref/alt predictions and average them.
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


making dataset


/opt/conda/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 78.20it/s]




In [44]:
odds.shape

(400, 164)

In [40]:
odds.obs.head()

Unnamed: 0,snpid,chrom,pos,alt,ref,rsid,zscore,pval,nsum,neff,direction,eaf,beta,se
0,6:32630634_G_A,chr6,32630634,G,A,6:32630634,3.974476,7.1e-05,71639,71639.0,?+?+,0.2237,0.025194,0.006339
1,6:32630797_A_G,chr6,32630797,A,G,6:32630797,4.040244,5.3e-05,71639,71639.0,?+?+,0.2435,0.024866,0.006155
2,6:32630824_T_C,chr6,32630824,T,C,6:32630824,3.921736,8.8e-05,71639,71639.0,?+?+,0.1859,0.02663,0.00679
3,6:32630829_G_A,chr6,32630829,G,A,6:32630829,4.044549,5.2e-05,71639,71639.0,?+?+,0.1859,0.027463,0.00679
4,6:32630925_T_A,chr6,32630925,T,A,6:32630925,3.942586,8.1e-05,71639,71639.0,?+?+,0.2137,0.025407,0.006444


In [41]:
odds.var.head()

Unnamed: 0_level_0,source
name,Unnamed: 1_level_1
8988T,encode/wgEncodeAwgDnaseDuke8988tUniPk.narrowPe...
AoSMC,encode/wgEncodeAwgDnaseDukeAosmcUniPk.narrowPe...
Chorion,encode/wgEncodeAwgDnaseDukeChorionUniPk.narrow...
CLL,encode/wgEncodeAwgDnaseDukeCllUniPk.narrowPeak.gz
Fibrobl,encode/wgEncodeAwgDnaseDukeFibroblUniPk.narrow...


In [42]:
odds.X[:5, :5]

array([[ 6.70550740e-04, -7.25158025e-05,  8.21364112e-04,
         3.01392458e-04,  9.98657197e-04],
       [-3.16521619e-04,  1.80897303e-04, -3.88936605e-04,
        -1.28909480e-04, -2.76844949e-04],
       [ 2.20930134e-03,  1.18435547e-03,  4.36861161e-03,
         2.14962347e-04,  2.73526926e-02],
       [-6.00683503e-04,  1.17215188e-03, -1.48754101e-03,
        -5.56568848e-05, -5.91373816e-03],
       [ 1.23135047e-03, -3.22295120e-04,  1.63462386e-03,
        -1.35493581e-04,  1.03949215e-02]], dtype=float32)

In [43]:
odds.X.max(), odds.X.min()

(0.4989372, -0.35747033)