Skip to content

Commit

Permalink
Add img_size argument
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Feb 8, 2023
1 parent 95b13d9 commit 09d5be4
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
5 changes: 4 additions & 1 deletion src/datasets/vision/imagenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Dict{String, Any} with 8 entries:
"class_names" => Vector{SubString{String}}[["tench", "Tinca tinca"], ["goldfish", "C
"metadata_path" => "/Users/funks/.julia/datadeps/ImageNet/devkit/data/meta.mat"
"n_classes" => 1000
"img_size" => (224, 224)
"wnid_to_label" => Dict("n07693725"=>932, "n03775546"=>660, "n01689811"=>45, "n0210087
julia> dataset.metadata["class_names"][y]
Expand Down Expand Up @@ -120,6 +121,7 @@ ImageNet(Tx::Type; kws...) = ImageNet(; Tx, kws...)
function ImageNet(
Tx::Type,
split::Symbol;
img_size::Tuple{Int,Int}=(224, 224),
preprocess=ImageNetReader.default_preprocess,
inverse_preprocess=ImageNetReader.default_inverse_preprocess,
dir=nothing,
Expand Down Expand Up @@ -157,9 +159,10 @@ function ImageNet(
metadata["features_dir"] = features_dir
metadata["n_observations"] = n_observations
metadata["n_classes"] = ImageNetReader.NCLASSES
metadata["img_size"] = img_size

# Create FileDataset
dataset = ImageNetReader.get_file_dataset(Tx, preprocess, features_dir)
dataset = ImageNetReader.get_file_dataset(Tx, img_size, preprocess, features_dir)
@assert length(dataset) == n_observations

targets = [
Expand Down
9 changes: 5 additions & 4 deletions src/datasets/vision/imagenet_reader/ImageNetReader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@ import ..@lazy
@lazy import JpegTurbo = "b835a17e-a41a-41e7-81f0-2f016b05efe0"

const NCLASSES = 1000
const IMGSIZE = (224, 224)

include("preprocess.jl")

function get_file_dataset(Tx::Type{<:Real}, preprocess::Function, dir::AbstractString)
function get_file_dataset(
Tx::Type{<:Real}, img_size::Tuple{Int,Int}, preprocess::Function, dir::AbstractString
)
# Construct a function that loads images from FileDataset path,
# applies preprocessing and converts to type Tx.
function load_image(file::AbstractString)
im = JpegTurbo.jpeg_decode(RGB{Tx}, file; preferred_size=IMGSIZE)
return Tx.(preprocess(im))
im = JpegTurbo.jpeg_decode(RGB{Tx}, file; preferred_size=img_size)
return preprocess(im, img_size)
end
return FileDataset(load_image, dir, "*.JPEG")
end
Expand Down
6 changes: 3 additions & 3 deletions src/datasets/vision/imagenet_reader/preprocess.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Image preprocessing defaults for ImageNet models.

function default_preprocess(im::AbstractMatrix{<:AbstractRGB})
im = channelview(center_crop(im))
function default_preprocess(im::AbstractMatrix{<:AbstractRGB}, outsize)
im = channelview(center_crop(im, outsize))
return PermutedDimsArray(im, (3, 2, 1)) # Convert from Image.jl's CHW to Flux's WHC
end

Expand All @@ -12,7 +12,7 @@ function default_inverse_preprocess(x::AbstractArray{T,N}) where {T,N}
end

# Take rectangle of pixels of shape `outsize` at the center of image `im`
function center_crop(im::AbstractMatrix, outsize=IMGSIZE)
function center_crop(im::AbstractMatrix, outsize)
h2, w2 = div.(outsize, 2) # half height, half width of view
h_adjust, w_adjust = _adjust.(outsize)
return @view im[
Expand Down

0 comments on commit 09d5be4

Please sign in to comment.