In [105]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn import preprocessing
from diffusers import UNet1DModel, DDPMScheduler, DDIMScheduler
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F

In [106]:
def get_rms(records, multi_channels):
    if multi_channels == 1:
        n = records.shape[0]
        rms = 0
        for i in range(n):
            rms_t = np.sum([records[i]**2]/len(records[i]))
            rms += rms_t
        return rms/n
    
    if multi_channels == 0:
        rms = np.sum([records**2])/ len(records)
        return rms


def snr(signal, noisy):
    snr = 10 * np.log10(signal/noisy)
    return snr

In [107]:
def random_signal(signal, comb):
    res = []

    for i in range(comb):
        rand_num = np.random.permutation(signal.shape[0])
        shuffled_dataset = signal[rand_num, :]
        shuffled_dataset = shuffled_dataset.reshape(signal.shape[0], signal.shape[1])
        res.append(shuffled_dataset)
    
    random_result = np.array(res)

    return random_result

In [108]:
def prepare_data(comb):
    eeg_data = np.load('./data/EEG_all_epochs.npy')
    noise_data = np.load('./data/EMG_all_epochs.npy')

    eeg_random = np.squeeze(random_signal(signal=eeg_data, comb=1))
    noise_random = np.squeeze(random_signal(signal=noise_data, comb=1))

    reuse_num = noise_random.shape[0] - eeg_random.shape[0]
    eeg_reuse = eeg_random[0: reuse_num, :]
    eeg_random = np.vstack([eeg_reuse, eeg_random])
    print(f'EEG shape after crop and resuse to match EMG samples: {eeg_random.shape[0]}')

    t = noise_random.shape[1]
    train_num = round(eeg_random.shape[0] * 0.9)
    test_num = round(eeg_random.shape[0] - train_num)

    train_eeg = eeg_random[0: train_num, :]
    test_eeg = eeg_random[train_num: train_num + test_num,:]

    train_noise = noise_random[0: train_num, :]
    test_noise = noise_random[train_num: train_num+test_num, :]

    EEG_train = random_signal(signal=train_eeg, comb=comb).reshape(comb * train_eeg.shape[0],t)
    NOISE_train = random_signal(signal=train_noise, comb=comb).reshape(comb * train_noise.shape[0], t)

    EEG_test = random_signal(signal=test_eeg, comb=comb).reshape(comb * test_eeg.shape[0],t)
    NOISE_test = random_signal(signal=test_noise, comb=comb).reshape(comb * test_noise.shape[0], t)

    print(f"train data clean shape: {EEG_train.shape}")
    print(f"train data noise shape: {NOISE_train.shape}")

    sn_train = []
    eeg_train = []
    all_sn_test = []
    all_eeg_test = []

    SNR_train_dB = np.random.uniform(-7.0, 3.0, (EEG_train.shape[0]))
    print(SNR_train_dB.shape)
    SNR_train = np.sqrt(10**(0.1*(SNR_train_dB)))


    for i in range(EEG_train.shape[0]):
        noise = preprocessing.scale(NOISE_train[i])
        EEG = preprocessing.scale(EEG_train[i])

        alpha = get_rms(EEG, 0) / (get_rms(noise, 0 ) * SNR_train[i])
        noise *= alpha
        signal_noise = EEG + noise

        sn_train.append(signal_noise)
        eeg_train.append(EEG)
    
    SNR_test_dB = np.linspace(-7.0, 3.0, num=(11))
    SNR_test = np.sqrt(10 ** (0.1 * SNR_test_dB))

    for i in range(11):
        sn_test = []
        eeg_test = []
        for k in range(EEG_test.shape[0]):
            noise = preprocessing.scale(NOISE_test[k])
            EEG = preprocessing.scale(EEG_test[k])

            alpha = get_rms(EEG,0) / (get_rms(noise, 0) * SNR_test[i])
            noise *= alpha
            signal_noise = EEG + noise

            sn_test.append(signal_noise)
            eeg_test.append(EEG)
        
        sn_test = np.array(sn_test)
        eeg_test = np.array(eeg_test)

        all_sn_test.append(sn_test)
        all_eeg_test.append(eeg_test)
    
    X_train = np.array(sn_train)
    y_train = np.array(eeg_train)

    X_test = np.array(all_sn_test)
    y_test = np.array(all_eeg_test)

    X_train = np.expand_dims(X_train, axis=1)
    y_train = np.expand_dims(y_train, axis=1)

    X_test = np.expand_dims(X_test, axis=2)
    y_test = np.expand_dims(y_test, axis=2)

    print(X_train.shape, y_train.shape)
    print(X_test.shape, y_test.shape)

    return [X_train, y_train, X_test, y_test]

