-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathread_csv.lua
52 lines (40 loc) · 1 KB
/
read_csv.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/usr/bin/env th
-- Read CSV file
-- Split string
function string:split(sep)
local sep, fields = sep, {}
local pattern = string.format("([^%s]+)", sep)
self:gsub(pattern, function(substr) fields[#fields + 1] = substr end)
return fields
end
local filePath = 'train.csv'
-- Count number of rows and columns in file
local i = 0
for line in io.lines(filePath) do
if i == 0 then
COLS = #line:split(',')
end
i = i + 1
end
local ROWS = i - 1 -- Minus 1 because of header
-- Read data from CSV to tensor
local csvFile = io.open(filePath, 'r')
local header = csvFile:read()
local data = torch.Tensor(ROWS, COLS)
local i = 0
for line in csvFile:lines('*l') do
i = i + 1
local l = line:split(',')
for key, val in ipairs(l) do
data[i][key] = val
end
end
csvFile:close()
-- Serialize tensor
local outputFilePath = 'train.th7'
torch.save(outputFilePath, data)
-- Deserialize tensor object
local restored_data = torch.load(outputFilePath)
-- Make test
print(data:size())
print(restored_data:size())