## 1. Execute one of the following cell to pick AGs results to be plotted

### Specify 
- datasets that should be loaded, and the label you want to assign to them (eg GAN, GAN_versionXX ...)
- output directory name (will be created if needed)
- SNP position file(s)  
- number of individuals to keep from each dataset (will be randomly subampled if less than the total number)

## 2. Imports and general color dictionary

In [None]:
import seaborn as sns
import pandas as pd
import numpy as np
import importlib
import os
from src.short import plot_utils as plu
dirscript = 'src/short'

In [None]:
repoDIR = "./"
samplesDIR = "./data1kg"
DATA = "10K"
# DATA = "805"
nsub = 5000

if DATA == "10K":
    outDir = repoDIR + "FIGS/10K/"
    infiles = {'Real': samplesDIR + "/1000G_real_genomes/10K_SNP_1000G_real.hapt.test",
            'Truth': samplesDIR + "/1000G_real_genomes/10K_SNP_1000G_real.hapt.train",
            'GAN': samplesDIR + "/GAN_AGs/10K_SNP_GAN_AG_10800Epochs.hapt.zip",
            'RBM': samplesDIR + "/RBM_AGs/10K_SNP_RBM_AG_1050epochs.hapt.zip",
            'HCLT': samplesDIR + "/HCLT_AGs/" + "10K_SNP_HCLT_AG.hapt"
            }
    realposfname = samplesDIR + "/1000G_real_genomes/10k_SNP.legend" 

elif DATA == "805":
    outDir = repoDIR + "FIGS/805/"
    infiles = {'Real': samplesDIR + "/1000G_real_genomes/805_SNP_1000G_real.hapt.test",
            'Truth': samplesDIR + "/1000G_real_genomes/805_SNP_1000G_real.hapt.train",
            'GAN': samplesDIR + "/GAN_AGs/805_SNP_GAN_AG_20000epochs.hapt.zip",
            'RBM': samplesDIR + "/RBM_AGs/805_SNP_RBM_AG_800Epochs.hapt.zip",
            'HCLT': samplesDIR + "/HCLT_AGs/" + "805_SNP_HCLT_AG.hapt"
            }
    realposfname = samplesDIR + "/1000G_real_genomes/805_SNP.legend" 

# same SNP positions for all datasets so it is just repeated for all keys:
position_fname = {key:realposfname for key in infiles.keys()} 

print("- Datasets under study:\n", infiles)

# same SNP positions for all datasets so it is just repeated for all keys:
position_fname = {key:realposfname for key in infiles.keys()} 

In [None]:
# General colors
allcolpal = dict({'Real':"#95a5a6",
                  'Truth': "#95a5a6",
                  'GAN':"#3498db", 
                  'RBM':"#e74c3c", 
                #   'Indep': "#2ecc71",
                #   'Markov': "#a6761d",
                #   'HMM': 'gold',
                #   'Strudel': 'lightpink',
                  'HCLT': '#6a3d9a'
              })

# Update current color palette to the dataset type in infiles 
colpal =  {key:allcolpal[key] for key in infiles.keys()}
sns.set_palette(colpal.values())

sns.palplot(sns.color_palette())
print(f"- Output Directory for figures: {outDir}\n",
      f"- Real dataset positions: {realposfname}\n",
      f"- Sample size:{nsub}")


## 3. run notebook to plot all figures or a subset of sumstats (for faster results)

In [None]:
f"Figures will be saved in {outDir} or its subdirectories"

In [None]:
## Print one more time the name of datasets that will be loaded
## the path should exist otherwise you need to check that your setup is correct
for x in [f"- Input file {f} exists: {os.path.exists(f)}" for f in infiles.values()]:
    print(x)

In [None]:
# Setup options (transformations, sumstats to compute etc) and output directory (automatically derived from maintOutDir)

importlib.reload(plu) # useful only if plot_utils is changed since you imported it for dev reason
boolComputeAATS = True # if False notebook 5 will reload previously computed AATS instead of computing it
figwi = 12 # control size of some figures 

# set allchecks to False for a first rapid scan
# set to True for computing/plotting all sumstats and scores (long, better on a cluster)
allchecks = False

# pick the transformations you want to apply to the datasets
# For no transformation choose 
# transformations=None
transformations={'to_minor_encoding':False, 'min_af': 0, 'max_af': 1}

if not transformations is None:
    tname=';'.join([f'{k}-{v}' for k,v in transformations.items()])
else:
    tname = 'none'
tname=tname + ';allchecks-' + str(int(allchecks)) + ';n-' + str(nsub)
outDir = outDir = os.path.join(outDir, tname+'/')
print(f"- Figures will be saved in {outDir}")
if os.path.exists(outDir):
    print('    - This directory exists, the following files might be overwritten:')
    print('    -', os.listdir(outDir))

### Compute summary statistics
**You can pick which notebooks to execute** (and comment the other lines)  
Only the **first one is mandatory**  (plotfig_utils_1_INIT.ipynb)  
It loads datasets, applies basic transformations if asked, and initializes a few variables (such as a dictionnary of haplotypes, allele counts, fixed site vectors, etc)

In [None]:
%run -p {dirscript}/plotfig_utils_1_INIT.ipynb  # mandatory, all lines below are optional
%run -p {dirscript}/plotfig_utils_2_AF.ipynb 
%run -p {dirscript}/plotfig_utils_3_PCA.ipynb
%run -p {dirscript}/plotfig_utils_4_LD.ipynb

In [None]:
DIST = True
AATS = True
boolComputeAATS = True
%run -p {dirscript}/plotfig_utils_5_DIST_AATS.ipynb # computationnally long

In [None]:
if allchecks:
    %run -p {dirscript}/plotfig_utils_6_3pointcorr.ipynb  # computationnally long