# 第8章 绘图和可视化

## matplotlib API入门

In [None]:
%matplotlib inline

In [None]:
import numpy as np
import pandas as pd
from pandas import Series, DataFrame
import matplotlib.pyplot as plt #约定plt是matplotlib.pyplot的简写

In [None]:
data = np.arange(10)
data
plt.plot(data)

## Figure和Subplot
matplotlib的图像都位于Figure中。subplot是Figure中的子图

In [None]:
fig = plt.figure() #创建Figure
ax1 = fig.add_subplot(2, 2, 1) #创建subplot
ax2 = fig.add_subplot(2, 2, 2)
ax3 = fig.add_subplot(2, 2, 3)

In [None]:
#为不同的subplot绘图
ax1.plot(np.sin(np.linspace(0,2*3.14,20)))
ax2.plot(np.cos(np.linspace(0,2*3.14,20)))
ax3.plot(np.tan(np.linspace(0,2*3.14,20)))
fig #在ipython notebook以外请用fig.show()来显示绘图

In [None]:
fig = plt.figure() #创建Figure
ax1 = fig.add_subplot(2, 2, 1) #创建subplot
ax2 = fig.add_subplot(2, 2, 2)
ax3 = fig.add_subplot(2, 2, 3)

In [None]:
from numpy.random import randn
ax3.plot(randn(50).cumsum(), 'k--') #k--是一种线型
ax2.scatter(np.arange(30),np.arange(30)+3*randn(30))
_=ax1.hist(np.random.randn(100), bins=20, color='k', alpha=0.3)
fig

In [None]:
fig, axes = plt.subplots(2, 3)
axes

![subplots的选项](subplots的选项.png)

### 调整subplot周围的间距

函数`subplots_adjust(left=None, bottom=None, right=None, top=None,wspace=None, hspace=None)
`用于调整subplot周围的间距

In [None]:
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
for i in range(2):
    for j in range(2):
        axes[i, j].hist(np.random.randn(500), bins=50, color='k', alpha=0.5)
plt.subplots_adjust(wspace=0, hspace=0)#水平和垂直的间距为0

## 颜色，标记和线型

In [None]:
x=np.arange(0,10)
y=2*x+1
plt.subplot(2,1,1)
plt.plot(x, y, 'g--')
plt.subplot(2,1,2)
plt.plot(x, y, linestyle='--', color='r')

In [None]:
plt.plot(randn(30).cumsum(), 'ko--')

In [None]:
plt.plot(randn(30).cumsum(), color='#FF00FF', linestyle='dashed', marker='o')

In [None]:
data = np.random.randn(30).cumsum()
plt.plot(data, 'k--', label='Default')

In [None]:
plt.plot(data, 'k-', drawstyle='steps-post', label='steps-post')
plt.legend(loc='best')

## 刻度，标签和图例

### 设置标题，轴标签，刻度以及刻度标签

In [None]:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(np.random.randn(1000).cumsum())

In [None]:
ticks = ax.set_xticks([0, 250, 500, 750, 1000])
fig

In [None]:
labels = ax.set_xticklabels(['one', 'two', 'three', 'four', 'five'],
                            rotation=30, fontsize='small')
fig

In [None]:
ax.set_title('My first matplotlib plot')
fig

In [None]:
ax.set_xlabel('Stages')
fig

### 添加图例

In [None]:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(randn(1000).cumsum(), 'r', label='one') #label设置图例
ax.plot(randn(1000).cumsum(), 'g--', label='two')
ax.plot(randn(1000).cumsum(), 'b.', label='three')
ax.legend(loc='best') #显示图例

In [None]:
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(randn(1000).cumsum(), 'r', label='曲线1') #label设置图例
ax.plot(randn(1000).cumsum(), 'g--', label='曲线2')
ax.plot(randn(1000).cumsum(), 'b.', label='曲线3')
ax.legend(loc='best') #显示图例

