## **ANCIS Training and Detection**
Attentive neural cell instance segmentation, 

**Code:** ANCIS-Pythorch https://github.com/yijingru/ANCIS-Pytorch

**Paper:** Jingru Yi, Pengxiang Wu, Menglin Jiang, Qiaoying Huang, Daniel J. Hoeppner, Dimitris N. Metaxas, Attentive neural cell instance segmentation, Medical Image Analysis, Volume 55, 2019, Pages 228-240, https://doi.org/10.1016/j.media.2019.05.004.

---

**Begin:**  The first step, is to gather your data into two directories (folders).  One for the original images to be processed and one for data labels.

Additionally, if you are testing only and do not have labels, that is fine.  

If you are conducting training, labels are required.  Annotation files are not required.

Validation images are recommended for training.

---
NOTE: To activate GPU, go to Runtime (tab at top) --> Change runtime type, then select GPU under Hardware accelerator

---

**Next:**  Run all sections of the code.  Use defaults if unsure. 

---

**Then:**  Train or Test

---

**NOTES:**  Click folder on the left side <----- to see files. Then click "up arrow on folder" icon to get full list.

Noise detection function may vary, due to variations in images.

# **I. Download Requirements and Import Libraries**

In [None]:
#@markdown ___
#@markdown ## **A. Mount Drive**
#@markdown Using Drive?:
DRV = True #@param {type: "boolean"}
from google.colab import files
from google.colab import drive
if DRV == True:
  drive.mount('/content/drive')


In [None]:
#@markdown ___
#@markdown ## **B. Install Dependencies**
if DRV:
  %cd /content/drive/My Drive/Colab Notebooks/ANCIS
else:
  %cd
  !git clone --quiet https://github.com/yijingru/ANCIS-Pytorch.git
  %cd ANCIS

!pip install opencv-python
!pip install torch>0.4.0

In [None]:
%cd
#@markdown ___
#@markdown ## **C. Import Libraries**

import os
import cv2
import numpy as np
import pickle
import skimage
from skimage.measure import label, regionprops
import shutil

import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

# **II. Set Data Paths, Import Data and Process**

In [None]:
#@markdown ___
%cd

#@markdown ## **1. Set Data Paths and Load Images**

#@markdown
#@markdown ### **A. Uploading New Images?**
New = True #@param {type:"boolean"}
#@markdown

#@markdown ### **B. Training or Testing?**
#@markdown Both Training and Validation data is required for the "Train" option.

#@markdown *The Validation Set can be about 10-20% the size of the training set.*

#@markdown *Additionally, image labels are required for training and validation, but not for testing.*
Select = 'Test' #@param ["Test", "Train"] {type:"string"}
#@markdown

# Run Only to Clear Data
#@markdown ### **C. Do You Want to Clear Old Image Data?**
#@markdown New data will be added to your path.  If images are already in that path,
#@markdown they will remain.  You can also create new paths folders for your new i (default=off)
Clr = True #@param {type:"boolean"} 
#@markdown

#@markdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

#@markdown Advanced Options:
#@markdown ### **D. Select Datapaths** 
#@markdown Paths set in the cloud; leave defaults unless adding new separate datasets.

#@markdown Can also set Drive paths here, if images in Drive folder
DATA_PATH = "/root/datasets" #@param{type:"string"}
TRAIN_PATH = "/root/datasets/train" #@param{type:"string"}
LABEL_PATH = "/root/datasets/labels" #@param{type:"string"}
VAL_PATH = "/root/datasets/val" #@param{type:"string"}
VAL_LABEL_PATH = "/root/datasets/vallabel" #@param{type:"string"}
TEST_PATH = "/root/datasets/test" #@param{type:"string"}
TEST_LABEL_PATH = "/root/datasets/testlabel" #@param{type:"string"}
#@markdown *Test set does not require labels.*

annoDir = "/root/datasets/Labels"

#@markdown

if Clr == True:

  if os.path.exists(LABEL_PATH):
    shutil.rmtree(LABEL_PATH)
  if os.path.exists(TRAIN_PATH):
    shutil.rmtree(TRAIN_PATH)
  if os.path.exists(TEST_PATH):
    shutil.rmtree(TEST_PATH)
  if os.path.exists(VAL_PATH):
    shutil.rmtree(VAL_PATH)
  if os.path.exists(TEST_LABEL_PATH):
    shutil.rmtree(TEST_LABEL_PATH)
  if os.path.exists(VAL_LABEL_PATH):
    shutil.rmtree(VAL_LABEL_PATH)
  if os.path.exists(DATA_PATH):
    shutil.rmtree(DATA_PATH)
  if os.path.exists(annoDir):
    shutil.rmtree(annoDir)
#@markdown 
if not os.path.exists(annoDir):
  os.makedirs(annoDir)

#@markdown . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

#@markdown ### **E. Use buttons below and pop-up windows to select all desired images and labels**
#@markdown You must select each desired image, Colab doesn't like directories. Just click and drag, shift click, crl+a, etc...
if New == True:
  if Select == "Train":
    if not os.path.exists(TRAIN_PATH):
      os.makedirs(TRAIN_PATH)
    os.chdir(TRAIN_PATH)
    print('****************************************')
    print('Select all training images')
    print('****************************************')
    imagesTr = files.upload()

    if not os.path.exists(LABEL_PATH):
      os.makedirs(LABEL_PATH)
    os.chdir(LABEL_PATH)
    print('****************************************')
    print('Select all corresponding training labels')
    print('****************************************')
    labelsTr = files.upload()
  
    if not os.path.exists(VAL_PATH):
      os.makedirs(VAL_PATH)
    os.chdir(VAL_PATH)
    print('****************************************')
    print('Select all validation images')
    print('****************************************')
    imagesV = files.upload()

    if not os.path.exists(VAL_LABEL_PATH):
      os.makedirs(VAL_LABEL_PATH)
    os.chdir(VAL_LABEL_PATH)
    print('****************************************')
    print('Select all corresponding validation labels')
    print('****************************************')
    labelsV = files.upload()

  elif Select == "Test":
    if not os.path.exists(TEST_PATH):
      os.makedirs(TEST_PATH)
    os.chdir(TEST_PATH)
    print('****************************************')
    print('Select all test images')
    print('****************************************')
    imagesTe = files.upload()

    if not os.path.exists(TEST_LABEL_PATH):
      os.makedirs(TEST_LABEL_PATH)
    os.chdir(TEST_LABEL_PATH)
    print('****************************************')
    print('Select all corresponding test labels, or hit cancel')
    print('****************************************')
    labelsTe = files.upload()

In [None]:
#@markdown ___
#@markdown ## **2. Noise Detection and Removal**
%cd
#@markdown Would you like to conduct noise detection and removal?

#@markdown *Noise removal is only beneficial for test images.*
NR = False #@param {type: "boolean"}
#@markdown

#@markdown --- Select Noise Weights.  Load pre-trained, or upload your own.
NOISE_WEIGHTS = "Tissue" #@param ["Tissue", "Cell", "Select"] {type: "string"}

def bin_ndarray(ndarray, new_shape, operation='sum'):

    operation = operation.lower()
    if not operation in ['sum', 'mean']:
        raise ValueError("Operation not supported.")
    if ndarray.ndim != len(new_shape):
        raise ValueError("Shape mismatch: {} -> {}".format(ndarray.shape,
                                                           new_shape))
    compression_pairs = [(d, c//d) for d,c in zip(new_shape,
                                                  ndarray.shape)]
    flattened = [l for p in compression_pairs for l in p]
    ndarray = ndarray.reshape(flattened)
    for i in range(len(new_shape)):
        op = getattr(ndarray, operation)
        ndarray = op(-1*(i+1))
    return ndarray

