In [None]:
# 第一个cell - 添加项目路径
import sys
import os
sys.path.append('..')  # 添加项目根目录到Python路径

# 导入库
import torch
import matplotlib.pyplot as plt
import numpy as np
from src.data.loader import load_mnist_data
from src.data.preprocessing import visualize_samples, print_data_summary

In [None]:
# 第二个cell - 加载和探索数据
train_loader, test_loader = load_mnist_data(batch_size=64)
# 打印数据摘要
print_data_summary(train_loader, test_loader)

In [None]:
# 第三个cell - 可视化样本
visualize_samples(train_loader, num_samples=8)

In [None]:
# 第四个cell - 数据分布分析
import seaborn as sns

# 获取类别分布
def get_class_distribution(data_loader):
    class_counts = torch.zeros(10)
    for _, labels in data_loader:
        for label in labels:
            class_counts[label] += 1
    return class_counts

train_class_counts = get_class_distribution(train_loader)
test_class_counts = get_class_distribution(test_loader)

# 绘制类别分布图
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# 训练集分布
ax1.bar(range(10), train_class_counts.numpy())
ax1.set_title('训练集类别分布')
ax1.set_xlabel('数字')
ax1.set_ylabel('样本数量')
ax1.set_xticks(range(10))

# 测试集分布
ax2.bar(range(10), test_class_counts.numpy())
ax2.set_title('测试集类别分布')
ax2.set_xlabel('数字')
ax2.set_ylabel('样本数量')
ax2.set_xticks(range(10))

plt.tight_layout()
plt.show()