In [None]:
!nvidia-smi

## kaggle

In [None]:
!mkdir -p ~/.kaggle
!cp /content/drive/MyDrive/datas/kaggle.json  ~/.kaggle/

In [None]:
!chmod 600 /root/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d thedevastator/hubmap-2022-256x256

In [None]:
# !mkdir hubmap
!unzip /content/hubmap-2022-256x256.zip -d hubmap >/dev/null

In [None]:
!pip install pytorch-lightning

In [None]:
%load_ext tensorboard

## Start code

In [None]:
!pip install timm

In [None]:
from typing import Optional, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ConstantLR, LinearLR
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from timm import create_model

import numpy as np
import cv2
import matplotlib.pyplot as plt

import os
import albumentations as albu
import random

import pandas as pd

from sklearn.model_selection import KFold

##  load test data

In [None]:
# DATA_DIR = './data/CamVid/'

# # load repo with data if it is not exists
# if not os.path.exists(DATA_DIR):
#     print('Loading data...')
#     os.system('git clone https://github.com/alexgkendall/SegNet-Tutorial ./data')
#     print('Done!')

# x_train_dir = os.path.join(DATA_DIR, 'train')
# y_train_dir = os.path.join(DATA_DIR, 'trainannot')

# x_valid_dir = os.path.join(DATA_DIR, 'val')
# y_valid_dir = os.path.join(DATA_DIR, 'valannot')

# x_test_dir = os.path.join(DATA_DIR, 'test')
# y_test_dir = os.path.join(DATA_DIR, 'testannot')

In [None]:
SEED = 43
BATCH_SIZE = 64

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    #the following line gives ~10% speedup
    #but may lead to some stochasticity in the results 
    torch.backends.cudnn.benchmark = True
    
seed_everything(SEED)

In [None]:
DATA_DIR = './hubmap'
NFOLD = 5

x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'masks')
y_train_dir

In [None]:
ids = os.listdir(os.path.join(DATA_DIR, 'train'))
kf = KFold(n_splits=NFOLD,random_state=SEED,shuffle=True)

df = pd.DataFrame(ids, columns=['filename'])
for train, test in kf.split(ids) :
  train_file = df.iloc[train]['filename'].to_list()
  test_file = df.iloc[test]['filename'].to_list()
  break


In [None]:
train_file[:10]

## Utils

In [None]:
WIDTH, HEIGHT = 256, 256

In [None]:
def get_training_augmentation2(p=1.0):
    return albu.Compose([
        albu.HorizontalFlip(),
        albu.VerticalFlip(),
        albu.RandomRotate90(),
        albu.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, 
                         border_mode=cv2.BORDER_REFLECT),
        albu.OneOf([
            albu.OpticalDistortion(p=0.3),
            albu.GridDistortion(p=.1),
            albu.IAAPiecewiseAffine(p=0.3),
        ], p=0.3),
        albu.OneOf([
            # albu.HueSaturationValue(10,15,10),
            albu.CLAHE(clip_limit=2),
            albu.RandomBrightnessContrast(),            
        ], p=0.3),
        albu.GaussNoise(var_limit=(10.0, 50.0), p=0.3)
    ], p=p)

def get_training_augmentation():
    train_transform = [
        albu.PadIfNeeded(min_height=HEIGHT, min_width=WIDTH, always_apply=True, border_mode=0),
        albu.RandomCrop(height=HEIGHT, width=WIDTH, always_apply=True),
    ]
    return albu.Compose(train_transform)


def get_grayaug():
    train_transform = [
        albu.ToGray(p=1.0),
    ]
    return albu.Compose(train_transform)

# def get_validation_augmentation():
#     """画像のshapeが32で割り切れるようにPaddingするための関数"""
#     test_transform = [
#         albu.PadIfNeeded(384, 480)
#     ]
#     return albu.Compose(test_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

# 可視化用の関数
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

## dataset

In [None]:
mean = np.array([0.7720342, 0.74582646, 0.76392896])
std = np.array([0.24745085, 0.26182273, 0.25782376])

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))


