### 加载数据

In [12]:
from openpyxl import Workbook
import os
import math

base_dir = r'D:\项目资料\基因表达\20220817'
genes_file = os.path.join(base_dir, 'Diff_genes_name.csv')
patients_file = os.path.join(base_dir, '最终纳入的患者.csv')
m_genes_file = os.path.join(base_dir, 'm_all.csv')

# 400多个基因
gene_dict = {}
with open(genes_file) as f:
    for line in f.readlines()[1:]:
        arr = line.strip().split(',')
        gene_dict[arr[1]] = arr[2]
gene_id_arr = [gene_id for gene_id in gene_dict.keys()]
gene_num = len(gene_id_arr)
# print(gene_dict)
print('基因数量：', gene_num)

# 加载实验对象
ensemble = []
with open(patients_file, encoding='utf-8') as f:
    for line in f.readlines()[1:]:
        arr = line.strip().split(',')
        ensemble.append(arr[0])
            
# print(ensemble)
print('患者数量：', len(ensemble))


基因数量： 379
患者数量： 435


### 加载并过滤基因数据

In [13]:
gene_data = {}
for file in [m_genes_file]:
    with open(file) as f:
        for idx, line in enumerate(f.readlines()):
            arr = line.strip().split(',')
            if idx == 0:
                for name in arr[1:]:
                    patient_id = name[:12]
                    if patient_id in ensemble:
                        gene_data[(patient_id, arr.index(name))] = [0 for i in range(len(gene_id_arr))]
            else:
                gene_id = arr[0].split('.')[0]
                if gene_id in gene_dict:
                    gene_id_idx = gene_id_arr.index(gene_id)
                    for patient_id, pidx in gene_data.keys():
                        gene_data[(patient_id, pidx)][gene_id_idx] = math.log(float(arr[pidx]) + 1, 10)


gene_data_ = {}
pid = ''
for patient_id, pidx in gene_data.keys():
    gene_data_[patient_id] = gene_data[(patient_id, pidx)]
    pid = patient_id
gene_data = gene_data_

print('基因数据中患者数: %s，基因数据中基因量：%s' % (len(gene_data), len(gene_data[pid])))


基因数据中患者数: 435，基因数据中基因量：379


### 基因预测模型定义

In [14]:
import numpy as np
import pandas as pd
import os
import time
import torch
from torch import nn
from torch.utils.data import Subset, DataLoader, Dataset
from torch import optim
from sklearn.metrics import roc_auc_score
import sys

sys.path.append('HE2RNA')
from model import HE2RNA, fit, predict
from utils import compute_metrics

In [15]:
model_params = {
    'input_dim': 2048,
    'output_dim': gene_num,
    'layers': [256,256],
    'ks': [10,20,50,100,200,500,1000,2000,5000],
    'dropout': 0.25,
    'nonlin': nn.ReLU(),
    'device': 'cuda'
}
# model = HE2RNA(**model_params)


# optim_params = {'params': model.parameters(),
#                 'lr': 3e-4,
#                 'weight_decay': 1e-5
#                }
# optimizer = optim.Adam(**optim_params)
batch_size = 16
training_params = {
    'max_epochs': 100,
    'patience': 5,
    'batch_size': batch_size,
    'num_workers': 0
}

logdir = 'HE2RNA/logs'
savedir = 'HE2RNA/ckpts'
valid_projects = None

In [16]:
from scipy import stats
def stat_cor(pred_arr, labels_arr, out_files):
    with open(out_files, 'w') as f:
        for idx, gene in enumerate(gene_dict.keys()):
            X1 = [row[idx] for row in labels_arr]
            X2 = [row[idx] for row in pred_arr]
            r,p = stats.pearsonr(X1, X2)
            print('%s,%s,%s,%s' %(gene, gene_dict[gene], r, p))
            f.write('%s\t%s\t%s\t%s\n' %(gene, gene_dict[gene], r, p))
            

def fit_model(model, train_set, val_set, test_set, optimizer, k):
    fit(model,
                                train_set,
                                val_set,
                                valid_projects,
                                test_set=test_set,
                                params=training_params,
                                optimizer=optimizer,
                                logdir=logdir,
                                path=os.path.join(savedir,'model_' + str(k)))

