# Setup Data for GAPIT

In [None]:
# Data ----
from dataG2F.core import get_data
from dataG2F.qol  import ensure_dir_path_exists

# Data Utilities ----
import numpy  as np
import pandas as pd

# Model Building  ----
## General ====
# import torch
# from   torch import nn
# import torch.nn.functional as F
# from   torch.utils.data import Dataset
# from   torch.utils.data import DataLoader

# from vnnpaper.zma import \
#     BigDataset,    \
#     plDNN_general, \
#     mask_parents,  \
#     vnn_factory_1, \
#     vnn_factory_2, \
#     vnn_factory_3

import os 

In [None]:
# torch.set_float32_matmul_precision('medium')

In [None]:
# init_notebook_plotting()

## Setup

In [None]:
cache_path = '../nbs_artifacts/zma_g2f_individual_gapit/genotypes_holding/'

In [None]:
## Settings ====
# data settings
params_data = {
    'y_var': [
        # Description quoted from competition data readme
        'Yield_Mg_ha',     # Grain yield in Mg per ha at 15.5% grain moisture, using plot area without alley (Mg/ha).
        'Pollen_DAP_days', # Number of days after planting that 50% of plants in the plot began shedding pollen.
        'Silk_DAP_days',   # Number of days after planting that 50% of plants in the plot had visible silks.
        'Plant_Height_cm', # Measured as the distance between the base of a plant and the ligule of the flag leaf (centimeter).
        'Ear_Height_cm',   # Measured as the distance from the ground to the primary ear bearing node (centimeter).
        'Grain_Moisture',  # Water content in grain at harvest (percentage).
        'Twt_kg_m3'        # Shelled grain test weight (kg/m3), a measure of grain density.
    ],
}

In [None]:
save_prefix = [e for e in cache_path.split('/') if e != ''][-1]

# if 'None' != params_data['y_resid_strat']:
#     save_prefix = save_prefix+'_'+params_data['y_resid_strat']

ensure_dir_path_exists(dir_path = cache_path)

## Load Data

In [None]:
# Data Prep ----
# obs_geno_lookup          = get_data('obs_geno_lookup')
phno                     = get_data('phno')

In [None]:
# if '5_Genotype_Data_All_Years_vnn.hmp.txt' not in os.listdir(cache_path):
#     # write out a reference for TASSEL
#     # if the target doesn't exist then write this txt file and then filter the full dataset to these chromosome/positions using tassel
#     import re
#     parsed_kegg_gene_entries = get_data('KEGG_entries')
    
#     for ee in parsed_kegg_gene_entries:
#         e = ee['POSITION']
#         chrm, _ = e.split(':')
#         _ = [int(i) for i in re.findall('\d+\.\.\d+', _)[0].split('..')]
#         with open('./zma_filter_pos.txt', 'a') as f:
#             f.writelines( (f'{chrm}\t{i}\n' for i in range(_[0], _[1])) )


if '5_Genotype_Data_All_Years_acgt.hmp.txt' not in os.listdir(cache_path):
    # write out a reference for TASSEL
    # if the target doesn't exist then write this txt file and then filter the full dataset to these chromosome/positions using tassel
    # This differes from the previous version in that we're writing out exactly the locations from dataG2F instead of parsing the kegg data and providing a list. 
    ACGT_gene_slice_loci     = get_data('KEGG_slices_names') # 'KEGG_slices_names': 'ACGT_gene_site_name_list.pkl',
    tmp = sum(ACGT_gene_slice_loci, [])
    
    with open(cache_path+'/genotypes_holding/'+'zma_filter_pos_acgt.txt', 'w') as f:
        f.writelines( ('\t'.join(e.replace('S', '').split('_'))+'\n' for e in tmp) )

    # After this file is written, use Tassel to filter this:  
    # dataG2F/data_ext/zma/g2fc/Training_Data/5_Genotype_Data_All_Years.hmp.txt

In [None]:
# Make a reduced set of SNPs for GWAS
hmp_path = cache_path+'5_Genotype_Data_All_Years_acgt.hmp.txt'
with open(hmp_path, 'r') as f:
    dat = f.readline()
# taxa = dat.split('\t')[11:]
# mask = phno.Hybrid.isin(taxa)

In [None]:
with open(hmp_path, 'r') as f:
    dat = f.readlines()
    dat = pd.DataFrame([e.split('\t')[2:4] for e in dat][1:], columns=['chrom', 'pos'])
    dat['chrom'] = dat['chrom'].astype(int)
    dat['pos']   = dat['pos'].astype(int)

