In [1]:
import os
import pandas as pd, numpy as np
from glob import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
VER = 1

In [2]:
import pyarrow.parquet as pq
from torch.utils.data import Dataset
import torch
from sklearn.impute import SimpleImputer
from torch.utils.data import DataLoader
from torchvision import transforms

"""
# CAN RUN THIS FROM ANY NOTEBOOK
 
from spectrogram_preprocessor import *
from torch.utils.data import DataLoader
from torchvision import transforms

spectrogram_dataset = SpectrogramDataset("train", transform=transforms.Compose([
    MiddleCrop(), Impute(), LogTransform(), StackFrequencyBands()])
    )

dataloader = DataLoader(spectrogram_dataset, batch_size=32,
                        shuffle=True, num_workers=0)


for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched["values"].shape) #, "labels: ", sample_batched[1].shape)
    print(sample_batched["seizure_vote"].shape)
    print(sample_batched["lpd_vote"].shape)
    print(sample_batched["gpd_vote"].shape)
    print(sample_batched["lrda_vote"].shape)
    print(sample_batched["grda_vote"].shape)
    print(len(sample_batched["target"])) # for some reason target is a list
    # observe 4th batch and stop.
    if i_batch == 3:
        break

"""

class SpectrogramDataset(Dataset):
    """EEG spectrograms dataset."""

    def __init__(self, data_type, csv_file="/kaggle/input/hms-harmful-brain-activity-classification/train.csv", root_dir="/kaggle/input/hms-harmful-brain-activity-classification", transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data_type = data_type
        if data_type == "train":
            self.data_path = root_dir + "/train_spectrograms"
            self.df_train = process_training_csv(csv_file)
        elif data_type == "test":
            self.data_path = root_dir + "/test_spectrograms"
            self.df_train = pd.read_csv(csv_file)
        self.transform = transform

    def reset(self):
        self.df_train = process_training_csv("hms-harmful-brain-activity-classification/train.csv")
        

    def __len__(self):
        return len(self.df_train)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        if (self.data_type == "train"):
            parquet_path = os.path.join(self.data_path, str(self.df_train.iloc[idx]['spec_id']) + ".parquet")
            parquet_table = pq.read_table(parquet_path)

            sample = {"values" : parquet_table.to_pandas().values[:, 1:], # drop the time column
                "min" : self.df_train.iloc[idx]['min'],
                "max" : self.df_train.iloc[idx]['max']
                }
            if self.transform:
                sample = self.transform(sample)

            seizure_vote = self.df_train.iloc[idx]['seizure_vote']
            lpd_vote = self.df_train.iloc[idx]['lpd_vote']
            gpd_vote = self.df_train.iloc[idx]['gpd_vote']
            lrda_vote = self.df_train.iloc[idx]['lrda_vote']
            grda_vote = self.df_train.iloc[idx]['grda_vote']
            other_vote = self.df_train.iloc[idx]['other_vote']
            target = self.df_train.iloc[idx]['target']

            sample = {
                "values": sample["values"],
                "seizure_vote": seizure_vote,
                "lpd_vote": lpd_vote,
                "gpd_vote": gpd_vote,
                "lrda_vote": lrda_vote,
                "grda_vote": grda_vote,
                "other_vote": other_vote,
                "target": target
            }
        else:
            #spectrogram_id eeg_id patient_id
            parquet_path = os.path.join(self.data_path, str(self.df_train.iloc[idx]['spectrogram_id']) + ".parquet")
            parquet_table = pq.read_table(parquet_path)
            
            sample = {"values" : parquet_table.to_pandas().values[:, 1:], # drop the time column
                "min" : 0,
                "max" : 0
                }
            if self.transform:
                sample = self.transform(sample)
            
            sample = {
                "values": sample["values"],
                "patient_id": self.df_train.iloc[idx]['patient_id']
            }

        return sample


def process_training_csv(csv_file):
    """
    csv preprocessing from example notebook:
    """
    df = pd.read_csv(csv_file)
    TARGETS = df.columns[-6:]
    # Creating a Unique EEG Segment per eeg_id:
    # The code groups (groupby) the EEG data (df) by eeg_id. Each eeg_id represents a different EEG recording.
    # It then picks the first spectrogram_id and the earliest (min) spectrogram_label_offset_seconds for each eeg_id. This helps in identifying the starting point of each EEG segment.
    # The resulting DataFrame train has columns spec_id (first spectrogram_id) and min (earliest spectrogram_label_offset_seconds).
    train = df.groupby('eeg_id')[['spectrogram_id','spectrogram_label_offset_seconds']].agg(
        {'spectrogram_id':'first','spectrogram_label_offset_seconds':'min'})
    train.columns = ['spec_id','min']
    # Finding the Latest Point in Each EEG Segment:
    # The code again groups the data by eeg_id and finds the latest (max) spectrogram_label_offset_seconds for each segment.
    # This max value is added to the train DataFrame, representing the end point of each EEG segment.
    tmp = df.groupby('eeg_id')[['spectrogram_id','spectrogram_label_offset_seconds']].agg(
        {'spectrogram_label_offset_seconds':'max'})
    train['max'] = tmp
    # The code adds the patient_id for each eeg_id to the train DataFrame. This links each EEG segment to a specific patient.
    tmp = df.groupby('eeg_id')[['patient_id']].agg('first')
    train['patient_id'] = tmp
    # The code sums up the target variable counts (like votes for seizure, LPD, etc.) for each eeg_id.
    tmp = df.groupby('eeg_id')[TARGETS].agg('sum') 
    for t in TARGETS:
        train[t] = tmp[t].values
    # It then normalizes these counts so that they sum up to 1. This step converts the counts into probabilities, which is a common practice in classification tasks.
    y_data = train[TARGETS].values 
    y_data = y_data / y_data.sum(axis=1,keepdims=True)
    train[TARGETS] = y_data
    # For each eeg_id, the code includes the expert_consensus on the EEG segment's classification.
    tmp = df.groupby('eeg_id')[['expert_consensus']].agg('first')
    train['target'] = tmp
    # This makes eeg_id a regular column, making the DataFrame easier to work with.
    train = train.reset_index() 
    print('Train non-overlapp eeg_id shape:', train.shape)
    return train


class MiddleCrop(object):
    """Crop the spectrogram in a sample, centred in the middle.

    Args:
        output_size: Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size=300):
        self.output_size = output_size

    def __call__(self, sample):
        # //2 for average, //2 for 2 seconds per bin (min and max are in seconds, spectrogram is 2 seconds per value)
        start_from = int((sample["min"] + sample["max"]) // 4) 
        cropped = sample["values"][start_from:start_from+self.output_size, :]
        return {"values": cropped, "min": 0, "max": self.output_size*2}
    
class Impute(object):
    """
    replace NaNs with mean

    """

    def __init__(self):
        self.nan_imputer = SimpleImputer(strategy='mean')

    def __call__(self, sample):
        imputed = self.nan_imputer.fit_transform(sample["values"])
        return {"values": imputed, "min": sample["min"], "max": sample["max"]}
    
class StackFrequencyBands(object):
    """Stack the 4 frequency bands of the spectrogram in a sample.

    "Args:
        sample: 300x400 spectrogram
        returns: 4x300x100 spectrogram (band/channel, time, frequency)
    """
    def __call__(self, sample):
        values = sample["values"]
        split_arrays = np.array(np.split(values, 4, axis=1))
        return {
            "values": split_arrays,
                "min": sample["min"],
                "max": sample["max"]
        }

class LogTransform(object):
    """Apply log transformation to the spectrogram in a sample.

    Args:
        sample: 4x300x100 spectrogram (band/channel, time, frequency)
        returns: 4x300x100 spectrogram (band/channel, time, frequency)
    """
    def __call__(self, sample):
        values = sample["values"]
        log_transformed = np.log(values + 1)
        return {
            "values": log_transformed,
                "min": sample["min"],
                "max": sample["max"]
        }


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.init as init
import torch.nn.init as init
import gc
gc.collect()
torch.cuda.empty_cache()

"""
Ideas To Prevent Loss Nans
1. Normalize Data Better
2. Less Deep / Wide Architecture
3. CNN instead of FCNN
"""
class AE(torch.nn.Module):
    def __init__(self, numFrequencies, numRows, numFeatures=100):
        super().__init__()

        # Building a linear encoder with Batch Normalization
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(numFrequencies * numRows, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, numFeatures),
            torch.nn.ReLU(),
        )

        # Building a linear decoder with Batch Normalization
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(numFeatures, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, numFrequencies * numRows),
            torch.nn.Sigmoid()
        )

        # Apply Xavier initialization to the weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_uniform_(m.weight)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


In [4]:
alpha_frequencies = 21 * 4
delta_frequencies = 18 * 4
theta_frequencies = 20 * 4
beta_frequencies = 41 * 4

In [5]:
%time
# ENGINEER FEATURES
import warnings
warnings.filterwarnings('ignore')

PATH = '/kaggle/input/hms-harmful-brain-activity-classification/train_spectrograms/'

SPEC_FREQS = len(pd.read_parquet(f'{PATH}1000086677.parquet').columns[1:])
print(f"Num Frequencies: {SPEC_FREQS}")
numFeatures = 400
if torch.cuda.device_count() > 1:
    device = torch.cuda.current_device()
    print('Use Multi GPU', device)
elif torch.cuda.device_count() == 1 and use_gpu:
    device = torch.cuda.current_device()
    print('Use GPU', device)
else:
    print("use CPU")
    device = torch.device('cpu')  # sets the device to be CPU
    print(device)
# device = torch.device('cpu') # delete when issue resolved

print("Using: ", device)
"""
Define delta feature autoencoder
"""
model_delta = AE(delta_frequencies, 300, numFeatures=numFeatures)
model_delta = model_delta.to(device)
if torch.cuda.device_count() > 1:
    model_delta = nn.DataParallel(module=model_delta)
loss_function_delta = torch.nn.MSELoss()
optimizer_delta = torch.optim.Adam(model_delta.parameters(),
                            lr = 1e-4,
                            )

# No gamma waves in this data!

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 7.39 µs
Num Frequencies: 400
Use Multi GPU 0
Using:  0


In [6]:
from scipy import signal
def extract_frequency_band_features(segment):
    
    cols = pd.read_parquet(f'{PATH}1000086677.parquet').columns[1:] # like LR_14.32
    channel_groups = ['LL', 'RL', 'LP', 'RP']
    
    eeg_bands = {'Delta': (0.5, 4), 'Theta': (4, 8), 'Alpha': (8, 12), 'Beta': (12, 30)}
    band_datapoints = {
        "Alpha": [],
        "Delta": [],
        "Theta": [],
        "Beta": [],
    }
    
    for channel_group in channel_groups:
        for band in eeg_bands:
            low, high = eeg_bands[band]
            # Filter signal for the specific band
            idxs = []
            for idx, col in enumerate(cols):
                if channel_group in col and float(col.split("_")[1]) <= high and float(col.split("_")[1]) >= low:
                    idxs.append(idx)
                        
            filtered = segment[:, idxs].flatten()
            band_datapoints[band].append(filtered)
    
    for band in band_datapoints:
        band_datapoints[band] = np.array(band_datapoints[band]).flatten() 
        # join all 4 group signals into one to reconstruct in autoencoder
    return band_datapoints
            

In [7]:
from torch.utils.data import DataLoader, TensorDataset
torch.autograd.set_detect_anomaly(True)

batch_size = 100

spectrogram_dataset = SpectrogramDataset("train", transform=transforms.Compose([
    MiddleCrop(), Impute(), LogTransform()])
    )

dataloader = DataLoader(spectrogram_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=2)

num_batches = len(spectrogram_dataset) // batch_size + 1

print(f"Training Autoencoder on {len(spectrogram_dataset)} datapoints with batch size {batch_size}")
print(f"Batches {num_batches}:", end=' ')
num_epochs = 8 # fine with 2-3 epochs but should do more with GPU if possible

best_loss = float('inf')
best_epoch = -1
DELTA_PATH = "/kaggle/working/model_delta_latest.pth"
BEST_DELTA_PATH = "/kaggle/working/model_delta_best.pth"

for epoch in tqdm(range(num_epochs)): 
    
    epoch_loss_delta = 0.0
    
    for i, sample_batched in enumerate(dataloader):       
#         if (i==6):
#             break;
        input_delta_list = []
    
        this_batch_size = sample_batched["values"].shape[0]
        for k in range(this_batch_size):
            
            eeg_segment = sample_batched["values"][k]
            
            signals = extract_frequency_band_features(eeg_segment)
            
            # Convert to torch tensors and append to the lists flattened since VNN!!!
            if len(signals["Delta"]) == delta_frequencies * 300:
                vals = signals["Delta"]
                norm_vals = (vals - vals.min()) / (vals.max() - vals.min())
                input_delta_list.append(norm_vals)
            else:
                pass
                # print("Delta mismatch shape of: ", len(signals["Delta"].flatten()))
                
                
        # Forward pass through the autoencoders
        input_delta_batch = torch.tensor(input_delta_list, dtype=torch.float32).to(device)   
        
        output_delta_batch = model_delta(input_delta_batch)

        # Calculate loss and perform optimization for delta autoencoder
        loss_delta = loss_function_delta(output_delta_batch, input_delta_batch)
        optimizer_delta.zero_grad()
        loss_delta.backward()
        optimizer_delta.step()

        # Accumulate epoch loss
        epoch_loss_delta += loss_delta.item()
        
        # Clean up to avoid memory issues
        del output_delta_batch, input_delta_batch, input_delta_list
        
        if i % 5 == 0:
            print(f"Done batch {i}", end = '... ')

    # Calculate average loss for the epoch
    avg_loss_delta = epoch_loss_delta / num_batches

    print(f"Epoch {epoch} Summary: Avg Loss Delta: {avg_loss_delta}")
    
    if avg_loss_delta < best_loss:
        best_loss = avg_loss_delta
        best_epoch = epoch
        print(f"Saving new best model epoch {epoch} at {BEST_DELTA_PATH}")
        torch.save(model_delta.state_dict(), BEST_DELTA_PATH)

    # Save the trained model parameters

    print(f"Saving model at {DELTA_PATH}")
    torch.save(model_delta.state_dict(), DELTA_PATH)

del model_delta, loss_delta, optimizer_delta
torch.cuda.empty_cache()

Train non-overlapp eeg_id shape: (17089, 12)
Training Autoencoder on 17089 datapoints with batch size 100
Batches 171: 

  0%|          | 0/8 [00:00<?, ?it/s]

Done batch 0... Done batch 5... Done batch 10... Done batch 15... Done batch 20... Done batch 25... Done batch 30... Done batch 35... Done batch 40... Done batch 45... Done batch 50... Done batch 55... Done batch 60... Done batch 65... Done batch 70... Done batch 75... Done batch 80... Done batch 85... Done batch 90... Done batch 95... Done batch 100... Done batch 105... Done batch 110... Done batch 115... Done batch 120... Done batch 125... Done batch 130... Done batch 135... Done batch 140... Done batch 145... Done batch 150... Done batch 155... Done batch 160... Done batch 165... Done batch 170... Epoch 0 Summary: Avg Loss Delta: 0.031748422154644775
Saving new best model epoch 0 at /kaggle/working/model_delta_best.pth
Saving model at /kaggle/working/model_delta_latest.pth


 12%|█▎        | 1/8 [14:21<1:40:31, 861.65s/it]

Done batch 0... Done batch 5... Done batch 10... Done batch 15... Done batch 20... Done batch 25... Done batch 30... Done batch 35... Done batch 40... Done batch 45... Done batch 50... Done batch 55... Done batch 60... Done batch 65... Done batch 70... Done batch 75... Done batch 80... Done batch 85... Done batch 90... Done batch 95... Done batch 100... Done batch 105... Done batch 110... Done batch 115... Done batch 120... Done batch 125... Done batch 130... Done batch 135... Done batch 140... Done batch 145... Done batch 150... Done batch 155... Done batch 160... Done batch 165... Done batch 170... Epoch 1 Summary: Avg Loss Delta: 0.015824495743938356
Saving new best model epoch 1 at /kaggle/working/model_delta_best.pth
Saving model at /kaggle/working/model_delta_latest.pth


 25%|██▌       | 2/8 [27:58<1:23:31, 835.28s/it]

Done batch 0... Done batch 5... Done batch 10... Done batch 15... Done batch 20... Done batch 25... Done batch 30... Done batch 35... Done batch 40... Done batch 45... Done batch 50... Done batch 55... Done batch 60... Done batch 65... Done batch 70... Done batch 75... Done batch 80... Done batch 85... Done batch 90... Done batch 95... Done batch 100... Done batch 105... Done batch 110... Done batch 115... Done batch 120... Done batch 125... Done batch 130... Done batch 135... Done batch 140... Done batch 145... Done batch 150... Done batch 155... Done batch 160... Done batch 165... Done batch 170... Epoch 2 Summary: Avg Loss Delta: 0.014028833323969828
Saving new best model epoch 2 at /kaggle/working/model_delta_best.pth
Saving model at /kaggle/working/model_delta_latest.pth


 38%|███▊      | 3/8 [41:40<1:09:05, 829.16s/it]

Done batch 0... Done batch 5... Done batch 10... Done batch 15... Done batch 20... Done batch 25... Done batch 30... Done batch 35... Done batch 40... Done batch 45... Done batch 50... Done batch 55... Done batch 60... Done batch 65... Done batch 70... Done batch 75... Done batch 80... Done batch 85... Done batch 90... Done batch 95... Done batch 100... Done batch 105... Done batch 110... Done batch 115... Done batch 120... Done batch 125... Done batch 130... Done batch 135... Done batch 140... Done batch 145... Done batch 150... Done batch 155... Done batch 160... Done batch 165... Done batch 170... Epoch 3 Summary: Avg Loss Delta: 0.012836086122613204
Saving new best model epoch 3 at /kaggle/working/model_delta_best.pth
Saving model at /kaggle/working/model_delta_latest.pth


 50%|█████     | 4/8 [55:29<55:17, 829.34s/it]  

Done batch 0... Done batch 5... Done batch 10... Done batch 15... Done batch 20... Done batch 25... Done batch 30... Done batch 35... Done batch 40... Done batch 45... Done batch 50... Done batch 55... Done batch 60... Done batch 65... Done batch 70... Done batch 75... Done batch 80... Done batch 85... Done batch 90... Done batch 95... Done batch 100... Done batch 105... Done batch 110... Done batch 115... Done batch 120... Done batch 125... Done batch 130... Done batch 135... Done batch 140... Done batch 145... Done batch 150... Done batch 155... Done batch 160... Done batch 165... Done batch 170... Epoch 4 Summary: Avg Loss Delta: 0.012463808838517694
Saving new best model epoch 4 at /kaggle/working/model_delta_best.pth
Saving model at /kaggle/working/model_delta_latest.pth


 62%|██████▎   | 5/8 [1:09:24<41:34, 831.36s/it]

Done batch 0... Done batch 5... Done batch 10... Done batch 15... Done batch 20... Done batch 25... Done batch 30... Done batch 35... Done batch 40... Done batch 45... Done batch 50... Done batch 55... Done batch 60... Done batch 65... Done batch 70... Done batch 75... Done batch 80... Done batch 85... Done batch 90... Done batch 95... Done batch 100... Done batch 105... Done batch 110... Done batch 115... Done batch 120... Done batch 125... Done batch 130... Done batch 135... Done batch 140... Done batch 145... Done batch 150... Done batch 155... Done batch 160... Done batch 165... Done batch 170... Epoch 5 Summary: Avg Loss Delta: 0.011888765098189402
Saving new best model epoch 5 at /kaggle/working/model_delta_best.pth
Saving model at /kaggle/working/model_delta_latest.pth


 75%|███████▌  | 6/8 [1:23:02<27:33, 826.79s/it]

Done batch 0... Done batch 5... Done batch 10... Done batch 15... Done batch 20... Done batch 25... Done batch 30... Done batch 35... Done batch 40... Done batch 45... Done batch 50... Done batch 55... Done batch 60... Done batch 65... Done batch 70... Done batch 75... Done batch 80... Done batch 85... Done batch 90... Done batch 95... Done batch 100... Done batch 105... Done batch 110... Done batch 115... Done batch 120... Done batch 125... Done batch 130... Done batch 135... Done batch 140... Done batch 145... Done batch 150... Done batch 155... Done batch 160... Done batch 165... Done batch 170... Epoch 6 Summary: Avg Loss Delta: 0.011647678072950994
Saving new best model epoch 6 at /kaggle/working/model_delta_best.pth
Saving model at /kaggle/working/model_delta_latest.pth


 88%|████████▊ | 7/8 [1:36:36<13:42, 822.38s/it]

Done batch 0... Done batch 5... Done batch 10... Done batch 15... Done batch 20... Done batch 25... Done batch 30... Done batch 35... Done batch 40... Done batch 45... Done batch 50... Done batch 55... Done batch 60... Done batch 65... Done batch 70... Done batch 75... Done batch 80... Done batch 85... Done batch 90... Done batch 95... Done batch 100... Done batch 105... Done batch 110... Done batch 115... Done batch 120... Done batch 125... Done batch 130... Done batch 135... Done batch 140... Done batch 145... Done batch 150... Done batch 155... Done batch 160... Done batch 165... Done batch 170... Epoch 7 Summary: Avg Loss Delta: 0.01160341649376161
Saving new best model epoch 7 at /kaggle/working/model_delta_best.pth
Saving model at /kaggle/working/model_delta_latest.pth


100%|██████████| 8/8 [1:50:24<00:00, 828.07s/it]