#     stat_cor(preds, labels, os.path.join(savedir, 'results_per_fold_%s.txt' % k))
#     report = {}
#     report['correlation_fold_' + str(k)] = compute_metrics(labels, preds)
#     report = pd.DataFrame(report)
#     report.to_csv(os.path.join(savedir, 'results_per_fold_%s.csv' % k), index=False)


### 加载特征数据

In [17]:
import os
feature_base_path = r'\\192.168.0.60\public\模型算法组\训练数据\zhangkuo\基础AI\DX-color-result0727特征'
train_feature_path = os.path.join(feature_base_path, 'train')
val_feature_path = os.path.join(feature_base_path, 'test')

def get_all_files(path):
    results = []
    for root, dirs, files in os.walk(path):
        if len(files) > 0:
            for file in files:
                if file.endswith('txt'):
                    results.append(os.path.join(root, file))
    return results

def get_features(file_path):
    features = []
    with open(file_path) as f:
        for line in f.readlines():
            features.append([float(e) for e in line.split(' ')])
    return features

In [7]:
# import math
# train_files = get_all_files(train_feature_path)
# val_files = get_all_files(val_feature_path)

# line_counts = {}
# for file_path in train_files + val_files:
#     ct = 0
#     with open(file_path) as f:
#         print('processing %s' % file_path)
#         for idx,line in enumerate(f):
#             arr = line.split(' ')
#             for x in arr:
#                 if float(x) in [math.inf, -math.inf, math.nan]:
#                     print('illegal line: %s->%s' % (file_path, idx))
        
            
# #             ct = ct + 1
        
# #         print(file_path, ct)

# #     if ct not in line_counts:
# #         line_counts[ct] = 0
# #     line_counts[ct] = line_counts[ct] + 1

# # print(sorted([(k, v) for k, v in line_counts.items()], key=lambda x: x[0]))

processing \\192.168.0.60\public\模型算法组\训练数据\zhangkuo\基础AI\DX-color-result0727特征\train\M0\TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F.txt
processing \\192.168.0.60\public\模型算法组\训练数据\zhangkuo\基础AI\DX-color-result0727特征\train\M0\TCGA-3L-AA1B-01Z-00-DX2.17CE3683-F4B1-4978-A281-8F620C4D77B4.txt


KeyboardInterrupt: 

In [28]:
# arr = sorted([(k, v) for k, v in line_counts.items()], key=lambda x: x[0])
# ct, delta = 0, 2000
# for k, v in arr:
#     if k >= delta:
#         ct = ct + v
# print(ct)

351


In [18]:
class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, files):
        super(MyIterableDataset).__init__()
        self.files = files
        
    def __iter__(self):
        data_iterator = self.data_generator()
        return data_iterator

#     def data_generator(self):
#         for i in range(0, len(self.files), batch_size):
#             X, y = [], []
#             for j in range(i, min(len(self.files), i + batch_size)):
#                 file_path = self.files[j]
#                 file_name = os.path.basename(file_path)
#                 patient_id = file_name[:12]
#                 if patient_id not in gene_data:
#                     continue

#                 features = get_features(file_path)
#                 X.append(features)
#                 gene_arr = gene_data[patient_id]
#                 y.append(gene_arr)

#             yield torch.Tensor(X), torch.Tensor(y)
    def data_generator(self):
        for i in range(0, len(self.files)):
            file_path = self.files[i]
            file_name = os.path.basename(file_path)
            patient_id = file_name[:12]
            if patient_id not in gene_data:
                continue

            features_ = get_features(file_path)
            features = features_.copy()
            while len(features) < 8000:
                features.extend(features_)
            features = features[:8000]
            
            gene_arr = gene_data[patient_id]

            yield torch.Tensor(features).t(), torch.Tensor(gene_arr)

In [20]:
# val_files = get_all_files(val_feature_path)
# train_dev_files = get_all_files(train_feature_path)

# test_files = [

# ]

# import shutil
# for file in val_files + train_dev_files:
#     if '-'.join(os.path.basename(file).split('-')[:6]).split('.')[0] in test_files:
#         shutil.move(file, file.replace('train', 'test'))


In [19]:
# 跑全部训练数据
val_files = get_all_files(val_feature_path)
train_dev_files = np.array(get_all_files(train_feature_path))

