Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convolutional network slower than tensorflow on CPU #2350

Open
alerem18 opened this issue Oct 27, 2023 · 13 comments
Open

Convolutional network slower than tensorflow on CPU #2350

alerem18 opened this issue Oct 27, 2023 · 13 comments

Comments

@alerem18
Copy link

using Flux
using MLDatasets: MNIST, CIFAR10, CIFAR100
using Flux: logitcrossentropy, setup, Adam, train!
using Flux.OneHotArrays: onehotbatch, onecold
using Statistics: mean
using Flux.MLUtils: DataLoader
using ProgressBars: tqdm, set_postfix
using Flux.Zygote: ignore

# ------------------------ DATA ---------------------
TRAIN = MNIST(split=:train)
TEST = MNIST(split=:test)


x_train, y_train = TRAIN.features, TRAIN.targets
x_test, y_test = TEST.features, TEST.targets


x_train = Flux.unsqueeze(x_train, dims=3)
x_test = Flux.unsqueeze(x_test, dims=3)
y_train_encoded = onehotbatch(y_train, 0:9)

TRAIN_LOADER = DataLoader((x_train, y_train_encoded); batchsize=128, shuffle=true)
TEST_LOADER = DataLoader((x_test, y_test); batchsize=128, shuffle=false)

# -------------------- MODEL ------------------------------
model = Flux.@autosize (28, 28, 1, 1) Chain(
    
    Conv((3, 3), 1=>32, relu),
    MaxPool((2, 2)),
    Conv((3, 3), 32=>64, relu),
    MaxPool((2, 2)),
    Flux.flatten,
    Dropout(0.5),
    Dense(_, 10)
)

optimizer = setup(Adam(0.001), model)



# --------------------- HELPER ----------------------------------
function accuracy(m, LOADER)
    corrects = 0
    total = 0
    for (X, Y)  LOADER
        total += length(Y)
        corrects += sum(onecold(m(X), 0:9) .== Y)
    end

    return corrects / total
end

# ------------------- TRAIN ----------------------------------------
function train_loop(model, optimizer, train_loader, test_loader; epochs=5)
    for epoch  1:epochs
        iter = tqdm(train_loader)
        total = 0
        corrects = 0
        for (X, Y)  iter
            train!(model, [(X, Y)], optimizer) do m, features, labels
                predicted = m(features)
                b_size = size(features)[end]
                ignore() do 
                    corrects += sum(onecold(predicted, 0:9) .== onecold(labels, 0:9))
                    total += b_size
                end
                logitcrossentropy(m(features), labels)
            end
            set_postfix(iter, accuracy=corrects / total)
        end

        val_accuracy = accuracy(model, test_loader)
        @info "Epoch $epoch/5 | Accuracy : $val_accuracy"
    end
end


train_loop(model, optimizer, TRAIN_LOADER, TEST_LOADER)
`

each epoch in Flux takes about 1 minute and 10 seconds while each epoch in tensorflow takes about 15 seconds





```python

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

# Model / data parameters
num_classes = 10
input_shape = (32, 32, 3)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)


model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()


batch_size = 128
epochs = 15

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
@darsnack
Copy link
Member

darsnack commented Oct 28, 2023

That's not the intended use for Flux.train!. This function is meant to iterate over an entire epoch, not a single batch. Try writing your loop as

function train_loop(model, optimizer, train_loader, test_loader; epochs=5)
    for epoch  1:epochs
        iter = tqdm(train_loader)
        total = 0
        corrects = 0
        for (X, Y)  iter
            grads = Flux.gradient(model) do m
                predicted = m(X)
                ignore() do 
                    b_size = size(X)[end]
                    corrects += sum(onecold(predicted, 0:9) .== onecold(Y, 0:9))  # edit, labels is Y
                    total += b_size
                end
                logitcrossentropy(predicted, Y)
            end
            optimizer, model = Flux.Optimise.update!(optimizer, model, grads[1])  # edit, fixed [0]
            set_postfix(iter, accuracy=corrects / total)
        end

        val_accuracy = accuracy(model, test_loader)
        @info "Epoch $epoch/5 | Accuracy : $val_accuracy"
    end
end

@alerem18
Copy link
Author

That's not the intended use for Flux.train!. This function is meant to iterate over an entire epoch, not a single batch. Try writing your loop as

function train_loop(model, optimizer, train_loader, test_loader; epochs=5)
    for epoch  1:epochs
        iter = tqdm(train_loader)
        total = 0
        corrects = 0
        for (X, Y)  iter
            grads = Flux.gradient(model) do m
                predicted = m(X)
                ignore() do 
                    b_size = size(features)[end]
                    corrects += sum(onecold(predicted, 0:9) .== onecold(labels, 0:9))
                    total += b_size
                end
                logitcrossentropy(predicted, labels)
            end
            optimizer, model = Flux.Optimise.update!(optimizer, model, grads[0])
            set_postfix(iter, accuracy=corrects / total)
        end

        val_accuracy = accuracy(model, test_loader)
        @info "Epoch $epoch/5 | Accuracy : $val_accuracy"
    end
end

i did that already, same speed, even a little slower

