In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [89]:
import os
import collections
import pandas
# import intel_extension_for_pytorch as ipex

import numpy as np
import skimage.draw
import pathlib
import torchvision

file_names = []
class Echo(torchvision.datasets.VisionDataset):
    

    def __init__(self, root=None,
                 split="train", target_type="EF",
                 mean=0., std=1.,
                 length=16, period=2,
                 max_length=250,
                 clips=1,
                 pad=None,
                 rotate=None,
                 noise=None,
                 target_transform=None,
                 external_test_location=None):
        super().__init__(root, target_transform=target_transform)

        super(Echo, self).__init__(root, target_transform=target_transform)

        if root is None:
            root = "/content/EchoNet-Dynamic"

        self.root = pathlib.Path(root)
        self.split = split.upper()
        if not isinstance(target_type, list):
            target_type = [target_type]
        self.target_type = target_type
        self.mean = mean
        self.std = std
        self.length = length
        self.max_length = max_length
        self.period = period
        self.clips = clips
        self.pad = pad
        self.rotate = rotate
        self.noise = noise
        self.target_transform = target_transform
        self.external_test_location = external_test_location

        self.fnames, self.outcome = [], []

        if self.split == "EXTERNAL_TEST":
            self.fnames = sorted(os.listdir(self.external_test_location))
        else:
            # Load video-level labels
            with open(self.root / "FileList_new.csv") as f:
                data = pandas.read_csv(f)
            data["Split"].map(lambda x: x.upper())

            if self.split != "ALL":
                data = data[data["Split"] == self.split]

            self.header = data.columns.tolist()
            self.fnames = data["FileName"].tolist()
            self.fnames = [fn + ".avi" for fn in self.fnames if os.path.splitext(fn)[1] == ""]  # Assume avi if no suffix
            print("videos way before: ",len(self.fnames))
            self.outcome = data.values.tolist()
            
#             # Check that files are present
            missing = set(self.fnames) - set(os.listdir(os.path.join(self.root, "Videos")))
            if len(missing) != 0:
                print("{} videos could not be found in {}:".format(len(missing), os.path.join(self.root, "Videos")))
                for f in sorted(missing):
                    print("\t", f)
                raise FileNotFoundError(os.path.join(self.root, "Videos", sorted(missing)[0]))

            # Load traces
            self.frames = collections.defaultdict(list)
            self.trace = collections.defaultdict(_defaultdict_of_lists)

            with open(self.root / "VolumeTracings_new.csv") as f:
                header = f.readline().strip().split(",")
                if header == ["FileName", "X1", "Y1", "X2", "Y2", "Frame"]:
                    for line in f:
                        filename, x1, y1, x2, y2, frame = line.strip().split(',')
                        x1 = float(x1)
                        y1 = float(y1)
                        x2 = float(x2)
                        y2 = float(y2)
                        frame = int(frame)
                        if frame not in self.trace[filename]:
                            self.frames[filename].append(frame)
                        self.trace[filename][frame].append((x1, y1, x2, y2))
                if header == ["FileName", "X", "Y", "Frame"]:
                    # TODO: probably could merge
                    for line in f:
                        values = line.strip().split(',')
                        # if len(values) != 4 or any(v.strip() == '' for v in values):
                        #     print(f"Skipping invalid line: {line.strip()}")
                        #     continue
                        filename, x, y, frame = values
                        x = float(x)
                        y = float(y)
                        frame = int(frame)
                        # filename, x, y, frame = line.strip().split(',')
                        # x = float(x)
                        # y = float(y)
                        # frame = int(frame)
                        if frame not in self.trace[filename]:
                            self.frames[filename].append(frame)
                        self.trace[filename][frame].append((x, y))
            for filename in self.frames:
                for frame in self.frames[filename]:
                    self.trace[filename][frame] = np.array(self.trace[filename][frame])
            print("videos before: ",len(self.fnames))
            file_names.append(self.fnames)
            file_names.append(self.frames)
            keep = [len(self.frames[f]) >= 2 for f in self.fnames]
            self.fnames = [f for (f, k) in zip(self.fnames, keep) if k]
            print("videos : ",len(self.fnames))
            self.outcome = [f for (f, k) in zip(self.outcome, keep) if k]

    def __getitem__(self, index):
        # Find filename of video
        if self.split == "EXTERNAL_TEST":
            video = os.path.join(self.external_test_location, self.fnames[index])
        elif self.split == "CLINICAL_TEST":
            video = os.path.join(self.root, "ProcessedStrainStudyA4c", self.fnames[index])
        else:
            video = os.path.join(self.root, "Videos", self.fnames[index])

        # Load video into np.array
        video = loadvideo(video).astype(np.float32)

        # Add simulated noise (black out random pixels)
        # 0 represents black at this point (video has not been normalized yet)
        if self.noise is not None:
            n = video.shape[1] * video.shape[2] * video.shape[3]
            ind = np.random.choice(n, round(self.noise * n), replace=False)
            f = ind % video.shape[1]
            ind //= video.shape[1]
            i = ind % video.shape[2]
            ind //= video.shape[2]
            j = ind
            video[:, f, i, j] = 0

        # Apply normalization
        if isinstance(self.mean, (float, int)):
            video -= self.mean
        else:
            video -= self.mean.reshape(3, 1, 1, 1)

        if isinstance(self.std, (float, int)):
            video /= self.std
        else:
            video /= self.std.reshape(3, 1, 1, 1)

        # Set number of frames
        c, f, h, w = video.shape
        if self.length is None:
            # Take as many frames as possible
            length = f // self.period
        else:
            # Take specified number of frames
            length = self.length

        if self.max_length is not None:
            # Shorten videos to max_length
            length = min(length, self.max_length)

        if f < length * self.period:
            # Pad video with frames filled with zeros if too short
            # 0 represents the mean color (dark grey), since this is after normalization
            video = np.concatenate((video, np.zeros((c, length * self.period - f, h, w), video.dtype)), axis=1)
            c, f, h, w = video.shape  # pylint: disable=E0633

        if self.clips == "all":
            # Take all possible clips of desired length
            start = np.arange(f - (length - 1) * self.period)
        else:
            # Take random clips from video
            start = np.random.choice(f - (length - 1) * self.period, self.clips)

        # Gather targets
        target = []
        for t in self.target_type:
            key = self.fnames[index]
            if t == "Filename":
                target.append(self.fnames[index])
            elif t == "LargeIndex":
                # Traces are sorted by cross-sectional area
                # Largest (diastolic) frame is last
                target.append(np.int64(self.frames[key][-1]))
            elif t == "SmallIndex":
                # Largest (diastolic) frame is first
                target.append(np.int64(self.frames[key][0]))
            elif t in ["LargeFrame", "SmallFrame"]:
                if t == "LargeFrame":
                    frame = self.frames[key][-1]
                else:
                    frame = self.frames[key][0]

                if frame is None or frame >= video.shape[1]:
                    target.append(np.full((video.shape[0], video.shape[2], video.shape[3]), math.nan, video.dtype))
                else:
                    target.append(video[:, frame, :, :])
            elif t in ["LargeTrace", "SmallTrace"]:
                if t == "LargeTrace":
                    frame = self.frames[key][-1]
                else:
                    frame = self.frames[key][0]
                if frame is None or frame >= video.shape[1]:
                    mask = np.full((video.shape[2], video.shape[3]), math.nan, np.float32)
                else:
                    t = self.trace[key][frame]

                    if t.shape[1] == 4:
                        x1, y1, x2, y2 = t[:, 0], t[:, 1], t[:, 2], t[:, 3]
                        x = np.concatenate((x1[1:], np.flip(x2[1:])))
                        y = np.concatenate((y1[1:], np.flip(y2[1:])))
                    else:
                        assert t.shape[1] == 2
                        x, y = t[:, 0], t[:, 1]

                    r, c = skimage.draw.polygon(np.rint(y).astype(np.int64), np.rint(x).astype(np.int64), (video.shape[2], video.shape[3]))
                    mask = np.zeros((video.shape[2], video.shape[3]), np.float32)
                    mask[r, c] = 1
                target.append(mask)
            else:
                if self.split == "CLINICAL_TEST" or self.split == "EXTERNAL_TEST":
                    target.append(np.float32(0))
                else:
                    # target.append(np.float32(self.outcome[index][self.header.index(t)]))  # TODO: is floating necessary
                    target.append(self.outcome[index][self.header.index(t)])

        if target != []:
            target = tuple(target) if len(target) > 1 else target[0]
            if self.target_transform is not None:
                target = self.target_transform(target)

        # Select clips from video
        video = tuple(video[:, s + self.period * np.arange(length), :, :] for s in start)
        if self.clips == 1:
            video = video[0]
        else:
            video = np.stack(video)

        if self.pad is not None:
            # Add padding of zeros (mean color of videos)
            # Crop of original size is taken out
            # (Used as augmentation)
            c, l, h, w = video.shape
            temp = np.zeros((c, l, h + 2 * self.pad, w + 2 * self.pad), dtype=video.dtype)
            temp[:, :, self.pad:-self.pad, self.pad:-self.pad] = video  # pylint: disable=E1130
            i, j = np.random.randint(0, 2 * self.pad, 2)
            video = temp[:, :, i:(i + h), j:(j + w)]

        return video, target

    def __len__(self):
        return len(self.fnames)

    def extra_repr(self) -> str:
        """Additional information to add at end of __repr__."""
        lines = ["Target type: {target_type}", "Split: {split}"]
        return '\n'.join(lines).format(**self.__dict__)


