Примечание: часть функционала все еще находится в процессе разработки и тестирования, поэтому может содержать ошибки

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision.models.feature_extraction import create_feature_extractor
from torch.utils.data import Dataset, DataLoader
import torch.optim.lr_scheduler
import SimpleITK as sitk
import matplotlib.pyplot as plt
from IPython.display import clear_output
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import numpy as np
import os
import random
from PIL import Image
import cv2
import time
from lion_pytorch import Lion

  check_for_updates()


In [None]:
#Классы датасетов (старые, больше не используются из-за более низкой производительности, но при том же количестве эпох дает результат не хуже нового, но это занимает значительно больше времени, и возникает высокая зависимость от скорости диска)
class TrainDataset(Dataset):
    def __init__(self, folder) -> None:
        images_folder = folder
        gt_folder = 'seg-lungs-LUNA16'
        self.images = []
        self.masks = []

        for fname in os.listdir(images_folder):
            if '.mhd' in fname:
                img_path = os.path.join(images_folder, fname)
                mask_path = os.path.join(gt_folder, fname)
                self.images.append(img_path)
                self.masks.append(mask_path)

        self.transform = A.Compose([
            A.Resize(128, 128),
            #A.RandomResizedCrop(height=128, width=128, scale=(0.95, 1.0), ratio=(1.0, 1.0), p=1),
            #A.Rotate(limit=5, p=0.7),
            
            #A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=0.5),
            #A.GaussNoise(var_limit=(5.0, 20.0), p=0.1),
            ToTensorV2()
        ])

    def __getitem__(self, index):
        img = sitk.ReadImage(self.images[index])
        slice_num = random.randint(0, img.GetDepth()-1)
        img = sitk.GetArrayFromImage(img[:,:,slice_num])
        mask = sitk.GetArrayFromImage(sitk.ReadImage(self.masks[index])[:,:,slice_num])
        img = img.astype(np.float32)
        img = ((img - img.min()) / (img.max() - img.min()))

        transformed_all = self.transform(image=img, mask=mask)
        transformed_img = transformed_all['image']
        transformed_mask = transformed_all['mask']
        return transformed_img, transformed_mask

    def __len__(self):
        return len(self.images)


class ValDataset(Dataset):
    def __init__(self, folder) -> None:
        images_folder = folder
        gt_folder = 'seg-lungs-LUNA16'
        self.images = []
        self.masks = []

        for fname in os.listdir(images_folder):
            if '.mhd' in fname:
                img_path = os.path.join(images_folder, fname)
                mask_path = os.path.join(gt_folder, fname)
                self.images.append(img_path)
                self.masks.append(mask_path)

        self.transform = A.Compose([
            A.Resize(128, 128),
            ToTensorV2()
        ])
        
    def __getitem__(self, index):
        img = sitk.ReadImage(self.images[index])
        slice_num = random.randint(0, img.GetDepth()-1)
        img = sitk.GetArrayFromImage(img[:,:,slice_num])
        mask = sitk.GetArrayFromImage(sitk.ReadImage(self.masks[index])[:,:,slice_num])
        img = img.astype(np.float32)
        img = ((img - img.min()) / (img.max() - img.min()))

        transformed_all = self.transform(image=img, mask=mask)
        transformed_img = transformed_all['image']
        transformed_mask = transformed_all['mask']
        return transformed_img, transformed_mask

    def __len__(self):
        return len(self.images)

Было проведено еще несколько попыток изменить класс датасета (их не привожу из-за большого объема), которые в итоге привели к оптимальному и гибкому варианту:

In [3]:
#Класс датасета (оптимальный для работы с разным объемом оперативной памяти, используется в проекте)
class CustomDataset(Dataset):
    def __init__(self, folder, gt_folder, size, batch, max_uses_per_scan, transforms) -> None:
        self.images = []
        self.masks = []
        self.batch = batch #количество исследований КТ, загружаемых одновременно в оперативную память
        self.max_uses_per_scan = max_uses_per_scan #сколько в среднем раз будет браться случайный срез из каждого КТ до перезагрузки в оперативную память
        self.folder = folder 
        self.gt_folder = gt_folder
        self.counter = 0
        self.size = size

        self.mapping = {0: 0, 3: 1, 4: 2, 5: 3}
        self.lookup_table = self._create_lookup_table()
        
        for fname in os.listdir(folder):
            if '.mhd' in fname:
                img_path = os.path.join(folder, fname)
                mask_path = os.path.join(gt_folder, fname)
                self.images.append(img_path)
                self.masks.append(mask_path)

        self.transform = transforms
        self.current_ct_index = 0
        self.scans = list(zip(self.images, self.masks))
        random.shuffle(self.scans)
        self.load()

    def _create_lookup_table(self):
        max_label = max(self.mapping.keys())
        lookup = np.zeros(max_label + 1, dtype=np.int64)
        for original_label, new_label in self.mapping.items():
            lookup[original_label] = new_label
        return lookup
        
    def load(self):
        if self.current_ct_index + self.batch > len(self.images):
            self.current_ct_index = 0
            random.shuffle(self.scans)

        img_filenames = self.scans[self.current_ct_index : self.current_ct_index + self.batch] #pairs

        img_pathes = [img_filename[0] for img_filename in img_filenames]
        mask_pathes = [mask_filename[1] for mask_filename in img_filenames]

        self.imgs = [sitk.ReadImage(img_path) for img_path in img_pathes]
        self.masks = [sitk.ReadImage(mask_path) for mask_path in mask_pathes]
        
        self.current_ct_index += self.batch
        self.counter = 0
        

    def __getitem__(self, index):
        if self.counter >= self.max_uses_per_scan * self.batch:
            self.load()

        index = random.randint(0, self.batch - 1)
        img = self.imgs[index]
        
        slice_num = random.randint(0, img.GetDepth()-1)
        
        img = sitk.GetArrayFromImage(img[:,:,slice_num])
        mask = sitk.GetArrayFromImage(self.masks[index][:,:,slice_num])
        
        img = img.astype(np.float32)
        img = ((img - img.min()) / (img.max() - img.min()))

        mask = self.lookup_table[mask]

        self.counter += 1

        transformed_all = self.transform(image=img, mask=mask)
        transformed_img = transformed_all['image']
        transformed_mask = transformed_all['mask']
        return transformed_img, transformed_mask

    def __len__(self):
        return self.size

In [5]:
#Небольшая нейросеть U-Net для тестирования работоспособности и использования в качестве бейзлайна
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=2)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        skip_connection = x #перед relu, чтобы сохранить детали
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x, skip_connection

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, mid_channels, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x, skip_connection):
        x = self.upconv(x)
        x = torch.cat([x, skip_connection], dim=1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

class MiniUnet(nn.Module):
    def __init__(self, in_channels=1, out_channels=4):
        super().__init__()

        self.encoder1 = EncoderBlock(in_channels, 64)
        self.encoder2 = EncoderBlock(64, 128)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.decoder1 = DecoderBlock(256, 128, 128)
        self.decoder2 = DecoderBlock(128, 64, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)


    def forward(self, x):
        x, skip1 = self.encoder1(x)
        x, skip2 = self.encoder2(x)

        x = self.bottleneck(x)

        x = self.decoder1(x, skip2)
        x = self.decoder2(x, skip1)

        x = self.final_conv(x)
        return x

In [None]:
#Классический U-Net
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=2)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        skip_connection = x #перед relu, чтобы сохранить детали
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x, skip_connection

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, mid_channels, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x, skip_connection):
        x = self.upconv(x)
        x = torch.cat([x, skip_connection], dim=1) #стакнули изображения
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

