# 数据集探索

本笔记本用于探索数据集的结构和特征。

In [None]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pprint
from tqdm.notebook import tqdm
from omegaconf import OmegaConf
from data.dataset import Dataset
from data.datamodule import DataModule
from hydra import compose, initialize

from hydra.core.config_store import ConfigStore

def load_config(config):
    cs = ConfigStore.instance()
    cs.store(name="cfg", node=config)
    with initialize(config_path="configs", version_base="1.3"):
        cfg = compose(config_name="cfg")
    return cfg

## 1. 数据集基本信息

In [None]:
from configs.experiments.base import BaseTrainConfig

cfg = load_config(BaseTrainConfig)

print(OmegaConf.to_yaml(cfg))

In [None]:
import seaborn as sns
import numpy as np
info['views']

# 创建图形
plt.figure(figsize=(15, 10))

# 1. 原始分布
plt.subplot(2, 2, 1)
sns.histplot(data=info, x='views', bins=50)
plt.title('Views Distribution')
plt.xlabel('Views')
plt.ylabel('Frequency')

# 2. 对数尺度分布
plt.subplot(2, 2, 2)
sns.histplot(data=info, x='views', bins=50, log_scale=True)
plt.title('Views Distribution (Logarithmic Scale)')
plt.xlabel('Views (Logarithmic)')
plt.ylabel('Frequency')

# 3. 对数变换后的分布
plt.subplot(2, 2, 3)
sns.histplot(data=np.log1p(info['views']), bins=50)
plt.title('log1p(Views) Distribution')
plt.xlabel('log1p(Views)')
plt.ylabel('Frequency')

# 4. 箱线图
plt.subplot(2, 2, 4)
sns.boxplot(y=np.log1p(info['views']))
plt.title('log1p(Views) Boxplot')
plt.ylabel('log1p(Views)')

plt.tight_layout()
plt.show()

# 打印统计信息
print("Views Statistics:")
print(f"Min: {info['views'].min():,.0f}")
print(f"Max: {info['views'].max():,.0f}")
print(f"Average: {info['views'].mean():,.0f}")
print(f"Median: {info['views'].median():,.0f}")
print(f"Standard Deviation: {info['views'].std():,.0f}")

print(f"\nView Statistics after logarithmic transform:")
log_views = np.log1p(info['views'])
print(f"Min: {log_views.min():.2f}")
print(f"Max: {log_views.max():.2f}")
print(f"Average: {log_views.mean():.2f}")
print(f"Median: {log_views.median():.2f}")
print(f"Standard Deviation: {log_views.std():.2f}")

## 2. 目标值分布分析

In [None]:
# 收集所有目标值
targets = []
for i in tqdm(range(len(dataset))):
    targets.append(dataset[i]['target'].item())

# 绘制目标值分布
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
sns.histplot(targets, bins=50)
plt.title('目标值分布')
plt.xlabel('目标值')
plt.ylabel('频数')

plt.subplot(1, 2, 2)
sns.histplot(targets, bins=50, log_scale=True)
plt.title('目标值分布 (对数尺度)')
plt.xlabel('目标值 (对数)')
plt.ylabel('频数')

plt.tight_layout()
plt.show()

# 基本统计信息
print(f"目标值统计信息:")
print(f"最小值: {min(targets):.2f}")
print(f"最大值: {max(targets):.2f}")
print(f"平均值: {sum(targets)/len(targets):.2f}")
print(f"中位数: {sorted(targets)[len(targets)//2]:.2f}")

## 3. 图像特征分析

In [None]:
# 分析图像特征
sample_images = []
for i in range(min(5, len(dataset))):
    sample_images.append(dataset[i]['image'])

# 显示图像统计信息
for i, img in enumerate(sample_images):
    print(f"\n图像 {i+1} 统计信息:")
    print(f"形状: {img.shape}")
    print(f"最小值: {img.min().item():.3f}")
    print(f"最大值: {img.max().item():.3f}")
    print(f"平均值: {img.mean().item():.3f}")
    print(f"标准差: {img.std().item():.3f}")

## 4. 元数据分析

In [None]:
# 收集元数据
metadata = {}
for key in ['title', 'description', 'channel_title', 'category_id']:
    metadata[key] = []

for i in tqdm(range(len(dataset))):
    for key in metadata.keys():
        if key in dataset[i]:
            metadata[key].append(dataset[i][key])

# 分析元数据
for key, values in metadata.items():
    if values:
        print(f"\n{key} 统计信息:")
        if isinstance(values[0], str):
            print(f"唯一值数量: {len(set(values))}")
            print(f"示例值: {values[:3]}")
        else:
            print(f"唯一值数量: {len(set(values))}")
            print(f"示例值: {values[:3]}")

## 5. 数据加载器测试

In [None]:
# 测试数据加载器
datamodule = DataModule(
    dataset_path="../dataset",
    train_transform=None,  # 使用默认转换
    test_transform=None,   # 使用默认转换
    batch_size=32,
    num_workers=4,
    seed=42,
    val_split=0.1
)

# 获取一个批次的数据
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()

print(f"训练集批次数量: {len(train_loader)}")
if val_loader:
    print(f"验证集批次数量: {len(val_loader)}")

# 检查一个批次的数据
batch = next(iter(train_loader))
print("\n批次数据示例:")
for key, value in batch.items():
    if isinstance(value, torch.Tensor):
        print(f"{key}: Tensor shape {value.shape}")
    else:
        print(f"{key}: {value[:2] if isinstance(value, list) else value}")