# Classify Fluorescence Images as Polarized

This notebook implements a basic CNN to classify fluorescent images of embryos as having or lacking polarized caps.

In [1]:
pip install mxnet

Note: you may need to restart the kernel to use updated packages.


In [2]:
from __future__ import print_function
import h5py
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
import utils
import os
import numpy as np

# Fixing the random seed
mx.random.seed(42)

In [3]:
import mxnet.ndarray as F

class Net(gluon.Block):
    def __init__(self, **kwargs):
        super(Net, self).__init__(**kwargs)
        with self.name_scope():
            # layers created in name_scope will inherit name space
            # from parent layer.
            self.conv1 = nn.Conv2D(20, kernel_size=(5,5))
            self.pool1 = nn.MaxPool2D(pool_size=(2,2), strides = (2,2))
            self.conv2 = nn.Conv2D(50, kernel_size=(5,5))
            self.pool2 = nn.MaxPool2D(pool_size=(2,2), strides = (2,2))
            self.fc1 = nn.Dense(500)
            self.fc2 = nn.Dense(10)

    def forward(self, x):
        x = self.pool1(F.tanh(self.conv1(x)))
        x = self.pool2(F.tanh(self.conv2(x)))
        # 0 means copy over size from corresponding dimension.
        # -1 means infer size from the rest of dimensions.
        x = x.reshape((0, -1))
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        return x

In [11]:
net = Net()

# set the context on GPU is available otherwise CPU
ctx = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03})

# preprocess the data
fluo_data_path = "../data/video_fluo_data"
embryo_idx = 1
fluo = h5py.File(os.path.join(fluo_data_path,'embryo_'+str(embryo_idx)+'.mat'))
arrays = {}
for k, v in fluo.items():
    arrays[k] = np.array(v)
fluo_video = arrays['data']
pol_state = arrays['anno']

fluo_video = np.array([utils.get_middle_z(fluo_video)])
fluo_video = np.moveaxis(fluo_video, -1, 0)

batch_size = 100
train_data = mx.io.NDArrayIter(fluo_video, pol_state[0], batch_size, shuffle=True)
# val_data = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

  fluo = h5py.File(os.path.join(fluo_data_path,'embryo_'+str(embryo_idx)+'.mat'))


(60000, 1, 28, 28)
<mxnet.io.io.NDArrayIter object at 0x1412b1fd0>


In [None]:
epoch = 10

# Use Accuracy as the evaluation metric.
metric = mx.metric.Accuracy()
softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()

for i in range(epoch):
    # Reset the train data iterator.
    train_data.reset()
    # Loop over the train data iterator.
    for batch in train_data:
        # Splits train data into multiple slices along batch_axis
        # and copy each slice into a context.
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
        # Splits train labels into multiple slices along batch_axis
        # and copy each slice into a context.
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
        outputs = []
        # Inside training scope
        with ag.record():
            for x, y in zip(data, label):
                z = net(x)
                # Computes softmax cross entropy loss.
                loss = softmax_cross_entropy_loss(z, y)
                # Backpropogate the error for one iteration.
                loss.backward()
                outputs.append(z)
                
        print(outputs)
        # Updates internal evaluation
        metric.update(label, outputs)
        # Make one step of parameter update. Trainer needs to know the
        # batch size of data to normalize the gradient by 1/batch_size.
        trainer.step(batch.data[0].shape[0])
    # Gets the evaluation result.
    name, acc = metric.get()
    # Reset evaluation result to initial state.
    metric.reset()
    print('training acc at epoch %d: %s=%f'%(i, name, acc))


[2.2103627 2.5780544 2.677043  1.9435673 2.4546053 1.8247279 2.417427
 1.9807328 2.5779767 2.1709526 2.1524036 1.9793866 2.3828225 2.047558
 2.0884633 1.9752182 2.531736  2.1055849 2.37225   1.8707216 2.025089
 2.2103627 2.5780544 2.677043  1.9435673 2.4546053 1.8247279 2.417427
 1.9807328 2.5779767 2.1709526 2.1524036 1.9793866 2.3828225 2.047558
 2.0884633 1.9752182 2.531736  2.1055849 2.37225   1.8707216 2.025089
 2.2103627 2.5780544 2.677043  1.9435673 2.4546053 1.8247279 2.417427
 1.9807328 2.5779767 2.1709526 2.1524036 1.9793866 2.3828225 2.047558
 2.0884633 1.9752182 2.531736  2.1055849 2.37225   1.8707216 2.025089
 2.2103627 2.5780544 2.677043  1.9435673 2.4546053 1.8247279 2.417427
 1.9807328 2.5779767 2.1709526 2.1524036 1.9793866 2.3828225 2.047558
 2.0884633 1.9752182 2.531736  2.1055849 2.37225   1.8707216 2.025089
 2.2103627 2.5780544 2.677043  1.9435673 2.4546053 1.8247279 2.417427
 1.9807328 2.5779767 2.1709526 2.1524036 1.9793866 2.3828225 2.047558
 2.0884633 1.975218