# Features+Head Ensemble Starter [LB 0.34] for HMS Brain Comp
This is Features+Head is a combination and ensemble Starter notebook for Kaggle's HMS brain comp. We can train 4 different models using:
- Kaggle's spectrograms (CV 0.6123 – LB 0.41)
- Chris's EEG spectrograms(modified version) (CV 0.6288 – LB 0.39)
- Both Kaggle and EEG spectrograms (CV 0.5768 – LB 0.37)
- Chris's [WaveNet][4] (CV 0.6992 - LB 0.41)

**The Ensemble achieves LB 0.34** 

Great discussion [here][5] by @KOLOO that led to the latest score!

Features+Head Starter uses Chris Deotte's Kaggle dataset [here][1]. Also Uses Chris's EEG spectrograms [here][3] (modified version) 

### Train and Infer Tips

This notebook can be used both to train and submit (infer) to Kaggle LB. When training, you can set variable `submission = False` , you can also set `TEST_MODE = TRUE` to upload 500 samples queckly instead of the whole dataset for testing. 

To train a specific model type, you should set `DATA_TYPE = 'both|eeg|kaggle|raw'`, `kaggle` to train on Kaggle's spectrograms, `eeg` to train on EEG's spectrograms, `both` to train on Kaggle's and EEG's spectrograms, `raw` to train on EEG's signal with WaveNet,

For submission after training models, you should save them in the LOAD_MODELS_FROM dataset, then run this notebook with `submission = True`.

Once we have all the models saved to LOAD_MODELS_FROM and ready ensemble, we should set `submission = True` and `ENSEMBLE = True` and set the models versions that we prior specified, as well as their `LBs` for weighted ensemble.

This notebook is made as generic as possible to expand and try different experiments.

What you could do:
- Change EfficientNetB(0-7) with `LOAD_BACKBONE_FROM`
- Data augmentation by setting DataGenerator's parameter to `augment = True`
- Different image configurations as input.
- WaveNet model tuning.


This notebook is a direct descendent of Chris's notebook [here][2]

[1]: https://www.kaggle.com/datasets/cdeotte/brain-spectrograms
[2]: https://www.kaggle.com/code/cdeotte/efficientnetb2-starter-lb-0-57
[3]: https://www.kaggle.com/datasets/nartaa/eeg-spectrograms
[4]: https://www.kaggle.com/competitions/hms-harmful-brain-activity-classification/discussion/468684
[5]: https://www.kaggle.com/competitions/hms-harmful-brain-activity-classification/discussion/477461

In [1]:
import os, random
import tensorflow as tf
import tensorflow
import tensorflow.keras.backend as K
import pandas as pd, numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model


VER = 45
DATA_TYPE = 'both' # both|eeg|kaggle|raw
TEST_MODE = False
submission = False


np.random.seed(21)
random.seed(21)
tf.random.set_seed(21)

# USE SINGLE GPU, MULTIPLE GPUS 
gpus = tf.config.list_physical_devices('GPU')
# WE USE MIXED PRECISION
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
if len(gpus)>1:
    strategy = tf.distribute.MirroredStrategy()
    print(f'Using {len(gpus)} GPUs')
else:
    strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
    print(f'Using {len(gpus)} GPU')

2024-03-28 12:13:56.508133: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-28 12:13:59.939073: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Using 0 GPU


# Load and create Non-Overlapping Eeg Id Train Data
The competition data description says that test data does not have multiple crops from the same `eeg_id`. Therefore we will train and validate using only 1 crop per `eeg_id`. There is a discussion about this [here][1].
[1]: https://www.kaggle.com/competitions/hms-harmful-brain-activity-classification/discussion/467021

In [2]:
TARGETS = ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
FEATS2 = ['Fp1','T3','C3','O1','Fp2','C4','T4','O2']
FEAT2IDX = {x:y for x,y in zip(FEATS2,range(len(FEATS2)))}

def eeg_from_parquet(parquet_path):

    eeg = pd.read_parquet(parquet_path, columns=FEATS2)
    rows = len(eeg)
    offset = (rows-10_000)//2
    eeg = eeg.iloc[offset:offset+10_000]
    data = np.zeros((10_000,len(FEATS2)))
    for j,col in enumerate(FEATS2):
        
        # FILL NAN
        x = eeg[col].values.astype('float32')
        m = np.nanmean(x)
        if np.isnan(x).mean()<1: x = np.nan_to_num(x,nan=m)
        else: x[:] = 0
        
        data[:,j] = x

    return data

