Определим Unet

In [None]:
%pip install torch torchvision kaggle opencv-python scikit-learn



Берём [отсюда](https://www.kaggle.com/settings) токен и бросаем его в %USERPATH%/.kaggle/

In [None]:
from google.colab import files

uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

# Then move kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle/ && mv kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json

Saving kaggle.json to kaggle.json
User uploaded file "kaggle.json" with length 69 bytes


In [None]:
import kaggle
import pathlib
import torch
import torchvision
import zipfile
import cv2
import os
import numpy as np
from torch.nn import BCEWithLogitsLoss
from PIL import Image
from torch.utils.data import Dataset, random_split


if torch.cuda.is_available():
    print("Мы счастливые обладатели видеокарты от AMD/Apple/Nvidia, потому можем позволить себе использовать GPU для ускорения")
    device = torch.device("cuda")
else:
    print("GPU ускорение нам недоступно, а значит, придётся запастись терпением")
    device = torch.device("cpu")

kaggle.api.competition_download_files("ml-intensive-yandex-autumn-2023", path="dataset")
dataset_path = pathlib.Path("dataset") / "data"

if not os.path.exists(dataset_path):
    with zipfile.ZipFile("./dataset/ml-intensive-yandex-autumn-2023.zip", 'r') as zip_ref:
        zip_ref.extractall("./dataset/")

imagePaths = [filename for filename in os.listdir(dataset_path / "train_images")]
maskPaths = [str(dataset_path / "train_lung_masks" / filename) for filename in os.listdir(dataset_path / "train_images")]

Мы счастливые обладатели видеокарты от AMD/Apple/Nvidia, потому можем позволить себе использовать GPU для ускорения


In [None]:
class SegmentationDataset(Dataset):
    def __getitem__(self, index):
        image_name = self.images_paths[index]

        image = Image.open(os.path.join(self.image_dir, f"{image_name}")).convert("RGB")
        seg = Image.open(os.path.join(self.segmentation_dir, f"{image_name}")).convert("L")

        image = self.transform_image(image)
        seg = self.transform_mask(seg)

        return image, seg

    def __init__(self, image_paths, image_dir, segmentation_dir, transform_image, transform_mask):
        super(SegmentationDataset, self).__init__()
        self.image_dir = image_dir
        self.segmentation_dir = segmentation_dir
        self.transform_image = transform_image
        self.transform_mask = transform_mask
        self.images_paths = image_paths

    def __len__(self):
        return len(self.images_paths)

In [None]:
from torch.utils.data import DataLoader


def load_data_set(image_paths, image_dir, segmentation_dir, transforms, batch_size=8, shuffle=True):
    dataset = SegmentationDataset(
        image_paths,
        image_dir,
        segmentation_dir,
        transform_image=transforms[0],
        transform_mask=transforms[1]
    )

    train_dataset, val_dataset = random_split(dataset, [25000, 2000])

    return DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=shuffle
    ), DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=shuffle
        )

In [None]:
from torchsummary import summary
from torch import nn
import torch
import torchvision.transforms.functional as TF


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernal_size, strides, padding):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernal_size, strides, padding, bias=False),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernal_size, strides, padding, bias=False),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels, num_segmentations=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.bottleneck = DoubleConv(
            in_channels=features[-1],
            out_channels=features[-1]*2,
            kernal_size=3,
            strides=1,
            padding=1
        )
        self.output = nn.Conv2d(
            in_channels=features[0],
            out_channels=num_segmentations,
            kernel_size=1
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        in_channels_iter = in_channels
        for feature in features:
            self.downs.append(DoubleConv(
                    in_channels=in_channels_iter,
                    out_channels=feature,
                    kernal_size=3,
                    strides=1,
                    padding=1
                ))
            in_channels_iter = feature

        for feature in reversed(features):
            up = nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels=feature*2,
                    out_channels=feature,
                    kernel_size=2,
                    stride=2,
                    padding=0
                ),
                DoubleConv(
                    in_channels=feature*2,
                    out_channels=feature,
                    kernal_size=3,
                    padding=1,
                    strides=1
                )
            )

            self.ups.append(up)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for i in range(len(self.ups)):
            x = self.ups[i][0](x) 

            skip_connection = skip_connections[i]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_x = torch.cat((skip_connection, x), dim=1)
            
            x = self.ups[i][1](concat_x)

        return self.output(x)

