# Install Library

In [1]:
# !pip install -q tqdm
!pip install -q tensorboard
!pip install -Uq segmentation-models-pytorch
!pip install -q pytorch-lightning
!pip install -q monai
!pip install -q timm
!pip install -q transformers
!pip install -q torchmetrics
# !pip install -Uq openmim
# !mim install -q mmcv-full
# !pip install -q mmsegmentation

In [2]:
from __future__ import division, print_function

import collections
import datetime
import logging
import os
import random

from monai.networks import nets
import numpy as np
import segmentation_models_pytorch as smp
import torch
import torchmetrics
import torchvision.transforms.functional as F  # TODO Fにするのは, transforms.functionalなのか、nn.functionalなのか、両方しないのか
from PIL import Image
from matplotlib import pyplot as plt
import timm
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import torchmetrics
from torchvision import transforms
from torchvision.utils import draw_segmentation_masks
from tqdm import tqdm
from transformers import ViTModel
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ModelSummary
from pytorch_lightning.loggers import TensorBoardLogger

Matplotlib is building the font cache; this may take a moment.


2023-03-06 23:11:13,579 - Created a temporary directory at /tmp/tmp12lgxvmp
2023-03-06 23:11:13,580 - Writing /tmp/tmp12lgxvmp/_remote_module_non_scriptable.py


# データの読み込み

In [3]:
def get_subject_image_idx(path: str) -> tuple:
    filename = os.path.splitext(os.path.basename(path))[0]
    subject, image_idx = map(int, filename.split('_')[:2])
    return subject, image_idx


def get_data_path():
    train_dir = '../input/ultrasound-nerve-segmentation/train'
    input_img_paths = []
    target_paths = []

    for filename in os.listdir(train_dir):
        if filename.endswith("mask.tif"):
            target_paths.append(os.path.join(train_dir, filename))
        elif filename.endswith(".tif"):
            input_img_paths.append(os.path.join(train_dir, filename))

    input_img_paths.sort(key=lambda x: get_subject_image_idx(x))
    target_paths.sort(key=lambda x: get_subject_image_idx(x))
    data_paths = [(input_img, target) for input_img, target in zip(input_img_paths, target_paths)]

    return data_paths

## データの確認

In [4]:
def in_nerve(target_path):
    target = Image.open(target_path)
    return target.getextrema()[1] > 0

num_data = len(get_data_path())
print('{} samples in total.'.format(num_data))

5635 samples in total.


## Datasetの作成

In [28]:
class UltrasoundNerveDataset(Dataset):
    """Ultrasound image and Nerve structure dataset return the PIL image."""

    def __init__(self, is_val: bool=None, val_stride: int=0, only_nerve_imgs: bool=False, balance: bool=False, transform=None):
        self.data_paths = get_data_path()
        random.Random(111).shuffle(self.data_paths)

        if only_nerve_imgs:
            self.data_paths = [(input_img, target) for input_img, target in self.data_paths if in_nerve(target)]
            assert self.data_paths
            
        if balance:
            positive_data_paths = []
            negative_data_paths = []
            for input_img, target in self.data_paths:
                if in_nerve(target):
                    positive_data_paths.append((input_img, target))
                else:
                    negative_data_paths.append((input_img, target))
            self.data_paths = positive_data_paths + negative_data_paths[:len(positive_data_paths)]
            assert self.data_paths

        if is_val:
            assert val_stride > 0, val_stride
            self.data_paths = self.data_paths[::val_stride]
            assert self.data_paths
        elif val_stride > 0:
            del self.data_paths[::val_stride]
            assert self.data_paths

        print("{} {} samples".format(len(self.data_paths), "validation" if is_val else "training"))
        self.transform = transform

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        input_img_path, target_path = self.data_paths[idx]
        img = Image.open(input_img_path)
        target = Image.open(target_path)

        sample = {'img': img, 'target': target}

        if self.transform:
            sample = self.transform(sample)

        return sample

## 前処理