def add_kl(data):
    import torch
    labels = data[TARGETS].values + 1e-5

    # compute kl-loss with uniform distribution by pytorch
    data['kl'] = torch.nn.functional.kl_div(
        torch.log(torch.tensor(labels)),
        torch.tensor([1 / 6] * 6),
        reduction='none'
    ).sum(dim=1).numpy()
    return data

def reset_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    
if not submission:
    train = pd.read_csv('/scratch/eecs545w24_class_root/eecs545w24_class/shared_data/hms_data/raw_data/train.csv')
    TARGETS = ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
    META = ['spectrogram_id','spectrogram_label_offset_seconds','patient_id','expert_consensus']
    train = train.groupby('eeg_id')[META+TARGETS
                           ].agg({**{m:'first' for m in META},**{t:'sum' for t in TARGETS}}).reset_index() 
    train[TARGETS] = train[TARGETS]/train[TARGETS].values.sum(axis=1,keepdims=True)
    train.columns = ['eeg_id','spec_id','offset','patient_id','target'] + TARGETS
    train = add_kl(train)
    print(train.head(1).to_string())

   eeg_id    spec_id  offset  patient_id target  seizure_vote  lpd_vote  gpd_vote  lrda_vote  grda_vote  other_vote        kl
0  568657  789577333     0.0       20654  Other           0.0       0.0      0.25        0.0   0.166667    0.583333  4.584192


# Read Train Spectrograms and EEGs

We can read 3 file from Chris's [Kaggle dataset here][1] which contains all the 11k spectrograms. From Chris's modified EEG spectrogram [here][2]. From Chris's EEG signals [here][3]

[1]: https://www.kaggle.com/datasets/cdeotte/brain-spectrograms
[2]: https://www.kaggle.com/datasets/nartaa/eeg-spectrograms
[3]: https://www.kaggle.com/datasets/cdeotte/brain-eegs

In [None]:
%%time
if not submission:
    # FOR TESTING SET TEST_MODE TO TRUE
    if TEST_MODE:
        train = train.sample(500,random_state=42).reset_index(drop=True)
        spectrograms = {}
        for i,e in enumerate(train.spec_id.values):
            if i%100==0: print(i,', ',end='')
            x = pd.read_parquet(f'/kaggle/input/hms-harmful-brain-activity-classification/train_spectrograms/{e}.parquet')
            spectrograms[e] = x.values
        all_eegs = {}
        for i,e in enumerate(train.eeg_id.values):
            if i%100==0: print(i,', ',end='')
            x = np.load(f'/kaggle/input/eeg-spectrograms/EEG_Spectrograms/{e}.npy')
            all_eegs[e] = x
        all_raw_eegs = {}
        for i,e in enumerate(train.eeg_id.values):
            if i%100==0: print(i,', ',end='')
            x = eeg_from_parquet(f'/kaggle/input/hms-harmful-brain-activity-classification/train_eegs/{e}.parquet')              
            all_raw_eegs[e] = x
    else:
        spectrograms = None
        all_eegs = None
        all_raw_eegs = None
        if DATA_TYPE=='both' or DATA_TYPE=='kaggle':
            spectrograms = np.load('/scratch/eecs545w24_class_root/eecs545w24_class/shared_data/HMSEnsemble/brain-spectrograms/specs.npy',allow_pickle=True).item()
        if DATA_TYPE=='both' or DATA_TYPE=='eeg':
            all_eegs = np.load('/scratch/eecs545w24_class_root/eecs545w24_class/shared_data/HMSEnsemble/eeg-spectrograms/eeg_specs.npy',allow_pickle=True).item()
        if DATA_TYPE=='raw':
            all_raw_eegs = np.load('/kaggle/input/brain-eegs/eegs.npy',allow_pickle=True).item()

# DATA GENERATOR
This data generator outputs 512x512x3, the spectrogram and eeg images are concatenated all togother in a single image. For using data augmention you can set `augment = True` when creating the train data generator.

In [None]:
import albumentations as albu
from scipy.signal import butter, lfilter

