In [1]:
%load_ext autoreload
%autoreload 2
#required packages
import os
import glob
import numpy as np
import pandas as pd
pd.options.mode.chained_assignment = None  # default='warn'
from scipy import signal
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn import cluster
import pickle
from datetime import datetime
from tqdm import tqdm
#need to run as administrator, otherwise takes too long
# import umap
#my funcs
from ecog_functions import *

#inductive 
from sklearn.base import BaseEstimator, clone
from sklearn.cluster import AgglomerativeClustering
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_is_fitted

In [2]:
#two options - either fit new models, or train models on baseline 

In [3]:
#global params
MODE = 'apply' #apply applies trained clusterer, train trains a new one. Trying to use one clusterer per rat
DATA_PATH = r"C:\Users\marty\Documents\DATA\water_soluble_cbn_sleep\*\timestamped\*.txt"     #r"G:\DATA\tween_cbn_sleep\*\timestamped\*.txt"
SAMPLE_RATE = 1000
OVERWRITE = False
DATA_SUFFIX = '_clustered_rolled_50cl_fullfixed.pkl'
MODEL_SUFFIX = '_final_50cl.pkl'
RELATIVE = True #whether to calculate relative power or absolute power
WARN_BY_CORR = False

In [4]:
#select DAYS and RATS
rats=['\\51\\','\\52\\','\\53\\','\\54\\','\\55\\',
      '\\41\\','\\42\\','\\43\\','\\44\\','\\45\\','\\46\\',
      '\\21\\','\\23\\','\\24\\','\\25\\','\\26\\',
      '\\31\\','\\32\\','\\33\\','\\34\\','\\35\\',
      '\\5\\','\\6\\','\\7\\','\\8\\','\\10\\',
      '\\11\\','\\12\\','\\13\\','\\14\\','\\16\\']
#read paths - here set where DATA should be read from
paths = glob.glob(DATA_PATH)
#for training
# days = ['20241903','20230604','20232905','20230811','20240801']

In [5]:
EXTRA_PATHS = r"G:\DATA\tween_cbn_sleep\*\timestamped\*.txt"
paths.extend(glob.glob(EXTRA_PATHS))

