In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import openslide
from probreg import cpd
from probreg import transformation as tf
import cv2
from PIL import Image
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import json
import wandb
import platform

In [3]:
import sys
sys.path.append("..")
from registration_tree import Rect, QuadTree

In [4]:
sweep_config = {
    'method': 'random', #'bayes' # 'random'
}

metric = {
    'name': 'dist_mean',
    'goal': 'minimise'   
    }


sweep_config['metric'] = metric

In [5]:
parameters_dict = {
    'point_extractor': {
        'values': ["sift"] #'orb', 
        },
    'maxFeatures': {
        #'values': [64, 128, 256, 512, 768, 1024, 2048]
        'values': [64, 128, 256, 512, 768, 1024, 2048]
        },
    'crossCheck': {
        'values': [False]
        },
    'flann': {
        'values': [False, True]
        },
    'ratio': {
        #'values': [.1, .2, .3, .4, .5, .6, .7, .8, .9]
        'values': [.3, .4, .5, .6, .7, .8, .9]
        },
    'use_gray': {
        'values': [True, False]
        },
    'homography': {
        'values': [True, False]
        },
    'filter_outliner': {
        'values': [False]
        },
    'target_depth': {
        'values': [0]
        },
    'thumbnail_size': {
        #'values': [(1024, 1024), (2048, 2048), (4096, 4096)]
        'values': [(1024, 1024), (2048, 2048), (4096, 4096), (8192, 8192)] #
        },    
    
    'image_type': {
        'values': ["Cyto"] #, "Cyto"
        },
    
    'source_scanner': {
        'values': ["Aperio"]
        },
}
sweep_config['parameters'] = parameters_dict

In [6]:
sweep_id = wandb.sweep(sweep_config, project="quadtree")
sweep_id

Create sweep with ID: nvchnev8
Sweep URL: https://wandb.ai/christianml/quadtree/sweeps/nvchnev8


'nvchnev8'

In [7]:
def add_help_fields(frame):
    
    frame["image_name_stem"] = [Path(image_name).stem for image_name in frame["image_name"]]    
    frame["patient_id"] = [name.split("_")[2] for name in frame["image_name"]]

    frame["x1"] = [json.loads(vector.replace("\'","\""))['x1'] for vector in frame["vector"]]
    frame["y1"] = [json.loads(vector.replace("\'","\""))['y1'] for vector in frame["vector"]]

    frame["x2"] = [json.loads(vector.replace("\'","\""))['x2'] for vector in frame["vector"]]
    frame["y2"] = [json.loads(vector.replace("\'","\""))['y2'] for vector in frame["vector"]]

    frame["center_x"] = [x1 + ((x2-x1) / 2) for x1, x2 in zip(frame["x1"], frame["x2"])]
    frame["center_y"] = [y1 + ((y2-y1) / 2) for y1, y2 in zip(frame["y1"], frame["y2"])]
    
    frame["center"] = [np.array((center_x, center_y)) for center_x, center_y in zip(frame["center_x"], frame["center_y"])]

    frame["anno_width"] = [x2-x1 for x1, x2 in zip(frame["x1"], frame["x2"])]
    frame["anno_height"]= [y2-y1 for y1, y2 in zip(frame["y1"], frame["y2"])]
    
    return frame

In [8]:
folder = Path("..")

slide_folder = Path("D:/Datasets/ScannerStudy")
if slide_folder.exists() == False:
    slide_folder = Path("/data/ScannerStudy")
if slide_folder.exists() == False:
    slide_folder = Path("/mnt/d/Datasets/ScannerStudy")
if slide_folder.exists() == False:
    slide_folder = Path("/data/ScannerStudy")
    
slide_files = {path.name: path for path in slide_folder.glob("*/*/*.*")}

