In [11]:
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import glob
from collections import namedtuple
import functools
import csv
import SimpleITK as sitk
from util import XyzTuple, xyz2irc
from util import enumerateWithEstimate
from logconf import logging

from disk import getCache
from dsets import PrepcacheLunaDataset, getCtSampleSize
import torch
import torch.cuda
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
log.setLevel(logging.INFO)

In [12]:
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

MaskTuple = namedtuple(
    "MaskTuple",
    "raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask",
)

CandidateInfoTuple = namedtuple(
    "CandidateInfoTuple",
    "isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz",
)

ct_path = "/media/lim/Elements/3723295"
annotations_path = os.path.join(ct_path, "annotations_with_malignancy.csv")
candidate_path = os.path.join(ct_path, "candidates.csv")
ctscan_path = os.path.join(ct_path, "subset", "subset0")
raw_cache = getCache("raw")


@functools.lru_cache(1)
def getCandidateInfoList(requireOnDisk_bool=True):
    # We construct a set with all series_uids that are present on disk.
    # This will let us use the data, even if we haven't downloaded all of
    # the subsets yet.
    mhd_list = glob.glob(os.path.join(ctscan_path, "*.mhd"))
    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}

    candidateInfo_list = []
    with open(annotations_path, "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]
            # print(series_uid)
            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                continue
            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
            annotationDiameter_mm = float(row[4])
            # isMal_bool = row[5]
            # print(isMal_bool)
            isMal_bool = {"0.0": False, "1.0": True}[row[5]]

            candidateInfo_list.append(
                CandidateInfoTuple(
                    True,
                    True,
                    isMal_bool,
                    annotationDiameter_mm,
                    series_uid,
                    annotationCenter_xyz,
                )
            )

    with open(candidate_path, "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]

            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                continue

            isNodule_bool = bool(int(row[4]))
            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])

            if not isNodule_bool:
                candidateInfo_list.append(
                    CandidateInfoTuple(
                        False,
                        False,
                        False,
                        0.0,
                        series_uid,
                        candidateCenter_xyz,
                    )
                )

    candidateInfo_list.sort(reverse=True)
    return candidateInfo_list

In [13]:
CandidateInfoList = getCandidateInfoList()

In [14]:
len(CandidateInfoList)

56928

In [15]:
@functools.lru_cache(1)
def getCandidateInfoDict(requireOnDisk_bool=True):
    candidateInfo_list = getCandidateInfoList(requireOnDisk_bool)
    candidateInfo_dict = {}

    for candidateInfo_tup in candidateInfo_list:
        candidateInfo_dict.setdefault(candidateInfo_tup.series_uid, []).append(
            candidateInfo_tup
        )

    return candidateInfo_dict


CandidateInfoDict = getCandidateInfoDict()
len(CandidateInfoDict)

89