model = HE2RNA(**model_params)
optim_params = {'params': model.parameters(), 'lr': 3e-4, 'weight_decay': 1e-5 }
optimizer = optim.Adam(**optim_params)
train_dataset = MyIterableDataset(train_dev_files)
dev_dataset = MyIterableDataset(val_files)
test_dataset = MyIterableDataset(val_files)
#     fit_model(model, train_dataset, dev_dataset, test_dataset, k)
fit(model, train_dataset, dev_dataset, None, test_set=dev_dataset, params=training_params, optimizer=optimizer, logdir=logdir, path=os.path.join(savedir,'model'))

0it [00:05, ?it/s]


In [None]:
# 5-fold训练
val_files = get_all_files(val_feature_path)
train_dev_files = np.array(get_all_files(train_feature_path))
test_dataset = MyIterableDataset(val_files)

from sklearn.model_selection import KFold
KF = KFold(n_splits=5)
k = 0
for train_idxs, dev_idxs in KF.split(train_dev_files):
    k = k + 1
    print('KFold: %d, train files: %d, dev files: %d' % (k, len(train_dev_files[train_idxs]), len(train_dev_files[dev_idxs])))
    
    if k <= 2:
        continue
    elif k == 3:
        model = torch.load('HE2RNA/ckpts/model_3/model.pt')
    else:
        model = HE2RNA(**model_params)
        
    optim_params = {'params': model.parameters(), 'lr': 3e-4, 'weight_decay': 1e-5 }
    optimizer = optim.Adam(**optim_params)
    
    train_dataset = MyIterableDataset(train_dev_files[train_idxs])
    dev_dataset = MyIterableDataset(train_dev_files[dev_idxs])
    fit_model(model, train_dataset, dev_dataset, test_dataset, optimizer, k)


KFold: 1, train files: 284, dev files: 71
KFold: 2, train files: 284, dev files: 71
KFold: 3, train files: 284, dev files: 71


18it [13:56, 46.48s/it]


Epoch 1/100 - 836.62s




loss: 0.3355, val loss: 0.4913
correlations: nan
save model: HE2RNA/ckpts\model_3\model.pt


18it [13:52, 46.23s/it]


Epoch 2/100 - 1135.27s




loss: 0.2953, val loss: 0.5061
correlations: nan


18it [13:55, 46.41s/it]


Epoch 3/100 - 1133.86s




loss: 0.3135, val loss: 0.4626
correlations: nan
save model: HE2RNA/ckpts\model_3\model.pt


18it [13:49, 46.10s/it]


Epoch 4/100 - 1126.65s




loss: 0.2891, val loss: 0.4333
correlations: nan
save model: HE2RNA/ckpts\model_3\model.pt


18it [13:55, 46.40s/it]


Epoch 5/100 - 1133.87s




loss: 0.2973, val loss: 0.4603
correlations: nan


18it [13:56, 46.46s/it]


Epoch 6/100 - 1138.63s




loss: 0.2903, val loss: 0.4339
correlations: nan


18it [13:56, 46.49s/it]


Epoch 7/100 - 1137.86s




loss: 0.2809, val loss: 0.3985
correlations: nan
save model: HE2RNA/ckpts\model_3\model.pt


18it [14:00, 46.71s/it]


Epoch 8/100 - 1139.63s




loss: 0.2812, val loss: 0.4330
correlations: nan


18it [13:50, 46.15s/it]


Epoch 9/100 - 1128.46s




loss: 0.2842, val loss: 0.4076
correlations: nan


18it [13:55, 46.41s/it]


Epoch 10/100 - 1134.25s




loss: 0.2912, val loss: 0.3639
correlations: nan
save model: HE2RNA/ckpts\model_3\model.pt


18it [13:50, 46.13s/it]


Epoch 11/100 - 1126.44s




loss: 0.2919, val loss: 0.4000
correlations: nan


18it [13:47, 45.97s/it]


Epoch 12/100 - 1123.05s




loss: 0.2793, val loss: 0.4208
correlations: nan


18it [13:52, 46.26s/it]


Epoch 13/100 - 1129.66s




loss: 0.2885, val loss: 0.4345
correlations: nan


18it [13:38, 45.50s/it]


Epoch 14/100 - 1114.05s




loss: 0.2789, val loss: 0.3809
correlations: nan


18it [14:02, 46.81s/it]


Epoch 15/100 - 1136.26s




