In [None]:
import numpy as np
import torch
import torch.nn as nn
from torchsummary import summary
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# загрузка фрагментов 256х256 изображений и масок
imgs = np.load('PATH_to_imgs')
masks = np.load('PATH_tp_masks')

In [None]:
# аугментация снимков и масок
imgs_aug = []
masks_aug = []
import albumentations as A
import random
for i in range(len(imgs)):
  augmented = A.HorizontalFlip(p=1)(image=imgs[i], mask=masks[i])
  imgs_aug.append(augmented['image'])
  masks_aug.append(augmented['mask'])
  augmented = A.VerticalFlip(p=1)(image=imgs[i], mask=masks[i])
  imgs_aug.append(augmented['image'])
  masks_aug.append(augmented['mask'])
  augmented = A.Transpose(p=1)(image=imgs[i], mask=masks[i])
  imgs_aug.append(augmented['image'])
  masks_aug.append(augmented['mask'])
  augmented = A.RandomRotate90(p=1)(image=imgs[i], mask=masks[i])
  imgs_aug.append(augmented['image'])
  masks_aug.append(augmented['mask'])
  random.seed(7)
  augmented = A.ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03)(image=imgs[i], mask=masks[i])
  imgs_aug.append(augmented['image'])
  masks_aug.append(augmented['mask'])
  random.seed(7)
  augmented = A.GridDistortion(p=1)(image=imgs[i], mask=masks[i])
  imgs_aug.append(augmented['image'])
  masks_aug.append(augmented['mask'])
  random.seed(7)
  augmented = A.OpticalDistortion(distort_limit=2, shift_limit=0.5, p=1)(image=imgs[i], mask=masks[i])
  imgs_aug.append(augmented['image'])
  masks_aug.append(augmented['mask'])
imgs = np.concatenate((imgs, np.array(imgs_aug)), axis = 0)
masks = np.concatenate((masks, np.array(masks_aug)), axis = 0)

In [None]:
# перевод масок в категориальный вид
def image_cat(image, class_num, black_color = 128):
  pic = np.array(image)
  img = np.zeros((pic.shape[0], pic.shape[1], class_num))
  np.place(img[ :, :, 0], pic[ :, :, 0] >= black_color, 1)
  np.place(img[ :, :, 0], pic[ :, :, 2] >= black_color, 2)
  return img

segms = []
for i in range(len(masks)):
  segms.append(image_cat(masks[i], 1, black_color = 128))
segms = np.array(segms, int).squeeze(3)

In [None]:
# разделение на выборки и загрузка датасетов
ix = np.random.choice(len(imgs), len(imgs), False)
tr, ts = np.split(ix, ['boundary sample'])
print(len(tr), len(ts))

train_batch = torch.utils.data.DataLoader(list(zip(np.rollaxis(imgs[tr], 3, 1), segms[tr])),
                                          batch_size=8, shuffle=True, pin_memory=True)

test_batch = torch.utils.data.DataLoader(list(zip(np.rollaxis(imgs[ts], 3, 1), segms[ts])),
                                         batch_size=8, shuffle=True, pin_memory=True)

In [None]:
# архитектура модели
class unet_model(nn.Module):
    def __init__(self):
        super().__init__()

        input_nbr = 5
        num_ch = 64
        batchNorm_momentum = 0.1

        self.enc_conv0 = nn.Sequential(
            nn.Conv2d(in_channels=input_nbr, out_channels = num_ch, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch, momentum= batchNorm_momentum),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_ch, out_channels=num_ch, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch, momentum= batchNorm_momentum),
            nn.ReLU()
        )

        self.pool0 = nn.MaxPool2d(kernel_size = 2, return_indices = False)

        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(in_channels = num_ch, out_channels = num_ch*2, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*2, momentum= batchNorm_momentum),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_ch*2, out_channels=num_ch*2, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*2, momentum= batchNorm_momentum),
            nn.ReLU()
        )

        self.pool1 = nn.MaxPool2d(kernel_size = 2, return_indices = False)

        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(in_channels= num_ch*2, out_channels=num_ch*4, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*4, momentum= batchNorm_momentum),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_ch*4, out_channels=num_ch*4, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*4, momentum= batchNorm_momentum),
            nn.ReLU()
        )

        self.pool2 = nn.MaxPool2d(kernel_size = 2, return_indices = False)

        self.enc_conv3 = nn.Sequential(
            nn.Conv2d(in_channels=num_ch*4, out_channels=num_ch*8, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*8, momentum= batchNorm_momentum),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_ch*8, out_channels=num_ch*8, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*8, momentum= batchNorm_momentum),
            nn.ReLU()
        )

        self.pool3 = nn.MaxPool2d(kernel_size = 2, return_indices = False)


        self.bottleneck_enc = nn.Sequential(
            nn.Conv2d(in_channels=num_ch*8, out_channels=num_ch*16, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*16, momentum= batchNorm_momentum),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_ch*16, out_channels=num_ch*8, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*8, momentum= batchNorm_momentum),
            nn.ReLU()
        )


        self.upsample0 =  nn.Upsample(scale_factor=2)

        self.dec_conv0 =  nn.Sequential(
            nn.Conv2d(in_channels=num_ch*16, out_channels=num_ch*8, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*8, momentum= batchNorm_momentum),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_ch*8, out_channels=num_ch*4, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*4, momentum= batchNorm_momentum),
            nn.ReLU()
        )

        self.upsample1 =  nn.Upsample(scale_factor=2)

        self.dec_conv1 =  nn.Sequential(
            nn.Conv2d(in_channels=num_ch*8, out_channels=num_ch*4, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*4, momentum= batchNorm_momentum),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_ch*4, out_channels=num_ch*2, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*2, momentum= batchNorm_momentum),
            nn.ReLU()
        )

        self.upsample2 =  nn.Upsample(scale_factor=2)

        self.dec_conv2 =  nn.Sequential(
            nn.Conv2d(in_channels=num_ch*4, out_channels=num_ch*2, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch*2, momentum= batchNorm_momentum),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_ch*2, out_channels=num_ch, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch, momentum= batchNorm_momentum),
            nn.ReLU()
        )

        self.upsample3 = nn.Upsample(scale_factor=2)

        self.dec_conv3 =  nn.Sequential(
            nn.Conv2d(in_channels=num_ch*2, out_channels=num_ch, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch, momentum= batchNorm_momentum),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_ch, out_channels=num_ch, kernel_size=3, padding = (1,1)),
            nn.BatchNorm2d(num_ch, momentum= batchNorm_momentum),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_ch, out_channels=3, kernel_size=1)
        )



    def forward(self, x):
        e0 = self.enc_conv0(x)
        e1 = self.pool0(e0)
        e2 = self.enc_conv1(e1)
        e3 = self.pool1(e2)
        e4 = self.enc_conv2(e3)
        e5 = self.pool2(e4)
        e6 = self.enc_conv3(e5)
        e7 = self.pool3(e6)

        b = self.bottleneck_enc(e7)

        d0 = self.upsample0(b)
        d0 = self.dec_conv0(torch.cat((d0,e6), dim =1))
        d1 = self.upsample1(d0)
        d1 = self.dec_conv1(torch.cat((d1,e4), dim =1))
        d2 = self.upsample2(d1)
        d2 = self.dec_conv2(torch.cat((d2,e2), dim =1))
        d3 = self.upsample3(d2)
        d3 = self.dec_conv3(torch.cat((d3,e0), dim =1))
        return d3

