In [1]:
# magics
%load_ext autoreload
%autoreload 2

# Imports

In [2]:
# Regular Imports
import os, sys, configparser
import time


In [3]:
# Scientific imports
from joblib import Memory

import dask
from dask.distributed import Client, LocalCluster
from dask_jobqueue import SGECluster


  defaults = yaml.load(f)


In [4]:
# Dask functions
from functions import fit, get_reference_map, load_sample, evaluate_model, cluster_outliers, filter_clusters, estimate_bdcs, make_event_map, make_shell_maps, make_event_table
from pandda_analyse.criticise import make_map

In [5]:
# Multi Dataset Crystalography imports
import multi_dataset_crystalography as mdc # import MultiCrystalDataset
from multi_dataset_crystalography.utils import DefaultPanDDADataloader
from multi_dataset_crystalography.dataset.sample_loader import PanddaDiffractionDataTruncater



In [6]:
# PanDDA imports
from pandda_analyse.config import PanDDAConfig
from pandda_analyse.event_model import PanDDAEventModel
# from pandda_analyse.processor import ProcessModelSeriel
from pandda_analyse.event_model_distributed import PanDDAEventModelDistributed



# Args

In [7]:
# Arguments
arguments = None
# Config
config_path = "/dls/science/groups/i04-1/conor_dev/pandda/lib-python/pandda/pandda_analyse/pandda_analyse/analyse_config.ini"
config = configparser.ConfigParser()


# Config

In [8]:
config.read(config_path)

['/dls/science/groups/i04-1/conor_dev/pandda/lib-python/pandda/pandda_analyse/pandda_analyse/analyse_config.ini']

In [9]:
pandda_config = PanDDAConfig(config)  # Maps options to code abstractions

# Memory

In [10]:
from joblib import Memory
memory = Memory("/dls/science/groups/i04-1/conor_dev/cache", verbose=0)

# Dataset

In [11]:
%%time
# Get Dataset
pandda_dataset = mdc.dataset.dataset.MultiCrystalDataset(dataloader=pandda_config.dataloader,
                                         sample_loader=pandda_config.sample_loader
                                        )

CPU times: user 21.7 s, sys: 1.2 s, total: 22.9 s
Wall time: 23.5 s


# Get reference Model

In [12]:
%%time
# Get reference model
reference = pandda_config.get_reference(pandda_dataset.datasets)
pandda_dataset.sample_loader.reference = reference

CPU times: user 8.43 s, sys: 816 ms, total: 9.24 s
Wall time: 9.51 s


# Apply transforms to dataset

In [13]:
%%time
# Apply dataset transforms
if "data_check" in pandda_config.dataset_transforms:
    transform = pandda_config.dataset_transforms["data_check"]
    dataset = transform(pandda_dataset, reference)
    print("###### {} ######".format(transform.name))
    for block, record in transform.log().items():
        print("# # {} # #".format(block))
        print(record)

###### PanddaDataChecker ######
# # Rejected datasets # #
Datasets rejected: 
PDK2-x0279: rejected - rmsd to reference
PDK2-x0251: rejected - rmsd to reference
PDK2-x0107: rejected - rmsd to reference
PDK2-x0238: rejected - rmsd to reference
PDK2-x0878: rejected - rmsd to reference

CPU times: user 23.4 s, sys: 492 ms, total: 23.9 s
Wall time: 24.2 s


In [14]:
%%time
if "scale_diffraction" in pandda_config.dataset_transforms:
    transform = pandda_config.dataset_transforms["scale_diffraction"]
    dataset = memory.cache(transform)(dataset, reference)
    print("###### {} ######".format(transform.name))
    for block, record in transform.log().items():
        print("# # {} # #".format(block))
        print(record)

  This is separate from the ipykernel package so we can avoid doing imports until
  interpolation=interpolation)


###### PanddaDiffractionScaler ######
# # Rejected datasets # #
Datasets rejected: 
PDK2-x0641: rejected - scaling failed
PDK2-x0318: rejected - scaling failed
PDK2-x0317: rejected - scaling failed

CPU times: user 4min 56s, sys: 1min 19s, total: 6min 15s
Wall time: 6min 22s


If this happens often in your code, it can cause performance problems 
(results will be correct in all cases). 
The reason for this is probably some large input arguments for a wrapped
 function (e.g. large strings).
THIS IS A JOBLIB ISSUE. If you can, kindly provide the joblib's team with an
 example so that they can fix the problem.
  This is separate from the ipykernel package so we can avoid doing imports until


In [15]:
%%time
if "filter_structure" in pandda_config.dataset_transforms:
    transform = pandda_config.dataset_transforms["filter_structure"]
    dataset = transform(dataset, reference)
    print("###### {} ######".format(transform.name))
    for block, record in transform.log().items():
        print("# # {} # #".format(block))
        print(record)

###### PanddaDatasetFilterer ######
# # Rejected datasets # #
Datasets rejected: 
PDK2-x0281: rejected - non-identical structures
PDK2-x0365: rejected - non-identical structures
PDK2-x0658: rejected - non-identical structures
PDK2-x0288: rejected - non-identical structures
PDK2-x0182: rejected - non-identical structures
PDK2-x0347: rejected - non-identical structures
PDK2-x0676: rejected - non-identical structures
PDK2-x0585: rejected - non-identical structures
PDK2-x0248: rejected - non-identical structures
PDK2-x0106: rejected - non-identical structures
PDK2-x0287: rejected - non-identical structures
PDK2-x0229: rejected - non-identical structures
PDK2-x0289: rejected - non-identical structures
PDK2-x0299: rejected - non-identical structures
PDK2-x0310: rejected - non-identical structures
PDK2-x0316: rejected - non-identical structures
PDK2-x0192: rejected - non-identical structures
PDK2-x0050: rejected - non-identical structures
PDK2-x0236: rejected - non-identical structures
PDK2-x

In [16]:
%%time
if "filter_wilson" in pandda_config.dataset_transforms:
    transform = pandda_config.dataset_transforms["filter_wilson"]
    dataset = transform(dataset, reference)
    print("###### {} ######".format(transform.name))
    for block, record in transform.log().items():
        print("# # {} # #".format(block))
        print(record)

  return 0.6745*devs/mdev


###### PanddaDatasetFiltererWilsonRMSD ######
# # Rejected datasets # #
Datasets rejected: 


CPU times: user 25.1 s, sys: 435 ms, total: 25.5 s
Wall time: 25.5 s


  col='scaled_wilson_rmsd_all_z') > self.max_wilson_plot_z_score) or \
  col='scaled_wilson_rmsd_<4A_z') > self.max_wilson_plot_z_score) or \
  col='scaled_wilson_rmsd_>4A_z') > self.max_wilson_plot_z_score) or \
  col='scaled_wilson_ln_rmsd_z') > self.max_wilson_plot_z_score) or \
  col='scaled_wilson_ln_dev_z') > self.max_wilson_plot_z_score):


