# BYOL in PyTorch

This is written by Ying Liu (yili@di.ku.dk) who is the student in the Department of Computer Science at Københavns Universitet.

## Algorithm

<div style="display: flex; justify-content: center; align-items: center; background: white;">
    <img src="https://imgur.com/rBZRkOf.png" style="width: 20%; margin-right: 1rem;" />
    <img src="https://imgur.com/uudkfAk.png" style="width: 20%;" />
</div>


There are three important components, they are **encoder** $f$ (ResNet-50), **projector** $g$ and **predictor** $q$, respectively.

- Online Network: encoder + projector + predictor, with parameters $\theta$;

- Target Network: encoder + projector, with parameters $\xi$.

During the training, the online network is updated by the **loss function** while the target network is updated slowly by the **exponential moving average** (EMA) of the online network.

## Env Selection

In [None]:
# env could be local or colab
# env = 'colab'
env = 'local'

## Google Mount

Mount to Google Drive to access my data.

In [None]:
if env == 'colab':
    from google.colab import drive
    drive.mount('/content/drive')

## Global Area

This contains import and global definition.

### Import Declearation

Download any needed libs and import them.

In [None]:
!pip install timm
!pip install pydicom
!pip install beartype
!pip install torchlars
!pip install accelerate

In [None]:
import os
import copy
import json
import timm
import torch
import random
import pydicom
import argparse
import numpy as np
import pandas as pd
from torch import nn
from PIL import Image
import multiprocessing
from torch import optim
from pathlib import Path
# from torchlars import LARS
from functools import wraps
from torch.nn import Module
from beartype import beartype
from torchvision import models
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn import SyncBatchNorm
from accelerate import Accelerator
from beartype.typing import Optional
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

### Global Definition

This part concludes all the global definitions.

In [None]:
path_data_folder = '/content/drive/MyDrive/master-thesis/data/fetched_data_by_type'
path_dict_map = '/content/drive/MyDrive/master-thesis/data/fetched_data_by_type/all_address_dict.json'
path_dict_map_4dct_last_phase = '/content/drive/MyDrive/master-thesis/data/imgsnarr_3dct_last_phase/all_address_dict_4dct_last_phase.json'
path_3dct_last_phase_folder = '/content/drive/MyDrive/master-thesis/data/imgsnarr_3dct_last_phase'
path_3dct_last_phase_padded_zeros_mean = '/content/drive/MyDrive/master-thesis/data/imgsnarr_3dct_last_phase/padded_zeros_mean.npy'
path_3dct_last_phase_padded_mean_mean = '/content/drive/MyDrive/master-thesis/data/imgsnarr_3dct_last_phase/padded_mean_mean.npy'
path_3dct_last_phase_padded_zeros_std = '/content/drive/MyDrive/master-thesis/data/imgsnarr_3dct_last_phase/padded_zeros_std.npy'
path_3dct_last_phase_padded_mean_std = '/content/drive/MyDrive/master-thesis/data/imgsnarr_3dct_last_phase/padded_mean_std.npy'

In [None]:
BATCH_SIZE = 16
MAX_EPOCHS = 300
LR = 0.2 * (BATCH_SIZE/256)
GLOBAL_WEIGTH_DECAY = 1.5e-6

## Load Data

**<font color=red>No need to perform if DICOM Arrary has been saved!!!</font>**

**NB**: The total number of 4DCT scans is 82, however, there is one more in the `map_dict` which is formed by counting the number of slices. By checking the number of folders right up the slices, 4DCT should have 20 instead of 10, for delineation. Therefore, by checking again, there is one cbct instead of ct classified in subject 119 - '12-13-2000-NA-p4-13619'.

### Read mapping File

This is to read the json file which records all the subject's 4DCT images with the corresponding addresses.

In [None]:
with open(path_dict_map, 'r') as f:
    dict_map = json.load(f)

subjects = list(dict_map.keys())
print('(subjects):', subjects)

### Reform and Correct mapping

Because there is one mistookly classified to be a 4DCT and it should be corrected from the dict_map.

In [None]:
# reform the dict_map to the new real path
for subject in subjects:
    paths_4dct = dict_map[subject]['CT']
    paths_4dcbct = dict_map[subject]['CBCT']
    for i, path_4dct in enumerate(paths_4dct):
        name_4dct = path_4dct.split('\\')[-1]
        paths_4dct[i] = os.path.join(path_data_folder, subject, name_4dct)
    dict_map[subject]['CT'] = paths_4dct

    for i, path_4dcbct in enumerate(paths_4dcbct):
        name_4dcbct = path_4dcbct.split('\\')[-1]
        paths_4dcbct[i] = os.path.join(path_data_folder, subject, name_4dcbct)
    dict_map[subject]['CBCT'] = paths_4dcbct
print(dict_map)

# correct the mistook one from subject 119
paths_4dct_mistook_119 = dict_map[subjects[19]]['CT']
for i, path in enumerate(paths_4dct_mistook_119):
    name_4dct = path.split('/')[-1]
    if name_4dct.startswith('12-13'):
        del dict_map[subjects[19]]['CT'][i]
        dict_map[subjects[19]]['CBCT'].append(path)
