# Attention-Based Deep Multiple Instance Learning
## Project: Implementation with Breast Histology Images
### Ayooluwa Odemuyiwa



# The Paper

 https://arxiv.org/pdf/1802.04712.pdf

# Loading and Processing Images

The Breast Cancer images can be found here: https://bioimage.ucsb.edu/research/bio-segmentation

This histopathobiology dataset consists of 58 896 x 768 H&E Images. One image was not donwloadable, so we worked with 57 instead. Each large image is labeled as either


*   benign *or*
*   malignant

Every image is divided into 32 × 32 patches. If an image is more than 75% white, we discard that image. I run the script locally and saved in a local directory.

The following code snippet was helpful in figuring out how to tile the actual images, but more work was done. to make the bags : https://gist.github.com/zeraien/2503530


Below, I actually create folders for the different bags in my directory. I do this by looping over the files in my directory and assigning labels to the folder names after tiling my images

In [None]:
from PIL import Image
import sys
import os
import cv2
import numpy as np

def pixel_count(im):
    '''Returns true if the image is 75% or more white pixels and True otherwise'''
    white = np.sum(im==255)
    percent = (float(white)) / (float(32*32))
    if (percent >= 0.75):
        return False
    else:
        return True

def make_patches(tile_width, tile_height, image, title, bag_num):
  if image.size[0] % tile_width == 0 and image.size[1] % tile_height ==0:
    currentx = 0
    currenty = 0
    while currenty < image.size[1]:
        while currentx < image.size[0]:
            tile = image.crop((currentx,currenty,currentx + tile_width,currenty + tile_height))
            if ("benign" in title):
               # print("entered if")
                label = 'B'
            else:
                label = 'C'
            bw = np.array(tile)
            if(pixel_count(bw) == False):
                return
            if (os.path.isdir("Bag" + str(bag_num) + "L" + label) == False):
                os.mkdir("Bag" + str(bag_num) + "L" + label)
            tile.save("Bag" + str(bag_num) + "L" + label + "/" + "x" + str(currentx) + "y" + str(currenty) + label + ".png","PNG")
            currentx += tile_width
        currenty += tile_height
        currentx = 0
  else:
    print("The image is not able to be split")

# assign directory
directory = 'breast_cells'

# iterate over files in that directory

#making 57 bags
bag = 1
for filename in os.listdir(directory):
    f = os.path.join(directory, filename)
    if f.endswith("tif"):
        image = Image.open(f)
        make_patches(32, 32,image, f, bag)
        bag = bag + 1


# The Model

Some parameters were adjusted that differ from the

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.L = 500
        self.D = 128
        self.K = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(3, 20, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50, self.L),
            nn.ReLU(),
        )

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.squeeze(0)

        H = self.feature_extractor_part1(x)
        H = H.view(-1, 50)
        H = self.feature_extractor_part2(H)  # NxL

        A = self.attention(H)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N

        M = torch.mm(A, H)  # KxL

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat, A

    # AUXILIARY METHODS
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        _, Y_hat, _ = self.forward(X)
        error = 1. - Y_hat.eq(Y).cpu().float().mean().data.item()

        return error, Y_hat

    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob, _, A = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob))  # negative log bernoulli

        return neg_log_likelihood, A

class GatedAttention(nn.Module):
    def __init__(self):
        super(GatedAttention, self).__init__()
        self.L = 500
        self.D = 128
        self.K = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(1, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50 * 4 * 4, self.L),
            nn.ReLU(),
        )

        self.attention_V = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh()
        )

        self.attention_U = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Sigmoid()
        )

        self.attention_weights = nn.Linear(self.D, self.K)

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.squeeze(0)

        H = self.feature_extractor_part1(x)
        H = H.view(-1, 50 * 4 * 4)
        H = self.feature_extractor_part2(H)  # NxL

        A_V = self.attention_V(H)  # NxD
        A_U = self.attention_U(H)  # NxD
        A = self.attention_weights(A_V * A_U) # element wise multiplication # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N

        M = torch.mm(A, H)  # KxL

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat, A

    # AUXILIARY METHODS
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        _, Y_hat, _ = self.forward(X)
        error = 1. - Y_hat.eq(Y).cpu().float().mean().item()

        return error, Y_hat

    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob, _, A = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob))  # negative log bernoulli

        return neg_log_likelihood, A



# New Section

In [None]:
!unzip data_bags.zip

## Created a DataLoader

In [None]:
"""Pytorch dataset object that loads MNIST dataset as bags."""

import numpy as np
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
import os
from tensorflow.keras.preprocessing.image import img_to_array, load_img
import numpy as np

class Histopathology(data_utils.Dataset):
    def __init__(self, img_dir, transform):
      self.img_dir = img_dir
      self.transform = transform
      self.bags, self.labels = self._create_bags()

    def _create_bags(self):
      bags = []
      path =  self.img_dir
      bag_label = []
      for each_path_idx in os.listdir(path):
          img = []
          for each_img in os.listdir(os.path.join(path, each_path_idx)):
              each_img = os.path.join(path, each_path_idx, each_img)
              img_raw = load_img(each_img).convert('RGB') # this is a PIL image
              img.append(self.transform(img_raw))
          if (each_img[-5] == 'C'):
              bag_label.append(1)
          if (each_img[-5] == 'B'):
              bag_label.append(0)
          img = tuple(img)
          img = torch.stack(img, 0)
          bags.append(img)

      self.labels = bag_label
      self.bags = bags
      return bags, bag_label

    def __len__(self):
       return len(self.labels)

    def __getitem__(self, index):
        bag = self.bags[index]
        label = self.labels[index]
        return bag, label



In [None]:
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler


path = "/content/data_bags"
transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([32, 32])] )
dataset = Histopathology(path, transform)