class DataGenerator():
    'Generates data for Keras'
    def __init__(self, data, specs=None, eeg_specs=None, raw_eegs=None, augment=False, mode='train', data_type=DATA_TYPE): 
        self.data = data
        self.augment = augment
        self.mode = mode
        self.data_type = data_type
        self.specs = specs
        self.eeg_specs = eeg_specs
        self.raw_eegs = raw_eegs
        self.on_epoch_end()
        
    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        X, y = self.data_generation(index)
        if self.augment: X = self.augmentation(X)
        return X, y
    
    def __call__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)
            
            if i == self.__len__()-1:
                self.on_epoch_end()
                
    def on_epoch_end(self):
        if self.mode=='train': 
            self.data = self.data.sample(frac=1).reset_index(drop=True)
    
    def data_generation(self, index):
        if self.data_type == 'both':
            X,y = self.generate_all_specs(index)
        elif self.data_type == 'eeg' or self.data_type == 'kaggle':
            X,y = self.generate_specs(index)
        elif self.data_type == 'raw':
            X,y = self.generate_raw(index)

        return X,y
    
    def generate_all_specs(self, index):
        X = np.zeros((512,512,3),dtype='float32')
        y = np.zeros((6,),dtype='float32')
        
        row = self.data.iloc[index]
        if self.mode=='test': 
            offset = 0
        else:
            offset = int(row.offset/2)
            
        eeg = self.eeg_specs[row.eeg_id]
        spec = self.specs[row.spec_id]
        
        imgs = [spec[offset:offset+300,k*100:(k+1)*100].T for k in [0,2,1,3]] # to match kaggle with eeg
        img = np.stack(imgs,axis=-1)
        # LOG TRANSFORM SPECTROGRAM
        img = np.clip(img,np.exp(-4),np.exp(8))
        img = np.log(img)
            
        # STANDARDIZE PER IMAGE
        img = np.nan_to_num(img, nan=0.0)    
            
        mn = img.flatten().min()
        mx = img.flatten().max()
        ep = 1e-5
        img = 255 * (img - mn) / (mx - mn + ep)
        
        X[0_0+56:100+56,:256,0] = img[:,22:-22,0] # LL_k
        X[100+56:200+56,:256,0] = img[:,22:-22,2] # RL_k
        X[0_0+56:100+56,:256,1] = img[:,22:-22,1] # LP_k
        X[100+56:200+56,:256,1] = img[:,22:-22,3] # RP_k
        X[0_0+56:100+56,:256,2] = img[:,22:-22,2] # RL_k
        X[100+56:200+56,:256,2] = img[:,22:-22,1] # LP_k
        
        X[0_0+56:100+56,256:,0] = img[:,22:-22,0] # LL_k
        X[100+56:200+56,256:,0] = img[:,22:-22,2] # RL_k
        X[0_0+56:100+56,256:,1] = img[:,22:-22,1] # LP_k
        X[100+56:200+56,256:,1] = img[:,22:-22,3] # RP_K
        
        # EEG
        img = eeg
        mn = img.flatten().min()
        mx = img.flatten().max()
        ep = 1e-5
        img = 255 * (img - mn) / (mx - mn + ep)
        X[200+56:300+56,:256,0] = img[:,22:-22,0] # LL_e
        X[300+56:400+56,:256,0] = img[:,22:-22,2] # RL_e
        X[200+56:300+56,:256,1] = img[:,22:-22,1] # LP_e
        X[300+56:400+56,:256,1] = img[:,22:-22,3] # RP_e
        X[200+56:300+56,:256,2] = img[:,22:-22,2] # RL_e
        X[300+56:400+56,:256,2] = img[:,22:-22,1] # LP_e
        
        X[200+56:300+56,256:,0] = img[:,22:-22,0] # LL_e
        X[300+56:400+56,256:,0] = img[:,22:-22,2] # RL_e
        X[200+56:300+56,256:,1] = img[:,22:-22,1] # LP_e
        X[300+56:400+56,256:,1] = img[:,22:-22,3] # RP_e

        if self.mode!='test':
            y[:] = row[TARGETS]
        
        return X,y
    
    def generate_specs(self, index):
        X = np.zeros((512,512,3),dtype='float32')
        y = np.zeros((6,),dtype='float32')
        
        row = self.data.iloc[index]
        if self.mode=='test': 
            offset = 0
        else:
            offset = int(row.offset/2)
            
        if self.data_type == 'eeg':
            img = self.eeg_specs[row.eeg_id]
        elif self.data_type == 'kaggle':
            spec = self.specs[row.spec_id]
            imgs = [spec[offset:offset+300,k*100:(k+1)*100].T for k in [0,2,1,3]] # to match kaggle with eeg
            img = np.stack(imgs,axis=-1)
            # LOG TRANSFORM SPECTROGRAM
            img = np.clip(img,np.exp(-4),np.exp(8))
            img = np.log(img)
            
            # STANDARDIZE PER IMAGE
            img = np.nan_to_num(img, nan=0.0)    
            
        mn = img.flatten().min()
        mx = img.flatten().max()
        ep = 1e-5
        img = 255 * (img - mn) / (mx - mn + ep)
        
        X[0_0+56:100+56,:256,0] = img[:,22:-22,0]
        X[100+56:200+56,:256,0] = img[:,22:-22,2]
        X[0_0+56:100+56,:256,1] = img[:,22:-22,1]
        X[100+56:200+56,:256,1] = img[:,22:-22,3]
        X[0_0+56:100+56,:256,2] = img[:,22:-22,2]
        X[100+56:200+56,:256,2] = img[:,22:-22,1]
        
        X[0_0+56:100+56,256:,0] = img[:,22:-22,0]
        X[100+56:200+56,256:,0] = img[:,22:-22,1]
        X[0_0+56:100+56,256:,1] = img[:,22:-22,2]
        X[100+56:200+56,256:,1] = img[:,22:-22,3]
        
        X[200+56:300+56,:256,0] = img[:,22:-22,0]
        X[300+56:400+56,:256,0] = img[:,22:-22,1]
        X[200+56:300+56,:256,1] = img[:,22:-22,2]
        X[300+56:400+56,:256,1] = img[:,22:-22,3]
        X[200+56:300+56,:256,2] = img[:,22:-22,3]
        X[300+56:400+56,:256,2] = img[:,22:-22,2]
        
        X[200+56:300+56,256:,0] = img[:,22:-22,0]
        X[300+56:400+56,256:,0] = img[:,22:-22,2]
        X[200+56:300+56,256:,1] = img[:,22:-22,1]
        X[300+56:400+56,256:,1] = img[:,22:-22,3]
        
        if self.mode!='test':
            y[:] = row[TARGETS]
        
        return X,y
    
    def generate_raw(self,index):
        X = np.zeros((10_000,8),dtype='float32')
        y = np.zeros((6,),dtype='float32')
        
        row = self.data.iloc[index]
        eeg = self.raw_eegs[row.eeg_id]
            
        # FEATURE ENGINEER
        X[:,0] = eeg[:,FEAT2IDX['Fp1']] - eeg[:,FEAT2IDX['T3']]
        X[:,1] = eeg[:,FEAT2IDX['T3']] - eeg[:,FEAT2IDX['O1']]
            
        X[:,2] = eeg[:,FEAT2IDX['Fp1']] - eeg[:,FEAT2IDX['C3']]
        X[:,3] = eeg[:,FEAT2IDX['C3']] - eeg[:,FEAT2IDX['O1']]
            
        X[:,4] = eeg[:,FEAT2IDX['Fp2']] - eeg[:,FEAT2IDX['C4']]
        X[:,5] = eeg[:,FEAT2IDX['C4']] - eeg[:,FEAT2IDX['O2']]
            
        X[:,6] = eeg[:,FEAT2IDX['Fp2']] - eeg[:,FEAT2IDX['T4']]
        X[:,7] = eeg[:,FEAT2IDX['T4']] - eeg[:,FEAT2IDX['O2']]
            
        # STANDARDIZE
        X = np.clip(X,-1024,1024)
        X = np.nan_to_num(X, nan=0) / 32.0
            
        # BUTTER LOW-PASS FILTER
        X = self.butter_lowpass_filter(X)
        # Downsample
        X = X[::5,:]
        
        if self.mode!='test':
            y[:] = row[TARGETS]
                
        return X,y
        
    def butter_lowpass_filter(self, data, cutoff_freq=20, sampling_rate=200, order=4):
        nyquist = 0.5 * sampling_rate
        normal_cutoff = cutoff_freq / nyquist
        b, a = butter(order, normal_cutoff, btype='low', analog=False)
        filtered_data = lfilter(b, a, data, axis=0)
        return filtered_data
    
    def resize(self, img,size):
        composition = albu.Compose([
                albu.Resize(size[0],size[1])
            ])
        return composition(image=img)['image']
            
    def augmentation(self, img):
        composition = albu.Compose([
                albu.HorizontalFlip(p=0.4)
            ])
        return composition(image=img)['image']

