In [None]:
import scipy.io as sio
from torch.utils.data import DataLoader 
import torch
from torch import nn
import numpy as np
from torchsummary import summary
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Load the training dataset from google drive
filename1 = '/content/drive/MyDrive/denoiser/deconvolver_sin1.mat'
batch_size = 64

data1 = sio.loadmat(filename1)

train1 = data1['train']
test1 = data1['test']

train_data = DataLoader(train1,batch_size=batch_size,shuffle=True,num_workers=2)
test_data = DataLoader(test1,batch_size=batch_size,shuffle=True,num_workers=2)

In [None]:
# Define the CNN
class denoiser(nn.Module):
  def __init__(self):
    super(denoiser,self).__init__()
    self.conv = nn.Sequential(

        nn.Conv1d(1,2,kernel_size=9,padding=4,bias=False),
        nn.ReLU(inplace=True),

        nn.Conv1d(2,2,kernel_size=17,padding=8,bias=False),
        nn.ReLU(inplace=True),

        nn.Conv1d(2,2,kernel_size=17,padding=8,bias=False),
        nn.ReLU(inplace=True),

        nn.Conv1d(2,2,kernel_size=37,padding=18,bias=False),
        nn.ReLU(inplace=True),

        nn.Conv1d(2,1,kernel_size=9,padding=4,bias=False)
    )

  def forward(self,x):
    
    x = self.conv(x)
    return x

In [None]:
model = denoiser()
# Load the pre-set CNN as initialization
weight = sio.loadmat('/content/drive/MyDrive/denoiser/deconv22_pre.mat')
for i in range(9):
  if i%2==0:
    model._modules['conv'][i]._parameters['weight'].data = torch.from_numpy(weight['conv'+str(i//2+1)])

# Initialize the trainable layers with random values
model._modules['conv'][2]._parameters['weight'].data = torch.rand(2,2,17)
model._modules['conv'][4]._parameters['weight'].data = torch.rand(2,2,17)

# Set non-trainable layers
for name,value in model.named_parameters():
  if (name=='conv.0.weight')|(name=='conv.6.weight')|(name=='conv.8.weight'):
    value.requires_grad = False

In [None]:
# Set training parameters
epochs = 700
device = torch.device("cuda:0")
model = model.to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr=1e-3)
#optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=300,gamma=0.1)
loss = nn.MSELoss().to(device)
summary(model, (1, 500))

In [None]:
# Train the CNN
train_loss_epoch = []
test_loss_epoch = []
for epoch in range(epochs):
  train_losses = []
  test_losses = []
  for training in train_data:
    model.train()
    x_train = training[:,0,:].view(-1,1,500).type(torch.FloatTensor).to(device)
    y_train = training[:,2,:].view(-1,1,500).type(torch.FloatTensor).to(device)
    optimizer.zero_grad()
    y_pred = model(x_train)
    loss_train = loss(y_pred,y_train)
    loss_train.backward()
    optimizer.step()
    train_losses.append(loss_train.item())
  train_loss_epoch.append(np.mean(train_losses))
  scheduler.step()

  for testing in test_data:
    model.eval()
    x_test = testing[:,0,:].view(-1,1,500).type(torch.FloatTensor).to(device)
    y_test = testing[:,2,:].view(-1,1,500).type(torch.FloatTensor).to(device)
    y_pred = model(x_test)
    loss_test = loss(y_pred,y_test)
    test_losses.append(loss_test.item())
  test_loss_epoch.append(np.mean(test_losses))

  if (epoch%5)==0:
    print("Epoch: %d      train loss: %f      test loss: %f" %(epoch,np.mean(train_losses),np.mean(test_losses)))

  if ((epoch%100)==99)|(epoch==0):
    signal = y_test.view(-1,500).cpu().detach().numpy()
    denoised = y_pred.view(-1,500).cpu().detach().numpy()
    noisy = x_test.view(-1,500).cpu().detach().numpy()
    plt.figure(figsize=(10,10))
    plt.subplot(3,1,1)
    plt.plot(signal[0])
    plt.title('Pure Signal')
    plt.subplot(3,1,2)
    plt.plot(denoised[0])
    plt.title('Denoised Signal')
    plt.subplot(3,1,3)
    plt.plot(noisy[0])
    plt.title('Signal with Noise')


In [None]:
# Save the trained CNN
weight_trained = dict()
for i in range(9):
  if i%2==0:
    weight_trained['conv'+str(i//2+1)] = model._modules['conv'][i]._parameters['weight'].data.cpu().numpy()
sio.savemat('/content/drive/MyDrive/denoiser/deconv22.mat', mdict=weight_trained)