#@markdown  . . . . . . . . . . . . . . . . . . . . . . . . . .  . . . . . . . . . . . . . . . . . . . . . . . . . .  . . . . . . . . . . . . . . . . . . . . . . . . . .  . . . . . . . . . . . . . . . . . . . . . . . . . .

#@markdown Advanced Options:

#@markdown ### **A. Select Target Noisy Image Path**
NOISE_IMG_PATH = "/root/datasets/test" #@param {type: "string"}
#@markdown ### **B. DeNoised Image Save Path**
IMG_SAVE_PATH = "/root/datasets/imagesDN" #@param {type: "string"}
if not os.path.exists(IMG_SAVE_PATH):
  os.makedirs(IMG_SAVE_PATH)
#@markdown ### **C. Noise Detections Save Path**
NOISE_MAP = "/root/datasets/Noise/map" #@param {type:"string"}
if not os.path.exists(NOISE_MAP):
  os.makedirs(NOISE_MAP)
#@markdown

if NOISE_WEIGHTS == "Tissue":
  NW = '/content/drive/My Drive/Colab Notebooks/weights/unet_noise_tissue.hdf5'
elif NOISE_WEIGHTS == "Cell":
  NW = '/content/drive/My Drive/Colab Notebooks/weights/unet_noise_cell_line.hdf5'
elif NOISE_WEIGHTS == "Select":
  nweights = files.upload()
  for name, data in nweights.items():
    with open(name, 'wb') as f:
      print ('saved file', name)
    NW = os.path.join(os.getcwd(), name)

if NR == True:
  %cd /root/
  if not os.path.exists('unet'):
    !git clone --quiet https://github.com/zhixuhao/unet.git
  %cd /root/unet/

  from model import *
  from data import *
  import numpy as np 
  import cv2
  import os
  import glob
  import skimage.io as io
  import skimage.transform as trans

  model = load_model(NW)
  test_path = NOISE_IMG_PATH
  save_path = NOISE_MAP
  save_path2 = IMG_SAVE_PATH

  def image_normalized(file_path):

      #img = cv2.imread(file_path,0)
      img = skimage.io.imread(file_path)
      img_shape = img.shape
      M = img_shape[0]
      N = img_shape[1]
      container = np.zeros((M,N,1,1))
      image_size = (img_shape[1],img_shape[0])
      img_standard = bin_ndarray(img*1.2, (M,N), operation='mean')
      #img_standard = cv2.resize(img, (M, M), interpolation=cv2.INTER_CUBIC)
      img_new = img_standard
      imgT = img_standard
      img_new = np.asarray([img_new / 255.])
      return img_new,image_size, imgT

  for name in os.listdir(test_path):
    image_path = os.path.join(test_path,name)
    if os.path.isdir(image_path):
      continue
    ll = len(name)
    img,img_size, imgT = image_normalized(image_path)
    img = np.reshape(img,img.shape+(1,))
    results = model.predict(img)
    out = np.zeros(img.shape)

    out = 255*results[0,:,:,0];
    
    cv2.imwrite(os.path.join(save_path, ("%s") % (name)), out)

    imgDN = imgT - out

    cv2.imwrite(os.path.join(save_path2, ("%s") % (name)), imgDN)

    print(name)


In [None]:
#@markdown ___
#@markdown ## **3. Process Data**
#@markdown Resizing, finding data metrics, and rewriting images to network specifications.

os.chdir(DATA_PATH)
#@markdown ### **A. Do You Want to Resize Your Images?**
#@markdown *Downsizing your data can decrease training time and memory requirements.*
#@markdown *If you are using one of our pretrained models, this is required (512x512).*
#@markdown

Resize = True #@param  {type:"boolean"}
#@markdown Resize to MxN
M = 512 #@param {type:"integer"}
N = 512 #@param {type:"integer"}
#@markdown

#@markdown ### **B. Do you want to write new data to your path?** 
#@markdown Required if loading new data, not if using data that has been processed already
Write = True #@param  {type:"boolean"}

def bin_ndarray(ndarray, new_shape, operation='sum'):
    """
    J.F. Sebastian
    Bins an ndarray in all axes based on the target shape, by summing or
        averaging.

    Number of output dimensions must match number of input dimensions and 
        new axes must divide old ones.

    Example
    -------
    >>> m = np.arange(0,100,1).reshape((10,10))
    >>> n = bin_ndarray(m, new_shape=(5,5), operation='sum')
    >>> print(n)

    [[ 22  30  38  46  54]
     [102 110 118 126 134]
     [182 190 198 206 214]
     [262 270 278 286 294]
     [342 350 358 366 374]]

    """
    operation = operation.lower()
    if not operation in ['sum', 'mean']:
        raise ValueError("Operation not supported.")
    if ndarray.ndim != len(new_shape):
        raise ValueError("Shape mismatch: {} -> {}".format(ndarray.shape,
                                                           new_shape))
    compression_pairs = [(d, c//d) for d,c in zip(new_shape,
                                                  ndarray.shape)]
    flattened = [l for p in compression_pairs for l in p]
    ndarray = ndarray.reshape(flattened)
    for i in range(len(new_shape)):
        op = getattr(ndarray, operation)
        ndarray = op(-1*(i+1))
    return ndarray

MP = 0
Num = 0
fin = 0
Spath = DATA_PATH
if Select == "Train":
  PATH = TRAIN_PATH
elif Select == "Test":
  PATH = TEST_PATH

NME = Select
seti=0

