In [None]:
import torch

In [None]:
torch.__version__

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
import os
import numpy as np
import cv2
from skimage.io import imread, imsave
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

**Define model**

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from torchvision import models

**Define the model**

In [None]:
class VGG19(nn.Module):

    def __init__(self, pretrained=False):
        super().__init__()

        vgg19_pretrained = models.vgg19(pretrained=pretrained)
        self.backbone = vgg19_pretrained.features
        self.avgpool = vgg19_pretrained.avgpool
        
        self.fc1 = nn.Linear(in_features=25088, out_features=4096)
        self.fc2 = nn.Linear(in_features=4096, out_features=4096)
        self.fc3 = nn.Linear(in_features=4096, out_features=5)


    def forward(self, x):
        x = self.backbone(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        
        x = F.relu(self.fc1(x))
        
        x = F.relu(self.fc2(x))
        
        logits = self.fc3(x)

        return x, logits

**Load pre-trained ckpts**

In [None]:
net = VGG19(pretrained=False)
net.load_state_dict(torch.load("../../ckpt/vgg19_wsi224.pt"))
net.eval()

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net.to(device)

**Evaluate mode with pre-traine weights on the original scale**

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

batch_size = 32

testset = torchvision.datasets.ImageFolder("../../data/path-dt-msu-wsi/val", transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)


In [None]:
correct = 0
total = 0

with torch.no_grad():
    for data in tqdm(testloader):
        inputs, labels = data[0].to(device), data[1].to(device)

        _, outputs = net(inputs)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network: {100 * correct // total} %')

**Testing on real multi-scale WSI via OpenSlide library**

In [None]:
# The path can also be read from a config file, etc.
OPENSLIDE_PATH = r'C:\tools\openslide-win64-20171122\bin'

import os
if hasattr(os, 'add_dll_directory'):
    # Python >= 3.8 on Windows
    with os.add_dll_directory(OPENSLIDE_PATH):
        import openslide
else:
    import openslide
    
import large_image

from openslide import open_slide
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt

In [None]:
wsi_path = '../../data/wsi_image.svs'

In [None]:
slide = open_slide(wsi_path)

In [None]:
slide.properties

In [None]:
slide.level_dimensions

In [None]:
def maxmin_norm(img):
    img = img.astype(np.float32)
    img = (img - img.min()) / (img.max() - img.min())
    return img.astype(np.float32)

In [None]:
def get_image_scale(wsi, x_c, y_c, h, w):
    x_0 = x_c - h//2
    y_0 = y_c - w//2

    location = (x_0, y_0)
    smaller_region = slide.read_region(location, level=0, size=(h,w))
    smaller_region = np.array(smaller_region.convert('RGB'))
    
    return maxmin_norm(smaller_region)

**Play a little bit with different scales extraction**

In [None]:
# x_c = 2000
# y_c = 5000

# x_c = 16000
# y_c = 16000

# x_c = 9000
# y_c = 9000

x_c = 2000
y_c = 6000

h_orig, w_orig = 64, 64

In [None]:
scale = 10

smaller_region = get_image_scale(slide, x_c, y_c, h_orig*scale, w_orig*scale)

In [None]:
plt.imshow(smaller_region)

**Feed multi-scale samples to the fine-tuned NN**

In [None]:
from PIL import Image
from matplotlib import cm

In [None]:
apply_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

**Review predictions and model's confidence in the predictions on various scales**

In [None]:
X = []
SCALE_GRID =  [1,2,4,8,10,20,30,40,50] 
for scale in SCALE_GRID:
    slide_region = get_image_scale(slide, x_c, y_c, h_orig*scale, w_orig*scale)
    img = Image.fromarray((np.clip(slide_region,0,1)*255).astype(np.uint8))
    
    img_t = apply_transform(img)[None, ...]
    img_t = img_t.to(device)
    f, logits = net(img_t)
    X.append(f.cpu().detach().numpy()[0])
    
    
    probs = F.softmax(logits, dim=-1)
    pred = torch.argmax(probs).detach().cpu().numpy().item()
    
    print("Scale: {0}; Prediciton: {1}; Probability: {2:.2f}".format(scale,
                                                                     pred,
                                                                     probs.detach().cpu().numpy()[0][pred] * 100))

**Look into deep features activation...**

In [None]:
W = net.fc3.weight.cpu().detach().numpy().transpose()
W.shape

In [None]:
l2_norm = np.linalg.norm(W, ord=None, axis=0)  # Frobenius norm
l2_norm.shape

In [None]:
X = np.array(X)
X.shape

In [None]:
dot = X.dot(W) / l2_norm[np.newaxis, ...]

In [None]:
dot.shape
dot_pos = dot * (dot>0)

In [None]:
fig, ax = plt.subplots()
for i in range(5):
    ax.plot(SCALE_GRID, dot[:,i], label=str(i))
ax.legend()

In [None]:
fig.savefig("out.png")

**Visualize an entire WSI image**

In [None]:
slide_thimb_600 = slide.get_thumbnail(size=(600,600))

In [None]:
slide_thimb_600