In [1]:
from PIL import Image
import numpy as np
import tifffile

# Load data

In [2]:
input_image = tifffile.imread('../../../pictures/rec_00076.tif')

input_matrix = np.array(input_image)

In [3]:
input_matrix.shape

(2615, 2615)

In [4]:
output_matrix = tifffile.imread('../../../pictures/Label_400_final.tif')

output_matrix.shape

(2615, 2615)

## Preprocess Data

In [5]:
from kern_segmentation.src.preprocessing.Preparer import Preparer

data, sobelx, sobely = Preparer.calculate(data=input_matrix, kernel_size=3, sigma=0.5)

## Augmentation

In [6]:
from kern_segmentation.src.preprocessing.Augmentator import Augmentator

data = Augmentator.augment_data(data, sobelx, sobely)

100%|██████████| 20/20 [00:12<00:00,  1.55it/s]
100%|██████████| 20/20 [00:10<00:00,  1.99it/s]
100%|██████████| 20/20 [00:09<00:00,  2.04it/s]


In [7]:
from matplotlib import pyplot as plt

label = Augmentator.augment_label(output_matrix)

100%|██████████| 20/20 [00:09<00:00,  2.07it/s]


In [8]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.65, random_state=42)

In [9]:
X_train.shape

(5040, 128, 128, 3)

# AI models

## FCnn

In [None]:
from torch import nn


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, activation):
        super(ConvBlock, self).__init__()
        self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=1)
        self.batch_norm_layer = nn.BatchNorm2d(out_channels)
        self.activation_layer = nn.ReLU() if activation == 'relu' else nn.PReLU()
        self.dropout_layer = nn.Dropout2d(p=0.2)

    def forward(self, x):
        x = self.conv_layer(x)
        x = self.batch_norm_layer(x)
        x = self.activation_layer(x)
        x = self.dropout_layer(x)
        return x

class FullyCNN(nn.Module):
    def __init__(self, in_channels, pretrained_weights=None):
        super(FullyCNN, self).__init__()

        self.conv_layers_3x3 = nn.ModuleList([ConvBlock(in_channels, 32, kernel_size=3, activation='relu')])

        self.n = 8
        for i in range(self.n):
            self.conv_layers_3x3.append(ConvBlock(32, 32, kernel_size=3, activation='relu'))

        self.output_layer = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        for conv_layer in self.conv_layers_3x3:
            x = conv_layer(x)
        x = self.output_layer(x)
        return x

# Train models

In [None]:
from torch.utils.data import TensorDataset, DataLoader
import torch

try:
  X_tensor = torch.from_numpy(X_train).float()
  y_tensor = torch.from_numpy(y_train).float()
except Exception as e:
  X_tensor = torch.from_numpy(X_train.astype(np.float32))
  y_tensor = torch.from_numpy(y_train.astype(np.float32))

dataset = TensorDataset(X_tensor, y_tensor)

batch_size = 32
train_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)

## Usefull function

## fcnn train

In [None]:
from kern_segmentation.src.utils import train

# train FCNN
EPOCH = 1

model = FullyCNN(in_channels=2)
criterion = nn.L1Loss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
log_metrics = lambda *args: print(args)

train(model, dataset, criterion, optimizer, EPOCH, log_metrics, 'fcnn_a')

In [None]:
from kern_segmentation.src.utils import visualize_random_sample
%%time
# better test on 4396
visualize_random_sample(model, train_loader)