In [None]:
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision
import math
import matplotlib.pyplot as plt
import torch
import urllib
import numpy as np
import PIL
import mne
import helper as h
import functionss as f
import os

In [None]:
def reverse_z_normalisation(X_normalized, std_X,mean_X):
    X = (X_normalized * std_X) + mean_X 
    return X

In [None]:
ch_names=['Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'Fz', 'Fp2', 'F4', 'C4', 'P4', 'O2', 'F8', 'T6', 'Cz', 'Pz']

In [None]:
def load_data_from_array(dataDir, group_label):
    data = []
    # iterate trough the files in the directory and load the data
    for i, filename in enumerate(os.scandir(dataDir)):
        if filename.name.endswith('.npy'):
            data_array = np.load(filename)
            if group_label == []:
                group_label.extend([i] * len(data_array))
            else:
                count = np.max(group_label)
                group_label.extend([count + 1] * len(data_array))
            data.append(data_array)
    return group_label, data

In [None]:
group_label, data = load_data_from_array('/home/ubuntu/Diffusion/own_diffusion/arrays_16/selected', [])

In [None]:
x_data = np.concatenate(data)

In [None]:
x_data.shape

In [None]:
x_data_norm, mean, std = f.ft_z_normalize(x_data)

In [None]:
f.array_to_edf(x_data_norm[0], ch_names, 'eeg', 'test.edf').plot()

In [None]:
f.array_to_edf(reverse_z_normalisation( x_data_norm[0],std, mean), ch_names, 'eeg', 'test.edf').plot()

In [None]:
x_data_norm.shape

In [None]:
x_data_norm = np.expand_dims(x_data_norm, axis=1)

In [None]:
x_data_norm.shape

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
torch_image ,x_min, x_max = h.transform_array(x_data_norm)

In [None]:
class DiffusionModel:
	def __init__(self, start_shedule=0.0001, end_schedule=0.02, timesteps=300, beta = 2):
		self.start_schedule = start_shedule
		self.end_schedule = end_schedule
		self.timesteps = timesteps
		self.beta = beta
		
		self.betas = torch.linspace(self.start_schedule, self.end_schedule, self.timesteps)
		self.alphas = 1-self.betas
		self.alpha_cumprod = torch.cumprod(self.alphas,axis=0)

	def forward(self, x0, t, device):
		noise = torch.randn_like(x0)/self.beta
		sqrt_alphas_cumprod_t = self.get_index_from_list(self.alpha_cumprod.sqrt(), t, x0.shape)
		sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alpha_cumprod), t, x0.shape)
		mean = sqrt_alphas_cumprod_t.to(device) * x0.to(device)
		variance = sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device)
		return mean + variance, noise.to(device)
	
	def backward(self, x, t, model, **kwargs):
		"""
		Calls the model to predict the noise in the image and returns 
		the denoised image. 
		Applies noise to this image, if we are not in the last step yet.
		"""
		betas_t = self.get_index_from_list(self.betas, t, x.shape)
		sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alpha_cumprod), t, x.shape)
		sqrt_recip_alphas_t = self.get_index_from_list(torch.sqrt(1.0 / self.alphas), t, x.shape)
		mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t, **kwargs) / sqrt_one_minus_alphas_cumprod_t)
		posterior_variance_t = betas_t

		if t == 0:
			return mean
		else:
			noise = torch.randn_like(x)
			variance = torch.sqrt(posterior_variance_t) * noise 
			return mean + variance

	
	@staticmethod
	def get_index_from_list(values, t, x_shape):
		batch_size = x_shape[0]
		result = values.gather(-1,t.cpu())

		return result.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


In [None]:
diffusion_model = DiffusionModel()

