In [None]:
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision
import math
import matplotlib.pyplot as plt
import torch
import urllib
import numpy as np
import PIL
import mne
import helper as h
import functionss as f

In [None]:
raw_h = mne.io.read_raw_edf('/home/ubuntu/Diffusion/own_diffusion/test_EEGs/H S10 EC.edf', preload=True)

In [None]:
raw_h.drop_channels(
                    [
                        "EEG A2-A1",
                        "EEG 23A-23R",
                        "EEG 24A-24R",
                        "EEG T6-LE",
                        "EEG Cz-LE",
                        "EEG Pz-LE",
                    ],
                    on_missing="ignore",
                )
raw_h = raw_h.rename_channels(
                    {
                        "EEG Fp1-LE": "Fp1",
                        "EEG F3-LE": "F3",
                        "EEG C3-LE": "C3",
                        "EEG P3-LE": "P3",
                        "EEG O1-LE": "O1",
                        "EEG F7-LE": "F7",
                        "EEG Fz-LE": "Fz",
                        "EEG Fp2-LE": "Fp2",
                        "EEG F4-LE": "F4",
                        "EEG C4-LE": "C4",
                        "EEG P4-LE": "P4",
                        "EEG O2-LE": "O2",
                        "EEG F8-LE": "F8",
                        "EEG T3-LE" : "T3",
                        "EEG T5-LE" : "T5",
                        "EEG T4-LE" : "T4",
                    }
                )

In [None]:
ch_names = raw_h.ch_names

In [None]:
len(ch_names)

In [None]:
array = raw_h.get_data()

In [None]:
array.shape

In [None]:
array_cut = array[:, 1024*6:1024*7]

In [None]:
raw = f.array_to_edf(array_cut, ch_names, 'eeg', 'test.edf')

In [None]:
raw.plot()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class DiffusionModel:
	def __init__(self, start_shedule=0.0001, end_schedule=0.02, timesteps=300):
		self.start_schedule = start_shedule
		self.end_schedule = end_schedule
		self.timesteps = timesteps
		
		self.betas = torch.linspace(self.start_schedule, self.end_schedule, self.timesteps)
		self.alphas = 1-self.betas
		self.alpha_cumprod = torch.cumprod(self.alphas,axis=0)

	def forward(self, x0, t, device):
		noise = torch.randn_like(x0)/2
		sqrt_alphas_cumprod_t = self.get_index_from_list(self.alpha_cumprod.sqrt(), t, x0.shape)
		sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alpha_cumprod), t, x0.shape)
		mean = sqrt_alphas_cumprod_t.to(device) * x0.to(device)
		variance = sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device)
		return mean + variance, noise.to(device)
	
	def backward(self, x, t, model, **kwargs):
		"""
		Calls the model to predict the noise in the image and returns 
		the denoised image. 
		Applies noise to this image, if we are not in the last step yet.
		"""
		betas_t = self.get_index_from_list(self.betas, t, x.shape)
		sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alpha_cumprod), t, x.shape)
		sqrt_recip_alphas_t = self.get_index_from_list(torch.sqrt(1.0 / self.alphas), t, x.shape)
		mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t, **kwargs) / sqrt_one_minus_alphas_cumprod_t)
		posterior_variance_t = betas_t

		if t == 0:
			return mean
		else:
			noise = torch.randn_like(x)
			variance = torch.sqrt(posterior_variance_t) * noise 
			return mean + variance

	
	@staticmethod
	def get_index_from_list(values, t, x_shape):
		batch_size = x_shape[0]
		result = values.gather(-1,t.cpu())

		return result.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


In [None]:
diffusion_model = DiffusionModel()

