In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import openslide
from probreg import cpd
from probreg import transformation as tf
import cv2
from PIL import Image
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import json
import wandb

In [3]:
import sys
sys.path.append("..")
from registration_tree import Rect, QuadTree

In [8]:
def add_help_fields(frame):
    
    frame["image_name_stem"] = [Path(image_name).stem for image_name in frame["image_name"]]    
    frame["patient_id"] = [name.split("_")[2] for name in frame["image_name"]]

    frame["x1"] = [json.loads(vector.replace("\'","\""))['x1'] for vector in frame["vector"]]
    frame["y1"] = [json.loads(vector.replace("\'","\""))['y1'] for vector in frame["vector"]]

    frame["x2"] = [json.loads(vector.replace("\'","\""))['x2'] for vector in frame["vector"]]
    frame["y2"] = [json.loads(vector.replace("\'","\""))['y2'] for vector in frame["vector"]]

    frame["center_x"] = [x1 + ((x2-x1) / 2) for x1, x2 in zip(frame["x1"], frame["x2"])]
    frame["center_y"] = [y1 + ((y2-y1) / 2) for y1, y2 in zip(frame["y1"], frame["y2"])]
    
    frame["center"] = [np.array((center_x, center_y)) for center_x, center_y in zip(frame["center_x"], frame["center_y"])]

    frame["anno_width"] = [x2-x1 for x1, x2 in zip(frame["x1"], frame["x2"])]
    frame["anno_height"]= [y2-y1 for y1, y2 in zip(frame["y1"], frame["y2"])]
    
    return frame

In [9]:
folder = Path("..")

slide_files = {path.name: path for path in Path("D:/Datasets/ScannerStudy").glob("*/*/*.*")}

In [13]:
def train(config=None):
    
    # Initialize a new wandb run
    with wandb.init(config=config):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        
        annotations = add_help_fields(pd.read_csv(folder / "Validation/GT.csv"))
        annotations = annotations[annotations["image_type"] == config.image_type]
        
        source_scanner_annotations = annotations[annotations["scanner"] == config.source_scanner]
    
        dist_list, q_list, sigma2_list = [], [], []
                        
        step = 0
        for patient_id in tqdm(source_scanner_annotations["patient_id"].unique()):

            source_annos = source_scanner_annotations[source_scanner_annotations["patient_id"] == patient_id]
            source_anno = source_annos.iloc[0]

            target_patient_annotations = annotations[annotations["patient_id"] == patient_id]

            for target_image_name in tqdm(target_patient_annotations["image_name"].unique()):
                image_dist_list = []

                target_annos = target_patient_annotations[target_patient_annotations["image_name"] == target_image_name]
                target_anno = target_annos.iloc[0]

                source_slide = openslide.OpenSlide(str(slide_files[source_anno.image_name]))
                target_slide = openslide.OpenSlide(str(slide_files[target_anno.image_name]))

                source_dimension = Rect.create(Rect, 0, 0, source_slide.dimensions[0], source_slide.dimensions[1])
                target_dimension = Rect.create(Rect, 0, 0, target_slide.dimensions[0], target_slide.dimensions[1])


                qtree = QuadTree(source_dimension, source_slide, target_dimension, target_slide, **config)
                                  
                q_list.append(qtree.q)
                sigma2_list.append(qtree.sigma2)
                
                intersections = list(set(source_annos["type_name"]).intersection(target_annos["type_name"]))
                
                for type_name in intersections:

                    source_anno = source_annos[source_annos["type_name"] == type_name].iloc[0]
                    target_anno = target_annos[target_annos["type_name"] == type_name].iloc[0]

                    box = [source_anno.center_x, source_anno.center_y, source_anno.anno_width, source_anno.anno_height]
                    target_box = [target_anno.center_x, target_anno.center_y, target_anno.anno_width, target_anno.anno_height]

                    trans_box = qtree.transform_boxes(np.array([box]))[0]

                    distance = np.linalg.norm(target_box[:2]-trans_box[:2])

                    dist_list.append(distance)
                    image_dist_list.append(distance)
                    
                image_dist_list = np.array(image_dist_list)
                wandb.log({
                    "dist_mean_image": image_dist_list.mean(),
                    "dist_mean_image": image_dist_list.min(),
                    "dist_mean_image": image_dist_list.max(),
                    "step": step,
                })
                
                step += 1
        
        dist_list, q_list, sigma2_list = np.array(dist_list), np.array(q_list), np.array(sigma2_list)
        
        wandb.log({
            "dist_mean": dist_list.mean(),
            "dist_min": dist_list.min(),
            "dist_max": dist_list.max(),
            
            "q_mean": q_list.mean(),
            "q_min": q_list.min(),
            "q_max": q_list.max(),
            
            "sigma2_mean": sigma2_list.mean(),
            "sigma2_min": sigma2_list.min(),
            "sigma2_max": sigma2_list.max(),
        })

In [14]:
sweep_id = "q6j7pg6g"

In [15]:
wandb.agent(sweep_id, train, project="quadtree")

wandb: Agent Starting Run: 0txsjquq with config:
wandb: 	crossCheck: False
wandb: 	filter_outliner: False
wandb: 	flann: True
wandb: 	homography: True
wandb: 	image_type: CCMCT
wandb: 	maxFeatures: 256
wandb: 	point_extractor: sift
wandb: 	ratio: 0.8
wandb: 	source_scanner: Aperio
wandb: 	target_depth: 0
wandb: 	thumbnail_size: [1024, 1024]
wandb: 	use_gray: True
Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
wandb: wandb version 0.10.15 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


  0%|                                                                                            | 0/5 [00:00<?, ?it/s]
  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A
 25%|█████████████████████                                                               | 1/4 [00:01<00:05,  1.95s/it][A
 50%|██████████████████████████████████████████                                          | 2/4 [00:05<00:04,  2.30s/it][A
 75%|███████████████████████████████████████████████████████████████                     | 3/4 [00:07<00:02,  2.47s/it][A
100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:24<00:00,  6.06s/it][A
 20%|████████████████▊                                                                   | 1/5 [00:24<01:36, 24.24s/it]
  0%|                                                                                            | 0/4 [00:00<?, ?it/s][A
 25%|█████████████████

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
dist_mean_image,81.87275
step,19
image_name,N1_CCMCT_22108_1.ndp...
_step,20
_runtime,122
_timestamp,1611581078
dist_mean,63.64235
dist_min,0.0
dist_max,574.64732
q_mean,-1.0


0,1
dist_mean_image,▂▁▂▁▂▁▂▇▂▁▂▁▂▁▃█▂▁▇▂
step,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
_step,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
_runtime,▁▁▁▂▃▃▃▃▃▃▃▄▄▄▄▅▅▆▆██
_timestamp,▁▁▁▂▃▃▃▃▃▃▃▄▄▄▄▅▅▆▆██
dist_mean,▁
dist_min,▁
dist_max,▁
q_mean,▁
q_min,▁


wandb: Ctrl + C detected. Stopping sweep.
