In [1]:
# magics
%load_ext autoreload
%autoreload 2

# Imports

In [2]:
# Regular Imports
import os, sys, configparser
import time


In [3]:
# Scientific imports
from dask.distributed import Client, LocalCluster
from dask_jobqueue import SGECluster


  defaults = yaml.load(f)


In [4]:
# Multi Dataset Crystalography imports
import multi_dataset_crystalography as mdc # import MultiCrystalDataset
from multi_dataset_crystalography.utils import DefaultPanDDADataloader


In [5]:
# PanDDA imports
from pandda_analyse.config import PanDDAConfig
from pandda_analyse.event_model import PanDDAEventModel
# from pandda_analyse.processor import ProcessModelSeriel
from pandda_analyse.event_model_distributed import PanDDAEventModelDistributed, load, fit, evaluate, criticise



# Args

In [6]:
# Arguments
arguments = None
# Config
config_path = "/dls/science/groups/i04-1/conor_dev/pandda/lib-python/pandda/pandda_analyse/pandda_analyse/analyse_config.ini"
config = configparser.ConfigParser()


# Config

In [7]:
config.read(config_path)

['/dls/science/groups/i04-1/conor_dev/pandda/lib-python/pandda/pandda_analyse/pandda_analyse/analyse_config.ini']

In [8]:
pandda_config = PanDDAConfig(config)  # Maps options to code abstractions

# Dataset

In [9]:
# Get Dataset
pandda_dataset = mdc.dataset.dataset.MultiCrystalDataset(dataloader=pandda_config.dataloader,
                                         sample_loader=pandda_config.sample_loader
                                        )

# Get reference Model

In [10]:
# Get reference model
reference = pandda_config.get_reference(pandda_dataset.datasets)
pandda_dataset.sample_loader.reference = reference

# Apply transforms to dataset

In [11]:
# Apply dataset transforms
if "data_check" in pandda_config.dataset_transforms:
    transform = pandda_config.dataset_transforms["data_check"]
    dataset = transform(pandda_dataset, reference)
    print("###### {} ######".format(transform.name))
    for block, record in transform.log().items():
        print("# # {} # #".format(block))
        print(record)

###### PanddaDataChecker ######
# # Rejected datasets # #
Datasets rejected: 
PDK2-x0279: rejected - rmsd to reference
PDK2-x0251: rejected - rmsd to reference
PDK2-x0107: rejected - rmsd to reference
PDK2-x0238: rejected - rmsd to reference
PDK2-x0878: rejected - rmsd to reference



In [None]:
if "scale_diffraction" in pandda_config.dataset_transforms:
    transform = pandda_config.dataset_transforms["scale_diffraction"]
    dataset = transform(dataset, reference)
    print("###### {} ######".format(transform.name))
    for block, record in transform.log().items():
        print("# # {} # #".format(block))
        print(record)

In [None]:
if "filter_structure" in pandda_config.dataset_transforms:
    transform = pandda_config.dataset_transforms["filter_structure"]
    dataset = transform(dataset, reference)
    print("###### {} ######".format(transform.name))
    for block, record in transform.log().items():
        print("# # {} # #".format(block))
        print(record)

In [None]:
if "filter_wilson" in pandda_config.dataset_transforms:
    transform = pandda_config.dataset_transforms["filter_wilson"]
    dataset = transform(dataset, reference)
    print("###### {} ######".format(transform.name))
    for block, record in transform.log().items():
        print("# # {} # #".format(block))
        print(record)

In [None]:
if "align" in pandda_config.dataset_transforms:
    transform = pandda_config.dataset_transforms["align"]
    dataset = transform(dataset, reference)
    print("###### {} ######".format(transform.name))
    for block, record in transform.log().items():
        print("# # {} # #".format(block))
        print(record)

# Setup Output

In [None]:
tree = pandda_config.pandda_output(dataset)
for block, record in pandda_config.pandda_output.log().items():
    print("####### {} ########".format(block))
    print(record)

# Model

In [None]:
# Define Model
pandda_event_model = PanDDAEventModel(pandda_config.statistical_model,
                                      pandda_config.clusterer,
                                      pandda_config.event_finder,
                                      bdc_calculator=pandda_config.bdc_calculator,
                                      statistics=[],
                                      map_maker=pandda_config.map_maker, 
                                      event_table_maker=pandda_config.event_table_maker,
                                      cpus=config["args"]["cpus"],
                                      tree=tree)

