In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from resources.conv_learner import *
import os
from pathlib import Path
from os.path import basename
import skimage.external.tifffile as tiff
import pickle


In [None]:
PATH = "datasets/yeast_v4"
data_path = Path(PATH)

In [None]:
# calculate
test_dirs , train_dirs = [], []

for ds_dir in data_path.iterdir():
    if 'DS_Store' not in str(ds_dir):
        for class_dir in ds_dir.iterdir():
            if 'DS_Store' not in str(class_dir):
                if 'test' in str(class_dir): test_dirs.append(class_dir)
                elif 'train' in str(class_dir): train_dirs.append(class_dir)

In [None]:
import numpy as np
stats = {}

class_dirs = zip(test_dirs,train_dirs)
for test, train in class_dirs:
    class_images = []
    class_name = basename(str(test).split('/')[-1]) #adding "basename" function avoids saving entire path; kinda nicer
    class_images = []
    for dir_ in [test,train]:
        # read from each dir and append to the images
        for file in dir_.iterdir():
            image = tifffile.imread(str(file))
            if image.shape[-1] == 200 and image.shape[-2] ==200:
                class_images.append(image)
            else:
                os.remove(str(file)) # get rid of the non square images
                print(f"removed file: {str(file)}")
    
    # calc std
    print(class_name)
    mean, stdev = np.mean(class_images, axis=(0,2,3))/65535 , np.std(class_images, axis=(0,2,3))/65535
    stats[class_name] = (mean, stdev)

In [None]:
# pickle the stats dict

with open('./norm_stats.dict','wb') as file:
    pickle.dump(stats,file)

In [None]:
classes = ('WT', 'mfb1KO', 'mfb1KO_mmr1KO', 'mmr1KO')
PATH = "datasets/yeast_v4/"

In [None]:
def get_data(sz,bs):
    create, lbl2index = ImageClassifierData.from_paths_and_stats(PATH, val_name='test', bs=bs)
    stats_dict = {lbl2index[key]: val for key, val in stats.items()}
    tfms = tfms_from_stats(stats_dict, sz, aug_tfms=[RandomDihedral()], pad=sz//8) #even without transformations and padding -> failure
    print(lbl2index)
    return create(tfms)

### the eventual sub-function of ImageClassifierData (read_dirs) expects subdirectories for each class: 
### e.g. all "test/cat.png" images should be in a "cat" folder. 

In [None]:
bs=64

In [None]:
data = get_data(200,bs)

In [None]:
x,y=next(iter(data.trn_dl))

In [None]:
print(y)

In [None]:
idx = 0
tiff.imshow(data.trn_ds.denorm(x[idx], y[idx]).squeeze()[:,:,1]); #denorm function called has a rollaxis() hence indexing changes.

# Training setup

In [None]:
class SimpleNet(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(layers[i], layers[i + 1]) for i in range(len(layers) - 1)])
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        for l in self.layers:
            l_x = l(x)
            x = F.relu(l_x)
        return F.log_softmax(l_x, dim=-1)

In [None]:
learn = ConvLearner.from_model_data(SimpleNet([200*200*2, 40, 2]), data) #(!) change channel-number & classes accordingly

In [None]:
learn, [o.numel() for o in learn.model.parameters()]

In [None]:
learn.summary()

In [None]:
lr=1e-5

In [None]:
%time learn.fit(lr, 200, cycle_len=1)

## ConvNet

In [None]:
class ConvNet(nn.Module):
    def __init__(self, layers, c):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(layers[i], layers[i + 1], kernel_size=5, stride=1, padding=(2,2))
            for i in range(len(layers) - 1)])
        self.pool = nn.AdaptiveMaxPool2d(1)
        self.out = nn.Linear(layers[-1], c)
        
    def forward(self, x):
        for l in self.layers: x = F.relu(l(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return F.log_softmax(self.out(x), dim=-1)

In [None]:
learn = ConvLearner.from_model_data(ConvNet([2, 20, 40, 80], 4), data)

In [None]:
learn.summary() ### learner.summary is hardcording the number of channels = 3

In [None]:
lr=1e-10

In [None]:
%time learn.fit(lr, 10, cycle_len=1)

In [None]:
torch.cuda.is_available()

## ResNet_with_Batchnorm

In [None]:
class BnLayer(nn.Module):
    def __init__(self, ni, nf, stride=2, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride,
                              bias=False, padding=1)
        self.a = nn.Parameter(torch.zeros(nf,1,1))
        self.m = nn.Parameter(torch.ones(nf,1,1))
        
    def forward(self, x):
        x = F.relu(self.conv(x))
        x_chan = x.transpose(0,1).contiguous().view(x.size(1), -1)
        if self.training:
            self.means = x_chan.mean(1)[:,None,None]
            self.stds  = x_chan.std (1)[:,None,None]
        return (x-self.means) / self.stds *self.m + self.a

In [None]:
class ResnetLayer(BnLayer):
    def forward(self, x): return x + super().forward(x)

In [None]:
class Resnet(nn.Module):
    def __init__(self, layers, c):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 10, kernel_size=5, stride=1, padding=2)
        self.layers = nn.ModuleList([BnLayer(layers[i], layers[i+1])
            for i in range(len(layers) - 1)])
        self.layers2 = nn.ModuleList([ResnetLayer(layers[i+1], layers[i + 1], 1)
            for i in range(len(layers) - 1)])
        self.layers3 = nn.ModuleList([ResnetLayer(layers[i+1], layers[i + 1], 1)
            for i in range(len(layers) - 1)])
        self.out = nn.Linear(layers[-1], c)
        
    def forward(self, x):
        x = self.conv1(x)
        for l,l2,l3 in zip(self.layers, self.layers2, self.layers3):
            x = l3(l2(l(x)))
        x = F.adaptive_max_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        return F.log_softmax(self.out(x), dim=-1)

In [None]:
learn = ConvLearner.from_model_data(Resnet([10, 20, 40, 80, 160], 4), data)

In [None]:
learn.summary()

In [None]:
wd=1e-5

In [None]:
%time learn.fit(1e-2, 8, cycle_len=4, wds=wd)

In [None]:
%time learn.fit(1e-2, 8, wds=wd, cycle_len=4, use_clr=(20,8, 0.95, 0.85))

In [None]:
%time learn.fit(1e-2, 8, wds=wd, cycle_len=40, use_clr=(20,8, 0.95, 0.85))

In [None]:
%time learn.fit(1e-3, 8, wds=wd, cycle_len=40, use_clr=(20,8, 0.95, 0.85))

In [None]:
learn.save('tmp_resnet_clr')