# YNet - Dataset 10.4:

Data from Experiment (2), Mitochondria = Cit1-mCherry 

### Importing utilities:

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import skimage.external.tifffile as tiff

from common import Statistics, dataset_source
from resources.conv_learner import *
from resources.plots import *
from pprint import pprint
import matplotlib.pyplot as plt
%matplotlib inline

#### Setting up variables

In [3]:
PATH = "datasets/yeast_v10.4/"
data_path = Path(PATH)

In [4]:
CLASSES = ('WT', 'mfb1KO', 'mmr1KO','mmr1KO-mfb1KO', 'dnm1KO', 'LatA_5uM', 'fzo1KO')
NUM_CLASSES = len(CLASSES)
BATCH_SIZE = 64
SIZE = 200

#### Calculating normalization statistics

Note that we are setting up train & val data, as well as test. Within test, we are here including a mutant cell type that the model never trains on. The idea is to use to the feature space developed during training to evaluate novel cell types by similarity to the landmarks that the model was trained on. 

In [5]:
# stats_name = "yeast_v10.2_per_class.dict"
classes = Statistics.source_class(data_path)

train_val = zip(classes['train'], classes['val'])
test_ = zip(classes['test'])
 
main_stats = Statistics.per_class(train_val)
test_stats = Statistics.per_class(test_)

working on: datasets\yeast_v10.4\train\02_mfb1KO
working on: datasets\yeast_v10.4\val\02_mfb1KO
working on: datasets\yeast_v10.4\train\02_mmr1KO
working on: datasets\yeast_v10.4\val\02_mmr1KO
working on: datasets\yeast_v10.4\train\02_mmr1KO-mfb1KO
working on: datasets\yeast_v10.4\val\02_mmr1KO-mfb1KO
working on: datasets\yeast_v10.4\train\02_WT
working on: datasets\yeast_v10.4\val\02_WT
working on: datasets\yeast_v10.4\train\03_dnm1KO
working on: datasets\yeast_v10.4\val\03_dnm1KO
working on: datasets\yeast_v10.4\train\03_fzo1KO
working on: datasets\yeast_v10.4\val\03_fzo1KO
working on: datasets\yeast_v10.4\train\03_LatA-5uM
working on: datasets\yeast_v10.4\val\03_LatA-5uM
working on: datasets\yeast_v10.4\train\03_WT
working on: datasets\yeast_v10.4\val\03_WT
working on: datasets\yeast_v10.4\train\04_WT
working on: datasets\yeast_v10.4\val\04_WT
working on: datasets\yeast_v10.4\test\01_mfb1KO
working on: datasets\yeast_v10.4\test\01_mmr1KO
working on: datasets\yeast_v10.4\test\01_WT
wo

In [6]:
for keys in main_stats.keys():
    print(f"{keys}: \t \t \t {main_stats[keys]}")

02_mfb1KO: 	 	 	 (array([0.00794, 0.00484]), array([0.00075, 0.00163]))
02_mmr1KO: 	 	 	 (array([0.00799, 0.00503]), array([0.0008 , 0.00186]))
02_mmr1KO-mfb1KO: 	 	 	 (array([0.00791, 0.00489]), array([0.00073, 0.00162]))
02_WT: 	 	 	 (array([0.00796, 0.00478]), array([0.00075, 0.00149]))
03_dnm1KO: 	 	 	 (array([0.02515, 0.00477]), array([0.0025 , 0.00192]))
03_fzo1KO: 	 	 	 (array([0.02517, 0.0047 ]), array([0.00202, 0.00202]))
03_LatA-5uM: 	 	 	 (array([0.0253, 0.0049]), array([0.0024, 0.0017]))
03_WT: 	 	 	 (array([0.02536, 0.00459]), array([0.00255, 0.00147]))
04_WT: 	 	 	 (array([0.02535, 0.00493]), array([0.00215, 0.00156]))


In [7]:
for keys in test_stats.keys():
    print(f"{keys}: \t \t \t {test_stats[keys]}")