In [9]:
def train(config=None):
    
    # Initialize a new wandb run
    with wandb.init(config=config):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        
        annotations = add_help_fields(pd.read_csv(folder / "Validation/GT.csv"))
        annotations = annotations[annotations["image_type"] == config.image_type]
        
        source_scanner_annotations = annotations[annotations["scanner"] == config.source_scanner]
    
        dist_list, mean_reg_error_list = [], []
                        
        step = 0
        for patient_id in tqdm(source_scanner_annotations["patient_id"].unique()):

            source_annos = source_scanner_annotations[source_scanner_annotations["patient_id"] == patient_id]
            source_anno = source_annos.iloc[0]

            target_patient_annotations = annotations[annotations["patient_id"] == patient_id]

            for target_image_name in tqdm(target_patient_annotations["image_name"].unique()):
                image_dist_list = []

                target_annos = target_patient_annotations[target_patient_annotations["image_name"] == target_image_name]
                target_anno = target_annos.iloc[0]
                
                if source_anno.scanner == target_anno.scanner:
                    continue

                source_slide = openslide.OpenSlide(str(slide_files[source_anno.image_name]))
                target_slide = openslide.OpenSlide(str(slide_files[target_anno.image_name]))

                source_dimension = Rect.create(Rect, 0, 0, source_slide.dimensions[0], source_slide.dimensions[1])
                target_dimension = Rect.create(Rect, 0, 0, target_slide.dimensions[0], target_slide.dimensions[1])


                qtree = QuadTree(source_dimension, source_slide, target_dimension, target_slide, debug=False, **config)
                                  
                mean_reg_error_list.append(qtree.mean_reg_error)
                
                intersections = list(set(source_annos["type_name"]).intersection(target_annos["type_name"]))
                
                for type_name in intersections:

                    source_anno = source_annos[source_annos["type_name"] == type_name].iloc[0]
                    target_anno = target_annos[target_annos["type_name"] == type_name].iloc[0]

                    box = [source_anno.center_x, source_anno.center_y, source_anno.anno_width, source_anno.anno_height]
                    target_box = [target_anno.center_x, target_anno.center_y, target_anno.anno_width, target_anno.anno_height]

                    trans_box = qtree.transform_boxes(np.array([box]))[0]

                    distance = np.linalg.norm(target_box[:2]-trans_box[:2])

                    dist_list.append(distance)
                    image_dist_list.append(distance)
                    
                image_dist_list = np.array(image_dist_list)
                wandb.log({
                    "dist_mean_image": image_dist_list.mean(),
                    "dist_mean_image": image_dist_list.min(),
                    "dist_mean_image": image_dist_list.max(),
                    "mean_reg_error_image": qtree.mean_reg_error,
                    "step": step,
                })
                
                step += 1
        
        dist_list, mean_reg_error_list = np.array(dist_list), np.array(mean_reg_error_list),
        
        wandb.log({
            "dist_mean": dist_list.mean(),
            "dist_min": dist_list.min(),
            "dist_max": dist_list.max(),
            
            "mean_reg_error": mean_reg_error_list.mean(),
        })

In [10]:
sweep_id

'nvchnev8'

In [None]:
wandb.agent(sweep_id, train)

