# Model testing

## Includes

In [None]:
# magic command for matplotlib
%matplotlib notebook

# mass includes
import os, sys, warnings
import torch as t
import numpy as np
import matplotlib.pyplot as plt
from pickle import dump
from umap import UMAP
from sklearn.mixture import GaussianMixture as GMM
from cv2 import imread, imwrite

# add paths for all sub-folders
paths = [root for root, _, _ in os.walk('.')\
         if 'evals' not in root]
for item in paths:
    sys.path.append(item)

from ipynb.fs.full.config import Config
from ipynb.fs.full.network import *
from ipynb.fs.full.dataLoader import *
from ipynb.fs.full.util import *

## Initialization

In [None]:
# define models
opt = Config()
net_enc = Encoder().to(opt.device)
net_sum = Summarizer().to(opt.device)

# load pre-trained weights
net_enc.load()
net_sum.load()

# dataset for training
test_dataset = ImageSet(opt,
                        mode='test',
                        norm=True,
                        rand_trans=False,
                        mask_out=True)
test_loader = t.utils.data.DataLoader(test_dataset)

# make result folder
result_path = os.path.join(opt.save_path, 'results')
if os.path.exists(result_path) is False:
    os.makedirs(result_path)

## Run pre-trained models

In [None]:
# set to evaluation mode
net_enc.eval()
net_sum.eval()

encode_list = []
for s_idx, (sample, f_names) in enumerate(test_loader):
    # copy to device
    imgs = sample[:, :, :3, :, :].to(opt.device)
    masks = sample[:, :, 3, :, :].to(opt.device).unsqueeze(2)

    # reshape for batch processing
    _, n, _, h, w = sample.size()
    imgs = imgs.view(n, -1, h, w)
    masks = masks.view(n, -1, h, w)

    for f_idx, f_name in enumerate(f_names):
        pred_list = []
        iou_list = []
        for h_idx in range(0, h - opt.crop_size, opt.win_stride):
            for w_idx in range(0, w - opt.crop_size, opt.win_stride):
                in_img = imgs[f_idx, :, h_idx:h_idx + opt.crop_size,
                              w_idx:w_idx + opt.crop_size].unsqueeze(0)
                in_mask = masks[f_idx, :, h_idx:h_idx + opt.crop_size,
                                w_idx:w_idx + opt.crop_size].unsqueeze(0)

                # extract encodings
                with t.no_grad():
                    pred_loc_feats, _ = net_enc(in_img, in_mask)
                    pred_glb_feats = net_sum(pred_loc_feats)
                    glb_feats_np = pred_glb_feats.cpu().squeeze().numpy()
                    in_mask = in_mask.cpu().squeeze().numpy()

                    # register predictions
                    pred_list.append(glb_feats_np)
                    iou_list.append(np.sum(in_mask) / opt.crop_size**2)

        # select top n samples
        top_n_list = np.argsort(iou_list)[-opt.top_n:]

        # register samples
        for t_idx in top_n_list:
            encode_list.append((f_name[0], pred_list[t_idx], s_idx))

## Dimension reduction

In [None]:
# assemble input array
all_data = []
all_labels = []
for _, sample_feats, label in encode_list:
    all_data.append(sample_feats)
    all_labels.append(label)

# n_samples × n_encodings
all_data = np.stack(all_data, axis=0)

# umap dimension reduction
reducer = UMAP(random_state=42)
all_embeds = reducer.fit_transform(all_data)

# show distribution
plt.figure(figsize=(4, 3))
plt.scatter(all_embeds[:, 0], all_embeds[:, 1], c=all_labels)

## Fit GMM cluster

In [None]:
# fit GMM and predict labels
cluster_list = GMM(n_components=opt.data_part[1],
                   init_params='k-means++',
                   max_iter=1000,
                   random_state=42).fit_predict(all_embeds)

# collect all clusters
cluster_dict = {}
for d_idx, (f_name, _, _) in enumerate(encode_list):
    if f_name in cluster_dict:
        cluster_dict[f_name].append(cluster_list[d_idx])
    else:
        cluster_dict[f_name] = [cluster_list[d_idx]]

# find dominant cluster
for f_name in cluster_dict:
    # read 3 channels
    file_path = os.path.join(opt.data_path, f_name)
    r = imread(file_path + '-actin.tif', -1)
    b = imread(file_path + '-DNA.tif', -1)
    g = imread(file_path + '-pH3.tif', -1)

    # assemble as rgb image
    out_img = np.stack([b, g, r], axis=-1).astype('float32')
    out_img = (out_img - np.min(out_img)) / (np.max(out_img) - np.min(out_img))
    out_img = (out_img * 255).astype('uint8')

    dom_label = np.argmax(np.bincount(cluster_dict[f_name]))

    # save image according to the label
    file_path = os.path.join(result_path,
                             'M%02d-F%s.png' % (dom_label, f_name))
    imwrite(file_path, out_img)

print('Classification results have been save to %s' % result_path)