# Segmentation Model for scI Datasets

This model performs segmentation on separation lanes that were classified as containing protein bands, from the classification model.

We gratefully acknowledge Prof. Jonathan Shewchuk and the teaching assistants of the Spring 2019 offering of UC Berkeley’s CS289A course for helpful discussions and initial code for some of this work.  



# Notebook Initialization



In [14]:
## Mounting to Google Colab
## Comment out if not using Colab

from google.colab import drive
drive.mount('/content/gdrive')

## Loading libraries

# Classic libraries
import os
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt

# Scikit learn libraries
from skimage import io, transform
from PIL import Image
from sklearn import metrics

# Pytorch libraries
import torch
import torch.nn as nn
from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.models as models
from torchsummary import summary



Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [19]:
## Dataset Files

# Change this to the working directory
os.chdir('/content/gdrive/My Drive/Herrlab/projects/segmentationproject/experiments')

# Change this to where the data is stored
dataset_path = '/content/gdrive/My Drive/Herrlab/projects/segmentationproject/datasets/mtl/'

# Datasets
train_file = dataset_path+'unet_train.csv';
val_file = dataset_path+'unet_validate.csv';
test_file = dataset_path+'unet_test.csv';

# Final model output
model_save_file = '/content/gdrive/My Drive/Herrlab/projects/segmentationproject/experiments/exp5-18/unet_github.ckpt'


## Dataset Loading Functions

In [20]:
def imgLoad(imagePath):
    '''
    Loads an image using the PIL Image library
    Converts it to int64 and to x-dim * y-dim * 1 
    to make its size compatible with pytorch tensors.
    '''
    
    image = np.array(Image.open(imagePath)).astype('int64');
    return image

def div_n(x,n):
    '''
    Figures out how much padding needs to be added to
    x, in order to make it divisible by n 
    ''' 
    x = float(x);
    n = float(n);

    if (x % n == 0):
      padding = 0;

    else:
      x_padded = np.ceil(x/n) * n; 
      padding = (x_padded-x)/2.0;

    return(int(padding));

class roiData(data.Dataset):
    ''' Object representing a dataset. Code adapted from
    CS289 HW6 assignment, Spring 2019 offering @ UC Berkeley.

    This object stores an input data file. When called as roidata[i], 
    then the object will return the sample name and label of the ith
    data point.

    NOTE: PyTorch kept having issues with normalization, so I just decided
    to do it myself.
    '''   
    def __init__(self, label_file, img_load_function, transform=None, dataset_path=None, normalize=None,ndivisible=0):
        'Initialization'

        self.label_file = label_file
        self.loader = img_load_function
        self.data = pd.read_csv(self.label_file,header=0)[['roiPath','segmentPath']]
        self.transform = transform;
        self.normalize = normalize; 

        # Option to add a path to the image files, if located in some other
        # folder

        if dataset_path is not None:
          self.data['roiPath'] = dataset_path+self.data['roiPath'];
          self.data['segmentPath'] = dataset_path+self.data['segmentPath'];

        # Getting padding info

        if ndivisible is not 0:
          path,label = self.data.iloc[0];
          sample = self.loader(path);
          [x,y] = sample.shape;

          x_padding = div_n(x,16);
          y_padding = div_n(y,16);

          self.padding = [x_padding,y_padding];
          
        else:
          self.padding = [0,0];


    def __len__(self):
        'Denotes the total number of samples'
        return len(self.data)

    def __getitem__(self,idx):
        'Generates one sample of data'
        path,label = self.data.iloc[idx]
        label = self.loader(label);
        sample = self.loader(path)

        if self.normalize is not None:
          sample = (sample - self.normalize[0])/self.normalize[1];

        # We should not be transforming the samples unless we want to 
        # transform the labels as well... 
        # Might be worth doing in the future. 
        if self.transform is not None:
          sample = self.transform(sample);
          label = self.transform(label);

        sample = F.pad(sample,pad=(self.padding[1],self.padding[1],self.padding[0],self.padding[0]));
        label = F.pad(label,pad=(self.padding[1],self.padding[1],self.padding[0],self.padding[0]));

        return sample,label



# Normalization

In [21]:
batchSize = 10;

tensor_transform_normalization = transforms.Compose([transforms.ToTensor()]);
train_dataset_normalization = roiData(train_file,imgLoad,transform=tensor_transform_normalization,dataset_path=dataset_path);
train_loader_normalization = data.DataLoader(train_dataset_normalization, batch_size = batchSize, shuffle = True, num_workers = 2);

mean = [];
meansq = [];
sample_length = [];

