In [None]:
# 该文件用于处理数据并可视化
import csv 
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
#加上这句话，图片会嵌入notebook而不是跳出新窗口查看图片
%matplotlib inline
# 绘制漂亮的图形，避免模糊的图像
%config InlineBackend.figure_format = 'retina'
# 设置环境
sns.set_context('notebook')
sns.set(style="ticks", color_codes=True)
# 忽略警告
import warnings
warnings.filterwarnings('ignore')

# 启用多个单元输出
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [None]:
# 设置好需要处理的文件路径
file_path_list = ["/home/xzhang/Documents/我的模型/data/results/statistics/xavier_norm_psnr.csv",
                  "/home/xzhang/Documents/我的模型/data/results/statistics/xavier_uniform_psnr.csv",
                  "/home/xzhang/Documents/我的模型/data/results/statistics/kaiming_norm_psnr.csv",
                  "/home/xzhang/Documents/我的模型/data/results/statistics/kaiming_uniform_psnr.csv",
                  "/home/xzhang/Documents/我的模型/data/results/statistics/default_psnr.csv"
                  ]

file_path = "/home/xzhang/Documents/我的模型/data/statistics/xavier_norm.csv"

In [None]:

# 定义写入csv文件函数，用于训练阶段将数据保存起来
def write_csv(my_dict, file_name):
    nb_cols = len(list(my_dict.keys()))
    nb_rows = len(my_dict[f"{1}"])
    with open(file_name, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        # 写入表头
        header = ['iters'] + list(range(nb_cols))
        writer.writerow(header)        
        # 写入数据
        for i in range(nb_rows):
            row = [i] + [my_dict[f"{j}"][i] for j in range(nb_cols)]
            writer.writerow(row)
            


In [None]:
# 把所有数据都读进来，然后再做整理，我打算把这个表沿着列拼接试一试呗
df_list = list()
for file in file_path_list:
    name = file.split('/')[-1].split('.')[0]
    type = name.split('_')[0]
    dist = name.split('_')[1]
    df = pd.read_csv(file)
    df = df.iloc[:,1:]
    
    type = pd.Series([type]*len(df),name = 'type')
    dist = pd.Series([dist]*len(df),name = 'dist')
    test = pd.Series(list(range(len(df))),name='test')
    df.insert(0,'test',test)
    df.insert(1,'type',type)
    df.insert(2,'dist',dist)

    df_list.append(df) 
df_brut = pd.concat(df_list,axis=0) 
df_brut.head()

In [None]:
# 把测试数据读进来,找到峰值和对应的迭代次数
df_list = list()
for file in file_path_list:
    name = file.split('/')[-1].split('.')[0]
    type = name.split('_')[0]
    dist = name.split('_')[1]
    df = pd.read_csv(file)
    df = df.iloc[:,1:]
    
    max_values = df.max()
    max_indexes = df.idxmax()
    
    df_new = pd.concat([max_values,max_indexes],axis=1).reset_index()
    df_new.columns = ['test', 'max_value', 'max_index']
    type = pd.Series([type]*len(max_values),name = 'type')
    dist = pd.Series([dist]*len(max_indexes),name = 'dist')
    df_new.insert(1,'type',type)
    df_new.insert(2,'dist',dist)
    df_list.append(df_new)   
df = pd.concat(df_list,axis=0)
df.tail()
df.head()
df.describe()

In [None]:
choices = ['box','boxen','violin']
for choice in choices:
    ax = sns.catplot(x='dist',y='max_value',col='type',data=df,kind=choice)
    plt.suptitle('PSNR peak')
    ax = sns.catplot(x='dist',y='max_index',col='type',kind=choice,data=df);
    plt.suptitle('iterations for PSNR peak')


In [None]:
file_path_list = ["/home/xzhang/Documents/我的模型/data/results/statistics/xavier_uniform_mse_gt.csv",
                  "/home/xzhang/Documents/我的模型/data/results/statistics/xavier_norm_mse_gt.csv",
                  "/home/xzhang/Documents/我的模型/data/results/statistics/kaiming_uniform_mse_gt.csv",
                  "/home/xzhang/Documents/我的模型/data/results/statistics/kaiming_norm_mse_gt.csv",
                  "/home/xzhang/Documents/我的模型/data/results/statistics/default_mse_gt.csv",
                  
                  ]

In [None]:
# 把所有数据都读进来，然后再做整理，我打算把这个表沿着列拼接试一试呗
df_list = list()
for file in file_path_list:
    name = file.split('/')[-1].split('.')[0]
    type = name.split('_')[0]
    dist = name.split('_')[1]
    df = pd.read_csv(file)
    df = df.iloc[:,1:]
    
    type = pd.Series([type]*len(df),name = 'type')
    dist = pd.Series([dist]*len(df),name = 'dist')
    test = pd.Series(list(range(len(df))),name='test')
    df.insert(0,'test',test)
    df.insert(1,'type',type)
    df.insert(2,'dist',dist)

    df_list.append(df) 
df_brut = pd.concat(df_list,axis=0) 
df_brut.head()

In [None]:
# 把测试数据读进来,找到峰值和对应的迭代次数
df_list = list()
for file in file_path_list:
    name = file.split('/')[-1].split('.')[0]
    type = name.split('_')[0]
    dist = name.split('_')[1]
    df = pd.read_csv(file)
    df = df.iloc[:,1:]
    
    min_values = df.min()
    min_indexes = df.idxmin()
    
    df_new = pd.concat([min_values,min_indexes],axis=1).reset_index()
    df_new.columns = ['test', 'min_value', 'min_index']
    type = pd.Series([type]*len(min_values),name = 'type')
    dist = pd.Series([dist]*len(min_indexes),name = 'dist')
    df_new.insert(1,'type',type)
    df_new.insert(2,'dist',dist)
    df_list.append(df_new)   
df = pd.concat(df_list,axis=0)
df.tail()
df.head()
df.describe()

In [None]:
choices = ['box','boxen','violin']
for choice in choices:
    ax = sns.catplot(x='dist',y='min_value',col='type',data=df,kind=choice)
    plt.suptitle('min mse')
    ax = sns.catplot(x='dist',y='min_index',col='type',kind=choice,data=df);
    plt.suptitle('iterations for min mse')

In [1]:

ax = sns.boxplot(x='type',y='max_value',hue='dist',data=df);

medians = df.groupby(['type','dist'])['max_value'].median().values
# print(medians)
# 统计各个种类的样本数
nobs = df[['type','dist']].value_counts().values
nobs = [str(x) for x in nobs.tolist()]
nobs = ["n: " + i for i in nobs]
# print(nobs)
# 在中位数上放标签
pos = range(len(nobs))
for tick,label in zip(pos,ax.get_xticklabels()):
    ax.text(pos[tick]-0.2, medians[tick] + 0.03, nobs[tick], horizontalalignment='center', size='x-small', color='w', weight='semibold')
    ax.text(pos[tick]+0.2, medians[tick] + 0.03, nobs[tick], horizontalalignment='center', size='x-small', color='w', weight='semibold')

NameError: name 'sns' is not defined

In [None]:
print(df_brut.columns)

In [None]:
choices = ['scatter','line']
for choice in choices:
    ax = sns.relplot(x='test',y=[str(x) for x in list(range(100))],hue='dist',col='type',kind=choice,data=df_brut);