In [6]:
# ToTensor
# - Image: Tensorにして、MaxMinScale
# - Target: そのままTensorにする

class PILToTensor:
    def __call__(self, sample):
        img_pil, target = sample['img'], sample['target']
        img = F.to_tensor(img_pil)
        img_vis = F.pil_to_tensor(img_pil)
        target = torch.as_tensor(np.array(target), dtype=torch.float64)
        target = torch.div(target, 255, rounding_mode='floor')
        sample = {'img': img, 'target': target, 'img_vis': img_vis}
        return sample

class Normalize:
    def __init__(self):
        # Mean and std are calculated with all data.
        self.mean = 0.3898
        self.std = 0.2219

    def __call__(self, sample):
        img = sample['img']
        img = F.normalize(img, self.mean, self.std)
        sample['img'] = img
        return sample


In [7]:
def select_transform():
    return transforms.Compose([
        PILToTensor(),
        Normalize(),
        ])


# PyTorch Lightning Data Module

In [8]:
class UltrasoundDataModule(pl.LightningDataModule):
    def __init__(self, only_nerve_imgs: bool=False, balance: bool=False, transform=None, batch_size=8):
        super().__init__()
        self.transform = transform
        if self.transform is None:
            self.transform = transforms.Compose([PILToTensor(), Normalize()])
        self.save_hyperparameters()

    def setup(self, stage):
        self.train_dataset = UltrasoundNerveDataset(is_val=False, val_stride=3, only_nerve_imgs=self.hparams.only_nerve_imgs, balance=self.hparams.balance, transform=self.transform)
        self.val_dataset = UltrasoundNerveDataset(is_val=True, val_stride=3, only_nerve_imgs=self.hparams.only_nerve_imgs, balance=self.hparams.balance, transform=self.transform)
        if stage == 'test':
            self.test_dataset = UltrasoundNerveDataset(transform=self.transform)
        if stage == 'fit':
            self.predict_dataset = UltrasoundNerveDataset(transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size, shuffle=True)

    def predict_dataloader(self):
        return DataLoader(self.predict_dataset, batch_size=self.hparams.batch_size, shuffle=True)


## Create Model

### Model for debug

In [9]:
class ModelForDebug(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, padding="same", padding_mode="zeros")
        self.conv2 = nn.Conv2d(64, 1, 3, padding="same", padding_mode="zeros")
        self.batch_norm = nn.BatchNorm2d(64)


    def forward(self, x):
        out = self.batch_norm(self.conv1(x))
        out = torch.sigmoid(self.conv2(out))
        return out


### UNet with Batch Normalization

In [10]:
class UNetBatchNormConvBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=torch.relu):
        super(UNetBatchNormConvBlock, self).__init__()
        self.batch_norm = nn.BatchNorm2d(out_size)
        self.batch_norm2 = nn.BatchNorm2d(out_size)
        self.conv = nn.Conv2d(in_size, out_size, kernel_size)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size)
        self.activation = activation

    def forward(self, x):
        out = self.activation(self.batch_norm(self.conv(x)))
        out = self.activation(self.batch_norm2(self.conv2(out)))

        return out


class UNetBatchNormUpBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=torch.relu):
        super(UNetBatchNormUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
        self.conv = nn.Conv2d(in_size, out_size, kernel_size)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size)
        self.batch_norm = nn.BatchNorm2d(out_size)
        self.batch_norm2 = nn.BatchNorm2d(out_size)
        self.activation = activation

    def center_crop(self, layer: torch.Tensor, target_height, target_width):
        batch_size, n_channels, layer_height, layer_width = layer.size()
        height_remove = (layer_height - target_height) // 2
        width_remove = (layer_width - target_width) // 2
        return layer[:, :, height_remove:(height_remove + target_height), width_remove:(width_remove + target_width)]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.size()[2], up.size()[3])
        out = torch.cat([up, crop1], 1)
        out = self.activation(self.batch_norm(self.conv(out)))
        out = self.activation(self.batch_norm2(self.conv2(out)))

        return out


