In [None]:
import numpy as np
from joblib import Memory
from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score

from pre_processing.preprocessing import Pre_Processing
from transfer_learning.tl_classifier import TL_Classifier
from transfer_learning import TLSplitter, encode_datasets
from loaddata import Dataset_Left_Right_MI

# 设置参数
dataset_name = 'Pan2023'
fs = 250
freqband = [8,30]
datapath = r'E:\工作进展\小论文2023会议\数据处理python\datasets'

# 加载数据
dataset = Dataset_Left_Right_MI(dataset_name,fs,fmin=freqband[0],fmax=freqband[1],tmin=0,tmax=4,path=datapath)
sdata, slabel = [], []
for i in range(1,4):    
    data, label = dataset.get_data([i])
    sdata.append(data)
    slabel.append(label)
    
X, y_enc, domain =encode_datasets(sdata, slabel)
print(X.shape, y_enc.shape, len(domain))
print(domain)

In [None]:
import numpy as np
from joblib import Memory
from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score

from pre_processing.preprocessing import Pre_Processing
from transfer_learning.tl_classifier import TL_Classifier
from transfer_learning import TLSplitter, encode_datasets
from loaddata import Dataset_Left_Right_MI

# 设置参数
dataset_name = 'Pan2023'
fs = 250
freqband = [8,30]
datapath = r'E:\工作进展\小论文2023会议\数据处理python\datasets'

# 加载数据
dataset = Dataset_Left_Right_MI(dataset_name,fs,fmin=freqband[0],fmax=freqband[1],tmin=0,tmax=4,path=datapath)
sdata, slabel = [], []
for i in range(1,4):    
    data, label = dataset.get_data([i])
    sdata.append(data)
    slabel.append(label)
    
X, y_enc, domain =encode_datasets(sdata, slabel)
print(X.shape, y_enc.shape, len(domain))
print(domain)

# 设置缓存目录
cachedir = '../my_cache_directory'
memory = Memory(cachedir, verbose=0)

# 实例化模型
preprocess = Pre_Processing(fs_new=160, fs_old=250, 
                       n_channels=np.arange(0, 28), 
                       start_time=0.5, end_time=3.5,
                       lowcut=None, highcut=None, )
Model = TL_Classifier(dpa_method='EA', 
                      fee_method='CSP', 
                      fes_method='MIC-K', 
                      clf_method='SVM',
                      pre_est=preprocess.process,
                      memory=memory,
                      target_domain=domain[0],
                      )

# 交叉验证
cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=5, random_state=42)
tl_cv = TLSplitter(target_domain=domain[0], cv=cv)



In [None]:
Model = TL_Classifier(dpa_method='RPA', 
                      fee_method='TS', 
                      fes_method='MIC-K', 
                      clf_method='LR',
                      pre_est=preprocess.process,
                      memory=memory,
                      target_domain=domain[0],
                      )
a = 0

In [None]:
acc = []
for train, test in tl_cv.split(X, y_enc):
    X_train, y_train = X[train], y_enc[train]
    X_test, y_test = X[test], y_enc[test]
    Model.fit(X_train, y_train)
    score = Model.score(X_test, y_test)
    acc.append(score)
    print("Score: %0.2f" % score)
print("Accuracy: %0.2f (+/- %0.2f)" % (np.mean(acc), np.std(acc)))

In [None]:
scores = cross_val_score(Model, X, y_enc, cv=tl_cv, n_jobs=15)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std()))

In [None]:
from machine_learning.csp import FBCSP
from utils import check_pipeline_compatibility as cpc
from utils import ensure_pipeline

est = FBCSP(fs=250)

print(callable(est))
print(cpc(est))
print(ensure_pipeline(est))

In [None]:
import numpy as np
from sklearn.model_selection import RepeatedStratifiedKFold
from transfer_learning import TLSplitter

cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=3, random_state=42)
tl_cv = TLSplitter(target_domain='S1', cv=cv, no_calibration=True)

In [None]:
from sklearn.model_selection import StratifiedShuffleSplit
cv = StratifiedShuffleSplit(n_splits=10, random_state=42)
tl_cv = TLSplitter(target_domain='S1', cv=cv, no_calibration=False)

In [None]:
tl_cv.cv.train_size = 0.2

In [None]:
print(tl_cv.cv.train_size)

