In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [48]:
import os
from pathlib import Path
import skimage.external.tifffile as tiff

from common import Statistics, dataset_source
from resources.conv_learner import *
from resources.plots import *
from typing import Union, List

In [3]:
PATH = "datasets/yeast_v4.2/"
data_path = Path(PATH)

In [4]:
CLASSES = ('WT', 'mfb1KO', 'mfb1KO_mmr1KO', 'mmr1KO')
NUM_CLASSES = len(CLASSES)
BATCH_SIZE = 64
SIZE = 200

In [44]:
def Xdataset_source(source: Path) -> tuple:
    dir_dict = {}
    for ds_dir in [source / "train", source / "val", source / "test"]:
#         print(ds_dir)
        temp = []
        for class_dir in ds_dir.iterdir():
            if class_dir.is_dir:
                temp.append(class_dir)
        dir_dict[ds_dir.name] = temp
    return dir_dict
            


In [45]:
def Xread_class(class_dir: Path, class_images: List):
    for file in class_dir.iterdir():
        file_name = str(file)
        if '.tif' in file_name:
            image = tiff.imread(file_name)
            class_images.append(image)

In [97]:
data_dirs = Xdataset_source(data_path)
# norm_value=65536

for key in data_dirs.keys():
    tag = "_test" if key == 'test' else ""
    for class_dirs in data_dirs[key]:
        class_names = class_dirs.name + tag
        print(class_names)
        
        class_images = []
        for dir_ in class_dirs: # test, train
            # read from each dir and append to the images
            Statistics.read_class(dir_, class_images)
            
        print(f"working on: {class_name}")
        mean = np.mean(class_images, axis=(0, 2, 3)) / norm_value
        stdev = np.std(class_images, axis=(0, 2, 3)) / norm_value

        stats[class_name] = (mean, stdev)


Cit1_MC_mfb1KO
Cit1_MC_mfb1KO_mmr1KO
Cit1_MC_mmr1KO
Cit1_MC_WT
Cit1_MC_mfb1KO
Cit1_MC_mfb1KO_mmr1KO
Cit1_MC_mmr1KO
Cit1_MC_WT
Cit1_MC_mfb1KO_test
Cit1_MC_mfb1KO_mmr1KO_test
Cit1_MC_mmr1KO_test
Cit1_MC_WT_test


In [None]:
stats_name = "yeast_v4.2_per_class.dict"
train_dirs, val_dirs, test_dirs  = dataset_source(data_path) 
stats_dict = Statistics.per_class(train_dirs, val_dirs, test_dirs ,save_name=stats_name)

In [None]:
create, lbl2index = ImageClassifierData.prepare_from_path(PATH, val_name='test', test_name='yeast_v3_test_v1', test_with_labels=True, bs=64)
stats_dict_ = {lbl2index[key]: val for key, val in stats_dict.items()}
print(stats_dict)
print(lbl2index)
print(stats_dict_)

