Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blocks and container added for Text Dataset #205

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/FastAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ export Vision
include("Tabular/Tabular.jl")
@reexport using .Tabular

include("Text/Text.jl")
@reexport using .Text


include("deprecations.jl")
export
Expand Down Expand Up @@ -171,6 +174,7 @@ export
LabelMulti,
Many,
TableRow,
TextRow,
Continuous,
Image,

Expand Down
60 changes: 60 additions & 0 deletions src/Text/Text.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module Text


using ..FastAI
using DataFrames
using ..FastAI:
# blocks
Block, WrapperBlock, AbstractBlock, OneHotTensor, OneHotTensorMulti, Label,
LabelMulti, wrapped, Continuous, getencodings, getblocks, encodetarget, encodeinput,
# encodings
Encoding, StatefulEncoding, OneHot,
# visualization
ShowText,
# other
Context, Training, Validation, FASTAI_METHOD_REGISTRY, registerlearningtask!

# for tests
using ..FastAI: testencoding

# extending
import ..FastAI:
blockmodel, blockbackbone, blocklossfn, encode, decode, checkblock,
encodedblock, decodedblock, showblock!, mockblock, setup


import DataAugmentation
import DataFrames: DataFrame
import Flux: Embedding, Chain, Dropout, Dense, Parallel
import PrettyTables
import Requires: @require
import ShowCases: ShowCase
import Tables
import Statistics

using InlineTest


# Blocks
include("blocks/textrow.jl")

# Encodings
include("encodings/textpreprocessing.jl")


include("recipes.jl")


function __init__()
_registerrecipes()
@require Makie="ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" begin
import .Makie
import .Makie: @recipe, @lift
import .FastAI: ShowMakie
include("makie.jl")
end
end

export TextRow

end
70 changes: 70 additions & 0 deletions src/Text/blocks/textrow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@


# TextRow

"""
TextRow{M, N}(catcols, contcols, categorydict) <: Block

`Block` for table rows with M categorical and N continuous columns. `data`
is valid if it satisfies the `AbstractRow` interface in Tables.jl, values
present in indices for categorical and continuous columns are consistent,
and `data` is indexable by the elements of `catcols` and `contcols`.
"""
struct TextRow{M,N,T} <: Block
catcols::NTuple{M}
contcols::NTuple{N}
categorydict::T
end

function TextRow(catcols, contcols)
TextRow{length(catcols),length(contcols)}(catcols, contcols, categorydict)
end

function checkblock(block::TextRow, x)
columns = Tables.columnnames(x)
(
all(col -> col ∈ columns, (block.catcols..., block.contcols...)) &&
all(
col ->
haskey(block.categorydict, col) &&
(ismissing(x[col]) || x[col] ∈ block.categorydict[col]),
block.catcols,
) &&
all(col -> ismissing(x[col]) || x[col] isa Number, block.contcols)
)
end

function mockblock(block::TextRow)
cols = (block.catcols..., block.contcols...)
vals = map(cols) do col
col in block.catcols ? rand(block.categorydict[col]) : rand()
end
return NamedTuple(zip(cols, vals))
end

"""
setup(TextRow, data[; catcols, contcols])

Create a `TextRow` block from data container `data::TextDataset`. If the
categorical and continuous columns are not specified manually, try to
guess them from the dataset's column types.
"""
function setup(::Type{TextRow}, data; catcols=nothing, contcols=nothing)
catcols_, contcols_ = getcoltypes(data)
catcols = isnothing(catcols) ? catcols_ : catcols
contcols = isnothing(contcols) ? contcols_ : contcols

return TextRow(
catcols,
contcols,
gettransformdict(data, DataAugmentation.Categorify, catcols),
)
end

function Base.show(io::IO, block::TextRow)
print(io, ShowCase(block, (:catcols, :contcols), show_params=false, new_lines=true))
end

# ## Interpretation

# function showblock!(io, ::ShowText, block::TextRow, obs) end
188 changes: 188 additions & 0 deletions src/Text/encodings/textpreprocessing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""
EncodedTextRow{M, N} <: Block

Block for processed rows having a tuple of M categorical and
N continuous value collections.
"""
struct EncodedTextRow{M, N, T} <: Block
catcols::NTuple{M}
contcols::NTuple{N}
categorydict::T
end

function EncodedTextRow(catcols, contcols, categorydict)
EncodedTextRow{length(catcols), length(contcols)}(catcols, contcols, categorydict)
end

function checkblock(::EncodedTextRow{M, N}, x::Tuple{Vector, Vector}) where {M, N}
length(x[1]) == M && length(x[2]) == N
end


function showblock!(io, ::ShowText, block::EncodedTextRow, obs)
print(io, "EncodedTextRow(...)")
end


# ## Encoding

"""
TextPreprocessing <: Encoding