print(dict_map[subjects[19]]['CT'])

### Specify Phase

Choose a fixed phase, the end of the inhalation, which might be the last one among 10.

Besides, since the operation [-1] is not always chooses the last phase, therefore, in the second cell, a condition has been made to ensure the last phase is chosen.

In [None]:
# get the folders right under the folder
def get_subfolders(path_parent):
   return [os.path.join(path_parent, f) for f in os.listdir(path_parent) if os.path.isdir(os.path.join(path_parent, f))]

get_subfolders(dict_map[subjects[1]]['CT'][0])[-1]

In [None]:
# initialize a new dict_map_4dct_last_phase {subject:[folder_path]}
dict_map_4dct_last_phase = {}
for subject in subjects:
    dict_map_4dct_last_phase[subject] = []
    for path_4dct in dict_map[subject]['CT']:
        temp = get_subfolders(path_4dct)[-1]
        if len(os.listdir(temp)) == 1:
            temp = get_subfolders(path_4dct)[-2]
        dict_map_4dct_last_phase[subject].append(temp)
print(dict_map_4dct_last_phase)

# save the dict_map_4dct_last_phase
with open(path_dict_map_4dct_last_phase, 'w') as f:
    json.dump(dict_map_4dct_last_phase, f)

# compute the total number of 4dct
total_4dct = 0
for subject in subjects:
    total_4dct += len(dict_map_4dct_last_phase[subject])
print(total_4dct, dict_map_4dct_last_phase[subjects[0]])

### Collect DICOM Files

By maintaining a dict_map_4dct_last_phase, all the corresponding 3DCT (H, W, Slice) images have been loaded with respect to the subjects, and so that each subject could get the corresponding 3DCT images by reading the dcm files.

In [None]:
def collect_and_convert_dcm_files(path_folder):
    dcm_files = [os.path.join(path_folder, f) for f in os.listdir(path_folder) if f.endswith('.dcm')]
    slices = [pydicom.dcmread(dcm_file) for dcm_file in dcm_files]
    slices = np.array([s.pixel_array for s in slices])
    return dcm_files, slices.transpose(1, 2, 0)

_, array_slices = collect_and_convert_dcm_files(dict_map_4dct_last_phase[subjects[4]][0])
type(array_slices), array_slices.shape

In [None]:
# an image := 3D array (H, W, s)
dict_3dct_last_phase = {}
for subject in subjects:
    dict_3dct_last_phase[subject] = []
    for path_4dct_last_phase in dict_map_4dct_last_phase[subject]:
        _, array_slices = collect_and_convert_dcm_files(path_4dct_last_phase)
        image_3dct = array_slices
        dict_3dct_last_phase[subject].append(image_3dct)

In [None]:
if not os.path.exists(path_3dct_last_phase_folder):
    os.makedirs(path_3dct_last_phase_folder)
    # create folder for subject
    for subject in subjects:
        os.makedirs(os.path.join(path_3dct_last_phase_folder, subject))
    print('Last Phase 3DCT Images in Array - Folders Created')

# save the dict_3dct_last_phase
for subject in subjects:
    for i, image_3dct in enumerate(dict_3dct_last_phase[subject]):
        np.save(os.path.join(path_3dct_last_phase_folder, subject, f'3dct_{i}'), image_3dct)
print('All Saved')

## Load DICOM Array

Since all the 3DCT images with the shape (H, W, Slice) for each have been saved in a json file by the code above, therefore, load them directly into a dict_3dct_last_phase `{subject: [(3dct image)]}`.

In [None]:
# load the dict_map_4dct_last_phase
with open(path_dict_map_4dct_last_phase, 'r') as f:
    dict_map_4dct_last_phase = json.load(f)

subjects = list(dict_map_4dct_last_phase.keys())

# Load npy files and compose them into the same variable dict_3dct_last_phase
dict_3dct_last_phase = {}
for subject in subjects:
    dict_3dct_last_phase[subject] = []
    for i in range(len(dict_map_4dct_last_phase[subject])):
        dict_3dct_last_phase[subject].append(np.load(os.path.join(path_3dct_last_phase_folder, subject, f'3dct_{i}.npy')))

## Registered?

To see whether the slices are registered for same patient or for all of them.

### Inter

In [None]:
print(dict_3dct_last_phase[subjects[0]][0].shape) # only 1 ct for subject 0
for i in range(dict_3dct_last_phase[subjects[0]][0].shape[2]):
    plt.figure(figsize=(10, 6))
    plt.imshow(dict_3dct_last_phase[subjects[0]][0][:, :, i], cmap=plt.cm.bone)
    plt.title(f"{subjects[0].split('_')[0]} - slice {i}")
    plt.show()

In [None]:
print(dict_3dct_last_phase[subjects[1]][0].shape) # only 1 ct for subject 1
for i in range(dict_3dct_last_phase[subjects[1]][0].shape[2]):
    plt.figure(figsize=(10, 6))
    plt.imshow(dict_3dct_last_phase[subjects[1]][0][:, :, i], cmap=plt.cm.bone)
    plt.title(f"{subjects[1].split('_')[0]} - slice {i}")
    plt.show()

