In [1]:
import os

import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from omegaconf import OmegaConf
import albumentations as A

from src.utils import instantiate_from_config, get_obj_from_str

In [2]:
config = OmegaConf.load('./configs/base_config.yaml', )
print(OmegaConf.to_yaml(config))

common:
  gpus:
  - 0
  seed: 17
  folds_count: null
  batch_size: 2
  num_workers: 2
  epochs: 256
  exp_name: test_experiment
  wandb: false
model:
  target: segmentation_models_pytorch.Unet
  params:
    encoder_name: resnet34
    classes: 4
criterions:
- target: segmentation_models_pytorch.losses.FocalLoss
  params:
    mode: multiclass
  weight: 0.5
  name: focal
- target: segmentation_models_pytorch.losses.JaccardLoss
  params:
    mode: multiclass
  weight: 0.5
  name: jaccard
optimizers:
- target: torch.optim.Adam
  params:
    lr: 0.002
  scheduler:
    target: pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR
    params:
      warmup_epochs: 16
      max_epochs: 256
      warmup_start_lr: 0.001
    additional:
      monitor: iou_valid
metrics:
- target: segmentation_models_pytorch.utils.metrics.IoU
  params:
    threshold: 0.5
  name: iou
  use_bg: false
callbacks:
- target: pytorch_lightning.callbacks.LearningRateMonitor
- target: pytorch_lightning.callbacks.Mod

In [3]:
pl.seed_everything(config['common']['seed'], workers=True)

Global seed set to 17


17

In [4]:
# TRAIN_IMAGES_FOLDER = '/home/user/datasets/hubmap-organ-segmentation/train_images'
# TRAIN_MASKS_FOLDER = '/home/user/datasets/hubmap-organ-segmentation/train_masks'

In [5]:
class TrainDataset(Dataset):
    def __init__(self, images_dir, masks_dir, labels, img_w=256, img_h=256, augs=None, img_format='png'):
        self.img_names = get_img_names(images_dir, img_format=img_format)
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.labels = labels
        self.img_w = img_w
        self.img_h = img_h
        self.augs = augs
        
    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, index):
        img_name = self.img_names[index]
        img_path = os.path.join(self.images_dir, img_name)
        msk_path = os.path.join(self.masks_dir, img_name)

        image = cv2.imread(img_path)
        mask = cv2.imread(msk_path, 0)

        if self.augs is not None:
            item = self.augs(image=image, mask=mask)
            image = item['image']
            mask = item['mask']

        image = preprocess_image(image, img_w=self.img_w, img_h=self.img_h)
        oh_mask = preprocess_mask2onehot(mask, self.labels, img_w=self.img_w, img_h=self.img_h)
        sg_mask = preprocess_single_mask(mask, self.labels, img_w=self.img_w, img_h=self.img_h)

        return {
            'image': image, 
            'oh_mask': oh_mask, 
            'sg_mask': sg_mask,
        }

def get_train_augs():
    return A.Compose([
            A.RandomCrop(512*1, 512*1, p=1),
            A.ToGray(p=0.15),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.15),
            A.Rotate(limit=180, border_mode=3, p=0.5),
            A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        ], p=1.0)


def get_valid_augs():
    return None

In [6]:
class LightningModel(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = instantiate_from_config(config['model'])
        
        self.criterions = {x['name']: instantiate_from_config(x) for x in config['criterions']}
        self.crit_weights = {x['name']: x['weight'] for x in config['criterions']}
        
        self.metrics = {x['name']: instantiate_from_config(x) for x in config['metrics']}
        self.use_bg = {x['name']: x['use_bg'] for x in config['metrics']}

    def forward(self, x):
        return self.model(x)
    
    def _common_step(self, batch, batch_idx, stage):
        gt_img, sg_mask, oh_mask = batch['image'], batch['sg_mask'].long(), batch['oh_mask']
        pr_msk = self.model(gt_img)
         
        loss = 0
        for c_name in self.criterions.keys():
            c_loss = self.criterions[c_name](pr_msk, sg_mask) * self.crit_weights[c_name]
            self.log(f"{c_name}_loss_{stage}", c_loss, on_epoch=True, prog_bar=True)
            loss += c_loss
        self.log(f"total_loss_{stage}", loss, on_epoch=True, prog_bar=True)

        for m_name in self.metrics.keys():
            metric_info = f"{m_name}_{stage}"
            index = 0 if self.use_bg[m_name] else 1
            metric_value = self.metrics[m_name](pr_msk[:, index:, :, :], oh_mask[:, index:, :, :])
            self.log(metric_info, metric_value, on_epoch=True, prog_bar=True)              
        return {
            'loss': loss,
        }
    
    def training_step(self, batch, batch_idx):
        item = self._common_step(batch, batch_idx, 'train')
        return item

    def validation_step(self, batch, batch_idx):
        item = self._common_step(batch, batch_idx, 'valid')
        return item
    
    def test_step(self, batch, batch_idx):
        item = self._common_step(batch, batch_idx, 'test')
        return item

    def configure_optimizers(self):
        optimizers = []
        schedulers = []
        
        for item in self.config['optimizers']:
            optimizer = get_obj_from_str(item['target'])(
                self.parameters(), 
                **item.get('params', {}))
            optimizers.append(optimizer)
            
            scheduler = get_obj_from_str(item['scheduler']['target'])(
                optimizer = optimizer, 
                **item['scheduler'].get('params', {}))
            schedulers.append({
                'scheduler':scheduler,
                **item['scheduler']['additional'],
            })
        
        return optimizers, schedulers
    
    def configure_callbacks(self):
        callbacks = []
        for item in self.config.get('callbacks', []):
            callback = instantiate_from_config(item)
            callbacks.append(callback)  
        return callbacks

In [7]:
class DataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.train = instantiate_from_config(config['datasets']['train'])
        self.train.augs = get_train_augs()

        self.valid = instantiate_from_config(config['datasets']['valid'])
        self.valid.augs = get_valid_augs()

    def train_dataloader(self):
        return DataLoader(self.train, 
                          batch_size=self.config['common'].get('batch_size', 1),  
                          num_workers=self.config['common'].get('num_workers', 1),
                          drop_last=self.config['common'].get('drop_last', True),
                          pin_memory=self.config['common'].get('pin_memory', True),
                          shuffle=self.config['common'].get('shuffle', True),)

    def val_dataloader(self):
        return DataLoader(self.valid, 
                          batch_size=self.config['common'].get('batch_size', 1),  
                          num_workers=self.config['common'].get('num_workers', 1),
                          drop_last=self.config['common'].get('drop_last', False),
                          pin_memory=self.config['common'].get('pin_memory', True),
                          shuffle=self.config['common'].get('shuffle', False),)

In [8]:
model = LightningModel(config)
datamodule = DataModule(config)

trainer = pl.Trainer(max_epochs=config['common']['epochs'], gpus=config['common']['gpus'])
trainer.fit(model, datamodule) 

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | Unet | 24.4 M
-------------------------------
24.4 M    Trainable params
0         Non-trainable params
24.4 M    Total params
97.747    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