In [109]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [110]:
X_train, y_train, X_test, y_test = prepare_data(11)

EEG shape after crop and resuse to match EMG samples: 5598
train data clean shape: (55418, 512)
train data noise shape: (55418, 512)
(55418,)
(55418, 1, 512) (55418, 1, 512)
(11, 6160, 1, 512) (11, 6160, 1, 512)


In [111]:
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size = 0.1, random_state=42
)

In [112]:
X_train = X_train.astype(np.float32)
y_train = y_train.astype(np.float32)
X_val = X_val.astype(np.float32)
y_val = y_val.astype(np.float32)

In [113]:
print(X_val.shape, y_val.shape)

(5542, 1, 512) (5542, 1, 512)


In [114]:
X_train, y_train = torch.from_numpy(X_train), torch.from_numpy(y_train)
X_val, y_val = torch.from_numpy(X_val), torch.from_numpy(y_val)
print(type(X_train), type(y_train))
print(X_train.shape)
print(type(X_val), type(y_val))
print(X_val.shape)
# X_train.to(device), y_train.to(device)

<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([49876, 1, 512])
<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([5542, 1, 512])


In [115]:
class EEgDataSet(Dataset):
    def __init__(self, X_noisy: torch.tensor, y_clean: torch.tensor):
        self.xn = X_noisy.float()
        self.yn = y_clean.float()
    
    def __len__(self):
        return self.xn.shape[0]
    
    def __getitem__(self, i):
        return self.xn[i], self.yn[i]
    

In [None]:
model = UNet1DModel(
    sample_size=512,
    in_channels=2,
    out_channels=1,
    layers_per_block=3,
    block_out_channels=(64,128,256),
    down_block_types=(
        "DownBlock1D",
        "DownBlock1D",
        "DownBlock1D",
    ),
    up_block_types=(
        "UpBlock1D",
        "UpBlock1D",
        "UpBlock1D",
    ),
)
model.to(device)
model.parameters

