In [49]:
import torch
import torch.nn as nn
import numpy as np
from data_loader import data_load
from unet_pytorch import build_unet
from ramps import get_current_consistency_weight, update_ema_variables
from glob import glob
import tensorflow as tf
from time import time
from datetime import datetime
from monai.data import decollate_batch
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.transforms import (
    Compose,
    AsDiscrete,
    EnsureType,
)
import os
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()
def load_dataset(batch, folder, label=True, buffer_size=1000):
    features = BANDS + TARGETS if label else BANDS
    tf_files = glob(f"{folder}/*.gz")
    columns = [
        tf.io.FixedLenFeature(
            shape=KERNEL_SHAPE if label else BUFFERED_SHAPE, dtype=tf.float32
        )
        for _feature in features
    ]
    description = dict(zip(features, columns))
    data_func = data_load(
        tf_files,
        BANDS,
        description,
        response=TARGETS,
        batch_size=batch,
        buffer_size=buffer_size,
    )
    data = (
        data_func.get_training_dataset()
        if label
        else data_func.get_pridiction_dataset()
    )
    # print(tf_files)
    return data


BANDS = ["blue", "green", "red", "nir", "swir1", "swir2", "ndvi", "nirv"]
KERNEL_SHAPE = [256, 256]
KERNEL_BUFFER = [128, 128]
X_BUFFER, Y_BUFFER = [buffer // 2 for buffer in KERNEL_BUFFER]
X_BUFFERED, Y_BUFFERED = (X_BUFFER + KERNEL_SHAPE[0]), (Y_BUFFER + KERNEL_SHAPE[1])
BUFFERED_SHAPE = [
    kernel + buffer for kernel, buffer in zip(KERNEL_SHAPE, KERNEL_BUFFER)
]
TARGETS = ["cropland"]
NCLASS = 2
model_folder = f"/bess23/huaize/semi-supervised/models/"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
run_time = datetime.today().strftime("%m_%d_%H_%M_%S")
batch = 128


dataset = load_dataset(
    batch, "/bess23/huaize/semi-supervised/data/labeled/train/", label=True
)
val_dataset = load_dataset(
    batch, "/bess23/huaize/semi-supervised/data/labeled/valid", label=True
)
test_dataset = load_dataset(
    batch, "/bess23/huaize/semi-supervised/data/unlabeled", label=False
)

model = build_unet(len(BANDS), NCLASS).cuda()
ema_model = build_unet(len(BANDS), NCLASS).cuda()
model = nn.DataParallel(model)
ema_model = nn.DataParallel(ema_model)
model.to(device)
ema_model.to(device)

max_epochs = 1000
MeanTeacherEpoch = 50
lr = 3e-4
opt = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
dice_metric = DiceMetric(include_background=False, reduction="mean")

# %% train
max_epochs = 1000
MeanTeacherEpoch = 50
val_interval = 1
best_metric = -1
best_metric_epoch = -1
iter_num = 0
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=NCLASS)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=NCLASS)])

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    start_time = time()
    model.train()
    epoch_loss = 0
    step = 0
    train_loader = dataset.as_numpy_iterator()
    val_loader = val_dataset.as_numpy_iterator()
    unlabeled_train_loader = test_dataset.as_numpy_iterator()
    for labeled_batch, unlabeled_batch in zip(train_loader, unlabeled_train_loader):
        step += 1
        labeled_inputs, labels = (
            torch.tensor(labeled_batch[0]).to(device),
            torch.tensor(labeled_batch[1]).to(device),
        )
        unlabeled_batch = unlabeled_batch[
            slice(None), slice(None), X_BUFFER:X_BUFFERED, Y_BUFFER:Y_BUFFERED
        ]
        unlabeled_inputs = torch.tensor(unlabeled_batch).to(device)
        opt.zero_grad()
        noise_labeled = torch.clamp(torch.randn_like(labeled_inputs) * 0.1, -0.2, 0.2)
        noise_unlabeled = torch.clamp(
            torch.randn_like(unlabeled_inputs) * 0.1, -0.2, 0.2
        )
        noise_labeled_inputs = labeled_inputs + noise_labeled
        noise_unlabeled_inputs = unlabeled_inputs + noise_unlabeled

        outputs = model(labeled_inputs)
        with torch.no_grad():
            soft_out = torch.softmax(outputs, dim=1)
            outputs_unlabeled = model(unlabeled_inputs)
            soft_unlabeled = torch.softmax(outputs_unlabeled, dim=1)
            outputs_aug = ema_model(noise_labeled_inputs)
            soft_aug = torch.softmax(outputs_aug, dim=1)
            outputs_unlabeled_aug = ema_model(noise_unlabeled_inputs)
            soft_unlabeled_aug = torch.softmax(outputs_unlabeled_aug, dim=1)

        supervised_loss = loss_function(outputs, labels)
        if epoch < MeanTeacherEpoch:
            consistency_loss = 0.0
        else:
            consistency_loss = torch.mean((soft_out - soft_aug) ** 2) + torch.mean(
                (soft_unlabeled - soft_unlabeled_aug) ** 2
            )
        consistency_weight = get_current_consistency_weight(iter_num // 150)
        iter_num += 1
        loss = supervised_loss + consistency_weight * consistency_loss
        loss.backward()
        opt.step()
        update_ema_variables(model, ema_model, 0.99, iter_num)
        epoch_loss += loss.item()
        print(
            # f"{step}/{len(unlabeled_train_ds) // unlabeled_train_loader.batch_size}, "
            f"train_loss: {loss.item():.4f}"
        )
        writer.add_scalar("Loss/train", loss.item(), epoch)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)

    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    torch.tensor(val_data[0]).to(device),
                    torch.tensor(val_data[1]).to(device),
                )
                val_outputs = model(val_inputs)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            print(f"val dice: {metric}")
            writer.add_scalar("val dice", metric, epoch)
            # reset the status for next validation round
            dice_metric.reset()

        metric_values.append(metric)
        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(
                model.module.state_dict(),
                os.path.join(model_folder, f"best_{run_time}.pth"),
            )
            print("saved new best metric model")
        print(
            f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
            f"\nbest mean dice: {best_metric:.4f} "
            f"at epoch: {best_metric_epoch}"
        )
        
    print(f"epoch time = {time() - start_time}")