01_mfb1KO: 	 	 	 (array([0.0211 , 0.00454]), array([0.00151, 0.00165]))
01_mmr1KO: 	 	 	 (array([0.02115, 0.00486]), array([0.00158, 0.00193]))
01_WT: 	 	 	 (array([0.0211 , 0.00449]), array([0.00149, 0.00129]))
03_axl1KO: 	 	 	 (array([0.02548, 0.00475]), array([0.00221, 0.00144]))
03_bud1KO: 	 	 	 (array([0.02544, 0.00459]), array([0.00223, 0.00142]))
03_DMSO: 	 	 	 (array([0.02535, 0.00494]), array([0.00216, 0.00156]))
03_DTT: 	 	 	 (array([0.02586, 0.00496]), array([0.00224, 0.00167]))
03_Eth: 	 	 	 (array([0.02533, 0.00469]), array([0.00226, 0.00134]))
03_LatA-05uM: 	 	 	 (array([0.02535, 0.00498]), array([0.00252, 0.00168]))


## Defining datasets:

In [58]:
def tfms_for_test(stats, sz):
    test_norm = Normalize(stats)
    test_denorm = Denormalize(stats)
    val_crop = CropType.NO
    test_tfms = image_gen(test_norm, test_denorm,sz, crop_type=val_crop)
    return test_tfms

In [59]:
def get_data(path: str, sz, bs):
    create, lbl2index, lbl2index_test = ImageClassifierData.prepare_from_path(path, val_name='val', bs=bs, num_workers=1,
                                                                             test_name='test', test_with_labels=True, balance=False)
    
    main_stats_X = {lbl2index[key][0]: val for key, val in main_stats.items()}
    tfms = tfms_from_stats(main_stats_X, sz, aug_tfms=[RandomDihedral()], pad=sz//8)
    
    test_stats_X = {lbl2index_test[key][0]: val for key, val in test_stats.items()}
    test_tfms = tfms_for_test(test_stats_X,sz)
    tfms += (test_tfms, )
    
#     print(main_stats_X)
#     print(test_stats_X)
    
    print('\n class to index mapping:\n',lbl2index)
    print('\n class to index mapping:\n',lbl2index_test)
    return create(tfms)

In [60]:
data = get_data(PATH,SIZE,BATCH_SIZE)


 class to index mapping:
 {'02_WT': [0, 0, 'WT'], '02_mfb1KO': [1, 1, 'mfb1KO'], '02_mmr1KO': [2, 2, 'mmr1KO'], '02_mmr1KO-mfb1KO': [3, 3, 'mmr1KO-mfb1KO'], '03_LatA-5uM': [4, 4, 'LatA-5uM'], '03_WT': [5, 0, 'WT'], '03_dnm1KO': [6, 5, 'dnm1KO'], '03_fzo1KO': [7, 6, 'fzo1KO'], '04_WT': [8, 0, 'WT']}

 class to index mapping:
 {'01_WT': [0, 0, 'WT'], '01_mfb1KO': [1, 1, 'mfb1KO'], '01_mmr1KO': [2, 2, 'mmr1KO'], '03_DMSO': [3, 3, 'DMSO'], '03_DTT': [4, 4, 'DTT'], '03_Eth': [5, 5, 'Eth'], '03_LatA-05uM': [6, 6, 'LatA-05uM'], '03_axl1KO': [7, 7, 'axl1KO'], '03_bud1KO': [8, 8, 'bud1KO']}


In [66]:
def analyze_batch_composition():
    
    bat_ = iter(data.trn_dl)

    for i in range(len(data.trn_dl)):
        x, y = next(bat_)
        ys = np.array([list(to_np(y)).count(j) for j in range(NUM_CLASSES)])
        print
        if i == 0:
            bys = ys
        else:
            bys = np.vstack((bys, ys))

    means = np.mean(bys, axis = 0)/64
    print(means)

In [67]:
analyze_batch_composition()

[0.3186  0.08232 0.1189  0.10823 0.08994 0.15549 0.11509]


In [25]:
x, y = next(iter(data.trn_dl))

In [None]:
print(len(data.val_dl.dataset.y))
print(len(data.trn_dl.dataset.y))
print(len(data.test_dl.dataset.y))

In [28]:
data.trn_dl.sampler#.__dict__.keys()

<torch.utils.data.sampler.WeightedRandomSampler at 0x20a22d973c8>

### Inspect loaded data:

Displaying the same image with and without normalization.

In [None]:
# specify which image-index
idx = 6

# loading it from GPU to CPU
xx = x[idx].cpu().numpy().copy()
yy = y[idx]
# showing the image
#
#sp.axis('Off')
#sp.set_title("Norm", fontsize=11)
figure, _ ,_ = tiff.imshow(np.sum(xx, axis=0))
figure.set_size_inches(6,6)
figure.add_subplot(111)

# figure2, _, _ = tiff.imshow(np.sum(data.trn_ds.denorm(xx,yy).squeeze() * 65536, axis=2)) # not very elegant atm. 
# figure2.set_size_inches(6,6)
print(yy)

# Training setup

In [None]:
torch.cuda.is_available()

## ResNet_with_Batchnorm

Defining network architecture. 

In [29]:
class BnLayer(nn.Module):
    def __init__(self, ni, nf, stride=2, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride,
                              bias=False, padding=1)
        self.a = nn.Parameter(torch.zeros(nf,1,1))
        self.m = nn.Parameter(torch.ones(nf,1,1))
        
    def forward(self, x):
        x = F.relu(self.conv(x))
        x_chan = x.transpose(0,1).contiguous().view(x.size(1), -1)
        if self.training:
            self.means = x_chan.mean(1)[:,None,None]
            self.stds  = x_chan.std (1)[:,None,None]
        return (x-self.means) / self.stds *self.m + self.a

In [30]:
class ResnetLayer(BnLayer):
    def forward(self, x): return x + super().forward(x)

In [31]:
class Resnet(nn.Module):
    def __init__(self, layers, c):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 10, kernel_size=5, stride=1, padding=2)
        self.layers = nn.ModuleList([BnLayer(layers[i], layers[i+1])
            for i in range(len(layers) - 1)])
        self.layers2 = nn.ModuleList([ResnetLayer(layers[i+1], layers[i + 1], 1)
            for i in range(len(layers) - 1)])
        self.layers3 = nn.ModuleList([ResnetLayer(layers[i+1], layers[i + 1], 1)
            for i in range(len(layers) - 1)])
        self.out = nn.Linear(layers[-1], c)
        
    def forward(self, x):
        x = self.conv1(x)
        for l,l2,l3 in zip(self.layers, self.layers2, self.layers3):
            x = l3(l2(l(x)))
        x = F.adaptive_max_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        return F.log_softmax(self.out(x), dim=-1)

