In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
from pathlib import Path
from tqdm import tqdm

In [3]:
import cv2
from fastai import *
from fastai.vision import *

In [4]:
path = Path('/data/Datasets/WhiteBloodCancer/train/')

In [5]:
np.random.seed(42)

In [9]:
fnames = get_image_files(path, recurse=True)
pat = re.compile(r'^.*(hem|all).bmp$')

In [7]:
size = 224
bs = 64

In [11]:
data  = (ImageDataBunch.from_name_re(path, fnames, pat, size=size, bs=bs, valid_pct=0.1)).normalize()

In [12]:
from torch.autograd import Variable

def one_hot_embedding(labels, num_classes):
    return torch.eye(num_classes)[labels.data.cpu()]

class FocalLoss(nn.Module):
    def __init__(self, num_classes, alpha=0.25, gamma=1.):
        super().__init__()
        self.num_classes = num_classes
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, targ, reduction='none'):
        t = one_hot_embedding(targ, self.num_classes + 1)
        t = Variable(t[:, :-1].contiguous()).cuda()  # .cpu()
        x = pred[:, :-1]
        w = Variable(self.get_weight(x, t))
        return F.binary_cross_entropy_with_logits(x, t, w, size_average=False) / self.num_classes

    def get_weight(self,x,t):
        p = x.sigmoid()
        pt = p*t + (1-p)*(1-t)
        w = self.alpha*t + (1-self.alpha)*(1-t)
        return w * (1-pt).pow(self.gamma)

In [None]:
alphas = np.arange(0.0, 2.0, 0.1)
gammas = np.arange(0.0, 5.0, 0.1)

values = np.zeros((len(gammas), len(alphas)), dtype=np.float)
for i, a in enumerate(gammas):
    for j, g in enumerate(alphas):
        
        gc.collect();
        learn = create_cnn(data, models.resnet34, metrics=[error_rate], loss_func=FocalLoss(num_classes=1, alpha=a, gamma=g))
        learn.fit_one_cycle(4, 1e-2)
        
        learn.loss_func = data.loss_func
        interp = ClassificationInterpretation.from_learner(learn)
        
        values[i,j] = int(interp.confusion_matrix().flatten()[[1, 2]].sum())

epoch,train_loss,valid_loss,error_rate




In [None]:
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
fig = plt.figure()

Xm, Ym = np.meshgrid(alphas, gammas)

ax = Axes3D(plt.gcf())
ax.plot_surface(Xm, Ym, values)