### Data Loading and Preparation

In [9]:
import os, random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import IterableDataset, DataLoader, get_worker_info

DATA_ROOT = "/kaggle/input/hms-harmful-brain-activity-classification" # dataset folder
train = pd.read_csv(f"{DATA_ROOT}/train.csv") # reads into Dataframe
SPEC_DIR = f"{DATA_ROOT}/train_spectrograms" # folder containing spectrogram parquet files.
def soft_label(row):
    v = np.array([
        row["seizure_vote"], row["lpd_vote"], row["gpd_vote"],
        row["lrda_vote"], row["grda_vote"], row["other_vote"],
    ], dtype=np.float32)
    v = v / (v.sum() + 1e-6) # normalizes so values sum to 1
    return v

# Keep only columns needed for training (saving RAM)
# “Load spectrogram <id>”
# “Crop around <offset>”
# “Use label distribution <soft>”
use_cols = ["spectrogram_id", "spectrogram_label_offset_seconds",
            "seizure_vote","lpd_vote","gpd_vote","lrda_vote","grda_vote","other_vote"] # votes used to build label
df = train[use_cols].copy()
df["soft"] = df.apply(soft_label, axis=1) # adds a new column soft which contains the 6D probability vector for each row

display("Required Training columns:", df)

# Shape of the first paraquet file
tmp = pd.read_parquet(f"{SPEC_DIR}/{df.iloc[0].spectrogram_id}.parquet")
print(tmp.shape)
print(tmp.columns[:15])

# For reference
CLASSES = ["seizure", "lpd", "gpd", "lrda", "grda", "other"]

'Required Training columns:'

Unnamed: 0,spectrogram_id,spectrogram_label_offset_seconds,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,soft
0,353733,0.0,3,0,0,0,0,0,"[0.9999997, 0.0, 0.0, 0.0, 0.0, 0.0]"
1,353733,6.0,3,0,0,0,0,0,"[0.9999997, 0.0, 0.0, 0.0, 0.0, 0.0]"
2,353733,8.0,3,0,0,0,0,0,"[0.9999997, 0.0, 0.0, 0.0, 0.0, 0.0]"
3,353733,18.0,3,0,0,0,0,0,"[0.9999997, 0.0, 0.0, 0.0, 0.0, 0.0]"
4,353733,24.0,3,0,0,0,0,0,"[0.9999997, 0.0, 0.0, 0.0, 0.0, 0.0]"
...,...,...,...,...,...,...,...,...,...
106795,2147388374,12.0,0,0,0,3,0,0,"[0.0, 0.0, 0.0, 0.9999997, 0.0, 0.0]"
106796,2147388374,14.0,0,0,0,3,0,0,"[0.0, 0.0, 0.0, 0.9999997, 0.0, 0.0]"
106797,2147388374,16.0,0,0,0,3,0,0,"[0.0, 0.0, 0.0, 0.9999997, 0.0, 0.0]"
106798,2147388374,18.0,0,0,0,3,0,0,"[0.0, 0.0, 0.0, 0.9999997, 0.0, 0.0]"


(320, 401)
Index(['time', 'LL_0.59', 'LL_0.78', 'LL_0.98', 'LL_1.17', 'LL_1.37',
       'LL_1.56', 'LL_1.76', 'LL_1.95', 'LL_2.15', 'LL_2.34', 'LL_2.54',
       'LL_2.73', 'LL_2.93', 'LL_3.13'],
      dtype='object')


In [10]:
# ---- Spectrogram window loader from parquet  to tensor window----
#loads one spectrogram file and extracts a small time window.
def load_spec_window_parquet(spec_id: int, offset_s: float, window_seconds: float = 10.0):
    path = os.path.join(SPEC_DIR, f"{spec_id}.parquet")
    spec_df = pd.read_parquet(path)

    # Separate time and frequency data
    time = spec_df["time"].to_numpy()
    freq_mat = spec_df.drop(columns=["time"]).to_numpy(dtype=np.float32)
    # freq_mat shape: (time_steps, freq_bins)

    # Find center index closest to offset_s
    center_idx = np.argmin(np.abs(time - offset_s))

    # Estimate time resolution
    dt = np.median(np.diff(time))  # seconds per row
    half_window = int(round((window_seconds / 2) / dt))

    start = max(0, center_idx - half_window)
    end   = min(len(time), center_idx + half_window)

    window = freq_mat[start:end, :]  # (time, freq)

    # Pad if near edges
    if window.shape[0] < 2 * half_window:
        pad = 2 * half_window - window.shape[0]
        window = np.pad(window, ((0, pad), (0, 0)), mode="constant")

    # Log + normalize
    window = np.log1p(np.maximum(window, 0))
    window = (window - window.mean()) / (window.std() + 1e-6)

    # CNN format: (C, H, W) = (1, freq, time)
    window = window.T  # (freq, time)
    return window[None, :, :]  # (1, freq, time)