class UNetBatchNorm(nn.Module):
    def __init__(self, imsize):
        super(UNetBatchNorm, self).__init__()
        self.imsize = imsize

        self.activation = torch.relu

        self.pool1 = nn.MaxPool2d(2)
        self.pool2 = nn.MaxPool2d(2)
        self.pool3 = nn.MaxPool2d(2)

        self.conv_block1_64 = UNetBatchNormConvBlock(1, 64)
        self.conv_block64_128 = UNetBatchNormConvBlock(64, 128)
        self.conv_block128_256 = UNetBatchNormConvBlock(128, 256)
        self.conv_block256_512 = UNetBatchNormConvBlock(256, 512)

        self.up_block512_256 = UNetBatchNormUpBlock(512, 256)
        self.up_block256_128 = UNetBatchNormUpBlock(256, 128)
        self.up_block128_64 = UNetBatchNormUpBlock(128, 64)

        self.last = nn.Conv2d(64, 1, 1)
        self.pad = nn.ConstantPad2d(44, 0)

    def forward(self, x):
        block1 = self.conv_block1_64(x)
        pool1 = self.pool1(block1)

        block2 = self.conv_block64_128(pool1)
        pool2 = self.pool2(block2)

        block3 = self.conv_block128_256(pool2)
        pool3 = self.pool3(block3)

        block4 = self.conv_block256_512(pool3)

        up1 = self.up_block512_256(block4, block3)

        up2 = self.up_block256_128(up1, block2)

        up3 = self.up_block128_64(up2, block1)

        return self.pad(torch.sigmoid(self.last(up3)))

### Segmentation Models Pytorch

In [11]:
class UNetPlusPlus(nn.Module):
    def __init__(self, imsize):
        super(UNetPlusPlus, self).__init__()
        self.imsize = imsize
        self.pad = nn.ConstantPad2d(14, 0)
        self.model = smp.UnetPlusPlus(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=1,
            classes=1,
            activation='sigmoid'
        )

    def forward(self, x):
        x = self.pad(x)
        out =  self.model(x)
        out = F.crop(out, 14, 14, 420, 580)
        return out

class PSPNet(nn.Module):
    def __init__(self, imsize):
        super(PSPNet, self).__init__()
        self.imsize = imsize
        self.pad = nn.ConstantPad2d(14, 0)
        self.model = smp.PSPNet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=1,
            classes=1,
            activation='sigmoid'
        )

    def forward(self, x):
        x = self.pad(x)
        out =  self.model(x)
        out = F.crop(out, 14, 14, 420, 580)
        return out

class LinkNet(nn.Module):
    def __init__(self, imsize):
            super(LinkNet, self).__init__()
            self.imsize = imsize
            self.pad = nn.ConstantPad2d(14, 0)
            self.model = smp.Linknet(
                encoder_name="resnet34",
                encoder_weights="imagenet",
                in_channels=1,
                classes=1,
                activation='sigmoid'
            )

    def forward(self, x):
        x = self.pad(x)
        out =  self.model(x)
        out = F.crop(out, 14, 14, 420, 580)
        return out

class MAnet(nn.Module):
    def __init__(self, imsize):
        super(MAnet, self).__init__()
        self.imsize = imsize
        self.pad = nn.ConstantPad2d(14, 0)
        self.model = smp.MAnet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=1,
            classes=1,
            activation='sigmoid'
        )

    def forward(self, x):
        x = self.pad(x)
        out =  self.model(x)
        out = F.crop(out, 14, 14, 420, 580)
        return out

class DeepLabV3Plus(nn.Module):
    def __init__(self, imsize):
        super(DeepLabV3Plus, self).__init__()
        self.imsize = imsize
        self.pad = nn.ConstantPad2d(14, 0)
        self.model = smp.DeepLabV3Plus(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=1,
            classes=1,
            activation='sigmoid'
        )

    def forward(self, x):
        x = self.pad(x)
        out =  self.model(x)
        out = F.crop(out, 14, 14, 420, 580)
        return out

