# Финальный пайплайн обучения и теста


In [None]:
!pip install vesuvius
!vesuvius.accept_terms --yes

In [None]:
import vesuvius
from vesuvius import Volume

In [None]:
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

import numpy as np
from tqdm import tqdm

In [None]:
class VolumetricDataset(Dataset):
    def __init__(self, volume, label, tile_size, stride, validation_zone, valid=False):
        """
        Initialize the dataset with the volume and label.
        
        volume (np.ndarray): The volumetric image of shape (Z, Y, X).
        label (np.ndarray): The 2D label of shape (Y, X).
        tile_size (int): The size of the tiles to extract along the Y and X dimensions.
        stride (int): The stride for extracting tiles along the Y and X dimensions.
        """
        
        self.volume = volume
        self.label = label
        self.tile_size = tile_size
        self.stride = stride
        self.validation_zone = validation_zone
        self.valid = valid
        self.tiles, self.labels, self.corners = self.extract_tiles()

    def extract_tiles(self):
        """
        Extract 3D tiles from the volume and corresponding 2D labels.

        Returns:
            tiles (list): A list of 3D tiles.
            labels (list): A list of 2D labels.
        """
        Z, Y, X = self.volume.shape
        tiles = []
        labels = []
        corners = []
        # generate 3D tiles by moving along the Y and X axes
        for y in range(0, Y - self.tile_size + 1, self.stride):
            for x in range(0, X - self.tile_size + 1, self.stride):
                if self.valid is False:
                    if (y + self.tile_size < self.validation_zone[0]) or (y > self.validation_zone[1]):
                        if (x + self.tile_size < self.validation_zone[2]) or (x > self.validation_zone[3]):
                            tile = self.volume[:, y:y + self.tile_size, x:x + self.tile_size]
                            label_tile = self.label[y:y + self.tile_size, x:x + self.tile_size]
                            if np.sum(label_tile)/self.tile_size**2 > 0.05: # at least 5% of ink
                                tiles.append(tile)
                                labels.append(label_tile)
                                corners.append([y,x])
                else:
                    if (y >= self.validation_zone[0]) and (y + self.tile_size <= self.validation_zone[1]):
                        if (x >= self.validation_zone[2]) and (x + self.tile_size <= self.validation_zone[3]):
                            tile = self.volume[:, y:y + self.tile_size, x:x + self.tile_size]
                            label_tile = self.label[y:y + self.tile_size, x:x + self.tile_size]
                            tiles.append(tile)
                            labels.append(label_tile)
                            corners.append([y,x])

        return tiles, labels, corners

    def __len__(self):
        return len(self.tiles)

    def __getitem__(self, idx):
        tile = self.tiles[idx] 
        label = self.labels[idx]
        corners = self.corners[idx]
        
        tile = torch.tensor(tile, dtype=torch.float32).unsqueeze(0)
        label = torch.tensor(label, dtype=torch.float32).unsqueeze(0)
        corners = torch.tensor(corners, dtype=torch.int).unsqueeze(0)
        return tile, label, corners

In [None]:
def fractal_dimension(binary):
    '''
    Calculate fractal dimension of binary image ln(N) / ln(1/r) using box counting method.
    '''
    device = binary.device
    batch_size, blocks_num, height, width = binary.shape
    min_dim = min(height, width)
    # square sizes to use box counting method
    scales = torch.tensor([1,2,4,6,8,10], device=device)
    scales = torch.unique(scales)
    scales = scales[scales <= min_dim]
    counts = []

    binary_flat = binary.view(-1, height, width)
    cnt_flat = binary_flat.size(0)
    for scale in scales:
        if scale < 1:
            continue

        # add zero padding to use box counting method
        pad_h = (scale - (height % scale)) % scale
        pad_w = (scale - (width % scale)) % scale
        padded_binary = F.pad(binary_flat, (0, pad_w, 0, pad_h), value=False)

        # split image into small squares to use box counting method
        unfolded = padded_binary.unfold(1, scale, scale).unfold(2, scale, scale)
        count = unfolded.any(dim=-1).any(dim=-1).sum(dim=(1,2))
        counts.append(count)

    if len(counts) < 2:
        return torch.zeros(batch_size, blocks_num, device=device)
    
    log_scales = torch.log(1 / scales.float() + 1e-8) # log(1/r)
    log_counts = torch.log(torch.stack(counts, dim=1).float() + 1e-8) # log(N)

    # here we prepare matricies to get the slope of the line that best fits the
    # 2D points of the form (ln(N), ln(1/r))
    # that way we will get fractal dimensions
    X = torch.stack([log_scales, torch.ones_like(log_scales)], dim=1) 
    X = X.repeat(cnt_flat, 1, 1)
    Y = log_counts.unsqueeze(-1)

    # get the slope of the line that best fits the 2D points of the form (ln(N), ln(1/r)) 
    coeffs = torch.linalg.lstsq(X, Y).solution.squeeze(-1)

    # FD is first coefficient (slope of the line)
    fd = coeffs[..., 0]
    return fd