In [16]:
class Ct:
    def __init__(self, series_uid):
        mhd_path = glob.glob(
            "/media/lim/Elements/3723295/subset/subset0/{}.mhd".format(series_uid)
        )[0]

        ct_mhd = sitk.ReadImage(mhd_path)
        self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)

        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.

        self.series_uid = series_uid

        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
        self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)

        candidateInfo_list = getCandidateInfoDict()[self.series_uid]

        self.positiveInfo_list = [
            candidate_tup
            for candidate_tup in candidateInfo_list
            if candidate_tup.isNodule_bool
        ]
        self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)
        self.positive_indexes = (
            self.positive_mask.sum(axis=(1, 2)).nonzero()[0].tolist()
        )

    def buildAnnotationMask(self, positiveInfo_list, threshold_hu=-700):
        boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)

        for candidateInfo_tup in positiveInfo_list:
            center_irc = xyz2irc(
                candidateInfo_tup.center_xyz,
                self.origin_xyz,
                self.vxSize_xyz,
                self.direction_a,
            )
            ci = int(center_irc.index)
            cr = int(center_irc.row)
            cc = int(center_irc.col)

            index_radius = 2
            try:
                while (
                    self.hu_a[ci + index_radius, cr, cc] > threshold_hu
                    and self.hu_a[ci - index_radius, cr, cc] > threshold_hu
                ):
                    index_radius += 1
            except IndexError:
                index_radius -= 1

            row_radius = 2
            try:
                while (
                    self.hu_a[ci, cr + row_radius, cc] > threshold_hu
                    and self.hu_a[ci, cr - row_radius, cc] > threshold_hu
                ):
                    row_radius += 1
            except IndexError:
                row_radius -= 1

            col_radius = 2
            try:
                while (
                    self.hu_a[ci, cr, cc + col_radius] > threshold_hu
                    and self.hu_a[ci, cr, cc - col_radius] > threshold_hu
                ):
                    col_radius += 1
            except IndexError:
                col_radius -= 1

            # assert index_radius > 0, repr([candidateInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]])
            # assert row_radius > 0
            # assert col_radius > 0

            boundingBox_a[
                ci - index_radius : ci + index_radius + 1,
                cr - row_radius : cr + row_radius + 1,
                cc - col_radius : cc + col_radius + 1,
            ] = True

        mask_a = boundingBox_a & (self.hu_a > threshold_hu)

        return mask_a

    def getRawCandidate(self, center_xyz, width_irc):
        center_irc = xyz2irc(
            center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a
        )

        slice_list = []
        for axis, center_val in enumerate(center_irc):
            start_ndx = int(round(center_val - width_irc[axis] / 2))
            end_ndx = int(start_ndx + width_irc[axis])

            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr(
                [
                    self.series_uid,
                    center_xyz,
                    self.origin_xyz,
                    self.vxSize_xyz,
                    center_irc,
                    axis,
                ]
            )

            if start_ndx < 0:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                start_ndx = 0
                end_ndx = int(width_irc[axis])

            if end_ndx > self.hu_a.shape[axis]:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                end_ndx = self.hu_a.shape[axis]
                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])

            slice_list.append(slice(start_ndx, end_ndx))

        ct_chunk = self.hu_a[tuple(slice_list)]
        pos_chunk = self.positive_mask[tuple(slice_list)]

        return ct_chunk, pos_chunk, center_irc

In [17]:
@functools.lru_cache(1, typed=True)
def getCt(series_uid):
    return Ct(series_uid)


@raw_cache.memoize(typed=True)
def getCtRawCandidate(series_uid, center_xyz, width_irc):
    ct = getCt(series_uid)
    ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
    ct_chunk.clip(-1000, 1000, ct_chunk)
    return ct_chunk, pos_chunk, center_irc


@raw_cache.memoize(typed=True)
def getCtSampleSize(series_uid):
    ct = Ct(series_uid)
    return int(ct.hu_a.shape[0]), ct.positive_indexes

In [18]:
np.bool = np.bool_


class Luna2dSegmentationDataset(Dataset):
    def __init__(
        self,
        val_stride=0,
        isValSet_bool=None,
        series_uid=None,
        contextSlices_count=3,
        fullCt_bool=False,
    ):
        self.contextSlices_count = contextSlices_count
        self.fullCt_bool = fullCt_bool

        if series_uid:
            self.series_list = [series_uid]
        else:
            self.series_list = sorted(getCandidateInfoDict().keys())

        if isValSet_bool:
            assert val_stride > 0, val_stride
            self.series_list = self.series_list[::val_stride]
            assert self.series_list
        elif val_stride > 0:
            del self.series_list[::val_stride]
            assert self.series_list

        self.sample_list = []
        for series_uid in self.series_list:
            index_count, positive_indexes = getCtSampleSize(series_uid)

            if self.fullCt_bool:
                self.sample_list += [
                    (series_uid, slice_ndx) for slice_ndx in range(index_count)
                ]
            else:
                self.sample_list += [
                    (series_uid, slice_ndx) for slice_ndx in positive_indexes
                ]

        self.candidateInfo_list = getCandidateInfoList()

        series_set = set(self.series_list)
        self.candidateInfo_list = [
            cit for cit in self.candidateInfo_list if cit.series_uid in series_set
        ]

        self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]

        log.info(
            "{!r}: {} {} series, {} slices, {} nodules".format(
                self,
                len(self.series_list),
                {None: "general", True: "validation", False: "training"}[isValSet_bool],
                len(self.sample_list),
                len(self.pos_list),
            )
        )

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, ndx):
        series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)]
        return self.getitem_fullSlice(series_uid, slice_ndx)

    def getitem_fullSlice(self, series_uid, slice_ndx):
        ct = getCt(series_uid)
        ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512))

        start_ndx = slice_ndx - self.contextSlices_count
        end_ndx = slice_ndx + self.contextSlices_count + 1
        for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
            context_ndx = max(context_ndx, 0)
            context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
            ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))

        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
        # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
        # The upper bound nukes any weird hotspots and clamps bone down
        ct_t.clamp_(-1000, 1000)

        pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0)

        return ct_t, pos_t, ct.series_uid, slice_ndx