print("Training samples: {}".format(len(train_dataset_normalization.data)));

for sample, label in train_loader_normalization:
  mean.append(np.mean(np.array(sample)))
  meansq.append(np.mean(np.array(sample**2)))
  sample_length.append(len(label))

batch_p = np.array(sample_length)/len(train_dataset_normalization);

sample_mean = np.sum(batch_p*np.array(mean));
sample_var = np.sum(batch_p*np.array(meansq)) - (sample_mean ** 2);
sample_std = np.sqrt(sample_var);

print("Sample mean: {}".format(sample_mean))
print("Sample stdev: {}".format(sample_std))


Training samples: 153
Sample mean: 13502.82817769608
Sample stdev: 8676.41096301769


## Loading Data

In [26]:
# Dataset transformations for training
train_transform = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.RandomHorizontalFlip()
                                       ]);

# Dataset transformation into tensor
tensor_transform = transforms.Compose([transforms.ToTensor()]);


# Number of workers
workers = 2; 

# Batch size
batchSize = 16;

train_dataset = roiData(train_file,imgLoad,transform=tensor_transform,dataset_path=dataset_path,normalize=[sample_mean,sample_std],ndivisible=16);
test_dataset = roiData(test_file,imgLoad,transform=tensor_transform,dataset_path=dataset_path,normalize=[sample_mean,sample_std],ndivisible=16);
val_dataset = roiData(val_file,imgLoad,transform=tensor_transform,dataset_path=dataset_path,normalize=[sample_mean,sample_std],ndivisible=16);

train_loader = data.DataLoader(train_dataset, batch_size = batchSize, shuffle = True, num_workers = workers);
test_loader = data.DataLoader(test_dataset, batch_size = batchSize, shuffle = False, num_workers = workers);
val_loader = data.DataLoader(val_dataset, batch_size = batchSize, shuffle = False, num_workers = workers);


# Model Tuning

In [None]:
# This is just a check to make sure it properly loaded
train_dataset = roiData(train_file,imgLoad,transform=None,dataset_path=dataset_path);
print("Training dataset length: "+str(len(train_dataset)))

Training dataset length: 153


In [23]:
## Model Hyperparameters

## Hyperparameters

# Number of epochs
num_epochs = 100;

# Learning Rate
learning_rate = 1E-5

# Loss
criterion = nn.CrossEntropyLoss()

# Model

class NeuralNet(nn.Module):
  def convblock(self,input_layers,output_layers):
    block = nn.Sequential(
        nn.Conv2d(input_layers,output_layers,kernel_size=3, stride = 1, padding=1),
        nn.ReLU(),
        nn.Conv2d(output_layers,output_layers,kernel_size=3, stride = 1, padding=1),
        nn.ReLU()
    );
    return(block);

  def copy_and_crop(self,x,x_skip):
    '''
    Ensures that the copy and crop step produces and output of a 
    suitable shape. 
    '''

    x_pad = (x_skip.shape[2] - x.shape[2]);
    y_pad = (x_skip.shape[3] - x.shape[3]);

    if (x_pad % 2 == 0): #if even
      x_pad_left = int(x_pad/2);
      x_pad_right = int(x_pad/2);
    else: 
      x_pad_left = int(x_pad/2)+1;
      x_pad_right = int(x_pad/2);

    if (y_pad % 2 == 0): #if even
      y_pad_left = int(y_pad/2);
      y_pad_right = int(y_pad/2);
    else: 
      y_pad_left = int(y_pad/2);
      y_pad_right = int(y_pad/2)-1;

    x_skip_cropped = F.pad(x_skip,pad=(-y_pad_left,-y_pad_right,-x_pad_left,-x_pad_right,0,0,0,0));

    final_tensor = torch.cat((x_skip_cropped,x),dim=1); 

    return(final_tensor)
      
  def __init__(self):
      super(NeuralNet, self).__init__()

      self.pool = nn.MaxPool2d(2,2);

      self.d1 = self.convblock(1,64);
      self.d2 = self.convblock(64,128); # Layer 2
      self.d3 = self.convblock(128,256); # Layer 3
      self.d4 = self.convblock(256,512); # Layer 4
      self.d5 = self.convblock(512,1024); # Layer 5

      self.u6 = self.convblock(1024,512); # Layer 6
      self.u7 = self.convblock(512,256); # Layer 7
      self.u8 = self.convblock(256,128); # Layer 8
      self.u9 = self.convblock(128,64); # Layer 9

      self.upconv5 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2, padding=0)
      self.upconv6 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, padding=0)
      self.upconv7 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, padding=0)
      self.upconv8 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0)
      
      self.final_cov = nn.Conv2d(64,2,kernel_size=1,stride=1,padding=0);

  def forward(self, x):

      # Left Network
      x1 = self.d1(x);    
      x1_pool = self.pool(x1);

      x2 = self.d2(x1_pool);
      x2_pool = self.pool(x2);

      x3 = self.d3(x2_pool);
      x3_pool = self.pool(x3);

      x4 = self.d4(x3_pool);
      x4_pool = self.pool(x4);
   
      x5 = self.d5(x4_pool);        
      x5_upconv = self.upconv5(x5);

      # Right network   
      x6_tensor = self.copy_and_crop(x5_upconv,x4);
      x6 = self.u6(x6_tensor);
      x6_upconv = self.upconv6(x6);

      x7_tensor = self.copy_and_crop(x6_upconv,x3);
      x7 = self.u7(x7_tensor);
      x7_upconv = self.upconv7(x7);

      x8_tensor = self.copy_and_crop(x7_upconv,x2);
      x8 = self.u8(x8_tensor);
      x8_upconv = self.upconv8(x8);

      x9_tensor = self.copy_and_crop(x8_upconv,x1);
      x9 = self.u9(x9_tensor);
      x_final = self.final_cov(x9);

      return(x_final)

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
model = NeuralNet().to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## Running the Model