class Unet(nn.Module):
    def __init__(self, in_channels=1, out_channels=4):
        super().__init__()

        self.encoder1 = EncoderBlock(in_channels, 64)
        self.encoder2 = EncoderBlock(64, 128)
        self.encoder3 = EncoderBlock(128, 256)
        self.encoder4 = EncoderBlock(256, 512)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )

        self.decoder1 = DecoderBlock(1024, 512, 512)
        self.decoder2 = DecoderBlock(512, 256, 256)
        self.decoder3 = DecoderBlock(256, 128, 128)
        self.decoder4 = DecoderBlock(128, 64, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x, skip1 = self.encoder1(x)
        x, skip2 = self.encoder2(x)
        x, skip3 = self.encoder3(x)
        x, skip4 = self.encoder4(x)

        x = self.bottleneck(x)

        x = self.decoder1(x, skip4)
        x = self.decoder2(x, skip3)
        x = self.decoder3(x, skip2)
        x = self.decoder4(x, skip1)

        x = self.final_conv(x)
        return x

In [None]:
#увеличенный U-Net (требует слишком много вычислительных ресурсов)
class MaxUnet(nn.Module):
    def __init__(self, in_channels=1, out_channels=4):
        super().__init__()

        self.encoder1 = EncoderBlock(in_channels, 64)
        self.encoder2 = EncoderBlock(64, 128)
        self.encoder3 = EncoderBlock(128, 256)
        self.encoder4 = EncoderBlock(256, 512)
        self.encoder5 = EncoderBlock(512, 1024)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(1024, 2048, kernel_size=3, padding=1),
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace=True),
            nn.Conv2d(2048, 2048, kernel_size=3, padding=1),
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace=True),
            nn.Conv2d(2048, 2048, kernel_size=3, padding=1),
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace=True),            
        )

        self.decoder1 = DecoderBlock(2048, 1024, 1024)
        self.decoder2 = DecoderBlock(1024, 512, 512)
        self.decoder3 = DecoderBlock(512, 256, 256)
        self.decoder4 = DecoderBlock(256, 128, 128)
        self.decoder5 = DecoderBlock(128, 64, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x, skip1 = self.encoder1(x)
        x, skip2 = self.encoder2(x)
        x, skip3 = self.encoder3(x)
        x, skip4 = self.encoder4(x)
        x, skip5 = self.encoder5(x)

        x = self.bottleneck(x)

        x = self.decoder1(x, skip5)
        x = self.decoder2(x, skip4)
        x = self.decoder3(x, skip3)
        x = self.decoder4(x, skip2)
        x = self.decoder5(x, skip1)

        x = self.final_conv(x)
        return x

In [None]:
#Архитектура U-Net++

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

class NestedDecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, *inputs):
        inputs = [x for x in inputs if x is not None]
        x = torch.cat(inputs, dim=1)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class UnetPlusPlus(nn.Module):
    def __init__(self, in_channels=1, out_channels=4):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = EncoderBlock(in_channels, nb_filter[0])
        self.conv1_0 = EncoderBlock(nb_filter[0], nb_filter[1])
        self.conv2_0 = EncoderBlock(nb_filter[1], nb_filter[2])
        self.conv3_0 = EncoderBlock(nb_filter[2], nb_filter[3])
        self.conv4_0 = EncoderBlock(nb_filter[3], nb_filter[4])

        self.conv0_1 = NestedDecoderBlock(nb_filter[0]+nb_filter[1], nb_filter[0])
        self.conv1_1 = NestedDecoderBlock(nb_filter[1]+nb_filter[2], nb_filter[1])
        self.conv2_1 = NestedDecoderBlock(nb_filter[2]+nb_filter[3], nb_filter[2])
        self.conv3_1 = NestedDecoderBlock(nb_filter[3]+nb_filter[4], nb_filter[3])

        self.conv0_2 = NestedDecoderBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0])
        self.conv1_2 = NestedDecoderBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1])
        self.conv2_2 = NestedDecoderBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2])

        self.conv0_3 = NestedDecoderBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0])
        self.conv1_3 = NestedDecoderBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1])

        self.conv0_4 = NestedDecoderBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1)

    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        output = self.final(x0_4)
        return output

In [None]:
#Первая версия U-Net++ с механизмом внимания (механизм внимания применяется только в конце)

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

class NestedDecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, *inputs):
        inputs = [x for x in inputs if x is not None]
        x = torch.cat(inputs, dim=1)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class AttentionUnetPlusPlus1(nn.Module):
    def __init__(self, in_channels=1, out_channels=4):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = EncoderBlock(in_channels, nb_filter[0])
        self.conv1_0 = EncoderBlock(nb_filter[0], nb_filter[1])
        self.conv2_0 = EncoderBlock(nb_filter[1], nb_filter[2])
        self.conv3_0 = EncoderBlock(nb_filter[2], nb_filter[3])
        self.conv4_0 = EncoderBlock(nb_filter[3], nb_filter[4])

        self.conv0_1 = NestedDecoderBlock(nb_filter[0]+nb_filter[1], nb_filter[0])
        self.conv1_1 = NestedDecoderBlock(nb_filter[1]+nb_filter[2], nb_filter[1])
        self.conv2_1 = NestedDecoderBlock(nb_filter[2]+nb_filter[3], nb_filter[2])
        self.conv3_1 = NestedDecoderBlock(nb_filter[3]+nb_filter[4], nb_filter[3])

        self.conv0_2 = NestedDecoderBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0])
        self.conv1_2 = NestedDecoderBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1])
        self.conv2_2 = NestedDecoderBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2])

        self.conv0_3 = NestedDecoderBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0])
        self.conv1_3 = NestedDecoderBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1])

        self.conv0_4 = NestedDecoderBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0])

        #Attention Gates
        self.attention1 = AttentionGate(F_g=nb_filter[1], F_l=nb_filter[0], F_int=nb_filter[0]//2)
        self.attention2 = AttentionGate(F_g=nb_filter[2], F_l=nb_filter[1], F_int=nb_filter[1]//2)
        self.attention3 = AttentionGate(F_g=nb_filter[3], F_l=nb_filter[2], F_int=nb_filter[2]//2)
        self.attention4 = AttentionGate(F_g=nb_filter[4], F_l=nb_filter[3], F_int=nb_filter[3]//2)

        self.final = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1)

    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        x0_1 = self.attention1(g=self.up(x1_0), x=x0_1)
        x1_1 = self.attention2(g=self.up(x2_0), x=x1_1)
        x2_1 = self.attention3(g=self.up(x3_0), x=x2_1)
        x3_1 = self.attention4(g=self.up(x4_0), x=x3_1)

        output = self.final(x0_4)
        return output

In [None]:
#Вторая версия U-Net++ с механизмом внимания (внедрение механизма внимания в skip-connections)

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, padding=0),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, padding=0),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, padding=0),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class NestedDecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, *inputs):
        inputs = [x for x in inputs if x is not None]
        x = torch.cat(inputs, dim=1)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class AttentionUnetPlusPlus2(nn.Module):
    def __init__(self, in_channels=1, out_channels=8):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = EncoderBlock(in_channels, nb_filter[0])
        self.conv1_0 = EncoderBlock(nb_filter[0], nb_filter[1])
        self.conv2_0 = EncoderBlock(nb_filter[1], nb_filter[2])
        self.conv3_0 = EncoderBlock(nb_filter[2], nb_filter[3])
        self.conv4_0 = EncoderBlock(nb_filter[3], nb_filter[4])

        self.attention1 = AttentionBlock(F_g=nb_filter[1], F_l=nb_filter[0], F_int=nb_filter[0]//2)
        self.attention2 = AttentionBlock(F_g=nb_filter[2], F_l=nb_filter[1], F_int=nb_filter[1]//2)
        self.attention3 = AttentionBlock(F_g=nb_filter[3], F_l=nb_filter[2], F_int=nb_filter[2]//2)
        self.attention4 = AttentionBlock(F_g=nb_filter[4], F_l=nb_filter[3], F_int=nb_filter[3]//2)

        self.conv0_1 = NestedDecoderBlock(nb_filter[0]+nb_filter[1], nb_filter[0])
        self.conv1_1 = NestedDecoderBlock(nb_filter[1]+nb_filter[2], nb_filter[1])
        self.conv2_1 = NestedDecoderBlock(nb_filter[2]+nb_filter[3], nb_filter[2])
        self.conv3_1 = NestedDecoderBlock(nb_filter[3]+nb_filter[4], nb_filter[3])

        self.conv0_2 = NestedDecoderBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0])
        self.conv1_2 = NestedDecoderBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1])
        self.conv2_2 = NestedDecoderBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2])

        self.conv0_3 = NestedDecoderBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0])
        self.conv1_3 = NestedDecoderBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1])

        self.conv0_4 = NestedDecoderBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1)

    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))

        # Attention and Up-sampling
        x0_0_att = self.attention1(g=self.up(x1_0), x=x0_0)
        x0_1 = self.conv0_1(torch.cat([x0_0_att, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_0_att = self.attention2(g=self.up(x2_0), x=x1_0)
        x1_1 = self.conv1_1(torch.cat([x1_0_att, self.up(x2_0)], 1))

        x0_1_att = self.attention1(g=self.up(x1_1), x=x0_1)
        x0_2 = self.conv0_2(torch.cat([x0_0_att, x0_1_att, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_0_att = self.attention3(g=self.up(x3_0), x=x2_0)
        x2_1 = self.conv2_1(torch.cat([x2_0_att, self.up(x3_0)], 1))

        x1_1_att = self.attention2(g=self.up(x2_1), x=x1_1)
        x1_2 = self.conv1_2(torch.cat([x1_0_att, x1_1_att, self.up(x2_1)], 1))

        x0_2_att = self.attention1(g=self.up(x1_2), x=x0_2)
        x0_3 = self.conv0_3(torch.cat([x0_0_att, x0_1_att, x0_2_att, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_0_att = self.attention4(g=self.up(x4_0), x=x3_0)
        x3_1 = self.conv3_1(torch.cat([x3_0_att, self.up(x4_0)], 1))

        x2_1_att = self.attention3(g=self.up(x3_1), x=x2_1)
        x2_2 = self.conv2_2(torch.cat([x2_0_att, x2_1_att, self.up(x3_1)], 1))

        x1_2_att = self.attention2(g=self.up(x2_2), x=x1_2)
        x1_3 = self.conv1_3(torch.cat([x1_0_att, x1_1_att, x1_2_att, self.up(x2_2)], 1))

        x0_3_att = self.attention1(g=self.up(x1_3), x=x0_3)
        x0_4 = self.conv0_4(torch.cat([x0_0_att, x0_1_att, x0_2_att, x0_3_att, self.up(x1_3)], 1))

        output = self.final(x0_4)
        return output

In [None]:
#Третья версия U-Net++ с механизмом внимания (введение residual connections)

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = self.shortcut(x)
        out = self.conv(x)
        out += residual
        out = self.relu(out)
        return out

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class NestedDecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, *inputs):
        inputs = [x for x in inputs if x is not None]
        x = torch.cat(inputs, dim=1)
        residual = self.shortcut(x)
        out = self.conv(x)
        out += residual
        out = self.relu(out)
        return out

class AttentionUnetPlusPlus3(nn.Module):
    def __init__(self, in_channels=1, out_channels=8):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = EncoderBlock(in_channels, nb_filter[0])
        self.conv1_0 = EncoderBlock(nb_filter[0], nb_filter[1])
        self.conv2_0 = EncoderBlock(nb_filter[1], nb_filter[2])
        self.conv3_0 = EncoderBlock(nb_filter[2], nb_filter[3])
        self.conv4_0 = EncoderBlock(nb_filter[3], nb_filter[4])

        self.attention1 = AttentionBlock(F_g=nb_filter[1], F_l=nb_filter[0], F_int=nb_filter[0]//2)
        self.attention2 = AttentionBlock(F_g=nb_filter[2], F_l=nb_filter[1], F_int=nb_filter[1]//2)
        self.attention3 = AttentionBlock(F_g=nb_filter[3], F_l=nb_filter[2], F_int=nb_filter[2]//2)
        self.attention4 = AttentionBlock(F_g=nb_filter[4], F_l=nb_filter[3], F_int=nb_filter[3]//2)

        self.conv0_1 = NestedDecoderBlock(nb_filter[0]+nb_filter[1], nb_filter[0])
        self.conv1_1 = NestedDecoderBlock(nb_filter[1]+nb_filter[2], nb_filter[1])
        self.conv2_1 = NestedDecoderBlock(nb_filter[2]+nb_filter[3], nb_filter[2])
        self.conv3_1 = NestedDecoderBlock(nb_filter[3]+nb_filter[4], nb_filter[3])

        self.conv0_2 = NestedDecoderBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0])
        self.conv1_2 = NestedDecoderBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1])
        self.conv2_2 = NestedDecoderBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2])

        self.conv0_3 = NestedDecoderBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0])
        self.conv1_3 = NestedDecoderBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1])

        self.conv0_4 = NestedDecoderBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1)

    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))

        x0_0_att = self.attention1(g=self.up(x1_0), x=x0_0)
        x0_1 = self.conv0_1(torch.cat([x0_0_att, self.up(x1_0)], dim=1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_0_att = self.attention2(g=self.up(x2_0), x=x1_0)
        x1_1 = self.conv1_1(torch.cat([x1_0_att, self.up(x2_0)], dim=1))

        x0_1_att = self.attention1(g=self.up(x1_1), x=x0_1)
        x0_2 = self.conv0_2(torch.cat([x0_0_att, x0_1_att, self.up(x1_1)], dim=1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_0_att = self.attention3(g=self.up(x3_0), x=x2_0)
        x2_1 = self.conv2_1(torch.cat([x2_0_att, self.up(x3_0)], dim=1))

        x1_1_att = self.attention2(g=self.up(x2_1), x=x1_1)
        x1_2 = self.conv1_2(torch.cat([x1_0_att, x1_1_att, self.up(x2_1)], dim=1))

        x0_2_att = self.attention1(g=self.up(x1_2), x=x0_2)
        x0_3 = self.conv0_3(torch.cat([x0_0_att, x0_1_att, x0_2_att, self.up(x1_2)], dim=1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_0_att = self.attention4(g=self.up(x4_0), x=x3_0)
        x3_1 = self.conv3_1(torch.cat([x3_0_att, self.up(x4_0)], dim=1))

        x2_1_att = self.attention3(g=self.up(x3_1), x=x2_1)
        x2_2 = self.conv2_2(torch.cat([x2_0_att, x2_1_att, self.up(x3_1)], dim=1))

        x1_2_att = self.attention2(g=self.up(x2_2), x=x1_2)
        x1_3 = self.conv1_3(torch.cat([x1_0_att, x1_1_att, x1_2_att, self.up(x2_2)], dim=1))

        x0_3_att = self.attention1(g=self.up(x1_3), x=x0_3)
        x0_4 = self.conv0_4(torch.cat([x0_0_att, x0_1_att, x0_2_att, x0_3_att, self.up(x1_3)], dim=1))

        output = self.final(x0_4)
        return output

In [None]:
#Четвертая версия Attention U-Net++ с механизмом внимания (увеличение глубины нейронной сети)

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = self.shortcut(x)
        out = self.conv(x)
        out += residual
        out = self.relu(out)
        return out

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class NestedDecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
        )
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, *inputs):
        inputs = [x for x in inputs if x is not None]
        x = torch.cat(inputs, dim=1)
        residual = self.shortcut(x)
        out = self.conv(x)
        out += residual
        out = self.relu(out)
        return out

class AttentionUnetPlusPlus4(nn.Module):
    def __init__(self, in_channels=1, out_channels=8):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = EncoderBlock(in_channels, nb_filter[0])
        self.conv1_0 = EncoderBlock(nb_filter[0], nb_filter[1])
        self.conv2_0 = EncoderBlock(nb_filter[1], nb_filter[2])
        self.conv3_0 = EncoderBlock(nb_filter[2], nb_filter[3])
        self.conv4_0 = EncoderBlock(nb_filter[3], nb_filter[4])

        # Блоки внимания
        self.attention1 = AttentionBlock(F_g=nb_filter[1], F_l=nb_filter[0], F_int=nb_filter[0]//2)
        self.attention2 = AttentionBlock(F_g=nb_filter[2], F_l=nb_filter[1], F_int=nb_filter[1]//2)
        self.attention3 = AttentionBlock(F_g=nb_filter[3], F_l=nb_filter[2], F_int=nb_filter[2]//2)
        self.attention4 = AttentionBlock(F_g=nb_filter[4], F_l=nb_filter[3], F_int=nb_filter[3]//2)

        self.conv0_1 = NestedDecoderBlock(nb_filter[0]+nb_filter[1], nb_filter[0])
        self.conv1_1 = NestedDecoderBlock(nb_filter[1]+nb_filter[2], nb_filter[1])
        self.conv2_1 = NestedDecoderBlock(nb_filter[2]+nb_filter[3], nb_filter[2])
        self.conv3_1 = NestedDecoderBlock(nb_filter[3]+nb_filter[4], nb_filter[3])

        self.conv0_2 = NestedDecoderBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0])
        self.conv1_2 = NestedDecoderBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1])
        self.conv2_2 = NestedDecoderBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2])

        self.conv0_3 = NestedDecoderBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0])
        self.conv1_3 = NestedDecoderBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1])

        self.conv0_4 = NestedDecoderBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1)

    def forward(self, input):
        # Этап энкодера
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))

        # Внимание и декодер
        x0_0_att = self.attention1(g=self.up(x1_0), x=x0_0)
        x0_1 = self.conv0_1(torch.cat([x0_0_att, self.up(x1_0)], dim=1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_0_att = self.attention2(g=self.up(x2_0), x=x1_0)
        x1_1 = self.conv1_1(torch.cat([x1_0_att, self.up(x2_0)], dim=1))

        x0_1_att = self.attention1(g=self.up(x1_1), x=x0_1)
        x0_2 = self.conv0_2(torch.cat([x0_0_att, x0_1_att, self.up(x1_1)], dim=1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_0_att = self.attention3(g=self.up(x3_0), x=x2_0)
        x2_1 = self.conv2_1(torch.cat([x2_0_att, self.up(x3_0)], dim=1))

        x1_1_att = self.attention2(g=self.up(x2_1), x=x1_1)
        x1_2 = self.conv1_2(torch.cat([x1_0_att, x1_1_att, self.up(x2_1)], dim=1))

        x0_2_att = self.attention1(g=self.up(x1_2), x=x0_2)
        x0_3 = self.conv0_3(torch.cat([x0_0_att, x0_1_att, x0_2_att, self.up(x1_2)], dim=1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_0_att = self.attention4(g=self.up(x4_0), x=x3_0)
        x3_1 = self.conv3_1(torch.cat([x3_0_att, self.up(x4_0)], dim=1))

        x2_1_att = self.attention3(g=self.up(x3_1), x=x2_1)
        x2_2 = self.conv2_2(torch.cat([x2_0_att, x2_1_att, self.up(x3_1)], dim=1))

        x1_2_att = self.attention2(g=self.up(x2_2), x=x1_2)
        x1_3 = self.conv1_3(torch.cat([x1_0_att, x1_1_att, x1_2_att, self.up(x2_2)], dim=1))

        x0_3_att = self.attention1(g=self.up(x1_3), x=x0_3)
        x0_4 = self.conv0_4(torch.cat([x0_0_att, x0_1_att, x0_2_att, x0_3_att, self.up(x1_3)], dim=1))

        output = self.final(x0_4)
        return output

In [None]:
#DeepLabV3+ с экстрактором признаков MobileNetV2 (модифицированной для одноканальных изображений)

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, rates=[6, 12, 18]):
        super(ASPP, self).__init__()
        self.aspp1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.aspp2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[0], dilation=rates[0], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.aspp3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[1], dilation=rates[1], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.aspp4 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[2], dilation=rates[2], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv1 = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        size = x.shape[2:]
        aspp1 = self.aspp1(x)
        aspp2 = self.aspp2(x)
        aspp3 = self.aspp3(x)
        aspp4 = self.aspp4(x)
        global_avg = self.global_avg_pool(x)
        global_avg = nn.functional.interpolate(global_avg, size=size, mode='bilinear', align_corners=True)
        concat = torch.cat([aspp1, aspp2, aspp3, aspp4, global_avg], dim=1)
        concat = self.conv1(concat)
        concat = self.bn1(concat)
        concat = self.relu(concat)
        concat = self.dropout(concat)
        return concat

class Decoder(nn.Module):
    def __init__(self, low_level_in, low_level_out, out_channels):
        super(Decoder, self).__init__()
        self.conv1 = nn.Conv2d(low_level_in, low_level_out, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(low_level_out)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Sequential(
            nn.Conv2d(low_level_out + out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)
        
        x = nn.functional.interpolate(x, size=low_level_feat.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, low_level_feat], dim=1)
        x = self.conv2(x)
        x = self.dropout(x)
        return x

class DeepLabV3Plus_MobileNetV2(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        super(DeepLabV3Plus_MobileNetV2, self).__init__()
        mobilenet = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1 if pretrained else None)
        
        #Модификация первой свёрточной слоя для 1 канала
        first_conv = mobilenet.features[0][0]
        if first_conv.in_channels != 1:
            new_first_conv = nn.Conv2d(
                1,
                first_conv.out_channels,
                kernel_size=first_conv.kernel_size,
                stride=first_conv.stride,
                padding=first_conv.padding,
                bias=first_conv.bias is not None
            )
            #Инициализация новых весов путем среднего значения по каналам
            new_first_conv.weight.data = first_conv.weight.data.mean(dim=1, keepdim=True)
            mobilenet.features[0][0] = new_first_conv
        
        #Извлечение необходимых слоёв
        self.backbone = create_feature_extractor(
            mobilenet, 
            return_nodes={
                'features.18': 'high_level',  #Последний слой MobileNetV2
                'features.3': 'low_level'     #Более ранний слой с 24 каналами для декодера
            }
        )
        self.aspp = ASPP(in_channels=1280, out_channels=256)
        self.decoder = Decoder(low_level_in=24, low_level_out=48, out_channels=256)
        self.final_conv = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        high_level = features['high_level']
        low_level = features['low_level']
        
        aspp_out = self.aspp(high_level)
        decoder_out = self.decoder(aspp_out, low_level)
        out = self.final_conv(decoder_out)
        out = nn.functional.interpolate(out, size=x.shape[2:], mode='bilinear', align_corners=True)
        return out

In [None]:
#Использование CBAM-attention (Convolutional Block Attention Module) 
class CBAM(nn.Module):
    def __init__(self, channels, reduction_ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        #Channel Attention Module
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction_ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction_ratio, channels, 1, bias=False)
        )
        self.sigmoid_channel = nn.Sigmoid()
        
        #Spatial Attention Module
        self.conv_spatial = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid_spatial = nn.Sigmoid()
        
    def forward(self, x):
        #Channel Attention
        avg_out = self.avg_pool(x)
        max_out = self.max_pool(x)
        avg_out = self.fc(avg_out)
        max_out = self.fc(max_out)
        channel_attn = self.sigmoid_channel(avg_out + max_out)
        x = x * channel_attn
        
        #Spatial Attention
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_attn = torch.cat([avg_out, max_out], dim=1)
        spatial_attn = self.conv_spatial(spatial_attn)
        spatial_attn = self.sigmoid_spatial(spatial_attn)
        x = x * spatial_attn
        
        return x

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, rates=[2, 4, 6]):
        super(ASPP, self).__init__()
        self.aspp1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.aspp2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[0], dilation=rates[0], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.aspp3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[1], dilation=rates[1], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.aspp4 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[2], dilation=rates[2], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv1 = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
        
        #Добавление CBAM
        self.cbam = CBAM(out_channels)
        
    def forward(self, x):
        size = x.shape[2:]
        aspp1 = self.aspp1(x)
        aspp2 = self.aspp2(x)
        aspp3 = self.aspp3(x)
        aspp4 = self.aspp4(x)
        global_avg = self.global_avg_pool(x)
        global_avg = F.interpolate(global_avg, size=size, mode='bilinear', align_corners=True)
        concat = torch.cat([aspp1, aspp2, aspp3, aspp4, global_avg], dim=1)
        concat = self.conv1(concat)
        concat = self.bn1(concat)
        concat = self.relu(concat)
        concat = self.dropout(concat)
        
        #Применение CBAM
        concat = self.cbam(concat)
        
        return concat

class Decoder(nn.Module):
    def __init__(self, low_level_in, low_level_out, out_channels):
        super(Decoder, self).__init__()
        self.conv1 = nn.Conv2d(low_level_in, low_level_out, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(low_level_out)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Sequential(
            nn.Conv2d(low_level_out + out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        self.dropout = nn.Dropout(0.5)
        
        self.cbam = CBAM(out_channels)
        
    def forward(self, x, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)
        
        x = F.interpolate(x, size=low_level_feat.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, low_level_feat], dim=1)
        x = self.conv2(x)
        x = self.dropout(x)
        
        x = self.cbam(x)
        
        return x

class DeepLabV3Plus_MobileNetV2(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        super(DeepLabV3Plus_MobileNetV2, self).__init__()
        mobilenet = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1 if pretrained else None)
        
        self.adapter = nn.Sequential( 
            nn.Conv2d(1, 3, kernel_size=3, padding = 1),
            nn.BatchNorm2d(3),
            nn.ReLU(inplace=True)
        )
        
        self.backbone = create_feature_extractor(
            mobilenet, 
            return_nodes={
                'features.18': 'high_level', 
                'features.3': 'low_level' 
            }
        )
        self.aspp = ASPP(in_channels=1280, out_channels=256)
        self.decoder = Decoder(low_level_in=24, low_level_out=48, out_channels=256)
        self.final_conv = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )
        
    def forward(self, x):
        x = self.adapter(x)
        features = self.backbone(x)
        high_level = features['high_level']
        low_level = features['low_level']
        
        aspp_out = self.aspp(high_level)
        decoder_out = self.decoder(aspp_out, low_level)
        out = self.final_conv(decoder_out)
        out = F.interpolate(out, size=x.shape[2:], mode='bilinear', align_corners=True)
        return out

In [None]:
#Использование EfficientNetB0 вместо MobileNetV2

class CBAM(nn.Module):
    def __init__(self, channels, reduction_ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction_ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction_ratio, channels, 1, bias=False)
        )
        self.sigmoid_channel = nn.Sigmoid()
        
        self.conv_spatial = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid_spatial = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = self.avg_pool(x)
        max_out = self.max_pool(x)
        avg_out = self.fc(avg_out)
        max_out = self.fc(max_out)
        channel_attn = self.sigmoid_channel(avg_out + max_out)
        x = x * channel_attn
        
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_attn = torch.cat([avg_out, max_out], dim=1)
        spatial_attn = self.conv_spatial(spatial_attn)
        spatial_attn = self.sigmoid_spatial(spatial_attn)
        x = x * spatial_attn
        
        return x

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, rates=[2, 4, 6]):
        super(ASPP, self).__init__()
        self.aspp1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.aspp2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[0], dilation=rates[0], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.aspp3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[1], dilation=rates[1], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.aspp4 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[2], dilation=rates[2], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv1 = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
        
        self.cbam = CBAM(out_channels)
        
    def forward(self, x):
        size = x.shape[2:]
        aspp1 = self.aspp1(x)
        aspp2 = self.aspp2(x)
        aspp3 = self.aspp3(x)
        aspp4 = self.aspp4(x)
        global_avg = self.global_avg_pool(x)
        global_avg = F.interpolate(global_avg, size=size, mode='bilinear', align_corners=True)
        concat = torch.cat([aspp1, aspp2, aspp3, aspp4, global_avg], dim=1)
        concat = self.conv1(concat)
        concat = self.bn1(concat)
        concat = self.relu(concat)
        concat = self.dropout(concat)
        
        concat = self.cbam(concat)
        
        return concat

class Decoder(nn.Module):
    def __init__(self, low_level_in, low_level_out, out_channels):
        super(Decoder, self).__init__()
        self.conv1 = nn.Conv2d(low_level_in, low_level_out, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(low_level_out)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Sequential(
            nn.Conv2d(low_level_out + out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        self.dropout = nn.Dropout(0.5)
        
        self.cbam = CBAM(out_channels)
        
    def forward(self, x, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)
        
        x = F.interpolate(x, size=low_level_feat.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, low_level_feat], dim=1)
        x = self.conv2(x)
        x = self.dropout(x)
        
        x = self.cbam(x)
        
        return x

class DeepLabV3Plus_EfficientNetB0(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        super(DeepLabV3Plus_EfficientNetB0, self).__init__()
        efficientnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None)
        
        self.adapter = nn.Sequential( 
            nn.Conv2d(1, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(inplace=True)
        )
        
        #Извлечение необходимых слоёв из EfficientNet_B0
        self.backbone = create_feature_extractor(
            efficientnet, 
            return_nodes={
                'features.8': 'high_level',  #Последний слой перед классификатором
                'features.2': 'low_level'    #Ранний слой с 24 каналами для декодера
            }
        )
        self.aspp = ASPP(in_channels=1280, out_channels=256, rates=[1, 2, 3])  #Измененные rates для более узкого контекста (или меньших изображений)
        self.decoder = Decoder(low_level_in=24, low_level_out=48, out_channels=256)
        self.final_conv = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )
        
    def forward(self, x):
        x = self.adapter(x)
        features = self.backbone(x)
        high_level = features['high_level']
        low_level = features['low_level']
        
        aspp_out = self.aspp(high_level)
        decoder_out = self.decoder(aspp_out, low_level)
        out = self.final_conv(decoder_out)
        out = F.interpolate(out, size=x.shape[2:], mode='bilinear', align_corners=True)
        return out

In [None]:
#Использование архитектуры MedT - Medical Transformer

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, stride=stride),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, in_channels, num_heads=8, num_layers=4):
        super(TransformerBlock, self).__init__()
        self.layers = nn.ModuleList(
            [nn.TransformerEncoderLayer(d_model=in_channels, nhead=num_heads) for _ in range(num_layers)]
        )
    def forward(self, x):
        B, C, H, W = x.shape
        x = x.view(B, C, -1).permute(2, 0, 1)
        for layer in self.layers:
            x = layer(x)
        x = x.permute(1, 2, 0).view(B, C, H, W)
        return x

class MedT(nn.Module):
    def __init__(self, img_size=128, in_channels=1, num_classes=8):
        super(MedT, self).__init__()
        self.encoder1 = ConvBlock(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        
        self.encoder2 = ConvBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        
        self.encoder3 = ConvBlock(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        
        self.encoder4 = ConvBlock(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        
        #Transformer block
        self.transformer = TransformerBlock(512, num_heads=8, num_layers=4)
        
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder4 = ConvBlock(768, 256) 
        
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = ConvBlock(384, 128)
        
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = ConvBlock(192, 64) 
        
        self.upconv1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.decoder1 = ConvBlock(96, 32) 
        
        self.conv_last = nn.Conv2d(32, num_classes, kernel_size=1)
        
    def forward(self, x):
        x1 = self.encoder1(x)       
        x_pool1 = self.pool1(x1)    
        
        x2 = self.encoder2(x_pool1) 
        x_pool2 = self.pool2(x2)    
        
        x3 = self.encoder3(x_pool2)   
        x_pool3 = self.pool3(x3)   
        
        x4 = self.encoder4(x_pool3)
        x_pool4 = self.pool4(x4)   
        
        #Transformer
        x_transformed = self.transformer(x_pool4)
        
        x_up4 = self.upconv4(x_transformed)   
        x_cat4 = torch.cat([x_up4, x4], dim=1)
        x_dec4 = self.decoder4(x_cat4)        
        
        x_up3 = self.upconv3(x_dec4)          
        x_cat3 = torch.cat([x_up3, x3], dim=1)
        x_dec3 = self.decoder3(x_cat3)        
        
        x_up2 = self.upconv2(x_dec3)          
        x_cat2 = torch.cat([x_up2, x2], dim=1)  
        x_dec2 = self.decoder2(x_cat2)            
        
        x_up1 = self.upconv1(x_dec2)           
        x_cat1 = torch.cat([x_up1, x1], dim=1)     
        x_dec1 = self.decoder1(x_cat1)          
        
        output = self.conv_last(x_dec1)           
        return output

In [None]:
#Transformer U-Net

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.conv(x)

class Encoder(nn.Module):
    def __init__(self, channels):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(len(channels)-1):
            self.layers.append(nn.Sequential(
                ConvBlock(channels[i], channels[i+1]),
                ConvBlock(channels[i+1], channels[i+1])
            ))
    def forward(self, x):
        features = []
        for layer in self.layers:
            x = layer(x)
            features.append(x)
            x = F.max_pool2d(x, kernel_size=2)
        return x, features

class Decoder(nn.Module):
    def __init__(self, channels):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(len(channels)-1):
            self.layers.append(nn.Sequential(
                nn.ConvTranspose2d(channels[i], channels[i+1], kernel_size=2, stride=2),
                ConvBlock(channels[i], channels[i+1]),
                ConvBlock(channels[i+1], channels[i+1])
            ))
    def forward(self, x, features):
        for i in range(len(self.layers)):
            x = self.layers[i][0](x)
            x = torch.cat([x, features[-(i+1)]], dim=1)
            x = self.layers[i][1](x)
            x = self.layers[i][2](x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim*4, embed_dim),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(embed_dim)
    def forward(self, x):
        x_res = x
        x = self.norm1(x)
        x, _ = self.attn(x, x, x)
        x = x_res + x
        x_res = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = x_res + x
        return x

class TransUNet(nn.Module):
    def __init__(self, img_size=128, in_channels=1, num_classes=8, base_channels=64, num_heads=8, num_layers=12, embed_dim=512, patch_size=8):
        super(TransUNet, self).__init__()
        self.encoder = Encoder([in_channels, base_channels, base_channels*2, base_channels*4, base_channels*8])
        self.bottleneck = nn.Sequential(
            ConvBlock(base_channels*8, base_channels*16),
            ConvBlock(base_channels*16, base_channels*16)
        )
        self.patch_embedding = nn.Conv2d(base_channels*16, embed_dim, kernel_size=patch_size, stride=patch_size)
        num_patches = (img_size // (2**4 * patch_size))**2
        
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim, num_heads) for _ in range(num_layers)])
        
        self.proj_back = nn.ConvTranspose2d(embed_dim, base_channels*16, kernel_size=patch_size, stride=patch_size)
        
        self.decoder = Decoder([base_channels*16, base_channels*8, base_channels*4, base_channels*2, base_channels])
        
        self.final_conv = nn.Conv2d(base_channels, num_classes, kernel_size=1)
    def forward(self, x):
        
        x, features = self.encoder(x)
        
        x = self.bottleneck(x)
        
        B, C, H, W = x.shape
        x = self.patch_embedding(x).flatten(2).permute(2, 0, 1)
        
        x = self.transformer(x)
        
        x = x.permute(1, 2, 0).view(B, -1, H // self.patch_embedding.kernel_size[0], W // self.patch_embedding.kernel_size[0])
        x = self.proj_back(x)
        
        x = self.decoder(x, features)
        
        x = self.final_conv(x)
        return x

Прочие архитектуры, основанные на vision-transformers здесь не привожу, поскольку они не показали высокой эффективности

In [7]:
#Функционалы потерь
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super(DiceLoss, self).__init__()
        self.eps = eps

    def forward(self, inputs, targets):
        num_classes = inputs.size(1)  #Число классов (4)

        inputs = F.softmax(inputs, dim=1) 

        targets_one_hot = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2).float()

        intersection = torch.sum(inputs * targets_one_hot, dim=(2, 3)) 
        cardinality = torch.sum(inputs + targets_one_hot, dim=(2, 3))

        dice_loss = 1 - (2. * intersection + self.eps) / (cardinality + self.eps)
        loss = dice_loss.mean()

        return loss

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, eps=1e-6):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha 
        self.eps = eps

    def forward(self, inputs, targets):
        num_classes = inputs.size(1)

        inputs_soft = F.softmax(inputs, dim=1) + self.eps 

        targets_one_hot = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2).float()  

        ce_loss = -targets_one_hot * torch.log(inputs_soft)  

        focal_loss = ce_loss * ((1 - inputs_soft) ** self.gamma)

        if self.alpha is not None:
            alpha = torch.tensor(self.alpha).to(inputs.device) 
            focal_loss = alpha.view(1, -1, 1, 1) * focal_loss 

        loss = focal_loss.mean() 

        return loss

class IOULoss(nn.Module):
    def __init__(self, eps=1e-6):
        super(IOULoss, self).__init__()
        self.eps = eps 

    def forward(self, inputs, targets):
        num_classes = inputs.size(1) 

        inputs = F.softmax(inputs, dim=1)

        targets_one_hot = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2).float()

        intersection = torch.sum(inputs * targets_one_hot, dim=(2, 3)) 
        union = torch.sum(inputs + targets_one_hot - inputs * targets_one_hot, dim=(2, 3))

        iou = (intersection + self.eps) / (union + self.eps)

        iou_loss = 1 - iou
        loss = iou_loss.mean()  

        return loss

class HybridLoss(nn.Module):
    def __init__(self, weight=None):
        super(HybridLoss, self).__init__()
        self.dice_loss = DiceLoss()
        self.focal_loss = FocalLoss()
        self.iou_loss = IOULoss()
        self.weight = weight  #Список весов для каждой функции потерь

        if self.weight is None:
            #Если веса не заданы, используем равные веса
            self.weight = [1/3, 1/3, 1/3]

    def forward(self, inputs, targets):
        loss_dice = self.dice_loss(inputs, targets)
        loss_focal = self.focal_loss(inputs, targets)
        loss_iou = self.iou_loss(inputs, targets)

        #Комбинируем потери с заданными весами
        loss = self.weight[0] * loss_dice + self.weight[1] * loss_focal + self.weight[2] * loss_iou

        return loss

In [13]:
#Метрики качества
class DiceCoefficient(nn.Module):
    def __init__(self, num_classes, eps=1e-8):
        super(DiceCoefficient, self).__init__()
        self.num_classes = num_classes
        self.eps = eps

    def forward(self, inputs, targets, use_argmax=True):
        if use_argmax:
            #Получаем предсказанные классы через argmax
            preds = torch.argmax(inputs, dim=1)
        else:
            preds = inputs
        
        dice_scores = []

        #Итерация по всем классам, кроме фона
        for cls in range(1, self.num_classes):
            pred_mask = (preds == cls).float() 
            target_mask = (targets == cls).float() 

            intersection = (pred_mask * target_mask).sum(dim=(1, 2))
            pred_sum = pred_mask.sum(dim=(1, 2))
            target_sum = target_mask.sum(dim=(1, 2))
            union = pred_sum + target_sum

            dice = (2 * intersection + self.eps) / (union + self.eps)

            dice_scores.append(dice)

        dice_scores = torch.stack(dice_scores, dim=1)

        dice_mean = dice_scores.mean()

        return dice_mean

class OneClassDiceCoefficient(nn.Module):
    def __init__(self, num_classes, class_num, eps=1e-8):
        super(OneClassDiceCoefficient, self).__init__()
        self.num_classes = num_classes
        self.eps = eps
        self.class_num = class_num

    def forward(self, inputs, targets, use_argmax=True):
        if use_argmax:
            preds = torch.argmax(inputs, dim=1)  
        else:
            preds = inputs

        cls = self.class_num

        pred_mask = (preds == cls).float()
        target_mask = (targets == cls).float() 

        intersection = (pred_mask * target_mask).sum(dim=(1, 2))
        pred_sum = pred_mask.sum(dim=(1, 2)) 
        target_sum = target_mask.sum(dim=(1, 2))
        union = pred_sum + target_sum

        dice = torch.where(union > 0, (2 * intersection + self.eps) / (union + self.eps), torch.zeros_like(union).float())

        dice_mean = dice.mean()
        return dice_mean

class IOU(nn.Module):
    def __init__(self, num_classes, eps=1e-8):
        super(IOU, self).__init__()
        self.num_classes = num_classes
        self.eps = eps

    def forward(self, inputs, targets, use_argmax=True):
        if use_argmax:
            preds = torch.argmax(inputs, dim=1) 
        else:
            preds = inputs

        iou_list = []
        for cls in range(1, self.num_classes):
            pred_cls = (preds == cls).float()
            target_cls = (targets == cls).float() 

            intersection = torch.sum(pred_cls * target_cls, dim=(1, 2))
            union = torch.sum(pred_cls + target_cls - pred_cls * target_cls, dim=(1, 2)) 

            iou = (intersection + self.eps) / (union + self.eps) 
            iou_list.append(iou)
            
        iou = torch.stack(iou_list, dim=1) 

        iou_mean = iou.mean()

        return iou_mean

class OneClassIOU(nn.Module):
    def __init__(self, num_classes, class_num, eps=1e-8):
        super(OneClassIOU, self).__init__()
        self.num_classes = num_classes
        self.eps = eps
        self.class_num = class_num

    def forward(self, inputs, targets, use_argmax=True):
        if use_argmax:
            preds = torch.argmax(inputs, dim=1) 
        else:
            preds = inputs

        cls = self.class_num
        pred_cls = (preds == cls).float() 
        target_cls = (targets == cls).float() 

        intersection = torch.sum(pred_cls * target_cls, dim=(1, 2))
        union = torch.sum(pred_cls + target_cls - pred_cls * target_cls, dim=(1, 2))

        iou = torch.where(union > 0, (intersection + self.eps) / (union + self.eps), torch.zeros_like(union).float())

        iou_mean = iou.mean()
        
        return iou_mean

class PixelAccuracy(nn.Module):
    def __init__(self, ignore_index=None):
        super(PixelAccuracy, self).__init__()
        self.ignore_index = ignore_index  # Индекс класса, который следует игнорировать (фон)

    def forward(self, inputs, targets, use_argmax=True):
        if use_argmax:
            _, preds = torch.max(inputs, dim=1) 
        else:
            preds = inputs

        if self.ignore_index is not None:
            mask = targets != self.ignore_index
            correct = (preds[mask] == targets[mask]).sum().float()
            total = mask.sum().float()
        else:
            correct = (preds == targets).sum().float()
            total = targets.numel()

        accuracy = correct / (total + 1e-8)

        return accuracy

class Precision(nn.Module):
    def __init__(self, num_classes, eps=1e-8):
        super(Precision, self).__init__()
        self.num_classes = num_classes
        self.eps = eps

    def forward(self, inputs, targets, use_argmax=True):
        if use_argmax:
            _, preds = torch.max(inputs, dim=1) 
        else:
            preds = inputs

        precision = []
        for cls in range(1, self.num_classes):
            true_positive = ((preds == cls) & (targets == cls)).float().view(preds.size(0), -1).sum(dim=1) 
            predicted_positive = (preds == cls).float().view(preds.size(0), -1).sum(dim=1) 
            precision_cls = (true_positive + self.eps) / (predicted_positive + self.eps) 
            precision.append(precision_cls)

        precision = torch.stack(precision, dim=1) 
        precision_per_object = precision.mean(dim=1) 

        precision_mean = precision_per_object.mean()

        return precision_mean

class OneClassPrecision(nn.Module):
    def __init__(self, num_classes, class_num, eps=1e-8):
        super(OneClassPrecision, self).__init__()
        self.num_classes = num_classes
        self.eps = eps  
        self.class_num = class_num

    def forward(self, inputs, targets, use_argmax=True):
        if use_argmax:
            _, preds = torch.max(inputs, dim=1) 
        else:
            preds = inputs

        true_positive = ((preds == self.class_num) & (targets == self.class_num)).float().view(preds.size(0), -1).sum(dim=1) 
        predicted_positive = (preds == self.class_num).float().view(preds.size(0), -1).sum(dim=1)
        precision_cls = (true_positive + self.eps) / (predicted_positive + self.eps) 

        precision_mean = precision_cls.mean()

        return precision_mean

class Recall(nn.Module):
    def __init__(self, num_classes, eps=1e-8):
        super(Recall, self).__init__()
        self.num_classes = num_classes
        self.eps = eps 

    def forward(self, inputs, targets, use_argmax=True):
        if use_argmax:
            _, preds = torch.max(inputs, dim=1) 
        else:
            preds = inputs

        recall = []
        for cls in range(1, self.num_classes):
            true_positive = ((preds == cls) & (targets == cls)).float().view(preds.size(0), -1).sum(dim=1)
            actual_positive = (targets == cls).float().view(targets.size(0), -1).sum(dim=1)
            recall_cls = (true_positive + self.eps) / (actual_positive + self.eps) 
            recall.append(recall_cls)
        
        recall = torch.stack(recall, dim=1)
        recall_per_object = recall.mean(dim=1) 

        recall_mean = recall_per_object.mean()

        return recall_mean

class OneClassRecall(nn.Module):
    def __init__(self, num_classes, class_num, eps=1e-8):
        super(OneClassRecall, self).__init__()
        self.num_classes = num_classes
        self.eps = eps  
        self.class_num = class_num

    def forward(self, inputs, targets, use_argmax=True):
        if use_argmax:
            _, preds = torch.max(inputs, dim=1)  
        else:
            preds = inputs

        true_positive = ((preds == self.class_num) & (targets == self.class_num)).float().view(preds.size(0), -1).sum(dim=1) 
        actual_positive = (targets == self.class_num).float().view(targets.size(0), -1).sum(dim=1) 

        recall_cls = torch.where(actual_positive > 0, (true_positive + self.eps) / (actual_positive + self.eps), torch.zeros_like(actual_positive).float())

        recall_mean = recall_cls.mean()

        return recall_mean

In [33]:
DEVICE = 'mps'
BATCH_SIZE = 32    

def train_segmentation_model(model, optimizer, criterion, train_transforms, val_transforms, header, metrics, n_epoch = 101):
    """
    model - модель (экземпляр) для обучения
    optimizer - оптимизатор
    criterion - функционал ошибки
    train_transforms - аугментации для обучения
    val_transforms - аугментации для валидации/теста (обычно только Resize) 
    header - описание эксперимента, нужное для логирования (логи в segmentation_logs.txt) 
    metrics - метрики для логирования
    n_epoch - максимальное количество эпох обучения (под эпохой подразумевается не полный прогон всех данных, а только size из CustomDataset из-за ограниченных вычислительных ресурсов)
    """
    with open('segmentation_logs.txt', 'a') as file:
       file.write(header + '\n')
    max_acc = 0
    min_loss = 100
    losses_train, losses_val = [], []

    train_dataset = CustomDataset('subset0', 'seg-lungs-LUNA16', 960, 20, 10, train_transforms)
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataset = CustomDataset('subset1', 'seg-lungs-LUNA16', 384, 3, 10, val_transforms)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True) 

    current_metrics = [0] * (2 * len(metrics))
    
    for epoch in range(1, n_epoch + 1):
        model.train()

        current_metrics = [0] * (2 * len(metrics))
        
        train_loss = 0
        for inputs, masks in tqdm(train_dataloader):
            inputs = inputs.to(DEVICE)
            masks = masks.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss = criterion(outputs, masks.long())
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            for i in range(len(metrics)):
                current_metrics[i] += metrics[i][0](outputs, masks.long()).item()

        losses_train.append(train_loss / len(train_dataloader))

        model.eval()

        val_loss = 0
        val_acc = 0
        val_pixel_acc = 0
        mask_true0 = 0
        mask_true1 = 0
        mask_true2 = 0
        mask_true3 = 0
        with torch.no_grad():
            for inputs, masks in tqdm(val_dataloader):
                inputs = inputs.to(DEVICE)
                masks = masks.to(DEVICE)

                outputs = model(inputs)
                
                loss = criterion(outputs, masks.long())

                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                for i in range(inputs.shape[0]):
                    mask_true0 += (masks[i] == 0).cpu().numpy().sum()
                    mask_true1 += (masks[i] == 1).cpu().numpy().sum()
                    mask_true2 += (masks[i] == 2).cpu().numpy().sum()
                    mask_true3 += (masks[i] == 3).cpu().numpy().sum()
                  
                for i in range(len(metrics)):
                    current_metrics[i + len(metrics)] += metrics[i][0](outputs, masks.long()).item()

        print(mask_true0, mask_true1, mask_true2, mask_true3) #отладочный вывод (можно использовать для баланса классов)
        losses_val.append(val_loss / len(val_dataloader))
                
        for i in range(len(metrics)):
            current_metrics[i] = current_metrics[i] / len(train_dataloader)
            current_metrics[i + len(metrics)] = current_metrics[i + len(metrics)] / len(val_dataloader)

        out_string = f'Iteration: {epoch}\ntrain loss: {losses_train[-1]}\n'
        for i in range(len(metrics)):
            out_string += f'train {metrics[i][1]}: {current_metrics[i]}\n'
            
        out_string += f'test loss: {losses_val[-1]}\n'
        for i in range(len(metrics)):
            out_string += f'test {metrics[i][1]}: {current_metrics[i + len(metrics)]}\n'

        with open('segmentation_logs.txt', 'a') as file:
           file.write(out_string + '\n')    

        torch.save({
             'model_state_dict': model.state_dict(),
             'optimizer_state_dict': optimizer.state_dict()
        }, f'mini_unet_{epoch}.pth')
        print(out_string)

