<a href="https://colab.research.google.com/github/AnnyKong/svm-cnn-idc-detection/blob/master/cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
BASE_DIR = '/content/breast-histopathology'
IMG_DIM = 50

In [0]:
! [ ! -d $BASE_DIR ] && git clone https://nick_lrc@bitbucket.org/nick_lrc/breast-histopathology.git
% cd $BASE_DIR
! rm -rf [0-9]*.pt

In [0]:
import os
import glob
from PIL import Image

negatives = []
positives = []

for file in glob.glob('data/0/*'):
  image = Image.open(file)
  if image.size == (IMG_DIM, IMG_DIM):
    negatives.append(file)

for file in glob.glob('data/1/*'):
  image = Image.open(file)
  if image.size == (IMG_DIM, IMG_DIM):
    positives.append(file)

print(f'Negative: {len(negatives)}')
print(f'Positive: {len(positives)}')

In [0]:
EPOCHS = 8
BATCH_SIZE = 128
LOG_INTERVAL = 200
LR = 0.1
# MOMENTUM = 0.9
WEIGHT_DECAY = 0
# LR_STEP_SIZE = 5
# GAMMA = 0.1

In [0]:
import numpy as np

def train_validate_test_split(data, train_ratio=0.8):
  np.random.shuffle(data)
  train, test = np.split(data, [int(train_ratio * len(data))])
  train, validate = np.split(train, [int(train_ratio * len(train))])
  return train, validate, test

def negative_positive_merge(negatives, positives, negative_label=0, positive_label=1):
  images = np.concatenate([negatives, positives])
  labels = np.array([negative_label] * len(negatives) + [positive_label] * len(positives))
  indices = np.random.permutation(len(images))
  return images[indices], labels[indices]

negative_train, negative_validate, negative_test = train_validate_test_split(negatives)
positive_train, positive_validate, positive_test = train_validate_test_split(positives)
image_train, label_train = negative_positive_merge(negative_train, positive_train)
image_validate, label_validate = negative_positive_merge(negative_validate, positive_validate)
image_test, label_test = negative_positive_merge(negative_test, positive_test)

print(f'Train   : {len(image_train)}')
print(f'Validate: {len(image_validate)}')
print(f'Test    : {len(image_test)}')

In [0]:
from torch.utils.data import Dataset, DataLoader
import torch
import multiprocessing

class BreastHistopathologyDataset(Dataset):

  def __init__(self, images, labels):
    self.images = images
    self.labels = labels

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    image = np.array(Image.open(self.images[index])) / 255.
    image = np.moveaxis(image, -1, -3).astype(np.float32)
    label = self.labels[index]
    return (image, label)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': multiprocessing.cpu_count(), 'pin_memory': True} if use_cuda else {}

train_loader = DataLoader(BreastHistopathologyDataset(image_train, label_train), 
                          batch_size=BATCH_SIZE, shuffle=True, **kwargs)
validate_loader = DataLoader(BreastHistopathologyDataset(image_validate, label_validate), 
                             batch_size=BATCH_SIZE, shuffle=True, **kwargs)
test_loader = DataLoader(BreastHistopathologyDataset(image_test, label_test), 
                         batch_size=BATCH_SIZE, shuffle=True, **kwargs)

In [0]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

from skimage.feature import hog

def layer_output_dimension(in_dimension, kernel_size, stride=1, padding=0):
  return (in_dimension - kernel_size + 2 * padding) // stride + 1

class BreastHistopathologyClassifier(nn.Module):
  
  def __init__(self):
    super(BreastHistopathologyClassifier, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(3, 32, 3),
        nn.ReLU(),
        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
    )

    self.conv_out_dim = layer_output_dimension(IMG_DIM, 3)
    self.conv_out_dim = layer_output_dimension(self.conv_out_dim, 3)
    self.conv_out_dim = layer_output_dimension(self.conv_out_dim, 2, 2)
    self.conv_out_dim *= self.conv_out_dim * 64

    self.fc = nn.Sequential(
        nn.Linear(self.conv_out_dim, 128),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(128, 2),
    )
    self.accuracy = 0
  
  def forward(self, images):
    out = self.conv(images)
    out = out.view(-1, self.conv_out_dim)
    return self.fc(out)

  def loss(self, pred, label, reduction='mean'):
    return F.cross_entropy(pred, label.squeeze().long(), reduction=reduction)
        
  def save_best_model(self, accuracy, dest):
    if self.accuracy < accuracy:
      self.accuracy = accuracy
      torch.save(self.state_dict(), dest)
      print(f'Saved best model to {dest}')

