# Semantic and instance segmentation of images
This shows an example of a combined approach to doing _semantic_ and _instance_ segmentation of images. For this example of microscope slides containing lice, this means segmenting the image such that each segment represents a unique object (an _instance_) and that that object is from a particular (_semantic_) class. In the case of these slides, the classes are: background, specimens, labels, barcodes, and type labels

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms

from segmentation.datasets import Slides
from segmentation.instances import DiscriminativeLoss, mean_shift, visualise_embeddings, visualise_instances
from segmentation.network import SemanticInstanceSegmentation
from segmentation.training import train

# Define model
The model is a neural network with two heads: one for the semantic class embeddings, and one for the instance embedding. A discriminative loss function is used that encourages embeddings from the same instance to be closer to each other than to an embedding from any other instance

In [None]:
model = SemanticInstanceSegmentation().cuda()
instance_clustering = DiscriminativeLoss().cuda()

# Load data

In [None]:
transform = transforms.Compose([
    transforms.RandomRotation(5),
    transforms.RandomCrop((256, 768)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor()])

target_transform = transforms.Compose([transform, transforms.Lambda(lambda x: (x * 255).long())])

# WARNING: Don't use multiple workers for loading! Doesn't work with setting random seed
train_data = Slides(download=True, train=True, root='data', transform=transform, target_transform=target_transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=6, drop_last=True, shuffle=True)
test_data = Slides(download=True, train=False, root='data', transform=transform, target_transform=target_transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=6, drop_last=True, shuffle=True)

# Train

In [None]:
train(model, instance_clustering, train_loader, test_loader)

# Evaluate

In [None]:
model.load_state_dict(torch.load('models/epoch_800'))
model.eval()

train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=True)

image, labels, instances = next(iter(train_loader))

image = Variable(image)
instances = Variable(instances + 1)
logits, instance_embeddings = model(image.cuda())

current_logits = logits[0]
current_labels = labels[0, 0].cuda()
current_instances = instances[0].cuda()

predicted_class = current_logits.data.max(0)[1]
predicted_instances = [None] * 5
for class_index in range(5):
    mask = predicted_class.view(-1) == class_index
    if mask.max() > 0:
        label_embedding = instance_embeddings[0].view(1, instance_embeddings.shape[1], -1)[..., mask]
        label_embedding = label_embedding.data.cpu().numpy()[0]

        predicted_instances[class_index] = mean_shift(label_embedding)

### Visualise training results
Note that for _semantic_ segmentation the colours correspond to semantic classes, whereas for _instance_ segmentation the colours represent unique instances that can be in an arbitrary order - hence the ID number (colour) won't be the same as in the ground truth

In [None]:
plt.rcParams['image.cmap'] = 'Paired'

fig, axes = plt.subplots(3, 2, figsize=(15, 10))
for ax in axes.flatten(): ax.axis('off')

axes[0, 0].set_title('Original image')
axes[0, 0].imshow(image[0].data.numpy().transpose(1, 2, 0))
axes[1, 0].set_title('Ground truth classes')
axes[1, 0].imshow(current_labels.cpu().numpy().squeeze())
axes[2, 0].set_title('Ground truth instances')
axes[2, 0].imshow(current_instances.cpu().numpy().squeeze())
axes[1, 1].set_title('Predicted classes')
axes[1, 1].imshow(predicted_class.cpu().numpy().squeeze())
instance_image = visualise_instances(predicted_instances, predicted_class, num_classes=5)
axes[2, 1].set_title('Predicted instances')
axes[2, 1].imshow(instance_image)

### Explicity clear memory on GPU
<small>Since all variables in the script are in the same scope, there is no garbage collection until reassignment. Need to get rid of derived data before running model subsequent times</small>

In [None]:
del (logits, instance_embeddings, instance_image, image, labels,
     instances, current_logits, current_labels, current_instances,
     mask, label_embedding, predicted_class, predicted_instances)

### Evaluate on full image

In [None]:
image_original = torch.Tensor((plt.imread('test/010666874_816412_1428113.JPG') / 255).transpose(2, 0, 1)).unsqueeze(0)
image = F.pad(image_original, (-16, -16, -37, -37))
logits, instance_embeddings = model(image.cuda())
current_logits = logits[0]
predicted_class = F.pad(current_logits, (16, 16, 37, 37)).data.max(0)[1]
instance_embeddings = F.pad(instance_embeddings, (16, 16, 37, 37))[0]

predicted_instances = [None] * 5
for class_index in range(5):
    mask = predicted_class.view(-1) == class_index
    if mask.max() > 0:
        label_embedding = instance_embeddings.view(1, instance_embeddings.shape[0], -1)[..., mask]
        label_embedding = label_embedding.data.cpu().numpy()[0]

        predicted_instances[class_index] = mean_shift(label_embedding)

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(10, 12))
for ax in axes: ax.axis('off')

axes[0].set_title('Original image')
axes[0].imshow(image_original[0].data.numpy().transpose(1, 2, 0))

axes[1].set_title('Predicted classes')
axes[1].imshow(predicted_class.cpu().numpy())

axes[2].set_title('Predicted instances')
axes[2].imshow(visualise_instances(predicted_instances, predicted_class, num_classes=5))