Skip to content

Commit

Permalink
fixed sizes returned by MNIST to include batchSize
Browse files Browse the repository at this point in the history
  • Loading branch information
abeschneider committed Jan 9, 2017
1 parent e6da841 commit 4a65ec6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions Sources/DataLoader/mnist.swift
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ public class MNISTData: Sequence, IteratorProtocol, Shuffable, SupervisedData {
public var index:Int
public var batchSize:Int

public var imageSize:Extent { return Extent(28, 28) }
public var labelSize:Extent { return Extent(1) }
public var imageSize:Extent { return Extent(batchSize, 28, 28) }
public var labelSize:Extent { return Extent(batchSize) }

public var count:Int { return images.shape[0] }

Expand Down Expand Up @@ -201,8 +201,8 @@ public class MNISTData: Sequence, IteratorProtocol, Shuffable, SupervisedData {
let end:Int = index+batchSize
let sz:Int = end >= self.count ? self.count-1 : end

let image = Tensor<I>(Extent(sz, imageSize[0], imageSize[1]))
let label = Tensor<I>(Extent(sz))
let image = Tensor<I>(imageSize)
let label = Tensor<I>(labelSize)
for i in 0..<sz {
let index = indices[i]
image[i, all, all] = images[index, all, all]
Expand Down

0 comments on commit 4a65ec6

Please sign in to comment.