while fin == 0:
  print(NME)
  
  if Select == 'Test' and len(os.listdir(TEST_LABEL_PATH)) == 0:
    seti = 1
  if NME == 'Val':
    if len(os.listdir(VAL_LABEL_PATH)) == 0:
      seti = 1

  for name in os.listdir(PATH):
    if Select == "Train":
      path = os.path.join(TRAIN_PATH, name)
      path2 = os.path.join(LABEL_PATH, name)
      if NR == True:
        path = os.path.join(IMG_SAVE_PATH, name)
      if PATH == VAL_PATH:
        path = os.path.join(VAL_PATH, name)
        path2 = os.path.join(VAL_LABEL_PATH, name)
    elif Select == "Test":
      path = os.path.join(TEST_PATH, name)
      path2 = os.path.join(TEST_LABEL_PATH, name)
      if NR == True:
        path = os.path.join(IMG_SAVE_PATH, name)

    ll = len(name)
    #img = cv2.imread(path,0)
    img = skimage.io.imread(path)
    
    # Get Extension
    if Num == 0:
      nme, ext = os.path.splitext(name)
    img = img.astype('uint8')

    # Resize
    if Resize == True:
      #img = cv2.resize(img,(M,M),interpolation=cv2.INTER_LINEAR)
      img = bin_ndarray(img*1.2, (M,M), operation='mean')
    sh = img.shape

    # Pixel Average
    if PATH != VAL_PATH:
      if len(sh) == 3 or 4:
        MP = (np.mean(img,axis=(0,1)) + MP)
      else:
        MP = (np.mean(img) + MP)

    Num = Num + 1
    print(name)
    if Write == True:
      if not os.path.exists(os.path.join(Spath, NME, name[0:ll-4], "images")):
        os.makedirs(os.path.join(Spath, NME, name[0:ll-4], "images"))
      cv2.imwrite(os.path.join(Spath, NME, name[0:ll-4], "images", name[0:ll-4]+ '.png'), img)

    if Select == "Train" or (Select == "Test" and seti == 0):
      #img2 = cv2.imread(path2)
      img2 = skimage.io.imread(path2)
      img2 = label(img2)
      img2 = cv2.resize(img2,(M,M),interpolation=cv2.INTER_NEAREST)
      P = img2.max()
      out = np.zeros([M, N, P])
      
      if Write == True:
        if not os.path.exists(os.path.join(Spath, NME, name[0:ll-4], "masks")):		
          os.mkdir(os.path.join(Spath, NME, name[0:ll-4], "masks"))
        if len(img2.shape) > 2:
          if img2.shape[2] > 1:
            img2 = cv2.cvtColor(img2.astype('uint8'),cv2.COLOR_BGR2GRAY)
        cv2.imwrite(os.path.join(annoDir, name[0:ll-4] + '.png'),img2)
        for n in range(1,P+1):
          ind = np.where(img2 == n)
          for i in range(0,ind[0].shape[0]-1):
            out[ind[0][i],ind[1][i],n-1] = 1
          cv2.imwrite(os.path.join(Spath, NME, name[0:ll-4], "masks", name[0:ll-4] + "_" + str(n-1) + '.png'),out[:,:,n-1])

    elif Select == "Test" and seti == 1:
      
      img3 = np.ones([M,M])
      out = np.zeros([M, M,1])
      
      if Write == True:
        cv2.imwrite(os.path.join(TEST_LABEL_PATH, name[0:ll-4] + '.png'),img3)
        cv2.imwrite(os.path.join(annoDir, name[0:ll-4] + '.png'),img3)
        if not os.path.exists(os.path.join(Spath, NME, name[0:ll-4], "masks")):		
          os.mkdir(os.path.join(Spath, NME, name[0:ll-4], "masks"))
        for n in range(1,2):
          ind = np.where(img3 == n)
          for i in range(0,ind[0].shape[0]-1):
            out[ind[0][i],ind[1][i],n-1] = 1
          cv2.imwrite(os.path.join(Spath, NME, name[0:ll-4], "masks", name[0:ll-4] + "_" + str(n-1) + '.png'),img3)
  
  if NME == "Train":
    PATH = VAL_PATH
    NME = "Val"

    if len(os.listdir(VAL_PATH)) == 0:
      if not os.path.exists(os.path.join(Spath,NME,"val1")):
        os.makedirs(os.path.join(Spath,NME,"val1","images"))
        os.makedirs(os.path.join(Spath,NME,"val1","masks"))
      if len(img2.shape) > 2:
        if img2.shape[2] > 1:
          img2 = cv2.cvtColor(img2.astype('uint8'),cv2.COLOR_BGR2GRAY)
      img3 = np.ones([M,N,1])
      cv2.imwrite(os.path.join(Spath,NME,"val1","images","val1.png"),img3)
      cv2.imwrite(os.path.join(Spath,NME,"val1","masks","val1_0.png"),img3)
      cv2.imwrite(os.path.join(annoDir,"val1.png"),img3)
      fin = 1
      
  else:
    fin = 1

MP = MP/Num
if len(sh) != 3 or 4:
  MP2 = np.array([MP,MP,MP])
else:
  MP2 = MP

# **III. Configuration**



### **Configuration Parameters**

Breif explanation of the configuration parameters available below.

--**trainDir**: Training Directory Containing Images (default = /root/dataset/Train)

--**valDir**: Validation Image Directory for Training (default = /root/dataset/Validation)

--**cacheDir**: Directory for Training Cache Files (default = /root/ANCIS-Pytorch/cache)

--**batch_size**: Number of Images Processed at a Time.  More = faster, but the number is limited by available memory (default = 2)

--**multi_gpu**: Train Using Multiple GPUs, where Available (default = False)

--**num_workers**: Number of Batches Loaded at a Time (default = 4)

--**init_lr**: Initial Learning Rate (default = 0.001)

--**num_epochs**: Number of Training Epochs. More may result in better training at the cost of time. Too many could lead to overfitting and reduced accuracy. (default = 200)

--**decayEpochs**: Which Epoch to Begin Decaying the Learning Rate (default = 100)

--**Dec_weight_Dst**: Saving Directory for Trained Detection Network Weights. (default = /content/drive/My Drive/Colab Notebooks/ANCIS/Dec_Weights)

--**Seg_weight_Dst**: Saving Directory for Trained Segmentation Network Weights. (default = /content/drive/My Drive/Colab Notebooks/ANCIS/Seg_Weights)

--**Dec_log_Files**: Saving Directory for Training Statistics Log Files. (default = /content/drive/My Drive/Colab Notebooks/ANCIS/Dec_Weights)

--**Seg_log_Files**: Saving Directory for Training Statistics Log Files. (default = /content/drive/My Drive/Colab Notebooks/ANCIS/Seg_Weights)

--**img_height**: Height, or Rows, of Training Images (default = 512)

--**img_width**: Width, or Columns, or Training Images (default = 512)

--**num_classes**: Number of Classes to Train on, as in "Nucleus" and "Background". (default = 2)

--**top_k**: Maximum Number of Training Instances to Use. (default = 200)

--**conf_thresh**: Confidence Threshold. Minimum Confidence Score for a Positive Detection for Testing or Validation. (default = 0.5)

--**nms_thresh**: Near Maximum Suppression (NMS). Higher values allow for more detections, at the cost of more overlaps. (default = 0.7)

--**seg_thresh**: Segmentation Threshold. Minimum score for a positive segmentation. (default = 0.5)

In [None]:
#@markdown ___
#@markdown ## **1. Network Configuration Parameters**
#@markdown
%cd /content/drive/My Drive/Colab Notebooks/ANCIS

#@markdown ### **A. Train Directory**
trainDir = DATA_PATH + "/Train" #@param {type: "raw"}
#@markdown ### **B. Test Directory**
testDir = DATA_PATH + "/Test" #@param {type: "raw"}
#@markdown ### **C. Validation Directory**
valDir = DATA_PATH + "/Val" #@param {type: "raw"}
#@markdown ### **D. Cache Directory**
cacheDir = "/content/drive/My Drive/Colab Notebooks/ANCIS/cache" #@param {type: "string"}
if not os.path.exists(cacheDir):
  os.makedirs(cacheDir)
#@markdown ### **E. Batch Size**
batch_size = 2 #@param {type: "integer"}
#@markdown ### **F. Use Multiple GPUs (local)**
multi_gpu = False #@param {type: "boolean"}
#@markdown ### **G. Number of Workers**
num_workers = 4 #@param {type: "integer"}
#@markdown ### **H. Initial Learning Rate**
init_lr = 0.001 #@param {type: "number"}
#@markdown ### **I. Number of Epochs**
num_epochs =  10#@param {type: "integer"}
#@markdown ### **J. Epoch to Begin Learning Rate Decay**
decayEpoch =  8#@param {type: "integer"}
#@markdown ### **K. Detection Weights Save Path**
Dec_weight_Dst = "/content/drive/My Drive/Colab Notebooks/ANCIS/Dec_Weights" #@param {type: "string"}
if not os.path.exists(Dec_weight_Dst):
  os.makedirs(Dec_weight_Dst)
#@markdown ### **L. Save Path for Segmentation Weights**
Seg_weight_Dst = "/content/drive/My Drive/Colab Notebooks/ANCIS/Seg_Weights" #@param {type: "string"}
if not os.path.exists(Seg_weight_Dst):
  os.makedirs(Seg_weight_Dst)
