In [1]:
import pyvips
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset as Dataset
from torch.autograd import Variable
import torch.nn.functional as F

from torchvision import datasets, models
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
plt.switch_backend('agg')

import time, os, copy, datetime, glob
from tqdm import tqdm
from skimage import io
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

In [2]:
format_to_dtype = {
    'uchar': np.uint8,
    'char': np.int8,
    'ushort': np.uint16,
    'short': np.int16,
    'uint': np.uint32,
    'int': np.int32,
    'float': np.float32,
    'double': np.float64,
    'complex': np.complex64,
    'dpcomplex': np.complex128,
}

# vips image to numpy array
def vips2numpy(vi):
    return np.ndarray(buffer=vi.write_to_memory(),
                      dtype=format_to_dtype[vi.format],
                      shape=[vi.height, vi.width, vi.bands])
NORM_PATH = '/BrainSeg/normalization.npy'
norm = np.load(NORM_PATH,allow_pickle=True).item()
print(norm)

trans = transforms.Compose([
#                               transforms.RandomHorizontalFlip(),
#                               transforms.RandomVerticalFlip(),
#                               transforms.RandomRotation(180),
#                               transforms.ColorJitter(brightness=0.1, contrast=0.2,saturation=0.2, hue=0.02),
#                               transforms.RandomAffine(0, translate=(0.05,0.05), scale=(0.9,1.1), shear=10),
                              transforms.ToTensor(),
                              transforms.Normalize(norm['mean'], norm['std'])
                              ])

{'mean': array([0.77906426, 0.74919518, 0.77529276]), 'std': array([0.13986633, 0.15931302, 0.17665639])}


In [3]:
import torch
from torch import nn
### comments starting with "###" are my (Toluwa's) notes

class CRF(nn.Module):
    def __init__(self, num_nodes, iteration=10):
        """Initialize the CRF module
        Args:
            num_nodes: int, number of nodes/patches within the fully CRF
            iteration: int, number of mean field iterations, e.g. 10
        """
        super(CRF, self).__init__()
        self.num_nodes = num_nodes
        self.iteration = iteration
        self.W = nn.Parameter(torch.zeros(1, num_nodes, num_nodes))

    def forward(self, feats, logits):
        """Performing the CRF. Algorithm details is explained below:
        Within the paper, I formulate the CRF distribution using negative
        energy and cost, e.g. cosine distance, to derive pairwise potentials
        following the convention in energy based models. But for implementation
        simplicity, I use reward, e.g. cosine similarity to derive pairwise
        potentials. So now, pairwise potentials would encourage high reward for
        assigning (y_i, y_j) with the same label if (x_i, x_j) are similar, as
        measured by cosine similarity, pairwise_sim. For
        pairwise_potential_E = torch.sum(
            probs * pairwise_potential - (1 - probs) * pairwise_potential,
            dim=2, keepdim=True
        )
        This is taking the expectation of pairwise potentials using the current
        marginal distribution of each patch being tumor, i.e. probs. There are
        four cases to consider when taking the expectation between (i, j):
        1. i=T,j=T; 2. i=N,j=T; 3. i=T,j=N; 4. i=N,j=N
        probs is the marginal distribution of each i being tumor, therefore
        logits > 0 means tumor and logits < 0 means normal. Given this, the
        full expectation equation should be:
        [probs * +pairwise_potential] + [(1 - probs) * +pairwise_potential] +
                    case 1                            case 2
        [probs * -pairwise_potential] + [(1 - probs) * -pairwise_potential]
                    case 3                            case 4
        positive sign rewards logits to be more tumor and negative sign rewards
        logits to be more normal. But because of label compatibility, i.e. the
        indicator function within equation 3 in the paper, case 2 and case 3
        are dropped, which ends up being:
        probs * pairwise_potential - (1 - probs) * pairwise_potential
        In high level speaking, if (i, j) embedding are different, then
        pairwise_potential, as computed as cosine similarity, would approach 0,
        which then as no affect anyway. if (i, j) embedding are similar, then
        pairwise_potential would be a positive reward. In this case,
        if probs -> 1, then pairwise_potential promotes tumor probability;
        if probs -> 0, then -pairwise_potential promotes normal probability.
        Args:
            feats: 3D tensor with the shape of
            [batch_size, num_nodes, embedding_size], where num_nodes is the
            number of patches within a grid, e.g. 9 for a 3x3 grid;
            embedding_size is the size of extracted feature representation for
            each patch from ResNet, e.g. 512
            logits: 3D tensor with shape of [batch_size, num_nodes, 1], the
            logit of each patch within the grid being tumor before CRF
        Returns:
            logits: 3D tensor with shape of [batch_size, num_nodes, 1], the
            logit of each patch within the grid being tumor after CRF
        """
        ###We can formulate the above as 0 for bg
        ###0.5 for white matter, 1 for grey matter
        ###p=2 means nuclear norm
        #print("input logits are with shape", logits, logits.shape)
        feats_norm = torch.norm(feats, p=2, dim=2, keepdim=True)
        pairwise_norm = torch.bmm(feats_norm,
                                  torch.transpose(feats_norm, 1, 2))
        pairwise_dot = torch.bmm(feats, torch.transpose(feats, 1, 2))
        # cosine similarity between feats
        pairwise_sim = pairwise_dot / pairwise_norm
        # symmetric constraint for CRF weights
        W_sym = (self.W + torch.transpose(self.W, 1, 2)) / 2
        pairwise_potential = pairwise_sim * W_sym
        unary_potential = logits.clone()

        for i in range(self.iteration):
            # current Q after normalizing the logits
            ###probs = torch.transpose(logits.sigmoid(), 1, 2)
            #print("logits before", logits.shape, pairwise_potential.shape)
            probs = torch.transpose(logits.softmax(2, torch.float32), 1, 2)
            #print("logits now", probs, probs.shape, pairwise_potential.shape)
            
            # taking expectation of pairwise_potential using current Q
            ###Toluwa - this may need updating because original was formulated assuming there were only 2 classes
            
            ##Probability that it belongs to class 0
            pairwise_potential_E0 = torch.sum(
                probs[:,0:1,:] * pairwise_potential - (1 - probs[:,0:1,:]) * pairwise_potential,
                dim=2, keepdim=True)
            
            ##Probability that it belongs to class 1
            pairwise_potential_E1 = torch.sum(
            probs[:,1:2,:] * pairwise_potential - (1 - probs[:,0:1,:]) * pairwise_potential,
            dim=2, keepdim=True)
            
            ##Probability that it belongs to class 2          
            pairwise_potential_E2 = torch.sum(
            probs[:,2:,:] * pairwise_potential - (1 - probs[:,0:1,:]) * pairwise_potential,
            dim=2, keepdim=True)
            
            pairwise_potential_E = torch.cat((pairwise_potential_E0, pairwise_potential_E1), 2)
            pairwise_potential_E = torch.cat((pairwise_potential_E, pairwise_potential_E2), 2)
            
            #print("unary potential has shape", unary_potential.shape, pairwise_potential_E.shape)
            logits = unary_potential + pairwise_potential_E

        #print("Logits shape is", logits.shape, logits)
        return logits

    def __repr__(self):
        return 'CRF(num_nodes={}, iteration={})'.format(
            self.num_nodes, self.iteration
        )

