-
Notifications
You must be signed in to change notification settings - Fork 0
/
cifar_10.jl
39 lines (34 loc) · 1009 Bytes
/
cifar_10.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
load_cifar_10(n::Union{Nothing, Int}=nothing)
Loads data from the CIFAR-10 dataset.
"""
function load_cifar_10(n::Union{Nothing,Int}=nothing)
X, y = MLDatasets.CIFAR10()[:] # [:] gives us X, y
X = Flux.flatten(X)
X = X .* 2 .- 1 # normalization between [-1, 1]
y = MLJBase.categorical(y)
y = DataAPI.unwrap.(y)
# counterfactual_data = CounterfactualExplanations.CounterfactualData(
# X, y; domain=(-1.0, 1.0), standardize=false
# )
# Undersample:
if !isnothing(n)
X, y = subsample(X, y, n)
end
return (X, y)
end
"""
load_cifar_10_test()
Loads test data from the CIFAR-10 dataset.
"""
function load_cifar_10_test()
X, y = MLDatasets.CIFAR10(:test)[:]
X = Flux.flatten(X)
X = X .* 2 .- 1 # normalization between [-1, 1]
y = MLJBase.categorical(y)
y = DataAPI.unwrap.(y)
# counterfactual_data = CounterfactualExplanations.CounterfactualData(
# X, y; domain=(-1.0, 1.0)
# )
return (X, y)
end