# Partition

In [None]:
dataset.partitions = pandda_config.partitioner(dataset.datasets)

# Fit, evaluate, Criticise - Single dataset, Direct eval

##  VALIDATED!

In [None]:
# # Main Loop
# dataloader = DefaultPanDDADataloader(min_train_datasets=60, 
#                                      max_test_datasets=60)


In [None]:
# idx, ds = dataloader(dataset).next()

In [None]:
# context = pandda_event_model(idx, ds, reference)

In [None]:
# model = context.__enter__()

In [None]:
# model.fit()  # Fit the statistical model


In [None]:
# dtags = model.dataset.partition_samples("test").keys()
# truncated_datasets = model.dataset.sample_loader.truncated_datasets
# sample_loaders = {dtag: lambda d: model.dataset.sample_loader.get_sample(2.13786674777, d)
#                   for dtag
#                   in dtags}

In [None]:
# result = model.evaluate_single(sample_loaders["PDK2-x0621"],
#                                truncated_datasets["PDK2-x0621"],
#                                model.dataset.sample_loader.ref_map)

In [None]:
# model.grid = model.dataset.sample_loader.grid
# model.grid

In [None]:
# model.criticise_single(sample_loaders["PDK2-x0621"],
#                        truncated_datasets["PDK2-x0621"],
#                        model.dataset.sample_loader.ref_map,
#                        result[2],
#                        result[3],
#                        tree)

In [None]:
# model.dataset.max_res = max([d.data.summary.high_res for dtag, d in model.dataset.datasets.items()])

In [None]:
# model.evaluate_all()

In [None]:
# model.criticise_all(tree)

# Fit, evaluate, Criticise - All datasets, Local dask cluster

## TESTING

In [None]:
# # Main Loop
# dataloader = DefaultPanDDADataloader(min_train_datasets=60, 
#                                      max_test_datasets=60)


In [None]:
# ds = [(idx, d) for idx, d in dataloader(dataset)]

In [None]:
# # Set up client
# # client = dask.distributed.Client(scheduler_file="scheduler.json")
# cluster = LocalCluster(n_workers=2, threads_per_worker=1)
# client = Client(cluster)

In [None]:
# client

In [None]:
# # Get base distributed model
# pandda_event_model_distributed = PanDDAEventModelDistributed(pandda_config.statistical_model,
#                                       pandda_config.clusterer,
#                                       pandda_config.event_finder,
#                                         dataset=dataset,
#                                       bdc_calculator=pandda_config.bdc_calculator,
#                                       statistics=[],
#                                       map_maker=pandda_config.map_maker, 
#                                       event_table_maker=pandda_config.event_table_maker,
#                                       cpus=config["args"]["cpus"],
#                                       tree=tree)

In [None]:
# pandda_event_model_distributed.instantiate(reference,
#                                            tree)

In [None]:
# # Instantiate models
# models = [pandda_event_model_distributed.clone(dataset=d, 
#                                    name=idx)
#          for idx, d
#          in ds]

In [None]:
# models

In [None]:
# # Load model moponents
# models_loaded = client.map(load,
#                     models,
#                           pure=False)

In [None]:
# models_loaded

In [None]:
# # Load models over nodes
# models_fit = client.map(fit, 
#                            models_loaded,
#                        pure=False)


In [None]:
# models_fit

In [None]:
# # Ask models to process
# models_evaluated = client.map(evaluate, 
#                               models_fit,
#                              pure=False)


In [None]:
# models_evaluated

In [None]:
# # Ask models to criticise
# event_tables = client.map(criticise, 
#                                models_evaluated, 
#                                pure=False)


In [None]:
# event_tables

In [None]:
# event_tables_results =[e.result() for e in event_tables]

In [None]:
# client.close()

# Fi, evaluate, criticise - single dataset

In [None]:
import dask
dask.config.config
dask.config.set({"distributed.admin.tick.limit": "120s"})

In [None]:
# cluster = LocalCluster(n_workers=2, threads_per_worker=4)

