In [0]:
import scipy as sp
import scipy.misc
import imageio
import matplotlib.pyplot as plt
import numpy as np
from google.colab import drive
drive.mount('/content/drive')
%matplotlib inline
!nvidia-smi

In [0]:
plt.figure(figsize=(10,8))
plt.subplot(1,2,1)
im = imageio.imread('/content/drive/My Drive/confocal for segmentation/colonies/train/day 4_box 2.bmp')
plt.imshow(im)
plt.subplot(1,2,2)
mask = imageio.imread('/content/drive/My Drive/confocal for segmentation/segmentation/train/day 4_box 2.bmp')
plt.imshow(mask, 'gray')


In [0]:
def calc_iou(prediction, ground_truth):
    n_images = len(prediction)
    intersection, union = 0, 0
    for i in range(n_images):
        intersection += np.logical_and(prediction[i] > 0, ground_truth[i] > 0).astype(np.float32).sum() 
        union += np.logical_or(prediction[i] > 0, ground_truth[i] > 0).astype(np.float32).sum()
    return float(intersection) / union

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.autograd import Variable

from scipy.special import expit

from PIL import Image
import random
from os.path import join
from os import listdir
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from IPython.display import clear_output

### U-Net
The code for U-Net is taken from https://github.com/milesial/Pytorch-UNet

In [0]:
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

In [0]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

In [0]:
net = UNet(n_channels=3, n_classes=1).cuda()

## Loading the Data

In [0]:
def load_image_dataset(which='train'):
    X, Y = [], []
    for file in listdir(join('/content/drive/My Drive/confocal for segmentation/colonies/', which)):
        X.append(Image.open(join('/content/drive/My Drive/confocal for segmentation/colonies/', which, file)).convert('RGB').resize((512,512), Image.ANTIALIAS))
        if which != 'test':
            Y.append(Image.open(join('/content/drive/My Drive/confocal for segmentation/segmentation/', which, file)).convert('L').resize((512,512), Image.ANTIALIAS))
        else:
            Y.append(None)
    return X, Y

In [0]:
X_train, Y_train = load_image_dataset('train')
X_val, Y_val = load_image_dataset('val')
X_test, _ = load_image_dataset('test')

In [0]:
class AugmentedDataset(Dataset):
    def __init__(self, X, Y, transform_X=None, transform_Y=None):
        self.X, self.Y = X, Y
        self.transform_X = transform_X
        self.transform_Y = transform_Y
    
    def __getitem__(self, index):
        x = self.X[index]
        seed = np.random.randint(0xBADBEEF)
        if self.transform_X:
            random.seed(seed)
            x = self.transform_X(x)  
        if not self.Y:
            return x
        
        y = self.Y[index]
        if self.transform_Y:
            random.seed(seed)
            y = self.transform_Y(y)
        y = (np.array(y)[None, :, :] > 0).astype(np.float32)
        return x, y

    def __len__(self):
        return len(self.X)

In [0]:
means = np.array((0.4914, 0.4822, 0.4465))
stds = np.array((0.2023, 0.1994, 0.2010))

transform_X = transforms.Compose([
    #transforms.RandomCrop(512, padding=0),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(0.01, 0.01, 0.01, 0.01),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x + torch.randn(*x.size()) * 0.02),
    transforms.Normalize(means, stds),
])

transform_Y = transforms.Compose([
    #transforms.RandomCrop(512, padding=0),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
])

train_dataset = AugmentedDataset(X_train, Y_train, transform_X, transform_Y)
train_dataloader = DataLoader(train_dataset, batch_size=7, shuffle=True)

transform_test = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(means, stds),
])

val_dataset = AugmentedDataset(X_val, Y_val, transform_test, None)
val_dataloader = DataLoader(val_dataset, batch_size=3, shuffle=False)

## Network training

In [0]:
EPOCHS = 89

opt = torch.optim.Adam(net.parameters(), lr=1e-3 * .5, weight_decay=1e-6)
train_loss = []
val_accuracy = []

In [0]:
for epoch in range(EPOCHS):
    net.train(True)
    batch_loss = []
    for (X_batch, Y_batch) in train_dataloader:
        print("hey")
        loss = F.binary_cross_entropy_with_logits(net(X_batch.cuda()), Y_batch.cuda())
        opt.zero_grad()
        loss.backward()
        opt.step()
        opt.zero_grad()
        batch_loss.append(loss.data.cpu().numpy())
    train_loss.append(np.nanmean(batch_loss))
    net.train(False)

    val_iou = []
    for (X_batch, Y_batch) in val_dataloader:
        Y_pred = net(Variable(X_batch, volatile=True).cuda()).data.cpu().numpy()
        val_iou.append(calc_iou(expit(Y_pred) > 0.16, Y_batch.numpy()))
    val_accuracy.append(np.nanmean(val_iou))
    
    clear_output(wait=True)
    print(f"Epoch {epoch}")
    print(f"Training loss: {train_loss[-1]}")
    print(f"Validation IOU: {val_accuracy[-1]}")

In [0]:
plt.plot(train_loss)
plt.show()
plt.plot(val_accuracy)
plt.show()

## Test Evaluation

In [0]:
test_dataset = AugmentedDataset(X_test, None, transform_X=transform_test, transform_Y=None)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [0]:
Y_pred = []
for X_batch in test_dataloader:
    Y_pred.append(expit(net(X_batch.cuda()).data.cpu().numpy()) > 0.33)
Y_pred = np.concatenate(Y_pred)

### Lets just look on one result

In [0]:
from PIL import Image
from IPython.display import display
x = X_test[0]
y = Image.fromarray(np.uint8(Y_pred[0, 0, :, :] * 255) , 'L')
display(x, y)

### Saving results to files

In [0]:
torch.save(net.state_dict(),"/content/drive/My Drive/confocal for segmentation/tensor.pt")