In [5]:
def get_data(path: str, sz, bs, stats):
    create, lbl2index = ImageClassifierData.prepare_from_path(path, val_name='test', test_name='yeast_v3_test_v1', test_with_labels=True, bs=bs)
    stats_dict = {lbl2index[key]: val for key, val in stats.items()}
    tfms = tfms_from_stats(stats_dict, sz, aug_tfms=[RandomDihedral()], pad=sz//8) #even without transformations and padding -> failure
    print('\n class to index mapping:\n',lbl2index)
    return create(tfms)

### the eventual sub-function of ImageClassifierData (read_dirs) expects subdirectories for each class: 
### e.g. all "test/cat.png" images should be in a "cat" folder. 

In [None]:
data = get_data(PATH,SIZE, BATCH_SIZE,stats_dict)

In [None]:
x, y = next(iter(data.trn_dl))

In [None]:
idx = 30
tiff.imshow(data.trn_ds.denorm(x[idx], y[idx]).squeeze()[:,:,0]); #denorm function called has a rollaxis() hence indexing changes.

# Training setup

In [None]:
torch.cuda.is_available()

## ResNet_with_Batchnorm

In [None]:
class BnLayer(nn.Module):
    def __init__(self, ni, nf, stride=2, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride,
                              bias=False, padding=1)
        self.a = nn.Parameter(torch.zeros(nf,1,1))
        self.m = nn.Parameter(torch.ones(nf,1,1))
        
    def forward(self, x):
        x = F.relu(self.conv(x))
        x_chan = x.transpose(0,1).contiguous().view(x.size(1), -1)
        if self.training:
            self.means = x_chan.mean(1)[:,None,None]
            self.stds  = x_chan.std (1)[:,None,None]
        return (x-self.means) / self.stds *self.m + self.a

In [None]:
class ResnetLayer(BnLayer):
    def forward(self, x): return x + super().forward(x)

In [None]:
class Resnet(nn.Module):
    def __init__(self, layers, c):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 10, kernel_size=5, stride=1, padding=2)
        self.layers = nn.ModuleList([BnLayer(layers[i], layers[i+1])
            for i in range(len(layers) - 1)])
        self.layers2 = nn.ModuleList([ResnetLayer(layers[i+1], layers[i + 1], 1)
            for i in range(len(layers) - 1)])
        self.layers3 = nn.ModuleList([ResnetLayer(layers[i+1], layers[i + 1], 1)
            for i in range(len(layers) - 1)])
        self.out = nn.Linear(layers[-1], c)
        
    def forward(self, x):
        x = self.conv1(x)
        for l,l2,l3 in zip(self.layers, self.layers2, self.layers3):
            x = l3(l2(l(x)))
        x = F.adaptive_max_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        return F.log_softmax(self.out(x), dim=-1)

In [None]:
wd=1e-5

In [None]:
learn = ConvLearner.from_model_data(Resnet([10, 20, 40, 80, 160], 4), data)

In [None]:
learn.summary()

In [None]:
%time learn.fit(1e-2, 8, cycle_len=4, wds=wd)

In [None]:
# at very little overfitting we have 43% accuracy
%time learn.fit(1e-2, 8, wds=wd, cycle_len=10, use_clr=(20,8, 0.95, 0.85), best_save_name='Objective_A_Resnet_per_class_1')

In [None]:
%time learn.fit(1e-3, 8, wds=wd, cycle_len=20, use_clr=(20,8, 0.95, 0.85), best_save_name='Objective_A_Resnet_per_class_2')

In [None]:
learn.load('Objective_A_Resnet_per_class_2')

In [None]:
%time learn.fit(1e-3, 1, wds=wd, cycle_len=1, use_clr=(20,8, 0.95, 0.85))

## Predictions

In [None]:
learn.load('Objective_A_Resnet_per_class_2')

In [None]:
stats_dict

In [None]:
stats_name = "yeast_v4.2_test_per_class.dict"
test_dirs, train_dirs = dataset_source(data_path)
stats_dict = Statistics.per_class(test_dirs, train_dirs,save_name=stats_name)

In [None]:
from pathlib import Path
norm_path = Path(path + '/yeast_v3_test_v1')
xstats = {}

for d in norm_path.iterdir():
    class_images = []
    Statistics.mro
    Statistics.read_class(d, class_images)
    mean = np.mean(class_images, axis=(0,2,3)) / 65536
    std = np.std(class_images, axis=(0,2,3)) / 65536
    stats = {d.name: (mean,std)}

In [None]:
Statistics.mro()