loss: 0.2773, val loss: 0.3650
correlations: nan
Early stopping at epoch 15
KFold: 4, train files: 284, dev files: 71


18it [16:37, 55.44s/it]


Epoch 1/100 - 997.96s




loss: 0.6779, val loss: 0.6173
correlations: nan
save model: HE2RNA/ckpts\model_4\model.pt


18it [14:36, 48.70s/it]


Epoch 2/100 - 1121.25s




loss: 0.3744, val loss: 0.6390
correlations: nan


18it [14:39, 48.87s/it]


Epoch 3/100 - 1123.66s




loss: 0.3509, val loss: 0.5307
correlations: nan
save model: HE2RNA/ckpts\model_4\model.pt


18it [15:08, 50.48s/it]


Epoch 4/100 - 1157.82s




loss: 0.3325, val loss: 0.4780
correlations: nan
save model: HE2RNA/ckpts\model_4\model.pt


18it [14:44, 49.16s/it]


Epoch 5/100 - 1130.29s




loss: 0.3223, val loss: 0.4590
correlations: nan
save model: HE2RNA/ckpts\model_4\model.pt


18it [14:44, 49.13s/it]


Epoch 6/100 - 1129.43s




loss: 0.2967, val loss: 0.5084
correlations: nan


18it [14:47, 49.33s/it]


Epoch 7/100 - 1133.30s




loss: 0.3139, val loss: 0.4419
correlations: nan
save model: HE2RNA/ckpts\model_4\model.pt


18it [15:22, 51.25s/it]


Epoch 8/100 - 1174.02s




loss: 0.3021, val loss: 0.4577
correlations: nan


18it [15:43, 52.41s/it]


Epoch 9/100 - 1199.42s




loss: 0.3165, val loss: 0.4230
correlations: nan
save model: HE2RNA/ckpts\model_4\model.pt


18it [15:06, 50.34s/it]


Epoch 10/100 - 1157.03s




loss: 0.2991, val loss: 0.3964
correlations: nan
save model: HE2RNA/ckpts\model_4\model.pt


18it [15:14, 50.78s/it]


Epoch 11/100 - 1169.67s




loss: 0.2999, val loss: 0.4204
correlations: nan


18it [14:40, 48.90s/it]


Epoch 12/100 - 1126.85s




loss: 0.2986, val loss: 0.4171
correlations: nan


18it [14:44, 49.14s/it]


Epoch 13/100 - 1128.84s




loss: 0.2891, val loss: 0.4115
correlations: nan


18it [15:54, 53.01s/it]


Epoch 14/100 - 1200.65s




loss: 0.2936, val loss: 0.3554
correlations: nan
save model: HE2RNA/ckpts\model_4\model.pt


16it [12:54, 57.16s/it]

In [8]:
# 5-fold模型结果取表达值平均
val_files = get_all_files(val_feature_path)
all_preds, labels = [], []
for path in ['model_1', 'model_2', 'model_3', 'model_4', 'model_5']:
    print(path)
    model = torch.load(os.path.join('HE2RNA/ckpts', path, 'model.pt'))
    test_dataset = MyIterableDataset(val_files)
    preds, labels = predict(model, DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0))
    all_preds.append(preds)
    print(len(all_preds))

model_1
1
model_2
2
model_3
3
model_4
4
model_5
5


In [9]:
# 计算相关系数和P值
preds = np.mean(np.array(all_preds), axis=0)
stat_cor(preds, labels, os.path.join(savedir, 'results_per_fold.txt'))

ENSG00000005981,ASB4,0.057984525494429034,0.5915318862225967
ENSG00000007038,PRSS21,-0.12212664781426619,0.25699689937853804
ENSG00000007216,SLC13A2,0.16302515892565086,0.12911086719031825
ENSG00000007350,TKTL1,-0.15493959891537337,0.14947156710235524
ENSG00000007402,CACNA2D2,-0.012844692101669533,0.9054528872657591
ENSG00000029559,IBSP,0.033271870269288135,0.7582799643800965
ENSG00000036565,SLC18A1,-0.26626577155739234,0.01215804094659939
ENSG00000057149,SERPINB3,-0.04019327962002365,0.7100376080949337
ENSG00000063015,SEZ6,-0.10289290345204201,0.3401064251405881
ENSG00000068985,PAGE1,nan,nan
ENSG00000075673,ATP12A,-0.08566398974805044,0.4274460154689305
ENSG00000081800,SLC13A1,nan,nan
ENSG00000088726,TMEM40,0.10915778523553656,0.3113571028040993
ENSG00000096395,MLN,nan,nan
ENSG00000101443,WFDC2,-0.053631280092617445,0.6197040243039623
ENSG00000103355,PRSS33,0.03576633682316324,0.7407755185559847
ENSG00000104327,CALB1,0.08639098518598878,0.42351988986107286
ENSG00000104755,ADAM2,nan,na



