Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataLoader.nobs could make use of partial flag to return final number of samples being used ? #76

Closed
Karthik-d-k opened this issue Apr 14, 2022 · 3 comments

Comments

@Karthik-d-k
Copy link
Contributor

Karthik-d-k commented Apr 14, 2022

I have recently started learning Flux for doing Deep Learning and came across this unique behavior of DataLoader object.
if we create a DataLoader object as follows ->

julia> dl = DataLoader(rand(Int8, 10, 64), batchsize=30, partial=true)
DataLoader{Matrix{Int8}, Random._GLOBAL_RNG}(Int8[-73 -49  65 57; 82 -99  -125 -72;  ; -109 23  14 -68; -60 -90  -121 70], 30, 64, true, false, Random._GLOBAL_RNG())

julia> dl.nobs
64

We get total number of samples that will be used by DataLoader as 64 which is correct.

But when we set partial=false, we would get the same behavior as explained above w.r.t. dl.nobs being set to same value 64.

My expectation in the latter scenario would be to set dl.nobs to 60, because we will be throwing away last 4 samples (dropping last mini-batch).

As i couldn't able to find the docs for dl.nobs, this is my current understanding, please correct me if I'm missing something obvious here.

And, if my understanding is correct, there could possibly be 2 changes needed in main/src/dataloader.jl file as follows ->

function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
    batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
    nobs = numobs(data)
+   partial || (nobs -= nobs % batchsize)     # subtract last mini-batch samples when `partial=false`
    if nobs < batchsize
        @warn "Number of observations less than batchsize, decreasing the batchsize to $nobs"
        batchsize = nobs
    end
    DataLoader(data, batchsize, nobs, partial, shuffle, rng)
end
function Base.length(d::DataLoader)
    n = d.nobs / d.batchsize
-   d.partial ? ceil(Int, n) : floor(Int, n)  # removing this line as we would get correct `n`
end

I'm new to Julia, looking forward to learn and improve😃🤞

@darsnack
Copy link
Member

dl.nobs is not an officially support API, it's just an internal detail of how DataLoader is implemented. In this case, it is used to keep track of the number of observations in the original data container. This is why you cannot find any docs on it; it's not meant to be used by users.

In general, for Julia packages, accessing fields of a struct by name is not considered part of an API unless explicitly documented. I don't think we would want to introduce dl.nobs as an API in this case. Perhaps there could be an alternate function we could introduce here that returns this value.

@Karthik-d-k
Copy link
Contributor Author

Thanks for the clarification 😊
I was just trying to find out an easy way for getting total number of samples used by the DataLoader and found this dl.nobs which i thought could be a perfect fit.
As you have mentioned, an alternate function that does this job sounds perfectly fine me.

@CarloLucibello
Copy link
Member

CarloLucibello commented Jun 28, 2022

In the DataLoader docstring we have

The original data is preserved in the `data` field of the DataLoader.

so one way to get the total number of samples in the partial=true case is

numobs(dl.data)

For partial=false the number of samples consumed in an epoch is

 dl.batchsize * (numobs(dl.data) ÷ dl.batchsize)

and if shuffle=true different epochs leave out different samples.

I don't think an interface function like num_samples_consumed_per_epoch would be of general use, maybe it is not worth exposing. I'll close the issue for now, but if anyone thinks we should add this API to MLUtils.jl please ping for reopening.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants