To run this notebook, load it in a local Jupyter instance (`pip install jupyter`). You'll also need these dependencies:

```
pip install tf-nightly
pip install google-cloud-storage
pip install requests
pip install google-api-python-client
```

You may also need to run this if you're not inside a google cloud VM:

```
gcloud auth application-default login
```

You need to configure [OAuth](https://support.google.com/cloud/answer/6158849?hl=en). It's a complicated process, best described [here](https://github.com/googleapis/google-api-python-client/blob/master/docs/client-secrets.md). In the end you donwload the `client_secrets.json` file and use it below.

In [None]:
import argparse
from astronet import tune

config_name = 'local_global_new'

tune.FLAGS = tune.parser.parse_args([
  '--client_secrets', '../client_secrets.json',
  '--model', 'AstroCnnModel',
  '--config_name', config_name,
  '--train_files', '',
])

client = tune.initialize_client()

In [None]:
import pandas as pd

pd.set_option('max_colwidth', 100)

resp = client.projects().locations().studies().list(
  parent=tune.study_parent()).execute()
studies = pd.DataFrame(resp['studies'])
studies = studies.sort_values('createTime', ascending=False)
studies.head(5)

In [None]:
study = studies['name'][14]
study_id = '{}/studies/{}'.format(tune.study_parent(), study.split('/')[-1])
print(study_id)
resp = client.projects().locations().studies().trials().list(
  parent=study_id).execute()

metrics_r = []
metrics_p = []
params = []
trial_ids = []
for trial in resp['trials']:
  if 'finalMeasurement' not in trial:
    continue

  r, = (m['value'] for m in trial['finalMeasurement']['metrics'] if m['metric'] == 'r')  
  p, = (m['value'] for m in trial['finalMeasurement']['metrics'] if m['metric'] == 'p')
  
  if r < 0.05:
    continue
    
  params.append(trial['parameters'])
  metrics_r.append(r)
  metrics_p.append(p)
  trial_ids.append(int(trial['name'].split('/')[-1]))
  
print(max(trial_ids), 'valid studies')

In [None]:
import matplotlib
from matplotlib import pyplot as plt

matplotlib.rcParams.update({'font.size': 16})

fig, ax = plt.subplots()
ax.scatter(metrics_r, metrics_p)
plt.xlabel("recall")
plt.ylabel("precision")
sorted_metrics = sorted(zip(metrics_r, metrics_p))

for i, trial_id in enumerate(trial_ids):
  if ((metrics_r[i], metrics_p[i]) >= sorted_metrics[-5]):
    ax.annotate(' {}'.format(i), (metrics_r[i], metrics_p[i]))

plt.gcf().set_size_inches((16, 16))

In [None]:
best = 786
print(trial_ids[best])

In [None]:
import pprint
from astronet import models

config = models.get_model_config('AstroCNNModel', config_name)

for param in params[best]:
  tune.map_param(config['hparams'], param)
  
pprint.pprint(config['hparams'])

In [None]:
import difflib
import pprint
from astronet import models

config1 = models.get_model_config('AstroCNNModel', config_name)

config2 = models.get_model_config('AstroCNNModel', config_name)
for param in params[best]:
  tune.map_param(config2['hparams'], param)
  
pp = pprint.PrettyPrinter()
print('\n'.join(difflib.unified_diff(
  pp.pformat(config1).split('\n'), pp.pformat(config2).split('\n'),
  n=0
)))

```
python astronet/tune.py --model=AstroCNNModel --config_name=local_global_new --train_files=astronet/tfrecords-new/test-0000[0-5]* --eval_files=astronet/tfrecords-new/test-0000[6-6]* --train_steps=7000 --tune_trials=1000
```