In [None]:
!pip install shortuuid

In [None]:
import sys
import os
import math
import base64
import boto3
import sagemaker
import matplotlib.pyplot as plt
import numpy as np
import collections
from collections import defaultdict
from PIL import Image
import sklearn
from sklearn.metrics import ConfusionMatrixDisplay
from matplotlib.ticker import NullFormatter
from sklearn import manifold, datasets
from time import time

To avoid altering the conda environment of this notebook and introducing versioning conflicts, we do not pip -r the bioims cli requirements.txt file, although this would would be preferable than using pip directly here. 

In [None]:
s3c = boto3.client('s3')

In [None]:
%pwd

In [None]:
bioimsArtifactBucket='bioimagesearchbasestack-bioimagesearchdatabucketa-16h77xh6oyxmm'

In [None]:
# assumes cwd=/root/bioimage-search/datasets/bbbc-021/notebooks
sys.path.insert(0, "../../../cli/bioims/src")
import bioims as bi

In [None]:
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()

In [None]:
print(role)

## Prerequisites

### Permissions
This notebook requires adding the "BioimageSearch" managed policy to the above SageMaker execution role. Do this by using the IAM console to add the policy to the above role. The policy arn will be something like: arn:aws:iam::580829821648:policy/BioimageSearchResourcePermissionsStack-biomageSearchManagedPolicy9CB9C1D7-SXEV4WNUCZ7V

### TrainID
This notebook does a baseline, or sanity-check, on the training of the bbbc-021 dataset with a particular model, the context for which is specified by a BioimageSearch 'trainId', an output of the training process. If the evaluation of the results looks reasonable, then the comprehensive 'mechanism of action' (MOA) training series can be run, which creates a separate model for each chemical compound with known MOA. This collection of models, in turn, can be evaluated to determine the likelihood of whether a treatment (i.e., the application of a particular compound at a particular concentration) of unknown MOA would be properly classified. If so, then the model is likely to be useful for assigning MOA to treatments with unknown mechanisms of action, and more broadly for representing a molecule in 'MOA space'.

In [None]:
# bbbc021: trainId = 'r6KEudzQCuUtDwCzziiMZT'
# bbbc021-128:
trainId = 'vT44kUtLi7jnSGC7VXG7iT'

### Steps
* Get the Embedding for the TrainId
* Get the dimensions for the Embedding
* Get the list of compatible plates
* For each plate:
 * Get the origin row of each image to get its metadata
 * Get the embeddings for the specified TrainId
* Combine the metadata and embeddings into a tablular object
* Visualize separability
 * Compute the average embedding for each well
 * Label each well by known MOA (for cases where MOA is known)
 * Use a projection method (e.g., t-sne) to view the separability of results
* Create 'baseline' confusion matrix (not valid due to circular model inclusion, but just for sanity check)
 * For each treatment with known MOA:
   * Compute average across corresponding wells
   * Find the MOA of its nearest neighbor
 * Plot matrix

In [None]:
trainClient = bi.client('training-configuration')

In [None]:
trainInfo = trainClient.getTraining(trainId)

In [None]:
trainInfo

In [None]:
embeddingInfo = trainClient.getEmbeddingInfo(trainInfo['embeddingName'])

In [None]:
embeddingInfo

In [None]:
imageClient = bi.client('image-management')

In [None]:
plates = imageClient.listCompatiblePlates(embeddingInfo['inputWidth'], embeddingInfo['inputHeight'], embeddingInfo['inputDepth'], embeddingInfo['inputChannels'])

In [None]:
plates

In [None]:
imagePlateExample = imageClient.getImagesByPlateId(plates[0]['plateId'])

In [None]:
imagePlateExample[0]

In [None]:
embeddingPlateExample = imageClient.getImagesByPlateIdAndTrainId(plates[0]['plateId'], trainId)

In [None]:
embeddingPlateExample[0]

In [None]:
#e1 = "b'i6qJPbpKHL5IZSK+akeqPRVh570Wlvm869mrvXp6Pj5qKMy918wAvnCsQr6zHe+91nTsvalCALx3twG+Et21PRMvhb1Qzkw8mda2PYVZtD2RXyu7ggIWPs4AnT2keSE9oQ8ePqkEvL0bN4K97uzDPQ1NuD22JbO8TTCPvHkVCr0='"
e1 = embeddingPlateExample[0]['Item']['embedding']

