/
probability.py
232 lines (201 loc) · 7.7 KB
/
probability.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from manimlib.constants import *
from manimlib.mobject.geometry import Line
from manimlib.mobject.geometry import Rectangle
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.svg.brace import Brace
from manimlib.mobject.svg.tex_mobject import TexMobject
from manimlib.mobject.svg.tex_mobject import TextMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_gradient
from manimlib.utils.iterables import tuplify
EPSILON = 0.0001
class SampleSpace(Rectangle):
CONFIG = {
"height": 3,
"width": 3,
"fill_color": DARK_GREY,
"fill_opacity": 1,
"stroke_width": 0.5,
"stroke_color": LIGHT_GREY,
##
"default_label_scale_val": 1,
}
def add_title(self, title="Sample space", buff=MED_SMALL_BUFF):
# TODO, should this really exist in SampleSpaceScene
title_mob = TextMobject(title)
if title_mob.get_width() > self.get_width():
title_mob.set_width(self.get_width())
title_mob.next_to(self, UP, buff=buff)
self.title = title_mob
self.add(title_mob)
def add_label(self, label):
self.label = label
def complete_p_list(self, p_list):
new_p_list = list(tuplify(p_list))
remainder = 1.0 - sum(new_p_list)
if abs(remainder) > EPSILON:
new_p_list.append(remainder)
return new_p_list
def get_division_along_dimension(self, p_list, dim, colors, vect):
p_list = self.complete_p_list(p_list)
colors = color_gradient(colors, len(p_list))
last_point = self.get_edge_center(-vect)
parts = VGroup()
for factor, color in zip(p_list, colors):
part = SampleSpace()
part.set_fill(color, 1)
part.replace(self, stretch=True)
part.stretch(factor, dim)
part.move_to(last_point, -vect)
last_point = part.get_edge_center(vect)
parts.add(part)
return parts
def get_horizontal_division(
self, p_list,
colors=[GREEN_E, BLUE_E],
vect=DOWN
):
return self.get_division_along_dimension(p_list, 1, colors, vect)
def get_vertical_division(
self, p_list,
colors=[MAROON_B, YELLOW],
vect=RIGHT
):
return self.get_division_along_dimension(p_list, 0, colors, vect)
def divide_horizontally(self, *args, **kwargs):
self.horizontal_parts = self.get_horizontal_division(*args, **kwargs)
self.add(self.horizontal_parts)
def divide_vertically(self, *args, **kwargs):
self.vertical_parts = self.get_vertical_division(*args, **kwargs)
self.add(self.vertical_parts)
def get_subdivision_braces_and_labels(
self, parts, labels, direction,
buff=SMALL_BUFF,
min_num_quads=1
):
label_mobs = VGroup()
braces = VGroup()
for label, part in zip(labels, parts):
brace = Brace(
part, direction,
min_num_quads=min_num_quads,
buff=buff
)
if isinstance(label, Mobject):
label_mob = label
else:
label_mob = TexMobject(label)
label_mob.scale(self.default_label_scale_val)
label_mob.next_to(brace, direction, buff)
braces.add(brace)
label_mobs.add(label_mob)
parts.braces = braces
parts.labels = label_mobs
parts.label_kwargs = {
"labels": label_mobs.copy(),
"direction": direction,
"buff": buff,
}
return VGroup(parts.braces, parts.labels)
def get_side_braces_and_labels(self, labels, direction=LEFT, **kwargs):
assert(hasattr(self, "horizontal_parts"))
parts = self.horizontal_parts
return self.get_subdivision_braces_and_labels(parts, labels, direction, **kwargs)
def get_top_braces_and_labels(self, labels, **kwargs):
assert(hasattr(self, "vertical_parts"))
parts = self.vertical_parts
return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs)
def get_bottom_braces_and_labels(self, labels, **kwargs):
assert(hasattr(self, "vertical_parts"))
parts = self.vertical_parts
return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs)
def add_braces_and_labels(self):
for attr in "horizontal_parts", "vertical_parts":
if not hasattr(self, attr):
continue
parts = getattr(self, attr)
for subattr in "braces", "labels":
if hasattr(parts, subattr):
self.add(getattr(parts, subattr))
def __getitem__(self, index):
if hasattr(self, "horizontal_parts"):
return self.horizontal_parts[index]
elif hasattr(self, "vertical_parts"):
return self.vertical_parts[index]
return self.split()[index]
class BarChart(VGroup):
CONFIG = {
"height": 4,
"width": 6,
"n_ticks": 4,
"tick_width": 0.2,
"label_y_axis": True,
"y_axis_label_height": 0.25,
"max_value": 1,
"bar_colors": [BLUE, YELLOW],
"bar_fill_opacity": 0.8,
"bar_stroke_width": 3,
"bar_names": [],
"bar_label_scale_val": 0.75,
}
def __init__(self, values, **kwargs):
VGroup.__init__(self, **kwargs)
if self.max_value is None:
self.max_value = max(values)
self.add_axes()
self.add_bars(values)
self.center()
def add_axes(self):
x_axis = Line(self.tick_width * LEFT / 2, self.width * RIGHT)
y_axis = Line(MED_LARGE_BUFF * DOWN, self.height * UP)
ticks = VGroup()
heights = np.linspace(0, self.height, self.n_ticks + 1)
values = np.linspace(0, self.max_value, self.n_ticks + 1)
for y, value in zip(heights, values):
tick = Line(LEFT, RIGHT)
tick.set_width(self.tick_width)
tick.move_to(y * UP)
ticks.add(tick)
y_axis.add(ticks)
self.add(x_axis, y_axis)
self.x_axis, self.y_axis = x_axis, y_axis
if self.label_y_axis:
labels = VGroup()
for tick, value in zip(ticks, values):
label = TexMobject(str(np.round(value, 2)))
label.set_height(self.y_axis_label_height)
label.next_to(tick, LEFT, SMALL_BUFF)
labels.add(label)
self.y_axis_labels = labels
self.add(labels)
def add_bars(self, values):
buff = float(self.width) / (2 * len(values) + 1)
bars = VGroup()
for i, value in enumerate(values):
bar = Rectangle(
height=(value / self.max_value) * self.height,
width=buff,
stroke_width=self.bar_stroke_width,
fill_opacity=self.bar_fill_opacity,
)
bar.move_to((2 * i + 1) * buff * RIGHT, DOWN + LEFT)
bars.add(bar)
bars.set_color_by_gradient(*self.bar_colors)
bar_labels = VGroup()
for bar, name in zip(bars, self.bar_names):
label = TexMobject(str(name))
label.scale(self.bar_label_scale_val)
label.next_to(bar, DOWN, SMALL_BUFF)
bar_labels.add(label)
self.add(bars, bar_labels)
self.bars = bars
self.bar_labels = bar_labels
def change_bar_values(self, values):
for bar, value in zip(self.bars, values):
bar_bottom = bar.get_bottom()
bar.stretch_to_fit_height(
(value / self.max_value) * self.height
)
bar.move_to(bar_bottom, DOWN)
def copy(self):
return self.deepcopy()