#@markdown ### **M. Save Path for Detection Training Logs**
Dec_log_Files = "/content/drive/My Drive/Colab Notebooks/ANCIS/Dec_Weights" #@param {type: "string"}
if not os.path.exists(Dec_log_Files):
  os.makedirs(Dec_log_Files)
#@markdown ### **N. Save Path for Segmentation Training Logs**
Seg_log_Files = "/content/drive/My Drive/Colab Notebooks/ANCIS/Seg_Weights" #@param {type: "string"}
if not os.path.exists(Seg_log_Files):
  os.makedirs(Seg_log_Files)
#@markdown ### **O. Number of Classes**
num_classes = 2 #@param {type: "integer"}
#@markdown ### **P. Number of Detections to Keep**
top_k = 200 #@param {type: "integer"}
#@markdown ### **Q. Confidence Threshold**
conf_thresh = 0.5 #@param {type: "number"}
#@markdown ### **R. NMS Threshold**
nms_thresh = 0.7 #@param {type: "number"}
#@markdown ### **S. Segmentation threshold**
seg_thresh = 0.5 #@param {type: "number"}
#@markdown ### **T. Visualize Augmented Training Datasets**
vis = False #@param {type: "boolean"}

imgSuffix = '.png'
annoSuffix = '.png'
img_height = M
img_width = M

# **III. Detection (Region Proposal) Network**

In [None]:
%cd /content/drive/My Drive/Colab Notebooks/ANCIS
#@markdown ___
#@markdown ## **1. Train Detection Network**
#@markdown

#@markdown Train the Single Shot Detector (SSD) Used by ANCIS for region detection.

from dec_utils import *
from models import dec_net
from dec_utils import dec_transforms, dec_eval, dec_dataset_kaggle

def collater(data):
    imgs = []
    bboxes = []
    labels = []
    for sample in data:
        imgs.append(sample[0])
        bboxes.append(sample[1])
        labels.append(sample[2])
    return torch.stack(imgs,0), bboxes, labels

data_transforms = {
'train': dec_transforms.Compose([dec_transforms.ConvertImgFloat(),
                                dec_transforms.PhotometricDistort(),
                                dec_transforms.Expand(),
                                dec_transforms.RandomSampleCrop(),
                                dec_transforms.RandomMirror_w(),
                                dec_transforms.RandomMirror_h(),
                                dec_transforms.Resize(img_height, img_width),
                                dec_transforms.ToTensor()]),

'val': dec_transforms.Compose([dec_transforms.ConvertImgFloat(),
                              dec_transforms.Resize(img_height, img_width),
                              dec_transforms.ToTensor()])
}

dsets = {'train': dec_dataset_kaggle.NucleiCell(trainDir, annoDir, data_transforms['train'],
                                            imgSuffix=imgSuffix, annoSuffix=annoSuffix),
    'val': dec_dataset_kaggle.NucleiCell(valDir, annoDir, data_transforms['val'],
                                          imgSuffix=imgSuffix, annoSuffix=annoSuffix)}

dataloader = torch.utils.data.DataLoader(dsets['train'],
                                    batch_size = batch_size,
                                    shuffle = True,
                                    num_workers = num_workers,
                                    collate_fn = collater,
                                    pin_memory = True)

model = dec_net.resnetssd50(pretrained=True, num_classes=num_classes)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if multi_gpu:
  model = nn.DataParallel(model)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=0.9)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[decayEpoch, num_epochs], gamma=0.1)
criterion = DecLoss(img_height=img_height,
                img_width= img_width,
                num_classes=num_classes,
                variances=[0.1, 0.2])

if vis:
  for idx in range(len(dsets['train'])):
      img, bboxes, labels = dsets['train'].__getitem__(idx)
      img = img.numpy().transpose(1, 2, 0)*255
      bboxes = bboxes.numpy()
      labels = labels.numpy()
      for bbox in bboxes:
          y1, x1, y2, x2 = bbox
          cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), 2, lineType=1)
      cv2_imshow(np.uint8(img))
      k = cv2.waitKey(0)
      if k & 0xFF == ord('q'):
          cv2.destroyAllWindows()
          exit()
  cv2.destroyAllWindows()

# for validation data -----------------------------------
detector = Detect(num_classes=num_classes,
              top_k=top_k,
              conf_thresh=conf_thresh,
              nms_thresh=nms_thresh,
              variance=[0.1, 0.2])
anchorGen = Anchors(img_height, img_width)
anchors = anchorGen.forward()
if not os.path.exists(cacheDir):
  os.mkdir(cacheDir)
# --------------------------------------------------------
train_loss_dict = []
ap05_dict = []
ap07_dict = []
for epoch in range(num_epochs):
  print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  print('-' * 10)

  for phase in ['train', 'val']:
      if phase == 'train':
          scheduler.step()
          model.train()
          running_loss = 0.0
          for inputs, bboxes, labels in dataloader:
              inputs = inputs.to(device)
              # zero the parameter gradients
              optimizer.zero_grad()

              # forward
              # track history if only in train
              with torch.set_grad_enabled(phase == 'train'):
                  outputs = model(inputs)
                  loss_locs, loss_conf = criterion(outputs, bboxes, labels)
                  loss = loss_locs + loss_conf
                  # backward + optimize only if in training phase
                  if phase == 'train':
                      loss.backward()
                      optimizer.step()

              # statistics
              running_loss += loss.item() * inputs.size(0)

          epoch_loss = running_loss / len(dsets[phase])

          print('{} Loss: {:.4f}'.format(phase, epoch_loss))
          train_loss_dict.append(epoch_loss)
          np.savetxt(Dec_log_Files + '/dec_train_loss.txt', train_loss_dict, fmt='%.6f')
          if epoch % 5 == 0:
              torch.save(model.state_dict(),
                        os.path.join(Dec_weight_Dst, '{:d}_{:.4f}_model.pth'.format(epoch, epoch_loss)))
          torch.save(model.state_dict(), os.path.join(Dec_weight_Dst, 'end_model.pth'))

      else:
          model.eval()   # Set model to evaluate mode
          #model.eval()   # Set model to evaluate mode
          det_file = os.path.join(cacheDir, 'detections.pkl')
          all_boxes = [[[] for _ in range(len(dsets['val']))] for _ in range(num_classes)]
          for img_idx in range(len(dsets['val'])):
              ori_img = dsets['val'].load_img(img_idx)
              h,w,c = ori_img.shape
              inputs, gt_bboxes, gt_labels = dsets['val'].__getitem__(img_idx)  # [3, 512, 640], [3, 4], [3, 1]
              # run model
              inputs = inputs.unsqueeze(0).to(device)
              with torch.no_grad():
                  locs, conf = model(inputs)
              detections = detector(locs, conf, anchors)
              for cls_idx in range(1, detections.size(1)):
                  dets = detections[0, cls_idx, :]
                  mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t()
                  dets = torch.masked_select(dets, mask).view(-1, 5)
                  if dets.shape[0] == 0:
                      continue
                  pred_boxes = dets[:, 1:].cpu().numpy()
                  pred_score = dets[:, 0].cpu().numpy()
                  pred_boxes[:,0] /= img_height
                  pred_boxes[:,1] /= img_width
                  pred_boxes[:,2] /= img_height
                  pred_boxes[:,3] /= img_width
                  pred_boxes[:,0] *= h
                  pred_boxes[:,1] *= w
                  pred_boxes[:,2] *= h
                  pred_boxes[:,3] *= w
                  cls_dets = np.hstack((pred_boxes, pred_score[:, np.newaxis])).astype(np.float32, copy=False)
                  all_boxes[cls_idx][img_idx] = cls_dets

          with open(det_file, 'wb') as f:
              pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
              f.close()

          for cls_ind, cls in enumerate(dsets['val'].labelmap):
              filename = dec_eval.get_voc_results_file_template('test', cls, cacheDir)
              with open(filename, 'wt') as f:
                  for im_ind, index in enumerate(dsets['val'].img_files):
                      dets = all_boxes[cls_ind+1][im_ind]
                      if dets == []:
                          continue
                      for k in range(dets.shape[0]):
                          # format: [img_file  confidence, y1, x1, y2, x2] save to call for multiple times
                          f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.format(index,
                                                                                    dets[k, -1],
                                                                                    dets[k, 0],
                                                                                    dets[k, 1],
                                                                                    dets[k, 2],
                                                                                    dets[k, 3]))

          #ap05, ap07 = dec_eval.do_python_eval(dsets=dsets['val'],
          #                                    output_dir=cacheDir,
          #                                    offline=False,
          #                                    use_07=True)
          #print('ap05:{:.4f}, ap07:{:.4f}'.format(ap05, ap07))
          #ap05_dict.append(ap05)
          #np.savetxt(Dec_log_Files + '/dec_ap_05.txt', ap05_dict, fmt='%.6f')
          #ap07_dict.append(ap07)
          #np.savetxt(Dec_log_Files + '/dec_ap_07.txt', ap07_dict, fmt='%.6f')
