In [1]:
import wandb
import torch
import pandas as pd
from grelu.lightning import LightningModel
import pytorch_lightning as pl
from grelu.sequence.utils import get_unique_length, resize

  from .autonotebook import tqdm as notebook_tqdm


## set up wandb

In [2]:
wandb.login(host="https://api.wandb.ai", relogin=True)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mavantikalal[0m ([33mgrelu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
run = wandb.init(entity='grelu', project='borzoi', job_type='copy', name='copy-mouse') # Initialize a W&B Run

## Paths

In [4]:
tasks_path = '/code/borzoi/data/targets_mouse.txt.gz'
intervals_path = '/gstore/data/resbioai/grelu/borzoi-data/mm10/sequences.bed'

## Process tasks

In [5]:
tasks = pd.read_table(tasks_path, index_col=0)
print(len(tasks))

2608


In [6]:
tasks=tasks.reset_index(drop=True)
tasks["assay"] = tasks["description"].apply(lambda x: x.split(":")[0])
tasks["assay"].value_counts()
tasks["sample"] = tasks["description"].apply(lambda x: ":".join(x.split(":")[1:]))
tasks = tasks.rename(columns={"identifier":"name"})
tasks.head(2)

Unnamed: 0,name,file,clip,clip_soft,scale,sum_stat,strand_pair,description,assay,sample
0,CNhs10464+,/home/drk/tillage/datasets/mouse/cage/fantom/C...,768,384,1.0,sum,1,"CAGE:placenta, adult pregnant day17",CAGE,"placenta, adult pregnant day17"
1,CNhs10464-,/home/drk/tillage/datasets/mouse/cage/fantom/C...,768,384,1.0,sum,0,"CAGE:placenta, adult pregnant day17",CAGE,"placenta, adult pregnant day17"


In [7]:
tasks = tasks.to_dict(orient="list")

## Process intervals

In [8]:
intervals = pd.read_table(intervals_path, header=None)
intervals.columns = ['chrom', 'start', 'end', 'fold']
intervals.head()

Unnamed: 0,chrom,start,end,fold
0,chr1,46257174,46453782,fold0
1,chr2,83512641,83709249,fold0
2,chr7,16218353,16414961,fold0
3,chr3,113724419,113921027,fold0
4,chr3,107470140,107666748,fold0


In [9]:
intervals['split'] = 'train'
intervals.loc[intervals.fold=='fold3', 'split'] = 'test'
intervals.loc[intervals.fold=='fold4', 'split'] = 'val'
intervals.split.value_counts()

split
train    36950
val       6318
test      6101
Name: count, dtype: int64

In [10]:
get_unique_length(intervals)

196608

In [11]:
intervals = resize(intervals, 524288)

In [12]:
train_intervals = intervals[intervals.split=='train'].iloc[:, :3]
val_intervals = intervals[intervals.split=='val'].iloc[:, :3]
test_intervals = intervals[intervals.split=='test'].iloc[:, :3]
del intervals

## Initialize model

In [13]:
model_params = {
    "model_type":"BorzoiModel",
    "n_tasks":2608,
    "final_act_func": 'softplus',
    "final_pool_func":None,
    "crop_len":5120,
}
train_params={
    "task":"regression",
    "loss":"mse",
}
lm = LightningModel(model_params, train_params)

## Save checkpoints

In [14]:
from grelu.sequence.format import convert_input_type
input = convert_input_type(['A'*524288], "one_hot")
input.shape

torch.Size([1, 4, 524288])

In [15]:
6144*32

196608

In [17]:
for rep in range(4):
    state_dict = torch.load(f"/data/borzoi/torch_weights/mouse_fold{rep}.h5")
    lm.model.load_state_dict(state_dict)
    lm.data_params["tasks"] = tasks

    lm.data_params["train"] = dict()
    lm.data_params["val"] = dict()
    lm.data_params["test"] = dict()
    
    lm.data_params["train"]["seq_len"] = 524288
    lm.data_params["train"]["label_len"] = 6144*32
    lm.data_params["train"]["genome"] = "hg38"
    lm.data_params["train"]["bin_size"] = 32
    
    lm.data_params["train"]["intervals"] = train_intervals.to_dict(orient='list')
    lm.data_params["val"]["intervals"] = val_intervals.to_dict(orient='list')
    lm.data_params["test"]["intervals"] = test_intervals.to_dict(orient='list')
    
    assert lm(input).shape == (1, 2608, 6144)
    
    trainer = pl.Trainer()
    try:
        trainer.predict(lm)
    except:
        pass
    
    trainer.save_checkpoint(f'/data/borzoi/torch_weights/mouse_rep{rep}.ckpt')

  state_dict = torch.load(f"/data/borzoi/torch_weights/mouse_fold{rep}.h5")
Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
  state_dict = torch.load(f"/data/borzoi/torch_weights/mouse_fold{rep}.h5")
Trainer will use only 1 of 8 GPUs 

## Upload to wandb

In [18]:
metadata={
            'model_params':lm.model_params, 
            'train_params':lm.train_params, 
            'data_params':lm.data_params
        }

In [19]:
for rep in range(4):
    ckpt = f'/data/borzoi/torch_weights/mouse_rep{rep}.ckpt'
    artifact = wandb.Artifact(
        f'mouse_rep{rep}', 
        type='model',
        metadata=metadata,
    )
    artifact.add_file(ckpt)
    run.log_artifact(artifact)



In [20]:
run.finish()