In [1]:
import wandb
import torch
import pandas as pd

from grelu.lightning import LightningModel
import pytorch_lightning as pl
from grelu.sequence.utils import get_unique_length, resize

  from .autonotebook import tqdm as notebook_tqdm


## wandb login

In [2]:
wandb.login(host="https://api.wandb.ai")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mavantikalal[0m ([33mgrelu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
run = wandb.init(entity='grelu', project='enformer', job_type='copy', name='copy-mouse') # Initialize a W&B Run

## Paths

In [4]:
targets_path = 'https://raw.githubusercontent.com/calico/basenji/master/manuscripts/cross2020/targets_mouse.txt'

In [5]:
sequences_path = '/gstore/data/resbioai/grelu/enformer/sequences-mouse.bed'

## Process tasks

In [6]:
tasks = pd.read_csv(targets_path, sep='\t', index_col=0)
print(len(tasks))
tasks.head(3)

1643


Unnamed: 0_level_0,genome,identifier,file,clip,scale,sum_stat,description
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
5313,1,ENCFF866ZTV,/home/drk/tillage/datasets/mouse/dnase/encode/...,32,2,mean,DNASE:B6D2F1/J 416B
5314,1,ENCFF695LHM,/home/drk/tillage/datasets/mouse/dnase/encode/...,32,2,mean,DNASE:BALB/cAnN A20
5315,1,ENCFF079SPZ,/home/drk/tillage/datasets/mouse/dnase/encode/...,32,2,mean,DNASE:C57BL/6 B cell male adult (8 weeks)


In [7]:
tasks = tasks.reset_index(drop=True)
tasks = tasks.drop(columns=["genome"])
tasks["assay"] = tasks["description"].apply(lambda x: x.split(":")[0])
tasks["sample"] = tasks["description"].apply(lambda x: ":".join(x.split(":")[1:]))
tasks = tasks.rename(columns={"identifier":"name"})
tasks.head()

Unnamed: 0,name,file,clip,scale,sum_stat,description,assay,sample
0,ENCFF866ZTV,/home/drk/tillage/datasets/mouse/dnase/encode/...,32,2,mean,DNASE:B6D2F1/J 416B,DNASE,B6D2F1/J 416B
1,ENCFF695LHM,/home/drk/tillage/datasets/mouse/dnase/encode/...,32,2,mean,DNASE:BALB/cAnN A20,DNASE,BALB/cAnN A20
2,ENCFF079SPZ,/home/drk/tillage/datasets/mouse/dnase/encode/...,32,2,mean,DNASE:C57BL/6 B cell male adult (8 weeks),DNASE,C57BL/6 B cell male adult (8 weeks)
3,ENCFF798VSP,/home/drk/tillage/datasets/mouse/dnase/encode/...,32,2,mean,DNASE:C57BL/6 splenic B cell male adult (8 weeks),DNASE,C57BL/6 splenic B cell male adult (8 weeks)
4,ENCFF474GND,/home/drk/tillage/datasets/mouse/dnase/encode/...,32,2,mean,DNASE:C57BL/6 cerebellum male adult (8 weeks),DNASE,C57BL/6 cerebellum male adult (8 weeks)


In [8]:
tasks = tasks.to_dict(orient="list")

## Process intervals

In [9]:
intervals = pd.read_table(sequences_path, header=None, names = ['chrom', 'start', 'end', 'split'])
intervals.head()

Unnamed: 0,chrom,start,end,split
0,chr4,34106647,34237719,train
1,chr5,52207747,52338819,train
2,chr19,20136862,20267934,train
3,chr14,61845439,61976511,train
4,chr15,6592346,6723418,train


In [10]:
intervals.split.value_counts()

split
train    29295
valid     2209
test      2017
Name: count, dtype: int64

In [11]:
get_unique_length(intervals)

131072

In [12]:
intervals = resize(intervals, 196608)
intervals.head()

Unnamed: 0,chrom,start,end,split
0,chr4,34073879,34270487,train
1,chr5,52174979,52371587,train
2,chr19,20104094,20300702,train
3,chr14,61812671,62009279,train
4,chr15,6559578,6756186,train


In [13]:
train_intervals = intervals[intervals.split=='train'].iloc[:, :3]
val_intervals = intervals[intervals.split=='valid'].iloc[:, :3]
test_intervals = intervals[intervals.split=='test'].iloc[:, :3]
del intervals

## Initialize model

In [14]:
model_params={
    'model_type':'EnformerModel',
    'final_act_func': 'softplus',
    'final_pool_func':None,
    'n_tasks': 1643,
    'crop_len':320,
}
train_params={'task':'regression', 'loss':'mse'}

model = LightningModel(model_params, train_params)

## Load weights

In [15]:
state_dict = torch.load("/data/enformer/torch_weights/mouse.h5")
model.model.load_state_dict(state_dict)

  state_dict = torch.load("/data/enformer/torch_weights/mouse.h5")


<All keys matched successfully>

## Add hparams

In [16]:
model.data_params["train"] = dict()
model.data_params["val"] = dict()
model.data_params["test"] = dict()

In [17]:
model.data_params["train"]["seq_len"] = 196608
model.data_params["train"]["label_len"] = 896 * 128
model.data_params["train"]["genome"] = "mm10"
model.data_params["train"]["bin_size"] = 128
model.data_params["train"]["max_seq_shift"] = 3
model.data_params["train"]["rc"] = True

## Add tasks

In [18]:
model.data_params["tasks"] = tasks

## Add intervals

In [19]:
model.data_params["train"]["intervals"] = train_intervals.to_dict(orient='list')
model.data_params["val"]["intervals"] = val_intervals.to_dict(orient='list')
model.data_params["test"]["intervals"] = test_intervals.to_dict(orient='list')

## Save

In [20]:
trainer = pl.Trainer()
try:
    trainer.predict(model) 
except:
    trainer.save_checkpoint('/data/enformer/torch_weights/mouse.ckpt')

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


## Upload

In [21]:
artifact = wandb.Artifact(
    'mouse', 
    type='model',
    metadata={
        'model_params':model.model_params, 
        'train_params':model.train_params, 
        'data_params':model.data_params
    }
)
artifact.add_file(local_path='/data/enformer/torch_weights/mouse.ckpt', name='model.ckpt')
run.log_artifact(artifact)



<Artifact mouse>

In [22]:
run.finish() 