In [None]:
for train, test in tl_cv.split(X, y_enc):
    print(len(train), len(test))

In [None]:
from inspect import signature
from transfer_learning.utilities import *

# 字典存储分类器及其类
classifiers = {
    'svm': SVC,
    'lda': LDA,
    'lr':  LR,
    'knn': KNN,
    'dtc': DTC,
    'rfc': RFC,
    'etc': ETC,
    'abc': ABC,
    'gbc': GBC,
    'gnb': GNB,
    'mlp': MLP,
    'csp': CSP,
    'trcsp': TRCSP,
    'mdm': MDM,
    'fgmdm': FgMDM,
    'tsclassifier': TSclassifier,
    'rknn': RKNN,
    'rksvm': RKSVM,
    'ts': TS,
    'tlcenter': TLCenter,
    'tlstretch': TLStretch,
    'tlrotate': TLRotate,
    'rct': RCT,
    'str': STR,
    'rot': ROT,
}

def supports_sample_weight(clf_name):
    clf_class = classifiers.get(clf_name.lower())
    if clf_class is None:
        raise ValueError(f"Classifier '{clf_name}' is not recognized.")
    
    # Get the fit method of the classifier
    
    fit_method = getattr(clf_class(), 'fit')
    
    # Get the parameters of the fit method
    params = signature(fit_method).parameters
    
    return 'sample_weight' in params

# 测试函数
print('Method svm', supports_sample_weight('svm'))  # True
print('Method lda', supports_sample_weight('lda'))  # False
print('Method lr', supports_sample_weight('lr'))  # False 
print('Method knn', supports_sample_weight('knn'))  # False
print('Method dtc', supports_sample_weight('dtc'))  # False
print('Method rfc', supports_sample_weight('rfc'))  # False
print('Method etc', supports_sample_weight('etc'))  # False
print('Method abc', supports_sample_weight('abc'))  # False
print('Method gbc', supports_sample_weight('gbc'))  # False
print('Method gnb', supports_sample_weight('gnb'))  # False
print('Method mlp', supports_sample_weight('mlp'))  # False
print('Method csp', supports_sample_weight('csp'))  # False
print('Method trcsp', supports_sample_weight('trcsp'))  # False
print('Method mdm', supports_sample_weight('mdm'))  # False
print('Method fgmdm', supports_sample_weight('fgmdm'))  # False
print('Method tsclassifier', supports_sample_weight('tsclassifier'))  # False
print('Method rknn', supports_sample_weight('rknn'))  # False
print('Method rksvm', supports_sample_weight('rksvm'))  # False
print('Method ts', supports_sample_weight('ts'))  # False
print('Method tlcenter', supports_sample_weight('tlcenter'))  # False
print('Method rct', supports_sample_weight('rct'))  # False
print('Method str', supports_sample_weight('str'))  # False
print('Method rot', supports_sample_weight('rot'))  # False



In [None]:
from transfer_learning.rpa import TLCenter_online
from pyriemann.utils.covariance import covariances
import numpy as np

X = np.random.rand(20, 28, 28)
C = covariances(X,'lwf')
TLCenter_online().get_recenter(C, sample_weight=None)




In [None]:
from transfer_learning.tl_classifier import TL_Classifier
from utils import extract_dict_keys
DPA_METHODS = extract_dict_keys('transfer_learning.tl_classifier', 'TL_Classifier', 'check_dpa', 'prealignments')

In [None]:
import numpy as np

X = np.random.rand(5, 10, 20) 
input_shape =X.shape
sample_count, channel_count, time_point_count = input_shape[-3], input_shape[-2], input_shape[-1]
new_sample_count = np.prod(input_shape[:-3]) * sample_count
print(new_sample_count)

new_X = X.reshape((new_sample_count, channel_count, time_point_count))
# 检查new_X是否与X相同
print(np.all(new_X == X.reshape((new_sample_count, channel_count, time_point_count))))
print(new_X.shape)

In [14]:
y = np.random.randint(0, 2, (5, 1)) 
print(y)
new_y = np.repeat(y, 2, axis=0)
print(new_y)
new_y = np.tile(y, 2)
print(new_y)

[[0]
 [0]
 [0]
 [1]
 [0]]
[[0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]]
[[0 0]
 [0 0]
 [0 0]
 [1 1]
 [0 0]]