print('Finish')

In [None]:
#@markdown ___
#@markdown ## **2. Test Detection Network**

#@markdown Test Region Proposal (Detection) Network using pre-trained weights

#@markdown ### **A. Select Detection Network Weights for Testing**
DECWEIGHT = 'Tissue' #@param ['Tissue', 'Cell', 'Combine', 'Kaggle', 'Select', 'Drive'] {type: 'string'}
if DECWEIGHT == 'Tissue':
  weightTst = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Tissue/end_model.pth"
elif DECWEIGHT == 'Cell':
  weightTst = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Cell/end_model.pth"
elif DECWEIGHT == 'Combine':
  weightTst = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Combine/end_model.pth"
elif DECWEIGHT == 'Kaggle':
  weightTst = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Kaggle/end_model.pth"
elif DECWEIGHT == 'Select':
  weights = files.upload()
  for name, data in weights.items():
    with open(name, 'wb') as f:
      print ('saved file', name)
    weightTst = os.path.join(os.getcwd(), name)
elif DECWEIGHT == 'Drive':
  #@markdown -- If selecting detection weights stored in the cloud or Drive, provide path
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/ANCIS/Dec_Weights/end_model.pth" #@param {type: "string"}

#@markdown ### **B. Detection Test Result Save Path**
SAVE_PATH = "/root/Save/RPN" #@param {type: "string"}
if not os.path.exists(SAVE_PATH):
  os.makedirs(SAVE_PATH)

%cd /content/drive/MyDrive/Colab Notebooks/ANCIS/

import argparse

from dec_utils import *
from models import dec_net
from dec_utils import dec_transforms
import cv2
import matplotlib
from matplotlib.pyplot import imshow
from dec_utils import dec_transforms, dec_eval, dec_dataset_kaggle
from google.colab.patches import cv2_imshow

def load_dec_weights(dec_model, dec_weights):
    print('Resuming detection weights from {} ...'.format(dec_weights))
    dec_dict = torch.load(dec_weights, map_location=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
    dec_dict_update = {}
    for k in dec_dict:
        if k.startswith('module') and not k.startswith('module_list'):
            dec_dict_update[k[7:]] = dec_dict[k]
        else:
            dec_dict_update[k] = dec_dict[k]
    dec_model.load_state_dict(dec_dict_update, strict=True)
    return dec_model

data_transforms = dec_transforms.Compose([dec_transforms.ConvertImgFloat(),
                                    dec_transforms.Resize(M, M),
                                    dec_transforms.ToTensor()])

dsets = dec_dataset_kaggle.NucleiCell(testDir, annoDir, data_transforms,
                    imgSuffix=imgSuffix, annoSuffix=annoSuffix)

model = dec_net.resnetssd50(pretrained=True, num_classes=num_classes)
model = load_dec_weights(model, weightTst)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
detector = Detect(num_classes=num_classes,
                  top_k=top_k,
                  conf_thresh=conf_thresh,
                  nms_thresh=nms_thresh,
                  variance=[0.1, 0.2])
anchorGen = Anchors(M, M)
anchors = anchorGen.forward()
#cv2.namedWindow('img')
names = dsets.load_img_ids()
print(names)
for img_idx in range(len(dsets)):
    ori_img = dsets.load_img(img_idx)
    h,w,c = ori_img.shape
    inputs, gt_bboxes, _ = dsets.__getitem__(img_idx)  # [3, 512, 640], [3, 4], [3, 1]
    inputs = inputs.unsqueeze(0).to(device)
    ames = names[img_idx]
    ame = (os.path.basename(ames))
    print(ame)
    with torch.no_grad():
        locs, conf = model(inputs)
    detections = detector(locs, conf, anchors)
    for cls_idx in range(1, detections.size(1)):
        dets = detections[0, cls_idx, :]
        mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t()
        dets = torch.masked_select(dets, mask).view(-1, 5)
        if dets.shape[0] == 0:
            continue
        dets = dets.cpu().numpy()
        for i in range(dets.shape[0]):
            box = dets[i,1:]
            score = dets[i,0]
            y1,x1,y2,x2 = box
            y1 = float(y1)/M
            x1 = float(x1)/M
            y2 = float(y2)/M
            x2 = float(x2)/M
            y1 = int(float(y1)*h)
            x1 = int(float(x1)*w)
            y2 = int(float(y2)*h)
            x2 = int(float(x2)*w)
            cv2.rectangle(ori_img, (x1, y1), (x2, y2), (0, 255, 0), 2, 2)
            cv2.putText(ori_img, "%.2f"%score, (x1, y1 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
                        (255, 0, 255))
    #imshow(ori_img)
    cv2.imwrite(os.path.join(SAVE_PATH, ame), ori_img)
    cv2_imshow(ori_img)
    #k = cv2.waitKey(0)
    #if k & 0xFF == ord('q'):
       # cv2.destroyAllWindows()
      #  exit()
cv2.destroyAllWindows()

# **IV. Instance Segmentation**

In [None]:
%cd /content/drive/My Drive/Colab Notebooks/ANCIS
#@markdown ___
#@markdown ## **1. Train Segmentation Network**

#@markdown Train the Instance Segmentation Network.  Must either train the detection (region proposal network) first, or upload pretrained weights.
#@markdown

#@markdown ### **A. Select Detection Network Weights to use for Segmentation Training**
DECWEIGHT = 'Tissue' #@param ['Tissue', 'Cell', 'Combine', 'Kaggle', 'Select', 'Drive'] {type: 'string'}
if DECWEIGHT == 'Tissue':
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Tissue/end_model.pth"
elif DECWEIGHT == 'Cell':
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Cell/end_model.pth"
elif DECWEIGHT == 'Combine':
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Combine/end_model.pth"
elif DECWEIGHT == 'Kaggle':
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Kaggle/end_model.pth"
elif DECWEIGHT == 'Select':
  weights = files.upload()
  for name, data in weights.items():
    with open(name, 'wb') as f:
      print ('saved file', name)
    dec_weights = os.path.join(os.getcwd(), name)
elif DECWEIGHT == 'Drive':
  #@markdown -- If selecting detection weights stored in the cloud or Drive, provide path
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/ANCIS/Dec_Weights/end_model.pth" #@param {type: "string"}


import argparse
import torch.optim as optim
from torch.optim import lr_scheduler

from seg_utils import *
from dec_utils import *
from seg_utils import seg_transforms, seg_dataset_kaggle, seg_eval_kaggle

from models import dec_net_seg, seg_net
import cv2
import os
parser = argparse.ArgumentParser(description='Detection Training (MultiGPU)')
parser.add_argument('--img_height', default=img_height, type=str, help='train image height')
parser.add_argument('--img_width', default=img_width, type=str, help='train image width')
parser.add_argument('--conf_thresh', default=conf_thresh, type=str, help='Detection Confidence Threshold')
parser.add_argument('--seg_thresh', default=seg_thresh, type=str, help='Detection Segmentation Threshold')
args = parser.parse_args("")
def collater(data):
    imgs = []
    bboxes = []
    labels = []
    masks = []
    for sample in data:
        imgs.append(sample[0])
        bboxes.append(sample[1])
        labels.append(sample[2])
        masks.append(sample[3])
    return torch.stack(imgs,0), bboxes, labels, masks

def load_dec_weights(dec_model, dec_weights):
    print('Resuming detection weights from {} ...'.format(dec_weights))
    dec_dict = torch.load(dec_weights)
    dec_dict_update = {}
    for k in dec_dict:
        if k.startswith('module') and not k.startswith('module_list'):
            dec_dict_update[k[7:]] = dec_dict[k]
        else:
            dec_dict_update[k] = dec_dict[k]
    dec_model.load_state_dict(dec_dict_update, strict=True)
    return dec_model

# ................. Training Code .................

#-----------------load detection model -------------------------
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dec_model = dec_net_seg.resnetssd50(pretrained=False, num_classes=num_classes)
dec_model = load_dec_weights(dec_model, dec_weights)
dec_model = dec_model.to(device)
#-------------------------------------------------------------------
dec_model.eval()        # detector set to 'evaluation' mode
for param in dec_model.parameters():
    param.requires_grad = False
#-----------------load segmentation model -------------------------
seg_model =  seg_net.SEG_NET(num_classes=num_classes)
seg_model= seg_model.to(device)
##--------------------------------------------------------------
data_transforms = {
    'train': seg_transforms.Compose([seg_transforms.ConvertImgFloat(),
                                      seg_transforms.PhotometricDistort(),
                                      seg_transforms.Expand(),
                                      seg_transforms.RandomSampleCrop(),
                                      seg_transforms.RandomMirror_w(),
                                      seg_transforms.RandomMirror_h(),
                                      seg_transforms.Resize(img_height, img_width),
                                      seg_transforms.ToTensor()]),

    'val': seg_transforms.Compose([seg_transforms.ConvertImgFloat(),
                                    seg_transforms.Resize(img_height, img_width),
                                    seg_transforms.ToTensor()])
}


dsets = {'train': seg_dataset_kaggle.NucleiCell(trainDir, annoDir, data_transforms['train'],
                              imgSuffix=imgSuffix, annoSuffix=annoSuffix),
          'val': seg_dataset_kaggle.NucleiCell(valDir, annoDir, data_transforms['val'],
                              imgSuffix=imgSuffix, annoSuffix=annoSuffix)}

dataloader = torch.utils.data.DataLoader(dsets['train'],
                                          batch_size = batch_size,
                                          shuffle = True,
                                          num_workers = num_workers,
                                          collate_fn = collater,
                                          pin_memory = True)



optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, seg_model.parameters()), lr=init_lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.98, last_epoch=-1)
criterion = SEG_loss(height=img_height, width=img_width)


