-
Notifications
You must be signed in to change notification settings - Fork 27
/
mcmc.py
104 lines (90 loc) · 3.82 KB
/
mcmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Analysis element for the Markov Chain Monte Carlo (MCMC) method.
For more details on the method see the `easyvvuq.sampling.MonteCarloSampler` class.
The analysis part of Markov Chain Monte Carlo consists of approximating the distribution
from the results obtained by evaluating the samples.
"""
import pandas as pd
from .base import BaseAnalysisElement
from .results import AnalysisResults
class MCMCAnalysisResults(AnalysisResults):
"""The analysis results class for MCMC. You will not need to instantiate this
class manually.
Parameters
----------
chains: dict
A dictionary with pandas DataFrame that correspond to an MCMC chain each.
A chain consists of points that MCMC has visited. From this a distribution
of the input variables can be constructed by means of a simple histogram.
"""
def __init__(self, chains):
self.chains = chains
def plot_hist(self, input_parameter, chain=None, skip=0, merge=True):
"""Will plot a histogram for a given input parameter.
Parameters
----------
input_parameter: str
An input parameter name to draw the histogram for.
chain: int, optional
Index of a chain to be plotted.
skip: int
How many steps to skip (for getting rid of burn-in).
merge: bool
If set to True will use all chains to construct the histogram.
"""
import matplotlib.pyplot as plt
input_parameter = (input_parameter, 0)
if merge:
chain_keys = list(self.chains.keys())
df = self.chains[chain_keys[0]][input_parameter].iloc[skip:]
for chain in chain_keys[1:]:
df.append(self.chains[chain][input_parameter].iloc[skip:])
plt.hist(df, 20)
else:
plt.hist(self.chains[chain][input_parameter].iloc[skip:], 20)
def plot_chains(self, input_parameter, chain=None):
"""Will plot the chains with the input parameter value in the y axis.
Parameters
----------
input_parameter: str
Input parameter name.
chain: int, optional
The chain number of the chain to plot.
"""
import matplotlib.pyplot as plt
if chain is None:
for chain in self.chains:
plt.plot(self.chains[chain][(input_parameter, 0)])
else:
plt.plot(self.chains[chain][(input_parameter, 0)])
class MCMCAnalysis(BaseAnalysisElement):
"""The analysis part of the MCMC method in EasyVVUQ
Parameters
----------
sampler: MCMCSampler
An instance of MCMCSampler used to generate MCMC samples.
"""
def __init__(self, sampler):
self.sampler = sampler
def analyse(self, df):
"""Performs some pre-processing on the chains in order to be able to construct
the histograms or other methods of distribution estimation.
Parameters
----------
df: DataFrame
DataFrame with the results obtained by evaluating the samples generated by the
MCMC sampler.
"""
chains = dict([(chain_id, []) for chain_id in df[('chain_id', 0)].unique()])
for chain in chains:
chain_values = df[df[('chain_id', 0)] == chain]
values = chain_values.groupby(('iteration', 0)).apply(lambda x: x.mean())
indexes = values.index.values
for a, b in zip(indexes[:-1], indexes[1:]):
chains[chain] += [values.loc[a][self.sampler.inputs].to_dict()] * (b - a)
for chain in chains:
tmp = dict([(input_, []) for input_ in chains[chain][0]])
for row in chains[chain]:
for input_ in chains[chain][0]:
tmp[input_].append(row[input_])
chains[chain] = pd.DataFrame(tmp)
return MCMCAnalysisResults(chains)