In [0]:
import numpy as np
import os
import imutils
import time
import cv2
from google.colab.patches import cv2_imshow
import torch
from copy import deepcopy

In [0]:
def ReadCellImages(dir, imgSize):
  images = os.listdir(dir)
  images = [image for image in images if image.endswith('png')]

  X = np.empty([len(images),imgSize[0],imgSize[1],3], dtype=float)
  y = []

  count = 0
  for image in images:
    X[count] = cv2.imread(dir+image)#.reshape((1,-1))
    if '1' in image: y.append(0)
    elif '2' in image: y.append(1)
    elif '3' in image: y.append(2)
    count = count+1

  return (X,y)

In [0]:
train_data = ReadCellImages('cells/', (23,23))

In [0]:
def CreateBatch(data, index, batch_size):
  if index + batch_size > len(data[0]): batch_size = len(data[0]) - index - 1
  return (torch.cat(
      [torch.from_numpy(image).float().reshape(1, 3, 23, 23) for image in data[0][index:index + batch_size]], dim=0), 
      torch.tensor(data[1][index:index + batch_size], dtype=torch.long)
  )

In [0]:
class CNN_classifier(torch.nn.Module):
  def __init__(self):
    super().__init__()

    self.conv_model = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(kernel_size=2,stride=2)
    )

    self.linear = torch.nn.Linear(in_features=32*10*10,out_features=3)

  def forward(self, batch):
    z = batch
    z = self.conv_model(z)
    z = z.view(-1, 32*10*10)
    return self.linear(z)

In [0]:
cnn = CNN_classifier()

In [0]:
optimizer = torch.optim.SGD(
    cnn.parameters(),
    lr=0.01
)

In [0]:
loss_builder = torch.nn.NLLLoss(reduction='mean')
m = torch.nn.LogSoftmax(dim=1)
s = torch.nn.Softmax(dim=1)

In [32]:
batch_size = 3
n_epochs = 10
train_loss_list = []
train_acc_list = []

for epoch in range(n_epochs):
  cnn.train()
  for i in range(0,len(train_data[0]),batch_size):
    batch = CreateBatch(train_data, i, batch_size)
    x = batch[0]
    gold = batch[1]
    y = cnn(x)
    loss = loss_builder(m(y),gold)

    cnn.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(cnn.parameters(), 5.)  # clip gradient if its norm exceed 5
    optimizer.step()

  cnn.eval()
  train_acc = 0
  for i in range(0,len(train_data[0]),1):
    batch = CreateBatch(train_data, i, 1)
    x = batch[0]
    gold = batch[1]
    y = cnn(x)

    if np.max(s(y).detach().numpy()) < 0.9: gold_predicted = 3
    else: gold_predicted = np.argmax(s(y).detach().numpy())

    if gold_predicted == gold: train_acc += 1

  train_acc /= train_data[0].shape[0]
  
  train_loss_list.append(loss.item())
  train_acc_list.append(train_acc)

  print("Epoch: {:d}/{:d}".format(epoch+1,n_epochs))
  print ("Train Avg Loss:", loss.item(), "\t\tTrain Accurancy:", train_acc)
  print()

Epoch: 1/10
Train Avg Loss: 48.87565994262695 		Train Accurancy: 0.3333333333333333

Epoch: 2/10
Train Avg Loss: 4265.38037109375 		Train Accurancy: 0.6666666666666666

Epoch: 3/10
Train Avg Loss: 389.4669494628906 		Train Accurancy: 1.0

Epoch: 4/10
Train Avg Loss: 0.0 		Train Accurancy: 1.0

Epoch: 5/10
Train Avg Loss: 0.0 		Train Accurancy: 1.0

Epoch: 6/10
Train Avg Loss: 0.0 		Train Accurancy: 1.0

Epoch: 7/10
Train Avg Loss: 0.0 		Train Accurancy: 1.0

Epoch: 8/10
Train Avg Loss: 0.0 		Train Accurancy: 1.0

Epoch: 9/10
Train Avg Loss: 0.0 		Train Accurancy: 1.0

Epoch: 10/10
Train Avg Loss: 0.0 		Train Accurancy: 1.0



In [0]:
def SlidingWindow(image, stepSize, windowSize):
	for y in range(0, image.shape[0], stepSize):
		for x in range(0, image.shape[1], stepSize):
			yield (x, y, image[y:y + windowSize[1], x:x + windowSize[0]])

In [0]:
image = cv2.imread('black_bubbles_1.png')
(winW, winH) = (23, 23)

In [0]:
cell_1_count = 0
cell_2_count = 0
cell_3_count = 0

for (x, y, window) in SlidingWindow(image, stepSize=1, windowSize=(winW, winH)):
	if window.shape[0] != winH or window.shape[1] != winW:
		continue

	cnn.eval()
	window_reshaped = deepcopy(window)
	window_reshaped = window_reshaped.reshape((1,3,23,23))
	batch = CreateBatch((window_reshaped,[0,0,0]), 0, 1)
	patch = batch[0]
	gold = batch[1]
	pred = cnn(patch)
	if np.max(s(pred).detach().numpy()) >= 0.9:
		pred = np.argmax(s(pred).detach().numpy())
		if pred == 0: cell_1_count = cell_1_count + 1
		elif pred == 1: cell_2_count = cell_1_count + 1
		elif pred == 2: cell_3_count = cell_1_count + 1

In [0]:
print('Cell Type 1 Count', cell_1_count)
print('Cell Type 2 Count', cell_2_count)
print('Cell Type 3 Count', cell_3_count)