Encodes a `TextRow` by applying the following preprocessing steps:
- [`DataAugmentation.NormalizeRow`](#) (for normalizing a row of data for continuous columns)
- [`DataAugmentation.FillMissing`](#) (for filling missing values)
- [`DataAugmentation.Categorify`](#) (for label encoding categorical columns,
which can be later used for indexing into embedding matrices)
or a sequence of these transformations.
"""
struct TextPreprocessing{T} <: Encoding
tfms::T
end

TextPreprocessing(td::Datasets.TextDataset) = TextPreprocessing(gettransforms(td))

function encodedblock(::TextPreprocessing, block::TextRow)
EncodedTextRow(block.catcols, block.contcols, block.categorydict)
end

# function encode(tt::TextPreprocessing, _, block::TextRow, row)
# columns = Tables.columnnames(row)
# usedrow = NamedTuple(filter(
# x -> x[1] ∈ block.catcols || x[1] ∈ block.contcols,
# collect(zip(columns, row))
# ))
# tfmrow = DataAugmentation.apply(
# tt.tfms,
# DataAugmentation.TabularItem(usedrow, keys(usedrow))
# ).data
# catvals = collect(map(col -> tfmrow[col], block.catcols))
# contvals = collect(map(col -> tfmrow[col], block.contcols))
# (catvals, contvals)
# end


function setup(::Type{TextPreprocessing}, block::TextRow, data::TextDataset)
return TextPreprocessing(gettransforms(data, block.catcols, block.contcols))
end



# ## `blockmodel`


"""
blockmodel(inblock::TableRow{M, N}, outblock::Union{Continuous, OneHotTensor{0}}, backbone=nothing) where {M, N}

Contruct a model for tabular classification or regression. `backbone` should be a
NamedTuple of categorical, continuous, and a finalclassifier layer, with
the first two taking in batches of corresponding row value matrices.
"""

"""
blockmodel(::EncodedTableRow, ::OneHotTensor[, backbone])

Create a model for tabular classification. `backbone` should be named tuple
`(categorical = ..., continuous = ...)`. See [`TabularModel`](#) for more info.
"""
# function blockmodel(inblock::EncodedTableRow, outblock::OneHotTensor{0}, backbone)
# TabularModel(
# backbone.categorical,
# backbone.continuous,
# Dense(100, length(outblock.classes))
# )
# end


"""
blockmodel(::EncodedTableRow, ::Continuous[, backbone])

Create a model for tabular regression. `backbone` should be named tuple
`(categorical = ..., continuous = ...)`. See [`TabularModel`](#) for more info.
"""
# function blockmodel(inblock::EncodedTableRow, outblock::Continuous, backbone)
# TabularModel(
# backbone.categorical,
# backbone.continuous,
# Dense(100, outblock.size)
# )
# end


# function blockbackbone(inblock::EncodedTextRow{M, N}) where {M, N}
# embedszs = _get_emb_sz(collect(map(col->length(inblock.categorydict[col]), inblock.catcols)))
# catback = tabular_embedding_backbone(embedszs)
# contback = tabular_continuous_backbone(N)
# return (categorical = catback, continuous = contback)
# end


# ## Utilities

"""
The helper functions defined below can be used for quickly constructing a dictionary,
which will be required for creating various tabular transformations available in DataAugmentation.jl.

These functions assume that the table in the TableDataset object td has Tables.jl columnaccess interface defined.
"""
function gettransformdict(td, ::Type{DataAugmentation.NormalizeRow}, cols)
dict = Dict()
map(cols) do col
vals = skipmissing(Tables.getcolumn(td.table, col))
dict[col] = (Statistics.mean(vals), Statistics.std(vals))
end
dict
end

function gettransformdict(td, ::Type{DataAugmentation.FillMissing}, cols)
dict = Dict()
map(cols) do col
vals = skipmissing(Tables.getcolumn(td.table, col))
dict[col] = Statistics.median(vals)
end
dict
end

function gettransformdict(td, ::Type{DataAugmentation.Categorify}, cols)
dict = Dict()
map(cols) do col
vals = Tables.getcolumn(td.table, col)
dict[col] = unique(vals)
end
dict
end

"""
getcoltypes(td::Datasets.TableDataset)

Returns the categorical and continuous columns present in a `TextDataset`.
"""
function getcoltypes(td::Datasets.TextDataset)
schema = Tables.schema(td.table)

contcols = Tuple(name for (name, T) in zip(schema.names, schema.types)
if T <: Union{<:Number, <:Union{Missing, <:Number}})

catcols = Tuple(name for name in schema.names if !(name in contcols))
catcols, contcols
end

"""
gettransforms(td::Datasets.TableDataset)

Returns a composition of basic tabular transformations constructed
for the given TableDataset.
"""
function gettransforms(td::TextDataset, catcols, contcols)
normstats = gettransformdict(td, DataAugmentation.NormalizeRow, contcols)
fmvals = gettransformdict(td, DataAugmentation.FillMissing, contcols)
catdict = gettransformdict(td, DataAugmentation.Categorify, catcols)
normalize = DataAugmentation.NormalizeRow(normstats, contcols)
categorify = DataAugmentation.Categorify(catdict, catcols)
fm = DataAugmentation.FillMissing(fmvals, contcols)

return fm |> normalize |> categorify
end


gettransforms(td::TextDataset) = gettransforms(td, getcoltypes(td)...)
1 change: 1 addition & 0 deletions src/Text/makie.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# No Makie recipes yet, text is better I guess