forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_svmlight_format.pyx
120 lines (98 loc) · 3.57 KB
/
_svmlight_format.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Optimized inner loop of load_svmlight_file.
#
# Authors: Mathieu Blondel <mathieu@mblondel.org>
# Lars Buitinck
# Olivier Grisel <olivier.grisel@ensta.org>
# License: BSD 3 clause
import array
from cpython cimport array
cimport cython
from libc.string cimport strchr
cimport numpy as np
import numpy as np
import scipy.sparse as sp
from ..externals.six import b
np.import_array()
cdef bytes COMMA = u','.encode('ascii')
cdef bytes COLON = u':'.encode('ascii')
@cython.boundscheck(False)
@cython.wraparound(False)
def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
bint query_id, long long offset, long long length):
cdef array.array data, indices, indptr
cdef bytes line
cdef char *hash_ptr
cdef char *line_cstr
cdef int idx, prev_idx
cdef Py_ssize_t i
cdef bytes qid_prefix = b('qid')
cdef Py_ssize_t n_features
cdef long long offset_max = offset + length if length > 0 else -1
# Special-case float32 but use float64 for everything else;
# the Python code will do further conversions.
if dtype == np.float32:
data = array.array("f")
else:
dtype = np.float64
data = array.array("d")
indices = array.array("i")
indptr = array.array("i", [0])
query = np.arange(0, dtype=np.int64)
if multilabel:
labels = []
else:
labels = array.array("d")
if offset > 0:
f.seek(offset)
# drop the current line that might be truncated and is to be
# fetched by another call
f.readline()
for line in f:
# skip comments
line_cstr = line
hash_ptr = strchr(line_cstr, '#')
if hash_ptr != NULL:
line = line[:hash_ptr - line_cstr]
line_parts = line.split()
if len(line_parts) == 0:
continue
target, features = line_parts[0], line_parts[1:]
if multilabel:
if COLON in target:
target, features = [], line_parts[0:]
else:
target = [float(y) for y in target.split(COMMA)]
target.sort()
labels.append(tuple(target))
else:
array.resize_smart(labels, len(labels) + 1)
labels[len(labels) - 1] = float(target)
prev_idx = -1
n_features = len(features)
if n_features and features[0].startswith(qid_prefix):
_, value = features[0].split(COLON, 1)
if query_id:
query.resize(len(query) + 1)
query[len(query) - 1] = np.int64(value)
features.pop(0)
n_features -= 1
for i in xrange(0, n_features):
idx_s, value = features[i].split(COLON, 1)
idx = int(idx_s)
if idx < 0 or not zero_based and idx == 0:
raise ValueError(
"Invalid index %d in SVMlight/LibSVM data file." % idx)
if idx <= prev_idx:
raise ValueError("Feature indices in SVMlight/LibSVM data "
"file should be sorted and unique.")
array.resize_smart(indices, len(indices) + 1)
indices[len(indices) - 1] = idx
array.resize_smart(data, len(data) + 1)
data[len(data) - 1] = float(value)
prev_idx = idx
array.resize_smart(indptr, len(indptr) + 1)
indptr[len(indptr) - 1] = len(data)
if offset_max != -1 and f.tell() > offset_max:
# Stop here and let another call deal with the following.
break
return (dtype, data, indices, indptr, labels, query)