# 1. torch.utils.data.Datasetを継承したDataset classを定義
class Dataset(torch.utils.data.Dataset):
    # CLASSES = ['sky', 'building', 'pole', 'road', 'pavement', 
    #            'tree', 'signsymbol', 'fence', 'car', 
    #            'pedestrian', 'bicyclist', 'unlabelled']
    
    def __init__(
            self, 
            images_dir, # 画像のPath
            masks_dir, # マスク画像のPath
            file_list, # ファイル名一覧
            # classes=None, # 推論対象のクラス
            augmentation=None, # augmentation用関数
            preprocessing=None, # 前処理用関数
    ):
        self.ids = file_list
        # self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        # self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        # クラス名の文字列('car', 'sky'など)をIDに変換
        # self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.to_gray = get_grayaug()
    
    # 3. 学習用データ(image)と特徴(mask)を返す__getitem__メソッドを作成
    def __getitem__(self, i):
        # データの読み込み
        image = cv2.imread(self.images_dir + "/" + self.ids[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # mask = cv2.imread(self.masks_fps[i], 0)
        mask = cv2.imread(self.masks_dir + "/" + self.ids[i], cv2.IMREAD_GRAYSCALE)


        # 学習対象のクラス(例えば、'car')のみを抽出
        # masks = [(mask == v) for v in self.class_values]
        # mask = np.stack(masks, axis=-1).astype('float')
        
        # augmentation関数の適用
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # sample = self.to_gray(image=image, mask=mask)
        # image, mask = sample['image'], sample['mask']

        # 前処理関数の適用
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # return image, mask
        return img2tensor(image/255.0),img2tensor(mask)
        # return img2tensor((image/255.0 - mean)/std),img2tensor(mask)

        # image = transforms.ToTensor()(image)
        # mask = transforms.ToTensor()(mask*255)
        # return image, mask
    # 4. データセットの長さを返す__len__を作成
    def __len__(self):
        return len(self.ids)

In [None]:
# # # データセットのインスタンスを作成
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    train_file,
    augmentation=get_training_augmentation(), 
)

print(train_dataset.__len__())
for i in range(10) :
  x,y= train_dataset[i]
  print(y.max())
  # plt.imshow(x.permute(1,2,0))
  break

In [None]:
from torch.utils.data import DataLoader
CLASSES = ['car']


# データセットのインスタンスを作成
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    train_file,
    augmentation=get_training_augmentation2(p=0.9), 
)

valid_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    test_file,
    augmentation=None, 
)

# データローダーの作成
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [None]:
row, col = 8, 8
plt.figure(figsize=(20,20))

plt.subplot(col, row, 1)
plt.imshow(x[0])
plt.axis('off')

for data, mask in train_loader:
  for i in range(64) :
    plt.subplot(col, row, i+1)
    x = data[i] .permute(1,2,0)
    plt.imshow(x)
    plt.axis('off')
  # print(data.shape)
  break

## U-net model with timm

In [None]:
""" A simple U-Net w/ timm backbone encoder
Based off an old version of Unet in https://github.com/qubvel/segmentation_models.pytorch
Hacked together by Ross Wightman
"""
class Unet(nn.Module):
    """Unet is a fully convolution neural network for image semantic segmentation
    Args:
        encoder_name: name of classification model (without last dense layers) used as feature
            extractor to build segmentation model.
        encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
        decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks
        decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
            is used.
        num_classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
        center: if ``True`` add ``Conv2dReLU`` block on encoder head
    NOTE: This is based off an old version of Unet in https://github.com/qubvel/segmentation_models.pytorch
    """

    def __init__(
            self,
            backbone='resnet50',
            backbone_kwargs=None,
            backbone_indices=None,
            decoder_use_batchnorm=True,
            decoder_channels=(256, 128, 64, 32, 16),
            in_chans=1,
            num_classes=5,
            center=False,
            norm_layer=nn.BatchNorm2d,
    ):
        super().__init__()
        backbone_kwargs = backbone_kwargs or {}
        # NOTE some models need different backbone indices specified based on the alignment of features
        # and some models won't have a full enough range of feature strides to work properly.
        encoder = create_model(
            backbone, features_only=True, out_indices=backbone_indices, in_chans=in_chans,
            pretrained=True, **backbone_kwargs)
        encoder_channels = encoder.feature_info.channels()[::-1]
        self.encoder = encoder

        if not decoder_use_batchnorm:
            norm_layer = None
        self.decoder = UnetDecoder(
            encoder_channels=encoder_channels,
            decoder_channels=decoder_channels,
            final_channels=num_classes,
            norm_layer=norm_layer,
            center=center,
        )

    def forward(self, x: torch.Tensor):
        x = self.encoder(x)
        x.reverse()  # torchscript doesn't work with [::-1]
        x = self.decoder(x)
        return x