In [19]:
ds = Luna2dSegmentationDataset()

2024-11-01 22:31:45,019 INFO     pid:55736 __main__:051:__init__ <__main__.Luna2dSegmentationDataset object at 0x704ae5557610>: 89 general series, 971 slices, 112 nodules


In [20]:
ds[0][0].shape

torch.Size([7, 512, 512])

In [31]:
import torch
import torch.nn as nn
import torch.optim
import torch
import torch.cuda
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader

from util import enumerateWithEstimate
from logconf import logging
from unet import UNet
import random
import math

In [32]:
class UNetWrapper(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        self.input_batchnorm = nn.BatchNorm2d(kwargs["in_channels"])
        self.unet = UNet(**kwargs)
        self.final = nn.Sigmoid()

        self._init_weights()

    def _init_weights(self):
        init_set = {
            nn.Conv2d,
            nn.Conv3d,
            nn.ConvTranspose2d,
            nn.ConvTranspose3d,
            nn.Linear,
        }
        for m in self.modules():
            if type(m) in init_set:
                nn.init.kaiming_normal_(
                    m.weight.data, mode="fan_out", nonlinearity="relu", a=0
                )
                if m.bias is not None:
                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(
                        m.weight.data
                    )
                    bound = 1 / math.sqrt(fan_out)
                    nn.init.normal_(m.bias, -bound, bound)

        # nn.init.constant_(self.unet.last.bias, -4)
        # nn.init.constant_(self.unet.last.bias, 4)

    def forward(self, input_batch):
        bn_output = self.input_batchnorm(input_batch)
        un_output = self.unet(bn_output)
        fn_output = self.final(un_output)
        return fn_output


class SegmentationAugmentation(nn.Module):
    def __init__(self, flip=None, offset=None, scale=None, rotate=None, noise=None):
        super().__init__()

        self.flip = flip
        self.offset = offset
        self.scale = scale
        self.rotate = rotate
        self.noise = noise

    def forward(self, input_g, label_g):
        transform_t = self._build2dTransformMatrix()
        transform_t = transform_t.expand(input_g.shape[0], -1, -1)
        transform_t = transform_t.to(input_g.device, torch.float32)
        affine_t = F.affine_grid(
            transform_t[:, :2], input_g.size(), align_corners=False
        )

        augmented_input_g = F.grid_sample(
            input_g, affine_t, padding_mode="border", align_corners=False
        )
        augmented_label_g = F.grid_sample(
            label_g.to(torch.float32),
            affine_t,
            padding_mode="border",
            align_corners=False,
        )

        if self.noise:
            noise_t = torch.randn_like(augmented_input_g)
            noise_t *= self.noise

            augmented_input_g += noise_t

        return augmented_input_g, augmented_label_g > 0.5

    def _build2dTransformMatrix(self):
        transform_t = torch.eye(3)

        for i in range(2):
            if self.flip:
                if random.random() > 0.5:
                    transform_t[i, i] *= -1

            if self.offset:
                offset_float = self.offset
                random_float = random.random() * 2 - 1
                transform_t[2, i] = offset_float * random_float

            if self.scale:
                scale_float = self.scale
                random_float = random.random() * 2 - 1
                transform_t[i, i] *= 1.0 + scale_float * random_float

        if self.rotate:
            angle_rad = random.random() * math.pi * 2
            s = math.sin(angle_rad)
            c = math.cos(angle_rad)

            rotation_t = torch.tensor([[c, -s, 0], [s, c, 0], [0, 0, 1]])

            transform_t @= rotation_t

        return transform_t

In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [35]:
augmentation_dict = {}

augmentation_dict["flip"] = True

augmentation_dict["offset"] = 0.03

augmentation_dict["scale"] = 0.2

augmentation_dict["rotate"] = True

augmentation_dict["noise"] = 25.0

use_cuda = True

In [36]:
def initModel():
    segmentation_model = UNetWrapper(
        in_channels=7,
        n_classes=1,
        depth=3,
        wf=4,
        padding=True,
        batch_norm=True,
        up_mode="upconv",
    )

    augmentation_model = SegmentationAugmentation(augmentation_dict)

    if use_cuda:
        log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
        if torch.cuda.device_count() > 1:
            segmentation_model = nn.DataParallel(segmentation_model)
            augmentation_model = nn.DataParallel(augmentation_model)
        segmentation_model = segmentation_model.to(device)
        augmentation_model = augmentation_model.to(device)

    return segmentation_model, augmentation_model


segmentation_model, augmentation_model = initModel()

2024-11-01 22:47:17,959 INFO     pid:55736 __main__:015:initModel Using CUDA; 1 devices.


In [46]:
batch_size = 32
num_workers = 8
epochs = 1

In [37]:
class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.ratio_int = 2

    def __len__(self):
        return 300000

    def shuffleSamples(self):
        random.shuffle(self.candidateInfo_list)
        random.shuffle(self.pos_list)

    def __getitem__(self, ndx):
        candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)]
        return self.getitem_trainingCrop(candidateInfo_tup)

    def getitem_trainingCrop(self, candidateInfo_tup):
        ct_a, pos_a, center_irc = getCtRawCandidate(
            candidateInfo_tup.series_uid,
            candidateInfo_tup.center_xyz,
            (7, 96, 96),
        )
        pos_a = pos_a[3:4]

        row_offset = random.randrange(0, 32)
        col_offset = random.randrange(0, 32)
        ct_t = torch.from_numpy(
            ct_a[:, row_offset : row_offset + 64, col_offset : col_offset + 64]
        ).to(torch.float32)
        pos_t = torch.from_numpy(
            pos_a[:, row_offset : row_offset + 64, col_offset : col_offset + 64]
        ).to(torch.long)

        slice_ndx = center_irc.index

        return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx

In [50]:
def initOptimizer():
    return Adam(segmentation_model.parameters())
    # return SGD(self.segmentation_model.parameters(), lr=0.001, momentum=0.99)


optimizer = initOptimizer()

In [39]:
def initTrainDl():
    train_ds = TrainingLuna2dSegmentationDataset(
        val_stride=10,
        isValSet_bool=False,
        contextSlices_count=3,
    )

    batch_size = batch_size
    if use_cuda:
        batch_size *= torch.cuda.device_count()

    train_dl = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=use_cuda,
    )

    return train_dl


def initValDl():
    val_ds = Luna2dSegmentationDataset(
        val_stride=10,
        isValSet_bool=True,
        contextSlices_count=3,
    )

    batch_size = batch_size
    if use_cuda:
        batch_size *= torch.cuda.device_count()

    val_dl = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=use_cuda,
    )

    return val_dl

In [45]:
train_ds = TrainingLuna2dSegmentationDataset(
    val_stride=10,
    isValSet_bool=False,
    contextSlices_count=3,
)


train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=use_cuda,
)

2024-11-01 23:01:08,759 INFO     pid:55736 __main__:051:__init__ <__main__.TrainingLuna2dSegmentationDataset object at 0x704af1ff69d0>: 80 training series, 890 slices, 101 nodules


In [47]:
val_ds = Luna2dSegmentationDataset(
    val_stride=10,
    isValSet_bool=True,
    contextSlices_count=3,
)

if use_cuda:
    batch_size *= torch.cuda.device_count()

val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=use_cuda,
)

2024-11-01 23:03:05,657 INFO     pid:55736 __main__:051:__init__ <__main__.Luna2dSegmentationDataset object at 0x704af1faf2d0>: 9 validation series, 81 slices, 11 nodules


In [48]:
def diceLoss(prediction_g, label_g, epsilon=1):
    diceLabel_g = label_g.sum(dim=[1, 2, 3])
    dicePrediction_g = prediction_g.sum(dim=[1, 2, 3])
    diceCorrect_g = (prediction_g * label_g).sum(dim=[1, 2, 3])

    diceRatio_g = (2 * diceCorrect_g + epsilon) / (
        dicePrediction_g + diceLabel_g + epsilon
    )

    return 1 - diceRatio_g