## 注解以及在Subplot上绘图

In [None]:
from datetime import datetime

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)

data = pd.read_csv('spx.csv', index_col=0, parse_dates=True)
spx = data['SPX']

spx.plot(ax=ax, style='k-')

crisis_data = [
    (datetime(2007, 10, 11), 'Peak of bull market'),
    (datetime(2008, 3, 12), 'Bear Stearns Fails'),
    (datetime(2008, 9, 15), 'Lehman Bankruptcy')
]

for date, label in crisis_data:
    ax.annotate(label, xy=(date, spx.asof(date) + 75),
                xytext=(date, spx.asof(date) + 225),
                arrowprops=dict(facecolor='black', headwidth=4, width=2,
                                headlength=4),
                horizontalalignment='left', verticalalignment='top')

# Zoom in on 2007-2010
ax.set_xlim(['1/1/2007', '1/1/2011'])
ax.set_ylim([600, 1800])

ax.set_title('Important dates in the 2008-2009 financial crisis')

## 绘制几何图形

In [None]:
fig = plt.figure(figsize=(12, 6))
ax = fig.add_subplot(1, 1, 1)
rect = plt.Rectangle((0.2, 0.75), 0.4, 0.15, color='k', alpha=0.3)
circ = plt.Circle((0.7, 0.2), 0.15, color='b', alpha=0.3)
pgon = plt.Polygon([[0.15, 0.15], [0.35, 0.4], [0.2, 0.6]],
                   color='g', alpha=0.5)
ax.add_patch(rect)
ax.add_patch(circ)
ax.add_patch(pgon)

## 将图表保存到文件

In [None]:
plt.savefig('figpath.svg')

In [None]:
plt.savefig('figpath.png', dpi=400, bbox_inches='tight')

In [None]:
from io import BytesIO
buffer = BytesIO()
plt.savefig(buffer)
plot_data = buffer.getvalue()

![savefig选项](savefig选项.png)

## pandas中的绘图函数

### 线型图

In [None]:
plt.close('all')

In [None]:
s = pd.Series(np.random.randn(10).cumsum(), index=np.arange(0, 100, 10))
s.plot()

In [None]:
df = pd.DataFrame(np.random.randn(10, 4).cumsum(0),
                  columns=['A', 'B', 'C', 'D'],
                  index=np.arange(0, 100, 10))
df.plot()

### 柱状图

In [None]:
fig, axes = plt.subplots(2, 1)
data = pd.Series(np.random.rand(16), index=list('abcdefghijklmnop'))
data.plot(kind='bar', ax=axes[0], color='k', alpha=0.7)
data.plot(kind='barh', ax=axes[1], color='k', alpha=0.7)

In [None]:
df = pd.DataFrame(np.random.rand(6, 4),
                  index=['one', 'two', 'three', 'four', 'five', 'six'],
                  columns=pd.Index(['A', 'B', 'C', 'D'], name='Genus'))
df

In [None]:
df.plot(kind='bar')

In [None]:
tips = pd.read_csv('tips.csv')
party_counts = pd.crosstab(tips['day'], tips['size'])
party_counts

In [None]:
party_counts = party_counts.ix[:, 2:5]

In [None]:
party_pcts = party_counts.div(party_counts.sum(1).astype(float), axis=0)
party_pcts

In [None]:
party_pcts.plot(kind='bar', stacked=True)

### 直方图和密度图

In [None]:
tips['tip_pct'] = tips['tip'] / tips['total_bill']
tips['tip_pct'].hist(bins=50)

In [None]:
tips['tip_pct'].plot(kind='kde')

In [None]:
comp1 = np.random.normal(0, 1, size=200) # N(0, 1)

In [None]:
comp2 = np.random.normal(10, 2, size=200) # N(10, 4)

In [None]:
values = Series(np.concatenate([comp1, comp2]))

In [None]:
values.hist(bins=100, alpha=0.3, color='k', normed=True)