def calc_fractal_features(input_volume, windows = [2,4,8], thresholds=None, absolute=True, add_image_channel = False):
    '''
    Calculate fractal features of image with certain sizes of window 
    and certain thresholds to make image binary (binary = (image > image.mean() * threshold)).

    Also we could add original image to those features.
    '''
    device = input_volume.device
    batch_size, channels, z, y, x = input_volume.shape # channels = 1

    volume = input_volume.view(batch_size*channels, z, y, x)

    max_window_size = min(y, x) // 2
    windows = torch.tensor(windows, device=device)
    windows = windows[windows <= max_window_size]

    if thresholds is None:
        thresholds = [1.]
        absolute = False
    if not absolute:
        avg = volume.mean(dim=(2,3))
        new_thresholds = []
        for threshold in thresholds:
            new_thresholds.append(threshold * avg)
        thresholds = new_thresholds

    num_features = len(thresholds) * len(windows)
    if add_image_channel:
        num_features += 1
    
    features = torch.zeros(batch_size, num_features, z, y, x, device=device)
    cur_channel = 0
    
    if add_image_channel:
        features[:, 0] = volume
        cur_channel += 1

    for t_idx, threshold in enumerate(thresholds):
        
        # if we want to add relative to mean thresholds
        if isinstance(threshold, torch.Tensor):
            threshold = threshold.view(batch_size, z, 1, 1)
        
        binary = volume > threshold
        for window in windows:
            # Add padding to binary image to calculate fractal feautures
            pad_y = (window - (y % window)) % window
            pad_x = (window - (x % window)) % window
            padded_binary = F.pad(binary, (0, pad_x, 0, pad_y), value=False)

            # split image into small squares to calculate fractal feautures for them
            unfolded = padded_binary.unfold(2, window, window).unfold(3, window, window)
            _, _, y_unf, x_unf, _, _ = unfolded.shape

            # calculate fractal dimension for all squares
            fd = fractal_dimension(
                unfolded.reshape(batch_size*z, y_unf*x_unf, window, window)
            ).reshape(batch_size, z, y_unf, x_unf)

            # interpolate to original shape of image
            fd_upsampled = F.interpolate(
                fd,
                size=(y, x),
                mode='nearest'
            )

            features[:, cur_channel] = fd_upsampled.squeeze(1)
            cur_channel += 1

    return features