class Conv2dBnAct(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0,
                 stride=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = norm_layer(out_channels)
        self.act = act_layer(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2.0, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
        super().__init__()
        conv_args = dict(kernel_size=3, padding=1, act_layer=act_layer)
        self.scale_factor = scale_factor
        if norm_layer is None:
            self.conv1 = Conv2dBnAct(in_channels, out_channels, **conv_args)
            self.conv2 = Conv2dBnAct(out_channels, out_channels,  **conv_args)
        else:
            self.conv1 = Conv2dBnAct(in_channels, out_channels, norm_layer=norm_layer, **conv_args)
            self.conv2 = Conv2dBnAct(out_channels, out_channels, norm_layer=norm_layer, **conv_args)

    def forward(self, x, skip: Optional[torch.Tensor] = None):
        if self.scale_factor != 1.0:
            x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class UnetDecoder(nn.Module):

    def __init__(
            self,
            encoder_channels,
            decoder_channels=(256, 128, 64, 32, 16),
            final_channels=1,
            norm_layer=nn.BatchNorm2d,
            center=False,
    ):
        super().__init__()

        if center:
            channels = encoder_channels[0]
            self.center = DecoderBlock(channels, channels, scale_factor=1.0, norm_layer=norm_layer)
        else:
            self.center = nn.Identity()

        in_channels = [in_chs + skip_chs for in_chs, skip_chs in zip(
            [encoder_channels[0]] + list(decoder_channels[:-1]),
            list(encoder_channels[1:]) + [0])]
        out_channels = decoder_channels

        self.blocks = nn.ModuleList()
        for in_chs, out_chs in zip(in_channels, out_channels):
            self.blocks.append(DecoderBlock(in_chs, out_chs, norm_layer=norm_layer))
        self.final_conv = nn.Conv2d(out_channels[-1], final_channels, kernel_size=(1, 1))

        self._init_weight()

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x: List[torch.Tensor]):
        encoder_head = x[0]
        skips = x[1:]
        x = self.center(encoder_head)
        for i, b in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = b(x, skip)
        x = self.final_conv(x)
        return x

In [None]:
# model = Unet(backbone="dla60_res2net", in_chans=3, num_classes=1)

In [None]:
# model.load_state_dict(torch.load('/content/drive/MyDrive/datas/unet/train_fold0.pth'))

## loss , etc

In [None]:
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    return loss.mean()


def calc_loss(pred, target, metrics=None, bce_weight=0.5):
    # Dice LossとCategorical Cross Entropyを混ぜていい感じにしている
    bce = F.binary_cross_entropy_with_logits(pred, target)
    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)
    loss = bce * bce_weight + dice * (1 - bce_weight)
    return loss

In [None]:
#  PyTroch version

SMOOTH = 1e-6

