Skip to content

Commit

Permalink
speed improvements for lua
Browse files Browse the repository at this point in the history
  • Loading branch information
FourierTransformer committed Mar 31, 2019
1 parent b083444 commit b6da316
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 88 deletions.
14 changes: 8 additions & 6 deletions encoder.lua
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
-- CSV Encoder for ftcsv

-- lua/luajit load compat
local M = {}
local luaCompatibility = {}
if type(jit) == 'table' or _ENV then
M.load = _G.load
-- luajit and lua 5.2+
luaCompatibility.load = _G.load
else
M.load = loadstring
-- lua 5.1
luaCompatibility.load = loadstring
end

local function delimitField(field)
Expand Down Expand Up @@ -48,7 +49,7 @@ local function csvLineGenerator(inputTable, delimiter, headers)
-- so we're just going to pass it in
arguments.delimitField = delimitField

return M.load(outputFunc), arguments, 0
return luaCompatibility.load(outputFunc), arguments, 0

end

Expand Down Expand Up @@ -90,7 +91,8 @@ local function getHeadersFromOptions(options)
local headers = nil
if options then
if options.fieldsToKeep ~= nil then
assert(type(options.fieldsToKeep) == "table", "ftcsv only takes in a list (as a table) for the optional parameter 'fieldsToKeep'. You passed in '" .. tostring(options.headers) .. "' of type '" .. type(options.headers) .. "'.")
assert(
type(options.fieldsToKeep) == "table", "ftcsv only takes in a list (as a table) for the optional parameter 'fieldsToKeep'. You passed in '" .. tostring(options.headers) .. "' of type '" .. type(options.headers) .. "'.")
headers = options.fieldsToKeep
end
end
Expand Down
184 changes: 102 additions & 82 deletions ftcsv.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,37 @@ local ftcsv = {
}

-- luajit/lua compatability layer
local M = {}
local luaCompatibility = {}

-- perf
local sbyte = string.byte
local ssub = string.sub

-- determine the real headers as opposed to the header mapping
local function determineRealHeaders(headerField, fieldsToKeep)
local realHeaders = {}
local headerSet = {}
for i = 1, #headerField do
if not headerSet[headerField[i]] then
if fieldsToKeep ~= nil and fieldsToKeep[headerField[i]] then
table.insert(realHeaders, headerField[i])
headerSet[headerField[i]] = true
elseif fieldsToKeep == nil then
table.insert(realHeaders, headerField[i])
headerSet[headerField[i]] = true
end
end
end
return realHeaders
end

-- luajit specific speedups
-- luajit performs faster with iterating over string.byte,
-- whereas vanilla lua performs faster with string.find
if type(jit) == 'table' then
luaCompatibility.LuaJIT = true
-- finds the end of an escape sequence
function M.findClosingQuote(i, inputLength, inputString, quote, doubleQuoteEscape)
function luaCompatibility.findClosingQuote(i, inputLength, inputString, quote, doubleQuoteEscape)
local currentChar, nextChar = sbyte(inputString, i), nil
while i <= inputLength do
-- print(i)
Expand All @@ -65,8 +84,9 @@ if type(jit) == 'table' then
end

else
luaCompatibility.LuaJIT = false
-- vanilla lua closing quote finder
function M.findClosingQuote(i, inputLength, inputString, quote, doubleQuoteEscape)
function luaCompatibility.findClosingQuote(i, inputLength, inputString, quote, doubleQuoteEscape)
local j, difference
i, j = inputString:find('"+', i)
if j == nil then return end
Expand All @@ -77,21 +97,12 @@ else
-- print("difference", difference, "I", i, "J", j)
if difference >= 1 then doubleQuoteEscape = true end
if difference == 1 then
return M.findClosingQuote(j+1, inputLength, inputString, quote, doubleQuoteEscape)
return luaCompatibility.findClosingQuote(j+1, inputLength, inputString, quote, doubleQuoteEscape)
end
return j-1, doubleQuoteEscape
end
end

-- load an entire file into memory
local function loadFile(textFile)
local file = io.open(textFile, "r")
if not file then error("ftcsv: File not found at " .. textFile) end
local allLines = file:read("*all")
file:close()
return allLines
end

-- creates a new field
local function createField(inputString, quote, fieldStart, i, doubleQuoteEscape)
local field
Expand All @@ -110,6 +121,21 @@ local function createField(inputString, quote, fieldStart, i, doubleQuoteEscape)
return field
end

