In [None]:
import torch
import urllib
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import time
import mne
import helper as h
import functionss as f

In [None]:
raw = mne.io.read_raw_edf('/home/ubuntu/hyperparameter_tuning/data/H_EC/H S7 EC.edf', preload=True)

In [None]:
raw.drop_channels(
                    [
                        "EEG A2-A1",
                        "EEG 23A-23R",
                        "EEG 24A-24R",
                        "EEG T6-LE",
                        "EEG Cz-LE",
                        "EEG Pz-LE",
                    ],
                    on_missing="ignore",
                )
raw = raw.rename_channels(
                    {
                        "EEG Fp1-LE": "Fp1",
                        "EEG F3-LE": "F3",
                        "EEG C3-LE": "C3",
                        "EEG P3-LE": "P3",
                        "EEG O1-LE": "O1",
                        "EEG F7-LE": "F7",
                        "EEG Fz-LE": "Fz",
                        "EEG Fp2-LE": "Fp2",
                        "EEG F4-LE": "F4",
                        "EEG C4-LE": "C4",
                        "EEG P4-LE": "P4",
                        "EEG O2-LE": "O2",
                        "EEG F8-LE": "F8",
                        "EEG T3-LE" : "T3",
                        "EEG T5-LE" : "T5",
                        "EEG T4-LE" : "T4",
                    }
                )

In [None]:
ch_names = raw.ch_names
len(ch_names)

In [None]:
array = raw.get_data()
array.shape
array = array[:, 0:1024]

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
torch_eeg = h.transform_array(array)

In [None]:
x = torch_eeg.shape[0]
y = torch_eeg.shape[1]
print(x,y)


In [None]:
torch_eeg = h.transform_array(array)

In [None]:
torch_reshaped, min_val, max_val = torch_eeg.reshape(1,x*y)
torch_reshaped.shape

In [None]:
ch_names=['Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'Fz', 'Fp2', 'F4', 'C4', 'P4', 'O2', 'F8', 'T6', 'Cz', 'Pz']
raw_array=h.reverse_transform_array(torch_reshaped, min_val, max_val,changed_shape=True)
raw_re=f.array_to_edf(raw_array,ch_names,'eeg','test.edf')
raw_re.plot()

In [None]:
raw.plot()

In [None]:
class DiffusionModel:
	def __init__(self, start_shedule=0.0001, end_schedule=0.02, timesteps=300):
		self.start_schedule = start_shedule
		self.end_schedule = end_schedule
		self.timesteps = timesteps
		
		self.betas = torch.linspace(self.start_schedule, self.end_schedule, self.timesteps)
		self.alphas = 1-self.betas
		self.alpha_cumprod = torch.cumprod(self.alphas,axis=0)

	def forward(self, x0, t, device, min_val, max_val):
		#noise= (min_val - max_val) * torch.rand_like(x0) + max_val * -1
		noise = torch.randn_like(x0)/2 #funktion mit std und min, max von meinen Daten
		sqrt_alphas_cumprod_t = self.get_index_from_list(self.alpha_cumprod.sqrt(), t, x0.shape)
		sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alpha_cumprod), t, x0.shape)
		mean = sqrt_alphas_cumprod_t.to(device) * x0.to(device)
		variance = sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device)
		return mean + variance, noise.to(device)
	
	def backward(self, x, t, model, **kwargs):
		"""
		Calls the model to predict the noise in the image and returns 
		the denoised image. 
		Applies noise to this image, if we are not in the last step yet.
		"""
		betas_t = self.get_index_from_list(self.betas, t, x.shape)
		sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alpha_cumprod), t, x.shape)
		sqrt_recip_alphas_t = self.get_index_from_list(torch.sqrt(1.0 / self.alphas), t, x.shape)
		mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t, **kwargs) / sqrt_one_minus_alphas_cumprod_t)
		posterior_variance_t = betas_t

		if t == 0:
			return mean
		else:
			noise = torch.randn_like(x)/2
			variance = torch.sqrt(posterior_variance_t) * noise 
			return mean + variance

	@staticmethod
	def get_index_from_list(values, t, x_shape):
		batch_size = x_shape[0]
		result = values.gather(-1,t.cpu())

		return result.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


In [None]:
diffusion_model = DiffusionModel()

In [None]:
NO_OF_IMAGES = 5
batch_images = torch.stack([torch_reshaped] * NO_OF_IMAGES)
t = torch.linspace(0, diffusion_model.timesteps - 1, NO_OF_IMAGES).long()
noisy_edfs, _ = diffusion_model.forward(batch_images, t, 'cpu', min_val = min_val, max_val = max_val)


In [None]:
for idx, edf in enumerate(noisy_edfs):
	raw_array=h.reverse_transform_array(edf, min_val, max_val,changed_shape=True)
	raw=f.array_to_edf(raw_array,ch_names,'eeg','test.edf')
	raw.plot()

In [None]:
diffusion_model = h.DiffusionModel()