### Intra

In [None]:
print('num of ct for sub.19', len(dict_3dct_last_phase[subjects[19]]))
for image_3dct in dict_3dct_last_phase[subjects[19]]:
    print(image_3dct.shape)

# first ct of subject 19
for i in range(dict_3dct_last_phase[subjects[19]][0].shape[2]):
    plt.figure(figsize=(10, 6))
    plt.imshow(dict_3dct_last_phase[subjects[19]][0][:, :, i], cmap=plt.cm.bone)
    plt.title(f"{subjects[19].split('_')[0]} - slice {i}")
    plt.show()

In [None]:
for i in range(dict_3dct_last_phase[subjects[19]][1].shape[2]):
    plt.figure(figsize=(10, 6))
    plt.imshow(dict_3dct_last_phase[subjects[19]][1][:, :, i], cmap=plt.cm.bone)
    plt.title(f"{subjects[19].split('_')[0]} - slice {i}")
    plt.show()

In [None]:
for i in range(dict_3dct_last_phase[subjects[19]][2].shape[2]):
    plt.figure(figsize=(10, 6))
    plt.imshow(dict_3dct_last_phase[subjects[19]][2][:, :, i], cmap=plt.cm.bone)
    plt.title(f"{subjects[19].split('_')[0]} - slice {i}")
    plt.show()

## Data Preprocess

Aim is to compose all the data in the format of `(82, 512, 512, #slice)`.

In [None]:
# find the max_num_slices among all 3dct
max_num_slices = 0
for subject in subjects:
    for image_3dct in dict_3dct_last_phase[subject]:
        max_num_slices = max(max_num_slices, image_3dct.shape[2])
print('Maximum number of slices:', max_num_slices)

### Fill

By the data, 3DCT are different from each other along the slice dimension. Therefore, a padding technique should be considered to make them be composable to be (N, H, W, S).

#### Deprecated - First trial: Zero Padding

First trial: Generate the 0 slice to fill the 3DCT.

In [None]:
# first trial
# padding with 0s slices
dict_3dct_last_phase_padded_zeros = {}
for subject in subjects:
    dict_3dct_last_phase_padded_zeros[subject] = []
    for image_3dct in dict_3dct_last_phase[subject]:
        h, w, s = image_3dct.shape
        if s < max_num_slices:
            padding = np.zeros((h, w, max_num_slices - s))
            image_3dct = np.concatenate((image_3dct, padding), axis=2)
        dict_3dct_last_phase_padded_zeros[subject].append(image_3dct)
print(dict_3dct_last_phase_padded_zeros[subjects[0]][0].shape)

#### Second Trial: Mean Padding

Second trial: Generate the mean 2D slice to fill the 3DCT.

In [None]:
# pad the 3dct to the same size (H, W, max_num_slices) by adding mean 2d slices
dict_map_3dct_last_phase_padded_mean = {}
for subject in subjects:
    dict_map_3dct_last_phase_padded_mean[subject] = []
    for image_3dct in dict_3dct_last_phase[subject]:
        H, W, s = image_3dct.shape
        # pad the last dimension,
        pad_width = ((0, 0), (0, 0), (0, max_num_slices - s)) # computation is the number of slices to pad
        # pad with mean value of the 2d slices
        image_3dct_padded = np.pad(image_3dct, pad_width, mode='constant', constant_values=image_3dct.mean())
        dict_map_3dct_last_phase_padded_mean[subject].append(image_3dct_padded)
print(dict_map_3dct_last_phase_padded_mean[subjects[0]][0].shape)

visualize

In [None]:
for i in range(dict_map_3dct_last_phase_padded_mean[subjects[0]][0].shape[2]):
    plt.figure(figsize=(10, 6))
    plt.imshow(dict_map_3dct_last_phase_padded_mean[subjects[0]][0][:, :, i], cmap=plt.cm.bone)
    plt.title(f"Mean Padded {subjects[0].split('_')[0]} - slice {i}")
    plt.show()

### Dataset - Current Version is Mean

Compose and permute the dataset, and compute the mean and std for each channel.

#### Dataset Composition (N, H, W, S)

In [None]:
dataset_3dct_last_phase_padded_mean = np.zeros((82, 512, 512, 168))
for i, subject in enumerate(subjects):
    for j, image_3dct in enumerate(dict_map_3dct_last_phase_padded_mean[subject]):
        dataset_3dct_last_phase_padded_mean[i] = image_3dct
print(dataset_3dct_last_phase_padded_mean.shape)

#### Dataset Permutation (N, S, H, W)

In [None]:
# free RAM
dict_3dct_last_phase = None

In [None]:
# recompose to be the right order (N, C, H, W)
dataset_3dct_last_phase_padded_mean = torch.tensor(dataset_3dct_last_phase_padded_mean, dtype=torch.float32).permute(0, 3, 1, 2)