In [32]:
wd=1e-5 # weight-decay/L2 regularization 

In [33]:
learn = ConvLearner.from_model_data(Resnet([10, 20, 40, 80, 160], 7), data)

In [None]:
learn.summary()

In [34]:
%time learn.fit(1e-2, 8, cycle_len=1, wds=wd)

HBox(children=(IntProgress(value=0, description='Epoch', max=8), HTML(value='')))

EPOCH 0 ---------------------------------------- STEP 0                                                                                                                                                              
mean: [31.90244  7.70732 11.5122  10.43902  8.46341 15.14634 11.14634]
stdev: [6.48377 2.9652  3.80065 3.813   3.47884 4.07587 4.57758]

[WT]: 84.76%
[mfb1KO]:  0.0%
[mmr1KO]: 11.43%
[mmr1KO-mfb1KO]:  0.0%
[LatA-5uM]:  0.0%
[dnm1KO]:  0.0%
[fzo1KO]: 54.29%
epoch      trn_loss   val_loss   accuracy   
    0      11.752529  4.601089   0.355556  
EPOCH 1 ---------------------------------------- STEP 1                                                                                                                                                              
mean: [31.90244  7.82927 11.46341 10.5122   8.4878  15.2439  11.07317]
stdev: [6.83179 3.56748 3.62981 3.90198 3.67006 4.04716 4.84587]