In [None]:
class VAEDataGenerator(DataGenerator):
    def __getitem__(self,index):
        x,y=super().__getitem__(index)
        return (x)/256,(x)/256

# DISPLAY DATA GENERATOR
Below we display example data generator spectrogram images and raw EEG signals.

In [None]:
if not submission and DATA_TYPE!='raw':
    gen = DataGenerator(train, augment=False, specs=spectrograms, eeg_specs=all_eegs, data_type=DATA_TYPE)
    for x,y in gen:
        break
    plt.imshow(x[:,:,0])
    plt.title(f'Target = {y.round(1)}',size=12)
    plt.yticks([])
    plt.ylabel('Frequencies (Hz)',size=12)
    plt.xlabel('Time (sec)',size=12)
    plt.show()
    plt.imshow(x[:,:,1])
    plt.title(f'Target = {y.round(1)}',size=12)
    plt.yticks([])
    plt.ylabel('Frequencies (Hz)',size=12)
    plt.xlabel('Time (sec)',size=12)
    plt.show()
    plt.imshow(x[:,:,2])
    plt.title(f'Target = {y.round(1)}',size=12)
    plt.yticks([])
    plt.ylabel('Frequencies (Hz)',size=12)
    plt.xlabel('Time (sec)',size=12)
    plt.show()
    