batch_size = 1
validation_split = .4
shuffle_dataset = True
random_seed= 100

# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=valid_sampler)



In [None]:
print(len(validation_loader))

2


In [None]:
import os
from tensorflow.keras.preprocessing.image import img_to_array, load_img
import numpy as np
def generate_batch(path):
    """Loads and prepare image data and labels for training.
    Parameters
    -----------------
    path : list
        List of paths to the image folders
    labels : list
        list of labels for each image folder
    input_dim : list
        list with height and with of the images to be passee to the model
    Returns
    -----------------
    bags : list
        List of Lists containing image data, labels, and path to the image
    """
    bags = []
    path =  "/content/breast_bags"
    for each_path_idx in os.listdir(path):
        img = []
        bag_label = []
        for each_img in os.listdir(os.path.join(path, each_path_idx)):
            each_img = os.path.join(path, each_path_idx, each_img)
            img_raw = load_img(each_img).convert('RGB') # this is a PIL image
            img_data = img_to_array(img_raw) / 255  # this is a Numpy array with shape (3, 256, 256)
            img_data = np.asarray(img_data).astype('float32')
            img_tensor = np.stack((img_data,)*3, axis=-1)
            img.append(img_tensor)
            if (each_img[-5] == 'C'):
              bag_label.append(1)
            if (each_img[-5] == 'B'):
              bag_label.append(0)
        bags.append((img, bag_label))

    return bags

In [None]:
bags = generate_batch("/content/breast_bags")

In [None]:
from six import b
from torch.autograd import Variable
import torch.optim as optim
import tensorflow as tf

model = Attention()
optimizer = optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999), weight_decay=10e-5)
def train(epoch, train_loader):
    model.train()
    train_loss = 0.
    train_error = 0.
    for batch_idx, (data, label) in enumerate(train_loader):
        bag_label = label[0]
        data, bag_label =  Variable(data), Variable(torch.from_numpy(np.array(bag_label)))
        # reset gradients
        optimizer.zero_grad()
        # calculate loss and metrics
        loss, _ = model.calculate_objective(data, bag_label)
        train_loss += loss.data[0]
        error, _ = model.calculate_classification_error(data, bag_label)
        train_error += error
        # backward pass
        loss.backward()
        # step
        optimizer.step()

    # calculate loss and error for epoch
    train_loss /= len(train_loader)
    train_error /= len(train_loader)

    print('Epoch: {}, Loss: {:.4f}, Train error: {:.4f}'.format(epoch, train_loss.cpu().numpy()[0], train_error))

def test(validation_loader):
    model.eval()
    test_loss = 0.
    test_error = 0.
    for batch_idx, (data, label) in enumerate(validation_loader):
        bag_label = label[0]
        x = torch.from_numpy(np.array(data, dtype="float32"))
        data, bag_label =  Variable(torch.from_numpy(np.array(data))), Variable(torch.from_numpy(np.array(bag_label)))
        loss, attention_weights = model.calculate_objective(data, bag_label)
        test_loss += loss.data[0]
        error, predicted_label = model.calculate_classification_error(data, bag_label)
        test_error += error

    test_error /= len(validation_loader)
    test_loss /= len(validation_loader)

    print('\nTest Set, Loss: {:.4f}, Test error: {:.4f}'.format(test_loss.cpu().numpy()[0], test_error))

Cross Validation



In [None]:
for i in range(10):
  total_size = len(dataset)
  fraction = 0.1
  seg = int(total_size * fraction)
    # tr:train,val:valid; r:right,l:left;  eg: trrr: right index of right side train subset
    # index: [trll,trlr],[vall,valr],[trrl,trrr]
  for i in range(10):
      trll = 0
      trlr = i * seg
      vall = trlr
      valr = i * seg + seg
      trrl = valr
      trrr = total_size
      train_left_indices = list(range(trll,trlr))
      train_right_indices = list(range(trrl,trrr))

      train_indices = train_left_indices + train_right_indices
      val_indices = list(range(vall,valr))

      train_set = torch.utils.data.dataset.Subset(dataset,train_indices)
      validation_set = torch.utils.data.dataset.Subset(dataset,val_indices)

      train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
      validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

      train(100, train_loader)
      test(validation_loader)


Epoch: 100, Loss: 0.6947, Train error: 0.4722

Test Set, Loss: 0.6910, Test error: 0.4500
Epoch: 100, Loss: 0.6916, Train error: 0.4444

Test Set, Loss: 0.6897, Test error: 0.4500
Epoch: 100, Loss: 0.6908, Train error: 0.4444

Test Set, Loss: 0.6889, Test error: 0.4500
Epoch: 100, Loss: 0.6936, Train error: 0.4722

Test Set, Loss: 0.6890, Test error: 0.4500
Epoch: 100, Loss: 0.6889, Train error: 0.4444

Test Set, Loss: 0.6875, Test error: 0.4500
Epoch: 100, Loss: 0.6841, Train error: 0.4167

Test Set, Loss: 0.6855, Test error: 0.4500
Epoch: 100, Loss: 0.6904, Train error: 0.4444

Test Set, Loss: 0.6868, Test error: 0.4500
Epoch: 100, Loss: 0.6868, Train error: 0.4444

Test Set, Loss: 0.6837, Test error: 0.4500
Epoch: 100, Loss: 0.6853, Train error: 0.4444

Test Set, Loss: 0.6814, Test error: 0.4500
Epoch: 100, Loss: 0.6922, Train error: 0.4722


KeyboardInterrupt: ignored

In [None]:
train(300, train_loader)

Epoch: 300, Loss: 0.6886, Train error: 0.4722


In [None]:
test(validation_loader)


Test Set, Loss: 0.6782, Test error: 0.4500