def dice_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
    labels = labels.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W

    intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = outputs.float().sum((1,2))  +  labels.float().sum((1,2))        # Will be zzero if both are 0

    iou = (2*intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
    # thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return iou.mean()
    # return thresholded  # Or thresholded.mean() if you are interested in average across the batch
    

def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
    labels = labels.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W

    intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (outputs | labels).float().sum((1, 2))         # Will be zzero if both are 0
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
    # thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return iou.mean()
    # return thresholded  # Or thresholded.mean() if you are interested in average across the batch
    
    
# Numpy version
# Well, it's the same function, so I'm going to omit the comments

def iou_numpy(outputs: np.array, labels: np.array):
    outputs = outputs.squeeze(1)
    
    intersection = (outputs & labels).sum((1, 2))
    union = (outputs | labels).sum((1, 2))
    
    iou = (intersection + SMOOTH) / (union + SMOOTH)
    
    thresholded = np.ceil(np.clip(20 * (iou - 0.5), 0, 10)) / 10
    
    return thresholded  # Or thresholded.mean()

In [None]:
# 損失関数
loss_fn = calc_loss

# # 評価関数
# metrics = [
#     iou_pytorch,
# ]

# # 最適化関数
# optimizer = torch.optim.Adam([ 
#     dict(params=model.parameters(), lr=0.0001),
# ])

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
class LitUnet(pl.LightningModule):
    def __init__(self, lr=0.05):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        # self.model = Unet(backbone="dla60_res2next", in_chans=3, num_classes=1)
        self.model = Unet(backbone="dla60_res2net", in_chans=3, num_classes=1)


    def forward(self, x):
        out = self.model(x)
        return out

    def training_step(self, batch, batch_idx):
        data, label = batch
        pred = self.model(data)
        loss = loss_fn(pred, label)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        data, label = batch
        pred = self.model(data)
        loss = loss_fn(pred, label)
        iou = dice_pytorch(pred.cpu()>0.5, label.cpu()>0.5).detach().item()

        self.log(f'val_loss', loss, prog_bar=True)
        self.log(f'val_iou', iou, prog_bar=True)

    def test_step(self, batch, batch_idx):
        data, label = batch
        pred = self.model(data)
        loss = loss_fn(pred, label)
        iou = dice_pytorch(pred.cpu()>0.5, label.cpu()>0.5).detach().item()

        self.log(f'test_loss', loss, prog_bar=True)
        self.log(f'test_iou', iou, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        # optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
        scheduler_dict = {
            'scheduler': LinearLR(optimizer, start_factor=1, end_factor=0.1, total_iters=50*50), #ConstantLR(optimizer,  factor=1, total_iters=100),
            'interval': 'step',
        }
        return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}

In [None]:

checkpoint_callback = ModelCheckpoint(
                   dirpath="./", 
                   save_top_k=3, 
                   monitor="val_iou",
                   mode="max",
               )

# model = LitUnet(lr=0.0002)

# trainer = pl.Trainer(
#     progress_bar_refresh_rate=1,
#     max_epochs=2,
#     gpus=1,
#     logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='unet'),
#     callbacks=[LearningRateMonitor(logging_interval='step'),
#                checkpoint_callback,
#                ],
#     precision=16,
# )

# checkpoint_callback = ModelCheckpoint(
#     save_top_k=10,
#     monitor="val_loss",
#     mode="min",
#     dirpath="my/path/",
#     filename="sample-mnist-{epoch:02d}-{val_loss:.2f}",
# )

In [None]:
# model = LitUnet(lr=0.0002)

In [None]:
ids = os.listdir(os.path.join(DATA_DIR, 'train'))
kf = KFold(n_splits=NFOLD,random_state=SEED,shuffle=True)

df = pd.DataFrame(ids, columns=['filename'])
idx = 0
for train, test in kf.split(ids) :
  train_file = df.iloc[train]['filename'].to_list()
  test_file = df.iloc[test]['filename'].to_list()
  # データセットのインスタンスを作成
  train_dataset = Dataset(
      x_train_dir, 
      y_train_dir, 
      train_file,
      augmentation=get_training_augmentation2(p=0.9), 
  )

  valid_dataset = Dataset(
      x_train_dir, 
      y_train_dir, 
      test_file,
      augmentation=None, 
  )

  # データローダーの作成
  train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
  valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

  print("[fold", idx, "]", "-"*80)
  model = LitUnet(lr=0.0002)

  trainer = pl.Trainer(
    progress_bar_refresh_rate=1,
    max_epochs=50,
    gpus=1,
    logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='unet'),
    callbacks=[LearningRateMonitor(logging_interval='step'),
               checkpoint_callback,
               ],
    precision=16,
  )
  trainer.fit(model, 
              train_loader,
              valid_loader,
              )
  model = LitUnet.load_from_checkpoint(checkpoint_path=checkpoint_callback.best_model_path)
  torch.save(model.model.state_dict(), "train_fold{}.pth".format(idx))
  idx += 1


