In [None]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from skimage import io
from os.path import expanduser
from tqdm import tqdm
HOME = expanduser("~")
import os
import cv2

In [None]:
import torch
import torch.utils.data
from PIL import Image
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [None]:
def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

modelpath = os.path.join(HOME, '/net/birdstore/Active_Atlas_Data/data_root/brains_info/masks/mask.model.pth')
loaded_model = get_model_instance_segmentation(num_classes=2)
if os.path.exists(modelpath):
    loaded_model.load_state_dict(torch.load(modelpath,map_location=torch.device('cpu')))
else:
    print('no model to load')
transform = torchvision.transforms.ToTensor()


In [None]:
def combine_dims(a):
    if a.shape[0] > 0:
        a1 = a[0,:,:]
        a2 = a[1,:,:]
        a3 = np.add(a1,a2)
    else:
        a3 = np.zeros([a.shape[1], a.shape[2]]) + 255
    return a3

def greenify_mask(image):
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    r[image == 1], g[image == 1], b[image == 1] = [255,255,255]
    coloured_mask = np.stack([r, g, b], axis=2)
    return coloured_mask

def merge_mask(image, mask):
    b = mask
    g = image
    r = np.zeros_like(image).astype(np.uint8)
    merged = np.stack([r, g, b], axis=2)
    return merged


In [None]:
#DIR = os.path.join(HOME, 'programming', 'dk39')
DIR = '/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/DK63/preps'
INPUT = os.path.join(DIR, 'CH1/normalized')
MASKS = os.path.join(DIR, 'thumbnail_masked')
GREENS = os.path.join(DIR, 'thumbnail_green')
TESTS = os.path.join(DIR, 'thumbnail_test')
os.makedirs(TESTS, exist_ok=True)
files = sorted(os.listdir(INPUT))

In [None]:
bads = [116,119,127,132,140]
bads = [str(b).zfill(3) + '.tif' for b in bads]

In [None]:
%%time
loaded_model.eval()
file = '011.tif'
infile = os.path.join(INPUT, file)
outpath = os.path.join(MASKS, file)
test_path = os.path.join(TESTS, file)
img = Image.open(infile)
input = transform(img)
input = input.unsqueeze(0)
with torch.no_grad():
    pred = loaded_model(input)
pred_score = list(pred[0]['scores'].detach().numpy())
masks = [(pred[0]['masks']>0.5).squeeze().detach().cpu().numpy()]
mask = masks[0]
dims = mask.ndim
if dims > 2:
    mask = combine_dims(mask)

#del img
#raw_img = cv2.imread(infile, -1)
raw_img = np.array(img)
mask = mask.astype(np.uint8)
mask[mask>0] = 255
merged_img = merge_mask(raw_img, mask)
#cv2.imwrite(test_path, merged_img)    
fig=plt.figure(figsize=(26,18), dpi= 100, facecolor='w', edgecolor='k')
plt.imshow(merged_img, cmap="gray")
plt.title('merged:{}'.format(file), fontsize=30)
plt.tick_params(axis='x', labelsize=30)
plt.tick_params(axis='y', labelsize=30)
plt.show()

In [None]:
infile = os.path.join(GREENS, '024.tif')
img = cv2.imread(infile, -1)
mask = img[:,:,2]
mask[mask>0] = 255
print(mask.dtype, mask.shape, np.unique(mask))

fig=plt.figure(figsize=(26,18), dpi= 100, facecolor='w', edgecolor='k')
plt.imshow(r, cmap="gray")
plt.title('r:{}'.format(file), fontsize=30)
plt.tick_params(axis='x', labelsize=30)
plt.tick_params(axis='y', labelsize=30)
plt.show()