Below are links to resources related to this notebook:
* [W&B Project](https://wandb.ai/vincenttu/torch_vs_tf_talmo_lab?workspace=user-vincenttu)
* [GitHub](https://github.com/alckasoc/sleap_keypoint_tf_torch)



# Install SLEAP
Don't forget to set **Runtime** -> **Change runtime type...** -> **GPU** as the accelerator.

In [1]:
!pip install sleap -qqq
!pip install albumentations -qqq
!pip install nvidia-ml-py3 -qqq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.4/64.4 MB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m61.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.5/60.5 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.9/84.9 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m83.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.9/43.9 MB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os
import gc
import random
import math
import time
import collections
from typing import Sequence, Tuple, Text, Union, Optional, List

import nvidia_smi

import numpy as np
import albumentations as A

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.profiler import profile, record_function, ProfilerActivity

import sleap

sleap.versions()

SLEAP: 1.3.0
TensorFlow: 2.8.4
Numpy: 1.22.4
Python: 3.10.11
OS: Linux-5.15.107+-x86_64-with-glibc2.31


In [3]:
!pip install wandb -qqq
import wandb
wandb.login()

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m39.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m205.1/205.1 kB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

# Utils

In [4]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def get_vram():
    nvidia_smi.nvmlInit()

    deviceCount = nvidia_smi.nvmlDeviceGetCount()
    for i in range(deviceCount):
        handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
        info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
        output = ("Device {}: {}, Memory : ({:.2f}% free): {} (total), {} (free), {} (used)"
              .format(i, nvidia_smi.nvmlDeviceGetName(handle), 100*info.free/info.total, 
                      info.total/(1024 ** 3), info.free/(1024 ** 3), info.used/(1024 ** 3)))

    nvidia_smi.nvmlShutdown()

    return output

def get_param_count(model):
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  total_params = sum(p.numel() for p in model.parameters())
  nontrainable_params = total_params - trainable_params

  return trainable_params, nontrainable_params, total_params

In [5]:
seed = 42
seed_everything(seed)

# Download training data

In [6]:
!curl -L --output labels.slp https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/tracking_split2/train.pkg.slp
!ls -lah

!curl -L --output val_labels.slp https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/tracking_split2/val.pkg.slp
!ls -lah

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  619M  100  619M    0     0  32.0M      0  0:00:19  0:00:19 --:--:-- 35.1M
total 620M
drwxr-xr-x 1 root root 4.0K May 19 09:01 .
drwxr-xr-x 1 root root 4.0K May 19 08:42 ..
drwxr-xr-x 4 root root 4.0K May 17 20:40 .config
-rw-r--r-- 1 root root 620M May 19 09:01 labels.slp
drwxr-xr-x 1 root root 4.0K May 17 20:41 sample_data
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 77.2M  100 77.2M    0     0  24.6M      0  0:00:03  0:00:03 --:--:-- 24.6M
total 697M
drwxr-xr-x 1 root root 4.0K May 19 09:01 .
drwxr-xr-x 1 root root 4.0K May 19 08:42 ..
drwxr-xr-x 4 root root 4.0K May 17 20:40 .config
-rw-r--r-- 1 root root 620M May 19 09:01 labels.slp
drwxr-xr-x 1 root root 4.0K May 17 20:41 sample_data
-rw-r--r-- 1 root root

# Load the training data

In [7]:
# SLEAP Labels files (.slp) can include the images as well as labeled instances and
# other metadata for a project.
labels = sleap.load_file("labels.slp")
labels = labels.with_user_labels_only()
labels.describe()

Skeleton: Skeleton(description=None, nodes=[head, thorax, abdomen, wingL, wingR, forelegL4, forelegR4, midlegL4, midlegR4, hindlegL4, hindlegR4, eyeL, eyeR], edges=[thorax->head, thorax->abdomen, thorax->wingL, thorax->wingR, thorax->forelegL4, thorax->forelegR4, thorax->midlegL4, thorax->midlegR4, thorax->hindlegL4, thorax->hindlegR4, head->eyeL, head->eyeR], symmetries=[midlegL4<->midlegR4, eyeL<->eyeR, hindlegL4<->hindlegR4, forelegL4<->forelegR4, wingL<->wingR])
Videos: ['labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp', 'labels.slp']
Frames (user/predicted): 1,600/0
Instances (user/predicted): 3,200/0
Tracks: [Track(spawned_on=0, na

In [8]:
# Let's also do the same for the val labels.
val_labels = sleap.load_file("val_labels.slp")
val_labels = val_labels.with_user_labels_only()
val_labels.describe()

Skeleton: Skeleton(description=None, nodes=[head, thorax, abdomen, wingL, wingR, forelegL4, forelegR4, midlegL4, midlegR4, hindlegL4, hindlegR4, eyeL, eyeR], edges=[thorax->head, thorax->abdomen, thorax->wingL, thorax->wingR, thorax->forelegL4, thorax->forelegR4, thorax->midlegL4, thorax->midlegR4, thorax->hindlegL4, thorax->hindlegR4, head->eyeL, head->eyeR], symmetries=[wingL<->wingR, forelegL4<->forelegR4, eyeL<->eyeR, hindlegL4<->hindlegR4, midlegL4<->midlegR4])
Videos: ['val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp', 'val_labels.slp'

In [9]:
# Labels are list-like containers whose elements are LabeledFrames
print(f"Number of labels: {len(labels)}")

labeled_frame = labels[0]
labeled_frame

Number of labels: 1600


LabeledFrame(video=HDF5Video('labels.slp'), frame_idx=166050, instances=2)

In [10]:
# LabeledFrames are containers for instances that were labeled in a single frame
instance = labeled_frame[0]
instance

Instance(video=Video(filename=labels.slp, shape=(66, 1024, 1024, 1), backend=HDF5Video), frame_idx=166050, points=[head: (491.6, 187.7), thorax: (474.4, 224.8), abdomen: (459.9, 262.2), wingL: (448.3, 271.7), wingR: (452.1, 273.5), forelegL4: (478.5, 175.9), forelegR4: (499.9, 177.9), midlegL4: (440.6, 216.4), midlegR4: (510.1, 242.7), hindlegL4: (437.2, 234.3), hindlegR4: (490.9, 266.7), eyeL: (477.5, 193.2), eyeR: (498.4, 201.2)], track=Track(spawned_on=0, name='female'))

In [11]:
# They can be converted to numpy arrays where each row corresponds to the coordinates
# of a different body part:
pts = instance.numpy()
pts

rec.array([[491.58118169, 187.72078779],
           [474.3603939 , 224.80196948],
           [459.90098474, 262.16236338],
           [448.26137864, 271.72078779],
           [452.08118169, 273.54059084],
           [478.5       , 175.90098474],
           [499.94157558, 177.90098474],
           [440.58118169, 216.3603939 ],
           [510.12177253, 242.72078779],
           [         nan,          nan],
           [490.90098474, 266.72078779],
           [477.54059084, 193.16236338],
           [498.40098474, 201.18019695]],
          dtype=float64)

# Setup training data generation

In [12]:
def update_not_shown_nodes(not_shown_nodes, node_names, new_nodes):
  nodes_not_in_aug = np.array(list(set(node_names).difference(set(new_nodes))))
  not_shown_in_aug_or_original_ind = np.in1d(node_names, nodes_not_in_aug).nonzero()[0]
  not_shown_nodes[not_shown_in_aug_or_original_ind] = True 

  return not_shown_nodes


def update_kp(kp, not_shown_nodes, node_names, new_nodes):
  shown_after_aug_ind = np.in1d(node_names, new_nodes).nonzero()[0]
  shown_ind = np.in1d(not_shown_nodes, True).nonzero()[0]

  assert len(shown_after_aug_ind) == kp.shape[0]

  kp_ = np.zeros((len(node_names), 2))
  kp_[shown_after_aug_ind] = kp
  kp_[shown_ind] = 0

  return kp_


def make_grid_vectors(
    image_height: int, image_width: int, output_stride: int = 1):

    xv = torch.arange(0, image_width, step=output_stride).to(torch.float32)
    yv = torch.arange(0, image_height, step=output_stride).to(torch.float32)
    return xv, yv

def make_confmaps(
    points: torch.Tensor, xv: torch.Tensor, yv: torch.Tensor, sigma: float):

    x = torch.reshape(points[:, 0], (1, 1, -1))
    y = torch.reshape(points[:, 1], (1, 1, -1))
    cm = torch.exp(
        -((torch.reshape(xv, (1, -1, 1)) - x) ** 2 + (torch.reshape(yv, (-1, 1, 1)) - y) ** 2)
        / (2 * sigma ** 2)
    )

    # Replace NaNs with 0.
    cm = torch.where(torch.isnan(cm), 0.0, cm)
    return cm

def get_bbox_coords_on_centroid(anchor_coords, crop_size, img_size):
    (cx, cy) = anchor_coords

    # [bottom left     top right]
    # [  x1, y1,         x2, y2 ]
    bbox = [
        max(-crop_size / 2 + cx, 0),
        max(-crop_size / 2 + cy, 0),
        min(crop_size / 2 + cx, img_size[0]),
        min(crop_size / 2 + cy, img_size[1])
    ]

    return bbox

# My refactored version of this dataset generator.
class DataGenerator(Dataset):
    def __init__(self, 
      labels, 
      img_size=160,
      anchor_name="thorax",
      sigma=1.5,
      output_stride=2,
      rot_range=(-180, 180),
      is_train=True
    ):
        self.labels = labels.with_user_labels_only()
        self.labels.remove_empty_instances(keep_empty_frames=False)

        self.indices = []
        for frame_idx, l in enumerate(self.labels):
          inst_indices = np.arange(0, len(l.instances)).tolist()
          self.indices.extend([(frame_idx, i) for i in inst_indices])

        self.img_size = img_size

        assert anchor_name in self.labels.skeleton.node_names
        self.anchor_name = anchor_name

        # Assuming 1 skeleton.
        assert len(labels.skeletons) == 1
        self.node_names = labels.skeletons[0].node_names

        self.sigma = sigma
        self.output_stride = output_stride
        self.rot_range = rot_range

        self.tfm = A.Compose([
            A.Rotate(limit=list(self.rot_range), p=0.5)
        ], keypoint_params=A.KeypointParams(format='xy', label_fields=['class_labels']))

        self.xv, self.yv = make_grid_vectors(
            image_height=self.img_size, 
            image_width=self.img_size, 
            output_stride=self.output_stride
        )

        self.is_train = is_train

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        frame_idx, inst_idx = self.indices[idx]
        lf = self.labels[frame_idx]
        instance = lf[inst_idx]
        img = lf.image
        kp = instance.numpy()

        # NaNs to 0 and clip.
        assert kp.shape == (len(self.node_names), 2)
        not_shown_nodes = np.isnan(kp).any(axis=1)
        kp = np.nan_to_num(kp, nan=0)
        kp = np.concatenate((np.clip(kp[:, :1], 0, img.shape[1]), 
                              np.clip(kp[:, 1:], 0, img.shape[0])), 
                            axis=1)

        if self.is_train:
            # Apply augmentations.
            output = self.tfm(image=img, keypoints=kp, class_labels=self.node_names)
            img, kp, new_nodes = output["image"], np.array(output["keypoints"]), output["class_labels"]

            # Update not_shown_nodes and kp.
            not_shown_nodes = update_not_shown_nodes(not_shown_nodes, self.node_names, new_nodes)
            kp = update_kp(kp, not_shown_nodes, self.node_names, new_nodes)

        # Get bbox coordinate based on centroid.
        bbox = get_bbox_coords_on_centroid(
          kp[self.node_names.index(self.anchor_name)].tolist(), 
          self.img_size, img.shape[:2]
        )

        # Crop and pad.
        x1, y1, x2, y2 = bbox
        tfm_crop = A.Compose([
          A.Crop(int(round(x1)), int(round(y1)), int(round(x2)), int(round(y2))),
          A.PadIfNeeded(min_height=self.img_size, min_width=self.img_size)
        ], keypoint_params=A.KeypointParams(format='xy', label_fields=['class_labels']))

        crop_kp = tfm_crop(image=img, keypoints=kp, class_labels=self.node_names)
        crop, kp, new_nodes = crop_kp["image"], np.array(crop_kp["keypoints"]), crop_kp["class_labels"]
        crop = torch.Tensor(crop).permute(2, 0, 1)

        # Update not_shown_nodes and kp.
        not_shown_nodes = update_not_shown_nodes(not_shown_nodes, self.node_names, new_nodes)
        kp = update_kp(kp, not_shown_nodes, self.node_names, new_nodes)
        kp = torch.Tensor(kp)

        # Get confidence map.
        xv, yv = make_grid_vectors(
          image_height=self.img_size, 
          image_width=self.img_size, 
          output_stride=self.output_stride
        )

        cm = make_confmaps(
          points=kp, 
          xv=self.xv, 
          yv=self.yv, 
          sigma=self.sigma
        )
        cm = cm.permute(2, 0, 1)

        return crop, cm

In [13]:
# My refactored version of this dataset generator.
class DataGenerator(Dataset):
    def __init__(self, 
      labels, 
      img_size=160,
      anchor_name="thorax",
      sigma=1.5,
      output_stride=2,
      rot_range=(-180, 180),
      is_train=True
    ):
        self.labels = labels.with_user_labels_only()
        self.labels.remove_empty_instances(keep_empty_frames=False)

        self.indices = []
        for frame_idx, l in enumerate(self.labels):
          inst_indices = np.arange(0, len(l.instances)).tolist()
          self.indices.extend([(frame_idx, i) for i in inst_indices])

        self.img_size = img_size

        assert anchor_name in self.labels.skeleton.node_names
        self.anchor_name = anchor_name

        # Assuming 1 skeleton.
        assert len(labels.skeletons) == 1
        self.node_names = labels.skeletons[0].node_names

        self.sigma = sigma
        self.output_stride = output_stride
        self.rot_range = rot_range

        self.tfm = A.Compose([
            A.Rotate(limit=list(self.rot_range), p=0.5)
        ], keypoint_params=A.KeypointParams(format='xy', label_fields=['class_labels']))

        self.xv, self.yv = make_grid_vectors(
            image_height=self.img_size, 
            image_width=self.img_size, 
            output_stride=self.output_stride
        )

        self.is_train = is_train

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        frame_idx, inst_idx = self.indices[idx]
        lf = self.labels[frame_idx]
        instance = lf[inst_idx]
        img = lf.image
        kp = instance.numpy()

        # NaNs to 0 and clip.
        assert kp.shape == (len(self.node_names), 2)
        not_shown_nodes = np.isnan(kp).any(axis=1)
        kp = np.nan_to_num(kp, nan=0)
        kp = np.concatenate((np.clip(kp[:, :1], 0, img.shape[1]), 
                              np.clip(kp[:, 1:], 0, img.shape[0])), 
                            axis=1)

        if self.is_train:
            # Apply augmentations.
            output = self.tfm(image=img, keypoints=kp, class_labels=self.node_names)
            img, kp, new_nodes = output["image"], np.array(output["keypoints"]), output["class_labels"]

            # Update not_shown_nodes and kp.
            not_shown_nodes = update_not_shown_nodes(not_shown_nodes, self.node_names, new_nodes)
            kp = update_kp(kp, not_shown_nodes, self.node_names, new_nodes)

        # Get bbox coordinate based on centroid.
        bbox = get_bbox_coords_on_centroid(
          kp[self.node_names.index(self.anchor_name)].tolist(), 
          self.img_size, img.shape[:2]
        )

        # Crop and pad.
        x1, y1, x2, y2 = bbox
        tfm_crop = A.Compose([
          A.Crop(int(round(x1)), int(round(y1)), int(round(x2)), int(round(y2))),
          A.PadIfNeeded(min_height=self.img_size, min_width=self.img_size)
        ], keypoint_params=A.KeypointParams(format='xy', label_fields=['class_labels']))

        crop_kp = tfm_crop(image=img, keypoints=kp, class_labels=self.node_names)
        crop, kp, new_nodes = crop_kp["image"], np.array(crop_kp["keypoints"]), crop_kp["class_labels"]
        crop = torch.Tensor(crop).permute(2, 0, 1)

        # Update not_shown_nodes and kp.
        not_shown_nodes = update_not_shown_nodes(not_shown_nodes, self.node_names, new_nodes)
        kp = update_kp(kp, not_shown_nodes, self.node_names, new_nodes)
        kp = torch.Tensor(kp)

        # Get confidence map.
        xv, yv = make_grid_vectors(
          image_height=self.img_size, 
          image_width=self.img_size, 
          output_stride=self.output_stride
        )

        cm = make_confmaps(
          points=kp, 
          xv=self.xv, 
          yv=self.yv, 
          sigma=self.sigma
        )
        cm = cm.permute(2, 0, 1)

        return crop, cm

# Setting up a neural network model

In [14]:
class MaxPool2dWithSamePadding(nn.MaxPool2d):

    def _calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
        return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.padding == "same":
            ih, iw = x.size()[-2:]

            pad_h = self._calc_same_pad(i=ih, 
                                        k=self.kernel_size if type(self.kernel_size) is int else self.kernel_size[0], 
                                        s=self.stride if type(self.stride) is int else self.stride[0], 
                                        d=self.dilation if type(self.dilation) is int else self.dilation[0])
            pad_w = self._calc_same_pad(i=iw, 
                                        k=self.kernel_size if type(self.kernel_size) is int else self.kernel_size[1], 
                                        s=self.stride if type(self.stride) is int else self.stride[1], 
                                        d=self.dilation if type(self.dilation) is int else self.dilation[1])

            if pad_h > 0 or pad_w > 0:
                x = F.pad(
                    x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
                )
            self.padding = 0

        return F.max_pool2d(x, self.kernel_size, self.stride,
                                    self.padding, self.dilation, ceil_mode=self.ceil_mode,
                                    return_indices=self.return_indices)

def get_act_fn(activation: str) -> nn.Module:
    activations = {
        'relu': nn.ReLU(),
        'sigmoid': nn.Sigmoid(),
        'tanh': nn.Tanh()
    }

    return activations[activation]

class SimpleConvBlock(nn.Module):
    def __init__(self,
        in_channels: int,
        pool: bool = True,
        pooling_stride: int = 2,
        pool_before_convs: bool = False,
        num_convs: int = 2,
        filters: int = 32,
        kernel_size: int = 3,
        use_bias: bool = True,
        batch_norm: bool = False,
        activation: Text = "relu"
    ) -> None:
        super().__init__()

        self.in_channels = in_channels
        self.pool = pool
        self.pooling_stride = pooling_stride
        self.pool_before_convs = pool_before_convs
        self.num_convs = num_convs
        self.filters = filters
        self.kernel_size = kernel_size
        self.use_bias = use_bias
        self.batch_norm = batch_norm
        self.activation = activation

        self.blocks = []
        if pool and pool_before_convs:
            self.blocks.append(
                MaxPool2dWithSamePadding(
                    kernel_size=2,
                    stride=pooling_stride,
                    padding="same"
                )
            )

        for i in range(num_convs):
            self.blocks.append(
                nn.Conv2d(
                    in_channels=in_channels if i == 0 else filters,
                    out_channels=filters,
                    kernel_size=kernel_size,
                    stride=1,
                    padding="same",
                    bias=use_bias
                )
            )

            if batch_norm:
                self.blocks.append(
                    nn.BatchNorm2d(filters)
                )

            self.blocks.append(
                get_act_fn(activation)  
            )


        if pool and not pool_before_convs:
            self.blocks.append(
                MaxPool2dWithSamePadding(
                    kernel_size=2,
                    stride=pooling_stride,
                    padding="same"
                )
            )

        self.blocks = nn.Sequential(*self.blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.blocks(x)

class Encoder(nn.Module):
    def __init__(self,
        in_channels: int = 3,
        filters: int = 64,
        down_blocks: int = 4, 
        filters_rate: Union[float, int] = 2,
        current_stride: int = 2,
        stem_blocks: int = 0,
        convs_per_block: int = 2,
        kernel_size: Union[int, Tuple[int, int]] = 3,
        middle_block: bool = True,
        block_contraction: bool = False
    ) -> None:
        super().__init__()

        self.in_channels = in_channels
        self.filters = filters
        self.down_blocks = down_blocks
        self.filters_rate = filters_rate
        self.current_stride = current_stride
        self.stem_blocks = stem_blocks
        self.convs_per_block = convs_per_block
        self.kernel_size = kernel_size
        self.middle_block = middle_block
        self.block_contraction = block_contraction

        self.encoder_stack = nn.ModuleList([])
        for block in range(down_blocks):
            prev_block_filters = -1 if block==0 else block_filters
            block_filters = int(
                filters * (filters_rate ** (block + stem_blocks))
            )

            self.encoder_stack.append(
                SimpleConvBlock(
                    in_channels=in_channels if block == 0 else prev_block_filters,
                    pool=(block > 0),
                    pool_before_convs=True,
                    pooling_stride=2,
                    num_convs=convs_per_block,
                    filters=block_filters,
                    kernel_size=kernel_size,
                    use_bias=True,
                    batch_norm=False,
                    activation="relu"
                )
            )
        after_block_filters = block_filters

        self.encoder_stack.append(
            MaxPool2dWithSamePadding(
                kernel_size=2,
                stride=2,
                padding="same"
            )
        )

        # Create a middle block (like the CARE implementation).
        if middle_block:
            if convs_per_block > 1:
                # First convs are one exponent higher than the last encoder block.
                block_filters = int(
                    filters * (filters_rate ** (down_blocks + stem_blocks))
                )
                self.encoder_stack.append(
                    SimpleConvBlock(
                        in_channels=after_block_filters,
                        pool=False,
                        pool_before_convs=False,
                        pooling_stride=2,
                        num_convs=convs_per_block - 1,
                        filters=block_filters,
                        kernel_size=kernel_size,
                        use_bias=True,
                        batch_norm=False,
                        activation="relu",
                    )
                )

            if block_contraction:
                # Contract the channels with an exponent lower than the last encoder block.
                block_filters = int(
                    filters * (filters_rate ** (down_blocks + stem_blocks - 1))
                )
            else:
                # Keep the block output filters the same.
                block_filters = int(
                    filters * (filters_rate ** (down_blocks + stem_blocks))
                )

            self.encoder_stack.append(
                SimpleConvBlock(
                    in_channels=block_filters,
                    pool=False,
                    pool_before_convs=False,
                    pooling_stride=2,
                    num_convs=1,
                    filters=block_filters,
                    kernel_size=kernel_size,
                    use_bias=True,
                    batch_norm=False,
                    activation="relu",
                )
            )

        self.intermediate_features = {}
        for i, block in enumerate(self.encoder_stack):
            if isinstance(block, SimpleConvBlock) and block.pool:
                current_stride *= block.pooling_stride

            if current_stride not in self.intermediate_features.values():
                self.intermediate_features[i] = current_stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = []
        for i in range(len(self.encoder_stack)):
            x = self.encoder_stack[i](x)
            
            if i in self.intermediate_features.keys():
                features.append(x)

        return x, features[1:][::-1]

class SimpleUpsamplingBlock(nn.Module):
    def __init__(self,
        x_in_shape: int, 
        current_stride: int,
        upsampling_stride: int = 2,
        interp_method: Text = "bilinear",
        refine_convs: int = 2,
        refine_convs_filters: int = 64,
        refine_convs_kernel_size: int = 3,
        refine_convs_use_bias: bool = True,
        refine_convs_batch_norm: bool = True,
        refine_convs_batch_norm_before_activation: bool = True,
        refine_convs_activation: Text = "relu"
    ) -> None:
        super().__init__()

        self.x_in_shape = x_in_shape
        self.current_stride = current_stride
        self.upsampling_stride = upsampling_stride
        self.interp_method = interp_method
        self.refine_convs = refine_convs
        self.refine_convs_filters = refine_convs_filters
        self.refine_convs_kernel_size = refine_convs_kernel_size
        self.refine_convs_use_bias = refine_convs_use_bias
        self.refine_convs_batch_norm = refine_convs_batch_norm
        self.refine_convs_batch_norm_before_activation = refine_convs_batch_norm_before_activation
        self.refine_convs_activation = refine_convs_activation

        self.blocks = nn.ModuleList([])
        if current_stride is not None:
            # Append the strides to the block prefix.
            new_stride = current_stride // upsampling_stride

        # Upsample via interpolation.
        self.blocks.append(
            nn.Upsample(
                scale_factor=upsampling_stride,
                mode=interp_method,
            )
        )


        # Add further convolutions to refine after upsampling and/or skip.
        for i in range(refine_convs):
            filters = refine_convs_filters
            self.blocks.append(
                nn.Conv2d(
                    in_channels=x_in_shape if i==0 else filters,
                    out_channels=filters,
                    kernel_size=refine_convs_kernel_size,
                    stride=1,
                    padding="same",
                    bias=refine_convs_use_bias
                )
            )

            if (
                refine_convs_batch_norm
                and refine_convs_batch_norm_before_activation
            ):
                self.blocks.append(nn.BatchNorm2d(num_features=filters))


            self.blocks.append(
                get_act_fn(refine_convs_activation)  
            )

            if (
                refine_convs_batch_norm
                and not refine_convs_batch_norm_before_activation
            ):
                self.blocks.append(nn.BatchNorm2d(num_features=filters))


    def forward(self, x: torch.Tensor, feature: torch.Tensor) -> torch.Tensor:
        for idx, b in enumerate(self.blocks):
            if idx == 1:  # Right after upsampling or convtranspose2d.
                x = torch.concat((x, feature), dim=1)
            x = b(x)

        return x

class Decoder(nn.Module):
    def __init__(self,
        x_in_shape: int,
        current_stride: int,
        filters: int = 64,
        up_blocks: int = 4,
        down_blocks: int = 3,
        filters_rate: int = 2,
        stem_blocks: int = 0,
        convs_per_block: int = 2,
        kernel_size: int = 3,
        block_contraction: bool = False
    ) -> None:
        super().__init__()

        self.x_in_shape = x_in_shape
        self.current_stride = current_stride
        self.filters = filters
        self.up_blocks = up_blocks
        self.down_blocks = down_blocks
        self.filters_rate = filters_rate
        self.stem_blocks = stem_blocks
        self.convs_per_block = convs_per_block
        self.kernel_size = kernel_size
        self.block_contraction = block_contraction

        self.decoder_stack = nn.ModuleList([])
        for block in range(up_blocks):
            prev_block_filters_in = -1 if block == 0 else block_filters_in
            block_filters_in = int(
                filters
                * (
                    filters_rate
                    ** (down_blocks + stem_blocks - 1 - block)
                )
            )
            if block_contraction:
                block_filters_out = int(
                    filters
                    * (
                        filters_rate
                        ** (down_blocks + stem_blocks - 2 - block)
                    )
                )
            else:
                block_filters_out = block_filters_in

            next_stride = current_stride // 2

            self.decoder_stack.append(
                SimpleUpsamplingBlock(
                    x_in_shape=(x_in_shape + block_filters_in) if block == 0 else (prev_block_filters_in + block_filters_in), 
                    current_stride=current_stride,
                    upsampling_stride=2,
                    interp_method="bilinear",
                    refine_convs=self.convs_per_block,
                    refine_convs_filters=block_filters_out,
                    refine_convs_kernel_size=self.kernel_size,
                    refine_convs_batch_norm=False,
                )
            )

            current_stride = next_stride

    def forward(self, x: torch.Tensor, features: List[torch.Tensor]) -> torch.Tensor:
        for i in range(len(self.decoder_stack)):
            x = self.decoder_stack[i](x, features[i])

        return x

class Unet(nn.Module):
    def __init__(self,
        in_channels: int = 1,
        kernel_size: int = 3,
        filters: int = 32,
        filters_rate: int = 1.5,
        stem_blocks: int = 0,
        down_blocks: int = 4, 
        up_blocks: int = 3,
        convs_per_block: int = 2,
        middle_block: bool = True,
        block_contraction: bool = False
    ) -> None:
        super().__init__()

        self.enc = Encoder(
            in_channels=in_channels,
            filters=filters,
            down_blocks=down_blocks,
            filters_rate=filters_rate,
            stem_blocks=stem_blocks,
            convs_per_block=convs_per_block,
            kernel_size=kernel_size,
            middle_block=middle_block,
            block_contraction=block_contraction
        )

        current_stride = int(
            np.prod(
                [block.pooling_stride for block in self.enc.encoder_stack if hasattr(block, "pool") and block.pool]
                + [1]
            )
        )
        
        x_in_shape = int(
            filters * (filters_rate ** (down_blocks + stem_blocks))
        )

        self.dec = Decoder(x_in_shape=x_in_shape, current_stride=current_stride, filters=filters,
            up_blocks=up_blocks, down_blocks=down_blocks, filters_rate=filters_rate
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, features = self.enc(x)
        x = self.dec(x, features)
        return x


# Train the model

In [15]:
import multiprocessing

cores = multiprocessing.cpu_count()
cores

2

## Vanilla PyTorch with Mixed Precision

In [None]:
!pip freeze > requirements.txt

filters = 32
filters_rate = 1.5
down_blocks = 4
stem_blocks = 0
up_blocks = 3

# Some of this code is redundant. I kept the previous cells un-deleted just for reference.
for i in range(5): 
    train_ds = DataGenerator(labels)
    train_dl = DataLoader(
        train_ds,
        batch_size=4,
        shuffle=True,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    val_ds = DataGenerator(val_labels, is_train=False)
    val_dl = DataLoader(
        val_ds,
        batch_size=4,
        shuffle=False,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    unet = Unet(filters=filters, 
                filters_rate=filters_rate, 
                down_blocks=down_blocks, 
                stem_blocks=stem_blocks, 
                up_blocks=up_blocks)

    in_channels = int(
        filters
        * (
            filters_rate
            ** (down_blocks + stem_blocks - 1 - up_blocks + 1)
        )
    )
    model = nn.Sequential(*[
        unet, 
        nn.Conv2d(in_channels=in_channels, out_channels=13, kernel_size=1, padding="same")    
    ])

    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    scaler = torch.cuda.amp.GradScaler()
    model = model.to("cuda")

    run = wandb.init(
        project="torch_vs_tf_talmo_lab", 
        name=f"torch_baseline_run{i}", 
        config={
            "is_tf": False,
            "device_and_memory": get_vram(),
            "seed": seed,
            "model_param_count": get_param_count(model)
        }, 
        tags=["baseline"],
        notes="This experiment was done in a Google Colab Notebook."
    )

    # Log dependencies.
    artifact = wandb.Artifact("Dependencies", type="dependencies")
    artifact.add_file("requirements.txt", name=f"requirements.txt")
    run.log_artifact(artifact)

    for epoch in range(3):
        _ = model.train()
        start_time = time.time()
        train_loss = 0
        for idx, batch in enumerate(train_dl):
            X, y = batch
            X = X.to("cuda")
            y = y.to("cuda")
            
            with torch.autocast("cuda"):
                y_preds = model(X)
                loss = nn.MSELoss()(y_preds, y)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            opt.zero_grad()
                    
            if idx % 100 == 0:
              print(f"Epoch: {epoch} | Loss: {loss:.5f}")

            train_loss += loss

        train_loss /= (idx+1)
        train_time = time.time() - start_time
        print(f"TRAIN: --- {train_time}s seconds ---")

        _  = model.eval()
        start_time = time.time()
        val_loss = 0
        for idx, batch in enumerate(val_dl):
            X, y = batch
            X = X.to("cuda")
            y = y.to("cuda")

            with torch.no_grad():
                y_preds = model(X)
                loss = nn.MSELoss()(y_preds, y)

            val_loss += loss

        val_loss /= (idx+1)
        val_time = time.time() - start_time
        print(f"VAL: --- {val_time}s seconds ---")

        run.log({
            "train_loss": train_loss,
            "val_loss": val_loss,
            "train_time": train_time,
            "val_time": val_time,
            "total_time": train_time + val_time
        })

    del model, opt, scaler, train_ds, train_dl, val_ds, val_dl
    gc.collect()
    run.finish()

[34m[1mwandb[0m: Currently logged in as: [33mvincenttu[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch: 0 | Loss: 0.01214
Epoch: 0 | Loss: 0.00036
Epoch: 0 | Loss: 0.00029
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00027
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00027
TRAIN: --- 70.5167441368103s seconds ---
VAL: --- 5.927274465560913s seconds ---
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027


KeyboardInterrupt: ignored

## PyTorch Lightning

In [None]:
!pip install lightning -qqq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.4/66.4 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.5/55.5 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.1/557.1 kB[0m [31m35.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m238.7/238.7 kB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.9/66.9 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m50.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import lightning.pytorch as pl
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import Callback

pl.seed_everything(seed, workers=True)

INFO: Global seed set to 42
INFO:lightning.fabric.utilities.seed:Global seed set to 42


42

In [None]:
class FullUNet(pl.LightningModule):
  def __init__(self,
    n_nodes = 13,
    filters = 32,
    filters_rate = 1.5,
    down_blocks = 4,
    stem_blocks = 0,
    up_blocks = 3,
    lr = 1e-4
  ):
    super().__init__()
    
    self.train_losses = []
    self.val_losses = []

    self.n_nodes = n_nodes
    self.filters = filters
    self.filters_rate = filters_rate
    self.down_blocks = down_blocks
    self.stem_blocks = stem_blocks
    self.up_blocks = up_blocks
    self.lr = lr
    
    self.unet = Unet(
        filters=filters, 
        filters_rate=filters_rate, 
        down_blocks=down_blocks, 
        stem_blocks=stem_blocks, 
        up_blocks=up_blocks
    )

    in_channels = int(
      filters
      * (
          filters_rate
          ** (down_blocks + stem_blocks - 1 - up_blocks + 1)
      )
    )

    self.last_conv = nn.Conv2d(in_channels=in_channels, out_channels=n_nodes, kernel_size=1, padding="same") 

  def forward(self, x):
    x = self.unet(x)
    return self.last_conv(x)

  def training_step(self, batch, batch_idx):
    X, y = batch
    y_preds = self.forward(X)
    loss = nn.MSELoss()(y_preds, y)
    if batch_idx % 100 == 0:
        print(f"Loss: {loss:.5f}")
    self.train_losses.append(loss.item())
    return loss

  def validation_step(self, batch, batch_idx):
    X, y = batch
    y_preds = self.forward(X)
    loss = nn.MSELoss()(y_preds, y)
    self.val_losses.append(loss.item())
    return loss

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
    return optimizer

In [None]:
class LogLossAndTimeCallback(Callback):
  def __init__(self):
    super().__init__()
    self.state = {
        "train_start_time": 0,
        "train_time": 0,
        "val_start_time": 0,
        "val_time": 0,
    }

  def load_state_dict(self, state_dict):
    self.state.update(state_dict)

  def state_dict(self):
    return self.state.copy()

  def on_train_epoch_start(self, trainer, pl_module):
    self.state["train_start_time"] = time.time()

  def on_validation_epoch_start(self, trainer, pl_module):
    pl_module.log("train_loss", np.mean(pl_module.train_losses))
    pl_module.train_losses = []
    self.state["train_time"] = time.time() - self.state["train_start_time"]
    pl_module.log("train_time", self.state["train_time"])

    self.state["val_start_time"] = time.time()

  def on_validation_epoch_end(self, trainer, pl_module):
    pl_module.log("val_loss", np.mean(pl_module.val_losses))
    pl_module.val_losses = []
    self.state["val_time"] = time.time() - self.state["val_start_time"]
    pl_module.log("val_time", self.state["val_time"])
    pl_module.log("total_time", self.state["train_time"] + self.state["val_time"])

  def on_init_start(self, trainer, pl_module):
    artifact = wandb.Artifact("Dependencies", type="dependencies")
    artifact.add_file("requirements.txt", name=f"requirements.txt")
    pl_module.logger.experiment.log_artifact(artifact)

In [None]:
!pip freeze > requirements.txt

n_nodes = len(labels.skeleton)
filters = 32
filters_rate = 1.5
down_blocks = 4
stem_blocks = 0
up_blocks = 3

batch_size = 4

# Some of this code is redundant. I kept the previous cells un-deleted just for reference.
for i in range(5): 
    train_ds = DataGenerator(labels)
    train_dl = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    val_ds = DataGenerator(val_labels, is_train=False)
    val_dl = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    model = FullUNet(n_nodes)

    wandb_logger = WandbLogger(
        name=f"torch_lightning_run{i}", 
        project="torch_vs_tf_talmo_lab",
        config={
            "is_lightning": "true",
            "is_tf": False,
            "device_and_memory": get_vram(),
            "seed": seed,
            "model_param_count": get_param_count(model),
        }, 
        tags=["lightning"],
        notes="This experiment was done in a Google Colab Notebook."

    )
    trainer = pl.Trainer(
        precision=16,  # Mixed Precision.
        logger=wandb_logger, 
        fast_dev_run=False,
        callbacks=[
            TQDMProgressBar(refresh_rate=1),
            LogLossAndTimeCallback() 
        ],
        max_epochs=3,
        overfit_batches=0.0,
        enable_checkpointing=True,
        enable_progress_bar=True,
        enable_model_summary=True,
        deterministic=False,
        benchmark=True,
        default_root_dir=None
    )

    # # Log dependencies.
    # artifact = wandb.Artifact("Dependencies", type="dependencies")
    # artifact.add_file("requirements.txt", name=f"requirements.txt")
    # run.log_artifact(artifact)

    trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)

    # run.log({
    #     "train_loss": train_loss,
    #     "val_loss": val_loss,
    #     "train_time": train_time,
    #     "val_time": val_time,
    #     "total_time": train_time + val_time
    # })

    # run.finish()


    wandb.finish()

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name      | Type   | Params
-------------------------------------
0 | unet      | Unet   | 1.3 M 
1 | last_conv | Conv2d | 637   
-------------------------------------
1.3 M     Trainable params
0         Non-trainable pa

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Loss: 0.01214
Loss: 0.00039
Loss: 0.00030
Loss: 0.00028
Loss: 0.00028
Loss: 0.00027
Loss: 0.00028
Loss: 0.00027


Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669271116666096, max=1.0…

Problem at: /usr/local/lib/python3.10/dist-packages/lightning/pytorch/loggers/wandb.py 405 experiment


CommError: ignored

In [None]:
run.finish()

## PyTorch Lightning Fabric

In [None]:
!pip install lightning -qqq

In [None]:
from lightning.fabric import Fabric

In [None]:
!pip freeze > requirements.txt

filters = 32
filters_rate = 1.5
down_blocks = 4
stem_blocks = 0
up_blocks = 3

# Some of this code is redundant. I kept the previous cells un-deleted just for reference.
for i in range(5): 
    fabric = Fabric(accelerator="cuda", precision="16-mixed")

    train_ds = DataGenerator(labels)
    train_dl = DataLoader(
        train_ds,
        batch_size=4,
        shuffle=True,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    val_ds = DataGenerator(val_labels, is_train=False)
    val_dl = DataLoader(
        val_ds,
        batch_size=4,
        shuffle=False,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    unet = Unet(filters=filters, 
                filters_rate=filters_rate, 
                down_blocks=down_blocks, 
                stem_blocks=stem_blocks, 
                up_blocks=up_blocks)

    in_channels = int(
        filters
        * (
            filters_rate
            ** (down_blocks + stem_blocks - 1 - up_blocks + 1)
        )
    )
    model = nn.Sequential(*[
        unet, 
        nn.Conv2d(in_channels=in_channels, out_channels=13, kernel_size=1, padding="same")    
    ])

    opt = torch.optim.Adam(model.parameters(), lr=1e-4)

    model, opt = fabric.setup(model, opt)
    train_dl, val_dl = fabric.setup_dataloaders(train_dl, val_dl)

    run = wandb.init(
        project="torch_vs_tf_talmo_lab", 
        name=f"torch_lightning_fabric_run{i}", 
        config={
            "is_lightning_fabric": "true",
            "device_and_memory": get_vram(),
            "seed": seed,
            "model_param_count": get_param_count(model)
        }, 
        tags=["lightning_fabric"],
        notes="This experiment was done in a Google Colab Notebook."
    )

    # Log dependencies.
    artifact = wandb.Artifact("Dependencies", type="dependencies")
    artifact.add_file("requirements.txt", name=f"requirements.txt")
    run.log_artifact(artifact)

    for epoch in range(3):
        _ = model.train()
        start_time = time.time()
        train_loss = 0
        for idx, batch in enumerate(train_dl):
            X, y = batch

            with fabric.autocast():
                y_preds = model(X)
                loss = nn.MSELoss()(y_preds, y)
            fabric.backward(loss)
            opt.step()
            opt.zero_grad()
                    
            if idx % 100 == 0:
              print(f"Epoch: {epoch} | Loss: {loss:.5f}")

            train_loss += loss

        train_loss /= (idx+1)
        train_time = time.time() - start_time
        print(f"TRAIN: --- {train_time}s seconds ---")

        _  = model.eval()
        start_time = time.time()
        val_loss = 0
        for idx, batch in enumerate(val_dl):
            X, y = batch

            with torch.no_grad():
                y_preds = model(X)
                loss = nn.MSELoss()(y_preds, y)

            val_loss += loss

        val_loss /= (idx+1)
        val_time = time.time() - start_time
        print(f"VAL: --- {val_time}s seconds ---")

        run.log({
            "train_loss": train_loss,
            "val_loss": val_loss,
            "train_time": train_time,
            "val_time": val_time,
            "total_time": train_time + val_time
        })

    del fabric, model, opt, train_ds, train_dl, val_ds, val_dl
    gc.collect()
    run.finish()

## Hugging Face Accelerate

In [None]:
!pip install accelerate -qqq

In [None]:
from accelerate import Accelerator

In [None]:
!pip freeze > requirements.txt

filters = 32
filters_rate = 1.5
down_blocks = 4
stem_blocks = 0
up_blocks = 3

# Some of this code is redundant. I kept the previous cells un-deleted just for reference.
for i in range(5): 
    accelerator = Accelerator(mixed_precision="fp16")
    device = accelerator.device

    train_ds = DataGenerator(labels)
    train_dl = DataLoader(
        train_ds,
        batch_size=4,
        shuffle=True,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    val_ds = DataGenerator(val_labels, is_train=False)
    val_dl = DataLoader(
        val_ds,
        batch_size=4,
        shuffle=False,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    unet = Unet(filters=filters, 
                filters_rate=filters_rate, 
                down_blocks=down_blocks, 
                stem_blocks=stem_blocks, 
                up_blocks=up_blocks)

    in_channels = int(
        filters
        * (
            filters_rate
            ** (down_blocks + stem_blocks - 1 - up_blocks + 1)
        )
    )
    model = nn.Sequential(*[
        unet, 
        nn.Conv2d(in_channels=in_channels, out_channels=13, kernel_size=1, padding="same")    
    ])

    opt = torch.optim.Adam(model.parameters(), lr=1e-4)

    model, opt, train_dl, val_dl = accelerator.prepare(
        model, opt, train_dl, val_dl
    )

    run = wandb.init(
        project="torch_vs_tf_talmo_lab", 
        name=f"torch_accelerate_run{i}", 
        config={
            "is_accelerate": "true",
            "device_and_memory": get_vram(),
            "seed": seed,
            "model_param_count": get_param_count(model)
        }, 
        tags=["accelerate"],
        notes="This experiment was done in a Google Colab Notebook."
    )

    # Log dependencies.
    artifact = wandb.Artifact("Dependencies", type="dependencies")
    artifact.add_file("requirements.txt", name=f"requirements.txt")
    run.log_artifact(artifact)

    for epoch in range(3):
        _ = model.train()
        start_time = time.time()
        train_loss = 0
        for idx, batch in enumerate(train_dl):
            X, y = batch
            X = X.to(device)
            y = y.to(device)
            
            y_preds = model(X)
            loss = nn.MSELoss()(y_preds, y)

            accelerator.backward(loss)
            opt.step()
            opt.zero_grad()
                    
            if idx % 100 == 0:
              print(f"Epoch: {epoch} | Loss: {loss:.5f}")

            train_loss += loss

        train_loss /= (idx+1)
        train_time = time.time() - start_time
        print(f"TRAIN: --- {train_time}s seconds ---")

        _  = model.eval()
        start_time = time.time()
        val_loss = 0
        for idx, batch in enumerate(val_dl):
            X, y = batch
            X = X.to("cuda")
            y = y.to("cuda")

            with torch.no_grad():
                y_preds = model(X)
                loss = nn.MSELoss()(y_preds, y)

            val_loss += loss

        val_loss /= (idx+1)
        val_time = time.time() - start_time
        print(f"VAL: --- {val_time}s seconds ---")

        run.log({
            "train_loss": train_loss,
            "val_loss": val_loss,
            "train_time": train_time,
            "val_time": val_time,
            "total_time": train_time + val_time
        })

    del model, opt, train_ds, train_dl, val_ds, val_dl
    gc.collect()
    run.finish()

## MosaicML Composer

In [None]:
!pip install mosaicml -qqq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m565.9/565.9 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m518.6/518.6 kB[0m [31m39.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.9/61.9 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.6/42.6 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import composer
from composer import Callback, State
from composer.trainer import Trainer
from composer.loggers import WandBLogger
from composer.models import ComposerModel
from composer.utils.reproducibility import seed_all
from composer.core import Evaluator, Precision
from composer.metrics import LossMetric

import composer.functional as cf

In [None]:
seed_all(seed)

### With Trainer

⚠️ Note, the `Trainer` does run correctly, but I did face a few finnicky errors with `wandb` run initialization and proper logging. The custom callback implementation is suboptimal and, as of now, I could not find a better way to log the time and loss. The wandb dashboard is also not entirely correct. As of May 2023, Composer is still in its alpha stages. Unless your training script is straightforward, I would suggest leveraging Composer's functional API paired with PyTorch Lightning's Trainer. 

⚠️ Additionally, my custom callback and logger slow down training by about 20 seconds. Validation time is not affected much. Without callbacks and a logger, training and validation combined take about 75 seconds, comparable to the other frameworks. This does not take into consideration the training optimization algorithms Composer provides.

In the subsequent section, I'll adapt a minimal integration of Composer into a vanilla PyTorch script.

In [None]:
class LogLossAndTimeCallback(Callback):
  def __init__(self):
    super().__init__()

    self.state = {
        "train_start_time": 0,
        "train_time": 0,
        "val_start_time": 0,
        "val_time": 0,
    }
    self.train_losses = []
    self.train_minibatch_losses = []
    self.val_losses = []

  def fit_start(self, state, logger):
    artifact = wandb.Artifact("Dependencies", type="dependencies")
    artifact.add_file("requirements.txt", name=f"requirements.txt")
    wandb.log_artifact(artifact)

  def after_loss(self, state, logger):
    self.train_minibatch_losses.append(state.loss.item())

  def after_train_batch(self, state, logger):
    batch_loss = np.mean(self.train_minibatch_losses)
    self.train_minibatch_losses = []
    self.train_losses.append(batch_loss)

  def epoch_start(self, state, logger):
    self.state["train_start_time"] = time.time()

  def epoch_end(self, state, logger):
    wandb.log({"train_loss": np.mean(self.train_losses)})
    self.train_losses = []
    self.state["train_time"] = time.time() - self.state["train_start_time"]
    wandb.log({"train_time": self.state["train_time"]})

  def eval_start(self, state, logger):
    self.state["val_start_time"] = time.time()

  def eval_end(self, state, logger):
    wandb.log({"val_loss": torch.mean(torch.Tensor(state.model.val_losses)).item()})
    state.model.val_losses = []
    self.state["val_time"] = time.time() - self.state["val_start_time"]
    wandb.log({"val_time": self.state["val_time"]})
    wandb.log({"total_time": self.state["train_time"] + self.state["val_time"]})

    self.state["val_start_time"] = time.time()

In [None]:
class FullUNetMosaicML(ComposerModel):
  def __init__(self,
    n_nodes = 13,
    filters = 32,
    filters_rate = 1.5,
    down_blocks = 4,
    stem_blocks = 0,
    up_blocks = 3,
    lr = 1e-4
  ):
    super().__init__()

    self.train_losses = []
    self.val_losses = []

    self.n_nodes = n_nodes
    self.filters = filters
    self.filters_rate = filters_rate
    self.down_blocks = down_blocks
    self.stem_blocks = stem_blocks
    self.up_blocks = up_blocks
    self.lr = lr
    
    self.unet = Unet(
        filters=filters, 
        filters_rate=filters_rate, 
        down_blocks=down_blocks, 
        stem_blocks=stem_blocks, 
        up_blocks=up_blocks
    )

    in_channels = int(
      filters
      * (
          filters_rate
          ** (down_blocks + stem_blocks - 1 - up_blocks + 1)
      )
    )

    self.last_conv = nn.Conv2d(in_channels=in_channels, out_channels=n_nodes, kernel_size=1, padding="same") 

  def forward(self, batch):
    x, _ = batch
    x = self.unet(x)
    x = self.last_conv(x)
    return x

  def eval_forward(self, batch, outputs=None):
    x, y = batch
    x = self.unet(x)
    x = self.last_conv(x)
    loss = nn.MSELoss()(x, y)
    self.val_losses.append(loss)
    return x, y

  def loss(self, outputs, batch):
    _, y = batch
    loss = nn.MSELoss()(outputs, y)
    return loss

In [None]:
train_ds = DataGenerator(labels)
train_dl = DataLoader(
    train_ds,
    batch_size=4,
    shuffle=True,
    num_workers=cores,
    pin_memory=True,
    drop_last=True,
    prefetch_factor=2
)

val_ds = DataGenerator(val_labels, is_train=False)
val_dl = DataLoader(
    val_ds,
    batch_size=4,
    shuffle=False,
    num_workers=cores,
    pin_memory=True,
    drop_last=True,
    prefetch_factor=2
)

In [None]:
model = FullUNetMosaicML()
opt = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
!pip freeze > requirements.txt

logger = WandBLogger(
    project="torch_vs_tf_talmo_lab", 
    name=f"torch_composer_run0", 
    tags=["composer"], 
    log_artifacts=False, rank_zero_only=True,
    init_kwargs={
      "config": {
          "is_composer": "true",
          "device_and_memory": get_vram(),
          "seed": seed,
          "model_param_count": get_param_count(model)
      }, 
      "notes": "This experiment was done in a Google Colab Notebook."     
    }
)

val_dl_eval = Evaluator(
    label='eval',
    dataloader=val_dl,
    metric_names=['MSELoss']
)

trainer = Trainer(
    model=model,
    train_dataloader=train_dl,
    eval_dataloader=val_dl_eval,
    max_duration=3,
    optimizers=opt,
    device="gpu",
    loggers=logger,
    callbacks=[LogLossAndTimeCallback()],
    seed=seed,
    precision=Precision.AMP_FP16,
)

VBox(children=(Label(value='0.001 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.445353…

0,1
loss/train/total,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
time/batch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
time/batch_in_epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
time/epoch,▁
time/sample,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
time/sample_in_epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▁
train_time,▁
trainer/device_train_microbatch_size,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss/train/total,0.00027
time/batch,799.0
time/batch_in_epoch,799.0
time/epoch,0.0
time/sample,3196.0
time/sample_in_epoch,3196.0
train_loss,0.00037
train_time,83.53102
trainer/device_train_microbatch_size,4.0


In [None]:
trainer.fit()

******************************
Config:
node_name: unknown because NODENAME environment variable not set
num_gpus_per_node: 1
num_nodes: 1
rank_zero_seed: 42

******************************


train          Epoch   0:    0%|| 0/800 [00:00<?, ?ba/s]         

eval           Epoch   0:    0%|| 0/100 [00:00<?, ?ba/s]         

train          Epoch   1:    0%|| 0/800 [00:00<?, ?ba/s]         

eval           Epoch   1:    0%|| 0/100 [00:00<?, ?ba/s]         

ERROR:composer.core.engine:Error running WandBLogger.post_close().
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/composer/core/engine.py", line 540, in _close
    callback.post_close()
  File "/usr/local/lib/python3.10/dist-packages/composer/loggers/wandb_logger.py", line 322, in post_close
    wandb.finish(0)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 3705, in finish
    wandb.run.finish(exit_code=exit_code, quiet=quiet)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 394, in wrapper
    return func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 333, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 328, in wrapper
    wandb._attach(run=self)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 876, in _attach
    raise UsageError(f"Unable to attach to run

train          Epoch   2:    0%|| 0/800 [00:00<?, ?ba/s]         

ERROR:composer.core.engine:Error running WandBLogger.post_close().
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/composer/core/engine.py", line 540, in _close
    callback.post_close()
  File "/usr/local/lib/python3.10/dist-packages/composer/loggers/wandb_logger.py", line 322, in post_close
    wandb.finish(0)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 3705, in finish
    wandb.run.finish(exit_code=exit_code, quiet=quiet)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 394, in wrapper
    return func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 333, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 328, in wrapper
    wandb._attach(run=self)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 876, in _attach
    raise UsageError(f"Unable to attach to run

eval           Epoch   2:    0%|| 0/100 [00:00<?, ?ba/s]         

ERROR:composer.core.engine:Error running WandBLogger.post_close().
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/composer/core/engine.py", line 540, in _close
    callback.post_close()
  File "/usr/local/lib/python3.10/dist-packages/composer/loggers/wandb_logger.py", line 322, in post_close
    wandb.finish(0)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 3705, in finish
    wandb.run.finish(exit_code=exit_code, quiet=quiet)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 394, in wrapper
    return func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 333, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 328, in wrapper
    wandb._attach(run=self)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 876, in _attach
    raise UsageError(f"Unable to attach to run

### With Functional API in Vanilla PyTorch

All algorithms provided are within `composer.algorithms` and also available via `composer.functional`. Algorithms are divided into 3 categories:
* data augmentation
* model surgery (affects the model)
* training specific

The data augmentation techniques can be implemented in the `Dataset` or applied in the training script after a batch is returned. However, Composer, to my knowledge, doesn't support keypoint estimation yet. I'll only consider methods for model surgery and training specific algorithms that don't modify the confidence map labels.

In [None]:
!pip freeze > requirements.txt

filters = 32
filters_rate = 1.5
down_blocks = 4
stem_blocks = 0
up_blocks = 3

# Some of this code is redundant. I kept the previous cells un-deleted just for reference.
for i in range(5): 
    train_ds = DataGenerator(labels)
    train_dl = DataLoader(
        train_ds,
        batch_size=4,
        shuffle=True,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    val_ds = DataGenerator(val_labels, is_train=False)
    val_dl = DataLoader(
        val_ds,
        batch_size=4,
        shuffle=False,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    unet = Unet(filters=filters, 
                filters_rate=filters_rate, 
                down_blocks=down_blocks, 
                stem_blocks=stem_blocks, 
                up_blocks=up_blocks)

    in_channels = int(
        filters
        * (
            filters_rate
            ** (down_blocks + stem_blocks - 1 - up_blocks + 1)
        )
    )
    model = nn.Sequential(*[
        unet, 
        nn.Conv2d(in_channels=in_channels, out_channels=13, kernel_size=1, padding="same")    
    ])

    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    scaler = torch.cuda.amp.GradScaler()
    model = model.to("cuda")

    # Model surgery: apply squeeze and excite!
    cf.apply_squeeze_excite(
        model,
        optimizers=opt,
        min_channels=128,
        latent_channels=64
    )

    run = wandb.init(
        project="torch_vs_tf_talmo_lab", 
        name=f"torch_composer_run{i}", 
        config={
            "is_composer": "true",
            "device_and_memory": get_vram(),
            "seed": seed,
            "model_param_count": get_param_count(model)
        }, 
        tags=["composer"],
        notes="This experiment was done in a Google Colab Notebook."
    )

    # Log dependencies.
    artifact = wandb.Artifact("Dependencies", type="dependencies")
    artifact.add_file("requirements.txt", name=f"requirements.txt")
    run.log_artifact(artifact)

    for epoch in range(3):
        _ = model.train()
        start_time = time.time()
        train_loss = 0
        for idx, batch in enumerate(train_dl):
            X, y = batch
            X = X.to("cuda")
            y = y.to("cuda")
            
            with torch.autocast("cuda"):
                y_preds = model(X)
                loss = nn.MSELoss()(y_preds, y)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            opt.zero_grad()
                    
            if idx % 100 == 0:
              print(f"Epoch: {epoch} | Loss: {loss:.5f}")

            train_loss += loss

        train_loss /= (idx+1)
        train_time = time.time() - start_time
        print(f"TRAIN: --- {train_time}s seconds ---")

        _  = model.eval()
        start_time = time.time()
        val_loss = 0
        for idx, batch in enumerate(val_dl):
            X, y = batch
            X = X.to("cuda")
            y = y.to("cuda")

            with torch.no_grad():
                y_preds = model(X)
                loss = nn.MSELoss()(y_preds, y)

            val_loss += loss

        val_loss /= (idx+1)
        val_time = time.time() - start_time
        print(f"VAL: --- {val_time}s seconds ---")

        run.log({
            "train_loss": train_loss,
            "val_loss": val_loss,
            "train_time": train_time,
            "val_time": val_time,
            "total_time": train_time + val_time
        })

    del model, opt, scaler, train_ds, train_dl, val_ds, val_dl
    gc.collect()
    run.finish()

Epoch: 0 | Loss: 0.03348
Epoch: 0 | Loss: 0.00045
Epoch: 0 | Loss: 0.00032
Epoch: 0 | Loss: 0.00030
Epoch: 0 | Loss: 0.00029
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
TRAIN: --- 66.11166882514954s seconds ---
VAL: --- 5.944556474685669s seconds ---
Epoch: 1 | Loss: 0.00028
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00026
Epoch: 1 | Loss: 0.00028
Epoch: 1 | Loss: 0.00028
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
TRAIN: --- 67.85983324050903s seconds ---
VAL: --- 5.8227386474609375s seconds ---
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
TRAIN: --- 65.82826566696167s seconds ---
VAL: --- 5.578967094421387s seconds ---


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
total_time,▃█▁
train_loss,█▁▁
train_time,▂█▁
val_loss,█▄▁
val_time,█▆▁

0,1
total_time,71.40723
train_loss,0.00027
train_time,65.82827
val_loss,0.00027
val_time,5.57897


Epoch: 0 | Loss: 0.01273
Epoch: 0 | Loss: 0.00041
Epoch: 0 | Loss: 0.00031
Epoch: 0 | Loss: 0.00029
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
TRAIN: --- 65.93731164932251s seconds ---
VAL: --- 5.426535129547119s seconds ---
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
TRAIN: --- 68.60044479370117s seconds ---
VAL: --- 5.886198997497559s seconds ---
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00026
TRAIN: --- 66.09496188163757s seconds ---
VAL: --- 6.071070194244385s seconds ---


0,1
total_time,▁█▃
train_loss,█▁▁
train_time,▁█▁
val_loss,█▆▁
val_time,▁▆█

0,1
total_time,72.16603
train_loss,0.00027
train_time,66.09496
val_loss,0.00026
val_time,6.07107


Epoch: 0 | Loss: 0.02120
Epoch: 0 | Loss: 0.00039
Epoch: 0 | Loss: 0.00030
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00027
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
TRAIN: --- 65.88023710250854s seconds ---
VAL: --- 5.117318391799927s seconds ---
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00028
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
TRAIN: --- 67.48214650154114s seconds ---
VAL: --- 6.644918918609619s seconds ---
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00026
TRAIN: --- 66.08652782440186s seconds ---
VAL: --- 4.989238262176514s seconds ---


0,1
total_time,▁█▁
train_loss,█▁▁
train_time,▁█▂
val_loss,█▅▁
val_time,▂█▁

0,1
total_time,71.07577
train_loss,0.00027
train_time,66.08653
val_loss,0.00026
val_time,4.98924


Epoch: 0 | Loss: 0.01121
Epoch: 0 | Loss: 0.00041
Epoch: 0 | Loss: 0.00031
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00027
Epoch: 0 | Loss: 0.00027
TRAIN: --- 66.76165175437927s seconds ---
VAL: --- 5.399186134338379s seconds ---
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
TRAIN: --- 67.72224473953247s seconds ---
VAL: --- 6.672939777374268s seconds ---
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00025
Epoch: 2 | Loss: 0.00025
TRAIN: --- 67.36186408996582s seconds ---
VAL: --- 5.866755962371826s seconds ---


0,1
total_time,▁█▄
train_loss,█▁▁
train_time,▁█▅
val_loss,█▆▁
val_time,▁█▄

0,1
total_time,73.22862
train_loss,0.00026
train_time,67.36186
val_loss,0.00025
val_time,5.86676


Epoch: 0 | Loss: 0.01585
Epoch: 0 | Loss: 0.00040
Epoch: 0 | Loss: 0.00031
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00028
Epoch: 0 | Loss: 0.00027
Epoch: 0 | Loss: 0.00027
TRAIN: --- 66.80710625648499s seconds ---
VAL: --- 5.4922308921813965s seconds ---
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00027
Epoch: 1 | Loss: 0.00026
Epoch: 1 | Loss: 0.00026
Epoch: 1 | Loss: 0.00027
TRAIN: --- 68.87202477455139s seconds ---
VAL: --- 6.0507121086120605s seconds ---
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00027
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00026
Epoch: 2 | Loss: 0.00026
TRAIN: --- 67.38010573387146s seconds ---
VAL: --- 6.8685829639434814s seconds ---


0,1
total_time,▁█▆
train_loss,█▁▁
train_time,▁█▃
val_loss,█▆▁
val_time,▁▄█

0,1
total_time,74.24869
train_loss,0.00026
train_time,67.38011
val_loss,0.00026
val_time,6.86858


## Optimizing PyTorch Code

The following training script is "Vanilla PyTorch with Mixed Precision". I selected this script because it was the fastest out of all the libraries. Here, I implement a few tricks to speed up training, though they don't seem to provide much noticeable benefit here. 

My next step is to profile this code and identify and mitigate bottlenecks. The code below is a little messy!

In [16]:
# Benchmark.
torch.backends.cudnn.benchmark = True

In [34]:
!pip freeze > requirements.txt

filters = 32
filters_rate = 1.5
down_blocks = 4
stem_blocks = 0
up_blocks = 3


# Some of this code is redundant. I kept the previous cells un-deleted just for reference.
for i in range(1): 
    train_ds = DataGenerator(labels)
    train_dl = DataLoader(
        train_ds,
        batch_size=4,
        shuffle=True,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    val_ds = DataGenerator(val_labels, is_train=False)
    val_dl = DataLoader(
        val_ds,
        batch_size=4,
        shuffle=False,
        num_workers=cores,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2
    )

    unet = Unet(filters=filters, 
                filters_rate=filters_rate, 
                down_blocks=down_blocks, 
                stem_blocks=stem_blocks, 
                up_blocks=up_blocks)

    in_channels = int(
        filters
        * (
            filters_rate
            ** (down_blocks + stem_blocks - 1 - up_blocks + 1)
        )
    )
    model = nn.Sequential(*[
        unet, 
        nn.Conv2d(in_channels=in_channels, out_channels=13, kernel_size=1, padding="same")    
    ])

    # Using AdamW instead of Adam.
    opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scaler = torch.cuda.amp.GradScaler()
    model = model.to("cuda")

    # model = torch.compile(model)

    # run = wandb.init(
    #     project="torch_vs_tf_talmo_lab", 
    #     name=f"torch_optim_run{i}", 
    #     config={
    #         "is_tf": False,
    #         "device_and_memory": get_vram(),
    #         "seed": seed,
    #         "model_param_count": get_param_count(model)
    #     }, 
    #     tags=["optimizing"],
    #     notes="This experiment was done in a Google Colab Notebook."
    # )

    # Log dependencies.
    # artifact = wandb.Artifact("Dependencies", type="dependencies")
    # artifact.add_file("requirements.txt", name=f"requirements.txt")
    # run.log_artifact(artifact)

    with profile(activities=[
          ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
      for epoch in range(1):
          _ = model.train()
          start_time = time.time()
          train_loss = 0
          for idx, batch in enumerate(train_dl):

              with record_function("loading_data"):
                  X, y = batch
                  X = X.to("cuda")
                  y = y.to("cuda")
              
              with torch.autocast("cuda"), record_function("forward_pass"):
                  y_preds = model(X)
                  loss = nn.MSELoss()(y_preds, y)
              scaler.scale(loss).backward()
              scaler.step(opt)
              scaler.update()

              break
          break

              # Setting gradients to None instead of 0.
    #           opt.zero_grad(set_to_none=True)
                      
    #           if idx % 100 == 0:
    #             print(f"Epoch: {epoch} | Loss: {loss:.5f}")

    #           train_loss += loss

    #       train_loss /= (idx+1)
    #       train_time = time.time() - start_time
    #       print(f"TRAIN: --- {train_time}s seconds ---")

    #       _  = model.eval()
    #       start_time = time.time()
    #       val_loss = 0
    #       for idx, batch in enumerate(val_dl):
    #           X, y = batch
    #           X = X.to("cuda")
    #           y = y.to("cuda")

    #           with torch.no_grad():
    #               y_preds = model(X)
    #               loss = nn.MSELoss()(y_preds, y)

    #           val_loss += loss

    #       val_loss /= (idx+1)
    #       val_time = time.time() - start_time
    #       print(f"VAL: --- {val_time}s seconds ---")

    #       # run.log({
    #       #     "train_loss": train_loss,
    #       #     "val_loss": val_loss,
    #       #     "train_time": train_time,
    #       #     "val_time": val_time,
    #       #     "total_time": train_time + val_time
    #       # })

    # del model, opt, scaler, train_ds, train_dl, val_ds, val_dl
    # gc.collect()
    # run.finish()

In [35]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        64.98%     197.318ms        65.00%     197.362ms     197.362ms       0.000us         0.00%       0.000us       0.000us             1  
                             aten::convolution_backward         4.03%      12.248ms         5.72%      17.380ms       1.022ms       6.593ms        40.92%       7.416ms     436.235us            17  
autograd:

It looks like my suspicion is true. The Dataset does take a bit of time to load! Here I demonstrate using a profiler for performance tuning. I won't update the dataset here as this isn't directly relevant to the report for this notebook, but you can imagine if the dataset bottleneck was alleviated, training would resume even faster!

More details on optimizing the PyTorch code can be found in the W&B Report!