In [4]:
import torch
import torch.nn as nn
import math

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=3, num_nodes=1,
                 use_crf=True):
        """Constructs a ResNet model.
        Args:
            num_classes: int, since we are doing binary classification
                (tumor vs normal), num_classes is set to 1 and sigmoid instead
                of softmax is used later
            num_nodes: int, number of nodes/patches within the fully CRF
            use_crf: bool, use the CRF component or not
        """
        ###Jokes, we're NOT doing binary classification so modify this for multi-classification
        ###Mainly we're doing softmax and setting num_classes to 3

        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        #print("Model debug block", block, block.expansion)
        self.crf = CRF(num_nodes) if use_crf else None

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        """
        Args:
            x: 5D tensor with shape of
            [batch_size, grid_size, 3, crop_size, crop_size],
            where grid_size is the number of patches within a grid (e.g. 9 for
            a 3x3 grid); crop_size is 224 by default for ResNet input;
        Returns:
            logits, 2D tensor with shape of [batch_size, grid_size], the logit
            of each patch within the grid being tumor
        """
        #print("X shape is", x.shape)
        batch_size, grid_size, _, crop_size = x.shape[0:4]
        # flatten grid_size dimension and combine it into batch dimension
        x = x.view(-1, 3, crop_size, crop_size)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        # feats means features, i.e. patch embeddings from ResNet
        feats = x.view(x.size(0), -1)
        #print("feats shape", feats.shape, feats[0].shape)
        #print("feats are", feats)

        logits = self.fc(feats)

        # restore grid_size dimension for CRF
        feats = feats.view((batch_size, grid_size, -1))
        logits = logits.view((batch_size, grid_size, -1))

        if self.crf:
            logits = self.crf(feats, logits)
            
        #print("Final logits shape before squeezeis ", logits.shape, logits)

        logits = torch.squeeze(logits)
        ##Toluwa adding this to return which class we're picking
        ##print("Final logits shape is ", logits.shape, logits)
        #_, logits = torch.max(logits, 2)
        #print("Final logits shape 2 is", logits.shape, logits)
        return logits