In [None]:
values.plot(kind='kde', style='k--')

### 散布图

In [None]:
macro = pd.read_csv('macrodata.csv')
data = macro[['cpi', 'm1', 'tbilrate', 'unemp']]
trans_data = np.log(data).diff().dropna()
trans_data[-5:]

In [None]:
plt.scatter(trans_data['m1'], trans_data['unemp'])
plt.title('Changes in log %s vs. log %s' % ('m1', 'unemp'))

## 绘制地图：图形化显示海底地震危机数据

In [None]:
data = pd.read_csv('Haiti.csv')
data

In [None]:
data[['INCIDENT DATE', 'LATITUDE', 'LONGITUDE']][:10]

In [None]:
data['CATEGORY'][:6]

In [None]:
data.describe()

In [None]:
data = data[(data.LATITUDE > 18) & (data.LATITUDE < 20) & 
            (data.LONGITUDE > -75) & (data.LONGITUDE < -70) 
            & data.CATEGORY.notnull()]

In [None]:
def to_cat_list(catstr):
    stripped = (x.strip() for x in catstr.split(','))
    return [x for x in stripped if x]

In [None]:
def get_all_categories(cat_series):
    cat_sets = (set(to_cat_list(x)) for x in cat_series) 
    return sorted(set.union(*cat_sets))

In [None]:
def get_english(cat):
    code, names = cat.split('.') 
    if '|' in names:
        names = names.split(' | ')[1] 
    return code, names.strip()

In [None]:
get_english('2. Urgences logistiques | Vital Lines')

In [None]:
all_cats = get_all_categories(data.CATEGORY)

In [None]:
english_mapping = dict(get_english(x) for x in all_cats)
english_mapping['2a']

In [None]:
english_mapping['6c']

In [None]:
def get_code(seq):
    return [x.split('.')[0] for x in seq if x]

In [None]:
all_codes = get_code(all_cats)
code_index = pd.Index(np.unique(all_codes))
dummy_frame = DataFrame(np.zeros((len(data), len(code_index))),
                                        index=data.index, columns=code_index)

In [None]:
dummy_frame.ix[:, :6]

In [None]:
for row, cat in zip(data.index, data.CATEGORY): 
    codes = get_code(to_cat_list(cat)) 
    dummy_frame.ix[row, codes] = 1

data = data.join(dummy_frame.add_prefix('category_'))

In [None]:
data.ix[:, 10:15]

以下部分需要安装basemap toolkit（http://matplotlib.github.com/basemap）

In [None]:
from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt
def basic_haiti_map(ax=None, lllat=17.25, urlat=20.25, lllon=-75, urlon=-71):
    # create polar stereographic Basemap instance.
    m = Basemap(ax=ax, projection='stere', lon_0=(urlon + lllon) / 2, lat_0=(urlat + lllat) / 2,llcrnrlat=lllat, urcrnrlat=urlat, llcrnrlon=lllon, urcrnrlon=urlon, resolution='f')
    # draw coastlines, state and country boundaries, edge of map. m.drawcoastlines()
    m.drawstates()
    m.drawcountries()
    return m

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 10)) 
fig.subplots_adjust(hspace=0.05, wspace=0.05)
to_plot = ['2a', '1', '3c', '7a']
lllat=17.25; urlat=20.25; lllon=-75; urlon=-71
for code, ax in zip(to_plot, axes.flat):
    m = basic_haiti_map(ax, lllat=lllat, urlat=urlat,lllon=lllon, urlon=urlon) 
    cat_data = data[data['category_%s' % code] == 1]
    # compute map proj coordinates.
    x, y = m(cat_data.LONGITUDE, cat_data.LATITUDE)
    m.plot(x, y, 'k.', alpha=0.5)
    ax.set_title('%s: %s' % (code, english_mapping[code]))

## Python图形化工具生态系统

### Chaco

### mayavi

### 其他库

### 图形化工具的未来