In [None]:
# инициализация модели
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = unet_model().to(DEVICE)
summary(model, (5, 256, 256))

In [None]:
# настройка параметров обучения
LEARNING_RATE = 1e-4
num_epochs = 50
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

In [None]:
# обучение модели
history = {"epochs": np.arange(num_epochs)+1, "score": [], "loss": []}
for epoch in range(num_epochs):
    dice_score = 0
    iou_score = 0
    loop = tqdm(enumerate(train_batch),total=len(train_batch))
    for batch_idx, (data, targets) in loop:
        data = data.to(DEVICE)
        targets = targets.to(DEVICE)
        targets = targets.type(torch.long)
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            softmax = nn.Softmax(dim=1)
            preds = torch.argmax(softmax(model(data)),axis=1)
            dice_score += (2 * (preds * targets).sum()) / ((preds + targets).sum() + 1e-8)
            iou_score += (((preds & targets).float().sum((1, 2)) + 1e-8) / ((preds | targets).float().sum((1, 2)) + 1e-8)).mean().item()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())
    avg_loss = 1 - (dice_score / len(loop))
    avg_score = iou_score/ len(loop)
    print('loss: %f' % avg_loss)
    print('score: %f' % avg_score)
    print('epoch: %f'% epoch)
    history["score"].append(avg_score)
    history["loss"].append(avg_loss)

In [None]:
# сохранение модели
model_scripted = torch.jit.script(model)
model_scripted.save('PATH_to_MODEL')

In [None]:
# вычисление метрик на выборках
def check_accuracy(loader, model):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    iou_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            softmax = nn.Softmax(dim=1)
            preds = torch.argmax(softmax(model(x)),axis=1)
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            iou_score += (((preds & y).float().sum((1, 2)) + 1e-8) / ((preds | y).float().sum((1, 2)) + 1e-8)).mean().item()

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    print(f"IoU score: {iou_score/len(loader)}")
    model.train()

print(check_accuracy(train_batch, model))
print(check_accuracy(test_batch, model))

In [None]:
# построение графиков обучения
def make_graph(history, model_name, loss_name):
    fig, ax = plt.subplots(1, 2, figsize = (14, 7))
    x = history["epochs"]
    loss_train = history["loss"]
    score_train = history["score"]
    ax[0].plot(x, loss_train, label = "train", color = "red")
    ax[0].legend(fontsize = 14)
    ax[0].grid(linestyle = "--")
    ax[0].tick_params(labelsize = 14)
    ax[0].set_xlabel("epoch", fontsize = 14)
    ax[0].set_ylabel("loss", fontsize = 14)
    ax[0].set_title("Loss vs epoch", fontsize = 16)
    ax[0].set_xlim(left = 0, right = x.max())
    ax[0].set_ylim(bottom = 0)
    ax[1].plot(x, score_train, label = "train", color = "blue")
    ax[1].legend(fontsize = 14)
    ax[1].grid(linestyle = "--")
    ax[1].tick_params(labelsize = 14)
    ax[1].set_xlabel("epoch", fontsize = 14)
    ax[1].set_ylabel("score", fontsize = 14)
    ax[1].set_title("Score vs epoch", fontsize = 16)
    ax[1].set_xlim(left = 0, right = x.max())
    ax[1].set_ylim(bottom = 0)
    plt.suptitle(f"Model = {model_name}, loss = {loss_name}", fontsize = 18, y=1.05)
    plt.tight_layout()
    plt.show()

for i in range(len(history['loss'])):
  history['loss'][i] = history['loss'][i].item()
make_graph(history, "MODEL_NAME", "LOSS_NAME")