def resnet18(**kwargs):
    """Constructs a ResNet-18 model.
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)

    return model


def resnet34(**kwargs):
    """Constructs a ResNet-34 model.
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)

    return model


def resnet50(**kwargs):
    """Constructs a ResNet-50 model.
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)

    return model


def resnet101(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)

    return model


def resnet152(**kwargs):
    """Constructs a ResNet-152 model.
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)

    return model

In [5]:
MODELS = {'resnet18': resnet18,
          'resnet34': resnet34,
          'resnet50': resnet50,
          'resnet101': resnet101,
          'resnet152': resnet152}

In [6]:
BATCH_SIZE = 1
PATCH_SIZE = 3 

IMG_HEIGHT = PATCH_SIZE*256
IMG_WIDTH  = PATCH_SIZE*256
cfg = {
 "model": "resnet18",
 "use_crf": True,
 "batch_size": BATCH_SIZE,
 "image_size": IMG_HEIGHT,
 "patch_size": 256,
 "crop_size": 768,
 "lr": 0.0001,
 "momentum": 0.9,
 "epoch": 20,
 "log_every": 100
}


if cfg['image_size'] % cfg['patch_size'] != 0:
    raise Exception('Image size / patch size != 0 : {} / {}'.format(cfg['image_size'], cfg['patch_size']))

patch_per_side = cfg['image_size'] // cfg['patch_size']
grid_size = patch_per_side * patch_per_side

In [21]:
BASE_DIR = '/BrainSeg/'
IMG_DIR  = BASE_DIR + 'norm_png/'
SAVE_DATA_DIR = BASE_DIR + 'Classify_Results/CRF_0913/GMBData/'
if not os.path.exists(SAVE_DATA_DIR):
    os.makedirs(SAVE_DATA_DIR)
print(SAVE_DATA_DIR)
SAVE_IMG_DIR = BASE_DIR + 'Classify_Results/CRF_0913/GMBMasks/'
if not os.path.exists(SAVE_IMG_DIR):
    os.makedirs(SAVE_IMG_DIR)
print(SAVE_IMG_DIR)

filenames = glob.glob(IMG_DIR + '*.png')
filenames = [filename.split('/')[-1] for filename in filenames]
filenames = sorted(filenames)
total_image_num = len(filenames)

idx_image_to_eval = slice(21,22)      # Modify this line to select images to eval
filenames = filenames[idx_image_to_eval]
print('Run for the following %d images out of %d images:\n' % (len(filenames), total_image_num), filenames)

/BrainSeg/Classify_Results/CRF_0913/GMBData/
/BrainSeg/Classify_Results/CRF_0913/GMBMasks/
Run for the following 1 images out of 30 images:
 ['NA4945-02_AB17-24.png']


In [9]:
def act(image):
    image = trans(image)
    image = np.array(image, dtype=np.float32).transpose((0, 1, 2))
    _patch_size = 256
    _crop_size = 224
    #flatten it 
    img_flat = np.zeros(
    (9, 3, _crop_size, _crop_size),
    dtype=np.float32)

    ### This part is from the NCRF code
    idx = 0
    for x_idx in range(3):
        for y_idx in range(3):
            # center crop each patch
            x_start = int(
                (x_idx + 0.5) * _patch_size - _crop_size / 2)
            x_end = x_start + _crop_size
            y_start = int(
                (y_idx + 0.5) * _patch_size - _crop_size / 2)
            y_end = y_start + _crop_size
            img_flat[idx] = image[:, x_start:x_end, y_start:y_end]

            idx += 1

    return img_flat

In [10]:
model = torch.load('/BrainSeg/Codes/CRF/LF0811_Checkpoints/PatchedCRF_3x3_23.pkl')
model = model.cuda()
model = model.eval()

In [20]:
total = sum([param.nelement() for param in model.parameters()])
# print("Number of parameter: %.6fM" % (total/1e6))
print("Number of parameter: %.1f" % (total))

Number of parameter: 11178132.0


In [10]:
WSI = pyvips.Image.new_from_file(IMG_DIR + filenames[0])

In [15]:
print('0')
img = WSI.extract_area(0,0,768,768)
print('1')
# img.write_to_file('/BrainSeg/Codes/CRF/Checkpoints/img.png')
# img = Image.open('/BrainSeg/Codes/CRF/Checkpoints/img.png')
img = vips2numpy(img)
img = Image.fromarray(img, 'RGB')

img = act(img)
print('4')
img = torch.from_numpy(img)
print('5')            
#             print(tile_img)
tile_img = img.unsqueeze(0)

tile_img = Variable(tile_img)
print('6')
tile_img = tile_img.cuda()
print('7')
predict = model(tile_img)
print('8')
_, p = torch.max(predict.data, 1)
print(p[4].item())

