# Mavis 数据集分析

这个notebook用于分析mavis数据集。Mavis是一个视觉问答数据集，包含图像和相关问题及答案。

In [None]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# 设置数据集路径
DATASET_PATH = "/Users/jia/datasets/data/mavis/train"

print("正在检查数据集...")
if not os.path.exists(DATASET_PATH):
    print(f"错误: 数据集路径 {DATASET_PATH} 不存在")
else:
    print(f"数据集路径: {DATASET_PATH}")

In [None]:
# 列出所有数据文件
files = os.listdir(DATASET_PATH)
parquet_files = [f for f in files if f.endswith(".parquet")]

print(f"\n数据文件总数: {len(files)}")
print(f"Parquet文件数量: {len(parquet_files)}")
print(f"\n前10个文件:")
for file in parquet_files[:10]:
    print(f"  - {file}")

In [None]:
# 加载第一个parquet文件
if parquet_files:
    first_file = os.path.join(DATASET_PATH, parquet_files[0])
    print(f"\n正在加载文件: {first_file}")
    try:
        df = pd.read_parquet(first_file)
        print(f"数据集形状: {df.shape}")
        print(f"列名: {list(df.columns)}")
    except Exception as e:
        print(f"加载文件时出错: {e}")

In [None]:
# 显示数据集基本信息
if 'df' in locals():
    print("\n数据集基本信息:")
    print(df.info())

In [None]:
# 显示数据集前5行
if 'df' in locals():
    print("\n数据集前5行:")
    df.head()

In [None]:
# 检查数据集中的缺失值
if 'df' in locals():
    print("\n缺失值统计:")
    print(df.isnull().sum())

In [None]:
# 分析问题长度
if 'df' in locals() and 'question' in df.columns:
    df['question_length'] = df['question'].str.len()
    
    print("\n问题长度统计:")
    print(df['question_length'].describe())
    
    # 绘制问题长度分布直方图
    plt.figure(figsize=(10, 6))
    plt.hist(df['question_length'], bins=50, alpha=0.7)
    plt.title('问题长度分布')
    plt.xlabel('问题长度 (字符数)')
    plt.ylabel('频次')
    plt.show()

In [None]:
# 分析答案长度
if 'df' in locals() and 'answer' in df.columns:
    df['answer_length'] = df['answer'].str.len()
    
    print("\n答案长度统计:")
    print(df['answer_length'].describe())
    
    # 绘制答案长度分布直方图
    plt.figure(figsize=(10, 6))
    plt.hist(df['answer_length'], bins=50, alpha=0.7)
    plt.title('答案长度分布')
    plt.xlabel('答案长度 (字符数)')
    plt.ylabel('频次')
    plt.show()

In [None]:
# 显示一些样本数据
print("\n样本数据:")
if 'df' in locals():
    sample_size = min(5, len(df))
    sample_df = df.sample(n=sample_size, random_state=42)
    
    for i in range(len(sample_df)):
        print(f"\n样本 {i+1}:")
        row = sample_df.iloc[i]
        
        if 'split' in row:
            print(f"  数据集分割: {row['split']}")
        if 'question' in row:
            question = row['question']
            print(f"  问题: {question[:300]}{'...' if len(question) > 300 else ''}")
        if 'answer' in row:
            answer = row['answer']
            print(f"  答案: {answer[:300]}{'...' if len(answer) > 300 else ''}")
        if 'image' in row:
            image = row['image']
            if image is not None and not (isinstance(image, (list, np.ndarray)) and len(image) == 0):
                print(f"  图像信息: 存在")
            else:
                print(f"  图像信息: 不存在")

In [None]:
# 数据集总结
print("\n数据集总结:")
print("="*50)

if 'parquet_files' in locals():
    print(f"总文件数: {len(parquet_files)}")
    
if 'df' in locals():
    print(f"当前加载文件的样本数: {len(df)}")
    print(f"列名: {list(df.columns)}")

print("="*50)
print("分析完成")