print(type(dataset_3dct_last_phase_padded_mean), dataset_3dct_last_phase_padded_mean.shape)

In [None]:
# free RAM
dict_3dct_last_phase_padded_mean = None

### Dataset Definition `class`

Model operations work with this class to support pytorch.

#### Dataset Declearation

To declear the class of the dataset which supports pytorch.

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.length = data.shape[0]
        self.num_slices = data.shape[1]
        self.image_size = data.shape[2]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
dataset = CustomDataset(dataset_3dct_last_phase_padded_mean)
dataset_3dct_last_phase_padded_mean = None

#### Dataset Split

Split the training_data, and test_data.

In [None]:
train_length = 72
test_length = len(dataset) - train_length

# random_split to split the entire dataset
training_data, test_data = random_split(dataset, [train_length, test_length])

print("Size (Trainset):", len(training_data))
print("Size (Testset):", len(test_data))

Apply dataloader to load them, train_dataloader and test_dataloader.

In [None]:
train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

###### Early Stopping Strategy

To employ early stopping, with the use of validation set.

In [None]:
# define split ratio within train_dataloader for the new train and val
split_ratio = 0.2
print('Original dataset size ==> ', len(train_dataloader.dataset))
train_size = int((1-split_ratio)*len(train_dataloader.dataset))
val_size = len(train_dataloader.dataset) - train_size
train_dataset, val_dataset = random_split(train_dataloader.dataset, [train_size, val_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
print('Train size ==> ', len(train_dataset))
print('Validation size ==> ', len(val_dataset))

#### Compute Mean and Std for Data Augment.

Compute the mean and std with 168 dimensions which are for the further data augmentations.

##### Compute and save

In [None]:
train_batch_mean = []
train_batch_std = []

for batch in train_dataloader:
    batch_mean = torch.mean(batch, dim=(0, 2, 3))
    batch_std = torch.std(batch, dim=(0, 2, 3))
    train_batch_mean.append(batch_mean)
    train_batch_std.append(batch_std)

train_mean = torch.stack(train_batch_mean).mean(dim=0)
train_std = torch.stack(train_batch_std).mean(dim=0)

print("Trainset Mean:", train_mean.shape)
print("Trainset Std:", train_std.shape)

In [None]:
np.save(path_3dct_last_phase_padded_mean_mean, train_mean.numpy())
np.save(path_3dct_last_phase_padded_mean_std, train_std.numpy())

train_batch_mean = None
train_batch_std = None
train_mean = None
train_std = None

##### Load the computed Mean and Std

In [None]:
mean = np.load(path_3dct_last_phase_padded_mean_mean)
std = np.load(path_3dct_last_phase_padded_mean_std)

## Model Implementation

Adjust the implementation of BYOL-PyTorch to fit the 3DCT data in my case.

**Citation**: [lucidrains](https://github.com/lucidrains/byol-pytorch/) and the other one [AI Summer](https://theaisummer.com/byol/).

### Helper Functions

In [None]:
# return the default value if the input value is None
def default(val, def_val):
    return def_val if val is None else val

# flatten a tensor
def flatten(t):
    # original shape: (N, C, H, W) -> (N, C*H*W)
    return t.reshape(t.shape[0], -1)

def singleton(cache_key):
    def inner_fn(fn): # the actual decorator, which takes the original function as input
        @wraps(fn) # preserve the metadata of the original function
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key) # get the instance from the cache
            if instance is not None: # if the instance is already created, return it
                return instance

            instance = fn(self, *args, **kwargs) # create the instance
            setattr(self, cache_key, instance) # store the instance to the cache
            return instance
        return wrapper
    return inner_fn

# get the device of the module
def get_module_device(module):
    return next(module.parameters()).device

# this function is used to set the requires_grad attribute of the model
def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

def MaybeSyncBatchnorm(is_distributed = None):
    is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
    return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d

### Loss Function

In [None]:
def loss_fn(x, y):
    # dim=-1 means the last dimension, p=2 means L2 normalization
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    # cosine similarity
    return 2 - 2 * (x * y).sum(dim=-1)

### <font color=red>Augmentation Utils</font>

#### Deprecated Version

In [None]:
# augmentation utils
class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

class MultiChannelColorJitter():
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.color_jitter = T.ColorJitter(brightness, contrast, saturation, hue)

    def __call__(self, image):
        # Split the image into individual channels
        channels = image.split(1)

        # Apply color jittering to each channel
        augmented_channels = [self.color_jitter(channel) for channel in channels]
        # Concatenate the augmented channels along the channel dimension
        augmented_image = torch.cat(augmented_channels, dim=0)

        return augmented_image

class MultiChannelRandomGrayscale:
    def __init__(self, p=0.1):
        self.p = p

    def __call__(self, image):
        channels = image.split(1)
        augmented_channels = [T.RandomGrayscale(p=self.p)(channel) for channel in channels]
        augmented_image = torch.cat(augmented_channels, dim=0)

        return augmented_image

class MultiChannelGaussianBlur:
    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        self.kernel_size = kernel_size
        self.sigma = sigma

    def __call__(self, image):
        channels = image.split(1)
        augmented_channels = [T.GaussianBlur(self.kernel_size, self.sigma)(channel) for channel in channels]
        augmented_image = torch.cat(augmented_channels, dim=0)

        return augmented_image

##### Deprecated nn.Module Version

In [None]:
# augmentation utils
class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# each augmentation follow is:
    # 1. split the image into individual channels, my case is 168 channels (H, W, 168)
    # 2. apply the augmentation to each channel
    # 3. concatenate the augmented channels along the channel dimension
# the final output is the augmented image with the same shape (H, W, 168)

class MultiChannelColorJitter(nn.Module):
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        super().__init__()
        self.color_jitter = T.ColorJitter(brightness, contrast, saturation, hue)

    def forward(self, image):
        # Split the image into individual channels
        channels = image.split(1)

        # Apply color jittering to each channel
        augmented_channels = [self.color_jitter(channel) for channel in channels]
        # Concatenate the augmented channels along the channel dimension
        augmented_image = torch.cat(augmented_channels, dim=0)

        return augmented_image

class MultiChannelRandomGrayscale(nn.Module):
    def __init__(self, p=0.1):
        super().__init__()
        self.p = p

    def forward(self, image):
        # Split the image into individual channels
        channels = image.split(1)
        augmented_channels = [T.RandomGrayscale(p=self.p)(channel) for channel in channels]
        augmented_image = torch.cat(augmented_channels, dim=0)

        return augmented_image

class MultiChannelGaussianBlur(nn.Module):
    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        super().__init__()
        self.kernel_size = kernel_size
        self.sigma = sigma

    def forward(self, image):
        # Split the image into individual channels
        channels = image.split(1)
        augmented_channels = [T.GaussianBlur(self.kernel_size, self.sigma)(channel) for channel in channels]
        augmented_image = torch.cat(augmented_channels, dim=0)

        return augmented_image

class MultiChannelRandomHorizontalFlip(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, image):
        # Split the image into individual channels
        channels = image.split(1)
        augmented_channels = [T.RandomHorizontalFlip(p=self.p)(channel) for channel in channels]
        augmented_image = torch.cat(augmented_channels, dim=0)

        return augmented_image

class MultiChannelRandomResizedCrop(nn.Module):
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)):
        super().__init__()
        self.size = size
        self.scale = scale
        self.ratio = ratio

    def forward(self, image):
        # Split the image into individual channels
        channels = image.split(1)
        augmented_channels = [T.RandomResizedCrop(self.size, scale=self.scale, ratio=self.ratio, antialias=True)(channel) for channel in channels]
        augmented_image = torch.cat(augmented_channels, dim=0)

        return augmented_image