### 使用MLP预测，然后合并结果

In [14]:
os.path.basename('\\\\192.168.0.60\\public\\模型算法组\\训练数据\\zhangkuo\\基础AI\\DX-color-result0727特征\\test\\M0\\TCGA-A6-2672-01Z-00-DX1.e2a845c8-6d77-4120-9f43-abec84a66c1c.txt')

'TCGA-A6-2672-01Z-00-DX1.e2a845c8-6d77-4120-9f43-abec84a66c1c.txt'

In [15]:
from scipy import stats
def stat_cor(pred_arr, labels_arr):
    for idx, gene in enumerate(gene_dict.keys()):
        X1 = [row[idx] for row in labels_arr]
        X2 = [row[idx] for row in pred_arr]
        r,p = stats.pearsonr(X1, X2)
        print('%s,%s,%s,%s' %(gene, gene_dict[gene], r, p))
    

In [23]:
stat_cor(pred_max_arr, labels_max_arr)

ENSG00000005981,ASB4,-0.030443755460308922,0.7913235342576515
ENSG00000007038,PRSS21,0.03786920350218425,0.7420271235471032
ENSG00000007216,SLC13A2,-0.14590838453257224,0.20243072416706573
ENSG00000007350,TKTL1,0.03245718038044584,0.7778665864575642
ENSG00000007402,CACNA2D2,0.008386052109706735,0.9419102788590765
ENSG00000029559,IBSP,0.2949524682633532,0.00875463791934883
ENSG00000036565,SLC18A1,-0.11764470305971175,0.30498473306381796
ENSG00000057149,SERPINB3,-0.02175041395546692,0.8500807451293372
ENSG00000063015,SEZ6,-0.3484878464517442,0.0017677534394488204
ENSG00000068985,PAGE1,0.15717476248561738,0.16935571719319684
ENSG00000075673,ATP12A,0.0015944805179871,0.9889458870393479
ENSG00000081800,SLC13A1,-0.15765389534533647,0.16804296514896988
ENSG00000088726,TMEM40,0.10458129527182848,0.36217585173375905
ENSG00000096395,MLN,0.08258598091723698,0.4722438061759445
ENSG00000101443,WFDC2,0.10575386824305572,0.35678891173046684
ENSG00000103355,PRSS33,-0.015283981583851,0.8943413819234552

In [19]:
stat_cor(pred_min_arr, labels_arr)

ENSG00000005981,ASB4,0.08461255833228458,0.4614049966151423
ENSG00000007038,PRSS21,-0.12859229429984295,0.26185124842947854
ENSG00000007216,SLC13A2,0.0740538583649224,0.5193470238564177
ENSG00000007350,TKTL1,0.013530212995364162,0.9064070354520173
ENSG00000007402,CACNA2D2,-0.07663568776283694,0.504847285021527
ENSG00000029559,IBSP,-0.2364353754011208,0.037153347097998835
ENSG00000036565,SLC18A1,0.08065088485296844,0.48272030175287817
ENSG00000057149,SERPINB3,-0.025941999356550394,0.8216273561409368
ENSG00000063015,SEZ6,0.2527539755388899,0.025574845345135582
ENSG00000068985,PAGE1,0.04357645096036088,0.7048191012489392
ENSG00000075673,ATP12A,-0.042497995107528365,0.7118002818118601
ENSG00000081800,SLC13A1,0.10011256063494595,0.383159871672875
ENSG00000088726,TMEM40,-0.03669284794009285,0.7497736311166264
ENSG00000096395,MLN,0.016774343884257175,0.8841066126241026
ENSG00000101443,WFDC2,-0.12200946368079849,0.28726399517784756
ENSG00000103355,PRSS33,0.0035983904784849994,0.975056508975891

In [21]:
stat_cor(pred_mean_arr, labels_arr)

