<a href="https://colab.research.google.com/github/Orasz/CNN4COVID19/blob/main/unetDiceLoss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
!pip install git+https://github.com/albumentations-team/albumentations.git

Collecting git+https://github.com/albumentations-team/albumentations.git
  Cloning https://github.com/albumentations-team/albumentations.git to /tmp/pip-req-build-9c5thrvk
  Running command git clone -q https://github.com/albumentations-team/albumentations.git /tmp/pip-req-build-9c5thrvk
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.5.2-cp37-none-any.whl size=86173 sha256=781c3c475be02ca0d33004af224a87a2e9fabc25c5d68829b66616a0d6fe6d76
  Stored in directory: /tmp/pip-ephem-wheel-cache-xf3ruqig/wheels/e2/85/3e/2a40fac5cc1f43ced656603bb2fca1327b30ec7de1b1b66517
Successfully built albumentations


In [25]:

import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim


#DATASET


class QaTaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.masks = os.listdir(mask_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.masks[index])
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [26]:
#UTILS
import torch
import torchvision
from torch.utils.data import DataLoader

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = QaTaDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = QaTaDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

In [27]:
from torch import nn
import torch


@torch.jit.script
def autocrop(encoder_layer: torch.Tensor, decoder_layer: torch.Tensor):
    """
    Center-crops the encoder_layer to the size of the decoder_layer,
    so that merging (concatenation) between levels/blocks is possible.
    This is only necessary for input sizes != 2**n for 'same' padding and always required for 'valid' padding.
    """
    if encoder_layer.shape[2:] != decoder_layer.shape[2:]:
        ds = encoder_layer.shape[2:]
        es = decoder_layer.shape[2:]
        assert ds[0] >= es[0]
        assert ds[1] >= es[1]
        if encoder_layer.dim() == 4:  # 2D
            encoder_layer = encoder_layer[
                            :,
                            :,
                            ((ds[0] - es[0]) // 2):((ds[0] + es[0]) // 2),
                            ((ds[1] - es[1]) // 2):((ds[1] + es[1]) // 2)
                            ]
        elif encoder_layer.dim() == 5:  # 3D
            assert ds[2] >= es[2]
            encoder_layer = encoder_layer[
                            :,
                            :,
                            ((ds[0] - es[0]) // 2):((ds[0] + es[0]) // 2),
                            ((ds[1] - es[1]) // 2):((ds[1] + es[1]) // 2),
                            ((ds[2] - es[2]) // 2):((ds[2] + es[2]) // 2),
                            ]
    return encoder_layer, decoder_layer


def conv_layer(dim: int):
    if dim == 3:
        return nn.Conv3d
    elif dim == 2:
        return nn.Conv2d


def get_conv_layer(in_channels: int,
                   out_channels: int,
                   kernel_size: int = 3,
                   stride: int = 1,
                   padding: int = 1,
                   bias: bool = True,
                   dim: int = 2):
    return conv_layer(dim)(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                           bias=bias)


def conv_transpose_layer(dim: int):
    if dim == 3:
        return nn.ConvTranspose3d
    elif dim == 2:
        return nn.ConvTranspose2d


def get_up_layer(in_channels: int,
                 out_channels: int,
                 kernel_size: int = 2,
                 stride: int = 2,
                 dim: int = 3,
                 up_mode: str = 'transposed',
                 ):
    if up_mode == 'transposed':
        return conv_transpose_layer(dim)(in_channels, out_channels, kernel_size=kernel_size, stride=stride)
    else:
        return nn.Upsample(scale_factor=2.0, mode=up_mode)


def maxpool_layer(dim: int):
    if dim == 3:
        return nn.MaxPool3d
    elif dim == 2:
        return nn.MaxPool2d


def get_maxpool_layer(kernel_size: int = 2,
                      stride: int = 2,
                      padding: int = 0,
                      dim: int = 2):
    return maxpool_layer(dim=dim)(kernel_size=kernel_size, stride=stride, padding=padding)


def get_activation(activation: str):
    if activation == 'relu':
        return nn.ReLU()
    elif activation == 'leaky':
        return nn.LeakyReLU(negative_slope=0.1)
    elif activation == 'elu':
        return nn.ELU()


def get_normalization(normalization: str,
                      num_channels: int,
                      dim: int):
    if normalization == 'batch':
        if dim == 3:
            return nn.BatchNorm3d(num_channels)
        elif dim == 2:
            return nn.BatchNorm2d(num_channels)
    elif normalization == 'instance':
        if dim == 3:
            return nn.InstanceNorm3d(num_channels)
        elif dim == 2:
            return nn.InstanceNorm2d(num_channels)
    elif 'group' in normalization:
        num_groups = int(normalization.partition('group')[-1])  # get the group size from string
        return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)


class Concatenate(nn.Module):
    def __init__(self):
        super(Concatenate, self).__init__()

    def forward(self, layer_1, layer_2):
        x = torch.cat((layer_1, layer_2), 1)

        return x


class DownBlock(nn.Module):
    """
    A helper Module that performs 2 Convolutions and 1 MaxPool.
    An activation follows each convolution.
    A normalization layer follows each convolution.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 pooling: bool = True,
                 activation: str = 'relu',
                 normalization: str = None,
                 dim: str = 2,
                 conv_mode: str = 'same'):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling
        self.normalization = normalization
        if conv_mode == 'same':
            self.padding = 1
        elif conv_mode == 'valid':
            self.padding = 0
        self.dim = dim
        self.activation = activation

        # conv layers
        self.conv1 = get_conv_layer(self.in_channels, self.out_channels, kernel_size=3, stride=1, padding=self.padding,
                                    bias=True, dim=self.dim)
        self.conv2 = get_conv_layer(self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=self.padding,
                                    bias=True, dim=self.dim)

        # pooling layer
        if self.pooling:
            self.pool = get_maxpool_layer(kernel_size=2, stride=2, padding=0, dim=self.dim)

        # activation layers
        self.act1 = get_activation(self.activation)
        self.act2 = get_activation(self.activation)

        # normalization layers
        if self.normalization:
            self.norm1 = get_normalization(normalization=self.normalization, num_channels=self.out_channels,
                                           dim=self.dim)
            self.norm2 = get_normalization(normalization=self.normalization, num_channels=self.out_channels,
                                           dim=self.dim)

    def forward(self, x):
        y = self.conv1(x)  # convolution 1
        y = self.act1(y)  # activation 1
        if self.normalization:
            y = self.norm1(y)  # normalization 1
        y = self.conv2(y)  # convolution 2
        y = self.act2(y)  # activation 2
        if self.normalization:
            y = self.norm2(y)  # normalization 2

        before_pooling = y  # save the outputs before the pooling operation
        if self.pooling:
            y = self.pool(y)  # pooling
        return y, before_pooling


class UpBlock(nn.Module):
    """
    A helper Module that performs 2 Convolutions and 1 UpConvolution/Upsample.
    An activation follows each convolution.
    A normalization layer follows each convolution.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 activation: str = 'relu',
                 normalization: str = None,
                 dim: int = 3,
                 conv_mode: str = 'same',
                 up_mode: str = 'transposed'
                 ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalization = normalization
        if conv_mode == 'same':
            self.padding = 1
        elif conv_mode == 'valid':
            self.padding = 0
        self.dim = dim
        self.activation = activation
        self.up_mode = up_mode

        # upconvolution/upsample layer
        self.up = get_up_layer(self.in_channels, self.out_channels, kernel_size=2, stride=2, dim=self.dim,
                               up_mode=self.up_mode)

        # conv layers
        self.conv0 = get_conv_layer(self.in_channels, self.out_channels, kernel_size=1, stride=1, padding=0,
                                    bias=True, dim=self.dim)
        self.conv1 = get_conv_layer(2 * self.out_channels, self.out_channels, kernel_size=3, stride=1,
                                    padding=self.padding,
                                    bias=True, dim=self.dim)
        self.conv2 = get_conv_layer(self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=self.padding,
                                    bias=True, dim=self.dim)

        # activation layers
        self.act0 = get_activation(self.activation)
        self.act1 = get_activation(self.activation)
        self.act2 = get_activation(self.activation)

        # normalization layers
        if self.normalization:
            self.norm0 = get_normalization(normalization=self.normalization, num_channels=self.out_channels,
                                           dim=self.dim)
            self.norm1 = get_normalization(normalization=self.normalization, num_channels=self.out_channels,
                                           dim=self.dim)
            self.norm2 = get_normalization(normalization=self.normalization, num_channels=self.out_channels,
                                           dim=self.dim)

        # concatenate layer
        self.concat = Concatenate()

    def forward(self, encoder_layer, decoder_layer):
        """ Forward pass
        Arguments:
            encoder_layer: Tensor from the encoder pathway
            decoder_layer: Tensor from the decoder pathway (to be up'd)
        """
        up_layer = self.up(decoder_layer)  # up-convolution/up-sampling
        cropped_encoder_layer, dec_layer = autocrop(encoder_layer, up_layer)  # cropping

        if self.up_mode != 'transposed':
            # We need to reduce the channel dimension with a conv layer
            up_layer = self.conv0(up_layer)  # convolution 0
        up_layer = self.act0(up_layer)  # activation 0
        if self.normalization:
            up_layer = self.norm0(up_layer)  # normalization 0

        merged_layer = self.concat(up_layer, cropped_encoder_layer)  # concatenation
        y = self.conv1(merged_layer)  # convolution 1
        y = self.act1(y)  # activation 1
        if self.normalization:
            y = self.norm1(y)  # normalization 1
        y = self.conv2(y)  # convolution 2
        y = self.act2(y)  # acivation 2
        if self.normalization:
            y = self.norm2(y)  # normalization 2
        return y


class UNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 out_channels: int = 2,
                 n_blocks: int = 4,
                 start_filters: int = 32,
                 activation: str = 'relu',
                 normalization: str = 'batch',
                 conv_mode: str = 'same',
                 dim: int = 2,
                 up_mode: str = 'transposed'
                 ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_blocks = n_blocks
        self.start_filters = start_filters
        self.activation = activation
        self.normalization = normalization
        self.conv_mode = conv_mode
        self.dim = dim
        self.up_mode = up_mode

        self.down_blocks = []
        self.up_blocks = []

        # create encoder path
        for i in range(self.n_blocks):
            num_filters_in = self.in_channels if i == 0 else num_filters_out
            num_filters_out = self.start_filters * (2 ** i)
            pooling = True if i < self.n_blocks - 1 else False

            down_block = DownBlock(in_channels=num_filters_in,
                                   out_channels=num_filters_out,
                                   pooling=pooling,
                                   activation=self.activation,
                                   normalization=self.normalization,
                                   conv_mode=self.conv_mode,
                                   dim=self.dim)

            self.down_blocks.append(down_block)

        # create decoder path (requires only n_blocks-1 blocks)
        for i in range(n_blocks - 1):
            num_filters_in = num_filters_out
            num_filters_out = num_filters_in // 2

            up_block = UpBlock(in_channels=num_filters_in,
                               out_channels=num_filters_out,
                               activation=self.activation,
                               normalization=self.normalization,
                               conv_mode=self.conv_mode,
                               dim=self.dim,
                               up_mode=self.up_mode)

            self.up_blocks.append(up_block)

        # final convolution
        self.conv_final = get_conv_layer(num_filters_out, self.out_channels, kernel_size=1, stride=1, padding=0,
                                         bias=True, dim=self.dim)

        # add the list of modules to current module
        self.down_blocks = nn.ModuleList(self.down_blocks)
        self.up_blocks = nn.ModuleList(self.up_blocks)

        # initialize the weights
        self.initialize_parameters()

    @staticmethod
    def weight_init(module, method, **kwargs):
        if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):
            method(module.weight, **kwargs)  # weights

    @staticmethod
    def bias_init(module, method, **kwargs):
        if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):
            method(module.bias, **kwargs)  # bias

    def initialize_parameters(self,
                              method_weights=nn.init.xavier_uniform_,
                              method_bias=nn.init.zeros_,
                              kwargs_weights={},
                              kwargs_bias={}
                              ):
        for module in self.modules():
            self.weight_init(module, method_weights, **kwargs_weights)  # initialize weights
            self.bias_init(module, method_bias, **kwargs_bias)  # initialize bias

    def forward(self, x: torch.tensor):
        encoder_output = []

        # Encoder pathway
        for module in self.down_blocks:
            x, before_pooling = module(x)
            encoder_output.append(before_pooling)

        # Decoder pathway
        for i, module in enumerate(self.up_blocks):
            before_pool = encoder_output[-(i + 2)]
            x = module(before_pool, x)

        x = self.conv_final(x)

        return x

    def __repr__(self):
        attributes = {attr_key: self.__dict__[attr_key] for attr_key in self.__dict__.keys() if '_' not in attr_key[0] and 'training' not in attr_key}
        d = {self.__class__.__name__: attributes}
        return f'{d}'

In [36]:
#TRAINING
# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 256  
IMAGE_WIDTH = 256 
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "/content/drive/MyDrive/UNET/train_img/"
TRAIN_MASK_DIR = "/content/drive/MyDrive/UNET/train_masks/"
VAL_IMG_DIR = "/content/drive/MyDrive/UNET/val_img/"
VAL_MASK_DIR = "/content/drive/MyDrive/UNET/val_masks/"

class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
    Args:
        smooth: A float number to smooth loss, and avoid NaN error, default: 1
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict
        reduction: Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    Raise:
        Exception if unexpected reduction
    """
    def __init__(self, smooth=1, p=2, reduction='mean'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p
        self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(predict.shape[0], -1)
        target = target.contiguous().view(target.shape[0], -1)

        num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
        den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth

        loss = 1 - num / den

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    loss2 = nn.BCEWithLogitsLoss()
    counter = 1
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets) + loss2(predictions,targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())
        counter +=1


train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            #A.centerCrop(224),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
)

val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            #A.centerCrop(224),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
)

model = UNet(in_channels=3,
             out_channels=1,
             n_blocks=4,
             start_filters=32,
             activation='relu',
             normalization='batch',
             conv_mode='same',
             dim=2).to(DEVICE)
             #    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = BinaryDiceLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
)

check_accuracy(val_loader, model, device=DEVICE)
scaler = torch.cuda.amp.GradScaler()
counter = 1
for epoch in range(NUM_EPOCHS):
        print(f"epoch: {counter}")
        train_fn(train_loader, model, optimizer, loss_fn, scaler)
        check_accuracy(val_loader, model, device=DEVICE)
        counter+=1
        # print some examples to a folder
save_predictions_as_imgs(
            val_loader, model, folder="/content/drive/MyDrive/UNET/saved_images/", device=DEVICE)









  0%|          | 0/148 [00:00<?, ?it/s][A[A[A[A[A[A

Got 5743595/38731776 with acc 14.83
Dice score: 0.2555655241012573
epoch: 1








  0%|          | 0/148 [00:00<?, ?it/s, loss=1.94][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:07,  2.18it/s, loss=1.94][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:07,  2.18it/s, loss=1.96][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<00:58,  2.51it/s, loss=1.96][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<00:58,  2.51it/s, loss=1.88][A[A[A[A[A[A





  2%|▏         | 3/148 [00:00<00:50,  2.88it/s, loss=1.88][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:50,  2.88it/s, loss=1.86][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:44,  3.23it/s, loss=1.86][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:44,  3.23it/s, loss=1.84][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:40,  3.51it/s, loss=1.84][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:40,  3.51it/s, loss=1.82][A[A[A[A[A[A





  4%|▍         | 6/148 [00:01<00:37,  3.77it/s, loss=1.82][A[A[A[A[A[A





  4%|▍         | 6

Got 26095760/38731776 with acc 67.38
Dice score: 0.39436763525009155
epoch: 2








  0%|          | 0/148 [00:00<?, ?it/s, loss=1.53][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:13,  1.99it/s, loss=1.53][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:13,  1.99it/s, loss=1.52][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:01,  2.38it/s, loss=1.52][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:01,  2.38it/s, loss=1.59][A[A[A[A[A[A





  2%|▏         | 3/148 [00:00<00:51,  2.80it/s, loss=1.59][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:51,  2.80it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:45,  3.15it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:45,  3.15it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:41,  3.46it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:41,  3.46it/s, loss=1.56][A[A[A[A[A[A





  4%|▍         | 6/148 [00:01<00:38,  3.72it/s, loss=1.56][A[A[A[A[A[A





  4%|▍         | 6

Got 24836227/38731776 with acc 64.12
Dice score: 0.39483004808425903
epoch: 3








  0%|          | 0/148 [00:00<?, ?it/s, loss=1.56][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:14,  1.98it/s, loss=1.56][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:14,  1.98it/s, loss=1.54][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:01,  2.37it/s, loss=1.54][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:01,  2.37it/s, loss=1.52][A[A[A[A[A[A





  2%|▏         | 3/148 [00:00<00:53,  2.73it/s, loss=1.52][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:53,  2.73it/s, loss=1.52][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:46,  3.07it/s, loss=1.52][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:46,  3.07it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:42,  3.38it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:42,  3.38it/s, loss=1.59][A[A[A[A[A[A





  4%|▍         | 6/148 [00:01<00:39,  3.63it/s, loss=1.59][A[A[A[A[A[A





  4%|▍         | 6

Got 27185830/38731776 with acc 70.19
Dice score: 0.4152595102787018
epoch: 4








  0%|          | 0/148 [00:00<?, ?it/s, loss=1.57][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:14,  1.99it/s, loss=1.57][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:14,  1.99it/s, loss=1.49][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:01,  2.36it/s, loss=1.49][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:01,  2.36it/s, loss=1.56][A[A[A[A[A[A





  2%|▏         | 3/148 [00:00<00:52,  2.77it/s, loss=1.56][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:52,  2.77it/s, loss=1.52][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:45,  3.14it/s, loss=1.52][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:45,  3.14it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:41,  3.45it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:41,  3.45it/s, loss=1.57][A[A[A[A[A[A





  4%|▍         | 6/148 [00:01<00:37,  3.74it/s, loss=1.57][A[A[A[A[A[A





  4%|▍         | 6

Got 27828948/38731776 with acc 71.85
Dice score: 0.42170655727386475
epoch: 5








  0%|          | 0/148 [00:00<?, ?it/s, loss=1.55][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:16,  1.92it/s, loss=1.55][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:16,  1.92it/s, loss=1.57][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:03,  2.29it/s, loss=1.57][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:03,  2.29it/s, loss=1.55][A[A[A[A[A[A





  2%|▏         | 3/148 [00:00<00:53,  2.70it/s, loss=1.55][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:53,  2.70it/s, loss=1.51][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:47,  3.06it/s, loss=1.51][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:47,  3.06it/s, loss=1.56][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:42,  3.37it/s, loss=1.56][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:42,  3.37it/s, loss=1.52][A[A[A[A[A[A





  4%|▍         | 6/148 [00:01<00:38,  3.67it/s, loss=1.52][A[A[A[A[A[A





  4%|▍         | 6

Got 27333931/38731776 with acc 70.57
Dice score: 0.4230419397354126
epoch: 6








  0%|          | 0/148 [00:00<?, ?it/s, loss=1.54][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:15,  1.95it/s, loss=1.54][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:15,  1.95it/s, loss=1.51][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:02,  2.32it/s, loss=1.51][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:02,  2.32it/s, loss=1.53][A[A[A[A[A[A





  2%|▏         | 3/148 [00:00<00:54,  2.67it/s, loss=1.53][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:54,  2.67it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:47,  3.06it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:47,  3.06it/s, loss=1.5] [A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:41,  3.42it/s, loss=1.5][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:41,  3.42it/s, loss=1.51][A[A[A[A[A[A





  4%|▍         | 6/148 [00:01<00:38,  3.68it/s, loss=1.51][A[A[A[A[A[A





  4%|▍         | 6/

Got 27956250/38731776 with acc 72.18
Dice score: 0.4243256449699402
epoch: 7








  0%|          | 0/148 [00:00<?, ?it/s, loss=1.48][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:14,  1.97it/s, loss=1.48][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:14,  1.97it/s, loss=1.52][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:01,  2.37it/s, loss=1.52][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:01,  2.37it/s, loss=1.5] [A[A[A[A[A[A





  2%|▏         | 3/148 [00:00<00:52,  2.74it/s, loss=1.5][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:52,  2.74it/s, loss=1.52][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:46,  3.08it/s, loss=1.52][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:46,  3.08it/s, loss=1.5] [A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:41,  3.41it/s, loss=1.5][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:41,  3.41it/s, loss=1.53][A[A[A[A[A[A





  4%|▍         | 6/148 [00:01<00:38,  3.66it/s, loss=1.53][A[A[A[A[A[A





  4%|▍         | 6/1

Got 27327477/38731776 with acc 70.56
Dice score: 0.4233703315258026
epoch: 8








  0%|          | 0/148 [00:00<?, ?it/s, loss=1.56][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:16,  1.92it/s, loss=1.56][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:16,  1.92it/s, loss=1.51][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:04,  2.26it/s, loss=1.51][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:04,  2.26it/s, loss=1.51][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:54,  2.66it/s, loss=1.51][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:54,  2.66it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:47,  3.06it/s, loss=1.55][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:47,  3.06it/s, loss=1.49][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:42,  3.33it/s, loss=1.49][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:42,  3.33it/s, loss=1.52][A[A[A[A[A[A





  4%|▍         | 6/148 [00:01<00:39,  3.62it/s, loss=1.52][A[A[A[A[A[A





  4%|▍         | 6

Got 28864686/38731776 with acc 74.52
Dice score: 0.4345478415489197
epoch: 9








  0%|          | 0/148 [00:00<?, ?it/s, loss=1.47][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:15,  1.94it/s, loss=1.47][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:15,  1.94it/s, loss=1.48][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:04,  2.26it/s, loss=1.48][A[A[A[A[A[A





  1%|▏         | 2/148 [00:01<01:04,  2.26it/s, loss=1.5] [A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:55,  2.62it/s, loss=1.5][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:55,  2.62it/s, loss=1.47][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:47,  3.01it/s, loss=1.47][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:47,  3.01it/s, loss=1.45][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:42,  3.38it/s, loss=1.45][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:42,  3.38it/s, loss=1.5] [A[A[A[A[A[A





  4%|▍         | 6/148 [00:01<00:38,  3.68it/s, loss=1.5][A[A[A[A[A[A





  4%|▍         | 6/1

Got 28588054/38731776 with acc 73.81
Dice score: 0.43276509642601013
epoch: 10








  0%|          | 0/148 [00:00<?, ?it/s, loss=1.5][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:16,  1.92it/s, loss=1.5][A[A[A[A[A[A





  1%|          | 1/148 [00:00<01:16,  1.92it/s, loss=1.49][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:03,  2.30it/s, loss=1.49][A[A[A[A[A[A





  1%|▏         | 2/148 [00:00<01:03,  2.30it/s, loss=1.51][A[A[A[A[A[A





  2%|▏         | 3/148 [00:00<00:53,  2.69it/s, loss=1.51][A[A[A[A[A[A





  2%|▏         | 3/148 [00:01<00:53,  2.69it/s, loss=1.49][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:47,  3.05it/s, loss=1.49][A[A[A[A[A[A





  3%|▎         | 4/148 [00:01<00:47,  3.05it/s, loss=1.45][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:42,  3.38it/s, loss=1.45][A[A[A[A[A[A





  3%|▎         | 5/148 [00:01<00:42,  3.38it/s, loss=1.49][A[A[A[A[A[A





  4%|▍         | 6/148 [00:01<00:39,  3.60it/s, loss=1.49][A[A[A[A[A[A





  4%|▍         | 6/1

Got 28602015/38731776 with acc 73.85
Dice score: 0.41587552428245544