model = UNet(3).to(device)

Приступим к обучению

In [None]:
import torch
from torchvision.transforms import transforms
from tqdm import tqdm
import torchvision

config = {
    "lr": 1e-3,
    "batch_size": 16,
    "image_dir": "./dataset/data/train_images",
    "segmentation_dir": "./dataset/data/train_lung_masks",
    "image_paths": imagePaths,
    "epochs": 10,
    "checkpoint": "./checkpoint/lungs_segmentation_v1.pth",
    "optimiser": "./checkpoint/lungs_segmentation_v1_optim.pth",
    "continue_train": False,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

print(f"Training using {config['device']}")

transforms_image = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0., 0., 0.), (1., 1., 1.))
])

transforms_mask = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1.,))
])

train_dataset, val_dataset = load_data_set(
    config['image_paths'],
    config['image_dir'],
    config['segmentation_dir'],
    transforms=[transforms_image, transforms_mask],
    batch_size=config['batch_size']
)

print("loaded", len(train_dataset), "batches")

model = UNet(3).to(config['device'])
optimiser = torch.optim.Adam(params=model.parameters(), lr=config['lr'])

if config['continue_train']:
    state_dict = torch.load(config['checkpoint'])
    optimiser_state = torch.load(config['optimiser'])
    model.load_state_dict(state_dict)
    optimiser.load_state_dict(optimiser_state)

loss_fn = torch.nn.BCEWithLogitsLoss()
scaler = torch.cuda.amp.GradScaler()

model.train()


def check_accuracy_and_save(model, optimiser, epoch):
    torch.save(model.state_dict(), config['checkpoint'])
    torch.save(optimiser.state_dict(), config['optimiser'])

    num_correct = 0
    num_pixel = 0
    dice_score = 0

    model.eval()
    with torch.no_grad():
        for x, y in val_dataset:
            x = x.to(config['device'])
            y = y.to(config['device'])

            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixel += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

            torchvision.utils.save_image(preds, f"./test/pred/{epoch}.png")
            torchvision.utils.save_image(y, f"./test/true/{epoch}.png")

    print(
        f"Dice Score = {dice_score/len(val_dataset)}"
    )
    model.train()


def train():
    step = 0
    for epoch in range(config['epochs']):
        loop = tqdm(train_dataset)
        for image, seg in loop:
            image = image.to(config['device'])
            seg = seg.float().to(config['device'])

            with torch.cuda.amp.autocast():
                pred = model(image)
                loss = loss_fn(pred, seg)

            optimiser.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimiser)
            scaler.update()

            loop.set_postfix(loss=loss.item())
            step += 1
        check_accuracy_and_save(model, optimiser, epoch)


if __name__ == "__main__":
    train()

Training using cuda
loaded 1563 batches


100%|██████████| 1563/1563 [11:14<00:00,  2.32it/s, loss=0.0389]


Dice Score = 0.9719495177268982


100%|██████████| 1563/1563 [11:14<00:00,  2.32it/s, loss=0.0304]


Dice Score = 0.9750511050224304


100%|██████████| 1563/1563 [11:15<00:00,  2.31it/s, loss=0.0329]


Dice Score = 0.9366233348846436


100%|██████████| 1563/1563 [11:14<00:00,  2.32it/s, loss=0.0338]


Dice Score = 0.9754799604415894


100%|██████████| 1563/1563 [11:13<00:00,  2.32it/s, loss=0.0187]


Dice Score = 0.9774037599563599


100%|██████████| 1563/1563 [11:16<00:00,  2.31it/s, loss=0.0411]


Dice Score = 0.9781410694122314


100%|██████████| 1563/1563 [11:16<00:00,  2.31it/s, loss=0.0163]


Dice Score = 0.9771878719329834


100%|██████████| 1563/1563 [11:17<00:00,  2.31it/s, loss=0.0227]


Dice Score = 0.9700518846511841


  2%|▏         | 28/1563 [00:12<11:26,  2.24it/s, loss=0.0169]


KeyboardInterrupt: ignored