def _defaultdict_of_lists():
    return collections.defaultdict(list)

In [90]:
data_dir = "/kaggle/input/echonet-pediatric/Dataset/A4C"
mean, std = get_mean_and_std(Echo(root=data_dir, split="train"))

videos way before:  2580
videos before:  2580
videos :  2483


100%|██████████| 16/16 [00:01<00:00,  8.02it/s]


In [50]:
"""Utility functions for videos, plotting and computing performance metrics."""

import os
import typing

import cv2  # pytype: disable=attribute-error
import matplotlib
import numpy as np
import torch
import tqdm


def loadvideo(filename: str) -> np.ndarray:
    

    if not os.path.exists(filename):
        raise FileNotFoundError(filename)
    capture = cv2.VideoCapture(filename)

    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))

    v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8)

    for count in range(frame_count):
        ret, frame = capture.read()
        if not ret:
            raise ValueError("Failed to load frame #{} of {}.".format(count, filename))

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        v[count, :, :] = frame

    v = v.transpose((3, 0, 1, 2))

    return v


def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1):
    

    c, _, height, width = array.shape

    if c != 3:
        raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape))))
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(filename, fourcc, fps, (width, height))

    for frame in array.transpose((1, 2, 3, 0)):
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame)


def get_mean_and_std(dataset: torch.utils.data.Dataset,
                     samples: int = 128,
                     batch_size: int = 8,
                     num_workers: int = 4):
    

    if samples is not None and len(dataset) > samples:
        indices = np.random.choice(len(dataset), samples, replace=False)
        dataset = torch.utils.data.Subset(dataset, indices)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)

    n = 0  # number of elements taken (should be equal to samples by end of for loop)
    s1 = 0.  # sum of elements along channels (ends up as np.array of dimension (channels,))
    s2 = 0.  # sum of squares of elements along channels (ends up as np.array of dimension (channels,))
    for (x, *_) in tqdm.tqdm(dataloader):
        x = x.transpose(0, 1).contiguous().view(3, -1)
        n += x.shape[1]
        s1 += torch.sum(x, dim=1).numpy()
        s2 += torch.sum(x ** 2, dim=1).numpy()
    mean = s1 / n  # type: np.ndarray
    std = np.sqrt(s2 / n - mean ** 2)  # type: np.ndarray

    mean = mean.astype(np.float32)
    std = std.astype(np.float32)

    return mean, std


def bootstrap(a, b, func, samples=10000):
   
    a = np.array(a)
    b = np.array(b)

    bootstraps = []
    for _ in range(samples):
        ind = np.random.choice(len(a), len(a))
        bootstraps.append(func(a[ind], b[ind]))
    bootstraps = sorted(bootstraps)

    return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))]


def latexify():
    
    params = {'backend': 'pdf',
              'axes.titlesize': 8,
              'axes.labelsize': 8,
              'font.size': 8,
              'legend.fontsize': 8,
              'xtick.labelsize': 8,
              'ytick.labelsize': 8,
              'font.family': 'DejaVu Serif',
              'font.serif': 'Computer Modern',
              }
    matplotlib.rcParams.update(params)


def dice_similarity_coefficient(inter, union):
    """Computes the dice similarity coefficient.

    Args:
        inter (iterable): iterable of the intersections
        union (iterable): iterable of the unions
    """
    return 2 * sum(inter) / (sum(union) + sum(inter))

# R(2+1)D Model : Undone
 "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

In [55]:
# import math

# import torch.nn as nn
# from torch.nn.modules.utils import _triple