In [35]:
train_transforms = A.Compose([
    A.Resize(128, 128),
    #A.RandomResizedCrop(height=128, width=128, scale=(0.95, 1.0), ratio=(1.0, 1.0), p=1),
    #A.Rotate(limit=5, p=0.7),
    #A.GaussNoise(var_limit=(5.0, 20.0), p=0.1),
    ToTensorV2()
])
val_transforms = A.Compose([
    A.Resize(128, 128),
    ToTensorV2()
])  

In [None]:
#Обучение модели
model = MiniUnet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
metrics = [(PixelAccuracy(), 'pixel accuracy'), 
           (DiceCoefficient(4), 'mean dice'),
           (OneClassDiceCoefficient(4, 1), 'right lung dice'),
           (OneClassDiceCoefficient(4, 2), 'left lung dice'),
           (OneClassDiceCoefficient(4, 3), 'trachea dice'),
           (IOU(4), 'mean IOU'),
           (OneClassIOU(4, 1), 'right lung IOU'),
           (OneClassIOU(4, 2), 'left lung IOU'),
           (OneClassIOU(4, 3), 'trachea IOU'),
           (Precision(4), 'mean precision'),
           (OneClassPrecision(4, 1), 'right lung precision'),
           (OneClassPrecision(4, 2), 'left lung precision'),
           (OneClassPrecision(4, 3), 'trachea precision'),
           (Recall(4), 'mean recall'),
           (OneClassRecall(4, 1), 'right lung recall'),
           (OneClassRecall(4, 2), 'left lung recall'),
           (OneClassRecall(4, 3), 'trachea recall')]
