In [None]:
import numpy as np
from torchsummary import summary
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# загрузка данных после предобработки
data = np.load('PATH_to_DATA')
masks = np.load('PATH_to_MASKS')

In [None]:
# разделение данных на тренировочную, валидационную и тестовую выборки
ix = np.random.choice(len(data), len(data), False)
tr, val, ts = np.split(ix, ['boundary samples'])
print(len(tr), len(val), len(ts))

In [None]:
# загрузка датасетов
batch_size = 16

data_train = DataLoader(list(zip(data[tr], masks[tr])),
                     batch_size=batch_size,
                     shuffle=True,
                     pin_memory=True)
data_val = DataLoader(list(zip(data[val], masks[val])),
                      batch_size=batch_size, shuffle=False,
                      pin_memory=True)
data_test = DataLoader(list(zip(data[ts], masks[ts])),
                     batch_size=batch_size, shuffle=False,
                     pin_memory=True)

In [None]:
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# инициализация модели
model = MODEL()
print(summary(model.to(device), (3, 256, 256)))

In [None]:
# установка парметров обучения
LEARNING_RATE = 1e-4
num_epochs = 50
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

In [None]:
# обучение модели
history = {"epochs": np.arange(num_epochs)+1, "score": [], "loss": []}
for epoch in range(num_epochs):
    dice_score = 0
    iou_score = 0
    loop = tqdm(enumerate(data_train),total=len(data_train))
    for batch_idx, (data, targets) in loop:
        data = data.to(device)
        targets = targets.to(device)
        targets = targets.type(torch.long)
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            softmax = nn.Softmax(dim=1)
            preds = torch.argmax(softmax(model(data)),axis=1)
            dice_score += (2 * (preds * targets).sum()) / ((preds + targets).sum() + 1e-8)
            iou_score += (((preds & targets).float().sum((1, 2)) + 1e-8) / ((preds | targets).float().sum((1, 2)) + 1e-8)).mean().item()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())
    avg_loss = 1 - (dice_score / len(loop))
    avg_score = iou_score/ len(loop)
    print('loss: %f' % avg_loss)
    print('score: %f' % avg_score)
    print('epoch: %f'% epoch)
    history["score"].append(avg_score)
    history["loss"].append(avg_loss)

In [None]:
# сохранение модели
model_scripted = torch.jit.script(model)
model_scripted.save('PATH_to_MODEL')

In [None]:
# вычисление метрик на выборках
def check_accuracy(loader, model):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    iou_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            softmax = nn.Softmax(dim=1)
            preds = torch.argmax(softmax(model(x)),axis=1)
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            iou_score += (((preds & y).float().sum((1, 2)) + 1e-8) / ((preds | y).float().sum((1, 2)) + 1e-8)).mean().item()

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    print(f"IoU score: {iou_score/len(loader)}")
    model.train()

print(check_accuracy(data_train, model))
print(check_accuracy(data_val, model))
print(check_accuracy(data_test, model))

In [None]:
# построение графиков обучения
def make_graph(history, model_name, loss_name):
    fig, ax = plt.subplots(1, 2, figsize = (14, 7))
    x = history["epochs"]
    loss_train = history["loss"]
    score_train = history["score"]
    ax[0].plot(x, loss_train, label = "train", color = "red")
    ax[0].legend(fontsize = 14)
    ax[0].grid(linestyle = "--")
    ax[0].tick_params(labelsize = 14)
    ax[0].set_xlabel("epoch", fontsize = 14)
    ax[0].set_ylabel("loss", fontsize = 14)
    ax[0].set_title("Loss vs epoch", fontsize = 16)
    ax[0].set_xlim(left = 0, right = x.max())
    ax[0].set_ylim(bottom = 0)
    ax[1].plot(x, score_train, label = "train", color = "blue")
    ax[1].legend(fontsize = 14)
    ax[1].grid(linestyle = "--")
    ax[1].tick_params(labelsize = 14)
    ax[1].set_xlabel("epoch", fontsize = 14)
    ax[1].set_ylabel("score", fontsize = 14)
    ax[1].set_title("Score vs epoch", fontsize = 16)
    ax[1].set_xlim(left = 0, right = x.max())
    ax[1].set_ylim(bottom = 0)
    plt.suptitle(f"Model = {model_name}, loss = {loss_name}", fontsize = 18, y=1.05)
    plt.tight_layout()
    plt.show()

for i in range(len(history['loss'])):
  history['loss'][i] = history['loss'][i].item()
make_graph(history, "MODEL_NAME", "LOSS_NAME")

In [None]:
# визуализация результатов на тестовых изображениях
for x,y in data_val:
    x = x.to(device)
    fig , ax =  plt.subplots(3, 3, figsize=(14, 14))
    softmax = nn.Softmax(dim=1)
    preds = torch.argmax(softmax(model(x)),axis=1).to('cpu')
    img1 = np.transpose(np.array(x[0,:,:,:].to('cpu')),(1,2,0))
    preds1 = np.array(preds[0,:,:])
    mask1 = np.array(y[0,:,:])
    img2 = np.transpose(np.array(x[1,:,:,:].to('cpu')),(1,2,0))
    preds2 = np.array(preds[1,:,:])
    mask2 = np.array(y[1,:,:])
    img3 = np.transpose(np.array(x[2,:,:,:].to('cpu')),(1,2,0))
    preds3 = np.array(preds[2,:,:])
    mask3 = np.array(y[2,:,:])
    ax[0,0].set_title('Image')
    ax[0,1].set_title('Prediction')
    ax[0,2].set_title('Mask')
    ax[1,0].set_title('Image')
    ax[1,1].set_title('Prediction')
    ax[1,2].set_title('Mask')
    ax[2,0].set_title('Image')
    ax[2,1].set_title('Prediction')
    ax[2,2].set_title('Mask')
    ax[0][0].axis("off")
    ax[1][0].axis("off")
    ax[2][0].axis("off")
    ax[0][1].axis("off")
    ax[1][1].axis("off")
    ax[2][1].axis("off")
    ax[0][2].axis("off")
    ax[1][2].axis("off")
    ax[2][2].axis("off")
    ax[0][0].imshow(img1)
    ax[0][1].imshow(preds1)
    ax[0][2].imshow(mask1)
    ax[1][0].imshow(img2)
    ax[1][1].imshow(preds2)
    ax[1][2].imshow(mask2)
    ax[2][0].imshow(img3)
    ax[2][1].imshow(preds3)
    ax[2][2].imshow(mask3)
    break