if not submission and DATA_TYPE=='raw':
    gen = DataGenerator(train, raw_eegs=all_raw_eegs, data_type=DATA_TYPE)
    for x,y in gen:
        plt.figure(figsize=(20,4))
        offset = 0
        for j in range(x.shape[-1]):
            if j!=0: offset -= x[:,j].min()
            plt.plot(range(2_000),x[:,j]+offset,label=f'feature {j+1}')
            offset += x[:,j].max()
        plt.legend()
        plt.show()
        break

# Autoencoder

In [8]:
"""def halved_glorot_uniform(shape, dtype=None):
    initializer = tf.keras.initializers.GlorotUniform()
    weights = initializer(shape, dtype)
    return weights / 2.0
tf.keras.layers.Dense.Conv2D = halved_glorot_uniform"""

'def halved_glorot_uniform(shape, dtype=None):\n    initializer = tf.keras.initializers.GlorotUniform()\n    weights = initializer(shape, dtype)\n    return weights / 2.0\ntf.keras.layers.Dense.Conv2D = halved_glorot_uniform'

In [9]:
import tensorflow as tf

class ResNetBlock(tf.keras.layers.Layer):
    def __init__(self, in_channels, kernel_size, modify=False, bn=True):
        super(ResNetBlock, self).__init__()
        self.modify = modify
        if modify == 'downsample':
            self.conv1 = tf.keras.layers.Conv2D(in_channels*2, kernel_size, strides=2, padding='same', use_bias=False, activation=tf.nn.relu, kernel_regularizer=tf.keras.regularizers.l2(0.005))
            self.conv2 = tf.keras.layers.Conv2D(in_channels*2, kernel_size, padding='same', use_bias=False, activation=tf.nn.relu, kernel_regularizer=tf.keras.regularizers.l2(0.005))
            if bn:
                self.bn1 = tf.keras.layers.BatchNormalization()
                self.bn2 = tf.keras.layers.BatchNormalization()
            else:
                self.bn1 = tf.keras.layers.Layer()
                self.bn2 = tf.keras.layers.Layer()
        elif modify == 'upsample':
            self.conv1 = tf.keras.layers.Conv2DTranspose(in_channels//2, kernel_size, strides=2, padding='same', output_padding=1, use_bias=False, activation=tf.nn.relu, kernel_regularizer=tf.keras.regularizers.l2(0.005))
            self.conv2 = tf.keras.layers.Conv2D(in_channels//2, kernel_size, padding='same', use_bias=False, activation=tf.nn.relu, kernel_regularizer=tf.keras.regularizers.l2(0.005))
            self.bn1 = tf.keras.layers.BatchNormalization()
            self.bn2 = tf.keras.layers.BatchNormalization()
        else:
            self.conv1 = tf.keras.layers.Conv2D(in_channels, kernel_size, padding='same', activation=tf.nn.relu, kernel_regularizer=tf.keras.regularizers.l2(0.005))
            self.conv2 = tf.keras.layers.Conv2D(in_channels, kernel_size, padding='same', activation=tf.nn.relu, kernel_regularizer=tf.keras.regularizers.l2(0.005))
            self.bn1 = tf.keras.layers.BatchNormalization()
            self.bn2 = tf.keras.layers.BatchNormalization()
        self.act = tf.keras.layers.ReLU()
        if modify == 'downsample':
            self.proj = tf.keras.layers.Conv2D(in_channels*2, kernel_size, strides=2, padding='same', activation=tf.nn.relu, kernel_regularizer=tf.keras.regularizers.l2(0.005))
        if modify == 'upsample':
            self.proj = tf.keras.layers.Conv2DTranspose(in_channels//2, kernel_size, strides=2, padding='same', output_padding=1, activation=tf.nn.relu, kernel_regularizer=tf.keras.regularizers.l2(0.005))

    def call(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.modify:
            x = self.proj(x)
        out = x + out
        out = self.act(out)
        return out


In [10]:
from sklearn.model_selection import train_test_split
train, test = train_test_split(train, train_size=0.80)

In [11]:
x_train = VAEDataGenerator(train, augment=False, specs=spectrograms, eeg_specs=all_eegs, raw_eegs=all_raw_eegs)
x_train = tf.data.Dataset.from_generator(generator=x_train, 
                                               output_signature=(tf.TensorSpec(shape=(512,512,3), dtype=tf.float32),
                                                                 tf.TensorSpec(shape=(512,512,3), dtype=tf.float32))).batch(16).prefetch(tf.data.AUTOTUNE)
x_test = VAEDataGenerator(test, augment=False, specs=spectrograms, eeg_specs=all_eegs, raw_eegs=all_raw_eegs)
x_test= tf.data.Dataset.from_generator(generator=x_test, 
                                               output_signature=(tf.TensorSpec(shape=(512,512,3), dtype=tf.float32),
                                                                 tf.TensorSpec(shape=(512,512,3), dtype=tf.float32))).batch(16).prefetch(tf.data.AUTOTUNE)

In [None]:
for x,y in x_train:
    print(x.shape)
    print(y.shape)
    print(np.max(x))
    print(np.min(x))
    break

2024-03-28 10:24:24.052828: W tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:2303] No (suitable) GPUs detected, skipping auto_mixed_precision graph optimizer
2024-03-28 10:24:24.848866: W tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:2303] No (suitable) GPUs detected, skipping auto_mixed_precision graph optimizer
2024-03-28 10:24:25.125586: W tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:2303] No (suitable) GPUs detected, skipping auto_mixed_precision graph optimizer
2024-03-28 10:24:25.472042: W tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:2303] No (suitable) GPUs detected, skipping auto_mixed_precision graph optimizer


