-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
27 lines (23 loc) · 1.03 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
class Dataset:
def __init__(self, means1=[-5, 0], means2=[0, 5]):
self._rand_sample(means1, means2)
def _rand_sample(self, means1, means2):
means1 = np.array(means1)
means2 = np.array(means2)
covar = np.array([1, 0, 0, 1]).reshape(2,2)
x1 = np.random.multivariate_normal(means1, covar, size=200)
x2 = np.random.multivariate_normal(means2, covar, size=200)
y1 = np.ones((200, 1))
y2 = np.ones((200, 1)) * -1
#y2 = np.ones((200, 1)) * -1
self._split(x1, y1, x2, y2)
def _split(self, x1, y1, x2, y2):
num1 = x1.shape[0]
train_num1 = int(num1 * 0.8)
num2 = x2.shape[0]
train_num2 = int(num2 * 0.8)
self.x_train = np.concatenate((x1[:train_num1],x2[:train_num2]),axis=0)
self.y_train = np.concatenate((y1[:train_num1],y2[:train_num2]),axis=0)
self.x_test = np.concatenate((x1[train_num1:],x2[train_num2:]),axis=0)
self.y_test = np.concatenate((y1[train_num1:],y2[train_num2:]),axis=0)