## 初期設定

### ライブラリのインストール

In [1]:
!pip install segmentation_models_pytorch
!pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.3.0-py3-none-any.whl (97 kB)
[K     |████████████████████████████████| 97 kB 3.4 MB/s 
Collecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[K     |████████████████████████████████| 376 kB 20.7 MB/s 
[?25hCollecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[K     |████████████████████████████████| 58 kB 7.2 MB/s 
Collecting efficientnet-pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
Collecting munch
  Downloading munch-2.5.0-py2.py3-none-any.whl (10 kB)
Building wheels for collected packages: efficientnet-pytorch, pretrainedmodels
  Building wheel for efficientnet-pytorch (setup.py) ... [?25l[?25hdone
  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16446 sha256=220fa087

### ライブラリのインポート

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
import random
from tqdm.notebook import tqdm

import pandas as pd
import numpy as np
import cv2

import segmentation_models_pytorch as smp
from albumentations import *

import torch
from torch.utils.data import Dataset, DataLoader

### Google Drive のマウント

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### データセットの準備

In [4]:
!mkdir /root/.kaggle
!cp /content/drive/MyDrive/kaggle/kaggle.json /root/.kaggle/

In [5]:
!kaggle datasets download -d ogog0128/hubmap-2022-lung-expand -p /content/
!unzip /content/hubmap-2022-lung-expand.zip -d /content
!rm /content/hubmap-2022-lung-expand.zip 

[1;30;43mストリーミング出力は最後の 5000 行に切り捨てられました。[0m
  inflating: /content/hubmap-2022-256x256/masks/31958_0005.png  
  inflating: /content/hubmap-2022-256x256/masks/31958_0007.png  
  inflating: /content/hubmap-2022-256x256/masks/32009_0000.png  
  inflating: /content/hubmap-2022-256x256/masks/32009_0001.png  
  inflating: /content/hubmap-2022-256x256/masks/32009_0002.png  
  inflating: /content/hubmap-2022-256x256/masks/32009_0003.png  
  inflating: /content/hubmap-2022-256x256/masks/32009_0004.png  
  inflating: /content/hubmap-2022-256x256/masks/32009_0005.png  
  inflating: /content/hubmap-2022-256x256/masks/32009_0006.png  
  inflating: /content/hubmap-2022-256x256/masks/32009_0007.png  
  inflating: /content/hubmap-2022-256x256/masks/32009_0008.png  
  inflating: /content/hubmap-2022-256x256/masks/32126_0000.png  
  inflating: /content/hubmap-2022-256x256/masks/32126_0001.png  
  inflating: /content/hubmap-2022-256x256/masks/32126_0002.png  
  inflating: /content/hubmap-2022-256x256/ma

In [6]:
!kaggle datasets download -d thedevastator/hubmap-2022-256x256 -p /content/
!unzip /content/hubmap-2022-256x256.zip -d /content/hubmap-2022-256x256_org
!rm /content/hubmap-2022-256x256.zip 

[1;30;43mストリーミング出力は最後の 5000 行に切り捨てられました。[0m
  inflating: /content/hubmap-2022-256x256_org/masks/164_0005.png  
  inflating: /content/hubmap-2022-256x256_org/masks/164_0006.png  
  inflating: /content/hubmap-2022-256x256_org/masks/164_0007.png  
  inflating: /content/hubmap-2022-256x256_org/masks/164_0008.png  
  inflating: /content/hubmap-2022-256x256_org/masks/16564_0001.png  
  inflating: /content/hubmap-2022-256x256_org/masks/16564_0002.png  
  inflating: /content/hubmap-2022-256x256_org/masks/16564_0003.png  
  inflating: /content/hubmap-2022-256x256_org/masks/16564_0004.png  
  inflating: /content/hubmap-2022-256x256_org/masks/16564_0005.png  
  inflating: /content/hubmap-2022-256x256_org/masks/16564_0007.png  
  inflating: /content/hubmap-2022-256x256_org/masks/16564_0008.png  
  inflating: /content/hubmap-2022-256x256_org/masks/16609_0000.png  
  inflating: /content/hubmap-2022-256x256_org/masks/16609_0001.png  
  inflating: /content/hubmap-2022-256x256_org/masks/16609_0002.pn

In [7]:
!kaggle competitions download -c hubmap-organ-segmentation -p /content/
!unzip /content/hubmap-organ-segmentation.zip -d /content/hubmap-organ-segmentation/
!rm /content/hubmap-organ-segmentation.zip

Downloading hubmap-organ-segmentation.zip to /content
100% 5.76G/5.78G [00:33<00:00, 268MB/s]
100% 5.78G/5.78G [00:33<00:00, 186MB/s]
Archive:  /content/hubmap-organ-segmentation.zip
  inflating: /content/hubmap-organ-segmentation/sample_submission.csv  
  inflating: /content/hubmap-organ-segmentation/test.csv  
  inflating: /content/hubmap-organ-segmentation/test_images/10078.tiff  
  inflating: /content/hubmap-organ-segmentation/train.csv  
  inflating: /content/hubmap-organ-segmentation/train_annotations/10044.json  
  inflating: /content/hubmap-organ-segmentation/train_annotations/10274.json  
  inflating: /content/hubmap-organ-segmentation/train_annotations/10392.json  
  inflating: /content/hubmap-organ-segmentation/train_annotations/10488.json  
  inflating: /content/hubmap-organ-segmentation/train_annotations/10610.json  
  inflating: /content/hubmap-organ-segmentation/train_annotations/10611.json  
  inflating: /content/hubmap-organ-segmentation/train_annotations/10651.json  


In [8]:
!kaggle datasets download -d ogog0128/hubmap-my-misc -p /content/
!unzip /content/hubmap-my-misc.zip -d /content/hubmap-my-misc/
!rm /content/hubmap-my-misc.zip

Downloading hubmap-my-misc.zip to /content
 76% 18.0M/23.7M [00:00<00:00, 74.1MB/s]
100% 23.7M/23.7M [00:00<00:00, 94.1MB/s]
Archive:  /content/hubmap-my-misc.zip
  inflating: /content/hubmap-my-misc/fold_0_train.csv  
  inflating: /content/hubmap-my-misc/fold_0_valid.csv  
  inflating: /content/hubmap-my-misc/fold_1_train.csv  
  inflating: /content/hubmap-my-misc/fold_1_valid.csv  
  inflating: /content/hubmap-my-misc/fold_2_train.csv  
  inflating: /content/hubmap-my-misc/fold_2_valid.csv  
  inflating: /content/hubmap-my-misc/fold_3_train.csv  
  inflating: /content/hubmap-my-misc/fold_3_valid.csv  


In [9]:
!kaggle datasets download -d ogog0128/hubmap-mydata
!unzip /content/hubmap-mydata.zip -d /content/hubmap-mydata/
!rm /content/hubmap-mydata.zip

Downloading hubmap-mydata.zip to /content
 88% 33.0M/37.7M [00:00<00:00, 115MB/s] 
100% 37.7M/37.7M [00:00<00:00, 115MB/s]
Archive:  /content/hubmap-mydata.zip
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10044.png  
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10274.png  
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10392.png  
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10488.png  
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10610.png  
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10611.png  
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10651.png  
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10666.png  
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10703.png  
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10892.png  
  inflating: /content/hubmap-mydata/hubmap-2022-for-Train/masks/10912.png  
  in

## 学習環境設定

### Configuration

In [405]:
from torch.optim import optimizer
class config:
    competition = 'Hubmap-2022'
    train_name = 'lung-DataExpand'
    train_type = 'Patch' # 'Patch' or 'NoPatch'
    nfold = 4
    fold = 3
    SEED = 2020
    BASE = '/content/hubmap-organ-segmentation/'
    FOLD_BASE = '/content/hubmap-my-misc/'
    TRAIN = {'Patch':'/content/hubmap-2022-256x256/train/',
             'NoPatch':'/content/hubmap-mydata/hubmap-2022-for-Train/train'}
    MASKS = {'Patch':'/content/hubmap-2022-256x256/masks/',
             'NoPatch':'/content/hubmap-mydata/hubmap-2022-for-Train/masks'}
    VALID_TRAIN = '/content/hubmap-2022-256x256_org/train/'
    VALID_MASK = '/content/hubmap-2022-256x256_org/masks/'
    LABELS = '/content/hubmap-organ-segmentation/train.csv'
    SAVE_BASE = '/content/drive/MyDrive/Colab Notebooks/save_models'
    NUM_WORKERS = 2
    organ_type = {'lung':1, 'kidney':2, 'largeintestine':3, 'prostate':4, 'spleen':5}

    ARCH = 'unetplusplus'
    BACKBONE ='efficientnet-b7'
    WEIGHTS = 'imagenet'

    MAX_SAVEMODEL = 8
    model_soups = True

    pre_epoch = 64
    pre_batchsize = 64
    pre_init_lr = 1e-1
    pre_sch_step = 128
    pre_sch_gamma = 0.1
    pre_BN_fix = False

    post_epoch = 256
    post_batchsize = 8
    post_init_lr = 1e-2
    post_sch_step = 96
    post_sch_gamma = 0.1
    post_BN_fix = True
    
    p = 1.0
    train_transform = Compose([
        HorizontalFlip(p=0.5),
        VerticalFlip(),
        RandomRotate90(p=1),
        # Morphology
        ShiftScaleRotate(shift_limit=0, scale_limit=(-0.2, 0.2), rotate_limit=(-30, 30),
                         interpolation=1, border_mode=0, value=(0, 0, 0), p=0.5),
        GaussNoise(var_limit=(0, 50.0), mean=0, p=0.5),
        GaussianBlur(blur_limit=(3, 7), p=0.5),
        # Color
        RandomBrightnessContrast(brightness_limit=0.35, contrast_limit=0.5,
                                 brightness_by_max=True, p=0.5),
        HueSaturationValue(hue_shift_limit=30, sat_shift_limit=30,
                           val_shift_limit=0, p=0.5),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            PiecewiseAffine(p=0.3),
        ], p=0.3),
    ], p=p)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [406]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(config.SEED)

In [407]:
TRAIN_PHASE = {'Not Start':0, 'PreTrain':1, 'PostTrain':2, 'Finish':3}

Phase = 0
Resume = False

### 学習保存場所

In [408]:
SAVE_PATH = os.path.join(config.SAVE_BASE, config.competition, config.train_type, config.train_name, f'fold_{config.fold}')
MODEL_NAME =  f'model_fold{config.fold}_{config.BACKBONE}'
os.makedirs(os.path.join(SAVE_PATH, '01_PreTrain'), exist_ok=True)
os.makedirs(os.path.join(SAVE_PATH, '02_PostTrain', 'Models'), exist_ok=True)

### 学習フェーズの確認

In [409]:
if os.path.isfile(os.path.join(SAVE_PATH, '01_PreTrain', f'{MODEL_NAME}.pth')):
    if os.path.isfile(os.path.join(SAVE_PATH, '02_PostTrain', f'{MODEL_NAME}.pth')):
        Phase = TRAIN_PHASE['Finish']
        raise Exception()
    else:
        Phase = TRAIN_PHASE['PostTrain']
        config.WEIGHTS = None
        MODEL_PATH = os.path.join(SAVE_PATH, '02_PostTrain', 'checkpoint.pth')
        if os.path.isfile(MODEL_PATH):
            Resume = True
        else:
            Resume = False
            MODEL_PATH = os.path.join(SAVE_PATH, '01_PreTrain', f'{MODEL_NAME}.pth')        
        SAVE_PATH = os.path.join(SAVE_PATH, '02_PostTrain')
else:
    SAVE_PATH = os.path.join(SAVE_PATH, '01_PreTrain')
    MODEL_PATH = os.path.join(SAVE_PATH, 'checkpoint.pth')
    if os.path.isfile(MODEL_PATH):
        Resume = True
        config.WEIGHTS = None
        Phase = TRAIN_PHASE['PreTrain']
    else:
        Resume = False
        Phase = TRAIN_PHASE['Not Start']


CHECKPOINT_PATH = os.path.join(SAVE_PATH, 'checkpoint.pth')

### モデルデータ・再開パラメータ読み込み

In [410]:
model = smp.create_model(arch=config.ARCH, encoder_weights=config.WEIGHTS, encoder_name=config.BACKBONE, classes=1, activation=None)

In [411]:
if Phase < TRAIN_PHASE['PostTrain']:
    INIT_LR = config.pre_init_lr
    END_EPOCH = config.pre_epoch
    BATCH_SIZE = config.pre_batchsize
else:
    INIT_LR = config.post_init_lr
    END_EPOCH = config.post_epoch
    BATCH_SIZE = config.post_batchsize

optimizer = torch.optim.SGD([dict(params=model.parameters(), lr=INIT_LR, momentum=0.9)])

In [412]:
if Phase < TRAIN_PHASE['PostTrain']:
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.pre_sch_step, gamma=config.pre_sch_gamma)
else:
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.post_sch_step, gamma=config.post_sch_gamma)

In [413]:
loss_score = {'train_loss':[], 'train_score':[], 'valid_loss':[], 'valid_score':[], 'dice_score':[], 'lung':[], 'kidney':[], 'largeintestine':[], 'prostate':[], 'spleen':[]}
save_epochs = []
_epoch = -1
if Phase > TRAIN_PHASE['Not Start']:
    if Resume:
        checkpoint = torch.load(MODEL_PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(config.device)

        _epoch = checkpoint['epoch']
        loss_score = checkpoint['loss_score']
        save_epochs = checkpoint['save_epoch']
    else:
        state_dict = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
        model.load_state_dict(state_dict)

model = model.float()
model = model.to(config.device)

for param in model.parameters():
    param.requires_grad = True

### Dataset定義

In [414]:
def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset_Patch(Dataset):
    def __init__(self, fold=0, train=True, tfms=None):
        if train:
            filename = config.FOLD_BASE + f'fold_{fold}_train.csv'
        else:
            filename = config.FOLD_BASE + f'fold_{fold}_valid.csv'

        self.df = pd.read_csv(filename)
        #ids = set(df[df['organ'] == 'lung'].id.astype(str).values)

        ids = set(pd.read_csv(filename).id.astype(str).values)
        if train:
            self.fnames = [fname for fname in os.listdir(config.TRAIN['Patch']) if fname.split('_')[0] in ids]
        else:
            self.fnames = [fname for fname in os.listdir(config.VALID_TRAIN) if fname.split('_')[0] in ids]
        self.train = train
        self.tfms = tfms


        # hubmap-2022-256x256
        self.mean = np.array([0.7720342, 0.74582646, 0.76392896])
        self.std = np.array([0.24745085, 0.26182273, 0.25782376])
        
    def __len__(self):
        return len(self.fnames)
    
    def __getitem__(self, idx):
        fname = self.fnames[idx]
        id = int(fname.split('_')[0])
        organ= self.df[self.df['id'] == id]['organ'].values[0]
        
        if self.train:
            img = cv2.cvtColor(cv2.imread(os.path.join(config.TRAIN['Patch'],fname)), cv2.COLOR_BGR2RGB)
            mask = cv2.imread(os.path.join(config.MASKS['Patch'],fname),cv2.IMREAD_GRAYSCALE)
        else:
            img = cv2.cvtColor(cv2.imread(os.path.join(config.VALID_TRAIN,fname)), cv2.COLOR_BGR2RGB)
            mask = cv2.imread(os.path.join(config.VALID_MASK,fname),cv2.IMREAD_GRAYSCALE)
            
        if self.tfms is not None:
            augmented = self.tfms(image=img,mask=mask)
            img,mask = augmented['image'],augmented['mask']
        return img2tensor((img/255.0 - self.mean)/self.std),img2tensor(mask), organ

class HuBMAPDataset_NoPatch(Dataset):
    def __init__(self, fold=0, train=True, tfms=None):
        if train:
            filename = config.FOLD_BASE + f'fold_{fold}_train.csv'
        else:
            filename = config.FOLD_BASE + f'fold_{fold}_valid.csv'

        #df = pd.read_csv(filename)
        #ids = set(df[df['organ'] == 'lung'].id.astype(str).values)

        ids = set(pd.read_csv(filename).id.astype(str).values)
        #self.fnames = [fname for fname in os.listdir(config.TRAIN) if fname.split('_')[0] in ids]
        self.fnames = [fname for fname in os.listdir(config.TRAIN['NoPatch']) if fname.split('.')[0] in ids]
        self.train = train
        self.tfms = tfms


        # input/hubmap-2022-for-Train
        self.mean = np.array([0.82829359, 0.80269771, 0.82058153])
        self.std = np.array([0.14989631, 0.17862655, 0.16854124])
        
    def __len__(self):
        return len(self.fnames)
    
    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = cv2.cvtColor(cv2.imread(os.path.join(config.TRAIN['NoPatch'],fname)), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(os.path.join(config.MASKS['NoPatch'],fname),cv2.IMREAD_GRAYSCALE)
        if self.tfms is not None:
            augmented = self.tfms(image=img,mask=mask)
            img,mask = augmented['image'],augmented['mask']
        return img2tensor((img/255.0 - self.mean)/self.std),img2tensor(mask)

### Loss Function定義

In [415]:
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse


def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard


def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
    """
    IoU for foreground class
    binary: 1 foreground, 0 background
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        intersection = ((label == 1) & (pred == 1)).sum()
        union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
        if not union:
            iou = EMPTY
        else:
            iou = float(intersection) / union
        ious.append(iou)
    iou = f_mean(ious)    # mean accross images if per_image
    return 100 * iou


def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
    """
    Array of IoU for each (non ignored) class
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        iou = []    
        for i in range(C):
            if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
                intersection = ((label == i) & (pred == i)).sum()
                union = ((label == i) | ((pred == i) & (label != ignore))).sum()
                if not union:
                    iou.append(EMPTY)
                else:
                    iou.append(float(intersection) / union)
        ious.append(iou)
    ious = map(f_mean, zip(*ious)) # mean accross images if per_image
    return 100 * np.array(ious)


# --------------------------- BINARY LOSSES ---------------------------


def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = f_mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
                          for log, lab in zip(logits, labels))
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    #loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    loss = torch.dot(F.elu(errors_sorted)+1, Variable(grad))
    return loss


def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = scores.view(-1)
    labels = labels.view(-1)
    if ignore is None:
        return scores, labels
    valid = (labels != ignore)
    vscores = scores[valid]
    vlabels = labels[valid]
    return vscores, vlabels


class StableBCELoss(torch.nn.modules.Module):
    def __init__(self):
         super(StableBCELoss, self).__init__()
    def forward(self, input, target):
         neg_abs = - input.abs()
         loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
         return loss.mean()


def binary_xloss(logits, labels, ignore=None):
    """
    Binary Cross entropy loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      ignore: void class id
    """
    logits, labels = flatten_binary_scores(logits, labels, ignore)
    loss = StableBCELoss()(logits, Variable(labels.float()))
    return loss


# --------------------------- MULTICLASS LOSSES ---------------------------


def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    if per_image:
        loss = f_mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present)
                          for prob, lab in zip(probas, labels))
    else:
        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present)
    return loss


def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float() # foreground for class c
        if only_present and fg.sum() == 0:
            continue
        errors = (Variable(fg) - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return f_mean(losses)


def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    B, C, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C
    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = (labels != ignore)
    vprobas = probas[valid.nonzero().squeeze()]
    vlabels = labels[valid]
    return vprobas, vlabels

def xloss(logits, labels, ignore=None):
    """
    Cross entropy loss
    """
    return F.cross_entropy(logits, Variable(labels), ignore_index=255)


# --------------------------- HELPER FUNCTIONS ---------------------------

def f_mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(np.isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n

In [416]:
def symmetric_lovasz(outputs, targets):
    return 0.5*(lovasz_hinge(outputs, targets) + lovasz_hinge(-outputs, 1.0 - targets))

In [417]:
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    return loss.mean()


def calc_loss(pred, target, metrics=None, bce_weight=0.5):
    # Dice LossとCategorical Cross Entropyを混ぜていい感じにしている
    bce = F.binary_cross_entropy_with_logits(pred, target)
    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)
    loss = bce * bce_weight + dice * (1 - bce_weight)
    return loss

### 評価関数の定義

In [418]:
def compute_dice_score(probability, mask):
    N = len(probability)
    p = probability.reshape(N,-1)
    t = mask.reshape(N,-1)

    p = p>0.5
    t = t>0.5
    uion = p.sum(-1) + t.sum(-1)
    overlap = (p*t).sum(-1)
    dice = 2*overlap/(uion+0.0001)
    return dice

In [419]:
# dice with automatic threshold selection
class Dice_th():
    def __init__(self, ths=np.arange(0.0,1.0,0.01), axis=1): 
        self.axis = axis
        self.ths = ths
        
    def reset(self): 
        self.inter = torch.zeros(len(self.ths))
        self.union = torch.zeros(len(self.ths))
        
    def accumulate(self, prob, mask):
        N = len(prob)
        pred = prob.reshape(N,-1)
        targ = mask.reshape(N,-1)
        
        for i,th in enumerate(self.ths):
            p = (pred > th).float()
            self.inter[i] += (p*targ).float().sum().item()
            self.union[i] += (p+targ).float().sum().item()

    @property
    def value(self):
        dices = torch.where(self.union > 0.0, 
                2.0*self.inter/self.union, torch.zeros_like(self.union))
        return dices.max()

## 学習


### データーローダー

In [420]:
if config.train_type == 'Patch':
    ds_t = HuBMAPDataset_Patch(fold=config.fold, train=True, tfms=config.train_transform)
    ds_v = HuBMAPDataset_Patch(fold=config.fold, train=False)
else:
    ds_t = HuBMAPDataset_NoPatch(fold=config.fold, train=True, tfms=config.train_transform)
    ds_v = HuBMAPDataset_NoPatch(fold=config.fold, train=False)
    
t_dataloader = DataLoader(dataset=ds_t, batch_size=BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True, drop_last=True)
v_dataloader = DataLoader(dataset=ds_v, batch_size=BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True)

### Batch Normalized 層固定

In [421]:
def freeze_bn(net, sw=False):
    for m in net.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.eval()
            m.weight.requires_grad = sw
            m.bias.requires_grad = sw

### パラメータ更新の制御

In [422]:
if Phase < TRAIN_PHASE['PostTrain']:
    freeze_bn(model, sw=not config.pre_BN_fix)
    for param in model.encoder.parameters():
        param.requires_grad = False
else:
    freeze_bn(model, sw=not config.post_BN_fix)

### チェックポイントの保存

In [423]:
def save_checkpoint(epoch, model, optim, loss_score, save_epoch, save_path):

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_score': loss_score,
        'save_epoch': save_epochs
    }, save_path)

### 学習ループ

In [424]:
def select_del_epochs(loss_score, save_epochs):

    #loss_score = {'train_loss':[], 'train_score':[], 'valid_loss':[], 'valid_score':[], 'dice_score':[], 'lung':[], 'kidney':[], 'largeintestine':[], 'prostate':[], 'spleen':[]}
    keep_epochs  = set(np.argsort(np.array(loss_score['dice_score']))[::-1][:config.MAX_SAVEMODEL]) 
    keep_epochs |= set(np.argsort(np.array(loss_score['lung']))[::-1][:config.MAX_SAVEMODEL])
    keep_epochs |= set(np.argsort(np.array(loss_score['kidney']))[::-1][:config.MAX_SAVEMODEL])
    keep_epochs |= set(np.argsort(np.array(loss_score['largeintestine']))[::-1][:config.MAX_SAVEMODEL])
    keep_epochs |= set(np.argsort(np.array(loss_score['prostate']))[::-1][:config.MAX_SAVEMODEL])
    keep_epochs |= set(np.argsort(np.array(loss_score['spleen']))[::-1][:config.MAX_SAVEMODEL])

    del_epochs = set(save_epochs) - keep_epochs

    return del_epochs

In [425]:
scaler = torch.cuda.amp.GradScaler()
dice = Dice_th()
dice.reset()
organ_score = {}
for key in config.organ_type:
    organ_score[key] = Dice_th()
    organ_score[key].reset()
EPOCH_MODELS_PATH = os.path.join(SAVE_PATH, 'Models')

In [426]:
_epoch += 1
for epoch in range(_epoch, END_EPOCH, 1):
    train_loss = []
    train_score = []
    valid_loss = []
    valid_score = []
    pbar = tqdm(t_dataloader, desc = 'description')
    model.train()
    for img, label, organ in pbar:
        x = img.to(config.device)
        label = label.to(config.device)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            y = model(x)

        loss_val = symmetric_lovasz(y.float(), label)
        losses_value = loss_val.item()
        score = compute_dice_score(y.float(), label)

        scaler.scale(loss_val).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss.append(losses_value)
        train_score.extend(score.cpu().numpy().tolist())

        pbar.set_description(f"Epoch: {epoch+1}, LR: {scheduler.get_last_lr()}, loss: {np.mean(train_loss):.3f}, DICE: {np.mean(train_score):.3f}")

    with torch.no_grad():
        model.eval()
        for img, label, organ in v_dataloader:
            x = img.to(config.device)
            label = label.to(config.device)

            with torch.cuda.amp.autocast():
                y = model(x)

            loss_val = symmetric_lovasz(y.float(), label)

            losses_value = loss_val.item()
            score = compute_dice_score(y.float(), label)
            dice.accumulate(y.float(), label)
            for i, i_organ in enumerate(organ): organ_score[i_organ].accumulate(y.float()[i:i+1,:,:,:], label[i:i+1,:,:,:])

            valid_loss.append(losses_value)
            valid_score.extend(score.cpu().numpy().tolist())

    #loss_score = {'train_loss':[], 'train_score':[], 'valid_loss':[], 'valid_score':[], 'dice_score':[], 'lung':[], 'kidney':[], 'largeintestine':[], 'prostate':[], 'spleen':[]}
    loss_score['train_loss'].append(np.mean(train_loss))
    loss_score['train_score'].append(np.mean(train_score))
    loss_score['valid_loss'].append(np.mean(valid_loss))
    loss_score['valid_score'].append(np.mean(valid_score))
    loss_score['dice_score'].append(dice.value)
    
    loss_score['lung'].append(organ_score['lung'].value)
    loss_score['kidney'].append(organ_score['kidney'].value)
    loss_score['largeintestine'].append(organ_score['largeintestine'].value)
    loss_score['prostate'].append(organ_score['prostate'].value)
    loss_score['spleen'].append(organ_score['spleen'].value)

    print(f"Train Loss: {loss_score['train_loss'][-1]}, Train DICE: {loss_score['train_score'][-1]}")
    print(f"Valid Loss: {loss_score['valid_loss'][-1]}, Valid DICE: {loss_score['valid_score'][-1]}, Dice_th : {dice.value}")
    print(f"lung: {loss_score['lung'][-1]}, kidney: {loss_score['kidney'][-1]}, largeintestine : {loss_score['largeintestine'][-1]}, prostate: {loss_score['prostate'][-1]}, spleen : {loss_score['spleen'][-1]}")

    # 評価値に基づきモデルを保存
    if Phase == TRAIN_PHASE['PostTrain']:
        save_epochs.append(epoch)
        model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{epoch}.pth')
        torch.save(model.state_dict(), model_name)

        del_epochs = select_del_epochs(loss_score, save_epochs)

        for del_ep in del_epochs:
            model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{del_ep}.pth')
            os.remove(model_name)
            save_epochs.remove(del_ep)

    dice.reset()
    scheduler.step()

    save_checkpoint(epoch, model, optimizer, loss_score, save_epochs, CHECKPOINT_PATH)

## 最終モデル保存

### 重み結合

In [427]:
def sum_model_params(sdA, sdB):
    """ modelA + modelB """
    sd = sdA.copy()
    for key in sdA:
        sd[key] = (sdB[key] + sdA[key])
    return sd

In [428]:
def multi_model_params(sd, a):
    for key in sd:
        sd[key] = sd[key] * a
    return sd

### ModelSoups

In [429]:
ep = save_epochs[np.argsort(np.array(loss_score["dice_score"])[save_epochs])[::-1][0]]

In [430]:
print(ep)

209


In [431]:
np.sort(np.array(loss_score["dice_score"])[save_epochs])[::-1][:5]

array([0.8410327 , 0.84072953, 0.84068286, 0.84051627, 0.84027   ],
      dtype=float32)

In [432]:
model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{ep}.pth')
state_dict = torch.load(model_name,map_location=torch.device('cpu'))
torch.save(state_dict, os.path.join(SAVE_PATH, MODEL_NAME)+f'_score_dice.pth')

In [292]:
"""
if config.model_soups and Phase == TRAIN_PHASE['PostTrain']:

    for organ_key in config.organ_type:

        ModelSoupOrder = np.argsort(np.array(loss_score[organ_key])[save_epochs])[::-1]
        max_score = -1.0
        combine_model_epochs = []

        dice = Dice_th()
        dice.reset()
        organ_score = {}
        for key in config.organ_type:
            organ_score[key] = Dice_th()
            organ_score[key].reset()

        pbar = tqdm(ModelSoupOrder, desc = 'Model Soup')
        for loop, idx in enumerate(pbar):
            
            epoch = save_epochs[idx]
            model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{epoch}.pth')
            state_dict_sum = torch.load(model_name,map_location=torch.device('cpu'))

            for cm_epoch in combine_model_epochs:

                model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{cm_epoch}.pth')
                state_dict_read = torch.load(model_name,map_location=torch.device('cpu'))
                state_dict_sum = sum_model_params(state_dict_sum, state_dict_read)

            model = smp.create_model(arch=config.ARCH, encoder_weights=None, encoder_name=config.BACKBONE, classes=1, activation=None)
            model.load_state_dict(multi_model_params(state_dict_sum, 1.0 / (len(combine_model_epochs) + 1)))
            model = model.float()
            model = model.eval()
            model = model.to(config.device)

            with torch.no_grad():
                for img, label, organ in v_dataloader:
                    x = img.to(config.device)
                    label = label.to(config.device)

                    with torch.cuda.amp.autocast():
                        y = model(x)
                    dice.accumulate(y.float(), label)
                    for i, i_organ in enumerate(organ): organ_score[i_organ].accumulate(y.float()[i:i+1,:,:,:], label[i:i+1,:,:,:])

            print(organ_score[organ_key].value)
            if organ_score[organ_key].value > max_score:
                max_score = organ_score[organ_key].value
                combine_model_epochs.append(epoch)

            dice.reset()
            for key in config.organ_type:
                organ_score[key].reset()
        
        for i, cm_epoch in enumerate(combine_model_epochs):

            model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{cm_epoch}.pth')
            state_dict_read = torch.load(model_name,map_location=torch.device('cpu'))
            if i == 0:
                state_dict_sum = state_dict_read.copy()
            else:
                state_dict_sum = sum_model_params(state_dict_sum, state_dict_read)

        model = smp.create_model(arch=config.ARCH, encoder_weights=None, encoder_name=config.BACKBONE, classes=1, activation=None)
        state_dict = multi_model_params(state_dict_sum, 1.0 / len(combine_model_epochs))

        torch.save(state_dict, os.path.join(SAVE_PATH, MODEL_NAME)+f'_{organ_key}.pth')
        
        print(f'{organ_key} : {max_score}')
else:
    torch.save(model.state_dict(), os.path.join(SAVE_PATH, MODEL_NAME)+'.pth')
"""

"\nif config.model_soups and Phase == TRAIN_PHASE['PostTrain']:\n\n    for organ_key in config.organ_type:\n\n        ModelSoupOrder = np.argsort(np.array(loss_score[organ_key])[save_epochs])[::-1]\n        max_score = -1.0\n        combine_model_epochs = []\n\n        dice = Dice_th()\n        dice.reset()\n        organ_score = {}\n        for key in config.organ_type:\n            organ_score[key] = Dice_th()\n            organ_score[key].reset()\n\n        pbar = tqdm(ModelSoupOrder, desc = 'Model Soup')\n        for loop, idx in enumerate(pbar):\n            \n            epoch = save_epochs[idx]\n            model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{epoch}.pth')\n            state_dict_sum = torch.load(model_name,map_location=torch.device('cpu'))\n\n            for cm_epoch in combine_model_epochs:\n\n                model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{cm_epoch}.pth')\n                state_dict_read = torch.load(model_name,map_location=t

In [None]:
"""
if config.model_soups and Phase == TRAIN_PHASE['PostTrain']:

    ModelSoupOrder = np.argsort(np.array(loss_score['dice_score'])[save_epochs])[::-1]
    max_score = -1.0
    combine_model_epochs = []

    dice = Dice_th()
    dice.reset()

    pbar = tqdm(ModelSoupOrder, desc = 'Model Soup')
    for loop, idx in enumerate(pbar):
        
        epoch = save_epochs[idx]
        model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{epoch}.pth')
        state_dict_sum = torch.load(model_name,map_location=torch.device('cpu'))

        for cm_epoch in combine_model_epochs:

            model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{cm_epoch}.pth')
            state_dict_read = torch.load(model_name,map_location=torch.device('cpu'))
            state_dict_sum = sum_model_params(state_dict_sum, state_dict_read)

        model = smp.create_model(arch=config.ARCH, encoder_weights=None, encoder_name=config.BACKBONE, classes=1, activation=None)
        model.load_state_dict(multi_model_params(state_dict_sum, 1.0 / (len(combine_model_epochs) + 1)))
        model = model.float()
        model = model.eval()
        model = model.to(config.device)

        with torch.no_grad():
            for img, label, organ in v_dataloader:
                x = img.to(config.device)
                label = label.to(config.device)

                with torch.cuda.amp.autocast():
                    y = model(x)
                dice.accumulate(y.float(), label)

        print(dice.value)
        if dice.value > max_score:
            max_score = dice.value
            combine_model_epochs.append(epoch)

        dice.reset()
    
    for i, cm_epoch in enumerate(combine_model_epochs):

        model_name = os.path.join(EPOCH_MODELS_PATH, MODEL_NAME+f'_{cm_epoch}.pth')
        state_dict_read = torch.load(model_name,map_location=torch.device('cpu'))
        if i == 0:
            state_dict_sum = state_dict_read.copy()
        else:
            state_dict_sum = sum_model_params(state_dict_sum, state_dict_read)

    model = smp.create_model(arch=config.ARCH, encoder_weights=None, encoder_name=config.BACKBONE, classes=1, activation=None)
    state_dict = multi_model_params(state_dict_sum, 1.0 / len(combine_model_epochs))

    torch.save(state_dict, os.path.join(SAVE_PATH, MODEL_NAME)+f'.pth')
    
    print(f'DICE_Score : {max_score}')
"""

In [None]:
#!kaggle datasets version -m "update models" -p '/content/save_models' --dir-mode zip