In [None]:
latent_space_dim = 512*3

In [None]:
inputs = tf.keras.layers.Input(shape=(512,512,3))
conv = tf.keras.layers.Conv2D(3, 7, 1, padding='same')(inputs)
eres1 = ResNetBlock(16, 3, modify='downsample')(conv)
eres2 = ResNetBlock(32, 3, modify='downsample')(eres1)
eres3 = ResNetBlock(64, 3, modify='downsample')(eres2)
eres4 = ResNetBlock(128, 3, modify='downsample')(eres3)
eres5 = ResNetBlock(256, 3, modify='downsample')(eres4)
eres6 = ResNetBlock(512, 3, modify='downsample')(eres5)

shape_before_flatten = tensorflow.keras.backend.int_shape(eres6)[1:]
encoder_flatten = tensorflow.keras.layers.Flatten()(eres6)

z_mean_l  = tensorflow.keras.layers.Dense(units=latent_space_dim, name="encoder_mu")
z_mean = z_mean_l(encoder_flatten) 
z_log_var_l  = tensorflow.keras.layers.Dense(units=latent_space_dim, name="encoder_log_variance", kernel_initializer='zeros', kernel_regularizer=tf.keras.regularizers.l2(3))
z_log_var = z_log_var_l(encoder_flatten)

#encoder_mu_log_variance_model = tensorflow.keras.models.Model(enc_input_layer, (encoder_mu, encoder_log_variance), name="encoder_mu_log_variance_model")