if vis:
    cv2.namedWindow('img')
    for idx in range(len(dsets['train'])):
        img, bboxes, labels, masks = dsets['train'].__getitem__(idx)
        img = img.numpy().transpose(1, 2, 0).copy()*255
        print(img.shape)
        bboxes = bboxes.numpy()
        labels = labels.numpy()
        masks = masks.numpy()
        for idx in range(bboxes.shape[0]):
            y1, x1, y2, x2 = bboxes[idx,:]
            y1 = int(y1)
            x1 = int(x1)
            y2 = int(y2)
            x2 = int(x2)
            cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), 2, lineType=1)
            mask = masks[idx, :, :]
            img = map_mask_to_image(mask, img, color=np.random.rand(3))
        cv2.imshow('img', img)
        k = cv2.waitKey(0)
        if k & 0xFF == ord('q'):
            cv2.destroyAllWindows()
            exit()
    cv2.destroyAllWindows()

# for validation data -----------------------------------
detector = Detect(num_classes=num_classes,
                  top_k=top_k,
                  conf_thresh=conf_thresh,
                  nms_thresh=nms_thresh,
                  variance=[0.1, 0.2])
anchorGen = Anchors(img_height, img_width)
anchors = anchorGen.forward()
# --------------------------------------------------------
train_loss_dict = []
ap05_dict = []
ap07_dict = []
loss2 = 0.5
for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    for phase in ['train', 'val']:
        if phase == 'train':
            scheduler.step()
            seg_model.train()
            running_loss = 0.0
            for inputs, bboxes, labels, masks in dataloader:
                inputs = inputs.to(device)
                with torch.no_grad():
                    locs, conf, feat_seg = dec_model(inputs)
                    detections = detector(locs, conf, anchors)

                optimizer.zero_grad()
                with torch.enable_grad():
                    outputs = seg_model(detections, feat_seg)
                    loss = criterion(outputs, bboxes, labels, masks)
                    if loss is None:
                        loss = loss2
                    else:
                        loss2 = loss
                        loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)

            epoch_loss = running_loss / len(dsets[phase])

            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
            train_loss_dict.append(epoch_loss)
            np.savetxt(Seg_log_Files + '/seg_train_loss.txt', train_loss_dict, fmt='%.6f')
            if epoch % 50 == 0:
                torch.save(seg_model.state_dict(),
                            os.path.join(Seg_weight_Dst, '{:d}_{:.4f}_model.pth'.format(epoch, epoch_loss)))
            torch.save(seg_model.state_dict(), os.path.join(Seg_weight_Dst, 'end_model.pth'))

        else:
            if epoch % 50 == 0:
                seg_model.eval()   # Set model to evaluate mode
                ap_05, ap_07 = seg_eval_kaggle.do_python_eval(dsets=dsets[phase], dec_model=dec_model, seg_model=seg_model,
                                                        detector=detector, anchors=anchors, device=device,
                                                        args=args, offline=False)
                # print('ap05:{:.4f}, ap07:{:.4f}'.format(ap05, ap07))
                ap05_dict.append(ap_05)
                np.savetxt(Seg_log_Files + '/seg_ap_05.txt', ap05_dict, fmt='%.6f')
                ap07_dict.append(ap_07)
                np.savetxt(Seg_log_Files + '/seg_ap_07.txt', ap07_dict, fmt='%.6f')

print('Finish')

In [None]:
%cd /content/drive/MyDrive/Colab Notebooks/ANCIS/

#@markdown ___
#@markdown ## **2. Test Segmentation Network**

