# 加载数据
## 通义加载feature和label

In [None]:
import numpy as np
import tensorflow as tf
from keras.utils import to_categorical

class KerasDatasetGenerator:
    def __init__(self, features_path, labels_path, batch_size=128):
        self.features_path = features_path
        self.labels_path = labels_path
        self.batch_size = batch_size

    def load_data(self):
        features = np.load(self.features_path)
        labels = np.load(self.labels_path)
        return features, labels

    def preprocess_data(self, features, labels):
        # 如果需要对数据进行任何预处理（如归一化、标准化等），请在此处添加相应代码
        return features, to_categorical(labels, num_classes=19 * 19)  # 假设您的数据有 19 * 19 个类别

    def create_dataset(self):
        features, labels = self.load_data()
        preprocessed_features, preprocessed_labels = self.preprocess_data(features, labels)

        dataset = tf.data.Dataset.from_tensor_slices((preprocessed_features, preprocessed_labels))
        dataset = dataset.shuffle(buffer_size=len(preprocessed_features))  # 随机打乱数据
        dataset = dataset.batch(self.batch_size)  # 分批
        dataset = dataset.prefetch(tf.data.AUTOTUNE)  # 提前预取数据以提高效率

        return dataset

def main():
    features_path = "./data/KGS-2007-19-11644-train_features_100.npy"
    labels_path = "./data/KGS-2007-19-11644-train_labels_100.npy"

    generator = KerasDatasetGenerator(features_path, labels_path, batch_size=128)
    train_dataset = generator.create_dataset()

    sample_count = min(10, len(generator.load_data()[0]))  # 取样不超过10个或数据集总长度，以较小者为准
    for features_batch, labels_batch in train_dataset.take(sample_count):
        print("Features batch shape:", features_batch.shape)
        print("Labels batch shape:", labels_batch.shape)
        #print("\nFirst sample features:\n", features_batch[0])
        #print("\nFirst sample labels:\n", labels_batch[0])
        print("\n---\n")

    # 现在您可以使用 train_dataset 作为 Keras 模型的输入数据集进行训练
    # model.fit(train_dataset, epochs=..., validation_data=...)

if __name__ == "__main__":
    main()

## 从 channels_first 转换为 channels_last 

In [None]:
# 将数据从 channels_first 转换为 channels_last 格式
# preprocessed_features = tf.transpose(preprocessed_features, perm=[0, 2, 3, 1])  # 注意 perm 参数的顺序

## tensorflow加载数据

In [5]:
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4],[5, 6]])
list(dataset.as_numpy_iterator())

2024-04-20 16:03:26.727083: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


[array([1, 2], dtype=int32),
 array([3, 4], dtype=int32),
 array([5, 6], dtype=int32)]