In [1]:
# Install a pip package in the current Jupyter kernel
import sys
!{sys.executable} -m pip install seaborn

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import os
import numpy as np
from tqdm import tqdm
from skimage.io import imread, imshow, imread_collection, concatenate_images
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNN
import torchvision.transforms as T
from torchvision.transforms import functional as F
from skimage import draw,io,segmentation
import time
import sys
import torch.optim as optim
import helper_functions as myutils
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import imgaug as ia
from imgaug import augmenters as iaa
import seaborn as sns
import pandas as pd

In [3]:
device=myutils.set_to_gpu()
myutils.setup_seed(4)
ia.seed(5)
myutils.gpu_mem_allocated()
TRAIN_PATH="/scratch_tmp/snk218/stage1_train"
IMG_CHANNELS=3

There are 1 GPU(s) available.
We will use the GPU: Tesla T4
Device being used: cuda


In [4]:
# BELOW IS TAKEN FROM 3rd place entry in competition: "DeepRetina" (https://github.com/Gelu74/DSB_2018)
# "The dataset doesn't have a standard train/val split, so I picked a variety of images to surve as a validation set."
VAL_IMAGE_IDS = [
    "0c2550a23b8a0f29a7575de8c61690d3c31bc897dd5ba66caec201d201a278c2",
    "92f31f591929a30e4309ab75185c96ff4314ce0a7ead2ed2c2171897ad1da0c7",
    "1e488c42eb1a54a3e8412b1f12cde530f950f238d71078f2ede6a85a02168e1f",
    "c901794d1a421d52e5734500c0a2a8ca84651fb93b19cec2f411855e70cae339",
    "8e507d58f4c27cd2a82bee79fe27b069befd62a46fdaed20970a95a2ba819c7b",
    "60cb718759bff13f81c4055a7679e81326f78b6a193a2d856546097c949b20ff",
    "da5f98f2b8a64eee735a398de48ed42cd31bf17a6063db46a9e0783ac13cd844",
    "9ebcfaf2322932d464f15b5662cae4d669b2d785b8299556d73fffcae8365d32",
    "1b44d22643830cd4f23c9deadb0bd499fb392fb2cd9526d81547d93077d983df",
    "97126a9791f0c1176e4563ad679a301dac27c59011f579e808bbd6e9f4cd1034",
    "e81c758e1ca177b0942ecad62cf8d321ffc315376135bcbed3df932a6e5b40c0",
    "f29fd9c52e04403cd2c7d43b6fe2479292e53b2f61969d25256d2d2aca7c6a81",
    "0ea221716cf13710214dcd331a61cea48308c3940df1d28cfc7fd817c83714e1",
    "3ab9cab6212fabd723a2c5a1949c2ded19980398b56e6080978e796f45cbbc90",
    "ebc18868864ad075548cc1784f4f9a237bb98335f9645ee727dac8332a3e3716",
    "bb61fc17daf8bdd4e16fdcf50137a8d7762bec486ede9249d92e511fcb693676",
    "e1bcb583985325d0ef5f3ef52957d0371c96d4af767b13e48102bca9d5351a9b",
    "947c0d94c8213ac7aaa41c4efc95d854246550298259cf1bb489654d0e969050",
    "cbca32daaae36a872a11da4eaff65d1068ff3f154eedc9d3fc0c214a4e5d32bd",
    "f4c4db3df4ff0de90f44b027fc2e28c16bf7e5c75ea75b0a9762bbb7ac86e7a3",
    "4193474b2f1c72f735b13633b219d9cabdd43c21d9c2bb4dfc4809f104ba4c06",
    "f73e37957c74f554be132986f38b6f1d75339f636dfe2b681a0cf3f88d2733af",
    "a4c44fc5f5bf213e2be6091ccaed49d8bf039d78f6fbd9c4d7b7428cfcb2eda4",
    "cab4875269f44a701c5e58190a1d2f6fcb577ea79d842522dcab20ccb39b7ad2",
    "8ecdb93582b2d5270457b36651b62776256ade3aaa2d7432ae65c14f07432d49",
]

In [5]:
# For training: resize image to atleast 512 at min. side and crop to 512x512
MIN_DIM=512 # resize original image to atleast 512 at the minimum side
CROP_DIM=512 # if larger, crop to 512x512

In [6]:
# Nucleus data set