class BreastHistopathologyClassifierWithHog(nn.Module):

  def __init__(self):
    super(BreastHistopathologyClassifierWithHog, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(3, 32, 3),
        nn.ReLU(),
        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
    )

    self.hog_ppc = 5
    self.hog_cpb = 10

    self.conv_out_dim = layer_output_dimension(IMG_DIM, 3)
    self.conv_out_dim = layer_output_dimension(self.conv_out_dim, 3)
    self.conv_out_dim = layer_output_dimension(self.conv_out_dim, 2, 2)
    self.conv_out_dim *= self.conv_out_dim * 64

    self.fc = nn.Sequential(
        nn.Linear(self.conv_out_dim + self.hog_ppc*self.hog_ppc*8, 128),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(128, 2),
    )
    self.accuracy = 0

  def loss(self, pred, label, reduction='mean'):
    return F.cross_entropy(pred, label.squeeze().long(), reduction=reduction)
        
  def save_best_model(self, accuracy, dest):
    if self.accuracy < accuracy:
      self.accuracy = accuracy
      torch.save(self.state_dict(), dest)
      print(f'Saved best model to {dest}')

  def forward(self, images):
    out = self.conv(images)
    out = out.view(-1, self.conv_out_dim)
    
    h = lambda i : hog(i, orientations=8, pixels_per_cell=(10,10), cells_per_block=(5, 5),block_norm= 'L2')
    im_cpu = images.cpu()
    im_cpu = [numpy.]
    h_out = [h(im) for im in im_cpu]
    h_out = torch.Tensor(h_out).to(device)

    out = torch.cat((h_out, out), dim=1)

    return self.fc(out)

model = BreastHistopathologyClassifier().to(device)
model.load_state_dict(torch.load('LR0_1_016.pt'))
optimizer = optim.Adadelta(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=GAMMA)

In [0]:
import time
import traceback
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

def train(model, device, train_loader, optimizer, epoch, log_interval):
  model.train()
  loss_sum = 0

  for i, (image, label) in enumerate(train_loader):
    image = image.to(device)
    label = label.to(device)
    optimizer.zero_grad()
    pred = model(image)
    loss = model.loss(pred, label)
    loss_sum += loss.item()
    loss.backward()
    optimizer.step()

    if i % log_interval == 0 or i == len(train_loader) - 1:
      print('{} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            time.ctime(time.time()),
            epoch, 
            i * len(image), 
            len(train_loader.dataset),
            100.0 * i / len(train_loader), 
            loss.item()))
  return loss_sum / len(train_loader.dataset)