----------
epoch 1/1000


2023-02-20 16:50:25.799313: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 863 of 1000
2023-02-20 16:50:27.537125: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6715
train_loss: 0.6638
epoch 1 average loss: 0.6677


2023-02-20 16:50:45.397752: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 841 of 1000
2023-02-20 16:50:47.705159: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.06324220448732376
saved new best metric model
current epoch: 1 current mean dice: 0.0632
best mean dice: 0.0632 at epoch: 1
epoch time = 66.15352702140808
----------
epoch 2/1000


2023-02-20 16:51:31.927782: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 872 of 1000
2023-02-20 16:51:33.555652: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6588
train_loss: 0.6563
epoch 2 average loss: 0.6575


2023-02-20 16:51:51.404000: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 864 of 1000
2023-02-20 16:51:53.183581: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.06324220448732376
current epoch: 2 current mean dice: 0.0632
best mean dice: 0.0632 at epoch: 1
epoch time = 64.47439932823181
----------
epoch 3/1000


2023-02-20 16:52:36.407489: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 881 of 1000
2023-02-20 16:52:37.913131: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6557
train_loss: 0.6540
epoch 3 average loss: 0.6549


2023-02-20 16:52:55.873923: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 863 of 1000
2023-02-20 16:52:57.655387: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.0443502813577652
current epoch: 3 current mean dice: 0.0444
best mean dice: 0.0632 at epoch: 1
epoch time = 64.49993944168091
----------
epoch 4/1000


2023-02-20 16:53:40.901301: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 885 of 1000
2023-02-20 16:53:42.337597: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6547
train_loss: 0.6472
epoch 4 average loss: 0.6509


2023-02-20 16:54:00.059807: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 842 of 1000
2023-02-20 16:54:02.214330: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.0004262778384145349
current epoch: 4 current mean dice: 0.0004
best mean dice: 0.0632 at epoch: 1
epoch time = 64.15835094451904
----------
epoch 5/1000


2023-02-20 16:54:45.059036: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 873 of 1000
2023-02-20 16:54:46.667772: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6501
train_loss: 0.6384
epoch 5 average loss: 0.6443


2023-02-20 16:55:04.423478: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 858 of 1000
2023-02-20 16:55:06.269858: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.0007091833977028728
current epoch: 5 current mean dice: 0.0007
best mean dice: 0.0632 at epoch: 1
epoch time = 64.1013195514679
----------
epoch 6/1000