In [17]:
%%time
if "align" in pandda_config.dataset_transforms:
    transform = pandda_config.dataset_transforms["align"]
    dataset = memory.cache(transform)(dataset, reference)
    print("###### {} ######".format(transform.name))
    for block, record in transform.log().items():
        print("# # {} # #".format(block))
        print(record)

  This is separate from the ipykernel package so we can avoid doing imports until
[Parallel(n_jobs=15)]: Using backend LokyBackend with 15 concurrent workers.
[Parallel(n_jobs=15)]: Done   1 tasks      | elapsed:    6.1s
[Parallel(n_jobs=15)]: Done   2 tasks      | elapsed:    8.5s
[Parallel(n_jobs=15)]: Done   3 tasks      | elapsed:   10.9s
[Parallel(n_jobs=15)]: Done   4 tasks      | elapsed:   13.6s
[Parallel(n_jobs=15)]: Done   5 tasks      | elapsed:   16.0s
[Parallel(n_jobs=15)]: Done   6 tasks      | elapsed:   18.6s
[Parallel(n_jobs=15)]: Done   7 tasks      | elapsed:   20.9s
[Parallel(n_jobs=15)]: Done   8 tasks      | elapsed:   23.1s
[Parallel(n_jobs=15)]: Done   9 tasks      | elapsed:   25.6s
[Parallel(n_jobs=15)]: Done  10 tasks      | elapsed:   28.0s
[Parallel(n_jobs=15)]: Done  11 tasks      | elapsed:   30.5s
[Parallel(n_jobs=15)]: Done  12 tasks      | elapsed:   32.9s
[Parallel(n_jobs=15)]: Done  13 tasks      | elapsed:   35.3s
[Parallel(n_jobs=15)]: Done  14 tas