# class SpatioTemporalConv(nn.Module):
#     def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
#         super(SpatioTemporalConv, self).__init__()
#         kernel_size = _triple(kernel_size)
#         stride = _triple(stride)
#         padding = _triple(padding)
#         spatial_kernel_size =  [1, kernel_size[1], kernel_size[2]]
#         spatial_stride =  [1, stride[1], stride[2]]
#         spatial_padding =  [0, padding[1], padding[2]]
#         temporal_kernel_size = [kernel_size[0], 1, 1]
#         temporal_stride =  [stride[0], 1, 1]
#         temporal_padding =  [padding[0], 0, 0]
#         intermed_channels = int(math.floor((kernel_size[0] * kernel_size[1] * kernel_size[2] * in_channels * out_channels)/ \
#                             (kernel_size[1]* kernel_size[2] * in_channels + kernel_size[0] * out_channels)))
#         self.spatial_conv = nn.Conv3d(in_channels, intermed_channels, spatial_kernel_size,
#                                     stride=spatial_stride, padding=spatial_padding, bias=bias)
#         self.bn = nn.BatchNorm3d(intermed_channels)
#         self.relu = nn.ReLU()
#         self.temporal_conv = nn.Conv3d(intermed_channels, out_channels, temporal_kernel_size, 
#                                     stride=temporal_stride, padding=temporal_padding, bias=bias)
#     def forward(self, x):
#         x = self.relu(self.bn(self.spatial_conv(x)))
#         x = self.temporal_conv(x)
#         return x


# class SpatioTemporalResBlock(nn.Module):
#     r"""Single block for the ResNet network. Uses SpatioTemporalConv in 
#         the standard ResNet block layout (conv->batchnorm->ReLU->conv->batchnorm->sum->ReLU)
        
#         Args:
#             in_channels (int): Number of channels in the input tensor.
#             out_channels (int): Number of channels in the output produced by the block.
#             kernel_size (int or tuple): Size of the convolving kernels.
#             downsample (bool, optional): If ``True``, the output size is to be smaller than the input. Default: ``False``
#         """
#     def __init__(self, in_channels, out_channels, kernel_size, downsample=False):
#         super(SpatioTemporalResBlock, self).__init__()
#         self.downsample = downsample
#         padding = kernel_size//2

#         if self.downsample:
#             # downsample with stride =2 the input x
#             self.downsampleconv = SpatioTemporalConv(in_channels, out_channels, 1, stride=2)
#             self.downsamplebn = nn.BatchNorm3d(out_channels)
#             self.conv1 = SpatioTemporalConv(in_channels, out_channels, kernel_size, padding=padding, stride=2)
#         else:
#             self.conv1 = SpatioTemporalConv(in_channels, out_channels, kernel_size, padding=padding)
#         self.bn1 = nn.BatchNorm3d(out_channels)
#         self.relu1 = nn.ReLU()
#         self.conv2 = SpatioTemporalConv(out_channels, out_channels, kernel_size, padding=padding)
#         self.bn2 = nn.BatchNorm3d(out_channels)
#         self.outrelu = nn.ReLU()

#     def forward(self, x):
#         res = self.relu1(self.bn1(self.conv1(x)))    
#         res = self.bn2(self.conv2(res))

#         if self.downsample:
#             x = self.downsamplebn(self.downsampleconv(x))

#         return self.outrelu(x + res)


# class SpatioTemporalResLayer(nn.Module):
#     r"""Forms a single layer of the ResNet network, with a number of repeating 
#     blocks of same output size stacked on top of each other
        
#         Args:
#             in_channels (int): Number of channels in the input tensor.
#             out_channels (int): Number of channels in the output produced by the layer.
#             kernel_size (int or tuple): Size of the convolving kernels.
#             layer_size (int): Number of blocks to be stacked to form the layer
#             block_type (Module, optional): Type of block that is to be used to form the layer. Default: SpatioTemporalResBlock. 
#             downsample (bool, optional): If ``True``, the first block in layer will implement downsampling. Default: ``False``
#         """

#     def __init__(self, in_channels, out_channels, kernel_size, layer_size, block_type=SpatioTemporalResBlock, downsample=False):
        
#         super(SpatioTemporalResLayer, self).__init__()
#         self.block1 = block_type(in_channels, out_channels, kernel_size, downsample)
#         self.blocks = nn.ModuleList([])
#         for i in range(layer_size - 1):
#             self.blocks += [block_type(out_channels, out_channels, kernel_size)]

#     def forward(self, x):
#         x = self.block1(x)
#         for block in self.blocks:
#             x = block(x)

#         return x


# class R2Plus1DNet(nn.Module):
#     r"""Forms the overall ResNet feature extractor by initializng 5 layers, with the number of blocks in 
#     each layer set by layer_sizes, and by performing a global average pool at the end producing a 
#     512-dimensional vector for each element in the batch.
        
#         Args:
#             layer_sizes (tuple): An iterable containing the number of blocks in each layer
#             block_type (Module, optional): Type of block that is to be used to form the layers. Default: SpatioTemporalResBlock. 
#         """
#     def __init__(self, layer_sizes, block_type=SpatioTemporalResBlock):
#         super(R2Plus1DNet, self).__init__()
#         self.conv1 = SpatioTemporalConv(3, 64, [3, 7, 7], stride=[1, 2, 2], padding=[1, 3, 3])
#         self.conv2 = SpatioTemporalResLayer(64, 64, 3, layer_sizes[0], block_type=block_type)
#         self.conv3 = SpatioTemporalResLayer(64, 128, 3, layer_sizes[1], block_type=block_type, downsample=True)
#         self.conv4 = SpatioTemporalResLayer(128, 256, 3, layer_sizes[2], block_type=block_type, downsample=True)
#         self.conv5 = SpatioTemporalResLayer(256, 512, 3, layer_sizes[3], block_type=block_type, downsample=True)
#         self.pool = nn.AdaptiveAvgPool3d(1)
    
#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.conv2(x)
#         x = self.conv3(x)
#         x = self.conv4(x)
#         x = self.conv5(x)

#         x = self.pool(x)
        
#         return x.view(-1, 512)

# class R2Plus1DClassifier(nn.Module):
#     def __init__(self, num_classes, layer_sizes, block_type=SpatioTemporalResBlock):
#         super(R2Plus1DClassifier, self).__init__()

#         self.res2plus1d = R2Plus1DNet(layer_sizes, block_type)
#         self.linear = nn.Linear(512, num_classes)