class MultiChannelNormalize(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean
        self.std = std

    def forward(self, image):
        image = image.to(torch.float32)
        channels = image.split(1)
        augmented_channels = [T.Normalize(self.mean[i], self.std[i])(channel) for i, channel in enumerate(channels)]
        augmented_image = torch.cat(augmented_channels, dim=0).to(torch.float32)

        return augmented_image

#### Corrected

In [None]:
# augmentation utils
class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# the input seems to be x: (N, C, H, W)
class MultiChannelColorJitter(nn.Module):
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        super().__init__()
        self.color_jitter = T.ColorJitter(brightness, contrast, saturation, hue)

    def forward(self, x):
        # x := (N, C, H, W)
        N = x.shape[0]
        augmentated_x = torch.zeros_like(x)
        for i in range(N):
            # Split the image into individual channels
            channels = x[i].split(1)
            # Apply color jittering to each channel
            augmented_channels = [self.color_jitter(channel) for channel in channels]
            # Concatenate the augmented channels along the channel dimension
            augmented_image = torch.cat(augmented_channels, dim=0)
            augmentated_x[i] = augmented_image
        return augmentated_x

class MultiChannelRandomGrayscale(nn.Module):
    def __init__(self, p=0.1):
        super().__init__()
        self.p = p

    def forward(self, x):
        # x := (N, C, H, W)
        N = x.shape[0]
        augmentated_x = torch.zeros_like(x)
        for i in range(N):
            # Split the image into individual channels
            channels = x[i].split(1)
            augmented_channels = [T.RandomGrayscale(p=self.p)(channel) for channel in channels]
            augmented_image = torch.cat(augmented_channels, dim=0)
            augmentated_x[i] = augmented_image
        return augmentated_x

class MultiChannelRandomHorizontalFlip(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, x):
        # x := (N, C, H, W)
        N = x.shape[0]
        augmentated_x = torch.zeros_like(x)
        for i in range(N):
            # Split the image into individual channels
            channels = x[i].split(1)
            augmented_channels = [T.RandomHorizontalFlip(p=self.p)(channel) for channel in channels]
            augmented_image = torch.cat(augmented_channels, dim=0)
            augmentated_x[i] = augmented_image
        return augmentated_x

class MultiChannelGaussianBlur(nn.Module):
    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        super().__init__()
        self.kernel_size = kernel_size
        self.sigma = sigma

    def forward(self, x):
        # x := (N, C, H, W)
        N = x.shape[0]
        augmentated_x = torch.zeros_like(x)
        for i in range(N):
            # Split the image into individual channels
            channels = x[i].split(1)
            augmented_channels = [T.GaussianBlur(self.kernel_size, self.sigma)(channel) for channel in channels]
            augmented_image = torch.cat(augmented_channels, dim=0)
            augmentated_x[i] = augmented_image
        return augmentated_x

class MultiChannelRandomResizedCrop(nn.Module):
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)):
        super().__init__()
        self.size = size
        self.scale = scale
        self.ratio = ratio

    def forward(self, x):
        # x := (N, C, H, W)
        N = x.shape[0]
        augmentated_x = torch.zeros_like(x)
        for i in range(N):
            # Split the image into individual channels
            channels = x[i].split(1)
            augmented_channels = [T.RandomResizedCrop(self.size, scale=self.scale, ratio=self.ratio, antialias=True)(channel) for channel in channels]
            augmented_image = torch.cat(augmented_channels, dim=0)
            augmentated_x[i] = augmented_image
        return augmentated_x