<bound method Module.parameters of UNet1DModel(
  (time_proj): GaussianFourierProjection()
  (down_blocks): ModuleList(
    (0): DownBlock1D(
      (down): Downsample1d()
      (resnets): ModuleList(
        (0): ResConvBlock(
          (conv_skip): Conv1d(2, 32, kernel_size=(1,), stride=(1,), bias=False)
          (conv_1): Conv1d(2, 32, kernel_size=(5,), stride=(1,), padding=(2,))
          (group_norm_1): GroupNorm(1, 32, eps=1e-05, affine=True)
          (gelu_1): GELU(approximate='none')
          (conv_2): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,))
          (group_norm_2): GroupNorm(1, 32, eps=1e-05, affine=True)
          (gelu_2): GELU(approximate='none')
        )
        (1-2): 2 x ResConvBlock(
          (conv_1): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,))
          (group_norm_1): GroupNorm(1, 32, eps=1e-05, affine=True)
          (gelu_1): GELU(approximate='none')
          (conv_2): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(

In [117]:
def multiscale_cond(x):
    c2 = F.interpolate(F.avg_pool1d(x,2,2,ceil_mode=True), size=x.shape[-1], mode="linear", align_corners=False)
    c4 = F.interpolate(F.avg_pool1d(x,4,4,ceil_mode=True), size=x.shape[-1], mode="linear", align_corners=False)
    return torch.cat([x, c2, c4], dim=1)

In [118]:
class Conditioner(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_channels, hidden, kernel_size=7, padding=3, bias=False),
            nn.GroupNorm(8, hidden),
            nn.SiLU(),
            nn.Conv1d(hidden, hidden, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, hidden),
            nn.SiLU(),
            nn.Conv1d(hidden, out_channels, kernel_size=3, padding=1)
        )
    
    def forward(self,x):
        return self.net(x)

cond_net = Conditioner().to(device)

In [119]:
scheduler = DDPMScheduler(
    num_train_timesteps=400,
    beta_schedule="linear"
)

# scheduler = DDPMScheduler(
#     num_train_timesteps=1000,
#     beta_schedule="linear"
# )

In [120]:
from torch import amp
from torch.optim.lr_scheduler import CosineAnnealingLR
import copy

def train(model, scheduler, X_train, y_train, X_val, y_val,
           *, epochs=10, batch_size=512, lr=2e-4,
           wd=1e-5, grad_clip=1.0, cond_net):
    data = EEgDataSet(X_train, y_train)
    val_data = EEgDataSet(X_val, y_val)

    dl = DataLoader(data, 
                    batch_size=batch_size,
                    shuffle=True,
                    drop_last=True,
                    pin_memory=True)
    
    val_dl = DataLoader(val_data, 
                batch_size=batch_size,
                shuffle=False,
                drop_last=True,
                pin_memory=True)
    
    optim = torch.optim.AdamW(list(model.parameters()) + list(cond_net.parameters()), lr=lr, weight_decay=wd)
    scalar = amp.GradScaler(device=device)
    import copy
    ema_decay = 0.999
    ema_model = copy.deepcopy(model).to(device).eval()
    
    @torch.no_grad()
    def ema_update():
        for p, q in zip(model.parameters(), ema_model.parameters()):
            q.data.mul_(ema_decay).add_(p.data, alpha=1 - ema_decay)

    warmup_epochs = 5
    lr_sched = CosineAnnealingLR(optim, T_max=epochs-warmup_epochs, eta_min=1e-5)
    model.train()
    cond_net.train()

    for e in range(1, epochs + 1):

        if e <= warmup_epochs:
            warm_lr = lr * e/warmup_epochs
            for pg in optim.param_groups:
                pg["lr"] = warm_lr

        total = 0.0
        n = 0
        val_total = 0.0
        val_n = 0

        for x_noisy, x_clean in dl:
            x_noisy = x_noisy.to(device, non_blocking=True).float()
            x_clean = x_clean.to(device, non_blocking=True).float()

            B = x_clean.size(0)
            t = torch.randint(0, scheduler.config.num_train_timesteps, (B, ), device=device).long()
            noise = torch.randn_like(x_clean)
            x_t = scheduler.add_noise(x_clean, noise,t)
            cond = cond_net(multiscale_cond(x_noisy))
            x_in = torch.cat([x_t, cond], dim=1)
            with amp.autocast(device_type=device, dtype=torch.float16):
                pred_noise = model(x_in, t).sample
                loss = F.mse_loss(pred_noise, noise)

            optim.zero_grad(set_to_none=True)
            scalar.scale(loss).backward()
            if grad_clip:
                scalar.unscale_(optim)
                torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(cond_net.parameters()), grad_clip)
            scalar.step(optim)
            scalar.update()
            ema_update()

            total += loss.item() * B

            n += B
        
        if e > warmup_epochs:
            lr_sched.step()
        model.eval()
        cond_net.eval()
        val_loss = 0
        with torch.no_grad():
            for x_noisy, x_clean in val_dl:
                x_noisy = x_noisy.to(device).float()
                x_clean = x_clean.to(device).float()

                B= x_clean.size(0)
                t = torch.randint(0, scheduler.config.num_train_timesteps, (B,), device=device).long()
                noise = torch.randn_like(x_clean)
                x_t = scheduler.add_noise(x_clean, noise, t)
                cond = cond_net(multiscale_cond(x_noisy))
                with amp.autocast(device_type=device, dtype=torch.float16):
                    pred_noise = ema_model(torch.cat([x_t, cond], dim=1), t).sample
                    loss = F.mse_loss(pred_noise, noise)
                val_total += loss.item() * B
                val_n += B
        val_loss = val_total / val_n
        model.train()
        cond_net.train()
        print(f"epoch {e}: train loss = {total/n: .4f}, val_los = {val_loss:.4f}")
    
    return ema_model
            