In [None]:
e2 = e1.split("\'")

In [None]:
e3=e2[1].encode()

In [None]:
e4 = base64.decodebytes(e3)
e5 = np.frombuffer(e4, dtype=np.float32)

In [None]:
e5

In [None]:
embeddingWidth=len(e5)

In [None]:
embeddingWidth

In [None]:
plateMap={}

In [None]:
for plateEntry in plates:
    plateId=plateEntry['plateId']
    print("Adding plateId {}".format(plateId))
    images=imageClient.getImagesByPlateId(plateId)
    embeddings=imageClient.getImagesByPlateIdAndTrainId(plateId, trainId)
    imageMap={}
    for imageItem in images:
        image=imageItem['Item']
        imageId=image['imageId']
        imageMap[imageId]=image
    print("  found {} image entries".format(len(imageMap)))
    embeddingMap={}
    for embeddingItem in embeddings:
        if 'Item' in embeddingItem:
            embedding=embeddingItem['Item']
            imageId=embedding['imageId']
            e1=embedding['embedding']
            e2=e1.split("\'")
            e3=e2[1].encode()
            e4=base64.decodebytes(e3)
            e5=np.frombuffer(e4, dtype=np.float32)
            embedding['np']=e5
            embeddingMap[imageId]=embedding
    print("  found {} embedding entries".format(len(embeddingMap)))
    plateInfo = {
        "images" : imageMap,
        "embeddings" : embeddingMap
    }
    plateMap[plateId]=plateInfo

# Review Plate Images

Here I'd like to select a plate and visualize all composite images.

In [None]:
artifactClient=bi.client('artifact')

In [None]:
artifactClient.getArtifacts('1xDNMw2ZFhpSGDTppgyeMU', 'origin')

In [None]:
plateId = 'tUyR81nttbs4oTerCmeY1W'

In [None]:
def showImage(im):
    #plt.figure(figsize=(20,15))
    plt.figure(figsize=(4,15))
    plt.subplot(1,1,1)
    plt.xticks([])
    plt.yticks([])
    print("shape=", im.size)
    print("format=", im.format)
    ip=plt.imshow(im)
    #plt.xlabel(0)
    plt.show()