@tf.function
def sampling(args):
    z_mean_, z_log_var_ = args
    #tf.print(z_mean_)
    #tf.print(z_log_var_)
    #tf.print(tf.reduce_max(z_log_var_))
    epsilon = K.random_normal(shape=(K.shape(z_mean_)[0], latent_space_dim))
    #tf.print(epsilon)
    result = z_mean_ + K.exp(z_log_var_ / 2) * epsilon
    #tf.print(result)
    return tf.debugging.check_numerics(result, "NaN detected in sampling")

# Reparameterization trick
z = tf.keras.layers.Lambda(sampling)([z_mean, z_log_var])

encoder = tf.keras.Model(inputs, [z_mean, z_log_var, z], name='encoder')

In [None]:
dec_input_layer = tf.keras.layers.Input(shape=(latent_space_dim))
decoder_dense_layer1 = tensorflow.keras.layers.Dense(units=np.prod(shape_before_flatten), name="decoder_dense_1")(dec_input_layer)
decoder_reshape = tensorflow.keras.layers.Reshape(target_shape=shape_before_flatten)(decoder_dense_layer1)
dres1 = ResNetBlock(1024, 3, modify='upsample')(decoder_reshape)
dres2 = ResNetBlock(512, 3, modify='upsample')(dres1)
dres3 = ResNetBlock(256, 3, modify='upsample')(dres2)
dres4 = ResNetBlock(128, 3, modify='upsample')(dres3)
dres5 = ResNetBlock(64, 3, modify='upsample')(dres4)
dres6 = ResNetBlock(32, 3, modify='upsample')(dres5)
dconv = tf.keras.layers.Conv2D(3, 3, 1, padding='same', activation = tf.nn.sigmoid)(dres6)

decoder = tf.keras.models.Model(dec_input_layer, dconv, name="decoder_model")

In [None]:
#VAE:

#vae_input = tf.keras.layers.Input(shape=(512, 512, 3), name="VAE_input")
vae_encoder_output = encoder(inputs)
outputs = decoder(vae_encoder_output[2])
vae = tf.keras.models.Model(inputs, outputs, name="VAE")

In [None]:
encoder.summary()

In [None]:
decoder.summary()

In [None]:
vae.summary()

In [None]:
from tensorflow.keras.losses import mse
reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))
reconstruction_loss *= 512*512*3
kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=1)
B = 1
vae_loss = K.mean(B * reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.add_metric(kl_loss, name="kl_loss")
vae.add_metric(reconstruction_loss, name="reconstruction_loss")
vae.compile(optimizer=tf.keras.optimizers.Adam(lr=10e-6, global_clipnorm=10e-6, clipvalue=10e-6, weight_decay=1))

#vae.fit(x_train, epochs=500, batch_size=batch_size, validation_data=(x_test, None))

In [None]:
np.array([next(iter(x_train))[0][0].numpy()]).shape

In [None]:
from tqdm import tqdm
from tensorflow.keras.callbacks import LambdaCallback
def plot_callback(epoch, logs):
    x = np.array([next(iter(x_train))[0][0].numpy()])
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    # Plot x
    axes[0].imshow(x[0][:,:,0])
    axes[0].set_yticks([])
    axes[0].set_ylabel('Frequencies (Hz)', size=12)
    axes[0].set_xlabel('Time (sec)', size=12)
    # Plot out
    out = vae.predict(x)
    axes[1].imshow(out[0][:,:,0])
    axes[1].set_yticks([])
    axes[1].set_ylabel('Frequencies (Hz)', size=12)
    axes[1].set_xlabel('Time (sec)', size=12)

    plt.show()
    
class PrintValidationLoss(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        val_loss = logs.get('val_loss')
        val_kl_loss = logs.get('val_kl_loss')
        val_reconstruction_loss = logs.get("val_reconstruction_loss")
        print(f'Validation Loss: {val_loss} - kl_loss: {val_kl_loss} - reconstruction_loss: {val_reconstruction_loss}')
# Define the LambdaCallback
plot_callback_lambda = LambdaCallback(on_epoch_end=plot_callback)
#m_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath='/kaggle/working', save_weights_only = True, period=5)

# Fit the model
History = vae.fit(x_train, epochs=10, batch_size=64, callbacks=[plot_callback_lambda,PrintValidationLoss()])

In [None]:
vae.save_weights('final_output')