def evaluate(model, device, eval_loader, eval_type):
  model.eval()
  loss_sum = 0
  num_correct = 0

  with torch.no_grad():
    for i, (image, label) in enumerate(eval_loader):
      image = image.to(device)
      label = label.to(device)
      pred = model(image)
      loss = model.loss(pred, label, reduction='sum')
      loss_sum += loss.item()
      pred = pred.max(1)[1]
      correct_mask = pred.eq(label.view_as(pred))
      num_correct += correct_mask.sum().item()

  loss = loss_sum / len(eval_loader.dataset)
  accuracy = 100.0 * num_correct / len(eval_loader.dataset)
  print('{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        eval_type, loss, num_correct, len(eval_loader.dataset), accuracy))
  return loss, accuracy

def eval_to_confusion_matrix(model, device, eval_loader):
  model.eval()
  loss_sum = 0
  result = [[0, 0], [0, 0]]

  with torch.no_grad():
    for i, (image, label) in enumerate(eval_loader):
      image = image.to(device)
      label = label.to(device)
      pred = model(image)
      loss = model.loss(pred, label, reduction='sum')
      loss_sum += loss.item()
      pred = pred.max(1)[1]
      cm = confusion_matrix(label.view_as(pred).cpu(), pred.cpu())
      result[0][0] += cm[0][0]
      result[0][1] += cm[0][1]
      result[1][0] += cm[1][0]
      result[1][1] += cm[1][1]

  loss = loss_sum / len(eval_loader.dataset)
  return loss, result

def plot(epochs, history, title, history_label, figsize=(20, 10)):
  plt.figure(figsize=figsize)
  plt.plot(epochs, history)
  plt.title(title)
  plt.xlabel("epoch")
  plt.ylabel(history_label)
  plt.savefig(f'{title.lower().replace(" ", "_")}.png')
  plt.show()

In [0]:
train_loss_history = []
validate_loss_history = []
validate_accuracy_history = []
validate_accuracy = 0

try:
  for epoch in range(1, EPOCHS + 1):
    train_loss = train(model, device, train_loader, optimizer, epoch, LOG_INTERVAL)
    validate_loss, validate_accuracy = evaluate(model, device, validate_loader, "Validate")
    train_loss_history.append((epoch, train_loss))
    validate_loss_history.append((epoch, validate_loss))
    validate_accuracy_history.append((epoch, validate_accuracy))
    model.save_best_model(validate_accuracy, f'{epoch:03}.pt')

  evaluate(model, device, test_loader, "Test")

except KeyboardInterrupt as ke:
  print('Interrupted')
except:
  traceback.print_exc()
finally:
  model.save_best_model(validate_accuracy, f'{epoch:03}.pt')
  epochs, losses = zip(*train_loss_history)
  plot(epochs, losses, 'Train Loss', 'loss')
  epochs, losses = zip(*validate_loss_history)
  plot(epochs, losses, 'Validation Loss', 'loss')
  epochs, accuracies = zip(*validate_accuracy_history)
  plot(epochs, accuracies, 'Validation Accuracy', 'accuracy')

In [0]:
def accuracy(cm):
  tn = cm[0][0]
  fp = cm[0][1]
  fn = cm[1][0]
  tp = cm[1][1]
  return (tn+tp)/(tn+fp+fn+tp)
def recall(cm):
  tn = cm[0][0]
  fp = cm[0][1]
  fn = cm[1][0]
  tp = cm[1][1]
  return tp/(tp+fn)
def precision(cm):
  tn = cm[0][0]
  fp = cm[0][1]
  fn = cm[1][0]
  tp = cm[1][1]
  return tp/(tp+fp)
def all_info(tn, fp, fn, tp):
  cm = [[tn, fp], [fn, tp]]
  print('a', accuracy(cm))
  print('r', recall(cm))
  print('p', precision(cm))

In [0]:
model = BreastHistopathologyClassifierWithHog().to(device)
model.load_state_dict(torch.load('LR0_1_Hog_010.pt'))

loss, cm = eval_to_confusion_matrix(model, device, test_loader)
print(loss)
print(cm)
print(accuracy(cm))

In [0]:
# demo

import random
from mpl_toolkits.axes_grid1 import ImageGrid

def try_img(model, img_path, ground_truth):
  model.eval()

  img_display = Image.open(img_path)

  im = np.array(img_display) / 255.
  im = np.moveaxis(im, -1, -3).astype(np.float32)
  image = torch.Tensor([im]).to(device)
  label = torch.Tensor([ground_truth]).to(device)
  pred = model(image)
  pred = pred.max(1)[1]
  return img_display, pred.sum().item(), ground_truth

def show_grid(imgs, preds, truths):
  fig = plt.figure(figsize=(12., 12.))
  grid = ImageGrid(fig, 111,
      nrows_ncols=(4, 5),
      axes_pad=.8,
      )

  for ax, im, p, t in zip(grid, imgs, preds, truths):
    ax.imshow(im)
    ax.set_title('ground truth: {}\nprediction: {}'.format(t, p))
  plt.savefig('demo.png')

model = BreastHistopathologyClassifier().to(device)
model.load_state_dict(torch.load('LR0_1_016.pt'))

imgs = []
preds = []
truths = []
for i in range(10):
  rand = random.randrange(len(positive_test))
  im, p, t = try_img(model, positive_test[rand], 1)
  imgs.append(im)
  preds.append(p)
  truths.append(t)

for i in range(10):
  rand = random.randrange(len(negative_test))
  im, p, t = try_img(model, negative_test[rand], 0)
  imgs.append(im)
  preds.append(p)
  truths.append(t)

show_grid(imgs, preds, truths)


In [0]:
# calculate stats
# cnn
all_info(34891,4400,3139,12615)
all_info(36297,2994,2426,11508)
all_info(36647,2644,6126,9628)
all_info(36144,3147,3393,12361)

# svm
print('svm')
all_info(64284,33943,6114,32270)
all_info(49800,48427,4606,34778)
all_info(54092,44134,4919,34465)
all_info(56823,41404,5733,33651)
all_info(0,98227,0,39384)
all_info(49539,48688,4503,34881)
all_info(48188,50039,4248,35136)
all_info(3046,95181,274,39110)