class MultiChannelNormalize(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean
        self.std = std

    def forward(self, x):
        # x := (N, C, H, W)
        N = x.shape[0]
        augmentated_x = torch.zeros_like(x)
        for i in range(N):
            x[i] = x[i].to(torch.float32)
            # Split the image into individual channels
            channels = x[i].split(1)
            augmented_channels = [T.Normalize(self.mean[i], self.std[i])(channel) for i, channel in enumerate(channels)]
            augmented_image = torch.cat(augmented_channels, dim=0)
            augmentated_x[i] = augmented_image
        return augmentated_x

### Exponential Moving Average

In [None]:
# Exponential moving average
class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new


def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

### MLP

Class for projector and predictor.

In [None]:
def MLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size)
    )

### NetWrapper

The purpose of this class is to wrap the model and the ema_model, which means,
make the model and the ema_model as the attributes of the class.

In [None]:
# a wrapper class for the base neural network
# will manage the interception of the hidden layer output => interception := the output of the hidden layer
# and pipe it into the projecter and predictor nets

class NetWrapper(nn.Module):
    def __init__(self, net, projection_size, projection_hidden_size, layer=-2, sync_batchnorm=None):
        # layer: the layer to extract the feature, -2 means the last layer, -1 means the output layer
        super().__init__()
        self.net = net
        self.layer = layer

        self.projector = None
        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size

        self.sync_batchnorm = sync_batchnorm

        self.hidden = {}
        # register the hook to the layer, hook:=a function that is called when the hidden layer is computed
        # the hook will store the hidden layer output to the self.hidden
        self.hook_registered = False

    # find the layer to extract the feature
    def _find_layer(self):
        # if the layer is a string, find the layer by name
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        # if the layer is an integer, find the layer by index
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    # the hook function to store the hidden layer output
    def _hook(self, _, input, output):
        device = input[0].device
        # store the hidden layer output
        self.hidden[device] = flatten(output)

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    @singleton('projector') # create the projector only once
    def _get_projector(self, hidden):
        _, dim = hidden.shape
        create_mlp_fn = MLP
        projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size, sync_batchnorm=self.sync_batchnorm)
        return projector.to(hidden)

    def get_representation(self, x):
        if self.layer == -1:
            return self.net(x)

        # register the hook => store the hidden layer output
        if not self.hook_registered:
            self._register_hook()

        self.hidden.clear()
        _ = self.net(x)
        hidden = self.hidden[x.device]
        self.hidden.clear()

        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

    def forward(self, x, return_projection = True):
        representation = self.get_representation(x)

        if not return_projection:
            return representation

        projector = self._get_projector(representation)
        projection = projector(representation)
        return projection, representation

### Main BYOL