In [None]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, channels_in, channels_out, time_embedding_dims, labels, num_filters = 3, downsample=True):
        super().__init__()
        
        self.time_embedding_dims = time_embedding_dims
        self.time_embedding = SinusoidalPositionEmbeddings(time_embedding_dims)
        self.labels = labels
        if labels:
            self.label_mlp = nn.Linear(1, channels_out)
        
        self.downsample = downsample
        
        if downsample:
            self.conv1 = nn.Conv2d(channels_in, channels_out, num_filters, padding=1)
            self.final = nn.Conv2d(channels_out, channels_out, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(2 * channels_in, channels_out, num_filters, padding=1)
            self.final = nn.ConvTranspose2d(channels_out, channels_out, 4, 2, 1)
            
        self.bnorm1 = nn.BatchNorm2d(channels_out)
        self.bnorm2 = nn.BatchNorm2d(channels_out)
        
        self.conv2 = nn.Conv2d(channels_out, channels_out, 3, padding=1)
        self.time_mlp = nn.Linear(time_embedding_dims, channels_out)
        self.relu = nn.ReLU()

    def forward(self, x, t, **kwargs):
        o = self.bnorm1(self.relu(self.conv1(x)))
        o_time = self.relu(self.time_mlp(self.time_embedding(t)))
        o = o + o_time[(..., ) + (None, ) * 2]
        if self.labels:
            label = kwargs.get('labels')
            o_label = self.relu(self.label_mlp(label))
            o = o + o_label[(..., ) + (None, ) * 2]
            
        o = self.bnorm2(self.relu(self.conv2(o)))

        return self.final(o)

class UNet(nn.Module):
    def __init__(self, img_channels = 1, time_embedding_dims = 128, labels = False, sequence_channels = (64, 128, 256, 512, 1024)):
        super().__init__()
        self.time_embedding_dims = time_embedding_dims
        sequence_channels_rev = reversed(sequence_channels)
        
        self.downsampling = nn.ModuleList([Block(channels_in, channels_out, time_embedding_dims, labels) for channels_in, channels_out in zip(sequence_channels, sequence_channels[1:])])
        self.upsampling = nn.ModuleList([Block(channels_in, channels_out, time_embedding_dims, labels,downsample=False) for channels_in, channels_out in zip(sequence_channels[::-1], sequence_channels[::-1][1:])])
        self.conv1 = nn.Conv2d(img_channels, sequence_channels[0], 3, padding=1)
        self.conv2 = nn.Conv2d(sequence_channels[0], img_channels, 1)

    
    def forward(self, x, t, **kwargs):
        residuals = []
        o = self.conv1(x.float())
        for ds in self.downsampling:
            o = ds(o, t, **kwargs)
            residuals.append(o)
        for us, res in zip(self.upsampling, reversed(residuals)):
            o = us(torch.cat((o, res), dim=1), t, **kwargs)
            
        return self.conv2(o)

In [None]:
unet = UNet(labels=False)
unet.to(device)

NO_EPOCHS = 500
PRINT_FREQUENCY = 40
LR = 0.0001
BATCH_SIZE = 128
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)
VERBOSE= True


In [None]:
def plot_noise_distribution(noise, predicted_noise):
    plt.hist(noise.cpu().numpy().flatten(), density = True, alpha = 0.8, label = "ground truth noise")
    plt.hist(predicted_noise.cpu().numpy().flatten(), density = True, alpha = 0.8, label = "predicted noise")
    plt.legend()
    plt.show()


In [None]:
torch_image, min_val, max_val = h.transform_array(array_cut) 
torch_image = torch_image.unsqueeze(0)
torch_image.shape

In [None]:
for epoch in range(300):
	mean_epoch_loss = []
	
	batch = torch.stack([torch_image] * BATCH_SIZE)
	print('no crash1')
	t=torch.randint(0,diffusion_model.timesteps,(BATCH_SIZE,)).long().to(device)
	print('no crash2')
	batch_noisy_images, noise = diffusion_model.forward(batch, t, device)
	print('no crash3')
	print(t.size())
	print('no crash4')
	print(batch_noisy_images.size())
	print('no crash5')
	predicted_noise = unet(batch_noisy_images, t)
	print('no crash6')
	
	optimizer.zero_grad()
	print('no crash7')
	loss=torch.nn.functional.mse_loss(predicted_noise.float(), noise.float())
	print('no crash8')
	mean_epoch_loss.append(loss.item())
	print('no crash9')
	loss.backward()
	print('no crash10')
	optimizer.step()
	print('no crash11')
	
	if epoch % PRINT_FREQUENCY == 0:
		print('---')
		print(f"Epoch: {epoch} | Train Loss {np.mean(mean_epoch_loss)}")
		if VERBOSE:
			with torch.no_grad():
				plot_noise_distribution(noise, predicted_noise)
				print('no crash12')

In [None]:
#save unet 
torch.save(unet, 'unet_conv2d_overfitting_HS10_dec 24_28')

In [None]:
import torch
import io
unet = torch.load('unet_conv2d_overfitting_HS10_dec 24_28')


In [None]:
with torch.no_grad():
    imgs = []
    raw_list = []
    img = torch.randn(1, 1, 16, 1024).to(device)
    for i in reversed(range(diffusion_model.timesteps)):
        t = torch.full((1,), i, dtype=torch.long, device=device)
        img = diffusion_model.backward(img, t, unet.eval())
        if i % 50 == 0:
            imgs.append(img[0])
            #make from 1 16 1024 to 16 1024
            print (img[0].shape)
            array = img[0].view(16, 1024)
            raw_array=h.reverse_transform_array(array, min_val, max_val,changed_shape=False)
            raw=f.array_to_edf(raw_array,ch_names,'eeg','test.edf')
            raw_list.append(raw)
            raw.plot()

In [None]:
raw_list[-1].plot()

In [None]:
(imgs[5] - raw.get_data()).mean()

In [None]:
def print_psd(raw_list):
    for raw in raw_list:
        try:
            spectrum = raw.compute_psd()
        except :
            print('ok')
        spectrum.plot(average=True, picks="data", exclude="bads")
        plt.show()

In [None]:
print_psd(raw_list)

In [None]:
print_psd([raw])

In [None]:
#save img[-1] as fif
raw_list[-1].save('/home/ubuntu/Diffusion/made_eegs/overfitted_conv2d_HS10_sec24_28.fif', overwrite=True)