-
Notifications
You must be signed in to change notification settings - Fork 0
/
readWriteCSV.py
42 lines (38 loc) · 1.33 KB
/
readWriteCSV.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import csv
import numpy as np
def readCSV(filename, dtype="int64", verbose=True):
"""
Input: Str of csv filename
Output: numpy.ndarray of Nx17
"""
with open(filename, 'r') as foo:
if verbose:
print("Reading labels from {}...".format(filename))
csvreader = csv.reader(foo, dialect='excel')
label_mat = []
if dtype=="int64":
dtype = np.int64
elif dtype=="float":
dtype = np.float64
for line in csvreader:
tmp = np.asarray(line).astype(dtype).reshape((1, -1))
label_mat.append(tmp)
label_mat = np.concatenate(label_mat, axis=0)
return label_mat
def writeCSV(filename, label_mat):
"""
Input:
filename: str
label_mat: np.ndarray
"""
assert isinstance(label_mat, np.ndarray), "mat format error"
assert label_mat.ndim==2 and label_mat.shape[1]==17, "mat format error"
with open(filename, 'w', newline="") as csvfile:
print("Writing labels to {}...".format(filename))
writer = csv.writer(csvfile, dialect='excel')
writer.writerows(label_mat)
print("Successful!")
if __name__ == "__main__":
y_pred_mat = readCSV("c:/users/pzq/desktop/result-merge-1202.csv")
writeCSV("c:/users/pzq/desktop/test.csv", y_pred_mat)
import ipdb; ipdb.set_trace()