train_segmentation_model(model, optimizer, criterion, train_transforms, val_transforms, 'Mini U-Net, cross-entropy loss, adam optimizer, batch size = 32, only resize to 128x128 augmentations, datasets = (20, 10), (3, 8)', metrics, n_epoch = 30)

In [None]:
#Пример для Lion-optimizer
model = MiniUnet().to(DEVICE)
optimizer = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
criterion = nn.CrossEntropyLoss()
metrics = [(PixelAccuracy(), 'pixel accuracy'), 
           (DiceCoefficient(4), 'mean dice'),
           (OneClassDiceCoefficient(4, 1), 'right lung dice'),
           (OneClassDiceCoefficient(4, 2), 'left lung dice'),
           (OneClassDiceCoefficient(4, 3), 'trachea dice'),
           (IOU(4), 'mean IOU'),
           (OneClassIOU(4, 1), 'right lung IOU'),
           (OneClassIOU(4, 2), 'left lung IOU'),
           (OneClassIOU(4, 3), 'trachea IOU'),
           (Precision(4), 'mean precision'),
           (OneClassPrecision(4, 1), 'right lung precision'),
           (OneClassPrecision(4, 2), 'left lung precision'),
           (OneClassPrecision(4, 3), 'trachea precision'),
           (Recall(4), 'mean recall'),
           (OneClassRecall(4, 1), 'right lung recall'),
           (OneClassRecall(4, 2), 'left lung recall'),
           (OneClassRecall(4, 3), 'trachea recall')]