In [None]:
# client = Client(cluster)

In [None]:
cluster = SGECluster(queue="medium.q",
                     cores=20,
                     processes=5,
                           memory="64GB",
                           resourcce_spec="m_mem_free=64G",
                    python="/dls/science/groups/i04-1/conor_dev/ccp4/build/bin/cctbx.python")

In [None]:
cluster.scale(5)

In [None]:
client = Client(cluster)

In [None]:
client

In [None]:
# Main Loop
dataloader = DefaultPanDDADataloader(min_train_datasets=60, 
                                     max_test_datasets=60)


In [None]:
ds = [(idx, d) for idx, d in dataloader(dataset)]

In [None]:
# Get base distributed model
pandda_event_model_distributed = PanDDAEventModelDistributed(pandda_config.statistical_model,
                                      pandda_config.clusterer,
                                      pandda_config.event_finder,
                                        dataset=dataset,
                                      bdc_calculator=pandda_config.bdc_calculator,
                                      statistics=[],
                                      map_maker=pandda_config.map_maker, 
                                      event_table_maker=pandda_config.event_table_maker,
                                      cpus=config["args"]["cpus"],
                                      tree=tree)

In [None]:
# pandda_event_model_distributed.instantiate(reference,
#                                            tree)

In [None]:
# Instantiate models
models = [pandda_event_model_distributed.clone(dataset=d, 
                                   name=idx)
         for idx, d
         in ds]

In [None]:
# Load model moponents
# model_loaded = load(models[0])

In [None]:
dsk = {}

# loop over model blocks
for model in models:
    # Get model name
    name = str(model.name)
    print(name)
    
    # Get dataset
    dtags = set(model_loaded.dataset.partition_datasets("test").keys()
                    + model_loaded.dataset.partition_datasets("train").keys()
                    )
    
    # Get resolution
    resolutions_test = max([d.data.summary.high_res for dtag, d
                                in model_loaded.dataset.partition_datasets("test").items()])
    resolutions_train = max([d.data.summary.high_res for dtag, d
                                 in model_loaded.dataset.partition_datasets("train").items()])
    max_res = max(resolutions_test, resolutions_train)
    
    dsk["{}_model".format(name)] = model
    dsk["reference"] = reference
    dsk["tree"] = tree
    
        # Load datasets
    for dtag in dtags:
        dsk["{}".format(dtag.replace("-", "_"))] = model.dataset.datasets[dtag]
    
    # Load model
    dsk["{}_loaded_model".format(name)] = (lambda m, r, t: m.instantiate(r, t),
                                                       "{}_model".format(name),
                                                        "reference",
                                                        "tree"
                                                      )
    
    dsk["{}_max_res".format(name)] = max_res
    
        # Get sample loader
    sample_loader = model.dataset.sample_loader
    dsk["{}_sample_loader".format(name)] = (lambda m: m.dataset.sample_loader,
                                            "{}_loaded_model".format(name)
                                           )
    
    # ref map
    dsk["{}_ref_map".format(name)] = (lambda sl: sl.ref_map,
                                     "{}_sample_loader".format(name))
    
    # Load maps
    for dtag in dtags:
        dsk["{}_{}_map".format(name, dtag.replace("-", "_"))] = (lambda sl, r, _d: sl.get_sample(r, _d), 
                                               "{}_sample_loader".format(name),
                                               "{}_max_res".format(name), 
                                               "{}".format(dtag.replace("-", "_")))
    
    # Fit model
    dsk["{}_fit_model".format(name)] = (lambda m, train, test: m.statistical_model.fit(train, test),
                           "{}_loaded_model".format(name), 
                             ["{}".format(dtag.replace("-", "_")) for dtag, d in model_loaded.dataset.partition_datasets("train").items()], 
                             ["{}".format(dtag.replace("-", "_")) for dtag, d in model_loaded.dataset.partition_datasets("test").items()]
                          )
    
    # Find events
    for dtag in dtags:
        d = "{}".format(dtag.replace("-","_"))
        dsk["{}_{}_events".format(name, dtag.replace("-","_"))]  = (lambda m, s, _d, ref: m.evaluate_single(s, _d, ref),
                                                   "{}_fit_model".format(name),
                                                   "{}_{}_map".format(name, dtag.replace("-","_")),
                                                   d,
                                                   "{}_ref_map".format(name)
                                                  )
    
    # Criticise
    for dtag in dtags:
        d = "{}".format(dtag)
        dsk["{}_{}_event_table".format(name, dtag)] = (lambda m, _d, e: m.criticise_single(_d, e),
                                                       "{}_fit_model".format(name),
                                                       d,
                                                       "{}_{}_events".format(name, dtag)
                                                      )
    
    # Join
    dsk["{}_event_table".format(name)] = (lambda m, et: m.criticise_all(et),
                            "{}_fit_model".format(name),
                            ["{}_{}_events".format(name, dtag) for dtag in dtags])
    break
    