### MONAI

In [12]:
# def vit():
#     model = nets.ViT(in_channels=1, img_size=420*580, patch_size=8)
#     return model

### Timm

In [13]:
# def swin():
#     return timm.create_model('swin_base_patch4_window7_448')

## モデルを選択する

In [14]:
def select_model(model_str: str):
    model = None
    model_str = model_str.lower()
    if model_str == "debug":
        model = ModelForDebug()
    elif model_str == 'unet_bn':
        model = UNetBatchNorm(imsize=420*580)
    elif model_str == 'unet_plus_plus':
        model = UNetPlusPlus(imsize=420*580)
    elif model_str == 'manet':
        model = MAnet(imsize=420*580)
    elif model_str == 'pspnet':
        model = PSPNet(imsize=420*580)
    elif model_str == 'linknet':
        model = LinkNet(imsize=420*580)
    elif model_str == 'deep_lab_v3_plus':
        model = DeepLabV3Plus(imsize=420*580)
    # TODO ここから下は未実装
    # elif model_str == 'vit':
    #     model = Vit()
    # elif model_str == 'swin':
    #     model = swin()

    return model

# 損失関数と指標定義

In [15]:
class DiceLoss():
    def __call__(self, prediction, target):
        loss = self.calculate_loss(prediction, target)
        return loss

    @staticmethod
    def calculate_loss(prediction, target):
        epsilon = 1

        dice_prediction = prediction.sum(dim=[1, 2, 3])    # 要素iには、バッチiの、各ピクセルについて陽性である確率の総和
        dice_label = target.sum(dim=[1, 2, 3])    # 要素iには、バッチiの陽性ラベルの数
        dice_correct = (target * prediction).sum(dim=[1, 2, 3])
        
        dice_ratio = (2 * dice_correct + epsilon) / (dice_prediction + dice_label + epsilon)

        return 1 - dice_ratio
    
    
class SegmentationMetrics():

    def __call__(self, prediction, target):
        metrics_dict = self.calculate_metrics(prediction, target)
        return metrics_dict

    @staticmethod
    def calculate_metrics(prediction, target):
        epsilon = 1e-7
        prediction = (prediction > 0.5).to(torch.float32)
        dice_prediction = prediction.sum(dim=[1, 2, 3])
        dice_label = target.sum(dim=[1, 2, 3])
        dice_correct = (prediction * target).sum(dim=[1, 2, 3])
      
        dice_coefficient = (2 * dice_correct + epsilon) / (dice_prediction + dice_label + epsilon)
  
        precision = (dice_correct + epsilon) / (dice_prediction + epsilon)
        recall = (dice_correct + epsilon) / (dice_label + epsilon)
        f1 = 2 * (precision * recall) / (precision + recall)
        
        metrics_dict = {
            'dice_coefficient': dice_coefficient.mean(),
            'precision': precision.mean(),
            'recall': recall.mean(),
            'f1': f1.mean()
        }

        return metrics_dict

# PyTorch Lightning Module