INFO - 2021-01-26 15:44:48,448 - pyagent - Starting sweep agent: entity=None, project=None, count=None
[34m[1mwandb[0m: Agent Starting Run: i1fzmr8i with config:
[34m[1mwandb[0m: 	crossCheck: False
[34m[1mwandb[0m: 	filter_outliner: False
[34m[1mwandb[0m: 	flann: False
[34m[1mwandb[0m: 	homography: False
[34m[1mwandb[0m: 	image_type: Cyto
[34m[1mwandb[0m: 	maxFeatures: 128
[34m[1mwandb[0m: 	point_extractor: sift
[34m[1mwandb[0m: 	ratio: 0.9
[34m[1mwandb[0m: 	source_scanner: Aperio
[34m[1mwandb[0m: 	target_depth: 0
[34m[1mwandb[0m: 	thumbnail_size: [1024, 1024]
[34m[1mwandb[0m: 	use_gray: True
[34m[1mwandb[0m: Currently logged in as: [33mchristianml[0m (use `wandb login --relogin` to force relogin)


  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s][A
 25%|██▌       | 1/4 [00:06<00:20,  6.68s/it][A
 75%|███████▌  | 3/4 [00:14<00:04,  4.71s/it][A
100%|██████████| 4/4 [00:18<00:00,  4.62s/it][A
 20%|██        | 1/5 [00:18<01:13, 18.48s/it]
  0%|          | 0/4 [00:00<?, ?it/s][A
 25%|██▌       | 1/4 [00:04<00:12,  4.07s/it][A
 75%|███████▌  | 3/4 [00:10<00:03,  3.37s/it][A
100%|██████████| 4/4 [00:12<00:00,  3.24s/it][A
 40%|████      | 2/5 [00:31<00:45, 15.23s/it]
  0%|          | 0/4 [00:00<?, ?it/s][A
 25%|██▌       | 1/4 [00:04<00:14,  4.82s/it][A
 75%|███████▌  | 3/4 [00:11<00:03,  3.87s/it][A
100%|██████████| 4/4 [00:17<00:00,  4.29s/it][A
 60%|██████    | 3/5 [00:48<00:32, 16.12s/it]
  0%|          | 0/4 [00:00<?, ?it/s][A
 25%|██▌       | 1/4 [00:01<00:04,  1.56s/it][A
 75%|███████▌  | 3/4 [00:06<00:02,  2.26s/it][A
100%|██████████| 4/4 [00:09<00:00,  2.27s/it][A
 80%|████████  | 4/5 [00:57<00:13, 13.35s/it]
  0%|          | 0/4 [00:0

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
dist_mean_image,38.85395
mean_reg_error_image,6373.07629
step,14.0
_step,15.0
_runtime,76.0
_timestamp,1611672365.0
dist_mean,125.45839
dist_min,0.65812
dist_max,3858.84867
mean_reg_error,4313.90214


0,1
dist_mean_image,▁▁█▁▁▁▁▁▁▁▁▁▁▁▁
mean_reg_error_image,▂▃█▁▁▃▁▂▃▄▄▆▂▃▅
step,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
_step,▁▁▂▂▃▃▄▄▅▅▆▆▇▇██
_runtime,▁▂▂▃▃▄▄▅▅▆▆▆▇███
_timestamp,▁▂▂▃▃▄▄▅▅▆▆▆▇███
dist_mean,▁
dist_min,▁
dist_max,▁
mean_reg_error,▁


[34m[1mwandb[0m: Agent Starting Run: tw96jtrp with config:
[34m[1mwandb[0m: 	crossCheck: False
[34m[1mwandb[0m: 	filter_outliner: False
[34m[1mwandb[0m: 	flann: False
[34m[1mwandb[0m: 	homography: False
[34m[1mwandb[0m: 	image_type: Cyto
[34m[1mwandb[0m: 	maxFeatures: 512
[34m[1mwandb[0m: 	point_extractor: sift
[34m[1mwandb[0m: 	ratio: 0.5
[34m[1mwandb[0m: 	source_scanner: Aperio
[34m[1mwandb[0m: 	target_depth: 0
[34m[1mwandb[0m: 	thumbnail_size: [8192, 8192]
[34m[1mwandb[0m: 	use_gray: True


  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s][A
 25%|██▌       | 1/4 [21:32<1:04:36, 1292.22s/it][A
 75%|███████▌  | 3/4 [58:11<19:09, 1149.52s/it]  [A
100%|██████████| 4/4 [1:01:32<00:00, 923.05s/it][A
 20%|██        | 1/5 [1:01:32<4:06:08, 3692.20s/it]
  0%|          | 0/4 [00:00<?, ?it/s][A
 25%|██▌       | 1/4 [25:59<1:17:59, 1559.92s/it][A
 50%|█████     | 2/4 [26:00<21:25, 642.94s/it]   [A