ENSG00000005981,ASB4,0.03033634415665392,0.7920431835512495
ENSG00000007038,PRSS21,-0.15986913522610188,0.16207078224253554
ENSG00000007216,SLC13A2,0.24785569902376273,0.028675227249536424
ENSG00000007350,TKTL1,0.10363982051877206,0.36653704198214326
ENSG00000007402,CACNA2D2,-0.009625569703540341,0.9333425515802353
ENSG00000029559,IBSP,-0.0651711268444628,0.5707914537608454
ENSG00000036565,SLC18A1,0.13260274863436436,0.2471387392255085
ENSG00000057149,SERPINB3,-0.1510474430275085,0.18681323264594973
ENSG00000063015,SEZ6,0.12260410089045093,0.28490362222498045
ENSG00000068985,PAGE1,0.10392670332019747,0.36520472853249925
ENSG00000075673,ATP12A,-0.00974717996757429,0.9325023552861534
ENSG00000081800,SLC13A1,0.2342682318346237,0.03897792894609891
ENSG00000088726,TMEM40,-0.16196127794538892,0.15657600522944237
ENSG00000096395,MLN,0.05962161322451442,0.6040926798983984
ENSG00000101443,WFDC2,-0.09233801254151783,0.4213619327314606
ENSG00000103355,PRSS33,-0.07094111393008437,0.537104231746364

In [49]:
stat_cor(pred_quantile_arr, labels_arr)

ENSG00000007171,NOS2,-0.00042727658839101807,0.9970377178745256
ENSG00000021826,CPS1,0.046562781922578136,0.6856159476309227
ENSG00000026559,KCNG1,0.015038488682477648,0.8960289569420999
ENSG00000057468,MSH4,0.03748392434597125,0.7445614774119119
ENSG00000058404,CAMK2B,0.019389910973877725,0.8661914639782119
ENSG00000064655,EYA2,0.044902959681714545,0.6962655531677766
ENSG00000066405,CLDN18,0.17343885725387087,0.12886972678023953
ENSG00000075461,CACNG4,0.026670870467643243,0.8167019235949223
ENSG00000080166,DCT,-0.08900008816761489,0.4384126373089254
ENSG00000081051,AFP,0.030359743059869257,0.7918863974841348
ENSG00000084674,APOB,-0.24737845702426017,0.028993578853101054
ENSG00000088726,TMEM40,0.15047390816598663,0.18851162272823313
ENSG00000092345,DAZL,0.10869812578766648,0.3434821879593246
ENSG00000093134,VNN3,-0.013815661763889122,0.9044416758887885
ENSG00000095777,MYO3A,0.04413753262318597,0.701196614884642
ENSG00000096395,MLN,0.06065398473455846,0.59783326314781
ENSG00000099769,IG

In [None]:

# def gen_tensor_from_path(feature_path, batch_size):
#     features, labels = [], []
#     for subfolder in ['M0', 'M1', 'M2']:
#         print('subfolder: %s' % subfolder)
#         folder = os.path.join(feature_path, subfolder)
#         for file_name in os.listdir(folder):
#             print('file_name: %s' % file_name)
#             patient_id = file_name[:12]
#             if patient_id not in gene_data[subfolder]:
#                 continue

#             file_path = os.path.join(folder, file_name)
#             features_ = get_features(file_path)
#             features.extend(features_)

#             gene_arr = gene_data[subfolder][patient_id]
#             labels.extend([gene_arr for i in range(len(features_))])
            
#             if len(features) >= batch_size:
#                 rt_features, rt_labels = features[:batch_size], labels[:batch_size]
#                 features, labels = features[batch_size:], labels[batch_size:]
#                 yield rt_features, rt_labels
                
# #     print("特征维度：%s, 标签维度：%s" % (len(features), len(labels)))

# def data_generator(feature_path):
#     for subfolder in ['M0', 'M1', 'M2']:
#         folder = os.path.join(feature_path, subfolder)
#         for file_name in os.listdir(folder):
#             patient_id = file_name[:12]
#             if patient_id not in gene_data:
#                 continue

#             file_path = os.path.join(folder, file_name)
#             features = get_features(file_path)
#             gene_arr = gene_data[patient_id]
#             labels = [gene_arr for i in range(len(features))]
            
#             for feature, label in zip(features, labels):
#                 yield feature, label