In [None]:
def _dist_to_snp(dat0):
    cols = list(dat0)
    dat0 = np.asarray(dat0)

    # diff to next
    dist_next = dat0[:, 1] - np.concatenate([np.asarray(np.nan)[None], dat0[:-1, 1]])
    dist_prev = np.concatenate([dat0[1:, 1],  np.asarray(np.nan)[None]]) - dat0[:, 1]

    # Looking for the maximum distance so that we perfer the edges of a gene instead of the center
    # Consider:
    # 
    # a            b b b           c     d
    #
    # We'll keep the obs on the edge. Decreasing the number of obs to use, we would discard all the bs before c even though the 
    # b gene is further away from a and c than c is from d. b snps are close to themselves so b will appear to be close.

    maxdist = pd.DataFrame(
        np.concatenate([dist_next[:, None], dist_prev[:, None]], axis=1).max(axis=1), columns=['Dist']
    )

    return pd.DataFrame(np.concatenate([dat0, maxdist], axis = 1), columns=cols+['dist']) 

# _dist_to_snp(dat0 = dat.loc[mask, ])

In [None]:
# Get a certain number of snps, spaced as far apart as possible in each chromosome.

n_snps = 30000

x = pd.concat([_dist_to_snp(dat0 = dat.loc[(dat.chrom == chrom), ]) for chrom in dat.chrom.unique()])
x.loc[(x.dist.isna()), 'dist'] = x.dist.max() # code first and last as max dist
x = x.sort_values('dist', ascending=False).reset_index(drop=True)
x = x.loc[(x.index < n_snps), ]
x = x.sort_values(['chrom', 'pos']).reset_index(drop=True)

print(f'Using {n_snps} SNPS, all SNPS are at least {int(x.dist.min())} from the next closest')

Using 30000 SNPS, all SNPS are at least 1782 from the next closest


In [None]:
# write out a reference file for filtering
for i in x.index:
    chrom, pos, *rest = x.loc[i, ]
    with open(cache_path+f'zma_filter_pos_{n_snps}.txt', 'a') as f:
        f.writelines( f'{int(chrom)}\t{int(pos)}\n' )

In [None]:
# get taxa
hmp_path = cache_path+'5_Genotype_Data_All_Years_vnn.hmp.txt'
with open(hmp_path, 'r') as f:
    dat = f.readline()
taxa = dat.split('\t')[11:]
mask = phno.Hybrid.isin(taxa)

In [None]:
# Create phenotype data for all y vars
for y, eb in [(y, eb) for y in params_data['y_var'] for eb in [False, True]]:
    x_start = phno.loc[mask, ['Hybrid', 'Year', 'Env', y]].reset_index(drop = True)

    if not eb: # eb is enviromental residual bool
        # Average over observations to get a set of values small enought for BLINK/TASSEL.
        # Average withing Env -> Year -> Hybrid
        x = x_start.groupby(['Hybrid', 'Year', 'Env',]).agg('mean').reset_index().drop(columns='Env'
            ).groupby(['Hybrid', 'Year',       ]).agg('mean').reset_index().drop(columns='Year'
            ).groupby(['Hybrid',               ]).agg('mean').reset_index()
    else:
        env_means = x_start.drop(columns='Hybrid').groupby(['Year', 'Env',]).agg('mean').reset_index().rename(columns={y: 'env_mean'})
        x = x_start.merge(env_means)
        x[y] = x[y]-x['env_mean']
        x = x.drop(columns=['env_mean'])
        x = x.loc[:, ['Hybrid', y]]
        x = x.groupby(['Hybrid',]).agg('mean').reset_index()

    x = x.rename(columns={'Hybrid':'Taxa'})
    # Now save as a file that tassel can work with that will look like:
    # """
    # <Phenotype>
    # taxa    data
    # Taxa    YVarName
    # CML61   0.9
    # CI31A   0.8
    # CML61   1.0
    # """


    out_path = cache_path+'_'.join([
        'phno'+(lambda x: '_eres' if x else '')(eb
        )+f'_{y}.txt'
        ]
    )
    with open(out_path, 'w') as f:
        f.write("<Phenotype>\ntaxa\tdata\n")
        f.write('\t'.join(list(x))+'\n')
        for taxa, yvar in (x.loc[i] for i in x.index):
            f.write(f"{taxa}\t{yvar}\n")