# How to use gReLU with external pytorch models

In [11]:
import numpy as np
import pandas as pd
from torch import nn
import grelu.resources
import os
from grelu.model.models import BaseModel
from grelu.sequence.format import convert_input_type

## Load a pyTorch model

In [2]:
import kipoi
kipoi_model = kipoi.get_model('Basset', with_dataloader=False)

Already up to date.
Using downloaded and verified file: /root/.kipoi/models/Basset/downloaded/model_files/weights/4878981d84499eb575abd0f3b45570d3


  self.model.load_state_dict(torch.load(weights))


In [23]:
kipoi_model = kipoi_model.model.to('cpu')

In [108]:
tasks = pd.read_table('https://raw.github.com/davek44/Basset/refs/heads/master/data/models/targets.txt', header=None, 
                     names=['name', 'source'])
tasks.head()

Unnamed: 0,name,source
0,8988T,encode/wgEncodeAwgDnaseDuke8988tUniPk.narrowPe...
1,AoSMC,encode/wgEncodeAwgDnaseDukeAosmcUniPk.narrowPe...
2,Chorion,encode/wgEncodeAwgDnaseDukeChorionUniPk.narrow...
3,CLL,encode/wgEncodeAwgDnaseDukeCllUniPk.narrowPeak.gz
4,Fibrobl,encode/wgEncodeAwgDnaseDukeFibroblUniPk.narrow...


## Convert it to a gReLU compatible model

In [81]:
class InputReshape(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x.unsqueeze(-1)


class OutputReshape(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x.unsqueeze(-1)

In [88]:
embedding = nn.Sequential(InputReshape(), *[l for l in model[:21]])
head = nn.Sequential(*[x for x in model[21]], OutputReshape())

In [89]:
basemodel = BaseModel(embedding=embedding, head=head)

In [90]:
basemodel.head.n_tasks = 164

## Wrap it in a LightningModel

In [91]:
lm = grelu.lightning.LightningModel(model=basemodel)

In [92]:
lm.activation=nn.Sigmoid()

In [78]:
lm.data_params['tasks'] = {'name': tasks.tolist()}

## Test the model

In [None]:
oh = convert_input_type('A'*600, 'one_hot', add_batch_axis=True)
oh.shape

In [None]:
lm(oh).shape

## Load some variant data

In [71]:
variant_dir = grelu.resources.get_artifact(
    project='alzheimers-variant-tutorial',
    name='dataset'
).download()

variant_file = os.path.join(variant_dir, "variants.txt")

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [72]:
variants = pd.read_table(variant_file)

In [74]:
variants = variants.head()

## Predict variant effects

In [98]:
import grelu.variant

odds = grelu.variant.predict_variant_effects(
    variants=variants,
    model=lm, 
    devices=0, # Run on GPU 0
    num_workers=8,
    batch_size=128,
    genome="hg38",
    compare_func="subtract", # Return the log2 fold change between alt and ref predictions
    return_ad=True, # Return an anndata object.
    rc = True, # Reverse complement the ref/alt predictions and average them.
    seq_len=600,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


making dataset
Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 115.59it/s]




In [102]:
odds.obs.head()

Unnamed: 0,snpid,chrom,pos,alt,ref,rsid,zscore,pval,nsum,neff,direction,eaf,beta,se
0,6:32630634_G_A,chr6,32630634,G,A,6:32630634,3.974476,7.1e-05,71639,71639.0,?+?+,0.2237,0.025194,0.006339
1,6:32630797_A_G,chr6,32630797,A,G,6:32630797,4.040244,5.3e-05,71639,71639.0,?+?+,0.2435,0.024866,0.006155
2,6:32630824_T_C,chr6,32630824,T,C,6:32630824,3.921736,8.8e-05,71639,71639.0,?+?+,0.1859,0.02663,0.00679
3,6:32630829_G_A,chr6,32630829,G,A,6:32630829,4.044549,5.2e-05,71639,71639.0,?+?+,0.1859,0.027463,0.00679
4,6:32630925_T_A,chr6,32630925,T,A,6:32630925,3.942586,8.1e-05,71639,71639.0,?+?+,0.2137,0.025407,0.006444


In [103]:
odds.var.head()

8988T
AoSMC
Chorion
CLL
Fibrobl


In [106]:
odds.X[:, :10]

array([[ 6.7059975e-04, -7.1943970e-05,  8.2078017e-04,  3.0160247e-04,
         9.9612400e-04,  3.2657199e-04,  1.5163550e-04,  2.2395590e-04,
         2.5888509e-04,  2.3600331e-04],
       [-3.1625712e-04,  1.8137204e-04, -3.8895570e-04, -1.2878567e-04,
        -2.7760863e-04, -4.8756599e-05, -1.8160080e-04, -8.2039332e-05,
        -9.1010530e-05, -1.2605428e-04],
       [ 2.2092855e-03,  1.1844642e-03,  4.3695448e-03,  2.1502160e-04,
         2.7354609e-02,  1.2441045e-02,  1.1652061e-03,  2.2177270e-04,
         2.4229416e-04,  3.9629173e-04],
       [-6.0073775e-04,  1.1721095e-03, -1.4877692e-03, -5.5570388e-05,
        -5.9145242e-03,  6.3451566e-04, -1.7841649e-04, -1.7306185e-05,
         2.7102651e-06, -5.8941077e-05],
       [ 1.2307693e-03, -3.2268488e-04,  1.6341531e-03, -1.3554969e-04,
         1.0393305e-02,  4.1489732e-03,  5.7467003e-04, -7.7596225e-05,
         4.6995294e-05, -2.5936752e-04]], dtype=float32)