In [None]:
class BYOL(nn.Module):
    def __init__(
            self,
            net,
            image_size,
            hidden_layer = -2,
            projection_size = 256,
            projection_hidden_size = 4096,
            augment_fn = None,
            augment_fn2 = None,
            moving_average_decay = 0.99,
            use_momentum = True,
            sync_batchnorm = None
    ):
        super().__init__()
        self.net = net

        # default SimCLR augmentation
        DEFAULT_AUG = torch.nn.Sequential(
            RandomApply(
                MultiChannelColorJitter(0.8, 0.8, 0.8, 0.2),
                p = 0.3
            ),
            MultiChannelRandomGrayscale(p=0.2),
            MultiChannelRandomHorizontalFlip(),
            RandomApply(
                MultiChannelGaussianBlur((3, 3), (1.0, 2.0)),
                p = 0.2
            ),
            MultiChannelRandomResizedCrop((image_size, image_size)),
            MultiChannelNormalize(
                mean=torch.tensor(mean),
                std=torch.tensor(std)),
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)

        self.online_encoder = NetWrapper(
            net,
            projection_size,
            projection_hidden_size,
            layer = hidden_layer,
            sync_batchnorm = sync_batchnorm
        )

        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters, my data is (512, 512, 168)
        self.forward(torch.randn(2, 168, image_size, image_size, device=device))

    @singleton('target_encoder')
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        set_requires_grad(target_encoder, False)
        return target_encoder

    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_moving_average(self):
        assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
        assert self.target_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

    def forward(
            self,
            x,
            return_embedding = False,
            return_projection = True
    ):
        assert not (self.training and x.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'

        if return_embedding:
            return self.online_encoder(x, return_projection = return_projection)

        image_one, image_two = self.augment1(x), self.augment2(x)

        images = torch.cat([image_one, image_two], dim=0)

        online_projections, _ = self.online_encoder(images)
        online_predictions = self.online_predictor(online_projections)

        online_pred_one, online_pred_two = online_predictions.chunk(2, dim=0) # split the predictions into two parts

        with torch.no_grad():
            target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder

            target_projections, _ = target_encoder(images)
            target_projections = target_projections.detach()

            target_proj_one, target_proj_two = target_projections.chunk(2, dim=0)

        loss_one = loss_fn(online_pred_one, target_proj_two.detach())
        loss_two = loss_fn(online_pred_two, target_proj_one.detach())

        loss = loss_one + loss_two
        return loss.mean()

## Trainer

This section contains the functions to train the model and the like, work as a Trainer.

### Helper Functions

In [None]:
def exists(v):
    return v is not None

def cycle(dl):
    while True:
        for batch in dl:
            yield batch

### Commented MockDataset

In [None]:
# class
#class MockDataset(Dataset):
#    def __init__(self, image_size, length):
#        self.length = length
#        self.image_size = image_size

#    def __len__(self):
#        return self.length

#    def __getitem__(self, idx):
#        return torch.randn(168, self.image_size, self.image_size)

### Main Trainer from Official

In [None]:
class BYOLTrainer(Module):
    @beartype
    def __init__(
        self,
        net: Module,
        *,
        image_size: int,
        hidden_layer: str,
        learning_rate: float,
        dataset: Dataset,
        num_train_steps: int,
        batch_size: int = 16,
        optimizer_klass = Adam,
        checkpoint_every: int = 1000,
        checkpoint_folder: str = './checkpoints',
        byol_kwargs: dict = dict(),
        optimizer_kwargs: dict = dict(),
        accelerator_kwargs: dict = dict(),
    ):
        super().__init__()
        self.accelerator = Accelerator(**accelerator_kwargs)

        if dist.is_initialized() and dist.get_world_size() > 1:
            net = SyncBatchNorm.convert_sync_batchnorm(net)

        self.net = net

        self.byol = BYOL(net, image_size = image_size, hidden_layer = hidden_layer, **byol_kwargs)

        self.optimizer = optimizer_klass(self.byol.parameters(), lr = learning_rate, **optimizer_kwargs)

        self.dataloader = DataLoader(dataset, shuffle = True, batch_size = batch_size)

        self.num_train_steps = num_train_steps

        self.checkpoint_every = checkpoint_every
        self.checkpoint_folder = Path(checkpoint_folder)
        self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
        assert self.checkpoint_folder.is_dir()

        # prepare with accelerate
        (
            self.byol,
            self.optimizer,
            self.dataloader
        ) = self.accelerator.prepare(
            self.byol,
            self.optimizer,
            self.dataloader
        )

        self.register_buffer('step', torch.tensor(0))

    def wait(self):
        return self.accelerator.wait_for_everyone()

    def print(self, msg):
        return self.accelerator.print(msg)

    def forward(self):
        step = self.step.item()
        data_it = cycle(self.dataloader)

        for _ in range(self.num_train_steps):
            images = next(data_it)

            with self.accelerator.autocast():
                loss = self.byol(images)
                self.accelerator.backward(loss)

            self.print(f'loss: {loss.item():.3f}')

            self.optimizer.zero_grad()
            self.optimizer.step()

            self.wait()

            self.byol.update_moving_average()

            self.wait()

            if not (step % self.checkpoint_every) and self.accelerator.is_main_process:
                checkpoint_num = step // self.checkpoint_every
                checkpoint_path = self.checkpoint_folder / f'checkpoint.{checkpoint_num}.pt'
                torch.save(self.net.state_dict(), str(checkpoint_path))

            self.wait()

            step += 1

        self.print('training complete')

## Train

Use of official trainer version and the other one is self-implemented.

### Official Trainer Version

In [None]:
resnet50 = timm.create_model('resnet50', pretrained=True, in_chans=168)

In [None]:
! accelerate config 'default'

In [None]:
#dataset = MockDataset(512, 10000)
dataset = dataset
trainer = BYOLTrainer(
    resnet50,
    dataset = dataset,
    image_size = 512,
    hidden_layer = 'global_pool',
    learning_rate = 3e-4,
    num_train_steps = 100_000,
    batch_size = 16,
    checkpoint_every = 1000     # improved model will be saved periodically to ./checkpoints folder
)

In [None]:
accelerator = Accelerator()
trainer = accelerator.prepare(trainer)
trainer()

### Self-Implemented Train Version

In [None]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Without Early Stopping

In [None]:
resnet = (timm.create_model('resnet50', pretrained=False, in_chans=168)).to(device)
byol = (BYOL(resnet, image_size=512, hidden_layer='global_pool')).to(device)
#resnet = (timm.create_model('resnet50', pretrained=False, in_chans=168))
#byol = (BYOL(resnet, image_size=512, hidden_layer='global_pool'))

In [None]:
losses_per_epoch = []

opt = torch.optim.Adam(byol.parameters(), lr=3e-4)
EPOCHS = 100

for epoch in range(EPOCHS):
    losses = []
    running_loss = 0
    for i, images in enumerate(train_dataloader):
        images = images.to(device)

        loss = byol(images)
        losses.append(loss.item())

        opt.zero_grad()
        loss.backward()
        opt.step()

        byol.update_moving_average()

    epoch_loss = sum(losses)/len(losses)
    losses_per_epoch.append(epoch_loss)

    print(f"Epoch {epoch + 1} / {EPOCHS}, Loss: {epoch_loss}")

print("Training Complete")

torch.save(resnet.state_dict(), './resnet.pt')

In [None]:
plt.plot(losses_per_epoch, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()
plt.show()

#### Early Stopping Version

In [None]:
resnet = (timm.create_model('resnet50', pretrained=False, in_chans=168)).to(device)
byol = (BYOL(resnet, image_size=512, hidden_layer='global_pool')).to(device)
#resnet = (timm.create_model('resnet50', pretrained=False, in_chans=168))
#byol = (BYOL(resnet, image_size=512, hidden_layer='global_pool'))

##### Optimizier

Use LARS optimizer with a cosine decay learning rate schedule, without restarts, over 1000 epochs, with a warm-up period of 10 epochs.

In [None]:
base_opt = optim.SGD(byol.parameters(), lr=LR, weight_decay=GLOBAL_WEIGTH_DECAY)
opt = LARS(optimizer=base_opt, eps=1e-8, trust_coef=0.001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=MAX_EPOCHS)

##### Main Train Code

In [None]:
best_validation_loss = float('inf')
best_epoch = 0
no_improvement_count = 0
patience = 10
min_delta = 0.001
early_stop_epoch = 0

losses_per_epoch = []
val_losses_per_epoch = []

# opt = torch.optim.Adam(byol.parameters(), lr=3e-4)
EPOCHS = MAX_EPOCHS

for epoch in range(EPOCHS):
    losses = []

    for i, images in enumerate(train_dataloader):
        images = images.to(device)

        loss = byol(images)
        losses.append(loss.item())

        opt.zero_grad()
        loss.backward()
        opt.step()

        byol.update_moving_average()

    epoch_loss = sum(losses)/len(losses)
    losses_per_epoch.append(epoch_loss)

    print(f"Epoch {epoch + 1} / {EPOCHS}, Loss: {epoch_loss}")

    # validation step
    val_losses = []
    # disable gradient update
    with torch.no_grad():
        for j, val_images in enumerate(val_dataloader):
            val_images = val_images.to(device)

            val_loss = byol(val_images)
            val_losses.append(val_loss.item())

    val_loss_per_epoch = sum(val_losses)/len(val_losses)
    val_losses_per_epoch.append(val_loss_per_epoch)
    print(f'Epoch {epoch + 1} / {EPOCHS}, Val Loss: {val_loss_per_epoch}')

    # update lr scheduler
    scheduler.step()

    if val_loss_per_epoch < best_validation_loss - min_delta:
        best_validation_loss = val_loss_per_epoch
        best_epoch = epoch + 1 # record the best one for now
        no_improvement_count = 0
        # save the best model
        torch.save(resnet.state_dict(), '/content/drive/MyDrive/master-thesis/byol-results/resnet.pt')
    else:
        no_improvement_count += 1

    if no_improvement_count >= patience:
        print("Early stopping at epoch: ", epoch+1)
        early_stop_epoch = epoch + 1
        break

print("Training Complete, Best Epoch {}".format(best_epoch))

In [None]:
plt.plot(losses_per_epoch, label='Training Loss')
plt.plot(val_losses_per_epoch, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

## Performance

In [None]:
resnet = (timm.create_model('resnet50', pretrained=False, in_chans=168)).to(device)
resnet.load_state_dict(torch.load("/content/drive/MyDrive/master-thesis/byol-results/resnet.pt"))
byol = (BYOL(resnet, image_size=512, hidden_layer='global_pool')).to(device)

In [None]:
all_embeddings = []

for images in train_dataloader:
    images = images.to(device)
    _, embeddings = byol(images, return_embedding=True)
    all_embeddings.append(embeddings)

In [None]:
first_batch = [image for image in first_batch]

# 计算嵌入向量
embeddings = []
imgs = torch.randn(2, 3, 512, 512)
projection, embedding = byol(imgs, return_embedding = True)

# embeddings 列表中现在包含了第一个批次中每张图像的嵌入向量
embeddings.shape