In [None]:
import numpy as np
import torch
import os
import random
from torch.utils.data import Dataset
import torch.nn.functional as F

random.seed(1)
rmb_label = {"ASD": 0, "TD": 1}      # 设置标签

class ABIDEtxtDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        ABIDE_db的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform，数据预处理
        """
        self.label_name = {"ASD": 0, "TD": 1}
        self.data_info = self.get_txt_info(data_dir)  # data_info存储所有txt路径和标签，在DataLoader中通过index读取样本
        self.transform = transform
        print('Number of samples:', len(self.data_info))
        print('Sample info:', self.data_info[0])
        
    def __getitem__(self, index):
        path_txt, label = self.data_info[index]
        print(path_txt)
        txt_data = np.loadtxt(path_txt)
        cor = np.corrcoef(txt_data.T)
        # 将cor转换为Tensor
        cor = torch.from_numpy(cor).float().unsqueeze(0)
        
        if self.transform is not None:
            txt = self.transform(txt)   # 在这里做transform，转为tensor等等

        return cor, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_txt_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                txt_names = os.listdir(os.path.join(root, sub_dir))
                txt_names = list(filter(lambda x: x.endswith('.1D'), txt_names))

                # 遍历txt
                for i in range(len(txt_names)):
                    txt_name = txt_names[i]
                    path_txt = os.path.join(root, sub_dir, txt_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_txt, int(label)))

        return data_info