In [None]:
sz =200
path = PATH
bs = 64
create , lbl2index = ImageClassifierData.prepare_from_path(path, val_name='test', test_name='yeast_v3_test_v1', test_with_labels=True, bs=bs)
stats_dictX = {lbl2index[key]: val for key, val in stats_dict.items()}
trn_Xtfms, val_Xtfms = tfms_from_stats(stats_dictX, sz, aug_tfms=[RandomDihedral()], pad=sz//8)


fn = PATH+data.test_ds.fnames[1]
im = open_image(fn)
Nor_im = val_Xtfms(im, y=1)
preds = learn.predict_array(Nor_im[0][None])
print(np.argmax(np.exp(preds)))



In [None]:
data.test_ds.fnames

In [None]:
tiff.imshow(data.test_ds[1][0][1,:,:])
# np.amax(data.trn_ds[0][0][0,:,:])

In [None]:
len(y)

In [None]:
log_preds = learn.predict_with_targs()
log_testpreds = learn.predict_with_targs(is_test=True)

In [None]:
testprobs = np.exp(log_testpreds[0])
preds = np.argmax(testprobs, axis=1)
print(log_testpreds)
# print(log_testpreds[1])

In [None]:
testprobs = np.exp(log_testpreds)
# preds = np.argmax(testprobs, axis=1)
testprobs

## Analysis

In [None]:
log_preds, y = learn.TTA() # run predictions with TTA

### Confusion matrix

In [None]:
# Plot confusion matrix 
log_preds_mean = np.mean(log_preds, axis=0)
preds = np.argmax(log_preds_mean, axis=1)
cm = confusion_matrix(preds,y)
plot_confusion_matrix(cm, data.classes)

In [None]:
log_preds_mean.shape

### Analyse images 

#### Show random correct/incorrectly classified images:

In [None]:
log_preds_mean = np.mean(log_preds, axis=0) # averages predictions on original + 4 TTA images
preds = np.argmax(log_preds_mean, axis=1) # converts into 0 or 1

In [None]:
# probs = np.exp(log_preds_mean[:,0]) # prediction(WT)
probs = np.exp(log_preds_mean) # predictions

In [None]:
def rand_by_mask(mask): return np.random.choice(np.where(mask)[0], 4, replace=False)
def rand_by_correct(is_correct): return rand_by_mask((preds == data.val_y)==is_correct)

In [None]:
def plots(ims, channel, figsize=(12,6), rows=1, titles=None):
    f = plt.figure(figsize=figsize)
    for i in range(len(ims)):
        sp = f.add_subplot(rows, len(ims)//rows, i+1)
        sp.axis('Off')
        if titles is not None: sp.set_title(titles[i], fontsize=11)
        if channel is not None: plt.imshow(ims[i,channel,:,:]) 
        else: plt.imshow(np.sum(ims, axis=1)[i,:,:])

In [None]:
def plot_val_with_title_from_ds_no_denorm(idxs, title, channel=None):
    
    imgs = np.stack(data.val_ds[x][0] for x in idxs) # get images by idx
    corr_lbl = np.stack(data.val_ds[x][1] for x in idxs) # get correct label from data.val_ds by idx
    pred_lbl = np.stack(preds[x] for x in idxs) # get predicted label from preds by idx
    p_max = [np.amax(probs[x,:]) for x in idxs] # get highes probability from probs by idx
    
    title_fin = [f"true = {corr_lbl[x]}\n predicted: {pred_lbl[x]}\n  p = {p_max[x]}" for x in corr_lbl]
    print(title)
    
    return plots(imgs, channel, rows=1, titles=title_fin, figsize=(16,8))

In [None]:
# load from ds - not denormalized! 
plot_val_with_title_from_ds_no_denorm(rand_by_correct(True), "Correctly classified")
#optionally pass channel arg. to select single channel

In [None]:
plot_val_with_title_from_ds_no_denorm(rand_by_correct(False), "Incorrectly classified")

#### Show most correct/incorrectly classified images per class:

In [None]:
def most_by_mask(mask, y, mult):
    idxs = np.where(mask)[0]
    return idxs[np.argsort(mult * probs[:,y][idxs])[:4]]

def most_by_correct(y, is_correct): 
    mult = -1 if is_correct else 1
    return most_by_mask(((preds == data.val_y)==is_correct) & (data.val_y == y), y, mult)

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(0, True), "Most correctly classified WT")

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(0, False), "Most incorrectly classified WT") # logic?

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(1, True), "Most correctly classified mfb1KO") 

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(1, False), "Most incorrectly classified mfb1KO")

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(2, True), "Most correctly classified mfb1KO-mmr1KO")

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(3, True), "Most correctly classified mmr1KO")

In [None]:
# etc.

#### Show (most) uncertain images

In [None]:
most_uncertain = t = np.argsort(np.amax(probs, axis = 1))[:6] # get best "guess" per image and list the least confident ones
plot_val_with_title_from_ds_no_denorm(most_uncertain, "Most uncertain predictions")