2023-02-20 16:55:49.166468: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 886 of 1000
2023-02-20 16:55:50.590364: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6421
train_loss: 0.6424
epoch 6 average loss: 0.6423


2023-02-20 16:56:08.470297: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 870 of 1000
2023-02-20 16:56:10.138322: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.0030716643668711185
current epoch: 6 current mean dice: 0.0031
best mean dice: 0.0632 at epoch: 1
epoch time = 63.92766094207764
----------
epoch 7/1000


2023-02-20 16:56:53.095886: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 866 of 1000
2023-02-20 16:56:54.814008: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6407
train_loss: 0.6421
epoch 7 average loss: 0.6414


2023-02-20 16:57:12.697591: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 855 of 1000
2023-02-20 16:57:14.596703: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.010214196518063545
current epoch: 7 current mean dice: 0.0102
best mean dice: 0.0632 at epoch: 1
epoch time = 64.32542490959167
----------
epoch 8/1000


2023-02-20 16:57:57.425751: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 885 of 1000
2023-02-20 16:57:58.865437: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6382
train_loss: 0.6343
epoch 8 average loss: 0.6363


2023-02-20 16:58:16.680343: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 867 of 1000
2023-02-20 16:58:18.403946: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.016067251563072205
current epoch: 8 current mean dice: 0.0161
best mean dice: 0.0632 at epoch: 1
epoch time = 63.746607065200806
----------
epoch 9/1000


2023-02-20 16:59:01.174642: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 880 of 1000
2023-02-20 16:59:02.690624: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6335
train_loss: 0.6359
epoch 9 average loss: 0.6347


2023-02-20 16:59:20.567136: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 842 of 1000
2023-02-20 16:59:22.646092: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.02213538996875286
current epoch: 9 current mean dice: 0.0221
best mean dice: 0.0632 at epoch: 1
epoch time = 64.9588794708252
----------
epoch 10/1000


2023-02-20 17:00:06.118672: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 865 of 1000
2023-02-20 17:00:07.826002: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6340
train_loss: 0.6298
epoch 10 average loss: 0.6319


2023-02-20 17:00:25.665590: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 868 of 1000
2023-02-20 17:00:27.387951: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.03220274671912193
current epoch: 10 current mean dice: 0.0322
best mean dice: 0.0632 at epoch: 1
epoch time = 64.24831557273865
----------
epoch 11/1000


2023-02-20 17:01:10.371837: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 862 of 1000
2023-02-20 17:01:12.131059: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6301
train_loss: 0.6288
epoch 11 average loss: 0.6295


2023-02-20 17:01:30.067050: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 861 of 1000
2023-02-20 17:01:31.886291: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.04833564534783363
current epoch: 11 current mean dice: 0.0483
best mean dice: 0.0632 at epoch: 1
epoch time = 63.9665162563324
----------
epoch 12/1000


2023-02-20 17:02:14.329148: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 868 of 1000
2023-02-20 17:02:16.038507: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6250
train_loss: 0.6285
epoch 12 average loss: 0.6268


2023-02-20 17:02:33.939066: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 856 of 1000
2023-02-20 17:02:35.823066: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.05443393439054489
current epoch: 12 current mean dice: 0.0544
best mean dice: 0.0632 at epoch: 1
epoch time = 63.909377574920654
----------
epoch 13/1000


2023-02-20 17:03:18.246678: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 842 of 1000
2023-02-20 17:03:20.304624: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6277
train_loss: 0.6241
epoch 13 average loss: 0.6259


2023-02-20 17:03:38.379943: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 855 of 1000
2023-02-20 17:03:40.263439: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.06668917089700699
saved new best metric model
current epoch: 13 current mean dice: 0.0667
best mean dice: 0.0667 at epoch: 13
epoch time = 64.69853782653809
----------
epoch 14/1000


2023-02-20 17:04:22.953645: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 876 of 1000
2023-02-20 17:04:24.544671: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6273
train_loss: 0.6302
epoch 14 average loss: 0.6288


2023-02-20 17:04:42.446498: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 860 of 1000
2023-02-20 17:04:44.266579: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.06294019520282745
current epoch: 14 current mean dice: 0.0629
best mean dice: 0.0667 at epoch: 13
epoch time = 63.95479345321655
----------
epoch 15/1000