In [None]:
# dsk["0_loaded_model"]

In [None]:
# client.get(dsk, "0_loaded_model")

In [None]:
# dsk["0_sample_loader"]

In [None]:
# client.get(dsk, "0_sample_loader")

In [None]:
# client.get(dsk, "0_max_res")

In [None]:
# client.get(dsk, "PDK2_x0384")

In [None]:
# dsk["0_PDK2_x0384_map"]

In [None]:
# client.get(dsk, "0_PDK2_x0384_map")

In [None]:
dsk["0_fit_model"]

In [None]:
client.get(dsk, "0_fit_model")

In [None]:
dsk["0_PDK2_x0384_events"]

In [None]:
client.get(dsk, "0_PDK2_x0384_events")

In [None]:
dsk["0_max_res"]

In [None]:
dsk["PDK2-x0384"]

In [None]:
client.get(dsk, "0_PDK2-x0384_map")

In [None]:
client.get(dsk, "0_event_table")

In [None]:
client.get(dsk, )

In [None]:
# dsk["model_loaded"] = model_loaded

In [None]:
dtags = set(model_loaded.dataset.partition_datasets("test").keys()
                    + model_loaded.dataset.partition_datasets("train").keys()
                    )

In [None]:
resolutions_test = max([d.data.summary.high_res for dtag, d
                                in model_loaded.dataset.partition_datasets("test").items()])
resolutions_train = max([d.data.summary.high_res for dtag, d
                                 in model_loaded.dataset.partition_datasets("train").items()])
max_res = max(resolutions_test, resolutions_train)



In [None]:
sample_loader = model_loaded.dataset.sample_loader

In [None]:
# sample_loaders = {dtag: lambda d: sample_loader.get_sample(res, d)
#                           for dtag
#                           in dtags}

In [None]:
for dtag in dtags:
    dsk[dtag] = (sample_loader.get_sample, max_res, model_loaded.dataset.datasets[dtag])

In [None]:
dsk["params"] = (model_loaded.statistical_model.fit, 
                 [dtag for dtag, d in model_loaded.dataset.partition_datasets("train").items()], 
                 [dtag for dtag, d in model_loaded.dataset.partition_datasets("test").items()])

In [None]:
dsk

In [None]:
params = client.get(dsk, "params")

In [None]:
# Load models over nodes
model_fit = client.submit(fit, model_loaded)


In [None]:
model_fit.result()

In [None]:
# Ask models to process
model_evaluated = evaluate(model_fit)


In [None]:
model_evaluated.result()

In [None]:
# Ask models to criticise
event_tables = criticise(models_evaluated)


In [None]:
client.reset()

In [None]:
client.restart()

In [None]:
client.close()

# Fit, evaluate, Criticise - All datasets, Dask distributed

## TESTING

In [None]:
cluster = SGECluster(queue="medium.q",
                     cores=1,
                     processes=1,
                           memory="64GB",
                           resourcce_spec="m_mem_free=64G",
                    python="/dls/science/groups/i04-1/conor_dev/ccp4/build/bin/cctbx.python")

In [None]:
cluster.scale(3)

In [None]:
time.sleep(15)

In [None]:
cluster = LocalCluster(n_workers=2, threads_per_worker=1)

In [None]:
client = Client(cluster)

In [None]:
client

In [None]:
# Main Loop
dataloader = DefaultPanDDADataloader(min_train_datasets=60, 
                                     max_test_datasets=60)


In [None]:
ds = [(idx, d) for idx, d in dataloader(dataset)]