In [None]:
!cp *.pth /content/drive/MyDrive/datas/unet

In [None]:
model.model.load_state_dict(torch.load("train_fold0.pth"))

In [None]:
trainer.test(
    model,
    dataloaders = valid_loader
    )

In [None]:
# trainer.fit(model, 
#             train_loader,
#             valid_loader,
#             )
# trainer.save_checkpoint("unet.ckpt")

In [None]:
%tensorboard --logdir lightning_logs/unet

In [None]:
# from tqdm import tqdm
# from collections import OrderedDict

# EPOCH = 40

# model.to(device)

# for i in range(EPOCH) :
#   model.train()
#   tot, cnt = 0,0
#   with tqdm(train_loader) as pbar:
#     for batch, D in enumerate(pbar):
#       data, label = D[0], D[1]
#       # print(data.shape, label.shape)
#       data = data.to(device)
#       label = label.to(device)
#       pred = model(data)
#       loss = loss_fn(pred, label)
#       optimizer.zero_grad()
#       loss.backward()
#       optimizer.step()
#       loss  = loss.detach().item()
#       tot += loss
#       cnt += 1
#       pbar.set_postfix(OrderedDict(loss=loss))
#     print("")
#     print("[EPOCH {}] train loss={}".format(i+1, tot/cnt))

#   model.eval()
#   tot, cnt = 0,0
#   iou = 0.0
#   with tqdm(valid_loader) as pbar:
#     for batch, D in enumerate(pbar):
#       data, label = D[0], D[1]
#       # print(data.shape, label.shape)
#       data = data.to(device)
#       label = label.to(device)
#       pred = model(data)
#       loss = loss_fn(pred, label)
#       loss = loss.detach().item()
#       iou += iou_pytorch(pred.cpu()>0.5, label.cpu()>0.5).detach().item()
#       # print(iou)
#       tot += loss
#       cnt += 1
#       pbar.set_postfix(OrderedDict(loss=loss))
#     print("")
#     print("[EPOCH {}] valid loss={} iou = {}".format(i+1, tot/cnt, iou/cnt))
#     print("-"*80)
#     print("")
    

In [None]:
# checkpoint_callback.best_model_path
model = LitUnet.load_from_checkpoint(checkpoint_path=checkpoint_callback.best_model_path)

In [None]:
trainer.test(
    model,
    dataloaders = valid_loader
    )

In [None]:
num = 102
x, y = valid_dataset[num]
x.shape

row, col = 3,1
plt.figure(figsize=(20,20))

plt.subplot(col, row, 1)
plt.imshow(x[0])
plt.axis('off')

plt.subplot(col, row, 2)
plt.imshow(y[0])
plt.axis('off')

y.max()

In [None]:
x.unsqueeze(0).shape

In [None]:
model.to(device)
model.eval()
pred = model.forward(x.unsqueeze(0).to(device)).cpu()

In [None]:
row, col = 3,1
plt.figure(figsize=(20,20))

plt.subplot(col, row, 1)
plt.imshow(x[0])
plt.axis('off')

plt.subplot(col, row, 2)
plt.imshow(y[0])
plt.axis('off')

plt.subplot(col, row, 3)
plt.imshow(pred[0][0].detach().numpy() > 0.5)
plt.axis('off')

In [None]:
# check iou function 
SMOOTH = 1e-6
def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
    
    intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (outputs | labels).float().sum((1, 2))         # Will be zzero if both are 0
    

    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0

    print(union, intersection, iou)
    
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return thresholded.mean(), iou.mean()  # Or thresholded.mean() if you are interested in average across the batch

iou_pytorch(pred>0.5, y>0.5), dice_pytorch(pred>0.5, y>0.5)