In [27]:
##################################
#                                #
#            TRAINING            #
#                                #
##################################

# Looping over epochs


# Initializing variables
training_loss = [];
training_predictions = [];
training_labels = [];
training_accuracy = [];
training_micrographs = [];

print('Beginning training..')
total_step = len(train_loader) # To calculate total number of steps. 
start = time.time();

for epoch in np.arange(num_epochs):
  
    # Batch training
    model.train()  
    print('epoch {}'.format(epoch+1))

    for i, (local_batch,local_labels) in enumerate(train_loader):
        local_batch = local_batch.float();

        # Transfer to GPU
        local_ims, local_labels = local_batch.to(device), local_labels.to(device)  

        # Forward pass
        outputs = model.forward(local_ims)

        # Reshaping for loss

        if epoch == 0:
          x = local_labels.shape[-1];
          y = local_labels.shape[-2];

        local_batch_size = len(local_labels);
        
        # Loss
        local_labels_for_loss = torch.reshape(local_labels,(local_batch_size,x*y));
        outputs_for_loss = torch.reshape(outputs,(local_batch_size,outputs.shape[1],x*y))

        loss = criterion(outputs_for_loss, local_labels_for_loss)
        training_loss.append(loss.tolist())

        _, predicted = torch.max(outputs.data, 1)

        # Accuracy score
        predicted_reshape = torch.reshape(predicted,(local_batch_size*x*y,1)).cpu().numpy();
        local_labels_for_score = torch.reshape(local_labels_for_loss,(local_batch_size*x*y,1)).cpu().numpy();

        score = metrics.accuracy_score(predicted_reshape,local_labels_for_score);
        training_accuracy.append(score)

        # If last epoch, save the predictions
        if epoch == num_epochs-1:
          training_predictions.append(predicted.cpu().numpy());
          training_labels.append(local_labels.cpu().numpy()); 
          training_micrographs.append(local_ims.cpu().numpy());
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Printing results
        if (i+1) % 5 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
            print('Training accuracy {:.4f} %'.format(100*score))
            print('Time: '+str(time.time()-start))


torch.save(model.state_dict(), model_save_file)

Beginning training..
epoch 1
Epoch [1/100], Step [5/10], Loss: 0.3712
Training accuracy 89.5466 %
Time: 1.4224250316619873
Epoch [1/100], Step [10/10], Loss: 0.3596
Training accuracy 89.8571 %
Time: 2.5181751251220703
epoch 2
Epoch [2/100], Step [5/10], Loss: 0.3399
Training accuracy 89.0155 %
Time: 3.907350778579712
Epoch [2/100], Step [10/10], Loss: 0.3319
Training accuracy 89.0905 %
Time: 5.008713245391846
epoch 3
Epoch [3/100], Step [5/10], Loss: 0.3030
Training accuracy 89.2097 %
Time: 6.4027228355407715
Epoch [3/100], Step [10/10], Loss: 0.2414
Training accuracy 90.9014 %
Time: 7.505269765853882
epoch 4
Epoch [4/100], Step [5/10], Loss: 0.2525
Training accuracy 89.5940 %
Time: 8.907092332839966
Epoch [4/100], Step [10/10], Loss: 0.2532
Training accuracy 87.7457 %
Time: 10.023545742034912
epoch 5
Epoch [5/100], Step [5/10], Loss: 0.2188
Training accuracy 89.1241 %
Time: 11.471235752105713
Epoch [5/100], Step [10/10], Loss: 0.2058
Training accuracy 89.7983 %
Time: 12.58778929710388