In [49]:
# Used for computeClassificationLoss and logMetrics to index into metrics_t/metrics_a
# METRICS_LABEL_NDX = 0
METRICS_LOSS_NDX = 1
# METRICS_FN_LOSS_NDX = 2
# METRICS_ALL_LOSS_NDX = 3

# METRICS_PTP_NDX = 4
# METRICS_PFN_NDX = 5
# METRICS_MFP_NDX = 6
METRICS_TP_NDX = 7
METRICS_FN_NDX = 8
METRICS_FP_NDX = 9

METRICS_SIZE = 10


def computeBatchLoss(
    batch_ndx, batch_tup, batch_size, metrics_g, classificationThreshold=0.5
):
    input_t, label_t, series_list, _slice_ndx_list = batch_tup

    input_g = input_t.to(device, non_blocking=True)
    label_g = label_t.to(device, non_blocking=True)

    if segmentation_model.training and augmentation_dict:
        input_g, label_g = augmentation_model(input_g, label_g)

    prediction_g = segmentation_model(input_g)

    diceLoss_g = diceLoss(prediction_g, label_g)
    fnLoss_g = diceLoss(prediction_g * label_g, label_g)

    start_ndx = batch_ndx * batch_size
    end_ndx = start_ndx + input_t.size(0)

    with torch.no_grad():
        predictionBool_g = (prediction_g[:, 0:1] > classificationThreshold).to(
            torch.float32
        )

        tp = (predictionBool_g * label_g).sum(dim=[1, 2, 3])
        fn = ((1 - predictionBool_g) * label_g).sum(dim=[1, 2, 3])
        fp = (predictionBool_g * (~label_g)).sum(dim=[1, 2, 3])

        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
        metrics_g[METRICS_TP_NDX, start_ndx:end_ndx] = tp
        metrics_g[METRICS_FN_NDX, start_ndx:end_ndx] = fn
        metrics_g[METRICS_FP_NDX, start_ndx:end_ndx] = fp

    return diceLoss_g.mean() + fnLoss_g.mean() * 8

In [51]:
def doTraining(epoch_ndx, train_dl):
    totalTrainingSamples_count = 0
    trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset), device=device)
    segmentation_model.train()
    train_dl.dataset.shuffleSamples()

    batch_iter = enumerateWithEstimate(
        train_dl,
        "E{} Training".format(epoch_ndx),
        start_ndx=train_dl.num_workers,
    )
    for batch_ndx, batch_tup in batch_iter:
        optimizer.zero_grad()

        loss_var = computeBatchLoss(
            batch_ndx, batch_tup, train_dl.batch_size, trnMetrics_g
        )
        loss_var.backward()

        optimizer.step()

    totalTrainingSamples_count += trnMetrics_g.size(1)

    return trnMetrics_g.to("cpu")

In [52]:
def doValidation(epoch_ndx, val_dl):
    with torch.no_grad():
        valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset), device=device)
        segmentation_model.eval()

        batch_iter = enumerateWithEstimate(
            val_dl,
            "E{} Validation ".format(epoch_ndx),
            start_ndx=val_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)

    return valMetrics_g.to("cpu")

In [53]:
best_score = 0.0
validation_cadence = 5
for epoch_ndx in range(1, epochs + 1):
    log.info(
        "Epoch {} of {}, {}/{} batches of size {}*{}".format(
            epoch_ndx,
            epochs,
            len(train_dl),
            len(val_dl),
            batch_size,
            (torch.cuda.device_count() if use_cuda else 1),
        )
    )

    trnMetrics_t = doTraining(epoch_ndx, train_dl)
    # logMetrics(epoch_ndx, "trn", trnMetrics_t)

    if epoch_ndx == 1 or epoch_ndx % validation_cadence == 0:
        # if validation is wanted
        valMetrics_t = doValidation(epoch_ndx, val_dl)
        # score = logMetrics(epoch_ndx, "val", valMetrics_t)
        # best_score = max(score, best_score)

        # saveModel("seg", epoch_ndx, score == best_score)

        # logImages(epoch_ndx, "trn", train_dl)
        # logImages(epoch_ndx, "val", val_dl)

2024-11-01 23:19:36,241 INFO     pid:55736 __main__:004:<module> Epoch 1 of 1, 9375/3 batches of size 32*1


KeyboardInterrupt: 