# pytorch Mask R CNN model expects an Image (PIL) and, for ground truth targets, a dictionary with these keys:
# boxes (FloatTensor[N, 4]): the coordinates of the N bounding boxes in [x0, y0, x1, y1] format, ranging from 0 to W and 0 to H
# labels (Int64Tensor[N]): the label for each bounding box
# image_id (Int64Tensor[1]): an image identifier. It should be unique between all the images in the dataset, and is used during evaluation
# area (Tensor[N]): The area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes.
# iscrowd (UInt8Tensor[N]): instances with iscrowd=True will be ignored during evaluation.
# (optionally) masks (UInt8Tensor[N, H, W]): The segmentation masks for each one of the objects
# One note on the labels. The model considers class 0 as background. If your dataset does not contain the background class, you should not have 0 in your labels.

class NucleusDataset(torch.utils.data.Dataset):
    def __init__(self, root, train=True, min_dim=MIN_DIM, crop_dim=CROP_DIM, filter_ids=[], img_aug_seq=None):
        self.root=root # e.g. "stage1_train"
        self.train=train # will resize and random crop the image if True
        self.resize_t=T.Resize(min_dim)
        self.crop_dim=crop_dim
        self.aug_seq=img_aug_seq

        # a subset of the train set is being used as validation: "filter_ids": the image ids for validation 
        if(len(filter_ids)>0):
            if(train):
                img_ids=next(os.walk(root))[1]
                self.img_ids=list(set(img_ids) - set(filter_ids))
            else:
                self.img_ids=filter_ids
        else:
            self.img_ids=next(os.walk(root))[1]
        
    def __getitem__(self, idx):

        if(type(idx) == type("")):  # Note: don't use this when sending to model on cuda device
          image_id = idx
          idx=self.img_ids.index(image_id)
        else:
          image_id = self.img_ids[idx] # Use numerical index when sending to model on cuda device
          
        # Load image
        img_file = os.path.join(self.root, image_id, "images", image_id+".png")
        image = Image.open(img_file).convert("RGB") # must return PIL image 
        
        # Load masks - masks are a series of binary images - one per nucleus
        mask_files = next(os.walk(os.path.join(self.root, image_id, "masks")))[2]
        masks = np.empty((len(mask_files),image.height,image.width), dtype=bool)
        
        for i,mask_file in enumerate(mask_files):
            masks[i] = imread(os.path.join(self.root, image_id, "masks", mask_file)) != 0

        if self.train and self.aug_seq is not None: 
            # Image augmentations
            # convert mask array to SegmentationMapsOnImage instance: input array shape = (H,W,C)
            segmap = SegmentationMapsOnImage(np.stack(masks,axis=-1), shape=np.array(image).shape)
            image_aug, segmap_aug = self.aug_seq(image=np.array(image), segmentation_maps=segmap)
            
            masks=segmap_aug.get_arr()
            masks=np.transpose(masks, (2, 0, 1))
            
            image=F.to_pil_image(image_aug,mode="RGB")
            
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        if self.train: 
            ##### RESIZE/RANDOM CROPS #####
            # resize image and masks
            image=self.resize_t(image)
            masks=self.resize_t(masks) 
            
            # random crop image
            w, h = image.size  
            tw, th = self.crop_dim,self.crop_dim
            
            if (w > tw or h > th):
                if(w <= tw): tw=w
                if(h <= th): th=h
                top = torch.randint(0, h - th + 1, size=(1, )).item()
                left = torch.randint(0, w - tw + 1, size=(1, )).item()
                image = image.crop((left, top, left + tw, top + th))  

                # crop masks
                masks=masks[..., top:top + th, left:left + tw]

        # Set up target dictionary return structure 
        # get bbox coords for each mask
        boxes = []
        rr = []
        for i in range(len(masks)):
            pos = np.where(masks[i])
            if(len(pos[0]) > 0):
                # filter if mask not in the image (since it was cropped)
                # NOTE: the bbox expected is (x0,y0,x-max,y-max)
                # NOT usual numpy way!
                xmin = np.min(pos[1]) 
                xmax = np.max(pos[1])
                ymin = np.min(pos[0])
                ymax = np.max(pos[0])
                
                # filter if mask bbox length or width == 1
                if(xmax-xmin > 0 and ymax-ymin > 0):
                    boxes.append([xmin, ymin, xmax, ymax])
                    rr.append(i)
                
        num_objs = len(rr)

        if(num_objs==0):
            return None

        masks = masks[rr,:]

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((num_objs,), dtype=torch.int64) # there is only one class
        masks = torch.as_tensor(masks, dtype=torch.uint8) 
        idx=torch.as_tensor(idx,dtype=torch.int64)
        
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        # suppose all instances are not crowd: "instances with iscrowd==True wil be ignored during evaluation"
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes        # bounding boxes of the nuclei masks
        target["labels"] = labels      # labels, an array of ones, since we only have one class (nucleus)
        target["masks"] = masks        # pixel coordinates of the nuclei masks
        target["image_id"] = idx       # image id: numerical index
        target["area"] = area          # mask areas
        target["iscrowd"] = iscrowd    # always set to False 

        image = F.to_tensor(image)
        return image, target

    def __len__(self):
        return len(self.img_ids)
    
def nucleus_collate_fn(batch):
    batch = filter(lambda x:x is not None, batch)
    return tuple(zip(*batch))        

In [7]:
def show_data(img,label,model_output=False):

    if(isinstance(img, torch.Tensor)):
        img_=np.array(F.to_pil_image(img,mode="RGB"))
    else:
        img_=img
  
    if(model_output):
        fig,ax=plt.subplots(1,3,figsize=(20,30))
    else:
        #fig,ax=plt.subplots(1,2,figsize=(10,20))
        fig,ax=plt.subplots(1,1,figsize=(20,20))

    out_img=img_.copy()
    
    if(len(label['masks'])>0):
        labeled_img=np.zeros(label['masks'][0].squeeze().detach().numpy().shape, dtype='uint8') 
        
        if(model_output):
            for i,bbox in enumerate(label["boxes"]):
                if(label['scores'][i]>=0.5):
                    bbox_=[int(x.item()) for x in bbox]
                    mask=(label['masks'][i].squeeze().detach().numpy()>0.5).astype('uint8')

                    #print(f"{i}: {bbox}")
                    

                    if((i+1)<256): 
                        # check first if this mask is overlaping a previous mask: if so, skip it.
                        # is this normal for MaskRCNN to output overlapping masks?
                        if np.max(mask*labeled_img)>0:
                            pass
                            #print("Overlap")
                        else:
                            labeled_img = labeled_img+(mask*(i+1))

                            # Note!  bboxes are in x0,y0,xmax,ymax format 
                            rr,cc = draw.rectangle_perimeter((bbox_[1],bbox_[0]), 
                                                      extent=(bbox_[3]-bbox_[1],
                                                              bbox_[2]-bbox_[0]),
                                                      shape=out_img.shape)
                    
                            out_img[rr, cc, :] = [255,0,0]
                        
                    #if(i >= 2): break # ********* # <------
                    
                else:
                    break # scores are sorted high to low
            marked=segmentation.mark_boundaries(img_,labeled_img,) 
        else:
            for i,mask in enumerate(label['masks']):
                if((i+1)<256):
                    labeled_img = labeled_img+(mask.detach().numpy()*(i+1))
            marked=segmentation.mark_boundaries(img_,labeled_img,) 

            for bbox in label["boxes"]:
                bbox_=[int(x.item()) for x in bbox]

                # Note!  bboxes are in x0,y0,xmax,ymax format 
                rr, cc = draw.rectangle_perimeter((bbox_[1],bbox_[0]), 
                                                  extent=(bbox_[3]-bbox_[1],
                                                          bbox_[2]-bbox_[0]),
                                                  shape=out_img.shape)
                out_img[rr, cc, :] = [255,0,0]
    else:
        print("Alert: No masked nuclei in image.")
        marked=out_img
        
    if(model_output):
        print(f"Dim: h={img_.shape[0]} w={img_.shape[1]}")
        ax[0].imshow(out_img)
        ax[1].imshow(labeled_img)
        ax[2].imshow(marked)
    else:
        print(f"Image ID: {label['image_id']}") 
        print(f"Dim: h={img_.shape[0]} w={img_.shape[1]}")
        #ax[0].imshow(out_img)
        #ax[1].imshow(marked)
        
        ax.imshow(marked)
    
    #if(not model_output):
    #    for bbox,area in zip(label["boxes"],label['area']):
    #        bbox_=[int(x.item()) for x in bbox]
    #        ax[1].text(bbox_[0],bbox_[1],str(int(area.item())), color='red')


In [10]:
from torchvision.models import resnet101
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models.detection.anchor_utils import AnchorGenerator
def resnet_with_fpn(backbone, trainable_backbone_layers):
    
    layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_backbone_layers]
    if trainable_backbone_layers == 5:
        layers_to_train.append("bn1")

    for name, parameter in backbone.named_parameters():
        if all([not name.startswith(layer) for layer in layers_to_train]):
            parameter.requires_grad_(False)

    extra_blocks = LastLevelMaxPool()
    returned_layers = [1, 2, 3, 4]
    return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}

    in_channels_stage2 = backbone.inplanes // 8
    in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
    out_channels = 256

    return BackboneWithFPN(backbone, 
                             return_layers, 
                             in_channels_list, 
                             out_channels, 
                             extra_blocks=extra_blocks)

def get_model_instance_segmentation_resnet101(pretrained, num_classes, max_detections_per_img):
    
    if(pretrained):
        norm_layer=misc_nn_ops.FrozenBatchNorm2d
        trainable_layers=3 # allow final 3 layers to be trained
    else:
        norm_layer=nn.BatchNorm2d
        trainable_layers=5 # allow all layers to be trained
        
    backbone_resnet = resnet101(pretrained=pretrained, progress=True, norm_layer=norm_layer)
    backbone=resnet_with_fpn(backbone_resnet, trainable_backbone_layers=trainable_layers) 
    
    
    
    anchor_generator = AnchorGenerator(sizes=((8,), (16,), (32,), (64,), (128,)),  
                                       aspect_ratios=((0.5, 1.0, 2.0),
                                                     (0.5, 1.0, 2.0),
                                                     (0.5, 1.0, 2.0),
                                                     (0.5, 1.0, 2.0),
                                                     (0.5, 1.0, 2.0)) )
    
    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'],
                                                output_size=7,
                                                sampling_ratio=2)
    
    mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'],
                                                     output_size=14,
                                                     sampling_ratio=2)
        
    
    model = MaskRCNN(backbone, 
                     num_classes=num_classes, 
                     rpn_anchor_generator=anchor_generator,
                     box_roi_pool=roi_pooler,
                     mask_roi_pool=mask_roi_pooler,
                     box_detections_per_img=max_detections_per_img)
    
    
    
    # get number of input features for the box classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)
    return model



In [11]:
model0 = get_model_instance_segmentation_resnet101(True, 2, 500) 

In [12]:

def get_model_instance_segmentation(pretrained, pretrained_backbone, num_classes, max_detections_per_img):
    
    # Allow to change below since there are more than 100 nuclei in some images and default is 100
    # box_detections_per_img (int): maximum number of detections per image, for all classes 

    
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=pretrained, 
                                                               pretrained_backbone=pretrained_backbone,
                                                               box_detections_per_img=max_detections_per_img)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)
            
    return model


In [13]:
model1 = get_model_instance_segmentation(pretrained=False, 
                                         pretrained_backbone=True, 
                                         num_classes=2, 
                                         max_detections_per_img=500) 

In [14]:
# When running model for detection, it can output overlapping masks
# This not NOT allowed: Filter any masks that overlap another mask with a higher score
# NOTE: this function also detach model from Gradient and cuda device, and convert to numpy
def filter_output(model_output):
    # filter any masks that overlap a mask higher in the list
    # masks are sorted by score
    # go through masks and remove any that overlap a mask higher in the list
    
    remove_idx=[]
    labeled_img=np.zeros(model_output['masks'][0].squeeze().cpu().detach().numpy().shape, dtype='float32') 
    for i,mask in enumerate(model_output['masks']):
        mask=(mask.squeeze().cpu().detach().numpy()>0.5).astype('float32')
        if np.max(mask*labeled_img)>0:
            remove_idx.append(i)
        else:
            labeled_img = labeled_img+(mask*(i+1))
            
    for key in model_output.keys():
        model_output[key]=np.delete(model_output[key].cpu().detach().numpy(),remove_idx,axis=0)
        
    return model_output

In [15]:
def load_checkpoint(model, optimizer, scheduler, filename='checkpoint.pt.tar'): 
    # Note: Input model, optimizer and scheduler should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    best_model_epoch=0
    best_meanAP=0

    meanAP_arr = []
    loss_arr = []
    
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        
        loss_arr=checkpoint['loss_arr']
        meanAP_arr=checkpoint['meanAP_arr']

        best_meanAP = checkpoint['best_meanAP']
        best_model_epoch = checkpoint['best_model_epoch']

        model = model.to(device)

        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return (model, optimizer, scheduler, start_epoch, 
            meanAP_arr, loss_arr, best_model_epoch, best_meanAP)


In [16]:
def train_model(model, dataloader, optimizer, scheduler, num_epochs = 10, file_pre="model_state"):

    meanAP_arr = []
    loss_arr = []

    best_meanAP = 0
    best_model_epoch=0
    phases = ['train','validate']
    since = time.time()
    
    for i in range(num_epochs): 
        
        
        (model,
         optimizer, 
         scheduler, 
         i,
         meanAP_arr, 
         loss_arr, 
         best_model_epoch,
         best_meanAP) = load_checkpoint(model, 
                                        optimizer, 
                                        scheduler,
                                        filename=f"{file_pre}_checkpoint.pt.tar")
                            
        if(i >= num_epochs):
            break  # in case we've already run all epochs
            
        print('Epoch: {}/{}'.format(i, num_epochs-1))
        print('-'*10)
        
        for p in phases:
            
            running_total = 0

            if p == 'train':
                running_loss = 0
                model.train()
            else:
                running_meanAP = 0
                model.eval()
            
            # loop through batches:
            #for n, (images,targets) in tqdm(enumerate(dataloader[p]), total=len(dataloader[p])):
            for n, (images,targets) in enumerate(dataloader[p]):
            
                num_imgs=len(images)
                images = list(image.to(device) for image in images)

                if p == 'train':
                    optimizer.zero_grad()
                    
                    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                    loss = model(images, targets)
                    
                    losses = sum(loss for loss in loss.values())
                    losses.backward()
                    optimizer.step()
                    
                    #print(f"losses={losses.item()}")
                    running_loss += losses.item()*num_imgs
                else:
                    print(myutils.gpu_mem_allocated())
                    outputs = model(images)

                    # Mean AP over IoU range from 0.5 to 0.95 for each image in batch
                    # Average over all images in batch
                    for img_i in range(len(images)):
                        filtered_output=filter_output(outputs[img_i])
                        
                        #meanAP, all_AP, precisions, recalls=myutils.compute_ap_range(
                        meanAP, all_AP, precisions, recalls=myutils.compute_ap_range(
                                                        myutils.rearrange_boxes(targets[img_i]['boxes'].numpy()), 
                                                        targets[img_i]['labels'].numpy(), 
                                                        np.stack(targets[img_i]['masks'],axis=-1),
                                                        myutils.rearrange_boxes(filtered_output['boxes']), 
                                                        filtered_output['labels'], 
                                                        filtered_output['scores'], 
                                                        np.stack(filtered_output['masks'].squeeze(),axis=-1),
                                                        verbose=0)
                        #print(f"meanAP={meanAP}")
                        running_meanAP+=meanAP
                        
                    del outputs
                    #print(myutils.gpu_mem_allocated())
                        
                running_total += num_imgs
                
                del images
                del targets
                #print(myutils.gpu_mem_allocated())
                
            #print("End loop dataloader")
            print(myutils.gpu_mem_allocated())

            # Finish up for the epoch:
            if p == 'train':
                epoch_loss=float(running_loss/running_total)
                
                print('Phase: train, epoch loss: {:.6f}'.format(epoch_loss))
                loss_arr.append(epoch_loss) # epoch training loss

                if scheduler is not None:
                    scheduler.step()

            else: #p == 'validate':
                epoch_meanAP = float(running_meanAP/running_total)
                
                print('Phase: val, epoch meanAP: {:.6f}'.format(epoch_meanAP))
                meanAP_arr.append(epoch_meanAP) # epoch validation meanAP

                if epoch_meanAP > best_meanAP: 
                    best_meanAP = epoch_meanAP
                    best_model_epoch=i
                    
        # SAVE STATE AT END OF EPOCH
        checkpoint = {'epoch': i+1,
             'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(),
             'scheduler': scheduler.state_dict(),
             'meanAP_arr': meanAP_arr,
             'loss_arr': loss_arr,
             'best_meanAP': best_meanAP,
             'best_model_epoch': best_model_epoch,
             } 
        torch.save(checkpoint, f"{file_pre}_checkpoint.pt.tar")
        if(best_model_epoch == i):
            torch.save(checkpoint, f"{file_pre}_checkpoint-epoch_{i}.pt.tar")
        del checkpoint
                
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print(f"Best val epoch: {best_model_epoch}")
    print('Best val meanAP: {:6f}'.format(best_meanAP))
    
    checkpoint=torch.load(f"{file_pre}_checkpoint-epoch_{best_model_epoch}.pt.tar")
    model.load_state_dict(checkpoint['state_dict'])

    return model, best_model_epoch, best_meanAP, meanAP_arr, loss_arr  