In [None]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, channels_in, channels_out, time_embedding_dims, labels, num_filters = 3, downsample=True):
        super().__init__()
        
        self.time_embedding_dims = time_embedding_dims
        self.time_embedding = SinusoidalPositionEmbeddings(time_embedding_dims)
        self.labels = labels
        if labels:
            self.label_mlp = nn.Linear(1, channels_out)
        
        self.downsample = downsample
        
        if downsample:
            self.conv1 = nn.Conv2d(channels_in, channels_out, num_filters, padding=1)
            self.final = nn.Conv2d(channels_out, channels_out, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(2 * channels_in, channels_out, num_filters, padding=1)
            self.final = nn.ConvTranspose2d(channels_out, channels_out, 4, 2, 1)
            
        self.bnorm1 = nn.BatchNorm2d(channels_out)
        self.bnorm2 = nn.BatchNorm2d(channels_out)
        
        self.conv2 = nn.Conv2d(channels_out, channels_out, 3, padding=1)
        self.time_mlp = nn.Linear(time_embedding_dims, channels_out)
        self.relu = nn.ReLU()

    def forward(self, x, t, **kwargs):
        o = self.bnorm1(self.relu(self.conv1(x)))
        o_time = self.relu(self.time_mlp(self.time_embedding(t)))
        o = o + o_time[(..., ) + (None, ) * 2]
        if self.labels:
            label = kwargs.get('labels')
            o_label = self.relu(self.label_mlp(label))
            o = o + o_label[(..., ) + (None, ) * 2]
            
        o = self.bnorm2(self.relu(self.conv2(o)))

        return self.final(o)

class UNet(nn.Module):
    def __init__(self, img_channels = 1, time_embedding_dims = 128, labels = False, sequence_channels = (64, 128, 256, 512, 1024)):
        super().__init__()
        self.time_embedding_dims = time_embedding_dims
        sequence_channels_rev = reversed(sequence_channels)
        
        self.downsampling = nn.ModuleList([Block(channels_in, channels_out, time_embedding_dims, labels) for channels_in, channels_out in zip(sequence_channels, sequence_channels[1:])])
        self.upsampling = nn.ModuleList([Block(channels_in, channels_out, time_embedding_dims, labels,downsample=False) for channels_in, channels_out in zip(sequence_channels[::-1], sequence_channels[::-1][1:])])
        self.conv1 = nn.Conv2d(img_channels, sequence_channels[0], 3, padding=1)
        self.conv2 = nn.Conv2d(sequence_channels[0], img_channels, 1)

    
    def forward(self, x, t, **kwargs):
        residuals = []
        o = self.conv1(x.float())
        for ds in self.downsampling:
            o = ds(o, t, **kwargs)
            residuals.append(o)
        for us, res in zip(self.upsampling, reversed(residuals)):
            o = us(torch.cat((o, res), dim=1), t, **kwargs)
            
        return self.conv2(o)

In [None]:
unet = UNet(labels=False)
unet.to(device)

NO_EPOCHS = 500
PRINT_FREQUENCY = 40
LR = 0.0001
BATCH_SIZE = 128
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)
VERBOSE= True


In [None]:
def plot_noise_distribution(noise, predicted_noise):
    plt.hist(noise.cpu().numpy().flatten(), density = True, alpha = 0.8, label = "ground truth noise")
    plt.hist(predicted_noise.cpu().numpy().flatten(), density = True, alpha = 0.8, label = "predicted noise")
    plt.legend()
    plt.show()


In [None]:
torch_image.shape 

In [None]:
#get all mins and max values of all the arraysin the array of (150,1,16,1024) for the arrays in the 3rd dimenstion and than take the mean

mins = []
maxs = []
for i in range(0,x_data_norm.shape[0]):
    mins.append(np.min(x_data_norm[i][0][0]))
    maxs.append(np.max(x_data_norm[i][0][0]))
x_min_av = np.mean(mins)
x_max_av= np.mean(maxs)

In [None]:
#get global min and max from tensor (150,1,16,1024)

x_min_global = np.min(x_data_norm)
x_max_global = np.max(x_data_norm)

In [None]:
print(x_min_global, x_max_global)
print(x_min, x_max)
print(x_min_av, x_max_av)

In [None]:
#anders nromeiren
#tensor = h.transform_array_global(x_data, xmin, xmax)

In [None]:
from torch.utils.data import DataLoader, TensorDataset
dataset = TensorDataset(torch_image)

In [None]:
trainloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8, drop_last=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
unet = UNet(labels=False)
unet.to(device)

NO_EPOCHS = 500
PRINT_FREQUENCY = 40
LR = 0.0001
BATCH_SIZE = 128
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)
VERBOSE= True