@mcabbott mcabbott changed the title Flux is Too Slow Convolutional network slower than tensorflow on CPU Oct 29, 2023
@mcabbott
Copy link
Member

My guess is that this NNlib's CPU implementations of Conv etc. being sub-optimal. That's the target of e.g. FluxML/NNlib.jl#540, and seeing whether that PR speeds up this example might be helpful. (And if it does, finding a way to push that PR forwards).

Otherwise, isolating exactly which operations are slower would be more helpful than overall times. Xref earlier issue about the same thing #2300

@alerem18
Copy link
Author

will there be any updates?

@ToucheSir
Copy link
Member

Have you seen the linked PR at FluxML/NNlib.jl#540? Other than contributing performance improvements to NNlib itself, best thing would be to do some benchmarking of what the bottlenecks in the Julia code are with a profiler. Ideally you could narrow it down to 1-2 types of layers which could be compared directly against their equivalents in PyTorch.

@alerem18
Copy link
Author

whatever it is, it's related to backward path, feed forward path is in flux is already faster than pytorch, or same speed at least

@ToucheSir
Copy link
Member

That's why I asked to narrow it down. If you can find which specific layers are slower on the backwards path and provide a MWE demonstrating that, then we have something to work with.

@aminaqi
Copy link

aminaqi commented Dec 19, 2023

MWE

here are CPU tests
i've not tested with GPU

FeedForward Flux:

using Flux
using BenchmarkTools


m = Conv((3, 3), 1 => 16; stride=(2, 2), pad=1)

A = Float32.(randn(28, 28, 1, 100))

# compile for the first time
m(A)

@btime m(A)

753.000 μs (76 allocations: 2.44 MiB)

FeedForward Pytorch:

import torch
import torch.nn as nn

m = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=1)
A = torch.randn((100, 1, 28, 28))
%timeit m(A)

172 µs ± 7.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

@aminaqi
Copy link

aminaqi commented Dec 19, 2023

MWE

here are CPU tests i've not tested with GPU

FeedForward Flux:

using Flux
using BenchmarkTools


m = Conv((3, 3), 1 => 16; stride=(2, 2), pad=1)

A = Float32.(randn(28, 28, 1, 100))

# compile for the first time
m(A)

@btime m(A)

753.000 μs (76 allocations: 2.44 MiB)

FeedForward Pytorch:

import torch
import torch.nn as nn

m = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=1)
A = torch.randn((100, 1, 28, 28))
%timeit m(A)

172 µs ± 7.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Flux is Significantly slower(almost 6 times) than Pytorch on CPU!!!

@aminaqi
Copy link

aminaqi commented Dec 19, 2023

MWE

here are CPU tests i've not tested with GPU
FeedForward Flux:

using Flux
using BenchmarkTools


m = Conv((3, 3), 1 => 16; stride=(2, 2), pad=1)

A = Float32.(randn(28, 28, 1, 100))

# compile for the first time
m(A)

@btime m(A)

753.000 μs (76 allocations: 2.44 MiB)
FeedForward Pytorch:

import torch
import torch.nn as nn

m = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=1)
A = torch.randn((100, 1, 28, 28))
%timeit m(A)

172 µs ± 7.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Flux is Significantly slower(almost 6 times) than Pytorch on CPU!!! Same Approch for Dense Layer Pytorch is 1.7 times Faster than Flux, Also RNNs in Flux are Significantly Slower just like the CNN than Pytorch(6 times slower) as we need to Loop Over Sequences

@ToucheSir
Copy link
Member

@aminaqi that's a different issue, namely FluxML/NNlib.jl#234. As mentioned in that issue and the linked Discourse discussion, make sure you're starting Julia with multiple threads and using MKL for a proper apples-to-apples comparison with PyTorch.

For this issue, it's not clear where the exact slowdown(s) come from. What I'm sure of is that it can't be solely the conv forward pass, which is what you're benchmarking.

PS. it looks like the formatting on your comments got messed up? Every one quotes the entirety of the one before it and it probably shouldn't.

@aminaqi
Copy link

aminaqi commented Dec 21, 2023

i've started julia with 6 threads, anyway even if i start julia with multi threads, it's still significantly slower than pytorch because that's only feedforward, we have a slowdown on backward too, which makes flux to be 10 times slower than pytorch or tensorflow
also not only Conv, but RNNS also

@ToucheSir
Copy link
Member

Are you seeing Julia be 10x slower on the forward and backwards pass, for CNNs and RNNs, against PyTorch and TensorFlow? I'm pretty sure we are slower on all of those, but 10x for all of them would not be expected. If that's really what you're seeing, I'd recommend starting a Discourse thread with some MWEs for the various benchmarks and linking back to that here. It's possible that Flux itself is only a small part of the issue there, and Discourse will allow more folks to weigh in on what other parts of your code may be contributing (only Flux maintainers really follow this issue tracker).

Either way, the performance gap being discussed in this issue already has a reasonable benchmark. It just needs to be narrowed down to a couple of layers and/or profiled so we can see what the bottlenecks are to take action on them. If nobody has bandwidth to do that, then I'm not sure there's much else to discuss here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants