### Prerequisites

This file was last updated on May 3rd, 2022 - if any changes have been made to the class found in ./my_model.py then there may be issues with running this code.

In [1]:
import os
import json
import numpy as np

### Load in Model

In [2]:
import torch
import torchvision.models

import my_model
import utils

In [3]:
path_to_model_weights = "model.pth"

In [4]:
model_densenet = torchvision.models.DenseNet(
    growth_rate=32,
    block_config=(2,2,2,2),
    num_init_features=64,
    bn_size=4,
    drop_rate=0,
    num_classes=3
)
model_mymodel = my_model.MyModel(
    model = model_densenet,
    loss_fn = None,
    device = torch.device('cpu'),
    checkpoint_dir= None,
    optimizer=None,
    model_dir=None
)

In [5]:
model_mymodel.load_model(path_to_model_weights)

  return torch._C._cuda_getDeviceCount() > 0


In [7]:
# test the model real quick to make sure things work
from unified_image_reader import Image
img_path = "/workspaces/dev-container/testing/data/whole_slide_images/92321.tif"
img = Image(img_path)
print(img.dims, img.number_of_regions())
region = img.get_region(100)

(45000, 40000) 6786


In [8]:
print(model_mymodel.diagnose_region(region))
# test diagnose_wsi???

1


### Wrapper on MyModel for WebApp Compatibility

In [9]:
from model_manager_for_web_app import ManagedModel
from filtration import FilterManager, FilterBlackAndWhite, FilterHSV, FilterFocusMeasure

In [10]:
class WrappedModel(ManagedModel):
    def __init__(self, model, classes = ('Mild', 'Moderate', 'Severe'), aggregation_weights = (0,1,2)):
        self.model = model
        self.classes = classes
        self.aggregation_weights = aggregation_weights
        self.filtration = FilterManager([
            FilterBlackAndWhite(),
            FilterHSV(),
            FilterFocusMeasure()
        ])
        self._checked_for_device = False
    def diagnose(self, region_stream):
        """
            model takes in a stream of regions (numpy arrays) and produces diagnosis

            Example:
                # diagnosis is whichever category has the most 'votes'
                votes = {'positive':0, 'negative':0}
                for region in region_stream:
                    votes[self.process(region)] += 1
                return max(votes, key=votes.get) # key with max value
        """
        # first check to see if we can use hardware
        if not self._checked_for_device:
            print("checking for device to migrate")
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            self.model.model.to(device)
            self._checked_for_device = True
            print("finished migrating device")
        # initialize the weights
        votes = {i:0 for i in range(len(self.classes))}
        # diagnose each region
        for region in region_stream:
            if self.filtration(region) is True:
                region_diagnosis = self.model.diagnose_region(region)
                votes[region_diagnosis] += 1
        print(votes)
        # aggregate the votes using weights
        vote = sum([votes[c]*self.aggregation_weights[c] for c in range(len(votes))]) / sum(list(votes.values()))
        vote = round(vote)
        return self.classes[vote]

### Prepare Model for Saving

1. Wrap the model in a ManagedModel class (you will need to create a subclass of ManagedModel just as above)
2. Register any dependencies that might not be available to the WebApp when this model is deserialized. 
    - To identify whether you need to register a dependency, consider the code used to create the serialized object that may not be available to the WebApp when deserializing. 

In [11]:
model_wrapped = WrappedModel(model=model_mymodel)

In [12]:
test_diagnosis = False
if test_diagnosis:
    from tqdm import tqdm as loadingbar
    model_wrapped.diagnose(loadingbar(img))

### Save the Model

In [13]:
from model_manager_for_web_app import ModelManager

In [14]:
model_manager = ModelManager()
model_manager.save_model(
    model_name = "kevin_test",
    model = model_wrapped,
    model_info = {
        "info": "idk man anything you want to record"
    },
    overwrite_model=True,
    dependency_modules = [
        my_model,
        utils
    ]
)