In [None]:
# !pip install pytorch-ignite



In [None]:
# !pip install -q kaggle
# from google.colab import files
# files.upload() # choose kaggle.json
# !mkdir ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json
# !kaggle competitions download -c ml-intensive-yandex-autumn-2023
# !mkdir lung_dataset
# !unzip ml-intensive-yandex-autumn-2023.zip -d lung_dataset

In [3]:
import os

import torch
import torchvision
from PIL import Image
from ignite.engine import Events
from ignite.engine import create_supervised_trainer
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchsummary import summary
from torchvision.transforms.v2 import Compose, PILToTensor, ToDtype

In [25]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_EPOCHS = 0
BUTCH_SIZE = 20
IMAGE_DIR = '../data/train_images/'
MASK_DIR = '../data/train_lung_masks/'
TEST_IMAGE_DIR = '../data/test_images/'
TEST_MASK_DIR = '../data/test_lung_masks/'
LEARNING_RATE = 1e-4
USE_CHECKPOINT = True
GENERATE_IMAGES = True
TRAIN_MODEL = not GENERATE_IMAGESs

In [26]:
transform = Compose([
    PILToTensor(),
    ToDtype(torch.float32, scale=True),
])

In [27]:
from datasets.LungDataset import LungDataset

train_dataset = LungDataset(image_dir=IMAGE_DIR, mask_dir=MASK_DIR, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BUTCH_SIZE)

In [29]:
from models.UNET import UNET

model = UNET(in_channels=1, out_channels=1, features=[32, 64, 128]).to(DEVICE)
summary(model, train_dataset[0][0].shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 256, 256]             288
       BatchNorm2d-2         [-1, 32, 256, 256]              64
              ReLU-3         [-1, 32, 256, 256]               0
            Conv2d-4         [-1, 32, 256, 256]           9,216
       BatchNorm2d-5         [-1, 32, 256, 256]              64
              ReLU-6         [-1, 32, 256, 256]               0
        DoubleConv-7         [-1, 32, 256, 256]               0
         MaxPool2d-8         [-1, 32, 128, 128]               0
            Conv2d-9         [-1, 64, 128, 128]          18,432
      BatchNorm2d-10         [-1, 64, 128, 128]             128
             ReLU-11         [-1, 64, 128, 128]               0
           Conv2d-12         [-1, 64, 128, 128]          36,864
      BatchNorm2d-13         [-1, 64, 128, 128]             128
             ReLU-14         [-1, 64, 1

In [None]:
loss_function = nn.BCEWithLogitsLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
trainer = create_supervised_trainer(model, optimizer, loss_function)

In [30]:
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state,filename)

def load_checkpoint(checkpoint, model):
    model.load_state_dict(checkpoint['state_dict'])
    
if USE_CHECKPOINT:
    load_checkpoint(torch.load('checkpoint-5.pth.tar', map_location=DEVICE), model)

In [65]:
def log_iter_loss(engine):
    print(f'Epoch[{engine.state.epoch}] - Iter[{engine.state.iteration}]: loss = {engine.state.output}')

def create_checkpoint():
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    save_checkpoint(checkpoint)

def log_accuracy(loader):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
    print(f'Accuracy: {num_correct / num_pixels * 100:.2f}')
    print(f'Dice scroe: {dice_score / len(loader)}')

    model.train()

def generate_test_lung_masks():
    model.eval()

    test_images = os.listdir(TEST_IMAGE_DIR)
    to_pil = torchvision.transforms.ToPILImage()

    for image_name in test_images:
        img_path = os.path.join(TEST_IMAGE_DIR, image_name)
        image = Image.open(img_path).convert('L')
        image = transform(image)
        image = image.to(device=DEVICE)
        
        with torch.no_grad():
            predict_image = model(image.unsqueeze(1))
            predict_image = (predict_image > 0.5).float()
            
        new_mask_path = os.path.join(TEST_MASK_DIR, image_name)
        open(new_mask_path, 'w')
        img = to_pil(predict_image[0])
        img.save(new_mask_path)
        # torchvision.utils.save_image(predict_image[0], new_mask_path)
    model.train()

In [None]:
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), log_iter_loss)
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_accuracy, train_loader)
trainer.add_event_handler(Events.EPOCH_COMPLETED, create_checkpoint)
trainer.add_event_handler(Events.COMPLETED, generate_test_lung_masks)

<ignite.engine.events.RemovableEventHandle at 0x7d3654da90c0>

In [66]:
if GENERATE_IMAGES:
    generate_test_lung_masks()

KeyboardInterrupt: 

In [None]:
if TRAIN_MODEL:
    trainer.run(train_loader, NUM_EPOCHS)

State:
	iteration: 0
	epoch: 0
	epoch_length: 1350
	max_epochs: 0
	output: <class 'NoneType'>
	batch: <class 'NoneType'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

In [None]:
# !zip -r lung_dataset/data/test_lung_masks.zip lung_dataset/data/test_lung_masks
# from google.colab import files
# files.download("lung_dataset/data/test_lung_masks.zip")

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
  adding: lung_dataset/data/test_lung_masks/img_2257.png (deflated 3%)
  adding: lung_dataset/data/test_lung_masks/img_5721.png (deflated 2%)
  adding: lung_dataset/data/test_lung_masks/img_5161.png (deflated 3%)
  adding: lung_dataset/data/test_lung_masks/img_37.png (deflated 3%)
  adding: lung_dataset/data/test_lung_masks/img_2057.png (deflated 3%)
  adding: lung_dataset/data/test_lung_masks/img_1429.png (deflated 4%)
  adding: lung_dataset/data/test_lung_masks/img_1361.png (deflated 4%)
  adding: lung_dataset/data/test_lung_masks/img_3914.png (deflated 3%)
  adding: lung_dataset/data/test_lung_masks/img_1145.png (deflated 4%)
  adding: lung_dataset/data/test_lung_masks/img_1807.png (deflated 2%)
  adding: lung_dataset/data/test_lung_masks/img_5897.png (deflated 4%)
  adding: lung_dataset/data/test_lung_masks/img_4519.png (deflated 2%)
  adding: lung_dataset/data/test_lung_masks/img_6095.png (deflated 2

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>