# BCG-Unet Note

Please ensure that you have use `poetry` or `pip` to install the dependencies.

> run `poetry install` or `pip -r requirements.txt` to install them.

> If you are on Colab, the dependencies are already satisfied.


## Download Dataset

We need to download the dataset if it's not exist on your local machine.


In [None]:
# Download dataset if not present
import hashlib
import os
import requests
from tqdm import tqdm


DATASET_PATH = os.path.normpath("./data/EyeClose1_noscan.mat")
DATASET_URL = "https://bcgunet-data.csie.cool/EyeClose1_noscan.mat"


if not os.path.exists("data"):
    os.makedirs("data")
if not os.path.exists(DATASET_PATH):
    res = requests.get(DATASET_URL, stream=True)
    with tqdm.wrapattr(
        open(DATASET_PATH, "wb"),
        "write",
        miniters=1,
        desc="Downloading dataset",
        total=int(res.headers.get("content-length", default="0")),
    ) as file:
        for chunk in res.iter_content(chunk_size=4096):
            file.write(chunk)
        file.close()
    SHA = "d36a16ad45302971843a6d413eb833f407fd4146c3e32dfe23bdd8a17ccba2cb"
    assert (
        SHA == hashlib.sha256(open(DATASET_PATH, "rb").read()).hexdigest()
    ), "Dataset is corrupted!"
    print("Dataset downloaded!")
else:
    print("Dataset already exists!")

## Load the Dataset

There is a ECG channel and 31 EEG (BCE, BCG Corrupted EEG) channels in the dataset.

Also, a 31-channel OBS processed EEG is provided.


In [None]:
import os
import numpy as np
import h5py

if not os.path.exists(DATASET_PATH):
    raise FileNotFoundError(DATASET_PATH)

f = h5py.File(DATASET_PATH, "r")
ECG = np.array(f["ECG"]).flatten()
# BCG Corrupted EEG
BCE = np.array(f["EEG_before_bcg"]).T
EEG_OBS = np.array(f["EEG"]).T
print("shape of ECG:     ", ECG.shape)
print("shape of EEG:     ", BCE.shape)
print("shape of EEG_OBS: ", EEG_OBS.shape)

<!-- Let's use the data between 4 and 10 second of those signals. -->


In [None]:
import matplotlib.pyplot as plt

SAMPLE_FREQ = 5000
START_SEC = 4
END_SEC = 14

ECG_all = ECG
# ECG = ECG[SAMPLE_FREQ * START_SEC : SAMPLE_FREQ * END_SEC]
BCE_all = BCE
# BCE = BCE.T[SAMPLE_FREQ * START_SEC : SAMPLE_FREQ * END_SEC].T
EEG_OBS_all = EEG_OBS
# EEG_OBS = EEG_OBS.T[SAMPLE_FREQ * START_SEC : SAMPLE_FREQ * END_SEC].T

plt.figure(figsize=(20, 5))
plt.plot(ECG)
plt.title("ECG")
plt.show()

plt.figure(figsize=(20, 5))
plt.plot(BCE.T)
plt.title("BCE")
plt.show()

plt.figure(figsize=(20, 5))
plt.plot(EEG_OBS.T)
plt.title("EEG_OBS")
plt.show()

## Get the Baseline of each EEG Channel


In [None]:
from scipy.signal import butter, sosfilt


def butter_bandpass_filter(
    data: np.ndarray, fs: float, lowcut: float, highcut: float, order=5
):
    nyq = 0.5 * fs
    low = lowcut / nyq  # 1/fs [why?]
    high = highcut / nyq  # 0.8
    sos = butter(order, [low, high], analog=False, btype="band", output="sos")
    y = sosfilt(sos, data)
    return y


BCE_filtered = BCE * 0
BCE_all_filtered = BCE_all * 0
for ii in range(31):
    BCE_filtered[ii, ...] = butter_bandpass_filter(
        BCE[ii, :], SAMPLE_FREQ, 0.5, SAMPLE_FREQ * 0.4
    )
    BCE_all_filtered[ii, ...] = butter_bandpass_filter(
        BCE_all[ii, :], SAMPLE_FREQ, 0.5, SAMPLE_FREQ * 0.4
    )

BCE_baseline = BCE - BCE_filtered
BCE_all_baseline = BCE_all - BCE_all_filtered

plt.figure(figsize=(20, 5))
plt.plot(BCE_filtered.T)
plt.title("BCE_filtered")
plt.show()

plt.figure(figsize=(20, 5))
plt.plot(BCE_baseline.T)
plt.title("BCE_baseline")
plt.show()

In [None]:
# Store losses between different methods
loss_data = []

