In [1]:
import sys
sys.path.insert(0,'../../../')
from lib.data_processing import GenNLPMaskedDataset
from transformers import ElectraForMaskedLM, ElectraTokenizer, ElectraConfig, TrainingArguments
import pandas as pd
import numpy as np
from lib.utils import general as g
from lib.config.config_class import page_config
from lib.model.overwriter import OTrainingArguments, OTrainer
from lib.utils.metrics import evalpred_to_word, r2_score_transformers
import json
import os
from IPython.display import clear_output

# Config

In [2]:
config = None
with g.reading('/client/user1/cuongdev/GenImputation/data/train/electra_G1K_22_hs37d5/config.json') as cf:
    config = json.load(cf)
assert config is not None, "config can't none"

In [3]:
regions = [1,2,3,4,5,6,7,8,9,10,11,12]
batchs = [0,1,2,3,4,5,6,7,8]
train_region_paths = page_config.get_file_paths(config[page_config.file_train_prefix],page_config.page,regions,batchs)
test_region_paths = page_config.get_file_paths(config[page_config.file_test_prefix],page_config.page,regions,[0])
vocab_file = config[page_config.vocab_file]
save_dir = config[page_config.save_dir]

In [4]:
training_args = OTrainingArguments(**config[page_config.train_args])
output_dir = training_args.output_dir
logging_dir = training_args.logging_dir
modeling_args = ElectraConfig(**config[page_config.model_args])
tokenizer = ElectraTokenizer(vocab_file=vocab_file)
seed = training_args.seed

In [None]:
for i, region in enumerate(regions):
    clear_output(wait=True)
    save_path = save_dir.format(region)
    prevert_path = save_dir.format(region-1)
    ## Train and eval data
    train_batch_paths = train_region_paths[i]
    train_dataset = GenNLPMaskedDataset(
        train_batch_paths[:-1],
        tokenizer,
        seed=seed,
        masked_by_flag=True,
        # masked_per=0.15,
        only_input=True,
        force_create=True)
    eval_dataset = GenNLPMaskedDataset(train_batch_paths[-1:],tokenizer,seed=seed,masked_by_flag=True,only_input=True)
    ## test data
    test_batch_paths = test_region_paths[i]
    test_dataset = GenNLPMaskedDataset(test_batch_paths,tokenizer,seed=seed,masked_by_flag=True,only_input=True)
    ## model
    modeling_args.vocab_size = tokenizer.vocab_size
    modeling_args.max_position_embeddings = 1300
    electra_model = ElectraForMaskedLM(modeling_args)
    if os.path.isdir(prevert_path):
        electra_model = ElectraForMaskedLM.from_pretrained(prevert_path)
    training_args.output_dir = output_dir.format(region)
    training_args.logging_dir = logging_dir.format(region)
    trainer = OTrainer(
        model = electra_model,
        args=training_args,
        train_dataset = train_dataset,
        eval_dataset = eval_dataset,
        compute_metrics = r2_score_transformers,
    )
    trainer.train()
    trainer.save_model(save_path)
    output_test = trainer.predict(test_dataset)
    metrics = output_test.metrics
    test_result_path = os.path.join(save_path,'test_result.json')
    with g.writing(test_result_path) as trf:
        json.dump(metrics,trf)

# Run show r2 plot

In [None]:
true_data = []
pred_data = []
for i, region in enumerate(regions):
    clear_output(wait=True)
    ## Train and eval data
    train_batch_paths = train_region_paths[i]
    train_dataset = GenNLPMaskedDataset(
        train_batch_paths[:-1],
        tokenizer,
        seed=seed,
        masked_by_flag=True,
        # masked_per=0.15,
        only_input=True,
        force_create=True)
    eval_dataset = GenNLPMaskedDataset(train_batch_paths[-1:],tokenizer,seed=seed,masked_by_flag=True,only_input=True)
    ## test data
    test_batch_paths = test_region_paths[i]
    test_dataset = GenNLPMaskedDataset(test_batch_paths,tokenizer,seed=seed,masked_by_flag=True,only_input=True,force_create=True)
    ## model
    modeling_args.vocab_size = tokenizer.vocab_size
    modeling_args.max_position_embeddings = 2000
    save_path = save_dir.format(region)
    #Load model
    electra_model = ElectraForMaskedLM.from_pretrained(save_path)
    training_args.output_dir = output_dir.format(region)
    training_args.logging_dir = logging_dir.format(region)
    trainer = Trainer(
        model = electra_model,
        args=training_args,
        train_dataset = train_dataset,
        eval_dataset = eval_dataset,
        compute_metrics = r2_score_transformers,
    )
    output_test = trainer.predict(test_dataset)
    labels, top_word = logits_to_word(output_test)
    true_data.append(labels)
    pred_data.append(top_word)

In [None]:
variant_ids = page_config.get_file_paths_in_dir('/client/user1/cuongdev/GenImputation/data/train/electra_G1K_22_hs37d5/corpus_dir/',page_config.variant)
variant_ids.sort()

In [None]:
import pandas as pd

In [None]:
df_origin = None
for i, region in enumerate(regions):
    temp = pd.read_csv(variant_ids[region],sep=page_config.page_split_params)
    if df_origin is None:
        df_origin = temp.copy()
        
    else:
        df_origin = pd.concat([df_origin,temp])

In [None]:
true_data = [d[:,1:-1] for d in true_data]
pred_data = [d[:,1:-1] for d in pred_data]

In [None]:
y_true = np.concatenate(true_data,axis=1)
y_pred = np.concatenate(pred_data,axis=1)
y_true.shape, y_pred.shape

In [None]:
from lib.data_processing import process_ouput as po

In [None]:
masked_indexs = df_origin['flag'].values == 0

In [None]:
mafs = np.array(list(map(lambda af: af if af <= 0.5 else 1-af,df_origin['AF'].values)))
mafs = mafs[masked_indexs]

In [None]:
po.plot_r2_by_maf(mafs,y_true.T[masked_indexs],[y_pred.T[masked_indexs]])

In [None]:
paper_format = '/client/user1/cuongdev/GenImputation/temp/chr22_{}.gen'
paper_data = []
for i, region in enumerate(regions):
    paper_path = paper_format.format(region+1)
    temp = pd.read_csv(paper_path,sep=' ',header=None)
    temp.drop(columns=[0],inplace=True)
    temp.rename(columns={1:'CHROM',2:'POS',3:'REF',4:'ALT'},inplace=True)
    temp['CHROM'] = np.full(temp.shape[0],22)
    paper_data.append(temp)

In [None]:
paper_r10 = pd.concat(paper_data)

In [None]:
cols = ['CHROM','POS','REF','ALT']

In [None]:
pd.merge(df_origin[cols],paper_r10,how='inner',on=cols)

In [None]:
df_origin[cols].dtypes

In [None]:
paper_r10[cols].dtypes

In [None]:
df_origin

In [None]:
import pandas as pd

In [None]:
[pd.read_csv('/client/user1/cuongdev/GenImputation/data/train/electra_G1K_22_hs37d5/corpus_dir/G1K_22_hs37d5_biallelic_train.r00{:02d}.b0000.variant.gz'.format(i)).shape for i in range(20)]

In [None]:
pd.read_csv('/client/user1/cuongdev/GenImputation/data/external/region_info/rnnimp.chr22.r0000.variant.gz')

In [None]:
pd.read_csv('/client/user1/cuongdev/GenImputation/data/train/electra_G1K_22_hs37d5/corpus_dir/G1K_22_hs37d5_biallelic_train.r0000.b0000.variant.gz')