In [None]:
def plot_noise_distribution(noise, predicted_noise):
    plt.hist(noise.cpu().numpy().flatten(), density = True, alpha = 0.8, label = "ground truth noise")
    plt.hist(predicted_noise.cpu().numpy().flatten(), density = True, alpha = 0.8, label = "predicted noise")
    plt.legend()
    plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class Block(nn.Module):
    def __init__(self, channels_in, channels_out, time_embedding_dims, downsample=True):
        super().__init__()
        self.time_embedding_dims = time_embedding_dims
        self.time_embedding = SinusoidalPositionEmbeddings(time_embedding_dims)

        if downsample:
            self.conv1 = nn.Conv1d(channels_in, channels_out, kernel_size=3, padding=1)
            self.final = nn.Conv1d(channels_out, channels_out, kernel_size=4, stride=2, padding=1)
        else:
            self.conv1 = nn.Conv1d(2*channels_in, channels_out, kernel_size=3, padding=1)
            self.final = nn.ConvTranspose1d(channels_out, channels_out, kernel_size=4, stride=2, padding=1)
        
        #Set rest of the layers
        self.bnorm1 = nn.BatchNorm1d(channels_out)
        self.bnorm2 = nn.BatchNorm1d(channels_out)

        self.conv2 = nn.Conv1d(channels_out, channels_out, kernel_size=3, padding=1)
        self.time_mlp = nn.Linear(time_embedding_dims, channels_out)
        self.relu = nn.ReLU()

    def forward(self, x, t, **kwargs):
        o=self.conv1(x)
        o = self.relu(o)
        o = self.bnorm1(o)
        t = self.time_embedding(t)
        t = self.time_mlp(t)
        o_time = self.relu(t)
        o_time = o_time[(..., ) + (None, ) * 1]
        o = o + o_time
        o = self.conv2(o)
        o = self.relu(o)
        o = self.bnorm2(o)
        o = self.final(o)
        return o

class UNet(nn.Module):
    def __init__(self, time_embedding_dims=128):
        super().__init__()
        time_embedding_dims = time_embedding_dims
        down_channels = [64, 128, 256]
        up_channels = [256, 128, 64]
        out_channels = 1
        in_channels = 1

        self.initial = nn.Conv1d(in_channels, down_channels[0], kernel_size=3, padding=1)
        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], time_embedding_dims) for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], time_embedding_dims, downsample=False) for i in range(len(up_channels)-1)])
        # Final layer
        self.final = nn.Conv1d(up_channels[-1], out_channels, 1)
    def forward(self, x, t):
        residuals = []
        o = self.initial(x.float())
        #this changes the channels (1) of the 1D Input to 64 channels
        for downsampling in self.downs:
            o = downsampling(o, t)
            residuals.append(o)
            print("####Downsampling####")
        for upsampling, res in zip(self.ups, reversed(residuals)):
            print("####Upsampling####")
            o=torch.cat((o, res), dim=1)
            o = upsampling(o, t)
        o = self.final(o)
        return o

In [None]:
unet = UNet()


In [None]:
NO_EPOCHS = 300
PRINT_FREQUENCY = 10
LR = 0.0001
BATCH_SIZE = 128
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)
VERBOSE= True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
for epoch in range (NO_EPOCHS):
    start_time = time.time()
    mean_epoch_loss = []
    #Having 128 batches of the same image
    batch = torch.stack([torch_reshaped] * BATCH_SIZE)
    #resulting shape is (128,1, 16384)
    t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,)).long().to(device)
    #resulting shape is (128)
    batch_noisy_images, noise = diffusion_model.forward(batch, t, device, min_val=min_val, max_val=max_val)
    #resulting shape is (128,1, 16384)
    predicted_noise = unet(batch_noisy_images, t)
    print('Data type of predicted_noise:', predicted_noise.dtype)
    print('Data type of noise:', noise.dtype)
    optimizer.zero_grad()
    loss=torch.nn.functional.mse_loss(predicted_noise.float(), noise.float())
    mean_epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()
    end_time = time.time()  # Stop measuring the epoch duration
    epoch_duration = end_time - start_time
    print(f"Epoch: {epoch} | Train Loss {np.mean(mean_epoch_loss)} | Duration {epoch_duration//60:.0f}m {epoch_duration%60:.0f}s")
    if epoch % PRINT_FREQUENCY == 0:
        print('---')
        print(f"Epoch: {epoch} | Train Loss {np.mean(mean_epoch_loss)}")
        if VERBOSE:
            with torch.no_grad():
                plot_noise_distribution(noise, predicted_noise)


In [None]:
#save unet
torch.save(unet, 'unet_Conv1d_HS8EC')

In [None]:
#load unet
unet = torch.load('unet_Conv1d_HS8EC')

In [None]:
with torch.no_grad():
    imgs = []
    raw_list = []
    img = torch.randn(1, 1, 16384).to(device)
    for i in reversed(range(diffusion_model.timesteps)):
        t = torch.full((1,), i, dtype=torch.long, device=device)
        img = diffusion_model.backward(img, t, unet.eval())
        print(img.shape)
        print(img[0].shape)
        if i % 50 == 0:
            imgs.append(img[0])
            raw_array=h.reverse_transform_array(img[0], min_val, max_val,changed_shape=True)
            raw=f.array_to_edf(raw_array,ch_names,'eeg','test.edf')
            raw_list.append(raw)
            raw.plot()

In [None]:
(imgs[5] - raw.get_data()).mean()

In [None]:
raw_list[-1].plot()

In [None]:
def print_psd(raw_list):
    for raw in raw_list:
        try:
            spectrum = raw.compute_psd()
        except :
            print('ok')
        spectrum.plot(average=True, picks="data", exclude="bads")
        plt.show()

In [None]:
print_psd(raw_list)

In [None]:
print_psd([raw])

In [None]:
#save raw as fif
raw_list[-1].save('overfitted_conv1d.fif', overwrite=True)