In [1]:
import wandb
import torch
import pandas as pd

from grelu.lightning import LightningModel
import pytorch_lightning as pl
from grelu.sequence.utils import get_unique_length, resize

  from .autonotebook import tqdm as notebook_tqdm


## wandb login

In [2]:
wandb.login(host="https://api.wandb.ai")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mavantikalal[0m ([33mgrelu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [12]:
run = wandb.init(entity='grelu', project='enformer', job_type='copy', name='copy-human',
    settings=wandb.Settings(
        program_relpath='/code/github/gReLU-applications/enformer/save_wandb_enformer_human.ipynb',
        program_abspath='/code/github/gReLU-applications/enformer/save_wandb_enformer_human.ipynb'
    ))

In [13]:
wandb.run.log_code() 



## Paths

In [14]:
targets_path = 'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_human.txt'

In [15]:
sequences_path = '/gstore/data/resbioai/grelu/enformer/sequences.bed'

## Process tasks

In [16]:
tasks = pd.read_csv(targets_path, sep='\t', index_col=0)
print(len(tasks))
tasks.head(3)

5313


Unnamed: 0_level_0,genome,identifier,file,clip,scale,sum_stat,description
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,0,ENCFF833POA,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:cerebellum male adult (27 years) and mal...
1,0,ENCFF110QGM,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:frontal cortex male adult (27 years) and...
2,0,ENCFF880MKD,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:chorion


In [17]:
tasks = tasks.reset_index(drop=True)
tasks = tasks.drop(columns=["genome"])
tasks["assay"] = tasks["description"].apply(lambda x: x.split(":")[0])
tasks["sample"] = tasks["description"].apply(lambda x: ":".join(x.split(":")[1:]))
tasks = tasks.rename(columns={"identifier":"name"})
tasks.head()

Unnamed: 0,name,file,clip,scale,sum_stat,description,assay,sample
0,ENCFF833POA,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:cerebellum male adult (27 years) and mal...,DNASE,cerebellum male adult (27 years) and male adul...
1,ENCFF110QGM,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:frontal cortex male adult (27 years) and...,DNASE,frontal cortex male adult (27 years) and male ...
2,ENCFF880MKD,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:chorion,DNASE,chorion
3,ENCFF463ZLQ,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:Ishikawa treated with 0.02% dimethyl sul...,DNASE,Ishikawa treated with 0.02% dimethyl sulfoxide...
4,ENCFF890OGQ,/home/drk/tillage/datasets/human/dnase/encode/...,32,2,mean,DNASE:GM03348,DNASE,GM03348


In [18]:
tasks = tasks.to_dict(orient="list")

## Process intervals

In [19]:
intervals = pd.read_table(sequences_path, header=None)
intervals.columns = ['chrom', 'start', 'end', 'split']
intervals.head()

Unnamed: 0,chrom,start,end,split
0,chr18,928386,1059458,train
1,chr4,113630947,113762019,train
2,chr11,18427720,18558792,train
3,chr16,85805681,85936753,train
4,chr3,158386188,158517260,train


In [20]:
intervals.split.value_counts()

split
train    34021
valid     2213
test      1937
Name: count, dtype: int64

In [21]:
get_unique_length(intervals)

131072

In [22]:
intervals = resize(intervals, 196608)
intervals.head()

Unnamed: 0,chrom,start,end,split
0,chr18,895618,1092226,train
1,chr4,113598179,113794787,train
2,chr11,18394952,18591560,train
3,chr16,85772913,85969521,train
4,chr3,158353420,158550028,train


In [23]:
train_intervals = intervals[intervals.split=='train'].iloc[:, :3]
val_intervals = intervals[intervals.split=='valid'].iloc[:, :3]
test_intervals = intervals[intervals.split=='test'].iloc[:, :3]
del intervals

## Initialize model

In [24]:
model_params={
    'model_type':'EnformerModel',
    'final_act_func': 'softplus',
    'final_pool_func':None,
    'n_tasks': 5313,
    'crop_len':320,
}
train_params={'task':'regression', 'loss':'mse'}

model = LightningModel(model_params, train_params)

## Load weights

In [25]:
state_dict = torch.load("/data/enformer/torch_weights/human.h5")
model.model.load_state_dict(state_dict)

  state_dict = torch.load("/data/enformer/torch_weights/human.h5")


<All keys matched successfully>

## Add hparams

In [26]:
model.data_params["train"] = dict()
model.data_params["val"] = dict()
model.data_params["test"] = dict()

In [27]:
model.data_params["train"]["seq_len"] = 196608
model.data_params["train"]["label_len"] = 896 * 128
model.data_params["train"]["genome"] = "hg38"
model.data_params["train"]["bin_size"] = 128
model.data_params["train"]["max_seq_shift"] = 3
model.data_params["train"]["rc"] = True

## Add tasks

In [28]:
model.data_params["tasks"] = tasks

## Add intervals

In [29]:
model.data_params["train"]["intervals"] = train_intervals.to_dict(orient='list')
model.data_params["val"]["intervals"] = val_intervals.to_dict(orient='list')
model.data_params["test"]["intervals"] = test_intervals.to_dict(orient='list')

## Save

In [30]:
trainer = pl.Trainer()
try:
    trainer.predict(model) 
except:
    trainer.save_checkpoint('/data/enformer/torch_weights/human.ckpt')

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


## Upload

In [31]:
artifact = wandb.Artifact(
    'human', 
    type='model',
    metadata={
        'model_params':model.model_params, 
        'train_params':model.train_params, 
        'data_params':model.data_params
    }
)
artifact.add_file(local_path='/data/enformer/torch_weights/human.ckpt', name='model.ckpt')
run.log_artifact(artifact)



<Artifact human>

In [32]:
run.finish() 