[Parallel(n_jobs=15)]: Done 130 tasks      | elapsed:   50.4s
[Parallel(n_jobs=15)]: Done 131 tasks      | elapsed:   50.5s
[Parallel(n_jobs=15)]: Done 132 tasks      | elapsed:   50.5s
[Parallel(n_jobs=15)]: Done 133 tasks      | elapsed:   50.7s
[Parallel(n_jobs=15)]: Done 134 tasks      | elapsed:   50.8s
[Parallel(n_jobs=15)]: Done 135 tasks      | elapsed:   51.1s
[Parallel(n_jobs=15)]: Done 136 tasks      | elapsed:   51.3s
[Parallel(n_jobs=15)]: Done 137 tasks      | elapsed:   51.3s
[Parallel(n_jobs=15)]: Done 138 tasks      | elapsed:   51.4s
[Parallel(n_jobs=15)]: Done 139 tasks      | elapsed:   51.4s
[Parallel(n_jobs=15)]: Done 140 tasks      | elapsed:   51.8s
[Parallel(n_jobs=15)]: Done 141 tasks      | elapsed:   51.8s
[Parallel(n_jobs=15)]: Done 142 tasks      | elapsed:   51.8s
[Parallel(n_jobs=15)]: Done 143 tasks      | elapsed:   51.9s
[Parallel(n_jobs=15)]: Done 144 tasks      | elapsed:   51.9s
[Parallel(n_jobs=15)]: Done 145 tasks      | elapsed:   51.9s
[Paralle

[Parallel(n_jobs=15)]: Done 264 tasks      | elapsed:  1.0min
[Parallel(n_jobs=15)]: Done 265 tasks      | elapsed:  1.0min
[Parallel(n_jobs=15)]: Done 266 tasks      | elapsed:  1.0min
[Parallel(n_jobs=15)]: Done 267 tasks      | elapsed:  1.0min
[Parallel(n_jobs=15)]: Done 268 tasks      | elapsed:  1.0min
[Parallel(n_jobs=15)]: Done 269 tasks      | elapsed:  1.0min
[Parallel(n_jobs=15)]: Done 270 tasks      | elapsed:  1.0min
[Parallel(n_jobs=15)]: Done 271 tasks      | elapsed:  1.1min
[Parallel(n_jobs=15)]: Done 272 tasks      | elapsed:  1.1min
[Parallel(n_jobs=15)]: Done 273 tasks      | elapsed:  1.1min
[Parallel(n_jobs=15)]: Done 274 tasks      | elapsed:  1.1min
[Parallel(n_jobs=15)]: Done 275 tasks      | elapsed:  1.1min
[Parallel(n_jobs=15)]: Done 276 tasks      | elapsed:  1.1min
[Parallel(n_jobs=15)]: Done 277 tasks      | elapsed:  1.1min
[Parallel(n_jobs=15)]: Done 278 tasks      | elapsed:  1.1min
[Parallel(n_jobs=15)]: Done 279 tasks      | elapsed:  1.1min
[Paralle

[Parallel(n_jobs=15)]: Done 397 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 398 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 399 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 400 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 401 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 402 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 403 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 404 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 405 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 406 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 407 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 408 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 409 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 410 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 411 tasks      | elapsed:  1.2min
[Parallel(n_jobs=15)]: Done 412 tasks      | elapsed:  1.3min
[Paralle

[Parallel(n_jobs=15)]: Done 531 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 532 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 533 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 534 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 535 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 536 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 537 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 538 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 539 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 540 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 541 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 542 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 543 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 544 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 545 tasks      | elapsed:  1.4min
[Parallel(n_jobs=15)]: Done 546 tasks      | elapsed:  1.5min
[Paralle

[Parallel(n_jobs=15)]: Done 666 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 667 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 668 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 669 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 670 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 671 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 672 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 673 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 674 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 675 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 676 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 677 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 678 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 679 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 680 tasks      | elapsed:  1.6min
[Parallel(n_jobs=15)]: Done 681 tasks      | elapsed:  1.6min
[Paralle

###### PanddaDefaultStructureAligner ######
# # aligned_datasets # #
Datasets aligned: 
PDK2-x0811
PDK2-x0621
PDK2-x0810
PDK2-x0082
PDK2-x0083
PDK2-x0080
PDK2-x0081
PDK2-x0086
PDK2-x0087
PDK2-x0084
PDK2-x0085
PDK2-x0815
PDK2-x0088
PDK2-x0089
PDK2-x0814
PDK2-x0165
PDK2-x0164
PDK2-x0167
PDK2-x0166
PDK2-x0161
PDK2-x0160
PDK2-x0163
PDK2-x0162
PDK2-x0245
PDK2-x0247
PDK2-x0240
PDK2-x0241
PDK2-x0242
PDK2-x0243
PDK2-x0819
PDK2-x0627
PDK2-x0788
PDK2-x0750
PDK2-x0400
PDK2-x0051
PDK2-x0401
PDK2-x0047
PDK2-x0019
PDK2-x0018
PDK2-x0786
PDK2-x0011
PDK2-x0010
PDK2-x0013
PDK2-x0012
PDK2-x0015
PDK2-x0014
PDK2-x0017
PDK2-x0016
PDK2-x0420
PDK2-x0421
PDK2-x0422
PDK2-x0424
PDK2-x0425
PDK2-x0427
PDK2-x0429
PDK2-x0540
PDK2-x0058
PDK2-x0538
PDK2-x0539
PDK2-x0684
PDK2-x0685
PDK2-x0686
PDK2-x0687
PDK2-x0680
PDK2-x0681
PDK2-x0729
PDK2-x0683
PDK2-x0727
PDK2-x0726
PDK2-x0725
PDK2-x0724
PDK2-x0688
PDK2-x0689
PDK2-x0721
PDK2-x0720
PDK2-x0334
PDK2-x0335
PDK2-x0336
PDK2-x0337
PDK2-x0330
PDK2-x0331
PDK2-x0332
PDK2-x0531

If this happens often in your code, it can cause performance problems 
(results will be correct in all cases). 
The reason for this is probably some large input arguments for a wrapped
 function (e.g. large strings).
THIS IS A JOBLIB ISSUE. If you can, kindly provide the joblib's team with an
 example so that they can fix the problem.
  This is separate from the ipykernel package so we can avoid doing imports until


# Get Grid

In [18]:
%%time
grid = dataset.sample_loader.get_grid(reference)

----------------------------------->>>
Atomic Mask Summary:
Total Mask Size (1D): 823828
Outer Mask Size (1D): 905167
Inner Mask Size (1D): 81339
Masked Grid Min/Max: ((7, 69, 108), (144, 52, 62))
----------------------------------->>>
Atomic Mask Summary:
Total Mask Size (1D): 800827
Outer Mask Size (1D): 2072287
Inner Mask Size (1D): 1271460
Masked Grid Min/Max: ((0, 0, 1), (151, 133, 160))
CPU times: user 33.2 s, sys: 2.4 s, total: 35.6 s
Wall time: 47.1 s


# Setup Output

In [19]:
%%time
tree = pandda_config.pandda_output(dataset)
for block, record in pandda_config.pandda_output.log().items():
    print("####### {} ########".format(block))
    print(record)

####### pandda_created ########
Output directory created: 
True

CPU times: user 3.21 s, sys: 2.78 s, total: 5.99 s
Wall time: 35.7 s


# Model

In [20]:
%%time
# Define Model
pandda_event_model = PanDDAEventModel(pandda_config.statistical_model,
                                      pandda_config.clusterer,
                                      pandda_config.event_finder,
                                      bdc_calculator=pandda_config.bdc_calculator,
                                      statistics=[],
                                      map_maker=pandda_config.map_maker, 
                                      event_table_maker=pandda_config.event_table_maker,
                                      cpus=config["args"]["cpus"],
                                      tree=tree)

CPU times: user 184 µs, sys: 22 µs, total: 206 µs
Wall time: 202 µs


# Partition

In [21]:
%%time
dataset.partitions = pandda_config.partitioner(dataset.datasets)

CPU times: user 704 µs, sys: 0 ns, total: 704 µs
Wall time: 587 µs


# Dask Functions

In [22]:
statistical_model = pandda_config.statistical_model
sample_loader = pandda_config.sample_loader

clusterer = pandda_config.clusterer
event_finder = pandda_config.event_finder
bdc_calculator = pandda_config.bdc_calculator

criticiser = pandda_config.criticiser
criticiser_all = pandda_config.criticiser_all

# Fi, evaluate, criticise - single dataset

In [23]:
dask.config.config
dask.config.set({"distributed.admin.tick.limit": "300s"})

<dask.config.set at 0x7f1d20c4f450>

In [24]:
cluster = SGECluster(queue="medium.q",
                     project="labxchem",
                     cores=10,
                     processes=5,
                           memory="64GB",
                           resource_spec="m_mem_free=64G,redhat_release=rhel7",
                    python="/dls/science/groups/i04-1/conor_dev/ccp4/build/bin/cctbx.python",
                    walltime="03:00:00")
cluster.scale(60)

  "diagnostics_port has been deprecated. "


In [25]:
# client.close()

In [26]:
# del cluster
# del client

In [27]:
# cluster = LocalCluster(n_workers=5, 
#                       memory_limit="25GB"
#                       )

In [28]:
# client = Client(n_workers=1)

In [29]:
time.sleep(30)

In [30]:
client = Client(cluster)

In [31]:
# cluster.close()

In [32]:
# client.restart()

In [33]:
# time.sleep(15)

In [34]:
client

0,1
Client  Scheduler: tcp://172.23.159.7:35352  Dashboard: http://172.23.159.7:8787/status,Cluster  Workers: 60  Cores: 120  Memory: 768.00 GB


In [35]:
# client.close()

In [36]:
# client.restart()

In [37]:
# print(dir(client))
# workers = client.scheduler.workers_to_close(n=8)
# client.scheduler.retire_workers(workers=workers,close_workers=True,remove=True)

In [38]:
# cluster.restart()
# client.restart()

In [39]:
# cluster.scale(8)

In [40]:
# cluster.close()

In [41]:
# Main Loop
dataloader = DefaultPanDDADataloader(min_train_datasets=60, 
                                     max_test_datasets=60)


In [42]:
ds = [(idx, d) for idx, d in dataloader(dataset)]

Got all train datasets
Sorted datasets
Collected trian datasets
Got all test datasets
sorted test datasets
yielding dataset
Dataset 0 of length 60; res limits (2.13786674777,2.13786674777)
yielding dataset
Dataset 1 of length 111; res limits (2.13849713107,2.23009496576)
yielding dataset
Dataset 2 of length 120; res limits (2.23981287927,2.31509360885)
yielding dataset
Dataset 3 of length 120; res limits (2.31816393209,2.3755079232)
yielding dataset
Dataset 4 of length 120; res limits (2.37600830711,2.45904047421)
yielding dataset
Dataset 5 of length 120; res limits (2.45982275972,2.5285230784)
yielding dataset
Dataset 6 of length 120; res limits (2.5290507321,2.58989421755)
yielding dataset
Dataset 7 of length 120; res limits (2.59023952047,2.64984944033)
yielding dataset
Dataset 8 of length 120; res limits (2.64990070685,2.71924167548)
yielding dataset
Dataset 9 of length 120; res limits (2.71972848577,2.78869368508)
yielding dataset
Dataset 10 of length 117; res limits (2.7888705626

# Define Task graph

In [43]:
# dsk = {}

# Define Functions

In [44]:
# # Sample Loader
# dsk["sample_loader"] = dataset.sample_loader
# dsk["reference_map_getter"] = pandda_config.reference_map_getter
# dsk["map_loader"] = pandda_config.map_loader

# # Statistical Model
# dsk["statistical_model"] = pandda_config.statistical_model

# # Event Finding
# # # Clusterer
# dsk["clusterer"] = pandda_config.clusterer
# # # event finder
# dsk["event_finder"] = pandda_config.event_finder
# # # bdc clus
# dsk["bdc_calculator"] = pandda_config.bdc_calculator

# # Criticism
# # # Map maker
# dsk["map_maker"] = pandda_config.map_maker
# # # Event table maker
# dsk["event_table_maker"] = pandda_config.event_table_maker

# Define Data

### Global Data

In [45]:
# # Tree
# dsk["tree"] = tree

# # Reference
# dsk["reference"] = reference

# # Grid object
# # Doesn't work because of multiprocessing inside grid getting!
# # dsk["grid"] = (lambda sl, ref: sl.get_grid(reference=ref),
# #                   "sample_loader",
# #                   "reference")
# dsk["grid"] = grid


In [46]:
class TruncatedDatasets:
    def __init__(self, truncated_datasets):
        self.datasets = truncated_datasets
        
    def get_dataset(self, name):
        return self.datasets[name]

### Shell Data

In [47]:
%%time
from dask import delayed

 
name = ds[0][0]
d = ds[0][1]
print(d)

# ###############################################
# Get resolution
# ###############################################
resolutions_test = max([dts.data.summary.high_res for dtag, dts
                            in d.partition_datasets("test").items()])
resolutions_train = max([dts.data.summary.high_res for dtag, dts
                             in d.partition_datasets("train").items()])
max_res = max(resolutions_test, resolutions_train)


# ###############################################
# Instantiate sheel variable names
# ###############################################

# Dataset names
dtags = set(d.partition_datasets("test").keys()
                + d.partition_datasets("train").keys()
                )

dask_dtags = {"{}".format(dtag.replace("-", "_")): dtag
             for dtag
             in dtags}
train_dtags = [dtag 
               for dtag 
               in dask_dtags
               if (dask_dtags[dtag] in d.partition_datasets("train").keys())]
test_dtags = [dtag 
                   for dtag 
                   in dask_dtags
                   if (dask_dtags[dtag] in d.partition_datasets("test").keys())]   


# ###############################################
# Truncate datasets
# ###############################################
# TODO: move to imports section

truncated_reference, truncated_datasets = PanddaDiffractionDataTruncater()(d.datasets,
                                                                           reference)

# ###############################################
# Load computed variables into dask
# ###############################################
# Load shell sample laoder
shell_sample_loader = sample_loader.instantiate(grid, reference)

# Add truncated reference
shell_reference = truncated_reference

# Rename trucnated datasets
for ddtag, dtag in dask_dtags.items():
    truncated_datasets[ddtag] = truncated_datasets[dtag]

# record max res of shell datasets
shell_max_res = max_res


<multi_dataset_crystalography.dataset.dataset.MultiCrystalDataset instance at 0x7f1ceab633f8>
CPU times: user 1.28 s, sys: 71.1 ms, total: 1.35 s
Wall time: 1.36 s


In [48]:

# ###############################################
# Generate maps
# ###############################################

# Generate reference map for shell
shell_ref_map = delayed(get_reference_map)(pandda_config.reference_map_getter, 
                                           reference, 
                                           shell_max_res, 
                                           grid)

# Load maps
xmaps = {}
for dtag in dask_dtags:
    xmaps[dtag] = delayed(load_sample)(pandda_config.map_loader, 
                                       truncated_datasets[dtag],
                                       grid,
                                       shell_ref_map,
                                       shell_max_res)

# ###############################################
# Fit statistical model to trianing sets
# ###############################################


In [49]:
def identity(x):
    return x

def get_persisted(x):
    return dask.delayed(identity)(x).compute()

In [50]:
%%time
# xmaps_delayed = dask.delayed(dict)([(dtag, xmaps_loaded[dtag]) for dtag in dask_dtags])
# xmaps_persisted_future = client.persist(xmaps_delayed)
xmaps_persisted_futures = client.persist([xmaps[dtag] for dtag in dask_dtags])

  (<multi_dataset_crystalography.dataset.sample_load ... 37866747767121)
Consider scattering large objects ahead of time
with client.scatter to reduce scheduler burden and 
keep data on workers

    future = client.submit(func, big_data)    # bad

    big_future = client.scatter(big_data)     # good
    future = client.submit(func, big_future)  # good
  % (format_bytes(len(b)), s)


CPU times: user 1min 44s, sys: 7.68 s, total: 1min 51s
Wall time: 1min 52s


In [51]:
# client.compute(xmaps_persisted_futures[0]).result()



In [52]:
# xmaps_computed = get_persisted(xmaps_persisted)

In [53]:
# xmaps_futures_dict = {dtag: xmaps_persisted_futures[i] for i, dtag in enumerate(dask_dtags)}
# xmaps_futures_dict

In [54]:
%%time
# xmaps_computed = [client.compute(xmaps_futures_dict[dtag]).result() for dtag in dask_dtags]
xmaps_computed = {dtag: client.compute(xmaps_persisted_futures[i]).result() 
                  for i, dtag 
                  in enumerate(dask_dtags)}

CPU times: user 57.8 s, sys: 22 s, total: 1min 19s
Wall time: 3min 35s


In [55]:
%%time

shell_fit_model = fit(pandda_config.statistical_model, 
                                         [xmaps_computed[dtag] for dtag in train_dtags], 
                                         [xmaps_computed[dtag] for dtag in test_dtags]
                                        )

	### Fitting mu!
	### Fitting sigma_uncertainty!
1   5    10   15   20   25   30   35   40   45   50   55   60   
|                                                           |
	### Fitting sigma_adjusted!
CPU times: user 1min 30s, sys: 10.6 s, total: 1min 40s
Wall time: 2min 4s


In [56]:
%%time
shell_fit_model_scattered = client.scatter(shell_fit_model)

CPU times: user 748 ms, sys: 313 ms, total: 1.06 s
Wall time: 3.01 s


In [57]:
%%time
xmaps_scattered = client.scatter([xmaps_computed[dtag] for dtag in dask_dtags])

CPU times: user 5.71 s, sys: 2.01 s, total: 7.72 s
Wall time: 12.2 s


In [58]:
xmaps_scattered_dict = {dtag: xmaps_scattered[i] for i, dtag in enumerate(dask_dtags)}

In [59]:
grid_scattered = client.scatter(grid)

In [60]:
%%time
# ###############################################
# Find events
# ###############################################
zmaps = {}
clusters = {}
events = {}
bdcs = {}
for dtag in dask_dtags:
    # Get z maps by evaluating model on maps
    zmaps[dtag]  = delayed(evaluate_model)(shell_fit_model,
                                               xmaps_scattered_dict[dtag]
                                              )

    # Cluster outlying points in z maps
    clusters[dtag]  = delayed(cluster_outliers)(pandda_config.clusterer,
                                                truncated_datasets[dtag],
                                                zmaps[dtag],
                                                grid_scattered
                                              )


    # Find events by filtering the clusters
    events[dtag]  = delayed(filter_clusters)(pandda_config.event_finder,
                                               truncated_datasets[dtag],
                                             clusters[dtag],
                                               grid_scattered
                                              )





CPU times: user 19.8 ms, sys: 6.9 ms, total: 26.7 ms
Wall time: 22.2 ms


In [61]:
# events_delayed = dask.delayed(dict)([(dtag, events[dtag]) for dtag in dask_dtags])
# cluster_outliers(pandda_config.clusterer, truncated_datasets['PDK2_x0008'], x['PDK2_x0008'], grid)

In [62]:
# zmaps_delayed = dask.delayed(dict)([(dtag, zmaps[dtag]) for dtag in dask_dtags])


In [63]:
# x = zmaps_delayed.compute()

In [64]:
# x

In [65]:
# clusters_delayed = dask.delayed(dict)([(dtag, clusters[dtag]) for dtag in dask_dtags])

In [66]:
# y = clusters_delayed.compute()

In [67]:
# y
# right(article(widget(email)))
# right(email) | content | widget
# right(email) | content | (ask() | )
# view.run(email)
# view: return page | div
# page: return content | div
# ask() | greet : reader(str->str)
# greet: str -> reader
# 

In [68]:
%%time
events_persisted_futures = client.persist([events[dtag] for dtag in dask_dtags])


CPU times: user 24.7 s, sys: 6.05 s, total: 30.7 s
Wall time: 30.8 s


In [69]:
# events_persisted_futures

In [70]:
%%time
events_computed = {dtag: client.compute(events_persisted_futures[i]).result() 
                  for i, dtag 
                  in enumerate(dask_dtags)}

CPU times: user 43.6 s, sys: 17.3 s, total: 1min
Wall time: 3min 41s


In [71]:
events_scattered = client.scatter([events_computed[dtag] for dtag in dask_dtags])
events_scattered_dict = {dtag: xmaps_scattered[i] for i, dtag in enumerate(dask_dtags)}

In [72]:
# Calculate background correction factors
for dtag in dask_dtags:

    bdcs[dtag]  = delayed(estimate_bdcs)(pandda_config.bdc_calculator,
                                         truncated_datasets[dtag],
                                         xmaps_scattered_dict[dtag],
                                         shell_ref_map,
                                         events[dtag],
                                         grid_scattered
                                              )

In [73]:
%%time
bdcs_persisted_futures = client.persist([bdcs[dtag] for dtag in dask_dtags])

CPU times: user 30.2 s, sys: 7.1 s, total: 37.3 s
Wall time: 37.3 s


In [74]:
%%time
bdcs_computed = {dtag: client.compute(bdcs_persisted_futures[i]).result() 
                  for i, dtag 
                  in enumerate(dask_dtags)}



CPU times: user 16.4 s, sys: 11 s, total: 27.4 s
Wall time: 52 s


In [75]:
bdcs_computed

{'PDK2_x0008': OrderedDict(),
 'PDK2_x0012': OrderedDict(),
 'PDK2_x0021': OrderedDict(),
 'PDK2_x0033': OrderedDict(),
 'PDK2_x0071': OrderedDict([(('PDK2-x0071', 1), 0.87),
              (('PDK2-x0071', 2), 0.84)]),
 'PDK2_x0081': OrderedDict(),
 'PDK2_x0087': OrderedDict([(('PDK2-x0087', 1), 0.84)]),
 'PDK2_x0100': OrderedDict(),
 'PDK2_x0138': OrderedDict([(('PDK2-x0138', 1), 0.88),
              (('PDK2-x0138', 2), 0.88)]),
 'PDK2_x0173': OrderedDict(),
 'PDK2_x0187': OrderedDict([(('PDK2-x0187', 1), 0.78),
              (('PDK2-x0187', 2), 0.84)]),
 'PDK2_x0190': OrderedDict([(('PDK2-x0190', 1), 0.83)]),
 'PDK2_x0214': OrderedDict([(('PDK2-x0214', 1), 0.88)]),
 'PDK2_x0219': OrderedDict(),
 'PDK2_x0275': OrderedDict(),
 'PDK2_x0329': OrderedDict([(('PDK2-x0329', 1), 0.9299999999999999),
              (('PDK2-x0329', 2), 0.86)]),
 'PDK2_x0352': OrderedDict([(('PDK2-x0352', 1), 0.83)]),
 'PDK2_x0383': OrderedDict([(('PDK2-x0383', 1), 0.84),
              (('PDK2-x0383', 2), 0.75)])

In [76]:
events_computed

{'PDK2_x0008': (0, [], []),
 'PDK2_x0012': (0, [], []),
 'PDK2_x0021': (0, [], []),
 'PDK2_x0033': (0, [], []),
 'PDK2_x0071': (2,
  [(<scitbx_array_family_flex_ext.vec3_double at 0x7f1ce8722aa0>,
    <scitbx_array_family_flex_ext.double at 0x7f1ce87223c0>),
   (<scitbx_array_family_flex_ext.vec3_double at 0x7f1ce8722730>,
    <scitbx_array_family_flex_ext.double at 0x7f1ce8722788>)],
  [<pandda.analyse.events.Event at 0x7f1cf4fb0790>,
   <pandda.analyse.events.Event at 0x7f1cf0af8490>]),
 'PDK2_x0081': (0, [], []),
 'PDK2_x0087': (1,
  [(<scitbx_array_family_flex_ext.vec3_double at 0x7f1cf82f1b50>,
    <scitbx_array_family_flex_ext.double at 0x7f1cf82f1aa0>)],
  [<pandda.analyse.events.Event at 0x7f1cf5defc10>]),
 'PDK2_x0100': (0, [], []),
 'PDK2_x0138': (2,
  [(<scitbx_array_family_flex_ext.vec3_double at 0x7f1cf9286cb0>,
    <scitbx_array_family_flex_ext.double at 0x7f1cf9286208>),
   (<scitbx_array_family_flex_ext.vec3_double at 0x7f1cf9788998>,
    <scitbx_array_family_flex_ext.d

In [77]:
# events_delayed = dask.delayed(dict)([(dtag, events[dtag]) for dtag in dask_dtags])

In [78]:
# %%time
# events_persisted_future = events_delayed.persist()

In [79]:
# get_persisted(events_persisted)

In [80]:
# events_persisted_future.compute()

In [81]:
# bdcs_delayed = dask.delayed(dict)([(dtag, bdcs[dtag]) for dtag in dask_dtags])

In [82]:
# %%time
# bdcs_persisted_future = bdcs_delayed.persist()

In [83]:
# bdcs_persisted_future.compute()

In [98]:
# map_maker, tree, map_loader, truncated_dataset, ref_map, events, bdcs, grid
# Criticise each indiidual dataset (generate statistics, event map and event table)
event_maps = {}
for dtag in dask_dtags:
    
#     event_maps[dtag] = delayed(make_event_map)(pandda_config.map_maker, 
#                                                tree, 
#                                                pandda_config.map_loader, 
#                                                truncated_datasets[dtag], 
#                                                shell_ref_map, 
#                                                events[dtag], 
#                                                bdcs[dtag], 
#                                                grid,
#                                                shell_fit_model_scattered)
    event_maps[dtag] = {}
    for event_id, bdc in bdcs_computed[dtag].items():
        event_maps[dtag][event_id] = delayed(make_event_map)(make_map,
                                                   tree,
                                                   xmaps_scattered_dict[dtag], 
                                                   truncated_datasets[dtag], 
                                                   shell_ref_map, 
                                                   (events_computed[dtag][1][event_id[1]-1],
                                                       events_computed[dtag][2][event_id[1]-1]),
                                                   bdc,
                                                  shell_fit_model_scattered, 
                                                   grid_scattered)
    #self, sample_loader, truncated_dataset, ref_map, events, bdcs, dataset_path
    # xmap, truncated_dataset, ref_map, events, bdcs, dataset_path, statistical_model, grid

In [None]:
%%time
# event_maps_persisted_futures = client.persist([[event_maps[dtag][event_id] for event_id, bdc in bdcs] for dtag in dask_dtags])
event_maps_persisted_futures = {}
for dtag in dask_dtags:
    event_maps_persisted_futures[dtag] = {}
    for event_id, bdc in bdcs_computed[dtag].items():
        event_maps_persisted_futures[dtag][event_id] = client.persist(event_maps[dtag][event_id])
# get() | lambda name: put("tintin") | lambda: unit(hello name) | liftList | 

In [None]:
pandda_config.map_maker

In [None]:
%%time
# map_maker, tree, map_loader, truncated_dataset, ref_map, events, bdcs, grid, statistical_model
# event_maps_computed = {dtag: {event_id: client.compute(event_maps_persisted_futures[i]).result() 
#                   for i, dtag 
#                   in enumerate(dask_dtags)}
event_maps_computed = {}
for dtag in dask_dtags:
    event_maps_computed[dtag] = {}
    for event_id, bdc in bdcs_computed[dtag].items():
        event_maps_computed[dtag][event_id] = client.compute(event_maps[dtag][event_id]).result()

In [None]:
event_maps_computed

In [None]:
shell_maps = delayed(make_shell_maps)(map_maker, tree, name, reference, shell_ref_map)

In [None]:
shell_maps_persisted_futures = client.persist(shell_maps)
shell_maps_computed = shell_maps_persisted_futures.result()

# Make event table

In [None]:
event_table = delayed(make_event_table)(event_table_maker, 
                                        tree, 
                                        name, 
                                        d, 
                                        events_computed
                                       )

    

In [None]:
event_table_persisted_future = client.persist(event_table)
event_table_computed = event_table_persisted_future.result()

In [None]:

# # Join the event tables for each dataset to get a shell table
# shell_event_table = delayed(criticise_all)(shell_fit_model,
#                                                 [event_tables[dtag] for dtag in dask_dtags])

In [None]:
# %%time
# event_maps_delayed = dask.delayed(dict)([(dtag, event_maps[dtag]) for dtag in dask_dtags])
# event_maps_computed = event_maps_delayed.compute()

In [None]:
# %%time
# shell_maps_computed = shell_maps.compute()

In [None]:
# %%time
# shell_event_table.compute()

In [None]:
# def shell(idx, ds):
    
#     dsk = {}
#     # ###############################################
#     # Get resolution
#     # ###############################################
#     resolutions_test = max([dts.data.summary.high_res for dtag, dts
#                                 in d.partition_datasets("test").items()])
#     resolutions_train = max([dts.data.summary.high_res for dtag, dts
#                                  in d.partition_datasets("train").items()])
#     max_res = max(resolutions_test, resolutions_train)
    
    
#     # ###############################################
#     # Instantiate sheel variable names
#     # ###############################################
    
#     name = str(idx)
#     shell_sample_loader = "{}_sample_loader".format(name)
#     shell_ref_map = "{}_ref_map".format(name)
#     shell_loaded_model = "{}_loaded_model".format(name)
#     shell_max_res = "{}_max_res".format(name)
#     shell_fit_model = "{}_fit_model".format(name)
#     shell_event_table = "{}_event_table".format(name)
#     shell_reference = "{}_reference".format(name)
#     map_loader = "map_loader"
#     shell_datasets = "{}_truncated_datasets".format(name)
    
#     # Dataset names
    
#     dtags = set(d.partition_datasets("test").keys()
#                     + d.partition_datasets("train").keys()
#                     )
    
#     dask_dtags = {"{}".format(dtag.replace("-", "_")): dtag
#                  for dtag
#                  in dtags}
#     train_dtags = [dtag 
#                    for dtag 
#                    in dask_dtags
#                    if (dask_dtags[dtag] in d.partition_datasets("train").keys())]
#     test_dtags = [dtag 
#                        for dtag 
#                        in dask_dtags
#                        if (dask_dtags[dtag] in d.partition_datasets("test").keys())]   
#     print(test_dtags)
#     print(train_dtags)
    
#     truncated_datasets = {dtag: "{}_{}_truncated_dataset".format(name, dtag) for dtag in dask_dtags}
#     xmaps = {dtag: "{}_{}_xmap".format(name, dtag) for dtag in dask_dtags}
#     zmaps = {dtag: "{}_{}_zmap".format(name, dtag) for dtag in dask_dtags}
#     clusters = {dtag: "{}_{}_clusters".format(name, dtag) for dtag in dask_dtags}
#     events = {dtag: "{}_{}_events".format(name, dtag) for dtag in dask_dtags}
#     bdcs = {dtag: "{}_{}_bdcs".format(name, dtag) for dtag in dask_dtags}
#     event_tables = {dtag: "{}_{}_event_table".format(name, dtag) for dtag in dask_dtags}

#     # ###############################################
#     # Truncate datasets
#     # ###############################################
#     # TODO: move to imports section
 
#     truncated_reference, truncated_datasets_local = PanddaDiffractionDataTruncater()(d.datasets,
#                                                                                reference)
    
#     # ###############################################
#     # Load computed variables into dask
#     # ###############################################
    
#     # Load shell sample laoder
#     dsk[shell_sample_loader] = (lambda sl, g, r: sl.instantiate(g, r),
#                                 "sample_loader",
#                                 "grid",
#                                 "reference")
    
#     dsk[shell_datasets] = TruncatedDatasets(truncated_datasets_local)
    
#     for ddtag, dtag in dask_dtags.items():
#         print("{} : {}".format(truncated_datasets[ddtag], truncated_datasets_local[dtag]))
#         dsk[truncated_datasets[ddtag]] = truncated_datasets_local[dtag]
        
#     # Add truncated reference
#     dsk[shell_reference] = truncated_reference
    
#     # record max res of shell datasets
#     dsk[shell_max_res] = max_res
    
#     # ###############################################
#     # Generate maps
#     # ###############################################
    
#     # Generate reference map for shell
#     dsk[shell_ref_map] = (get_reference_map,
#                           "reference_map_getter", 
#                           "reference", 
#                           shell_max_res, 
#                           "grid")
#     # Load maps
#     for dtag in dask_dtags:
#         dsk[xmaps[dtag]] = (load_sample,
#                             map_loader, 
#                             truncated_datasets[dtag],
#                             "grid",
#                             shell_ref_map,
#                             shell_max_res)
    
#     # ###############################################
#     # Fit statistical model to trianing sets
#     # ###############################################
#     dsk[shell_fit_model] = (fit,
#                            "statistical_model", 
#                              [xmaps[dtag] for dtag in train_dtags], 
#                              [xmaps[dtag] for dtag in test_dtags]
#                           )
    
# #     # ###############################################
# #     # Find events
# #     # ###############################################
# #     for dtag in dask_dtags:
# #         # Get z maps by evaluating model on maps
# #         dsk[zmaps[dtag]]  = (evaluate,
# #                                                    shell_fit_model,
# #                                                    xmaps[dtag]
# #                                                   )

# #         # Cluster outlying points in z maps
# #         dsk[clusters[dtag]]  = (cluster,
# #                                                    shell_fit_model,
# #                                                    xmaps[dtag],
# #                                                    truncated_datasets[dtag],
# #                                                    shell_ref_map
# #                                                   )
        
        
# #         # Find events by filtering the clusters
# #         dsk[events[dtag]]  = (filter_clusters,
# #                                                    shell_fit_model,
# #                                                    xmaps[dtag],
# #                                                    truncated_datasets[dtag],
# #                                                    shell_ref_map
# #                                                   )
        
# #         # Calculate background correction factors
# #         dsk[bdcs[dtag]]  = (estimate_bdcs,
# #                                                    shell_fit_model,
# #                                                    xmaps[dtag],
# #                                                    truncated_datasets[dtag],
# #                                                    shell_ref_map
# #                                                   )
        
    
# #     # Criticise each indiidual dataset (generate statistics, event map and event table)
# #     for dtag in dask_dtags:
# #         dsk[event_tables[dtag]] = (criticise,
# #                                    shell_fit_model,
# #                                    truncated_datasets[dtag],
# #                                    events[dtag]
# #                                   )
    
# #     # Join the event tables for each dataset to get a shell table
# #     dsk[shell_event_table] = (criticise_all,
# #                             shell_fit_model,
# #                             [event_tables[dtag] for dtag in dask_dtags])

#     return dsk
    

In [None]:
# loop over model blocks
for idx, d in ds:
    
    dic = shell(idx, d)
    
    dsk.update(dic)
    
    break
    

In [None]:
print("##################################################")
print("Reference\n")
print("{}: {}".format('reference', dsk["reference"]))

print("##################################################")
print("Reference\n")
print("{}: {}".format('grid', dsk["grid"]))

print("##################################################")
print("Reference\n")
print("{}: {}".format('0_max_res', dsk["0_max_res"]))

print("##################################################")
print("sample loader\n")
print("{}: {}".format('0_sample_loader', dsk['0_sample_loader']))

print("##################################################")
print("map loader\n")
print("{}: {}".format('map_loader', dsk['map_loader']))

print("##################################################")
print("Truncated dataset\n")
print("{}: {}".format('0_PDK2_x0384_truncated_dataset', dsk['0_PDK2_x0384_truncated_dataset']))

print("##################################################")
print("Reference map\n")
print("{}: {}".format('0_ref_map', dsk["0_ref_map"]))

print("##################################################")
print("xmap\n")
print("{}: {}".format("0_PDK2_x0384_xmap", dsk["0_PDK2_x0407_xmap"]))

print("##################################################")
print("fit model\n")
print("{}: {}".format("0_fit_model", dsk["0_fit_model"]))

# print("##################################################")
# print("zmaps\n")
# print("{}: {}".format("0_PDK2_x0384_zmap", dsk["0_PDK2_x0384_zmap"]))


# print("##################################################")
# print("Cluster\n")
# print("{}: {}".format("0_PDK2_x0384_cluster", dsk["0_PDK2_x0384_events"]))

# print("##################################################")
# print("Events\n")
# print("{}: {}".format("0_PDK2_x0384_events", dsk["0_PDK2_x0384_events"]))

# print("##################################################")
# print("Event table\n")
# print("0_PDK2_x0384_event_table: {}".format(dsk["0_PDK2_x0384_event_table"]))

# print("##################################################")
# print("Shel event table\n")
# print("0_event_table: {}".format(dsk["0_event_table"]))

In [None]:
# dsk["0_truncated_datasets"]

In [None]:
# client.get(dsk, "0_truncated_datasets")

In [None]:
# client.get(dsk, "0_PDK2_x0407_truncated_dataset")

In [None]:
# client.get(dsk, "0_ref_map")

In [None]:
# client.get(dsk, "0_max_res")

In [None]:
# client.get(dsk, "map_loader")

In [None]:
# client.get(dsk, "0_PDK2_x0407_xmap")

In [None]:
# client.get(dsk, "statistical_model")

In [None]:
# client.get(dsk, "grid")

In [None]:
client.get(dsk, "0_fit_model")

In [None]:
client.get(dsk, "0_PDK2_x0384_zmap")

In [None]:
client.get(dsk, "0_PDK2_x0384_cluster")

In [None]:
client.get(dsk, "0_PDK2_x0384_events")

In [None]:
client.get(dsk, "0_PDK2_x0384_event_table")

In [None]:
client.close()

In [None]:
# dsk["model_loaded"] = model_loaded

In [None]:
dtags = set(model_loaded.dataset.partition_datasets("test").keys()
                    + model_loaded.dataset.partition_datasets("train").keys()
                    )

In [None]:
resolutions_test = max([d.data.summary.high_res for dtag, d
                                in model_loaded.dataset.partition_datasets("test").items()])
resolutions_train = max([d.data.summary.high_res for dtag, d
                                 in model_loaded.dataset.partition_datasets("train").items()])
max_res = max(resolutions_test, resolutions_train)



In [None]:
sample_loader = model_loaded.dataset.sample_loader

In [None]:
# sample_loaders = {dtag: lambda d: sample_loader.get_sample(res, d)
#                           for dtag
#                           in dtags}

In [None]:
for dtag in dtags:
    dsk[dtag] = (sample_loader.get_sample, max_res, model_loaded.dataset.datasets[dtag])

In [None]:
dsk["params"] = (model_loaded.statistical_model.fit, 
                 [dtag for dtag, d in model_loaded.dataset.partition_datasets("train").items()], 
                 [dtag for dtag, d in model_loaded.dataset.partition_datasets("test").items()])

In [None]:
dsk

In [None]:
params = client.get(dsk, "params")

In [None]:
# Load models over nodes
model_fit = client.submit(fit, model_loaded)


In [None]:
model_fit.result()

In [None]:
# Ask models to process
model_evaluated = evaluate(model_fit)


In [None]:
model_evaluated.result()

In [None]:
# Ask models to criticise
event_tables = criticise(models_evaluated)


In [None]:
client.reset()

In [None]:
client.restart()

In [None]:
client.close()

# Fit, evaluate, Criticise - All datasets, Dask distributed

## TESTING

In [None]:
cluster = SGECluster(queue="medium.q",
                     cores=1,
                     processes=1,
                           memory="64GB",
                           resourcce_spec="m_mem_free=64G",
                    python="/dls/science/groups/i04-1/conor_dev/ccp4/build/bin/cctbx.python")

In [None]:
cluster.scale(3)

In [None]:
time.sleep(15)

In [None]:
cluster = LocalCluster(n_workers=2, threads_per_worker=1)

In [None]:
client = Client(cluster)

In [None]:
client

In [None]:
# Main Loop
dataloader = DefaultPanDDADataloader(min_train_datasets=60, 
                                     max_test_datasets=60)


In [None]:
ds = [(idx, d) for idx, d in dataloader(dataset)]

In [None]:
# Get base distributed model
pandda_event_model_distributed = PanDDAEventModelDistributed(pandda_config.statistical_model,
                                      pandda_config.clusterer,
                                      pandda_config.event_finder,
                                        dataset=dataset,
                                      bdc_calculator=pandda_config.bdc_calculator,
                                      statistics=[],
                                      map_maker=pandda_config.map_maker, 
                                      event_table_maker=pandda_config.event_table_maker,
                                      cpus=config["args"]["cpus"],
                                      tree=tree)

In [None]:
pandda_event_model_distributed.instantiate(reference,
                                           tree)

In [None]:
# Instantiate models
models = [pandda_event_model_distributed.clone(dataset=d, 
                                   name=idx)
         for idx, d
         in ds]

In [None]:
# models

In [None]:
# # Load model moponents
# models_loaded = client.map(load,
#                     models,
#                           pure=False)

In [None]:
# models_loaded

In [None]:
# # Load models over nodes
# models_fit = client.map(fit, 
#                            models_loaded,
#                        pure=False)


In [None]:
# models_fit

In [None]:
# # Ask models to process
# models_evaluated = client.map(evaluate, 
#                               models_fit,
#                              pure=False)


In [None]:
# models_evaluated

In [None]:
# # Ask models to criticise
# event_tables = client.map(criticise, 
#                                models_evaluated, 
#                                pure=False)


In [None]:
# event_tables

In [None]:
# event_tables_results =[e.result() for e in event_tables]

In [None]:
dsk = {}
for i, model in enumerate(models):
    dsk["load_{}".format(i)] = (load, model)
    dsk["fit_{}".format(i)] = (fit, "load_{}".format(i))
    dsk["evaluate_{}".format(i)] = (evaluate, "fit_{}".format(i))
    dsk["criticise_{}".format(i)] = (criticise, "evaluate_{}".format(i))


In [None]:
dsk

In [None]:
dsk_combined, deps = dask.optimization.fuse(dsk)

In [None]:
client.get(dsk_combined, ["criticise_{}".format(i)
                          for i, model
                          in enumerate(models)])

In [None]:
client.close()

In [None]:
# Visualise tasks


In [None]:
# Ask models for events
event_tables = event_tables.results()

In [None]:
with ProcessModelSeriel() as P:            

    for dataset in dataloader(dataset):

        P(pandda_event_model(dataset))
        # call with self as model: model.fit(); model.evaluate(); model.criticise()
            # Seriel: run immediately
            # Qsub: pick model

    # exit: 
        # Seriel: just go on
        # qsub: wait for the jobs to complete

In [None]:
for dataset in dataloader(dataset):
    
    with pandda_event_model(dataset) as model:

        # Fit model
        print("Fitting model")
        model.fit()  # Fit the statistical model

        # Evaluate model
        model.evaluate_parallel()  # Evaluate the fitted model on maps, fidning events

        # Criticise Model
        model.criticise() # Stores statistics from model fitting and evaluation


# Criticise Run

## TESTING

In [None]:
# Criticise Model
pandda_statistics = PanDDARunStatistics(dataset, model)  # Generates statistics from dataset and model
PanDDARunGraphs(pandda_statistics)  # Produces a set of graphs of statistics
pandda_html = PanDDARunHTML(pandda_statistics)  # Produces a HTML from Statistics
PanDDARunLog()  # Produces a log of data processing from dataset, model, graphs and HTML


In [None]:
with processor as P:            

    for dataset in dataloader(dataset):

        P(pandda_event_model(dataset))
        # call with self as model: model.fit(); model.evaluate(); model.criticise()
            # Seriel: run immediately
            # Qsub: pick model

    # exit: 
        # Seriel: just go on
        # qsub: wait for the jobs to complete
        

In [None]:
class PanDDARunStatistics:
    
    def __init__(dataset, model):
        
        self.statistics = [...]
        
        for statistic in self.statistic:
            try:
                statistic.calculate()
                self.trace[statistic.name] = statistic.log()

            except Exeption as e:
                self.trace[statistic.name] = "{}".format(e)
        

In [None]:
class OutputNativeMaps:
    
    def calculate(samples):
        
        
        
    def log():
        
    