In [None]:
for epoch in range(250):
    mean_epoch_loss = []
    mean_epoch_loss_val = []
    for batch in trainloader:
        batch = batch[0]
        print(batch.shape)
        t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,)).long().to(device)
        batch = batch.to(device)
        batch_noisy, noise = diffusion_model.forward(batch, t, device) 
        predicted_noise = unet(batch_noisy, t)

        optimizer.zero_grad()
        loss = torch.nn.functional.mse_loss(predicted_noise.float(), noise.float()) 
        mean_epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    if epoch % PRINT_FREQUENCY == 0:
        print('---')
        print(f"Epoch: {epoch} | Train Loss {np.mean(mean_epoch_loss)}")
        if VERBOSE:
            with torch.no_grad():
                plot_noise_distribution(noise, predicted_noise)

In [None]:
#save unet
torch.save(unet, 'unet_conv2d_with_15_samples_znorm')

In [None]:
#load unet

unet = torch.load('unet_conv2d_with_13samples')

In [None]:
ch_names=['Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'Fz', 'Fp2', 'F4', 'C4', 'P4', 'O2', 'F8', 'T6', 'Cz', 'Pz']

In [None]:
NUM_DISPLAY_IMAGES = 5

In [None]:
imgs_tensor = torch.empty((NUM_DISPLAY_IMAGES, 1, 16, 1024)).to(device)

In [None]:
with torch.no_grad():
    for i in range(NUM_DISPLAY_IMAGES):
        img = torch.randn(1, 1, 16, 1024).to(device)
        for j in reversed(range(diffusion_model.timesteps)):
            t = torch.full((1,), j, dtype=torch.long, device=device)
            img = diffusion_model.backward(x=img, t=t, model=unet.eval().to(device))
            imgs_tensor[i] = img[0]

In [None]:
len(imgs_tensor)

In [None]:
def reverse_z_normalisation(X_normalized, std_X,mean_X):
    X = (X_normalized * std_X) + mean_X 
    return X

In [None]:
array = imgs_tensor[0].view(16,1024)
array = reverse_z_normalisation(array, std, mean)
raw = f.array_to_edf(array.cpu().numpy(), ch_names, 'eeg', 'test_eeg',sfreq=256)
raw.plot()

In [None]:
array = imgs_tensor[0].view(16, 1024)
array = reverse_z_normalisation(array, std, mean)
raw_global = h.reverse_transform_array(array, x_min_global, x_max_global)
raw_global=f.array_to_edf(raw_global,ch_names,'eeg','test.edf')
raw_local = h.reverse_transform_array(array,x_min, x_max)
raw_local=f.array_to_edf(raw_local,ch_names,'eeg','test.edf')
raw_av = h.reverse_transform_array(array, x_min_av, x_max_av)
raw_av=f.array_to_edf(raw_av,ch_names,'eeg','test.edf')

In [None]:
#plot the raws 
raw_global.plot()
plt.show()
raw_local.plot()
plt.show()
raw_av.plot()
plt.show()

In [None]:
raw_list_global_min_max = []
for idx, img in enumerate(imgs_tensor):
    array = img.view(16, 1024)
    #array = reverse_z_normalisation(array, std, mean) 
    raw_array=h.reverse_transform_array(array, x_min, x_max,changed_shape=False)
    raw=f.array_to_edf(raw_array,ch_names,'eeg','test.edf')
    raw_list_global_min_max.append(raw)
    raw.plot()

In [None]:
raw_list_av_min_max = []
for idx, img in enumerate(imgs_tensor):
    array = img.view(16, 1024)
    raw_array=h.reverse_transform_array(array, x_min_av, x_max_av,changed_shape=False)
    raw=f.array_to_edf(raw_array,ch_names,'eeg','test.edf')
    raw_list_av_min_max.append(raw)
    raw.plot()

In [None]:
def print_psd(raw_list):
    for raw in raw_list:
        try:
            spectrum = raw.compute_psd()
        except :
            print('ok')
        spectrum.plot(average=True, picks="data", exclude="bads")
        plt.show()

In [None]:
print_psd(raw_list_global_min_max)

In [None]:
print_psd(raw_list_av_min_max)

In [None]:
raw_from_source = f.array_to_edf(x_data[5], ch_names, 'eeg', 'test_eeg',sfreq=256)

In [None]:
print_psd([raw_from_source])