In [17]:
myutils.gpu_mem_allocated()

0.0

In [18]:
# Tune LR / Other hyperparameters?

# Image augmentations
seq = iaa.SomeOf((0, 4), [
    iaa.Fliplr(0.5),
    iaa.Flipud(0.5),
    iaa.OneOf([iaa.Affine(rotate=90),
               iaa.Affine(rotate=180),
               iaa.Affine(rotate=270)]),
    iaa.Multiply((0.8, 1.5)),
    iaa.GaussianBlur(sigma=(0.0, 5.0)),
    iaa.OneOf([iaa.ScaleX((0.25, 2)),
               iaa.ScaleY((0.25, 2))]),
])

tune_params=True
if(tune_params):
    dataset_dict = {
    'train': NucleusDataset(TRAIN_PATH, train=True, filter_ids=VAL_IMAGE_IDS, img_aug_seq=seq), 
    'validate': NucleusDataset(TRAIN_PATH, train=False, filter_ids=VAL_IMAGE_IDS) 
    }

    num_epochs=3
    
    learning_rates=[0.00001,0.0001,0.001,0.01]
    batch_sizes = [4,6]
    
    all_best_losses=[]
    all_best_meanAPs=[]
    all_best_epochs=[]
    
    overall_best_meanAP=0
    overall_best_lr=0
    overall_best_bs=0

    load_from_file=True
    for learn_rate in learning_rates:
        for bs in batch_sizes:
            print()
            print(f"Learning rate: {learn_rate}")
            print(f"Batch size: {bs}")

            save_file=f"MRCNN_tune-lr_{learn_rate}_bs_{bs}_best_model_state.pt"

            if(load_from_file and os.path.exists(save_file)):
                state_dict=torch.load(save_file)
            else:
                data_loader_dict = {
                    'train':torch.utils.data.DataLoader(dataset_dict['train'],
                                        batch_size=bs, 
                                        shuffle=True, 
                                        num_workers=4,
                                        collate_fn=nucleus_collate_fn),
                    
                    'validate':torch.utils.data.DataLoader(dataset_dict['validate'],
                                        batch_size=bs, 
                                        shuffle=False, 
                                        num_workers=4,
                                        collate_fn=nucleus_collate_fn)
                }
                
                

                #model = get_model_instance_segmentation_resnet101(True, 2, 500) 

                model = get_model_instance_segmentation(pretrained=True, 
                                                         pretrained_backbone=False, 
                                                         num_classes=2, 
                                                         max_detections_per_img=500) 

                # move model to the right device
                model.to(device)

                params = [p for p in model.parameters() if p.requires_grad]

                optimizer = torch.optim.Adam(params, lr=learn_rate) 
                lambda_func = lambda epoch: 0.5 ** epoch 
                scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func)

                (best_model,
                best_model_epoch, 
                best_meanAP, 
                meanAP_arr, 
                loss_arr)=train_model(model, 
                                     data_loader_dict, 
                                     optimizer, 
                                     scheduler, 
                                     num_epochs,  
                                     file_pre=f"MRCNN_tune-lr_{learn_rate}_bs_{bs}_chk")
    
                best_model_state = {
                         'best_model': best_model.state_dict(),
                         'best_model_epoch': best_model_epoch,
                         'best_meanAP': best_meanAP,
                         'meanAP_arr': meanAP_arr,
                         'loss_arr': loss_arr,
                         } 
                torch.save(best_model_state, f"MRCNN_tune-lr_{learn_rate}_bs_{bs}_best_model_state.pt")
   
            cur_best_meanAP = best_model_state['meanAP_arr'][best_model_state['best_model_epoch']]
            cur_best_loss = best_model_state['loss_arr'][best_model_state['best_model_epoch']]

            all_best_losses.append(cur_best_loss)
            all_best_meanAPs.append(cur_best_meanAP)
            all_best_epochs.append(best_model_state['best_model_epoch'])
            if(cur_best_meanAP > overall_best_meanAP):
                overall_best_meanAP = cur_best_meanAP
                overall_best_bs=bs
                overall_best_lr=learn_rate

    print(overall_best_bs)
    print(overall_best_lr)
    print(overall_best_loss)
    print(cur_best_meanAP)

else:
    pass


Learning rate: 1e-05
Batch size: 4
=> no checkpoint found at 'MRCNN_tune-lr_1e-05_bs_4_chk_checkpoint.pt.tar'
Epoch: 0/2
----------


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


KeyboardInterrupt: 