#-------STREAMING DATASET CLASS------------#
class HMSParquetStream(IterableDataset):
    def __init__(self, df, shuffle=True, infinite=True, window_seconds=10):
        self.df = df.reset_index(drop=True)
        self.shuffle = shuffle
        self.infinite = infinite #if True, it loops forever (continuous feed)
        self.window_seconds = window_seconds
        
    #--Iterator: main streaming loop---#
    #GPU → does math (CNN forward + backward)
    #CPU workers → read files, crop spectrograms, normalize, convert to tensors
    
    def __iter__(self):
        worker = get_worker_info() #Background CPU processes
        start, step = (0, 1) if worker is None else (worker.id, worker.num_workers)
        
        idxs = list(range(len(self.df))) #Create list of indices
        #Infinite loop (continuous feed)
        while True: 
            if self.shuffle:
                random.shuffle(idxs)
            #Worker-sharded iteration
            for k in range(start, len(idxs), step):
                row = self.df.iloc[idxs[k]] #Load one training row from the dataset
                
                #Load spectrogram window -> opens, crops, returns shape (1, F, window_seconds)
                x = load_spec_window_parquet(
                    int(row.spectrogram_id),
                    float(row.spectrogram_label_offset_seconds),
                    seconds=self.window_seconds
                )
                y = row.soft #Get soft labels
                
                #Yield tensors to DataLoader -> Each yielded item is (input_tensor, label_tensor).
                yield torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32) 
            
            #If infinite=False, it runs only one “epoch” and stops.
            if not self.infinite:
                break

# CREATE STREAM + DATALOADER
# stream: your continuous generator of samples
# DataLoader:
# batch_size=32: groups 32 samples into a batch
# num_workers=2: runs 2 parallel worker processes to load data
# pin_memory=True: speeds CPU→GPU transfer
# persistent_workers=True: keeps workers alive between iterations (faster)

stream = HMSParquetStream(df, shuffle=True, infinite=True, window_seconds=10)
loader = DataLoader(stream, batch_size=32, num_workers=2, pin_memory=True, persistent_workers=True)



### CNN Training Example  
The DataLoader continuously pulls spectrogram windows from streaming dataset (which loads+extracts them from parquet), batches them into x, and the loop repeatedly runs forward → loss → backward → optimizer step to train the CNN.

**Spectrogram → CNN → window-level probabilities**

**FETCHING THE EXTRACTED SPECTROGRAM WINDOWS**
* loader is a DataLoader built from HMSParquetStream(IterableDataset).
* When the loop asks for the next batch, the DataLoader tells its workers: “Give me the next samples.”
* Each worker runs the dataset’s __iter__() method, which does this for each sample:
* picks one row from df  
* calls load_spec_window_parquet(spec_id, offset_s, window_seconds)
* which reads the parquet file, finds the correct time row closest to offset_s
* slices a window of rows (time) and all columns (freq)
* normalizes it
* returns (1, freq, time) tensor
* The DataLoader takes 32 of these samples and stacks them into a batch:
* x becomes shape: (B, 1, freq, time), where B=32
* y becomes shape: (B, 6) soft labels
* So the “fetching” is not done by the CNN directly — it’s done by:
* ✅ DataLoader → ✅ IterableDataset.__iter__() → ✅ load_spec_window_parquet()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device) #Moves your CNN weights to GPU/CPU
opt = torch.optim.AdamW(model.parameters(), lr=1e-3) #Creates an optimizer (AdamW) that will update the model weights to reduce loss.

for step, (x, y) in enumerate(loader): #Training starts
    x = x.to(device, non_blocking=True) #Moves the batch to GPU.  
    y = y.to(device, non_blocking=True) # non_blocking=True-> CPU starts the transfer and immediately moves on to the next task
    
    logits = model(x)  # (B, 6) #Feeds the spectrogram batch into CNN. logits gives raw score per class
    loss = torch.nn.functional.kl_div(
        torch.log_softmax(logits, dim=1), #converts logits to log probabilities and computes KL divergence loss bwt model pred distr. & soft lable distr.
        y,
        reduction="batchmean"
    )

    loss.backward() #Computes gradients (how each weight should change to reduce loss).
    opt.step() # Updates model weights using those gradients.
    opt.zero_grad() # Clears gradients so they don’t accumulate into the next step.

    #LOGGING + STOPPING
    if step % 100 == 0:
        print(step, float(loss)) #Prints loss every 100 batches.

    if step == 5000:   # stops after 5000 batches since it's streaming infinitely
        break

### For monitoring the training outputs:

In [None]:
#Loss Curve
losses.append(loss.item())

#Accuracy (window-level)
pred = logits.argmax(dim=1)
true = y.argmax(dim=1)
acc = (pred == true).float().mean()

#Save Checkpoints
if step % 1000 == 0:
    torch.save(model.state_dict(), f"model_step_{step}.pth")


### After training Save & Reload the model



In [None]:
# Saves only the learned weights
MODEL_PATH = "cnn_hms_trained.pth"
torch.save(model.state_dict(), MODEL_PATH)


In [None]:
#Reload the model
model = MyCNN()        # same architecture as training
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model = model.to(device)


### Freeze + Inference
output is final predicted class (per window)

In [None]:
model.eval() #freezes the current batch and uses the running mean/variance learned during training
with torch.no_grad(): # The "Memory Saver". Freezes the weights
    probs = model(x) # probs here are logits (raw scores
probs = torch.softmax(model(x), dim=1) # final class probabilities
preds = probs.argmax(dim=1) # final predicted class (per window) [ P(seizure), P(lpd), P(gpd), P(lrda), P(grda), P(other) ]


### Scenario 2 (CNN + LSTM/Transformer) is only for aggregating across windows/patients.
Below is the **freeze + Inference step**
Completely disables gradient computation for parameters  
Prevents any weight updates  
Mostly used for:  
transfer learning  
feature extractors  
partial freezing  

In [None]:
model.eval()
for p in model.parameters():
    p.requires_grad = False