0
1
4
5
6
7
8
0


In [42]:
# pred = predict[4,:]
print(pred)
_, p = torch.max(predict.data, 1)
print(p[4].item())

tensor([ 4.0174, -1.8069, -0.2735], device='cuda:0', grad_fn=<SliceBackward>)
0


In [22]:
# Start evaluating
t = tqdm(total=len(filenames))
# print(filenames)
for filename in filenames:
  # print(filename)
  # in_img = np.array(Image.open(IMG_DIR + filename)) # Out of RAM
    in_img = pyvips.Image.new_from_file(IMG_DIR + filename)
  
    t.set_description_str("Image " + filename + ' - (%d, %d)' % (in_img.width, in_img.height))
    t.refresh()
    t.write("", end=' ')

    num_cols = int((in_img.width-768)/128)
    num_rows = int((in_img.height-768)/128)
    nums = np.zeros((num_rows, num_cols), dtype='uint8')
#     PRED = np.zeros((num_rows, num_cols, 3), dtype='uint8')

    for i in range(num_rows):
        for j in range(num_cols):
#             w, h = 256, 256
            w, h = 768, 768
#             if i == num_rows - 1: # if at last row
#                 h = 128
#             if j == num_cols - 1: # if at last column
#                 w = 128
            
            tile_img = in_img.extract_area(128*j,128*i,w,h) # c, r, w, h
            tile_img = vips2numpy(tile_img)
#             print(tile_img.shape)
#             if i == num_rows - 1: # if at last row
#                 tile_img = np.pad(tile_img, ((0, 128),(0, 0),(0, 0)), 'constant', constant_values=((0, 0),))
       
#             if j == num_cols - 1: # if at last column
#                 tile_img = np.pad(tile_img, ((0, 0),(0, 128),(0, 0)), 'constant', constant_values=((0, 0),))
      
            tile_img = Image.fromarray(tile_img, 'RGB')
            tile_img = act(tile_img)
            tile_img = torch.from_numpy(tile_img)
            
#             print(tile_img)
            tile_img = tile_img.unsqueeze(0)
            tile_img = Variable(tile_img)
            tile_img = tile_img.cuda()
            predict = model(tile_img)

            # a = predict.data[0,0].item()
            # b = predict.data[0,1].item()
            # c = predict.data[0,2].item()
#             PRED[i, j, 0] = predict.data[0,0].item()
#             PRED[i, j, 1] = predict.data[0,1].item()
#             PRED[i, j, 2] = predict.data[0,2].item()
            # print('a:',a,'b:',b,'c',c)
            # print('predict.data',predict.data)
#             _, predict = torch.max(predict.data, 1)
            _, p = torch.max(predict.data, 1)
            value = p[4].item()
#             print(value)
            #print(type(value)) #<class 'int'>
            nums[i, j] = value
      
    np.save(SAVE_DATA_DIR + filename.split('.')[-2] + '.npy', nums)
#     np.save(SAVE_DATA_DIR + filename.split('.')[-2] + 'value.npy', PRED)
    nums = np.repeat(nums[:, :, np.newaxis], 3, axis=2)
    #print(nums.shape, nums.dtype)

    # nums[:,:,0] = RED, nums[:,:,1] = Green, nums[:,:,2] = Blue
    idx_1 = np.where(nums[:,:,0] == 1)  # Index of label 1 (WM)
    idx_2 = np.where(nums[:,:,0] == 2)  # Index of label 2 (GM)

    # For label 0, leave as black color
    # For label 1, set to cyan color: R0G255B255
    nums[:,:,0].flat[np.ravel_multi_index(idx_1, nums[:,:,0].shape)] = 0
    nums[:,:,1].flat[np.ravel_multi_index(idx_1, nums[:,:,1].shape)] = 255
    nums[:,:,2].flat[np.ravel_multi_index(idx_1, nums[:,:,2].shape)] = 255
    # For label 2, set to yellow color: R255G255B0
    nums[:,:,0].flat[np.ravel_multi_index(idx_2, nums[:,:,0].shape)] = 255
    nums[:,:,1].flat[np.ravel_multi_index(idx_2, nums[:,:,1].shape)] = 255
    nums[:,:,2].flat[np.ravel_multi_index(idx_2, nums[:,:,2].shape)] = 0

    save_img = Image.fromarray(nums, 'RGB')
    save_img.save(SAVE_IMG_DIR + filename.split('.')[-2] + '.png')

    t.update()
t.close()

Image NA4945-02_AB17-24.png - (49800, 43128):   0%|          | 0/1 [00:00<?, ?it/s]

 

Image NA4945-02_AB17-24.png - (49800, 43128): 100%|██████████| 1/1 [56:27<00:00, 3387.55s/it]