#@markdown Test the Instance Segmentation Network.  
#@markdown
#@markdown ### **A. Select Segmentation Weights**
SEGWEIGHT = 'Tissue' #@param ['Tissue', 'Cell', 'Combine', 'Kaggle', 'Select', 'Drive'] {type: 'string'}
if SEGWEIGHT == 'Tissue':
  seg_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/SegWeights/ANCIS_SegWeight_Tissue/end_model.pth"
elif SEGWEIGHT == 'Cell':
  seg_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/SegWeights/ANCIS_SegWeight_Cell/end_model.pth"
elif SEGWEIGHT == 'Combine':
  seg_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/SegWeights/ANCIS_SegWeight_Combine/end_model.pth"
elif SEGWEIGHT == 'Kaggle':
  seg_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/SegWeights/ANCIS_SegWeight_Kaggle/end_model.pth"
elif SEGWEIGHT == 'Select':
  weights = files.upload()
  for name, data in weights.items():
    with open(name, 'wb') as f:
      print ('saved file', name)
    seg_weights = os.path.join(os.getcwd(), name)
elif SEGWEIGHT == 'Drive':
  #@markdown --- If Selecting segmentation weights stored in the cloud or Drive, provide path
  seg_weights = "/content/drive/MyDrive/Colab Notebooks/ANCIS/Seg_Weights/end_model.pth" #@param {type: "string"}

#@markdown ### **B. Select Detection Weights**
DECWEIGHT = 'Tissue' #@param ['Tissue', 'Cell', 'Combine', 'Kaggle', 'Select', 'Drive'] {type: 'string'}
if DECWEIGHT == 'Tissue':
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Tissue/end_model.pth"
elif DECWEIGHT == 'Cell':
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Cell/end_model.pth"
elif DECWEIGHT == 'Combine':
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Combine/end_model.pth"
elif DECWEIGHT == 'Kaggle':
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/weights/ANCIS/DecWeights/ANCIS_DecWeight_Kaggle/end_model.pth"
elif DECWEIGHT == 'Select':
  weights = files.upload()
  for name, data in weights.items():
    with open(name, 'wb') as f:
      print ('saved file', name)
    dec_weights = os.path.join(os.getcwd(), name)
elif DECWEIGHT == 'Drive':
  #@markdown --- If Selecting detection weights stored in the cloud or Drive, provide path
  dec_weights = "/content/drive/MyDrive/Colab Notebooks/ANCIS/Dec_Weights/end_model.pth" #@param {type: "string"}

#@markdown ### **C. Segmented Image Save Path**
SAVE_PATH = "/root/Save/Segment" #@param {type: "string"}
if not os.path.exists(SAVE_PATH):
  os.makedirs(SAVE_PATH)

#@markdown ### **D. Conduct Additional Segmentation Post-Processing?**
#@markdown *Overlap Removal (experimental)*
PROC = True #@param {type: 'boolean'}

#@markdown ### **E. Overlay Post-processed Images for Display (requires more time to setup)**
OVER = True #@param {type: 'boolean'}

if OVER:
  !pip install colorspacious
  !git clone --quiet https://github.com/taketwo/glasbey.git
  import colorspacious
  from glasbey import Glasbey
  from skimage.color import label2rgb

  color = np.array(([1,0,0],[0,1,0],[0,0,1],[1,1,0],[0,1,1],[1,0,1],[1,0.5,0],[0.5,1,0],[0,1,0.5],[0,0.5,1],[1,0,0.5],[0.5,0,1],[1,0.5,0.25],[0.25,0.5,1],[1,0.25,0.5],[0.5,0.25,1],[0.5,1,0.25],[0.25,1,0.5]),np.float32)
  gb = Glasbey(base_palette=color, chroma_range = (60,100), no_black=True)
  c4 = gb.generate_palette(size=18)
  #c4 = gb.load_palette('/content/drive/MyDrive/Colab Notebooks/ANCIS/glasbey/rgb_cam02ucs_lut.npz')
  color4 = c4[1:]

  def normalized(rgb):
    norm=np.zeros((512,512,3),np.float32)
    norm_rgb=np.zeros((512,512,3),np.uint8)

    b=rgb[:,:,0]
    g=rgb[:,:,1]
    r=rgb[:,:,2]

    sum=b+g+r

    norm[:,:,0]=b/sum*255.0
    norm[:,:,1]=g/sum*255.0
    norm[:,:,2]=r/sum*255.0

    norm_rgb=cv2.convertScaleAbs(norm)
    return norm_rgb

  def overlay(mask, orig, clr):
    maskPR = label(mask)
    labels = label2rgb(label=maskPR, bg_label=0, bg_color=(0, 0, 0), colors=clr)
    L2 = normalized(labels)
    if len(orig.shape) < 3: 
      O2 = cv2.cvtColor(orig.astype('uint8'), cv2.COLOR_GRAY2BGR)
    else:
      O2 = orig
    comb = cv2.addWeighted(L2.astype('float64'),0.5,O2.astype('float64'),0.5,0)
    return comb


import argparse
import torch.optim as optim
from torch.optim import lr_scheduler

from seg_utils import *
from dec_utils import *
from seg_utils import seg_transforms, seg_dataset_kaggle, seg_eval

from models import dec_net_seg, seg_net
import cv2
import os

from google.colab.patches import cv2_imshow