In [121]:
@torch.no_grad()
def denoise(model, scheduler, x_noisy, *, strength=0.25, num_inference_steps=50, cond_net):
    model.eval()
    cond_net.eval()
    x = x_noisy.float().to(device)

    scheduler.set_timesteps(num_inference_steps)
    t_start = int(max(1, min(num_inference_steps-1, round(strength * num_inference_steps))))
    timesteps = scheduler.timesteps[t_start:]
    start_t = scheduler.timesteps[t_start]

    outs = []
    for i in range(0, x.size(0), 128):
        X_noisy = x[i:i+128]
        noise = torch.randn_like(X_noisy)
        x_t = scheduler.add_noise(X_noisy, noise, start_t)
        c = cond_net(multiscale_cond(X_noisy))
        for t in timesteps:
            x_in = torch.cat([x_t, c], dim=1)
            eps = model(x_in, t).sample
            x_t = scheduler.step(eps, t, x_t).prev_sample
        
        outs.append(x_t.detach().cpu())
    return torch.cat(outs, dim=0)

In [122]:
def rrmse_time(yhat, y):
    num = torch.mean((yhat- y)**2, dim=-1).sqrt()
    den = torch.mean(y**2, dim=-1).sqrt() + 1e-8
    return (num/den).mean().item()

def cc (yhat, y):
    yhat = yhat - yhat.mean(dim=-1, keepdim=True)
    y = y - y.mean(dim=-1, keepdim=True)
    num = (yhat*y).sum(dim=-1)
    den = (yhat.norm(dim=-1)*y.norm(dim=-1) + 1e-8)

    return (num/den).mean().item()

In [123]:
ema_model = train(model, scheduler, X_train, y_train, X_val, y_val, epochs=100, batch_size=512, lr=1e-3, wd=0, cond_net=cond_net)

epoch 1: train loss =  0.8233, val_los = 1.7311
epoch 2: train loss =  0.6703, val_los = 1.3009
epoch 3: train loss =  0.6656, val_los = 1.0726
epoch 4: train loss =  0.6587, val_los = 0.9241
epoch 5: train loss =  0.6562, val_los = 0.8302
epoch 6: train loss =  0.6517, val_los = 0.7768
epoch 7: train loss =  0.6501, val_los = 0.7517
epoch 8: train loss =  0.6478, val_los = 0.7255
epoch 9: train loss =  0.6461, val_los = 0.7142
epoch 10: train loss =  0.6455, val_los = 0.7045
epoch 11: train loss =  0.6452, val_los = 0.6980
epoch 12: train loss =  0.6442, val_los = 0.6948
epoch 13: train loss =  0.6428, val_los = 0.6866
epoch 14: train loss =  0.6426, val_los = 0.6817
epoch 15: train loss =  0.6418, val_los = 0.6820
epoch 16: train loss =  0.6417, val_los = 0.6710
epoch 17: train loss =  0.6409, val_los = 0.6732
epoch 18: train loss =  0.6412, val_los = 0.6672
epoch 19: train loss =  0.6405, val_los = 0.6642
epoch 20: train loss =  0.6405, val_los = 0.6593
epoch 21: train loss =  0.638

In [124]:
X_test_t = torch.from_numpy(X_test).to(device)
y_test_t = torch.from_numpy(y_test).to(device)

print(X_test_t[0].shape)

torch.Size([6160, 1, 512])


In [125]:
inference_scheduler = DDIMScheduler.from_config(scheduler.config)

In [126]:
for s in [0, 2, 5, 6, 8, 10]:
    print(f"Testing for SNR: {-7+s}")
    y_ref = y_test_t[s].to(device)
    y_hat = denoise(ema_model, inference_scheduler, X_test_t[s], strength=0.6, num_inference_steps=150, cond_net=cond_net)
    y_hat = y_hat.to(device)
    m = rrmse_time(y_hat, y_ref)
    c = cc(y_hat, y_ref)
    print(f"RRMSE_t={m:.4f}, CC={c:.4f}")

Testing for SNR: -7
RRMSE_t=0.9415, CC=0.4687
Testing for SNR: -5
RRMSE_t=0.8828, CC=0.5309
Testing for SNR: -2
RRMSE_t=0.7959, CC=0.6199
Testing for SNR: -1
RRMSE_t=0.7688, CC=0.6469
Testing for SNR: 1
RRMSE_t=0.7203, CC=0.6941
Testing for SNR: 3
RRMSE_t=0.6810, CC=0.7315