In [None]:
# Get base distributed model
pandda_event_model_distributed = PanDDAEventModelDistributed(pandda_config.statistical_model,
                                      pandda_config.clusterer,
                                      pandda_config.event_finder,
                                        dataset=dataset,
                                      bdc_calculator=pandda_config.bdc_calculator,
                                      statistics=[],
                                      map_maker=pandda_config.map_maker, 
                                      event_table_maker=pandda_config.event_table_maker,
                                      cpus=config["args"]["cpus"],
                                      tree=tree)

In [None]:
pandda_event_model_distributed.instantiate(reference,
                                           tree)

In [None]:
# Instantiate models
models = [pandda_event_model_distributed.clone(dataset=d, 
                                   name=idx)
         for idx, d
         in ds]

In [None]:
# models

In [None]:
# # Load model moponents
# models_loaded = client.map(load,
#                     models,
#                           pure=False)

In [None]:
# models_loaded

In [None]:
# # Load models over nodes
# models_fit = client.map(fit, 
#                            models_loaded,
#                        pure=False)


In [None]:
# models_fit

In [None]:
# # Ask models to process
# models_evaluated = client.map(evaluate, 
#                               models_fit,
#                              pure=False)


In [None]:
# models_evaluated

In [None]:
# # Ask models to criticise
# event_tables = client.map(criticise, 
#                                models_evaluated, 
#                                pure=False)


In [None]:
# event_tables

In [None]:
# event_tables_results =[e.result() for e in event_tables]

In [None]:
dsk = {}
for i, model in enumerate(models):
    dsk["load_{}".format(i)] = (load, model)
    dsk["fit_{}".format(i)] = (fit, "load_{}".format(i))
    dsk["evaluate_{}".format(i)] = (evaluate, "fit_{}".format(i))
    dsk["criticise_{}".format(i)] = (criticise, "evaluate_{}".format(i))


In [None]:
dsk

In [None]:
dsk_combined, deps = dask.optimization.fuse(dsk)

In [None]:
client.get(dsk_combined, ["criticise_{}".format(i)
                          for i, model
                          in enumerate(models)])

In [None]:
client.close()

In [None]:
# Visualise tasks


In [None]:
# Ask models for events
event_tables = event_tables.results()

In [None]:
with ProcessModelSeriel() as P:            

    for dataset in dataloader(dataset):

        P(pandda_event_model(dataset))
        # call with self as model: model.fit(); model.evaluate(); model.criticise()
            # Seriel: run immediately
            # Qsub: pick model

    # exit: 
        # Seriel: just go on
        # qsub: wait for the jobs to complete

In [None]:
for dataset in dataloader(dataset):
    
    with pandda_event_model(dataset) as model:

        # Fit model
        print("Fitting model")
        model.fit()  # Fit the statistical model

        # Evaluate model
        model.evaluate_parallel()  # Evaluate the fitted model on maps, fidning events

        # Criticise Model
        model.criticise() # Stores statistics from model fitting and evaluation


# Criticise Run

## TESTING

In [None]:
# Criticise Model
pandda_statistics = PanDDARunStatistics(dataset, model)  # Generates statistics from dataset and model
PanDDARunGraphs(pandda_statistics)  # Produces a set of graphs of statistics
pandda_html = PanDDARunHTML(pandda_statistics)  # Produces a HTML from Statistics
PanDDARunLog()  # Produces a log of data processing from dataset, model, graphs and HTML


In [None]:
with processor as P:            

    for dataset in dataloader(dataset):

        P(pandda_event_model(dataset))
        # call with self as model: model.fit(); model.evaluate(); model.criticise()
            # Seriel: run immediately
            # Qsub: pick model

    # exit: 
        # Seriel: just go on
        # qsub: wait for the jobs to complete
        

In [None]:
class PanDDARunStatistics:
    
    def __init__(dataset, model):
        
        self.statistics = [...]
        
        for statistic in self.statistic:
            try:
                statistic.calculate()
                self.trace[statistic.name] = statistic.log()

            except Exeption as e:
                self.trace[statistic.name] = "{}".format(e)
        

In [None]:
class OutputNativeMaps:
    
    def calculate(samples):
        
        
        
    def log():
        
    