def load_dec_weights(dec_model, dec_weights):
    print('Resuming detection weights from {} ...'.format(dec_weights))
    dec_dict = torch.load(dec_weights, map_location=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
    dec_dict_update = {}
    for k in dec_dict:
        if k.startswith('module') and not k.startswith('module_list'):
            dec_dict_update[k[7:]] = dec_dict[k]
        else:
            dec_dict_update[k] = dec_dict[k]
    dec_model.load_state_dict(dec_dict_update, strict=True)
    return dec_model

#-----------------load detection model -------------------------
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dec_model = dec_net_seg.resnetssd50(pretrained=False, num_classes=num_classes)
dec_model = load_dec_weights(dec_model, dec_weights)
dec_model = dec_model.to(device)
dec_model.eval()
#-----------------load segmentation model -------------------------
seg_model =  seg_net.SEG_NET(num_classes=num_classes)
seg_model.load_state_dict(torch.load(seg_weights, map_location=torch.device('cpu')))
seg_model= seg_model.to(device)
seg_model.eval()
##--------------------------------------------------------------
data_transforms = seg_transforms.Compose([seg_transforms.ConvertImgFloat(),
                                    seg_transforms.Resize(img_height, img_width),
                                    seg_transforms.ToTensor()])


dsets = seg_dataset_kaggle.NucleiCell(testDir, annoDir, data_transforms,
                                      imgSuffix=imgSuffix, annoSuffix=annoSuffix)

# for validation data -----------------------------------
names = dsets.load_img_ids()
detector = Detect(num_classes=num_classes,
                  top_k=top_k,
                  conf_thresh=conf_thresh,
                  nms_thresh=nms_thresh,
                  variance=[0.1, 0.2])
anchorGen = Anchors(img_height, img_width)
anchors = anchorGen.forward()
if not os.path.exists(SAVE_PATH + '/masks'):
    os.makedirs(os.path.join(SAVE_PATH,'masks'))
if not os.path.exists(SAVE_PATH + '/seg'):
    os.makedirs(os.path.join(SAVE_PATH,'seg'))
# for img_idx in [1,55,57,72,78,123]:
for img_idx in range(len(dsets)):
  print('loading {}/{} image'.format(img_idx, len(dsets)))
  inputs, gt_boxes, gt_classes, gt_masks = dsets.__getitem__(img_idx)
  ori_img = dsets.load_img(img_idx)
  image = ori_img.copy()
  #ori_img_copy = ori_img.copy()
  #bboxes, labels, masks = dsets.load_annotation(dsets.img_files[img_idx])
  #for mask in masks:
  #    ori_img = map_mask_to_image(mask, ori_img, color=np.random.rand(3))
  h,w,c = ori_img.shape
  x = inputs.unsqueeze(0)
  x = x.to(device)
  locs, conf, feat_seg = dec_model(x)
  detections = detector(locs, conf, anchors)
  outputs = seg_model(detections, feat_seg)
  mask_patches, mask_dets = outputs
  ames = names[img_idx]
  ame = (os.path.basename(ames))
  print(ame)
  # For batches
  zees = np.zeros([img_height, img_width], dtype='uint8')
  maskD = np.zeros([zees.shape[0], zees.shape[1]], dtype='uint8')
  diff = np.zeros([zees.shape[0], zees.shape[1]], dtype='uint8')
  for b_mask_patches, b_mask_dets in zip(mask_patches, mask_dets):
    nd = len(b_mask_dets)
    # Step1: rearrange mask_patches and mask_dets
    for d in range(nd):
      d_mask = np.zeros((img_height, img_width), dtype=np.float32)
      d_mask_det = b_mask_dets[d].data.cpu().numpy()
      d_mask_patch = b_mask_patches[d].data.cpu().numpy()
      d_bbox = d_mask_det[0:4]
      d_conf = d_mask_det[4]
      d_class = d_mask_det[5]
      if d_conf < conf_thresh:
        continue
      [y1, x1, y2, x2] = d_bbox
      y1 = np.maximum(0, np.int32(np.round(y1)))
      x1 = np.maximum(0, np.int32(np.round(x1)))
      y2 = np.minimum(np.int32(np.round(y2)), img_height - 1)
      x2 = np.minimum(np.int32(np.round(x2)), img_width - 1)
      d_mask_patch = cv2.resize(d_mask_patch, (x2 - x1 + 1, y2 - y1 + 1))
      d_mask_patch = np.where(d_mask_patch >= seg_thresh, 1., 0.)
      d_mask[y1:y2 + 1, x1:x2 + 1] = d_mask_patch
      d_mask = cv2.resize(d_mask, dsize=(w,h), interpolation=cv2.INTER_NEAREST)
      ori_img = map_mask_to_image(d_mask, ori_img, color=np.random.rand(3))
      #zees = (d+1)*d_mask + zees

      # Additional Post-Processing

      #for n in range(0,masks.shape[2]):
      if PROC == False:
        zees = (d+1)*d_mask + zees
      elif PROC == True:
        d_mask = d_mask.astype('uint8')
        #d_mask[d_mask>0] = 1
        M2 = label(d_mask)
        props2 = regionprops(M2)
        for m in range(0,M2.max()):
          if props2[m].area < 750:
            M2[M2==props2[m].label] = 0
        M2[M2 > 0] = 1
        d_mask = M2*d_mask
        props2 = regionprops(d_mask)
        maskD = maskD + d_mask
        if maskD.max() <= 1:
          zees = zees + (d+1)*d_mask
        else:
          try:
            diff[maskD > 1] = 1
            diff2 = diff.copy()
            pd = regionprops(diff)

            area2 = props2[0].area 
            aread = pd[0].area
            Vals = diff*zees # Find value of existing region label, under new overlap
            vals = Vals[Vals>0] # Not zero
            vals = vals[vals != d+1] # Not the current label
            vals = list(set(vals)) # Really should only be one left
            z2 = np.zeros([img_height, img_width], dtype='uint8')
            z2[zees == vals[0]] = 1
            props1 = regionprops(z2)
            area1 = props1[0].area
            div1 = aread/area1
            div2 = aread/area2
            zees = zees + (d+1)*d_mask

            if div1 < 0.15 and div2 < 0.15:
              zees[diff > 0] = vals[0]
              #zees[zees==d+1] = vals[0]
            elif div1 < 0.15 and div2 > 0.15:
              zees[diff > 0] = d+1
              #zees[zees==vals[0]] = d+1
            elif div1 > 0.15 and div2 < 0.15:
              zees[diff > 0] = vals[0]
              #zees[zees==d+1] = vals[0]
            elif div1 > 0.15 and div2 > 0.15 and div1 < 0.6 and div2 < 0.6:
              y0, x0 = pd[0].centroid
              orientation = pd[0].orientation

              x1 = x0 - math.sin(orientation) * 0.55 * pd[0].major_axis_length
              y1 = y0 - math.cos(orientation) * 0.55 * pd[0].major_axis_length
              x2 = x0 + math.sin(orientation) * 0.55 * pd[0].major_axis_length
              y2 = y0 + math.cos(orientation) * 0.55 * pd[0].major_axis_length 

              cv2.line(diff, (int(x2),int(y2)), (int(x0),int(y0)), (0, 0, 0), thickness=2)
              cv2.line(diff, (int(x1),int(y1)), (int(x0),int(y0)), (0, 0, 0), thickness=2)

              lbl1 = label(diff)
              lbl1 = lbl1.astype('uint8')
              cv2.line(lbl1, (int(x2),int(y2)), (int(x0),int(y0)), (1, 1, 1), thickness=2)
              cv2.line(lbl1, (int(x1),int(y1)), (int(x0),int(y0)), (1, 1, 1), thickness=2)
              lbl2 = lbl1*diff2
              zees[lbl2 == 2] = d+1
              zees[lbl2 == 1] = vals[0]
                                      
            elif div1 > 0.6 or div2 > 0.6:
              if area1 > area2:
                zees[diff > 0] = vals[0]
                zees[zees==d+1] = vals[0]
              elif area2 > area1:
                zees[diff > 0] = d+1
                zees[zees==vals[0]] = d+1
            
          except Exception as e:
            print(e)
            continue

      maskD[maskD > 1] = 1
      diff = np.zeros([zees.shape[0], zees.shape[1]], dtype='uint8')

  zees = label(zees)
  propsz = regionprops(zees)
  try:
    for m in range(0,zees.max()):
      if propsz[m].area < 750:
        zees[zees==propsz[m].label] = 0
  except Exception as e:
    print(e)
  #cv2_imshow(ori_img)
  if OVER and PROC:
    ovr = overlay(zees, image, color4)
    cv2_imshow(ovr)
  #k = cv2.waitKey(0)
  #if k&0xFF==ord('q'):
  #    cv2.destroyAllWindows()
  #    exit()
  #elif k&0xFF==ord('s'):
      # cv2.imwrite('kaggle_imgs/{}_ori.png'.format(img_idx), ori_img_copy)
  cv2.imwrite(SAVE_PATH + '/masks/' + ame, zees)
  if OVER and PROC:
    cv2.imwrite(SAVE_PATH + '/seg/' + ame, ovr)
  else:
    cv2.imwrite(SAVE_PATH + '/seg/' + ame, ori_img)
cv2.destroyAllWindows()
print('Finish')

In [None]:
#@markdown ## **3. Zip and Download Predictions to Local Drive**
#@markdown If download does not occur, check if browser is blocking.

import shutil

output_filename = 'Results' #@param {type: 'string'}
dir_name = SAVE_PATH
shutil.make_archive(output_filename, 'zip', dir_name, verbose=1)

files.download(output_filename + '.zip')