In [6]:
# #select specific days for specific rats - optional step, when specific days matter (ex. for apply mode)
days_rat51 = ['20241703','20241803','20241903','20242003','20242103','20242203','20242303','20242403','20242503','20242603','20242703','20242803','20242903','20243003','20240104','20240204','20240304','20240404','20240504','20240604']
days_rat52 = ['20241703','20241803','20241903','20242003','20242103','20242203','20242303','20242403','20242503','20242603','20242703','20242803','20242903','20243003','20240104','20240204','20240304','20240404','20240504','20240604']
days_rat53 = ['20241703','20241803','20241903','20242003','20242103','20242203','20242303','20242403','20242503','20242603','20242703','20242803','20242903','20243003','20240104','20240204','20240304','20240404','20240504','20240604']
days_rat54 = ['20241703','20241803','20241903','20242003','20242103','20242203','20242303','20242403','20242503','20242603','20242703','20242803','20242903','20243003','20240104','20240204','20240304','20240404','20240504','20240604']
days_rat55 = ['20241703','20241803','20241903','20242003','20242103','20242203','20242303','20242403','20242503','20242603','20242703','20242803','20242903','20243003','20240104','20240204','20240304','20240404','20240504','20240604']
days_rat41 = ['20242202','20242302','20242402','20242502','20242602','20242702','20242802','20242902','20240203','20240303','20240403','20240503', '20241103']
days_rat42 = ['20242202','20242302','20242402','20242502','20242602','20242702','20242802','20242902','20240103','20240203','20240303','20240403','20240503', '20241103','20241203','20241303','20241403','20241503']
days_rat43 = ['20242202','20242302','20242402','20242502','20242602','20242702','20242802','20242902','20240103','20240203','20240303','20240403','20240503', '20241103','20241203','20241303','20241403','20241503']
days_rat44 = ['20242202','20242302','20242402','20242502','20242602','20242702','20242802','20242902','20240103','20240203','20240303','20240403','20240503', '20241103','20241203','20241303','20241403','20241503']
days_rat45 = ['20242202','20242302','20242402','20242502','20242602','20242702','20242802','20242902','20240103','20240203','20240303','20240403','20240503', '20241303','20241403','20241503']
days_rat46 = ['20242202','20242302','20242402','20242502','20242602','20242702','20242802','20242902','20240103','20240203','20240303','20240403','20240503', '20241103','20241203','20241303','20241403','20241503']
days_rat21=['20230711','20230811','20230911','20231011','20231111','20231211','20231311','20231411','20231811','20231911','20232011', '20233011','20230112','20230212']
days_rat23=['20230711','20230811','20230911','20231011','20231111','20231211','20231311','20231411','20231511','20231611','20231711','20231811','20233011','20230112','20230212','20230312','20230412']
days_rat24=['20230711','20230811','20230911','20231011','20231111','20231211','20231311','20231411','20231511','20231611','20231711','20231811','20231911', '20230112','20230212','20230312','20230412']
days_rat25=['20230711','20230811','20230911','20231011','20231111','20231211','20231311','20231411','20231511','20231811','20231911']
days_rat26=['20230711','20230811','20230911','20231011','20231111','20231211','20231311','20231411','20231511','20231711','20231811','20231911', '20233011','20230112','20230212']
days_rat31=['20240801','20240901','20241001','20241101','20241201','20241301','20241401','20242001','20242101','20242201','20242301','20242401','20242501']
days_rat32=['20232912','20233012','20233112','20240101','20240201','20240301','20240401','20240501','20240601','20240801','20240901','20241001','20241101','20241201','20241301','20241401','20241501','20242001','20242101','20242201','20242301','20242401','20242501','20242601']
days_rat33=['20232912','20233012','20233112','20240101','20240201','20240301','20240401','20240501','20240601','20240801','20240901','20241001','20241101','20241201','20241301','20241401','20241501','20242001','20242101','20242201','20242301','20242401','20242501','20242601']
days_rat34=['20232912','20233012','20233112','20240101','20240201','20240301','20240401','20240501','20240601','20240801','20240901','20241001','20241101','20241201','20241301','20241401','20242001','20242101','20242201','20242301','20242401','20242501','20242601']
days_rat35=['20232912','20233012','20233112','20240101','20240201','20240301','20240401','20240501','20240601','20240801','20240901','20241001','20241101','20241201','20241301','20241401','20241501','20242001','20242101','20242201','20242301','20242401','20242501','20242601']
days_rat5=['20230504','20230604','20230704','20230804','20230904','20231004','20231104']
days_rat6=['20230504','20230604','20230704','20230804','20230904']
days_rat7=['20230504','20230604','20230704','20230804']
days_rat8=['20230504','20230604','20230704','20230804','20230904']
days_rat10=['20230504','20230604','20230704'] 
days_rat11=['20232705','20232805','20232905''20233005','20233105','20230106','20230206','20230306','20230406','20230506', '20232206', '20232306', '20232406', '20232506', '20232606']
days_rat12=['20232705','20232805','20232905''20233005','20233105','20230106','20230406','20230506','20230606','20230706','20230806', '20232206', '20232306']
days_rat13=['20232705','20232805','20232905''20233005','20233105','20230106','20230406','20230506','20230606','20230706','20230806','20230906','20231006', '20232206', '20232306', '20232406', '20232506', '20232606']
days_rat14=['20232705','20232805','20232905''20233005','20233105','20230106','20230206','20230306','20230406','20230506','20230606','20230706','20230806','20230906','20231006','20232206', '20232306', '20232406', '20232506', '20232606']
days_rat16=['20232705','20232805','20232905''20233005','20233105','20230106','20230206','20230306','20230406','20230506','20230606','20230706','20232306', '20232406', '20232506']
# #put to a list
days_allrats=[days_rat51, days_rat52, days_rat53, days_rat54, days_rat55, 
              days_rat41,days_rat42,days_rat43,days_rat44,days_rat45,days_rat46, 
              days_rat21,days_rat23,days_rat24,days_rat25,days_rat26,
              days_rat31,days_rat32,days_rat33,days_rat34,days_rat35,
              days_rat5,days_rat6,days_rat7,days_rat8,
              days_rat10,days_rat11,days_rat12,days_rat13,days_rat14,days_rat16]#

