In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip ./drive/MyDrive/images.zip -d ./data/
!unzip ./drive/MyDrive/masks.zip -d ./data/
!unzip ./drive/MyDrive/train.csv.zip -d ./data/

In [6]:
!pip install -U segmentation-models-pytorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.3.2-py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
Collecting pretrainedmodels==0.7.4 (from segmentation-models-pytorch)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting efficientnet-pytorch==0.7.1 (from segmentation-models-pytorch)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting timm==0.6.12 (from segmentation-models-pytorch)
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.

In [1]:
import os

from tqdm import tqdm

from typing import Callable, Tuple, Any

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

import torch
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from segmentation_models_pytorch import Unet
from segmentation_models_pytorch.losses import JaccardLoss, DiceLoss

import albumentations as A
from albumentations.pytorch import ToTensorV2 
from albumentations import (HorizontalFlip,
                            VerticalFlip,
                            Normalize,
                            Compose)

import matplotlib.pyplot as plt

import cv2

In [2]:
class LungsDataset(Dataset):
    def __init__(self,
                 img_dir: str,
                 mask_dir: str,
                 file_list: pd.DataFrame,
                 transform: Callable) -> None:
        super().__init__()
        
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.file_list = file_list
        self.transform = transform

    def __getitem__(self, index) -> Tuple:
        img_name = self.file_list.loc[index, "ImageId"]
        mask_name = self.file_list.loc[index, "MaskId"]
        img_path = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)
        img = cv2.imread(img_path)
        mask = cv2.imread(mask_path)

        mask = mask[:,:,0]
        mask[mask < 240] = 0
        mask[mask > 0] = 1
        
        transformed = self.transform(image = img, mask = mask.astype(np.float32))

        img = transformed['image']
        mask = transformed['mask']

        return img, mask
        
    def __len__(self):
        return len(self.file_list)

In [3]:
BATCH_SIZE = 8

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

transform = A.Compose([
    Normalize(mean=mean, std=std, p=1),
    ToTensorV2(),
])

img_dir = "./data/images/"
mask_dir = "./data/masks/"
file_list = pd.read_csv("./data/train.csv")


train_list, val_list = train_test_split(file_list, test_size=0.2)
train_list, val_list = train_list.reset_index(drop=True), val_list.reset_index(drop=True)

train_dataset = LungsDataset(img_dir, mask_dir, train_list, transform)
val_dataset = LungsDataset(img_dir, mask_dir, val_list, transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0
    for batch_imgs, batch_labels in tqdm(loader):
        batch_imgs, batch_labels = batch_imgs.to(device), batch_labels.to(device)
        
        optimizer.zero_grad()
        batch_outputs = model(batch_imgs)
        batch_loss = criterion(batch_outputs, batch_labels)
        batch_loss.backward()
        optimizer.step()
        
        epoch_loss += batch_loss.detach()
        
    epoch_loss /= len(loader)
    
    return epoch_loss

def val_epoch(model, loader, criterion, device):
    model.eval()
    epoch_loss = 0
    for batch_imgs, batch_labels in loader:
        batch_imgs, batch_labels = batch_imgs.to(device), batch_labels.to(device)
        
        with torch.no_grad():
            batch_outputs = model(batch_imgs)
            batch_loss = criterion(batch_outputs, batch_labels)
        
        epoch_loss += batch_loss.detach()
    
    epoch_loss /= len(loader)
    
    return epoch_loss

def train(model, train_loader, val_loader, num_epochs, criterion, optimizer, device):
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss = val_epoch(model, val_loader, criterion, device)
        print(f'[Epoch {epoch + 1}] train loss: {train_loss:.3f}; val loss: {val_loss:.3f}')

In [6]:
model = Unet('efficientnet-b2', encoder_weights="imagenet")
model = model.to(device)

In [7]:
jaccard_loss = JaccardLoss(mode='binary')
dice_loss = DiceLoss(mode='binary')

adam = Adam(model.parameters())

In [8]:
num_epochs = 10
train(model=model, 
      train_loader = train_loader, 
      val_loader = val_loader, 
      num_epochs = num_epochs, 
      criterion = dice_loss, 
      optimizer = adam,
      device = device
)

100%|██████████| 1671/1671 [16:10<00:00,  1.72it/s]


[Epoch 1] train loss: 0.126; val loss: 0.126


100%|██████████| 1671/1671 [16:08<00:00,  1.73it/s]


[Epoch 2] train loss: 0.102; val loss: 0.094


100%|██████████| 1671/1671 [16:07<00:00,  1.73it/s]


[Epoch 3] train loss: 0.092; val loss: 0.090


 47%|████▋     | 788/1671 [07:37<08:32,  1.72it/s]


KeyboardInterrupt: ignored