In [None]:
def displayImages(images, columns, fx, fy):
    plt.figure(figsize=(fx,fy))
    l =len(images)
    spx=columns
    spy=math.ceil(l/spx)
    i=0
    for imkey in images:
        fo = s3c.get_object(Bucket=bioimsArtifactBucket, Key=imkey)
        fs = fo['Body']
        im = Image.open(fs)
        plt.subplot(spy,spx,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.imshow(im)
        plt.xlabel(i)
        i+=1

In [None]:
def displayThumbnailsForPlate(plateId):
    plateDataMap = plateMap[plateId]
    imageMap = plateDataMap['images']
    imageIds = imageMap.keys()
    keyList = []
    for imageId in imageIds:
        artifactList = artifactClient.getArtifacts(imageId, 'origin')
        for artifact in artifactList:
            s3key = artifact['artifact']
        if s3key.endswith('thumbnail-2d.png'):
            components = s3key.split('#')
            keyList.append(components[1])
    displayImages(keyList, 10, 20, 40)

In [None]:
displayThumbnailsForPlate(plateId)

Next, we create a mapping from MOA->well-embedding, where we will naively take the average of the embeddings for each well (there are several images per well). An alternative would be taking the median.

In [None]:
wellMap = {}

In [None]:
imageCount=0
for plateId in plateMap:
    print("plate {}".format(plateId))
    plateInfo = plateMap[plateId]
    imageMap = plateInfo['images']
    embeddingMap = plateInfo['embeddings']
    for imageId in imageMap:
        imageInfo = imageMap[imageId]
        if imageId in embeddingMap:
            embeddingInfo = embeddingMap[imageId]
            if 'trainLabel' in imageInfo and 'np' in embeddingInfo:
                wellId = imageInfo['wellId']
                if wellId not in wellMap:
                    imageArr = []
                    wellMap[wellId]=imageArr
                imageArr = wellMap[wellId]
                imageEntry = {}
                imageEntry['label']=imageInfo['trainLabel']
                imageEntry['np']=embeddingInfo['np']
                imageArr.append(imageEntry)
                imageCount+=1
print("Found {} wells".format(len(wellMap)))
print("Found {} images with MOA labels and embeddings".format(imageCount))

In [None]:
moaMap={}

In [None]:
for wellId in wellMap:
    imageArr = wellMap[wellId]
    label = imageArr[0]['label']
    if label not in moaMap:
        wellArr = []
        moaMap[label]=wellArr
    wellArr = moaMap[label]
    embeddingArr = []
    for imageEntry in imageArr:
        npe = imageEntry['np']
        embeddingArr.append(npe)
    npa = np.asarray(embeddingArr)
    npm = np.mean(npa, axis=0)
    wellArr.append(npm)

In [None]:
embeddingCount=0
for label in moaMap:
    wellArr = moaMap[label]
    print("label {} has {} entries".format(label, len(wellArr)))
    embeddingCount += len(wellArr)

In [None]:
embeddingCount

In [None]:
moaSortedArr = []
moaLabelMap = {}
for label in moaMap:
    moaSortedArr.append(label)

In [None]:
moaSortedArr

In [None]:
moaSortedArr.sort()

In [None]:
moaSortedArr

In [None]:
for i, l in enumerate(moaSortedArr):
    moaLabelMap[l] = i

In [None]:
orderedEmbedding = np.empty( (embeddingCount, embeddingWidth), dtype=np.float32 )

In [None]:
orderedLabels = np.empty(embeddingCount, dtype=np.int32)

In [None]:
i=0
for label in moaMap:
    wellArr = moaMap[label]
    for embedding in wellArr:
        orderedEmbedding[i]=embedding
        orderedLabels[i]=moaLabelMap[label]
        i += 1

In [None]:
near_neighbors_per_example = 10
gram_matrix = np.einsum("ae,be->ab", orderedEmbedding, orderedEmbedding)
near_neighbors = np.argsort(gram_matrix.T)[:, -(near_neighbors_per_example + 1) :]

In [None]:
class_idx_to_train_idxs = defaultdict(list)
for y_train_idx, y in enumerate(orderedLabels):
    class_idx_to_train_idxs[y].append(y_train_idx)

In [None]:
num_classes = 13
confusion_matrix = np.zeros((num_classes, num_classes))

# For each class.
for class_idx in range(num_classes):
    example_idxs = class_idx_to_train_idxs[class_idx][:near_neighbors_per_example]
    for y_test_idx in example_idxs:
        # Count the classes of its near neighbours.
        for nn_idx in near_neighbors[y_test_idx][:-1]:
            nn_class_idx = orderedLabels[nn_idx]
            confusion_matrix[class_idx, nn_class_idx] += 1

NOTE: with Confusion Matrix below, category 'DMSO' (predicted label 3), is equivalent to 'no treatment' since DMSO is the chemical control buffer. Therefore, we should not be surprised to see off-diagonals that represent a kind of 'best guess' network outcome. 

In [None]:
# Display a confusion matrix.
labels = [
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9",
    "10",
    "11",
    "12"
]
plt.rcParams["figure.figsize"] = (30,15)
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
plt.show()

In [None]:
moaLabelMap

In [None]:
colorList=['b', '#55ff55', 'r', 'c', 'm', 'y', 'k', '#eeeeee', '#777777', '#ff9999', '#880000', '#009900', '#000088']

NOTE: we remove DMSO from the t-sne plot since it is the 'no treatment' control

In [None]:
embeddingClassMembership=[]
for c in range(num_classes):
    if c == 3:
        r1 = orderedLabels == -1
    else:
        r1 = orderedLabels == c
    r2 = collections.Counter(r1)
    embeddingClassMembership.append(r1)

In [None]:
(fig2, subplots2) = plt.subplots(1, 2, figsize=(20, 10))
perplexities = [10, 100]

for i, perplexity in enumerate(perplexities):
    ax = subplots2[i]
    tsne = manifold.TSNE(n_components=2, init='random', random_state=0, perplexity=perplexity)
    Y = tsne.fit_transform(orderedEmbedding)
    ax.set_title("Perplexity=%d" % perplexity)
    for cl in range(num_classes):
        if cl != 3:
            ax.scatter(Y[embeddingClassMembership[cl], 0], Y[embeddingClassMembership[cl], 1], c=colorList[cl])
    ax.xaxis.set_major_formatter(NullFormatter())
    ax.yaxis.set_major_formatter(NullFormatter())
    ax.axis('tight')