#     def forward(self, x):
#         x = self.res2plus1d(x)
#         x = self.linear(x) 

#         return x   

# r21d_model = R2Plus1DClassifier(num_classes=1, layer_sizes = [2, 2, 2, 2])

# C3D Model : Done
"Learning Spatiotemporal Features with 3D Convolutional Networks"

In [56]:
import torch.nn as nn
import torch
import torch.nn.functional as F


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class DoubleConv3D(nn.Module):
    """(convolution => [BN] => ReLU) * 2 for 3D"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Up3D(nn.Module):
    """Upscaling then double conv for 3D"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
            self.conv = DoubleConv3D(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv3D(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # Compute padding sizes for depth, height, and width
        diffD = x2.size()[2] - x1.size()[2]  # Depth difference
        diffH = x2.size()[3] - x1.size()[3]  # Height difference
        diffW = x2.size()[4] - x1.size()[4]  # Width difference

        x1 = F.pad(x1, [diffW // 2, diffW - diffW // 2,
                        diffH // 2, diffH - diffH // 2,
                        diffD // 2, diffD - diffD // 2])
        x = torch.cat([x2, x1], dim=1)  # Concatenate along channel dimension
        return self.conv(x)

class C3D(nn.Module):
    """
    The C3D network as described in
    Tran, Du, et al. "Learning spatiotemporal features with 3d convolutional networks."
    Proceedings of the IEEE international conference on computer vision. 2015.
    """

    def __init__(self, num_classes, input_channel=3):
        super(C3D, self).__init__()

        self.feature1 = nn.Sequential(
            nn.Conv3d(input_channel, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
        )
        self.feature2 = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        )
        self.feature3 = nn.Sequential(
            nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)),
        )
        self.feature4 = nn.Sequential(
            nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        )
        self.feature5 = nn.Sequential(
            nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.AdaptiveMaxPool3d(output_size=(1, 4, 4))
        )

        self.fc = nn.Sequential(
            nn.Linear(8192, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout3d(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout3d(0.5),
            nn.Linear(4096, num_classes)
        )
        self.up1 = Up3D(in_channels=512, out_channels=256, bilinear=False)
        self.up2 = Up3D(in_channels=256, out_channels=128, bilinear=False)
        self.up3 = Up3D(in_channels=128, out_channels=64, bilinear=False)

      
        self.final_up = nn.Sequential(
            nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear', align_corners=True),  # Only upscale width & height
            nn.Conv3d(64, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True)
        )
      
        self.out_seg = OutConv(32, num_classes)
        self.__init_weight()

    def forward(self, x):
        x1 = self.feature1(x)
        x2 = self.feature2(x1)
        x3 = self.feature3(x2)
        x4 = self.feature4(x3)
        x5 = self.feature5(x4)
        
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.final_up(x)
        x = self.out_seg(x)
        logits = self.fc(x5.view(-1, 8192))
        return logits,x

    def __init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

c3d_model = C3D(num_classes=1)

# X3D Model : Undone
"X3D: Expanding Architectures for Efficient Video Recognition models"

In [57]:
# import math
# from functools import partial

# import torch
# import torch.nn as nn
# import torch.nn.functional as F


# class SubBatchNorm3d(nn.Module):
#     """ FROM SLOWFAST """
#     def __init__(self, num_splits, **args):
#         super(SubBatchNorm3d, self).__init__()
#         self.num_splits = num_splits
#         self.num_features = args["num_features"]
#         # Keep only one set of weight and bias.
#         if args.get("affine", True):
#             self.affine = True
#             args["affine"] = False
#             self.weight = torch.nn.Parameter(torch.ones(self.num_features))
#             self.bias = torch.nn.Parameter(torch.zeros(self.num_features))
#         else:
#             self.affine = False
#         self.bn = nn.BatchNorm3d(**args)
#         args["num_features"] = self.num_features * self.num_splits
#         self.split_bn = nn.BatchNorm3d(**args)

#     def _get_aggregated_mean_std(self, means, stds, n):
#         mean = means.view(n, -1).sum(0) / n
#         std = (
#             stds.view(n, -1).sum(0) / n
#             + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n
#         )
#         return mean.detach(), std.detach()

#     def aggregate_stats(self):
#         """Synchronize running_mean, and running_var. Call this before eval."""
#         if self.split_bn.track_running_stats:
#             (
#                 self.bn.running_mean.data,
#                 self.bn.running_var.data,
#             ) = self._get_aggregated_mean_std(
#                 self.split_bn.running_mean,
#                 self.split_bn.running_var,
#                 self.num_splits,
#             )

#     def forward(self, x):
#         if self.training:
#             n, c, t, h, w = x.shape
#             x = x.view(n // self.num_splits, c * self.num_splits, t, h, w)
#             x = self.split_bn(x)
#             x = x.view(n, c, t, h, w)
#         else:
#             x = self.bn(x)
#         if self.affine:
#             x = x * self.weight.view((-1, 1, 1, 1))
#             x = x + self.bias.view((-1, 1, 1, 1))
#         return x


# class Swish(nn.Module):
#     """ FROM SLOWFAST """
#     """Swish activation function: x * sigmoid(x)."""
#     def __init__(self):
#         super(Swish, self).__init__()

#     def forward(self, x):
#         return SwishEfficient.apply(x)


# class SwishEfficient(torch.autograd.Function):
#     """ FROM SLOWFAST """
#     """Swish activation function: x * sigmoid(x)."""
#     @staticmethod
#     def forward(ctx, x):
#         result = x * torch.sigmoid(x)
#         ctx.save_for_backward(x)
#         return result

#     @staticmethod
#     def backward(ctx, grad_output):
#         x = ctx.saved_variables[0]
#         sigmoid_x = torch.sigmoid(x)
#         return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))


# def conv3x3x3(in_planes, out_planes, stride=1):
#     return nn.Conv3d(in_planes,
#                      out_planes,
#                      kernel_size=3,
#                      stride=(1,stride,stride),
#                      padding=1,
#                      bias=False,
#                      groups=in_planes
#                      )


# def conv1x1x1(in_planes, out_planes, stride=1):
#     return nn.Conv3d(in_planes,
#                      out_planes,
#                      kernel_size=1,
#                      stride=(1,stride,stride),
#                      bias=False)


# class Bottleneck(nn.Module):
#     def __init__(self, in_planes, planes, stride=1, downsample=None, index=0, base_bn_splits=8):
#         super(Bottleneck, self).__init__()

#         self.index = index
#         self.base_bn_splits = base_bn_splits
#         self.conv1 = conv1x1x1(in_planes, planes[0])
#         self.bn1 = SubBatchNorm3d(num_splits=self.base_bn_splits, num_features=planes[0], affine=True) #nn.BatchNorm3d(planes[0])
#         self.conv2 = conv3x3x3(planes[0], planes[0], stride)
#         self.bn2 = SubBatchNorm3d(num_splits=self.base_bn_splits, num_features=planes[0], affine=True) #nn.BatchNorm3d(planes[0])
#         self.conv3 = conv1x1x1(planes[0], planes[1])
#         self.bn3 = SubBatchNorm3d(num_splits=self.base_bn_splits, num_features=planes[1], affine=True) #nn.BatchNorm3d(planes[1])
#         self.swish = Swish() #nn.Hardswish()
#         self.relu = nn.ReLU(inplace=True)
#         if self.index % 2 == 0:
#             width = self.round_width(planes[0])
#             self.global_pool = nn.AdaptiveAvgPool3d((1,1,1))
#             self.fc1 = nn.Conv3d(planes[0], width, kernel_size=1, stride=1)
#             self.fc2 = nn.Conv3d(width, planes[0], kernel_size=1, stride=1)
#             self.sigmoid = nn.Sigmoid()
#         self.downsample = downsample
#         self.stride = stride

#     def round_width(self, width, multiplier=0.0625, min_width=8, divisor=8):
#         if not multiplier:
#             return width

#         width *= multiplier
#         min_width = min_width or divisor
#         width_out = max(
#             min_width, int(width + divisor / 2) // divisor * divisor
#         )
#         if width_out < 0.9 * width:
#             width_out += divisor
#         return int(width_out)


#     def forward(self, x):
#         residual = x

#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.relu(out)

#         out = self.conv2(out)
#         out = self.bn2(out)
#         # Squeeze-and-Excitation
#         if self.index % 2 == 0:
#             se_w = self.global_pool(out)
#             se_w = self.fc1(se_w)
#             se_w = self.relu(se_w)
#             se_w = self.fc2(se_w)
#             se_w = self.sigmoid(se_w)
#             out = out * se_w
#         out = self.swish(out)

#         out = self.conv3(out)
#         out = self.bn3(out)

#         if self.downsample is not None:
#             residual = self.downsample(x)

#         out += residual
#         out = self.relu(out)

#         return out


# class ResNet(nn.Module):

#     def __init__(self,
#                  block,
#                  layers,
#                  block_inplanes,
#                  n_input_channels=3,
#                  shortcut_type='B',
#                  widen_factor=1.0,
#                  dropout=0.5,
#                  n_classes=1,
#                  base_bn_splits=8,
#                  task='class'):
#         super(ResNet, self).__init__()

#         block_inplanes = [(int(x * widen_factor),int(y * widen_factor)) for x,y in block_inplanes]
#         self.index = 0
#         self.base_bn_splits = base_bn_splits
#         self.task = task

#         self.in_planes = block_inplanes[0][1]

#         self.conv1_s = nn.Conv3d(n_input_channels,
#                                self.in_planes,
#                                kernel_size=(1, 3, 3),
#                                stride=(1, 2, 2),
#                                padding=(0, 1, 1),
#                                bias=False)
#         self.conv1_t = nn.Conv3d(self.in_planes,
#                                self.in_planes,
#                                kernel_size=(5, 1, 1),
#                                stride=(1, 1, 1),
#                                padding=(2, 0, 0),
#                                bias=False,
#                                groups=self.in_planes)
#         self.bn1 = SubBatchNorm3d(num_splits=self.base_bn_splits, num_features=self.in_planes, affine=True) #nn.BatchNorm3d(self.in_planes)
#         self.relu = nn.ReLU(inplace=True)
#         self.layer1 = self._make_layer(block,
#                                        block_inplanes[0],
#                                        layers[0],
#                                        shortcut_type,
#                                        stride=2)
#         self.layer2 = self._make_layer(block,
#                                        block_inplanes[1],
#                                        layers[1],
#                                        shortcut_type,
#                                        stride=2)
#         self.layer3 = self._make_layer(block,
#                                        block_inplanes[2],
#                                        layers[2],
#                                        shortcut_type,
#                                        stride=2)
#         self.layer4 = self._make_layer(block,
#                                        block_inplanes[3],
#                                        layers[3],
#                                        shortcut_type,
#                                        stride=2)
#         self.conv5 = nn.Conv3d(block_inplanes[3][1],
#                                block_inplanes[3][0],
#                                kernel_size=(1, 1, 1),
#                                stride=(1, 1, 1),
#                                padding=(0, 0, 0),
#                                bias=False)
#         self.bn5 = SubBatchNorm3d(num_splits=self.base_bn_splits, num_features=block_inplanes[3][0], affine=True) #nn.BatchNorm3d(block_inplanes[3][0])
#         if task == 'class':
#             self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
#         elif task == 'loc':
#             self.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1))
#         self.fc1 = nn.Conv3d(block_inplanes[3][0], 2048, bias=False, kernel_size=1, stride=1)
#         self.fc2 = nn.Linear(2048, n_classes)
#         self.dropout = nn.Dropout(dropout)

#         for m in self.modules():
#             if isinstance(m, nn.Conv3d):
#                 nn.init.kaiming_normal_(m.weight,
#                                         mode='fan_out',
#                                         nonlinearity='relu')

#     def _downsample_basic_block(self, x, planes, stride):
#         out = F.avg_pool3d(x, kernel_size=1, stride=stride)
#         zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2),
#                                 out.size(3), out.size(4))
#         if isinstance(out.data, torch.cuda.FloatTensor):
#             zero_pads = zero_pads.cuda()

#         out = torch.cat([out.data, zero_pads], dim=1)

#         return out

#     def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
#         downsample = None
#         if stride != 1 or self.in_planes != planes[1]:
#             if shortcut_type == 'A':
#                 downsample = partial(self._downsample_basic_block,
#                                      planes=planes[1],
#                                      stride=stride)
#             else:
#                 downsample = nn.Sequential(
#                     conv1x1x1(self.in_planes, planes[1], stride),
#                     SubBatchNorm3d(num_splits=self.base_bn_splits, num_features=planes[1], affine=True) #nn.BatchNorm3d(planes[1])
#                     )

#         layers = []
#         layers.append(
#             block(in_planes=self.in_planes,
#                   planes=planes,
#                   stride=stride,
#                   downsample=downsample,
#                   index=self.index,
#                   base_bn_splits=self.base_bn_splits))
#         self.in_planes = planes[1]
#         self.index += 1
#         for i in range(1, blocks):
#             layers.append(block(self.in_planes, planes, index=self.index, base_bn_splits=self.base_bn_splits))
#             self.index += 1

#         self.index = 0
#         return nn.Sequential(*layers)


#     def replace_logits(self, n_classes):
#         self.fc2 = nn.Linear(2048, n_classes)


#     def update_bn_splits_long_cycle(self, long_cycle_bn_scale):
#         for m in self.modules():
#             if isinstance(m, SubBatchNorm3d):
#                 m.num_splits = self.base_bn_splits * long_cycle_bn_scale
#                 m.split_bn = nn.BatchNorm3d(num_features=m.num_features*m.num_splits, affine=False).to(m.weight.device)
#         return self.base_bn_splits * long_cycle_bn_scale


#     def aggregate_sub_bn_stats(self):
#         """find all SubBN modules and aggregate sub-BN stats."""
#         count = 0
#         for m in self.modules():
#             if isinstance(m, SubBatchNorm3d):
#                 m.aggregate_stats()
#                 count += 1
#         return count


#     def forward(self, x):
#         x = self.conv1_s(x)
#         print(x.shape)
#         x = self.conv1_t(x)
#         print(x.shape)
#         x = self.bn1(x)
#         print(x.shape)
#         x = self.relu(x)
#         print(x.shape)

#         x = self.layer1(x)
#         print(x.shape)
#         x = self.layer2(x)
#         print(x.shape)
#         x = self.layer3(x)
#         print(x.shape)
#         x = self.layer4(x)
#         print(x.shape)

#         x = self.conv5(x)
#         print(x.shape)
#         x = self.bn5(x)
#         print(x.shape)
#         x = self.relu(x)
#         print(x.shape)

#         x = self.avgpool(x)
#         print(x.shape)

#         x = self.fc1(x)
#         print(x.shape)
#         x = self.relu(x)
#         print(x.shape)

#         if self.task == 'class':
#             x = x.squeeze(4).squeeze(3).squeeze(2) # B C
#             x = self.dropout(x)
#             x = self.fc2(x)
#         if self.task == 'loc':
#             x = x.squeeze(4).squeeze(3).permute(0,2,1) # B T C
#             x = self.dropout(x)
#             x = self.fc2(x).permute(0,2,1) # B C T

#         return x


# def replace_logits(self, n_classes):
#         self.fc2 = nn.Linear(2048, n_classes)


# def get_inplanes(version):
#     planes = {'S':[(54,24), (108,48), (216,96), (432,192)],
#               'M':[(54,24), (108,48), (216,96), (432,192)],
#               'XL':[(72,32), (162,72), (306,136), (630,280)]}
#     return planes[version]


# def get_blocks(version):
#     blocks = {'S':[3,5,11,7],
#               'M':[3,5,11,7],
#               'XL':[5,10,25,15]}
#     return blocks[version]


# def generate_model(x3d_version, **kwargs):
#     model = ResNet(Bottleneck, get_blocks(x3d_version), get_inplanes(x3d_version), **kwargs)
#     return model

# x3d_model = generate_model('S')

# Training : MultiTask

In [100]:
data_dir = "/kaggle/input/echonet-pediatric/Dataset/A4C"
mean, std = get_mean_and_std(Echo(root=data_dir, split="train"))
frames=32
kwargs = {"target_type": ["EF","LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace", "LargeIndex", "SmallIndex"],
            "mean": mean,
            "std": std,
            "length": 200,
            "period": 1
            }


print(f'\n\nMean : {mean} \nStandard Deviation : {std}')

print("\n\nThis is what I have got the Echonet Dynamic Dataset - Adults Echo : ")
print("""Mean : [33.66532  33.742973 33.911003] 
Standard Deviation : [50.45345  50.4825   50.614986]\n\n""")

dataset = {}

dataset["train"] = Echo(root=data_dir, split="train", **kwargs)
dataset["val"] = Echo(root=data_dir, split="val", **kwargs)

videos way before:  2580
videos before:  2580
videos :  2483


100%|██████████| 16/16 [00:01<00:00,  8.08it/s]




Mean : [26.098072 24.678114 28.73842 ] 
Standard Deviation : [45.68355  42.283268 47.268494]


This is what I have got the Echonet Dynamic Dataset - Adults Echo : 
Mean : [33.66532  33.742973 33.911003] 
Standard Deviation : [50.45345  50.4825   50.614986]


videos way before:  2580
videos before:  2580
videos :  2483
videos way before:  336
videos before:  336
videos :  326


In [107]:
ds = dataset['val']
dataloader = torch.utils.data.DataLoader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=True, drop_last=False)
for i,j in dataloader:
    print(i.shape)
    break

torch.Size([16, 3, 200, 112, 112])


In [111]:
selected_frames = i[torch.arange(16), :, j[-1], :, :]

In [123]:
import math
model = c3d_model
optim = torch.optim.SGD(model.parameters(), lr=1e-5, momentum=0.9, weight_decay=1e-5)
lr_step_period = math.inf
scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_epochs = 1
batch_size = 4
num_workers = 4
output_seg = os.path.join("MultiTask Model")
os.makedirs(output_seg, exist_ok=True)

In [124]:
model = model.to(device)

In [61]:
def collate_fn(x):
    x, f = zip(*x)
    i = list(map(lambda t: t.shape[1], x))
    x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))
    return x, f, i

In [126]:
def seg_run_epoch(model, dataloader, train, optim, device, save_all=False, block_size=25):
    total_loss = 0.0  # cumulative loss (as a Python scalar)
    n = 0           # number of samples processed
    pos = torch.tensor(0.0, device=device)
    neg = torch.tensor(0.0, device=device)
    pos_pix = None  # will hold an (H,W) tensor accumulated over batches
    neg_pix = None
    s1 = 0.0      # sum of ground truth EF (unused here, but kept for consistency)
    s2 = 0.0      # sum of ground truth EF squared

    model.train(train)

    # Initialize accumulators for segmentation intersection/union
    large_inter = torch.tensor(0.0, device=device)
    large_union = torch.tensor(0.0, device=device)
    small_inter = torch.tensor(0.0, device=device)
    small_union = torch.tensor(0.0, device=device)
    large_inter_list = []
    large_union_list = []
    small_inter_list = []
    small_union_list = []
    yhat_list = []
    y_list = []  # left empty since the original code did not add any elements

    with torch.set_grad_enabled(train):
        with tqdm.tqdm(total=len(dataloader)) as pbar:
            for batch_idx, (X_input, (outcome, large_frame, small_frame,
                                        large_trace, small_trace,
                                        large_index, small_index)) in enumerate(dataloader):
                # Move inputs and segmentation ground truths to GPU
                X_input = X_input.to(device)
                outcome = outcome.to(device)
                large_trace = large_trace.to(device)
                small_trace = small_trace.to(device)
                # Ensure indices are on the proper device (assumed to be 1D LongTensor)
                large_index = large_index.to(device)
                small_index = small_index.to(device)

                # Count pixels for the human segmentation (all on GPU)
                pos += (large_trace == 1).sum()
                pos += (small_trace == 1).sum()
                neg += (large_trace == 0).sum()
                neg += (small_trace == 0).sum()

                # Count pixels for the computer segmentation.
                # Note: since (trace == 1).sum(dim=0) returns a tensor (e.g. shape (H, W)),
                # we accumulate it batchwise. (This assumes that the spatial dimensions stay constant.)
                current_pos_pix = (large_trace == 1).sum(dim=0) + (small_trace == 1).sum(dim=0)
                current_neg_pix = (large_trace == 0).sum(dim=0) + (small_trace == 0).sum(dim=0)
                if pos_pix is None:
                    pos_pix = current_pos_pix
                    neg_pix = current_neg_pix
                else:
                    pos_pix = pos_pix + current_pos_pix
                    neg_pix = neg_pix + current_neg_pix

                # Accumulate EF statistics (converted to CPU scalar only for s1/s2)
                s1 += outcome.sum().item()
                s2 += (outcome ** 2).sum().item()

                # Forward pass: get the main output and segmentation output
                outputs, seg_output = model(X_input)
                # Append the EF predictions as a tensor (remains on GPU)
                yhat_list.append(outputs.view(-1))
                ef_loss = torch.nn.functional.mse_loss(outputs.view(-1), outcome.float())

                # Use the actual batch size (assumed to match the first dimension)
                batch_size = X_input.size(0)
                indices = torch.arange(batch_size, device=device)

                # --- Diastolic (large) frame processing ---
                # Select the frames using the provided index per sample.
                large_frame_selected = seg_output[indices, :, large_index, :, :]
                # Pass these through the model to get segmentation predictions.
                y_large = model(large_frame_selected)["out"]
                loss_large = torch.nn.functional.binary_cross_entropy_with_logits(
                    y_large[:, 0, :, :], large_trace, reduction="sum"
                )
                # Compute intersection/union using torch’s logical operators.
                pred_large = (y_large[:, 0, :, :] > 0).float()
                trace_large = (large_trace > 0).float()
                large_inter += (pred_large * trace_large).sum()
                large_union += ((pred_large + trace_large) > 0).float().sum()
                # Per-sample intersection/union (flatten spatial dimensions and sum)
                large_inter_list.append(
                    (pred_large * trace_large).view(batch_size, -1).sum(dim=1)
                )
                large_union_list.append(
                    ((pred_large + trace_large) > 0).float().view(batch_size, -1).sum(dim=1)
                )

                # --- Systolic (small) frame processing ---
                small_frame_selected = seg_output[indices, :, small_index, :, :]
                y_small = model(small_frame_selected)["out"]
                loss_small = torch.nn.functional.binary_cross_entropy_with_logits(
                    y_small[:, 0, :, :], small_trace, reduction="sum"
                )
                pred_small = (y_small[:, 0, :, :] > 0).float()
                trace_small = (small_trace > 0).float()
                small_inter += (pred_small * trace_small).sum()
                small_union += ((pred_small + trace_small) > 0).float().sum()
                small_inter_list.append(
                    (pred_small * trace_small).view(batch_size, -1).sum(dim=1)
                )
                small_union_list.append(
                    ((pred_small + trace_small) > 0).float().view(batch_size, -1).sum(dim=1)
                )

                # Combine losses and update parameters if training.
                loss = (loss_large + loss_small) / 2 + ef_loss
                if train:
                    optim.zero_grad()
                    loss.backward()
                    optim.step()

                total_loss += loss.item()
                n += batch_size

                # Compute baseline metrics using GPU tensors then convert to scalars for display.
                p_val = pos / (pos + neg + 1e-10)
                # p_pix is an (H,W) tensor; compute its elementwise “entropy” and average.
                p_pix_val = ((pos_pix + 1) / (pos_pix + neg_pix + 2)).mean()
                dice_large = 2 * large_inter / (large_union + large_inter + 1e-10)
                dice_small = 2 * small_inter / (small_union + small_inter + 1e-10)

                info_str = "{:.4f} (SEG:{:.4f}, EF:{:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(
                    total_loss / n / (112 * 112),
                    loss.item() / batch_size / (112 * 112),
                    ef_loss.item() / batch_size,
                    -p_val.item() * math.log(p_val.item() + 1e-10)
                    - (1 - p_val.item()) * math.log(1 - p_val.item() + 1e-10),
                    (-p_pix_val * torch.log(p_pix_val)
                     - (1 - p_pix_val) * torch.log(1 - p_pix_val)).mean().item(),
                    dice_large.item(),
                    dice_small.item()
                )
                pbar.set_postfix_str(info_str)
                pbar.update()

    # After processing, concatenate outputs while keeping them on GPU.
    if not save_all:
        yhat = torch.cat(yhat_list)
    else:
        yhat = yhat_list
    y_out = torch.cat(y_list) if y_list else torch.tensor([], device=device)
    large_inter_tensor = torch.cat(large_inter_list) if large_inter_list else torch.tensor([], device=device)
    large_union_tensor = torch.cat(large_union_list) if large_union_list else torch.tensor([], device=device)
    small_inter_tensor = torch.cat(small_inter_list) if small_inter_list else torch.tensor([], device=device)
    small_union_tensor = torch.cat(small_union_list) if small_union_list else torch.tensor([], device=device)

    return (
        total_loss / n / (112 * 112),
        large_inter_tensor,
        large_union_tensor,
        small_inter_tensor,
        small_union_tensor,
        ef_loss,
        yhat,
        y_out,
    )