2023-02-20 17:05:26.909358: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 885 of 1000
2023-02-20 17:05:28.357094: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6239
train_loss: 0.6246
epoch 15 average loss: 0.6242


2023-02-20 17:05:46.164057: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 853 of 1000
2023-02-20 17:05:48.104675: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.06285203248262405
current epoch: 15 current mean dice: 0.0629
best mean dice: 0.0667 at epoch: 13
epoch time = 64.84436774253845
----------
epoch 16/1000


2023-02-20 17:06:31.758774: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 844 of 1000
2023-02-20 17:06:33.742988: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6238
train_loss: 0.6239
epoch 16 average loss: 0.6238


2023-02-20 17:06:51.594299: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 845 of 1000
2023-02-20 17:06:53.639008: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.06822775304317474
saved new best metric model
current epoch: 16 current mean dice: 0.0682
best mean dice: 0.0682 at epoch: 16
epoch time = 64.83030247688293
----------
epoch 17/1000


2023-02-20 17:07:36.587917: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 878 of 1000
2023-02-20 17:07:38.128531: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6232
train_loss: 0.6261
epoch 17 average loss: 0.6247


2023-02-20 17:07:55.961886: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 859 of 1000
2023-02-20 17:07:57.777244: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.07096101343631744
saved new best metric model
current epoch: 17 current mean dice: 0.0710
best mean dice: 0.0710 at epoch: 17
epoch time = 63.82534837722778
----------
epoch 18/1000


2023-02-20 17:08:40.411954: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 882 of 1000
2023-02-20 17:08:41.900527: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6264
train_loss: 0.6246
epoch 18 average loss: 0.6255


2023-02-20 17:08:59.683016: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 844 of 1000
2023-02-20 17:09:01.765423: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.05570559203624725
current epoch: 18 current mean dice: 0.0557
best mean dice: 0.0710 at epoch: 17
epoch time = 64.58151578903198
----------
epoch 19/1000


2023-02-20 17:09:44.989727: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 879 of 1000
2023-02-20 17:09:46.526507: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6182
train_loss: 0.6235
epoch 19 average loss: 0.6209


2023-02-20 17:10:04.606923: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 852 of 1000
2023-02-20 17:10:06.557804: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.08475010842084885
saved new best metric model
current epoch: 19 current mean dice: 0.0848
best mean dice: 0.0848 at epoch: 19
epoch time = 64.81890416145325
----------
epoch 20/1000


2023-02-20 17:10:49.813410: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 875 of 1000
2023-02-20 17:10:51.447646: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6190
train_loss: 0.6176
epoch 20 average loss: 0.6183


2023-02-20 17:11:09.342445: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 853 of 1000
2023-02-20 17:11:11.265853: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.08755668997764587
saved new best metric model
current epoch: 20 current mean dice: 0.0876
best mean dice: 0.0876 at epoch: 20
epoch time = 64.39274334907532
----------
epoch 21/1000


2023-02-20 17:11:54.203443: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 885 of 1000
2023-02-20 17:11:55.651329: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6181
train_loss: 0.6213
epoch 21 average loss: 0.6197


2023-02-20 17:12:13.458591: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 863 of 1000
2023-02-20 17:12:15.243446: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.0723801851272583
current epoch: 21 current mean dice: 0.0724
best mean dice: 0.0876 at epoch: 20
epoch time = 63.75405502319336
----------
epoch 22/1000


2023-02-20 17:12:57.967667: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 852 of 1000
2023-02-20 17:12:59.868599: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6196
train_loss: 0.6200
epoch 22 average loss: 0.6198


2023-02-20 17:13:17.843697: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 860 of 1000
2023-02-20 17:13:19.683686: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


val dice: 0.056619495153427124
current epoch: 22 current mean dice: 0.0566
best mean dice: 0.0876 at epoch: 20
epoch time = 64.61617732048035
----------
epoch 23/1000


2023-02-20 17:14:02.561264: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 865 of 1000
2023-02-20 17:14:04.265851: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


train_loss: 0.6205
train_loss: 0.6162
epoch 23 average loss: 0.6184


2023-02-20 17:14:22.283624: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:392] Filling up shuffle buffer (this may take a while): 851 of 1000
2023-02-20 17:14:24.224725: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:417] Shuffle buffer filled.


In [18]:
epoch_loss

0