In [22]:
class LitSegmentationModule(pl.LightningModule):
    def __init__(self, model, lr=0.01):
        super().__init__()
        self.model = select_model(model_str=model)
        self.lr = lr
        self.save_hyperparameters()
        self.example_input_array = torch.Tensor(8, 1, 420, 580)
        self.loss_func = self.configure_loss_function()
        self.metrics_calculater = self.configure_metrics()
        self.training_step_outputs = {}
        self.validation_step_outputs = {}
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        img, target = batch['img'], batch['target']
        target = target.view(img.shape)
        prediction = self.model(img).view(img.shape)
        loss = self.loss_func(prediction, target)
        loss_mean = loss.mean()
        self.log('train_loss', loss_mean, on_epoch=True)
        output_dict = self.draw_mask_on_image(batch, prediction, target)
        output_dict['loss'] = loss
        self.training_step_outputs = output_dict

        return loss_mean
    
    
    def validation_step(self, batch, batch_idx):
        img, target = batch['img'], batch['target']
        target = target.view(img.shape)
        prediction = self.model(img)
        loss = self.loss_func(prediction, target)
        loss_mean = loss.mean()
        metrics = self.metrics_calculater(prediction, target)
        metrics['val_loss'] = loss_mean
        self.log_dict(metrics, on_epoch=True)
        output_dict = self.draw_mask_on_image(batch, prediction, target)
        output_dict['loss'] = loss
        self.validation_step_outputs = output_dict
        return loss_mean

    def validation_epoch_end(self, validation_step_outputs):
        if not self.training_step_outputs:
            return
        training_outputs = self.training_step_outputs
        validation_outputs = self.validation_step_outputs
        logger = self.logger.experiment
        
        epochs = self.current_epoch
        
        for mode_str, output in zip(['training', 'val'], [training_outputs, validation_outputs]):
            self.add_image(mode_str, output, epochs, logger)
        
    def draw_mask_on_image(self, batch, predictions, target):
        self.model.eval()

        # visualize image
        imgs_vis = batch['img_vis'].cpu()

        # target mask
        targets = torch.squeeze(target, 1)
        targets = targets.to(torch.bool).cpu()

        # output mask
        predictions = predictions >= 0.5
        predictions = torch.squeeze(predictions, 1).cpu()

        # input image
        rg = torch.zeros(imgs_vis.shape[0:1] +  (2,) + imgs_vis.shape[2:4], dtype=torch.uint8).cpu()
        imgs_rgb = torch.cat((rg, imgs_vis), 1).to(dtype=torch.uint8)

        # mask
        output_with_mask = [draw_segmentation_masks(img_rgb, masks=output, alpha=.3, colors="#FFFFFF") for img_rgb, output in zip(imgs_rgb, predictions)]
        target_with_mask = [draw_segmentation_masks(img_rgb, masks=target, alpha=.3, colors="#FFFFFF") for img_rgb, target in zip(imgs_rgb, targets)]
        
        output_dict = {'output': output_with_mask, 'target': target_with_mask}
        return output_dict

    @staticmethod
    def add_image(mode_str, output, epochs, logger):
        for image_ndx, (output_mask, target_mask, loss) in enumerate(zip(output['output'], output['target'], output['loss'])):

            logger.add_image(f'{mode_str}/E{epochs}_#{image_ndx}_prediction_{loss.item():.3f}loss', output_mask, dataformats='CHW')
            logger.add_image(f'{mode_str}/E{epochs}_#{image_ndx}_label', target_mask, dataformats='CHW')
            logger.flush()

    
    def configure_loss_function(self):
#         loss_func = torchmetrics.Dice(zero_division=1e-7, num_classes=2, average=None)
        loss_func = DiceLoss()
        return loss_func

    def configure_metrics(self):
        metrics_calculater = SegmentationMetrics()
        return metrics_calculater

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

# Execution

In [31]:
# root_dir
default_root_dir = os.path.join(os.path.dirname(os.path.abspath(os.getcwd())), 'working')
log_dir = os.path.join(default_root_dir, 'logs')

# Reproducibility
seed_everything = pl.seed_everything(42, workers=True)

# Dataloader
only_nerve_imgs = False
balance = True
data_module = UltrasoundDataModule()

# Callbacks
early_stopping = EarlyStopping(monitor='val_loss', verbose=True, patience=3)
model_checkpoint = ModelCheckpoint(monitor='val_loss', verbose=True)
model_summary = ModelSummary(max_depth=5)

callbacks = [early_stopping, model_checkpoint, model_summary]

# Model
# model = "debug", "Init", "UNet", "UNet_Pad", "UNet_BN", "UNet_with_library", "MANet", "Vit", "Swin"
model = LitSegmentationModule(model="Deep_Lab_v3_plus")