local function determineTotalColumnCount(headerField, fieldsToKeep)
local totalColumnCount = 0
local headerFieldSet = {}
for _, header in pairs(headerField) do
-- count unique columns and
-- also figure out if it's a field to keep
if not headerFieldSet[header] and
(fieldsToKeep == nil or fieldsToKeep[header]) then
headerFieldSet[header] = true
totalColumnCount = totalColumnCount + 1
end
end
return totalColumnCount
end

-- main function used to parse
local function parseString(inputString, delimiter, i, headerField, fieldsToKeep, buffered)

Expand All @@ -125,12 +151,17 @@ local function parseString(inputString, delimiter, i, headerField, fieldsToKeep,
local doubleQuoteEscape, emptyIdentified = false, false
local exit = false

local skipIndex
local charPatterToSkip = "[" .. delimiter .. "\r\n]"


--bytes
local CR = sbyte("\r")
local LF = sbyte("\n")
local quote = sbyte('"')
local delimiterByte = sbyte(delimiter)


local assignValue
local outResults
-- the headers haven't been set yet.
Expand All @@ -139,13 +170,15 @@ local function parseString(inputString, delimiter, i, headerField, fieldsToKeep,
headerField = {}
assignValue = function()
headerField[fieldNum] = field
doubleQuoteEscape = false
emptyIdentified = false
return true
end
else
outResults = {}
outResults[1] = {}
assignValue = function()
doubleQuoteEscape = false
emptyIdentified = false
if headerField[fieldNum] ~= nil then
outResults[lineNum][headerField[fieldNum]] = field
Expand All @@ -155,15 +188,8 @@ local function parseString(inputString, delimiter, i, headerField, fieldsToKeep,
end
end

-- calculate the initial line count (note: this can include duplicates)
local headerFieldsExist = {}
local initialLineCount = 0
for _, value in pairs(headerField) do
if not headerFieldsExist[value] and (fieldsToKeep == nil or fieldsToKeep[value]) then
headerFieldsExist[value] = true
initialLineCount = initialLineCount + 1
end
end
-- totalColumnCount based on unique headers and fieldsToKeep
local totalColumnCount = determineTotalColumnCount(headerField, fieldsToKeep)

while i <= inputLength do
-- go by two chars at a time! currentChar is set at the bottom.
Expand Down Expand Up @@ -192,7 +218,7 @@ local function parseString(inputString, delimiter, i, headerField, fieldsToKeep,
emptyIdentified = false
end

i, doubleQuoteEscape = M.findClosingQuote(i+1, inputLength, inputString, quote, doubleQuoteEscape)
i, doubleQuoteEscape = luaCompatibility.findClosingQuote(i+1, inputLength, inputString, quote, doubleQuoteEscape)
-- print("I VALUE", i, doubleQuoteEscape)
skipChar = 1

Expand All @@ -205,18 +231,18 @@ local function parseString(inputString, delimiter, i, headerField, fieldsToKeep,
-- print("FIELD", field, "FIELDEND", headerField[fieldNum], lineNum)
assignValue()
end
doubleQuoteEscape = false

fieldNum = fieldNum + 1
fieldStart = i + 1
-- print("fs+1:", fieldStart)

-- newline?!
elseif (currentChar == CR or currentChar == LF) then
elseif (currentChar == LF or currentChar == CR) then
if fieldsToKeep == nil or fieldsToKeep[headerField[fieldNum]] then
-- create the new field
field = createField(inputString, quote, fieldStart, i, doubleQuoteEscape)

-- used to exit for the headers...
exit = assignValue()
if exit then
if (currentChar == CR and nextChar == LF) then
Expand All @@ -226,7 +252,6 @@ local function parseString(inputString, delimiter, i, headerField, fieldsToKeep,
end
end
end
doubleQuoteEscape = false

-- determine how line ends
if (currentChar == CR and nextChar == LF) then
Expand All @@ -237,18 +262,17 @@ local function parseString(inputString, delimiter, i, headerField, fieldsToKeep,
end

-- incrememnt for new line
if fieldNum < initialLineCount then
if fieldNum < totalColumnCount then
-- sometimes in buffered mode, the buffer starts with a newline
-- this skips the newline and lets the parsing continue.
if lineNum == 1 and fieldNum == 1 and buffered then
-- print("fieldNum", fieldNum)
-- print("initialLineCount", initialLineCount)
-- print("totalColumnCount", totalColumnCount)
-- print("lineNum", lineNum)
-- print(i)
fieldStart = i + 1 + skipChar
lineStart = fieldStart
else
-- return "YA"
error('ftcsv: too few columns in row ' .. lineNum)
end
else
Expand All @@ -260,6 +284,12 @@ local function parseString(inputString, delimiter, i, headerField, fieldsToKeep,
-- print("fs:", fieldStart)
end

elseif luaCompatibility.LuaJIT == false then
skipIndex = inputString:find(charPatterToSkip, i)
if skipIndex then
skipChar = skipChar + (skipIndex - i - 1)
end

end
-- this happens when you can't find a closing quote - usually means in the middle of a buffer
if i == nil and buffered then
Expand Down Expand Up @@ -292,28 +322,20 @@ local function parseString(inputString, delimiter, i, headerField, fieldsToKeep,
return headerField, i-1
end

-- clean up last line if it's weird (this happens when there is a CRLF newline at end of file)
-- doing a count gets it to pick up the oddballs
local finalLineCount = 0
local lastValue = nil
for _, v in pairs(outResults[lineNum]) do
finalLineCount = finalLineCount + 1
lastValue = v
end

-- this indicates a CRLF
-- print("Final/Initial", finalLineCount, initialLineCount)
if finalLineCount == 1 and lastValue == "" then
outResults[lineNum] = nil

-- otherwise there might not be enough line
elseif finalLineCount < initialLineCount then
if buffered then
-- TODO: look into buffered here, as there's likely an edge case here.
if fieldNum < totalColumnCount then
-- indicates last field was really just a CRLF,
-- so, it can be removed
if fieldNum == 1 and field == "" then
outResults[lineNum] = nil
-- print(#outResults)
return outResults, lineStart
else
error('ftcsv: too few columns in row ' .. lineNum)
if buffered then
outResults[lineNum] = nil
-- print(#outResults)
return outResults, lineStart
else
error('ftcsv: too few columns in row ' .. lineNum)
end
end
end

Expand Down Expand Up @@ -364,6 +386,34 @@ local function handleHeaders(headerField, options)
return headerField
end

local function includesBOM(inputString)
return string.byte(inputString, 1) == 239
and string.byte(inputString, 2) == 187
and string.byte(inputString, 3) == 191
end

-- load an entire file into memory
local function loadFile(textFile)
local file = io.open(textFile, "r")
if not file then error("ftcsv: File not found at " .. textFile) end
local allLines = file:read("*all")
file:close()
return allLines
end

local function initializeInputFromStringOrFile(inputFile, options)
-- handle input via string or file!
local inputString
if options.loadFromString then inputString = inputFile
else inputString = loadFile(inputFile) end

-- if they sent in an empty file...
if inputString == "" then
error('ftcsv: Cannot parse an empty file')
end
return inputString
end

local function parseOptions(delimiter, options)
-- delimiter MUST be one character
assert(#delimiter == 1 and type(delimiter) == "string", "the delimiter must be of string type and exactly one character")
Expand Down Expand Up @@ -408,46 +458,16 @@ local function parseOptions(delimiter, options)

end

-- determine the real headers as opposed to the header mapping
local function determineRealHeaders(headerField, fieldsToKeep)
local realHeaders = {}
local headerSet = {}
for i = 1, #headerField do
if not headerSet[headerField[i]] then
if fieldsToKeep ~= nil and fieldsToKeep[headerField[i]] then
table.insert(realHeaders, headerField[i])
headerSet[headerField[i]] = true
elseif fieldsToKeep == nil then
table.insert(realHeaders, headerField[i])
headerSet[headerField[i]] = true
end
end
end
return realHeaders
end

-- runs the show!
function ftcsv.parse(inputFile, delimiter, options)
-- make sure options make sense and get fields to keep
local options, fieldsToKeep = parseOptions(delimiter, options)

-- handle input via string or file!
local inputString
if options.loadFromString then inputString = inputFile
else inputString = loadFile(inputFile) end

-- if they sent in an empty file...
if inputString == "" then
error('ftcsv: Cannot parse an empty file')
end
local inputString = initializeInputFromStringOrFile(inputFile, options)

-- parse through the headers!
-- determine start of input
local startLine = 1

-- check for BOM
if string.byte(inputString, 1) == 239
and string.byte(inputString, 2) == 187
and string.byte(inputString, 3) == 191 then
if includesBOM(inputString) then
startLine = 4
end

Expand Down

0 comments on commit b6da316

Please sign in to comment.