In [None]:
class UNet_with_FF(nn.Module):
    def __init__(self, z=16, y = 256, x = 256, windows = [2,4,8,16], thresholds=None, add_image_channel = False):
        super().__init__()
        self.y = y
        self.x = x
        self.z = z
        self.windows = windows
        self.thresholds = thresholds
        threshold_num = 1 if thresholds is None else len(thresholds)
        in_channels = len(windows) * threshold_num
        if add_image_channel:
            in_channels += 1
        self.in_channels = in_channels

        # input layers
        self.conv3d = nn.Conv3d(self.in_channels, 128, kernel_size=3, padding=1)
        self.attn = nn.Sequential(
            nn.Conv3d(128, 1, kernel_size=1),
            nn.Softmax(dim=2) # softmax over z-dimension (depth)
        )

        self.in_conv = nn.Conv2d(128, 128, kernel_size=3, padding=1)

        # contracting path
        self.enc_conv1 = self.double_conv(128, 128)
        self.enc_conv2 = self.double_conv(128, 256)
        self.enc_conv3 = self.double_conv(256, 512)
        self.enc_conv4 = self.double_conv(512, 1024)

        # expansive path
        self.up_trans1 = self.up_conv(1024, 512)
        self.dec_conv1 = self.double_conv(1024, 512)
        self.up_trans2 = self.up_conv(512, 256)
        self.dec_conv2 = self.double_conv(512, 256)
        self.up_trans3 = self.up_conv(256, 128)
        self.dec_conv3 = self.double_conv(256, 128)

        # final output
        self.out_conv = nn.Conv2d(128, 1, kernel_size=1)

    def forward(self, x):
        x = self.conv3d(x)
        attn = self.attn(x)
        x = torch.sum(x * attn, dim=2)

        x = self.in_conv(x)

        # contracting path
        x1 = self.enc_conv1(x)
        x2 = self.enc_conv2(F.max_pool2d(x1, kernel_size=2))
        x3 = self.enc_conv3(F.max_pool2d(x2, kernel_size=2))
        x4 = self.enc_conv4(F.max_pool2d(x3, kernel_size=2))

        # expansive path
        x = self.up_trans1(x4)
        x = torch.cat([x, x3], dim=1)
        x = self.dec_conv1(x)

        x = self.up_trans2(x)
        x = torch.cat([x, x2], dim=1)
        x = self.dec_conv2(x)

        x = self.up_trans3(x)
        x = torch.cat([x, x1], dim=1)
        x = self.dec_conv3(x)

        x = self.out_conv(x)

        return x

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def up_conv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)


def initialize_weights(model):
    '''
    Initialize weights.
    '''
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.ConvTranspose2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

In [None]:
segment_id = 20230827161847

tile_size = 256 
stride = 128  # stride for moving the tile in the YX dimension
batch_size = 4
z_depth = 16 # thickness of the tile

segment = Volume(segment_id, normalize=True)
validation_rect = [3260, 3260+512, 1860, 1860+512]