[WT]: 58.1%
[mfb1KO]:  0.0%
[mmr1KO]: 25.71%
[mmr1KO-mfb1KO]: 8.571%
[LatA-5uM]:  0.0%
[d

[array([1.32415]), 0.4825396772414919]

In [65]:
np.sum([31.80488, 7.65854, 11.53659, 10.39024,  8.58537, 15.2439,  11.17073])

96.39025000000001

In [None]:
f = plt.figure()
x = learn.sched.plot_loss()
plt.ylabel('Loss')
plt.xlabel('Iterations')
plt.show()

In [None]:
learn.sched.plot_lr()

#### Run some more cylces - error & accuracy should continuously improve

Note: cycle len = number of epochs per cycle

In [None]:
%time learn.fit(1e-2, 10, wds=wd, cycle_len=2, use_clr=(20,8, 0.95, 0.85))

In [None]:
%time learn.fit(1e-3, 3, wds=wd, cycle_len=2, cycle_mult = 2, use_clr=(20,8, 0.95, 0.85))

In [None]:
wd=1e-4 # weight-decay/L2 regularization 

In [None]:
%time learn.fit(1e-3, 3, wds=wd, cycle_len=2, cycle_mult=2, use_clr=(20,8, 0.95, 0.85))

In [None]:
wd=1e-4 # weight-decay/L2 regularization 

In [None]:
%time learn.fit(1e-2, 2, wds=wd, cycle_len=2, cycle_mult=2, use_clr=(20,8, 0.95, 0.85)) #, best_save_name = 'YNet_Res_v10.3'

## Analysis/Model evaluation

This is one of the major areas that needs improvement in our workflow. The tools we have so far (confusion matrix and manual inpsection of images) are essential but definitely not sufficient to ensure that our model learns something biologicaly relevant. Ideas are welcome!

### ...after training 

In [None]:
learn.save('YNet_Res_v10.4_release_1.0')

In [None]:
## Load model:
learn.load('YNet_Res_v10.3_release_1.0')

In [None]:
# %time learn.fit(1e-10, 1, wds=wd, cycle_len=1)

In [None]:
learn.warm_up(1e-10)

In [None]:
log_preds, y = learn.TTA(n_aug=4) # run predictions with TTA

### Confusion matrix

In [None]:
# Plot confusion matrix 
log_preds_mean = np.mean(log_preds, axis=0)
preds = np.argmax(log_preds_mean, axis=1)
cm = confusion_matrix(preds,y)
plot_confusion_matrix(cm, data.classes)

## Test-set eval

### Visualizing train and test datasets as exposed by dataloader

## @Yinan, please take the functionality of the next 2 cells and transfer it to the data_vis.py

In [None]:
dl_ = data.trn_dl
batch_ = iter(dl_)


plt.style.use('seaborn-whitegrid')
fig = plt.figure()
ax = plt.subplot(111)

for b in range(0,len(dl_)):
    x_, y_ = next(batch_)

    x_np = to_np(x_)
    y_np = to_np(y_)
    
    im_means = np.mean(x_np, axis=(2,3))
    
    ax.plot(im_means[:,0], im_means[:,1], 'o', color = 'C0' , alpha=0.5)
        
plt.xlim(-0.4,0.4)
plt.ylim(-1, 4)

In [None]:
dl_t = data.test_dl
batch_t = iter(dl_t)

plt.style.use('seaborn-whitegrid')
fig = plt.figure()
ax = plt.subplot(111)

for b in range(0,len(dl_t)):
    x_, y_ = next(batch_t)

    x_np = to_np(x_)
    y_np = to_np(y_)
    
    im_means = np.mean(x_np, axis=(2,3))
    
    ax.plot(im_means[:,0], im_means[:,1], 'o', color = 'C1' , alpha=0.5)
    
plt.xlim(-0.4,0.4)
plt.ylim(-1, 4)

### Inference

In [None]:
test_log_preds, targs = learn.predict_with_targs(is_test=True)
testprobs = np.exp(test_log_preds)
preds = np.argmax(testprobs, axis=1)

## @James, please transfer the functionality of the next 5 cells into the data_vis.py

In [None]:
# @James, there is definitely a simpler way of generating test_lbl2idx_ than calling this entire line. Please trim it down. 

_, lbl2idx_, test_lbl2idx_ = ImageClassifierData.prepare_from_path(PATH, val_name='val', bs=64, num_workers=1, test_name='test', test_with_labels=True)


In [None]:
# make predictions dictionary

h = 0
preds_dict = {}
for i, key in enumerate(test_lbl2idx_.keys()):
    l = h
    h = h + list(data.test_dl.dataset.src_idx).count(i)
    preds_dict[key] = list(preds[l:h])
    print(f"{key} predictions ready ({h - l} elements)")

In [None]:
preds_rel = {}
for key in preds_dict.keys():
    val = {cls: preds_dict[key].count(i)/len(preds_dict[key]) for i, cls in enumerate(data.classes)}
    preds_rel[key]= val

In [None]:
def plot_test_preds(targets, preds_rel):
    
    if not isinstance(targets, list):
        targets = [targets]
        
    x = math.ceil((int(len(targets)) /2)) # dynamic scaling of GridSpec
    sz = 4 * x # dynamic scaling of figuresize
    
    # plotting:
    plt.figure(figsize=(12,sz))
    gs1 = plt.GridSpec(x,2)
    gs1.update(wspace = 0.4)

    for i, targ in enumerate(targets):
        to_plot = [preds_rel[targ][key] for key in data.classes] # extracting data
        ax1 = plt.subplot(gs1[i])
        ax1.barh(data.classes, to_plot)
        ax1.set_title(targ)
        ax1.set_xlim(0,1)
    
    plt.show()

In [None]:
test_classes = list(test_lbl2idx_.keys())

plot_test_preds(test_classes, preds_rel)
# plot_test_preds(['01_WT', '03_WT', '03_fzo1KO', '01_mfb1KO'], preds_rel)
# plot_test_preds(['01_WT'], preds_rel)

### Analyse images 

#### Show random correct/incorrectly classified images:

In [None]:
log_preds_mean = np.mean(log_preds, axis=0) # averages predictions on original + 4 TTA images
preds = np.argmax(log_preds_mean, axis=1) # converts into 0 or 1

In [None]:
# probs = np.exp(log_preds_mean[:,0]) # prediction(WT)
probs = np.exp(log_preds_mean) # predictions

In [None]:
def rand_by_mask(mask): return np.random.choice(np.where(mask)[0], 4, replace=False)
def rand_by_correct(is_correct): return rand_by_mask((preds == data.val_y)==is_correct)

In [None]:
def plots(ims, channel, figsize=(12,6), rows=1, titles=None):
    f = plt.figure(figsize=figsize)
    for i in range(len(ims)):
        sp = f.add_subplot(rows, len(ims)//rows, i+1)
        sp.axis('Off')
        if titles is not None: sp.set_title(titles[i], fontsize=11)
        if channel is not None: plt.imshow(ims[i,channel,:,:]) 
        else: plt.imshow(np.sum(ims, axis=1)[i,:,:])

In [None]:
def plot_val_with_title_from_ds_no_denorm(idxs, title, channel=None):
    
    imgs = np.stack(data.val_ds[x][0] for x in idxs) # get images by idx
    corr_lbl = np.stack(data.val_ds[x][1] for x in idxs) # get correct label from data.val_ds by idx
    pred_lbl = np.stack(preds[x] for x in idxs) # get predicted label from preds by idx
    p_max = [np.amax(probs[x,:]) for x in idxs] # get highes probability from probs by idx
    
    title_fin = [f"true = {corr_lbl[x]}\n predicted: {pred_lbl[x]}\n  p = {p_max[x]}" for x in corr_lbl]
    print(title)
    
    return plots(imgs, channel, rows=1, titles=title_fin, figsize=(16,8))

### Plot images according to predictions

In [None]:
# load from ds - not denormalized! 
plot_val_with_title_from_ds_no_denorm(rand_by_correct(True), "Correctly classified")
#optionally pass channel arg. to select single channel

In [None]:
plot_val_with_title_from_ds_no_denorm(rand_by_correct(False), "Incorrectly classified")

#### Show most correct/incorrectly classified images per class:

In [None]:
def most_by_mask(mask, y, mult):
    idxs = np.where(mask)[0]
    return idxs[np.argsort(mult * probs[:,y][idxs])[:4]]

def most_by_correct(y, is_correct): 
    mult = -1 if is_correct else 1
    return most_by_mask(((preds == data.val_y)==is_correct) & (data.val_y == y), y, mult)

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(0, True), "Most correctly classified WT")

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(0, False), "Most incorrectly classified WT") # logic?

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(1, True), "Most correctly classified mfb1KO") 

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(1, False), "Most incorrectly classified mfb1KO")

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(2, True), "Most correctly classified mfb1KO-mmr1KO")

In [None]:
plot_val_with_title_from_ds_no_denorm(most_by_correct(3, True), "Most correctly classified mmr1KO")

In [None]:
# etc.

#### Show (most) uncertain images

In [None]:
most_uncertain = t = np.argsort(np.amax(probs, axis = 1))[:6] # get best "guess" per image and list the least confident ones
plot_val_with_title_from_ds_no_denorm(most_uncertain, "Most uncertain predictions")