# Cell Detection

*Implementing the paper "Robust Cell Detection and Segmentation in Histopathological Images Using Sparse Reconstruction and Stacked Denoising Autoencoders"*



In [5]:
import data_utils
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import math
import time


In [6]:
# Parameters

learning_rate_init = 3.7*1e-1   # Learning rate for finding minimum
reg_init = 0                    # Regularisation strength
num_iter_init = 1500            # Number of iterations
K = 15                          # Value of K (Size of basis)
patch_size = 45                 # Size of patches with centered cells
batch_size = 10                 # Size of data being processed each iteration
D_in = patch_size*patch_size*3  # Dimension of each patch = number of pixels in each patch
num_workers = 4                 # Number of workers


In [7]:
# Load Data

dtype = torch.double
device = torch.device("cpu")

train_data = data_utils.malaria_dataset(json_file='../datasets/malaria/training.json',
                      root_dir='../datasets/malaria',
                      transform=data_utils.crop_bounding_box((patch_size,patch_size)))

training_data_loader = torch.utils.data.DataLoader(train_data,
                           batch_size=batch_size,
                           shuffle=True,
                           num_workers=num_workers,
                           collate_fn=data_utils.collate_fn)


In [None]:
# Basis Vector Initialisation

for i_batch, batch_sample in enumerate(training_data_loader):
    
    patches_batch = batch_sample['patches']
    num_patches = len(patches_batch)
    
    # Reshape sampled batch
    x = patches_batch.view(num_patches,-1)
    
    # Initialise w
    w = torch.randn(num_patches, num_patches, device=device, dtype=dtype, requires_grad=True)
    
    for epoch in range(num_iter_init):
      # Evaluate Loss
      # w = torch.nn.functional.normalize(w, p=2, dim=1)
      temp1 = x - w.mm(x) + w.diag()[:,None].expand_as(x)*x
      #print((temp1**2).mean())
      temp2 = temp1.view(num_patches*D_in).pow(2).mean()
      x_norm = (x**2).sum(1).view(-1, 1)
      y_norm = x_norm.view(1, -1)
      dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(x, 0, 1))
      dist_norm = torch.nn.functional.normalize(dist, p=2, dim=1)
      dist_exp = torch.exp(dist_norm)
      loss = temp2 + reg_init*((dist_exp * w).view(num_patches*num_patches).pow(2).sum())
    
      # Print loss and iteration
      print(epoch, loss.item())
    
      # Calculate Gradient
      loss.backward()
    
      # Update weights
      with torch.no_grad():
        w -= learning_rate_init * w.grad
        w.grad.zero_()
        
    if(i_batch == 0):
      break
    
w = torch.nn.functional.normalize(w, p=2, dim=1)
basis_index = w.sum(0).sort()[1][K:] # The index of elements selected as the basis for the dictionary
basis = x[basis_index] # Basis Initialisation
coeff = (w.clone().t()[basis_index]).t() # The coefficients corresponding to the initialised basis


0 233.08105407246416
1 153.7304763214542
2 104.4186107468876
3 73.76477063487596
4 54.70008278183999
5 42.833865204295016
6 35.43893758972069
7 30.821368516137824
8 27.929007332790242
9 26.108332032130637
10 24.95342044125814
11 24.212141857450522
12 23.727889951537506
13 23.403391379954236
14 23.17822625894445
15 23.014865474465907
16 22.889998046595117
17 22.78914430107315
18 22.70330995172134
19 22.626907852397437
20 22.556467135725658
21 22.48983142093014
22 22.425660796909735
23 22.36312248910206
24 22.301698723436093
25 22.24106738487784
26 22.181027890864893
27 22.121455149057322
28 22.062270959095567
29 22.003426249356963
30 21.94489004365214
31 21.8866426080913
32 21.828671194380405
33 21.7709673958378
34 21.713525505120323
35 21.656341494143028
36 21.59941238046008
37 21.542735833691847
38 21.486309931047817
39 21.4301330054619
40 21.374203551248968
41 21.318520165491883
42 21.263081511620427
43 21.20788629677478
44 21.152933257731835
45 21.098221152149016
46 21.0437487531129

366 11.351451646021028
367 11.335781947323634
368 11.320164455283615
369 11.304598941367335
370 11.28908517816342
371 11.273622939376828
372 11.258211999822917
373 11.242852135421538
374 11.227543123191603
375 11.212284741245055
376 11.19707676878118
377 11.181918986081016
378 11.166811174501373
379 11.151753116469768
380 11.136744595478397
381 11.121785396078735
382 11.106875303875844
383 11.092014105523017
384 11.077201588716518
385 11.062437542189354
386 11.047721755707041
387 11.033054020060813
388 11.018434127063664
389 11.0038618695443
390 10.989337041341324
391 10.974859437299418
392 10.960428853262595
393 10.946045086069846
394 10.931707933550012
395 10.917417194515771
396 10.903172668759574
397 10.888974157048008
398 10.874821461116554
399 10.860714383665265
400 10.846652728352556
401 10.83263629979164
402 10.818664903544427
403 10.804738346117109
404 10.790856434955188
405 10.77701897843852
406 10.763225785876617
407 10.749476667503542
408 10.735771434473465
409 10.7221098988