In [4]:
import os
import numpy as np
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import random
from datetime import datetime
import torchvision.models as models
from torchsummary import summary
import cv2 as cv
from scipy.stats import binom
import pandas as pd
import pickle

### Attention-MIL model
* Attention-MIL model basically consists of 3 sections:
    * (1) A feature extractor, in most case it is a pretrained CNN
    * (2) Attention Module, which is used to decide weights for all patches and aggregate them 
    * (3) Final classification, which is a one or two-layer FCN to produce the final probability score for bag
* In the following implementation:
    * (1) class `Attention_modern` implements 2) Attention Module and 3)Final classification and leave 1) feature extractor as the input
    * (2) There are 3 feature extractors provided (vgg, inception and resnet). For all of them, I chose to load the pretrained version and freeze the convolutional layers
* This is the manuscript [https://arxiv.org/abs/1802.04712] proposing MIL-attention for your reference.

In [7]:
def load_vgg16():
    vgg = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16', pretrained=True)
    num_layer = 0
    for child in vgg.children():
        num_layer+=1
        if num_layer < 3:
            for param in child.parameters():
                param.requires_grad = False  
    return vgg
def load_inception():
    inception = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
    num_layer = 0
    for child in inception.children():
        num_layer+=1
        if num_layer < 10:
            for param in child.parameters():
                param.requires_grad = False  
    return inception    
def load_resnet18():
    resnet = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)
    num_layer = 0
    for child in resnet.children():
        num_layer+=1
        if num_layer < 18:
            for param in child.parameters():
                param.requires_grad = False  
    return resnet  

class Attention_modern(nn.Module):
    def __init__(self,cnn,focal_loss=False):
        super(Attention_modern,self).__init__()
        self.L = 1000
        self.D = 64
        self.K = 1 
        self.focal_loss = focal_loss     
        self.feature_extractor = cnn      
        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K))
        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 1),
            nn.Sigmoid()
        )
    def forward(self,x):
        x = x .squeeze(0)
        H = self.feature_extractor(x)
        A = self.attention(H)
        A = torch.transpose(A,1,0)
        A = F.softmax(A,dim=1) # calculate weights
        M = torch.mm(A,H) # Aggregate features
        Y_prob = self.classifier(M)[0]
        Y_hat = torch.ge(Y_prob,0.5).float()
        return Y_prob, Y_hat, A
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        Y_prob, Y_hat,_ = self.forward(X)
        error = 1. - Y_hat.eq(Y).cpu().float().mean().data.item()

        return error, Y_hat, Y_prob  
    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob, _, A = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        if not self.focal_loss:
            neg_log_likelihood =-1. * (Y * torch.log(Y_prob)+(1. - Y) * torch.log(1. - Y_prob))
        else:
            if Y.cpu().data.numpy()[0]==0:
                Y_prob = 1-Y_prob            
            if Y_prob.cpu().data.numpy()[0]<0.2:
                gamma = 5
            else:
                gamma = 3
            neg_log_likelihood =-1. *(1-Y_prob)**gamma* torch.log(Y_prob)
        return neg_log_likelihood, A

In [8]:
model = Attention_modern(load_vgg16(),focal_loss=True)
print(model)

Using cache found in /cis/home/zwang/.cache/torch/hub/pytorch_vision_v0.6.0


Attention_modern(
  (feature_extractor): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inpla