## Construct the BCG-Unet


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Conv(nn.Module):
    """convolution => [BN] => ReLU"""

    def __init__(self, in_channels, out_channels, num_groups=4):
        super().__init__()
        # self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv = nn.utils.weight_norm(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
        )
        self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.group_norm(x)
        x = self.activation(x)
        return x


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            Conv(in_channels, mid_channels), Conv(mid_channels, out_channels)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool1d(2), DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffX = x2.size()[2] - x1.size()[2]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet1d(nn.Module):
    def __init__(self, n_channels, n_classes, nfilter=24, nlayer=4):
        super(UNet1d, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.nfilter = nfilter
        self.nlayer = nlayer

        self.inc = DoubleConv(n_channels, nfilter)

        for i in range(nlayer - 1):
            setattr(self, f"down{i+1}", Down(nfilter * 2**i, nfilter * 2 ** (i + 1)))
        setattr(
            self,
            f"down{nlayer}",
            Down(nfilter * 2 ** (nlayer - 1), nfilter * 2 ** (nlayer - 1)),
        )

        for i in range(nlayer):
            setattr(
                self,
                f"up{i+1}",
                Up(
                    nfilter * 2 ** (nlayer - i),
                    nfilter * 2 ** (nlayer - i - 2 if nlayer - i - 2 >= 0 else 0),
                ),
            )

        self.outc = OutConv(nfilter, n_classes)

    def forward(self, x):
        x = self.inc(x)

        downs = []
        for i in range(self.nlayer):
            downs.append(x)
            x = getattr(self, f"down{i+1}")(x)

        for i in range(self.nlayer):
            x = getattr(self, f"up{i+1}")(x, downs.pop())

        logits = self.outc(x)
        return logits


EEG_CHANNEL = 31
LEARNING_RATE = 1e-3
ITERATION_STEPS = 5000
WINDOW = 2 * SAMPLE_FREQ

torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
NET = UNet1d(n_channels=1, n_classes=EEG_CHANNEL, nfilter=8).to(device)

print(NET)

## Let's Train the BCG-Unet


In [None]:
def run(nfilter = 8, nlayer = 5):
    torch.cuda.empty_cache()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    NET = UNet1d(n_channels=1, n_classes=EEG_CHANNEL, nfilter=nfilter, nlayer=nlayer).to(device)

    optimizer = torch.optim.Adam(NET.parameters())
    optimizer.zero_grad()
    maxlen = ECG.size

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=LEARNING_RATE, total_steps=ITERATION_STEPS
    )

    picked_idx = (np.random.random_sample(ITERATION_STEPS) * (maxlen - WINDOW)).astype(int)

    iter_loss = []
    loss_list = []

    pbar = tqdm(picked_idx)
    count = 0
    for idx in pbar:
        count += 1
        ECG_batch = ECG[idx : idx + WINDOW]
        BCE_batch = BCE_filtered[:, idx : idx + WINDOW]
        ECG_data = torch.from_numpy(ECG_batch[None, ...][None, ...]).float().to(device)
        BCE_data = torch.from_numpy(BCE_batch[None, ...]).float().to(device)

        logits = NET(ECG_data)
        loss = nn.functional.mse_loss(logits, BCE_data)
        loss_list.append(loss.item())

        loss.backward()  # Accumulate the gradients
        optimizer.step()  # Update network weights according to the optimizer
        optimizer.zero_grad()  # Reset the gradients
        scheduler.step()

        if count % 50 == 0:
            pbar.set_description(
                f"Loss {np.mean(loss_list):.3f}, lr: {optimizer.param_groups[0]['lr']:.5f}"
            )
            iter_loss.append(np.mean(loss_list))
            loss_list = []

    loss_data.append((iter_loss, nfilter, nlayer))

    plt.figure(figsize=(20, 5))
    plt.plot(iter_loss)
    plt.title(f"Loss ({nfilter} filters, {nlayer} layers)")
    plt.show()

run(nfilter=8, nlayer=4)

plt.figure(figsize=(40, 10))
for loss in loss_data:
    plt.plot(loss[0], label=f"{loss[1]} filters, {loss[2]} layers")
plt.title(f"Losses of {len(loss_data)} runs")
plt.show()

## Let's Make a Prediction with BCG-Unet


In [None]:
ECG_data = torch.from_numpy(ECG_all[None, ...][None, ...]).float().to(device)
BCE_data = torch.from_numpy(BCE_all_filtered[None, ...]).float().to(device)

print("ECG Channel: ", ECG_data.shape[0])

with torch.no_grad():
    logits = NET(ECG_data)
    BCG_pred = logits.cpu().detach().numpy()[0, ...]
    EEG_pred = BCE_all - BCG_pred + BCE_all_baseline
    pred_loss = nn.functional.mse_loss(logits, BCE_data).item()


print("BCG Channel: ", BCG_pred.shape[0])

print("prediction loss: ", pred_loss)

plt.figure(figsize=(20, 5))
plt.plot(BCG_pred.T)
plt.title("BCG_pred (channel:" + str(BCG_pred.shape[0]) + ")")
plt.show()

plt.figure(figsize=(20, 5))
plt.plot(EEG_pred.T)
plt.title("EEG_pred")
plt.show()

plt.figure(figsize=(20, 5))
plt.plot(EEG_OBS_all.T)
plt.title("EEG_OBS")
plt.show()