train_segmentation_model(model, optimizer, criterion, train_transforms, val_transforms, 'Mini U-Net, cross-entropy loss, lion optimizer, batch size = 32, only resize to 128x128 augmentations, datasets = (20, 10), (3, 8)', metrics, n_epoch = 30)

In [None]:
#Визуальная проверка нейросети (устарело; требует переработки после оптимизации обучения нейронных сетей)
accs = []
DEVICE = 'mps'
def show_input_output(model, name):
    global accs
    model.eval()
    #Читаем изображение используя SimpleITK
    image = sitk.ReadImage(os.path.join('subset1', name))
    ground_true = sitk.ReadImage(os.path.join('seg-lungs-LUNA16', name))

    for i in range(image.GetDepth()-1):
        #Извлекаем заданный срез
        input_image = sitk.GetArrayFromImage(image[:,:,i])
        input_image = input_image.astype(np.float32)
        input_image = ((input_image - input_image.min()) / (input_image.max() - input_image.min()))

        input_image = A.Compose([
            A.Resize(128, 128),
            ToTensorV2()
        ])(image=input_image)['image'].unsqueeze(0).to(DEVICE)

        ground_true_image = sitk.GetArrayFromImage(ground_true[:,:,i])
        ground_true_image = A.Compose([
            A.Resize(128, 128),
            ToTensorV2()
        ])(image=ground_true_image)['image'].unsqueeze(0).to(DEVICE)
            
        with torch.no_grad():
            output_image = model(input_image)
            output_image = torch.argmax(output_image, dim=1) 

        accs.append((output_image == ground_true_image).float().mean().item())


        #Преобразование тензоров в формат, подходящий для отображения
        input_image = input_image.squeeze().cpu().numpy()
        output_image = output_image.squeeze().cpu().numpy()
        ground_true_image = ground_true_image.squeeze().cpu().numpy()
        
        #Отображение изображений
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
        
        ax1.imshow(input_image, cmap = 'gray')
        ax1.set_title(f'Входное изображение. Срез {i}')
        
        ax2.imshow(output_image, cmap='plasma')
        ax2.set_title(f'Сегментация нейронной сети')

        ax3.imshow(ground_true_image, cmap='plasma')
        ax3.set_title(f'Истинная маска')

        accs.append((output_image == ground_true_image).mean())
         
        clear_output(wait = True)
        plt.show()

model = MiniUnet().to(DEVICE)
checkpoint = torch.load('mini_unet_8.pth')
model.load_state_dict(checkpoint['model_state_dict'])

show_input_output(model, '1.3.6.1.4.1.14519.5.2.1.6279.6001.340012777775661021262977442176.mhd') 
print(sum(accs)/len(accs))