# Logger
model_name = type(model.model).__name__
data_mode = 'only_nerve' if only_nerve_imgs else 'balance' if balance else 'full_data'
JST = datetime.timezone(datetime.timedelta(hours=+9), 'JST')
now_str = datetime.datetime.now(JST).strftime('%Y%m%d_%H.%M')
log_name = '_'.join([now_str, model_name, data_mode, 'logs'])
tensorboard_logger = TensorBoardLogger(save_dir=log_dir, name=log_name)


# Training

# TODO When submit the final results, change benchmark=False, and deterministic=True to ensure reproducibility.
trainer = pl.Trainer(fast_dev_run=False, devices='auto', accelerator='auto', default_root_dir=default_root_dir, logger=tensorboard_logger,  callbacks=callbacks, max_epochs=40, check_val_every_n_epoch=3, auto_scale_batch_size='power', auto_lr_find=True, benchmark=False, num_sanity_val_steps=0, enable_progress_bar=True)
trainer.tune(model, data_module)
trainer.fit(model, data_module)

2023-03-07 00:25:14,704 - Global seed set to 42
2023-03-07 00:25:15,108 - Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
2023-03-07 00:25:15,123 - GPU available: True (cuda), used: True
2023-03-07 00:25:15,124 - TPU available: False, using: 0 TPU cores
2023-03-07 00:25:15,124 - IPU available: False, using: 0 IPUs
2023-03-07 00:25:15,125 - HPU available: False, using: 0 HPUs
2023-03-07 00:25:15,127 - Missing logger folder: /home/ec2-user/SageMaker/working/logs/20230307_09.25_DeepLabV3Plus_balance_logs
<__main__.UltrasoundNerveDataset object at 0x7f5741e36f70>: 3756 training samples
<__main__.UltrasoundNerveDataset object at 0x7f5741e365e0>: 1879 validation samples
<__main__.UltrasoundNerveDataset object at 0x7f5741e365b0>: 5635 training samples
2023-03-07 00:25:15,314 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
2023-03-07 00:25:16,061 - `Trainer.fit` stopped:

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]

2023-03-07 00:28:32,246 - `Trainer.fit` stopped: `max_steps=100` reached.
2023-03-07 00:28:32,248 - Learning rate set to 0.0009120108393559097
2023-03-07 00:28:32,249 - Restoring states from the checkpoint path at /home/ec2-user/SageMaker/working/.lr_find_61a8b2a1-501b-4c9e-856e-44e79a140070.ckpt
2023-03-07 00:28:32,464 - Restored all states from the checkpoint file at /home/ec2-user/SageMaker/working/.lr_find_61a8b2a1-501b-4c9e-856e-44e79a140070.ckpt
<__main__.UltrasoundNerveDataset object at 0x7f57321908e0>: 3756 training samples
<__main__.UltrasoundNerveDataset object at 0x7f573869b3a0>: 1879 validation samples
<__main__.UltrasoundNerveDataset object at 0x7f583cb3b130>: 5635 training samples
2023-03-07 00:28:32,743 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
2023-03-07 00:28:32,773 - 
   | Name                                       | Type                 | Params | In sizes                                                                                                       | Out si

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

2023-03-07 00:39:36,137 - Metric val_loss improved. New best score: 0.418
2023-03-07 00:39:36,140 - Epoch 2, global step 354: 'val_loss' reached 0.41771 (best 0.41771), saving model to '/home/ec2-user/SageMaker/working/logs/20230307_09.25_DeepLabV3Plus_balance_logs/version_0/checkpoints/epoch=2-step=354.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

2023-03-07 00:50:24,771 - Epoch 5, global step 708: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

2023-03-07 01:01:58,544 - Epoch 8, global step 1062: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

2023-03-07 01:13:29,920 - Monitored metric val_loss did not improve in the last 3 records. Best score: 0.418. Signaling Trainer to stop.
2023-03-07 01:13:29,924 - Epoch 11, global step 1416: 'val_loss' was not in top 1


In [None]:
timm.list_models('*swin*')