In [28]:
##################################
#                                #
#          VALIDATION            #
#                                #
##################################

model.eval()

# Initializing variables
validation_loss = [];
validation_predictions = [];
validation_labels = [];
validation_accuracy = [];
validation_micrographs = [];

print("Starting validation")
for i, (local_batch,local_labels) in enumerate(val_loader):
    
    # Loading data
    local_batch = local_batch.float();
    local_ims, local_labels = local_batch.to(device), local_labels.to(device)
    
    # Evaluation
    outputs = model.forward(local_ims)

    # Reshaping for loss
    x = local_labels.shape[-1];
    y = local_labels.shape[-2];
    local_batch_size = len(local_labels);

    # Validation loss
    local_labels_for_loss = torch.reshape(local_labels,(local_batch_size,x*y));
    outputs_for_loss = torch.reshape(outputs,(local_batch_size,outputs.shape[1],x*y))

    loss = criterion(outputs_for_loss, local_labels_for_loss)
    validation_loss.append(loss.tolist())
    

    # Predictions via max     
    _, predicted = torch.max(outputs.data, 1)
    predicted_reshape = torch.reshape(predicted,(local_batch_size*x*y,1)).cpu().numpy();
    local_labels_for_score = torch.reshape(local_labels_for_loss,(local_batch_size*x*y,1)).cpu().numpy();

    score = metrics.accuracy_score(predicted_reshape,local_labels_for_score);
    validation_accuracy.append(score)

    validation_predictions.append(predicted.cpu().numpy());
    validation_labels.append(local_labels.cpu().numpy()); 
    validation_micrographs.append(local_ims.cpu().numpy());

# Calculating Accuracy Score
val_weights = [];

for i in validation_predictions:
  val_weights.append(len(i)/len(val_dataset))

print('Validation accuracy {:.4f} %'.format(100 * np.average(validation_accuracy,weights=val_weights)))

Starting validation
Validation accuracy 97.8208 %


In [30]:
##################################
#                                #
#          TESTING               #
#                                #
##################################

print("Test dataset length: {}".format(len(test_dataset)));
 
model.eval()

# Initializing variables
test_loss = [];
test_predictions = [];
test_labels = [];
test_accuracy = [];
test_micrographs = [];

print("Starting testing")
for i, (local_batch,local_labels) in enumerate(test_loader):
    
    # Loading data
    local_batch = local_batch.float();
    local_ims, local_labels = local_batch.to(device), local_labels.to(device)
    
    # Evaluation
    outputs = model.forward(local_ims)

    # Reshaping for loss
    x = local_labels.shape[-1];
    y = local_labels.shape[-2];
    local_batch_size = len(local_labels);

    # Validation loss
    local_labels_for_loss = torch.reshape(local_labels,(local_batch_size,x*y));
    outputs_for_loss = torch.reshape(outputs,(local_batch_size,outputs.shape[1],x*y))

    loss = criterion(outputs_for_loss, local_labels_for_loss)
    test_loss.append(loss.tolist())
    

    # Predictions via max     
    _, predicted = torch.max(outputs.data, 1)
    predicted_reshape = torch.reshape(predicted,(local_batch_size*x*y,1)).cpu().numpy();
    local_labels_for_score = torch.reshape(local_labels_for_loss,(local_batch_size*x*y,1)).cpu().numpy();

    score = metrics.accuracy_score(predicted_reshape,local_labels_for_score);
    test_accuracy.append(score)

    test_predictions.append(predicted.cpu().numpy());
    test_labels.append(local_labels.cpu().numpy()); # Don't need this if we never shuffle the validation set, but added for consistency
    test_micrographs.append(local_ims.cpu().numpy());

# Calculating Accuracy Score
test_weights = [];

for i in test_predictions:
  test_weights.append(len(i)/len(test_dataset))

print('Test accuracy {:.4f} %'.format(100 * np.average(test_accuracy,weights=test_weights)))

Test dataset length: 27
Starting testing
Test accuracy 94.2025 %


In [31]:
# Saving the test predictions output

test_predictions_array = [];

for i in range(len(test_predictions)):
  for j in range(test_predictions[i].shape[0]):

    test_predictions_array.append(test_predictions[i][j,4:204,:])

np.save("/content/gdrive/My Drive/Herrlab/projects/segmentationproject/datasets/mtl/test_predictions_array",test_predictions_array)