In [7]:
#selects each rat
sel_paths=[]
for rat in tqdm(range(0,len(rats))):
    #optional - selected days (comment single line below to change to same days for all rats)
    days=days_allrats[rat]
    
    for day in range(0,len(days)):
        print(f'{rats[rat]}, {days[day]} started')
        sel_paths = [p for p in paths if days[day] in p and rats[rat] in p]
        print(len(sel_paths))
        
        #savepath
        savepath = 'C:/Users/marty/Projects/sleep_new/res_temp/'
        savepath += str(rats[rat]).strip('\\') + '_'+str(days[day]) + DATA_SUFFIX
        #don't overwrite
        if os.path.isfile(savepath):
            if not OVERWRITE:
                print(f'{savepath} exists, skipping')
                continue
        
        #read data
        data=read_txt(sel_paths)
        if len(data)<1:
            continue
        #get data in mv
        data=convert_to_mv(data)
        #convert index
        data.index=pd.to_datetime(data['time'])
        print('data read done')
        
        #get corrs to warn by channel
        if WARN_BY_CORR:
            highcorrsum = np.sum(np.sum(data.drop('time', axis = 1).corr()>0.9))
            if highcorrsum > 4:
                print(f'warning: found {highcorrsum-4} highly correlated channels')
            else:
                print('all corrs < 0.9')
               
        #filtering
        filtered_data=pd.DataFrame()
        filtered_data['l_ecog']=filter_channel(data['l_ecog'], fstart=0.5, fstop=45, sr=SAMPLE_RATE, center=True, notch_50=False)
        filtered_data['r_ecog']=filter_channel(data['r_ecog'], fstart=0.5, fstop=45, sr=SAMPLE_RATE, center=True, notch_50=False)
        #LFP data not needed for sleep scoring
        #filtered_data['lfp']=filter_channel(data['lfp'], fstart=0.5, fstop=45, sr=1000, center=True, notch_50=True)
        filtered_data['emg']=filter_channel(data['emg'], fstart=5, fstop=100, sr=SAMPLE_RATE, center=True, notch_50=True)
        filtered_data.index=data.index
        print('filtering done')
        
        #get PCA to merge ecog channels
        sel_chs=['l_ecog','r_ecog']
        pca=PCA(n_components=2)
        #select first component
        to_features=pca.fit_transform(filtered_data[sel_chs])[:,0]
        print(pca.explained_variance_ratio_)
        
        #get spectral components
        BANDWIDTHS = {
            'delta': (0.5, 4),
            'theta': (4, 8),
            'alpha': (8, 13),
            'beta': (13, 30),
            'gamma': (30, 45)
        }
        #create filters
        bandwidth_filters = {
            name: signal.butter(N=20, Wn=frange, btype='bp', output='sos',  fs=SAMPLE_RATE)
            for name, frange in BANDWIDTHS.items()
        }
        #create output df
        freq_comps=pd.DataFrame()
        freq_comps.index=filtered_data.index
        #split to freq comps
        for i, (b, filter) in enumerate(bandwidth_filters.items()):
            freq_comps[str(b)] = signal.sosfiltfilt(filter, to_features)
        #add emg
        freq_comps['emg']=filtered_data['emg']
                
        #convert to power
        pows = []
        index = []
        for col in freq_comps.columns:
            pow,ind,var = signal_to_power(freq_comps[col])
            pows.append(pow)
            index.append(ind)
        freq_powers = pd.DataFrame(pows).T
        freq_powers.columns = freq_comps.columns
        freq_powers.index = index[0]
        
        #get power ratios
        #convert to relative power - take each row and divide it by the sum of that row
        #ignore EMG since it's not cortical and has different values - is basically always max if included
        if RELATIVE:
            for ind in freq_powers.index:
                freq_powers.loc[ind,freq_powers.columns[:-1]]=freq_powers.loc[ind,freq_powers.columns[:-1]]/freq_powers.loc[ind,freq_powers.columns[:-1]].sum()
        
        #also standardize cols
        scaler = StandardScaler()
        freq_powers[freq_powers.columns]=scaler.fit_transform(freq_powers[freq_powers.columns])
        
        #apply a rolling median filter 60 points
        freq_powers = freq_powers.rolling(60).median()#.fillna(method='bfill',inplace=True)
        freq_powers = freq_powers.fillna(method='bfill')
        
        #save/load path
        ind_learner_path = './models/ind_'+str(rats[rat]).strip('\\')+MODEL_SUFFIX
        
        #code for applying a trained clusterer
        if MODE == 'apply':
            #check if the file exists
            if not os.path.isfile(ind_learner_path):
                print(f"{ind_learner_path} doesn't exist, training instead")
                MODE = 'train'
                #if file doesn;t exist, train instead
            #if clusterer exists, do clustering
            else:
                inductive_learner = pickle.load(open(ind_learner_path,'rb'))
                X = freq_powers.values
                freq_powers['cluster']=inductive_learner.predict(X)
                freq_powers['cluster']=freq_powers['cluster'].astype('category')
            
        #code for training a new clusterer
        if MODE == 'train':
            #try clustering 
            X = freq_powers.values
            #select subset to save memory
            idx = np.random.randint(len(X), size=50000)
            subset=X[idx]

            clusterer = AgglomerativeClustering(n_clusters=50, linkage='ward')
            cluster_labels = clusterer.fit_predict(subset)

            # inductive learning model to replicate agglomerative clustering on new data - way more efficient
            classifier = RandomForestClassifier(random_state=42)
            inductive_learner = InductiveClusterer(clusterer, classifier).fit(X)

            #predictions from inductive clusterer
            freq_powers['cluster']=inductive_learner.predict(X)
            freq_powers['cluster']=freq_powers['cluster'].astype('category')
            
            #save inductive learner
            pickle.dump(inductive_learner, open(ind_learner_path, 'wb'))
            
            #sort clusters
            
            grouped=freq_powers.groupby('cluster').median()
            grouped['cluster']=freq_powers.groupby('cluster').median().index
            grouped['sleep']='unknown'
            #saving results
            pickle.dump(grouped, open('./res_temp/grouped_cls_'+str(rats[rat]).strip('\\')+MODEL_SUFFIX, 'wb'))
        
        #else for troubleshooting
        #else:
            #print('weird error, MODE not recognized')
            #break
        
        pickle.dump(freq_powers,open(savepath, 'wb'))

        print(f'{rats[rat]}, {days[day]} done')
