-
Notifications
You must be signed in to change notification settings - Fork 9
/
multiple.py
136 lines (114 loc) · 5.29 KB
/
multiple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
## run 10 motif split simulation script
import logging
logger = logging.getLogger(__name__)
import click
import numpy as np
from pathlib import Path
import torch
from torch import optim
from raptgen import models
from raptgen.models import CNN_Mul_VAE, LSTM_Mul_VAE, CNNLSTM_Mul_VAE
from raptgen.models import CNN_AR_VAE, LSTM_AR_VAE, CNNLSTM_AR_VAE
from raptgen.models import CNN_PHMM_VAE, LSTM_PHMM_VAE, CNNLSTM_PHMM_VAE
from raptgen.data import SequenceGenerator, SingleRound
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
default_path = str(Path(f"{dir_path}/../out/simlulation/multiple").resolve())
@click.command(help='run experiment with multiple motif',context_settings=dict(show_default=True))
@click.option("--n-motif", help = "the number of motifs", type = int, default = 10)
@click.option("--n-seq", help = "the number of the sequence to generate", type = int, default = 10000)
@click.option("--seed", help = "seed for seqeunce generation reproduction", type = int, default = 0)
@click.option("--error-rate", help = "the ratio to modify sequence", type = float, default = 0.1)
@click.option("--epochs", help = "the number of training epochs", type = int, default = 1000)
@click.option("--threshold", help = "the number of epochs with no loss update to stop training", type = int, default = 50)
@click.option("--use-cuda/--no-cuda", help = "use cuda if available", is_flag=True, default = True)
@click.option("--cuda-id", help = "the device id of cuda to run", type = int, default = 0)
@click.option("--save-dir", help = "path to save results", type = click.Path(), default=default_path)
@click.option("--reg-epochs", help = "the number of epochs to conduct state transition regularization", type = int, default=50)
@click.option("--multi", help = "the number of training for multiple times", type = int, default=1)
@click.option("--only-cnn/--all-models", help = "train all encoder types or not", type = bool, default=False)
def main(n_motif, n_seq, seed, error_rate, epochs, threshold, cuda_id, use_cuda, save_dir, reg_epochs, multi, only_cnn):
logger = logging.getLogger(__name__)
logger.info(f"saving to {save_dir}")
save_dir = Path(save_dir).expanduser()
save_dir.mkdir(exist_ok = True, parents=True)
# generate sequences
fwd_adapter = "AAAAA"
rev_adapter = "GGGGG"
generator = SequenceGenerator(
num_motifs = n_motif,
seed=seed,
fix_random_region_length=True,
error_rate=error_rate,
generate_motifs=True,
add_primer=True,
forward_primer=fwd_adapter,
reverse_primer=rev_adapter,
paired=False)
reads, motif_indices = generator.sample(n_seq)
with open(save_dir/"seqences.txt","w") as f:
for index, read in zip(motif_indices, reads):
f.write(f"{index}, {read}\n")
with open(save_dir/"motifs.txt","w") as f:
for motif in generator.motifs:
f.write(f"{motif}\n")
experiment = SingleRound(
reads,
forward_adapter = fwd_adapter,
reverse_adapter = rev_adapter)
# training
train_loader, test_loader = experiment.get_dataloader()
device = torch.device(f"cuda:{cuda_id}" if (use_cuda and torch.cuda.is_available()) else "cpu")
train_kwargs = {
"epochs" : epochs,
"threshold" : threshold,
"device" : device,
"train_loader" : train_loader,
"test_loader" : test_loader,
"save_dir" : save_dir,
"beta_schedule" : True,
"force_matching" : True,
"force_epochs" : reg_epochs,
}
# evaluate models
target_len = experiment.random_region_length
results = dict()
for i in range(multi):
eval_models = [
CNN_Mul_VAE (target_len=target_len, embed_size=2),
CNN_AR_VAE (embed_size=2),
CNN_PHMM_VAE (motif_len=target_len, embed_size=2)
]
if not only_cnn:
eval_models.extend([
LSTM_Mul_VAE (target_len=target_len, embed_size=2),
LSTM_AR_VAE (embed_size=2),
LSTM_PHMM_VAE (motif_len=target_len, embed_size=2),
CNNLSTM_Mul_VAE(target_len=target_len, embed_size=2),
CNNLSTM_AR_VAE(embed_size=2),
CNNLSTM_PHMM_VAE(motif_len=target_len, embed_size=2)])
for model in eval_models:
model_str = str(type(model)).split("\'")[-2].split(".")[-1].lower()
if multi > 1:
model_str += f"_{i}"
model_str += ".mdl"
print (f"training {model_str}")
optimizer = optim.Adam(model.parameters())
model = model.to(device)
train_kwargs.update({
"model" : model,
"model_str" : model_str,
"optimizer" : optimizer})
results[model_str] = models.train(**train_kwargs)
torch.cuda.empty_cache()
if __name__ == "__main__":
Path("./.log").mkdir(parents=True, exist_ok=True)
formatter = '%(levelname)s : %(name)s : %(asctime)s : %(message)s'
logging.basicConfig(
filename='.log/logger.log',
level=logging.DEBUG,
format=formatter)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger('').addHandler(console)
main()