valid_dataset = VolumetricDataset(segment[(32 - z_depth//2):(32+z_depth//2),200:5600,1000:4600,0], segment.inklabel[200:5600,1000:4600]/255, tile_size, stride, validation_zone=validation_rect, valid=True)
train_dataset = VolumetricDataset(segment[(32 - z_depth//2):(32+z_depth//2),200:5600,1000:4600,0], segment.inklabel[200:5600,1000:4600]/255, tile_size, stride, validation_zone=validation_rect, valid=False)

valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
import random


def get_random_thresholds(k, min_val=0.05, max_val=0.55, step=0.05):
    ans = []
    while len(ans) < k:
        possible_values = np.arange(min_val, max_val, step).tolist()
        n = random.randint(1, 5)
        random_numbers = sorted(random.sample(possible_values, k=n))
        if random_numbers not in ans:
            ans.append(random_numbers)
    return ans

def get_random_windows(k, tile_size=256):
    ans = []
    while len(ans) < k:
        possible_values = [2]
        while possible_values[-1] < tile_size//2:
            possible_values.append(possible_values[-1] * 2)
        n = random.randint(1, 7)
        random_numbers = sorted(random.sample(possible_values, k=n))
        if random_numbers not in ans:
            ans.append(random_numbers)
    return ans

abs_thresholds_search = [[0.05, 0.2, 0.35000000000000003]]
relative_thresholds_search = [[0.5, 1.75, 2.25]]
widndows_search = [
    [2,4,8,16,32,64,128],
    [2,4,8,16,32,64],
    [2,4,8,16,32],
    [2,4,8,16],
    [2,4,8],
    [2,4],
    [2, 8, 32, 128],
    [2, 16, 128],
    [2, 128],
    [4,16,128]
]
print(abs_thresholds_search, relative_thresholds_search, "\n", widndows_search)

In [None]:
from tqdm import tqdm
import time
clip_value = 10.0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 50
logs = open("logs.txt", "w")
start_time = time.time()
early_stopping = 4
            
for thresholds in abs_thresholds_search:
    for windows in widndows_search:
        for add_image_channel in [False, True]:
            print(f"Model windows={windows}, thresholds={thresholds}, abs=True, add_image_channel={add_image_channel}")
            logs.write(f"Model {windows} {thresholds} {add_image_channel}\n")
            model = UNet_with_FF(z=z_depth, y=tile_size, x=tile_size,
                        windows=windows, thresholds=thresholds, add_image_channel=add_image_channel)
            initialize_weights(model)
            model = model.to(device)
            criterion = nn.BCEWithLogitsLoss().to(device)
            optimizer = optim.AdamW(model.parameters(), lr=1e-4)
            scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
            scaler = GradScaler()
            model.train()
            no_improvement = 0
            best_score = 1e9
            
            for epoch in tqdm(range(NUM_EPOCHS)):
                running_loss = 0.0
                    
                model.train()
            
                for batch_tiles, batch_labels, _ in tqdm(dataloader, position=0, leave=False):
                    batch_tiles, batch_labels = batch_tiles.to(device), batch_labels.float().to(device)
                    optimizer.zero_grad()
            
                    with autocast(device_type=device.type):
                        outputs = calc_fractal_features(batch_tiles,
                                                        windows = windows,
                                                        thresholds = thresholds,
                                                        absolute = True,
                                                        add_image_channel = add_image_channel)
                        outputs = model(outputs)
                        loss = criterion(outputs, batch_labels)
        
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), clip_value)
        
                    scaler.step(optimizer)
                    scaler.update()
                    
                    running_loss += loss.item()
                
                print(f"Loss {running_loss/len(dataloader)}")
                    
                scheduler.step()
                
                print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], running loss: {running_loss/len(dataloader)}')

                letter_predictions = np.zeros_like(segment.inklabel[200:5600, 1000:4600], dtype=np.float32)
                counter_predictions = np.zeros_like(segment.inklabel[200:5600, 1000:4600], dtype=np.float32)
                val_loss = 0.0

                model.eval()
                with torch.no_grad():
                    for batch_tiles, batch_labels, corners in valid_dataloader:
                        batch_tiles, batch_labels = batch_tiles.to(device), batch_labels.float().to(device)
                
                        with autocast(device_type=device.type):
                            outputs = calc_fractal_features(batch_tiles,
                                                                windows = windows,
                                                                thresholds = thresholds,
                                                                absolute = True,
                                                                add_image_channel = add_image_channel)
                            outputs = model(outputs)

                            val_loss += criterion(outputs, batch_labels)
                            # apply sigmoid to get probabilities from logits
                            predictions = torch.sigmoid(outputs)

                            # update tiles predictions
                            corners = corners.squeeze(1).cpu().numpy()
                            for idx in range(corners.shape[0]):
                                x_start, y_start = corners[idx, 0], corners[idx, 1]
                                prediction_tile = predictions.cpu().numpy()[idx, 0] 
                                letter_predictions[x_start:x_start + tile_size, y_start:y_start + tile_size] += prediction_tile
                                counter_predictions[x_start:x_start + tile_size, y_start:y_start + tile_size] += 1

                val_loss /= len(valid_dataloader)
                print("Val Loss:", val_loss)

                # avoid division by zero by setting any zero counts to 1
                counter_predictions[counter_predictions == 0] = 1
                
                # normalize the predictions by the counter values
                letter_predictions /= counter_predictions
            
                if best_score > val_loss:
                    no_improvement = 0
                    best_score = val_loss
                    torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()},
                               f'best_model_windows_{windows}_thresholds_{thresholds}_image_{add_image_channel}_epoch_{epoch}_{val_loss}.pth')
                else:
                    no_improvement += 1
                
                logs.write(f"{epoch} {running_loss/len(dataloader)} {val_loss}\n")
                    
                if running_loss/len(dataloader) < 0.08 or time.time() - start_time > 18000 or no_improvement > 6:
                    print(f"Final loss {running_loss/len(dataloader)}")
                     # plot ground truth and model predictions
                    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
                    
                    ax = axes[0]
                    ax.imshow(segment.inklabel[200:5600, 1000:4600] / 255, cmap='gray')
                    ax.set_title('Ground Truth Label')
                    ax.axis('off')
                    
                    ax = axes[1]
                    ax.imshow(letter_predictions, cmap='gray')
                    ax.set_title('Model Prediction')
                    ax.axis('off')
                    
                    plt.savefig(f"Model windows={windows}, thresholds={thresholds}, add_image_channel={add_image_channel}.png")
                    torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()},
                               f'final_model_windows_{windows}_thresholds_{thresholds}_image_{add_image_channel}_{val_loss}.pth')
                    break
print("Training completed.")

logs.close()