print('all done')

 45%|████▌     | 14/31 [00:00<00:00, 130.71it/s]

\51\, 20241703 started
11
C:/Users/marty/Projects/sleep_new/res_temp/51_20241703_clustered_rolled_50cl_fullfixed.pkl exists, skipping
\51\, 20241803 started
23
C:/Users/marty/Projects/sleep_new/res_temp/51_20241803_clustered_rolled_50cl_fullfixed.pkl exists, skipping
\51\, 20241903 started
25
C:/Users/marty/Projects/sleep_new/res_temp/51_20241903_clustered_rolled_50cl_fullfixed.pkl exists, skipping
\51\, 20242003 started
25
C:/Users/marty/Projects/sleep_new/res_temp/51_20242003_clustered_rolled_50cl_fullfixed.pkl exists, skipping
\51\, 20242103 started
24
C:/Users/marty/Projects/sleep_new/res_temp/51_20242103_clustered_rolled_50cl_fullfixed.pkl exists, skipping
\51\, 20242203 started
24
C:/Users/marty/Projects/sleep_new/res_temp/51_20242203_clustered_rolled_50cl_fullfixed.pkl exists, skipping
\51\, 20242303 started
24
C:/Users/marty/Projects/sleep_new/res_temp/51_20242303_clustered_rolled_50cl_fullfixed.pkl exists, skipping
\51\, 20242403 started
25
C:/Users/marty/Projects/sleep_new/re

 45%|████▌     | 14/31 [00:20<00:00, 130.71it/s]

G:\DATA\tween_cbn_sleep\33\timestamped\202406010947.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406010954.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406011058.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406011158.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406011258.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406011358.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406011458.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406011558.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406011658.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406011758.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406011858.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406011958.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406012058.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406012157.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406012257.txt
G:\DATA\tween_cbn_sleep\33\timestamped\202406012357.txt
data read done
filtering done
[0.90968224 0.09031776]
\33\, 20240601 done
\33\, 20240801 started
24
G:\D

 58%|█████▊    | 18/31 [16:06<11:37, 53.67s/it] 


KeyboardInterrupt: 