In [127]:
import time
run_test = True
with open(os.path.join(output_seg, "log.csv"), "a") as f:
    epoch_resume = 0
    bestLoss = float("inf")
    try:
        # Attempt to load checkpoint
        checkpoint = torch.load(os.path.join(output_seg, "checkpoint.pt"))
        model.load_state_dict(checkpoint['state_dict'])
        optim.load_state_dict(checkpoint['opt_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_dict'])
        epoch_resume = checkpoint["epoch"] + 1
        bestLoss = checkpoint["best_loss"]
        f.write("Resuming from epoch {}\n".format(epoch_resume))
    except FileNotFoundError:
        f.write("Starting run from scratch\n")

    for epoch in range(epoch_resume, num_epochs):
        print("Epoch #{}".format(epoch), flush=True)
        for phase in ['train', 'val']:
            start_time = time.time()
            for i in range(torch.cuda.device_count()):
                torch.cuda.reset_peak_memory_stats(i)

            ds = dataset[phase]
            dataloader = torch.utils.data.DataLoader(
                ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device == "cuda"), drop_last=(phase == "train"))

            loss, large_inter, large_union, small_inter, small_union, ef_loss, y_hat, y = seg_run_epoch(model, dataloader, phase == "train", optim, device)
            overall_dice = 2 * (large_inter.sum() + small_inter.sum()) / (large_union.sum() + large_inter.sum() + small_union.sum() + small_inter.sum())
            large_dice = 2 * large_inter.sum() / (large_union.sum() + large_inter.sum())
            small_dice = 2 * small_inter.sum() / (small_union.sum() + small_inter.sum())
            f.write("{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch,
                                                                phase,
                                                                loss,
                                                                overall_dice,
                                                                large_dice,
                                                                small_dice,
                                                                ef_loss,
                                                                sklearn.metrics.r2_score(y, yhat),
                                                                time.time() - start_time,
                                                                large_inter.size,
                                                                sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
                                                                sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
                                                                batch_size))
            f.flush()
        scheduler.step()

        # Save checkpoint
        save = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'best_loss': bestLoss,
            'loss': loss,
            'r2': sklearn.metrics.r2_score(y, yhat),
            'opt_dict': optim.state_dict(),
            'scheduler_dict': scheduler.state_dict(),
        }
        torch.save(save, os.path.join(output_seg, "checkpoint.pt"))
        if loss < bestLoss:
            torch.save(save, os.path.join(output_seg, "best.pt"))
            bestLoss = loss

    # Load best weights
    if num_epochs != 0:
        checkpoint = torch.load(os.path.join(output_seg, "best.pt"))
        model.load_state_dict(checkpoint['state_dict'])
        f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"]))

    if run_test:
        # Run on validation and test
        for split in ["val", "test"]:
            dataset = Echo(root=data_dir, split=split, **kwargs)
            dataloader = torch.utils.data.DataLoader(dataset,
                                                        batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device == "cuda"))
            loss, large_inter, large_union, small_inter, small_union, ef_loss, y_hat, y = seg_run_epoch(model, dataloader, False, None, device)

            overall_dice = 2 * (large_inter + small_inter) / (large_union + large_inter + small_union + small_inter)
            large_dice = 2 * large_inter / (large_union + large_inter)
            small_dice = 2 * small_inter / (small_union + small_inter)
            with open(os.path.join(output_seg, "{}_dice.csv".format(split)), "w") as g:
                g.write("Filename, Overall, Large, Small, R2, MAE, RMSE\n")
                for (filename, overall, large, small) in zip(dataset.fnames, overall_dice, large_dice, small_dice):
                    g.write("{},{},{},{}\n".format(filename, overall, large, small))

            f.write("{} dice (overall): {:.4f} ({:.4f} - {:.4f})\n".format(split, *bootstrap(np.concatenate((large_inter, small_inter)), np.concatenate((large_union, small_union)), dice_similarity_coefficient)))
            f.write("{} dice (large):   {:.4f} ({:.4f} - {:.4f})\n".format(split, *bootstrap(large_inter, large_union, dice_similarity_coefficient)))
            f.write("{} dice (small):   {:.4f} ({:.4f} - {:.4f})\n".format(split, *bootstrap(small_inter, small_union, dice_similarity_coefficient)))
            f.write("{} (one clip) R2:   {:.3f} ({:.3f} - {:.3f})".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score)))
            f.write("{} (one clip) MAE:  {:.2f} ({:.2f} - {:.2f})".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error)))
            f.write("{} (one clip) RMSE: {:.2f} ({:.2f} - {:.2f})".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error)))))

            f.flush()
            with open(os.path.join(output, "{}_predictions.csv".format(split)), "w") as g:
                for (filename, pred) in zip(ds.fnames, yhat):
                    for (i, p) in enumerate(pred):
                        g.write("{},{},{:.4f}\n".format(filename, i, p))
            echonet.utils.latexify()
            yhat = np.array(list(map(lambda x: x.mean(), yhat)))

            # Plot actual and predicted EF
            fig = plt.figure(figsize=(3, 3))
            lower = min(y.min(), yhat.min())
            upper = max(y.max(), yhat.max())
            plt.scatter(y, yhat, color="k", s=1, edgecolor=None, zorder=2)
            plt.plot([0, 100], [0, 100], linewidth=1, zorder=3)
            plt.axis([lower - 3, upper + 3, lower - 3, upper + 3])
            plt.gca().set_aspect("equal", "box")
            plt.xlabel("Actual EF (%)")
            plt.ylabel("Predicted EF (%)")
            plt.xticks([10, 20, 30, 40, 50, 60, 70, 80])
            plt.yticks([10, 20, 30, 40, 50, 60, 70, 80])
            plt.grid(color="gainsboro", linestyle="--", linewidth=1, zorder=1)
            plt.tight_layout()
            plt.savefig(os.path.join(output, "{}_scatter.pdf".format(split)))
            plt.close(fig)

            # Plot AUROC
            fig = plt.figure(figsize=(3, 3))
            plt.plot([0, 1], [0, 1], linewidth=1, color="k", linestyle="--")
            for thresh in [35, 40, 45, 50]:
                fpr, tpr, _ = sklearn.metrics.roc_curve(y > thresh, yhat)
                print(thresh, sklearn.metrics.roc_auc_score(y > thresh, yhat))
                plt.plot(fpr, tpr)

            plt.axis([-0.01, 1.01, -0.01, 1.01])
            plt.xlabel("False Positive Rate")
            plt.ylabel("True Positive Rate")
            plt.tight_layout()
            plt.savefig(os.path.join(output, "{}_roc.pdf".format(split)))
            plt.close(fig)


Epoch #0


  checkpoint = torch.load(os.path.join(output_seg, "checkpoint.pt"))
  0%|          | 0/620 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.20 GiB. GPU 0 has a total capacity of 15.89 GiB of which 1.09 GiB is free. Process 2553 has 14.79 GiB memory in use. Of the allocated memory 14.38 GiB is allocated by PyTorch, and 116.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)