Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading weights into TinyCUDA #6

Closed
ZakSingh opened this issue Dec 8, 2021 · 7 comments
Closed

Loading weights into TinyCUDA #6

ZakSingh opened this issue Dec 8, 2021 · 7 comments

Comments

@ZakSingh
Copy link

ZakSingh commented Dec 8, 2021

Hi! I'm very excited by TinyCUDA and I'd like to test it out for an inference task on a pre-trained model. I have the network weights as a .npy file and I'd ideally like to load them into the fully fused MLP. From a quick scan of the codebase it looks like there isn't any way to load pre-computed model weights (please correct me if I'm wrong). Do you have any advice on how I could go about accomplishing this?

@Tom94
Copy link
Collaborator

Tom94 commented Dec 8, 2021

Hi there,

tiny-cuda-nn does let you load pre-computed model weights.

You can use Trainer::set_params or Trainer::set_params_full_precision, depending on which precision you have the params available in.

  • Trainer::set_params expects the same precision as tiny-cuda-nn was compiled for. (half precision by default)
  • Trainer::set_params_full_precision takes float params and does the casting for you.

These methods expect a CPU pointer to densely laid out network parameters. So: first layer, followed by the hidden layers, followed by the output layer.

All in row major memory order. (Depending on how your tensors were laid out when you were training your model, you might have to sneak in a transposition... which would mean column major after all. Sorry about this confusion -- I don't think there's a common standard that could be easily used here.)

Lastly, note that tiny-cuda-nn does not support biases, so if you haven't already, you'll need to make sure that the pre-trained model purely uses weight matrices + activations.

@ZakSingh
Copy link
Author

ZakSingh commented Dec 16, 2021

Thanks for the detailed response! I've now had a chance to give it a shot, but I can't seem to get the number of params to match:

Here's the Keras summary of the Tensorflow implementation of my network:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_22 (InputLayer)        [(None, 39)]              0         
_________________________________________________________________
dense_189 (Dense)            (None, 64)                2496      
_________________________________________________________________
dense_190 (Dense)            (None, 64)                4096      
_________________________________________________________________
dense_191 (Dense)            (None, 64)                4096      
_________________________________________________________________
dense_192 (Dense)            (None, 64)                4096      
_________________________________________________________________
dense_193 (Dense)            (None, 64)                4096      
_________________________________________________________________
dense_194 (Dense)            (None, 64)                4096      
_________________________________________________________________
dense_195 (Dense)            (None, 64)                4096      
_________________________________________________________________
dense_196 (Dense)            (None, 64)                4096      
_________________________________________________________________
dense_197 (Dense)            (None, 4)                 256       
=================================================================
Total params: 31,424
Trainable params: 31,424
Non-trainable params: 0
_________________________________________________________________

And here is my config for TinyCudaNN:

  const uint8_t n_frequencies = 6;
  const uint32_t n_input_dims = 3 + 3 * 2 * n_frequencies; // =39 (same as Tensorflow)
  const uint32_t n_output_dims = 4;

    json config = {
        {"loss", {{"otype", "L2"}}},
        {"optimizer", {
                          {"otype", "Adam"},
                          {"learning_rate", 1e-5},
                          {"beta1", 0.9f},
                          {"beta2", 0.99f},
                      }},
        {"encoding", {{"otype", "Identity"}, {"scale", 1.0}, {"offset", 0.0}}},
        {"network", {
                        {"otype", "FullyFusedMLP"},
                        {"n_neurons", 64},
                        {"n_hidden_layers", 8},
                        {"activation", "ReLU"},
                        {"output_activation", "None"},
                    }},
    };

When I call trainer->set_params_full_precision with my weights, I get an error:

Trainer: Initializing 32768 params and resetting training.
Uncaught exception: Can't set params because CPU buffer has the wrong size.

So my Tensorflow implementation has 31,424 parameters, while the TinyCudaNN version has 32,768.

When I run the network->layer_sizes() method, I get the following output:

64, 48
64, 64
64, 64
64, 64
64, 64
64, 64
64, 64
64, 64
16, 64

It looks like my problem stems from the input and output layers. I'm not sure where 48 is coming from (I would think it would be 39, as that's the size of my input layer). Similarly I would think the 16 in the final layer would be 4, as that's my specified number of output neurons. Is there anything going on behind the scenes that could be causing this discrepancy?

@Tom94
Copy link
Collaborator

Tom94 commented Dec 16, 2021

Hi there, many apologies, I totally forgot to explain the following detail: the hardware matrix multipliers (TensorCores) operate on 16x16 matrix chunks, so the input and output layers are padded to the nearest multiple of 16.

For the input layer (after encoding): the padded dimensions get a value of 1 (not zero) to help the first layer of the neural network implicitly learn a bias term.

For the output layer: any padded dimensions are trimmed away when calling network->inference()

So the following needs to change on the Keras side:

  • the input (after encoding) needs to be padded with ones until it's 48 dimensions wide.
  • the output needs to have a width of 16, where the last 12 dimensions are ignored in the optimization.

By the way, I noticed that in your config you encode the first 3 dimensions with the Frequency encoding and 3 extra dimensions are passed through. You can achieve this behavior in tiny-cuda-nn with the following Composite encoding:

	"otype": "Composite",
	"nested": [
		{
			"n_dims_to_encode": 3, // Spatial dims
			"otype": "Frequency",
			"n_frequencies": 6
		},
		{
			// Number of remaining linear dims is automatically derived
			"otype": "Identity"
		}
	]

(You can reverse the two nested encodings if it's the first 3 dimensions you would like to pass through.)

@Tom94
Copy link
Collaborator

Tom94 commented Dec 16, 2021

Another implementation detail: CutlassMLP and CutlassResnet pad to a multiple of 8 rather than 16 because they make different use of the TensorCores. Probably best to make extra sure by looking at layer_sizes() as you already figured out.

I'll have a think about how to expose all this information more elegantly in the future...

@ZakSingh
Copy link
Author

Thanks! One step closer... I've got my network fixed and the weights loaded, but now I'm hitting some issues with network->inference. Namely I'm getting quite a few CUDA errors:

Could not free memory: CUDA Error: cudaFree(rawptr) failed with error an illegal memory access was encountered
Could not free memory: CUDA Error: cudaFree(rawptr) failed with error an illegal memory access was encountered
Could not free memory: CUDA Error: cudaFree(rawptr) failed with error an illegal memory access was encountered
CUDA Error: cudaEventDestroy(m_training_splitk_events[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaStreamDestroy(m_training_splitk_streams[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaEventDestroy(m_training_splitk_events[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaStreamDestroy(m_training_splitk_streams[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaEventDestroy(m_training_splitk_events[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaStreamDestroy(m_training_splitk_streams[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaEventDestroy(m_training_splitk_events[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaStreamDestroy(m_training_splitk_streams[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaEventDestroy(m_training_splitk_events[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaStreamDestroy(m_training_splitk_streams[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaEventDestroy(m_training_splitk_events[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaStreamDestroy(m_training_splitk_streams[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaEventDestroy(m_training_splitk_events[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaStreamDestroy(m_training_splitk_streams[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaEventDestroy(m_training_splitk_events[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaStreamDestroy(m_training_splitk_streams[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaEventDestroy(m_training_splitk_events[i]) failed with error an illegal memory access was encountered
CUDA Error: cudaStreamDestroy(m_training_splitk_streams[i]) failed with error an illegal memory access was encountered
Could not free memory: CUDA Error: cudaFree(rawptr) failed with error an illegal memory access was encountered
Could not free memory: CUDA Error: cudaFree(rawptr) failed with error an illegal memory access was encountered
Could not free memory: CUDA Error: cudaFree(rawptr) failed with error an illegal memory access was encountered
Could not free memory: CUDA Error: cudaFree(rawptr) failed with error an illegal memory access was encountered
Uncaught exception: CUDA Error: cudaMemcpy(host_data.data(), prediction.data(), host_data.size() * sizeof(float), cudaMemcpyDeviceToHost) failed with error an illegal memory access was encountered

I've checked and the cudaFree(rawptr) errors are coming from the call in virtual ~GPUMemory() in gpu_memory.h.

I'm not sure what's going on with the all the m_training_splitk_events and m_training_splitk_streams errors, as I don't load any training data nor attempt to train the network (since I'm just using it for inference).

I've tried running w/ compute-sanitizer to see what's going on, and the logs are getting filled with thousands of out of bounds access errors from tcnn:frequency_encoding. (I switched to using the frequency encoding for the entire input). I don't they're the cause of the above errors, but it may be a problem. For example:

========= Invalid __global__ read of size 4 bytes
=========     at 0x3e0 in void tcnn::frequency_encoding<__half>(unsigned int, unsigned int, unsigned int, unsigned int, tcnn::PitchedPtr<const float>, tcnn::PitchedPtr<T1>, float *)
=========     by thread (37,0,0) in block (1,0,0)
=========     Address 0x563956bd6144 is out of bounds
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame:cuLaunchKernel [0x2efabe]
=========                in /usr/lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame: [0x1fa33b]
=========                in /home/ubuntu/l352/tinycudanerf/tiny-cuda-nn/build/./mlp_tinynerf
...

However, all of the above errors aren't fatal and the program continues execution. But when I try to copy the prediction matrix from the device to the CPU to read it, I get another illegal memory access error, this time fatal:

GPUMatrix<float> prediction(n_output_dims, image_width * image_height);
...
std::vector<float> host_data(image_width * image_height * 4);
CUDA_CHECK_THROW(cudaMemcpy(host_data.data(), prediction.data(), host_data.size() * sizeof(float), cudaMemcpyDeviceToHost));
Uncaught exception: CUDA Error: cudaMemcpy(host_data.data(), prediction.data(), host_data.size() * sizeof(float), cudaMemcpyDeviceToHost) failed with error an illegal memory access was encountered

Do you have any advice on how I should go about debugging this? Could it be caused by running out of memory (I'm running on a T4 GPU). Thanks again for the help!

@Tom94
Copy link
Collaborator

Tom94 commented Dec 17, 2021

I think you identified the root cause correctly using compute-sanitizer. The error is definitely fatal -- which is the reason all the following cuda* calls are failing (including your memcpy).

The reason the program execution continues at first is because destructors aren't permitted to throw exceptions and thus just print the errors they're getting. (to allow the rest of the program to clean up)

I recommend double-checking the inputs to the encoding/NetworkWithInputEncoding. Something is likely wrong there. If it was insufficient memory, you'd instead be getting an appropriate error at allocation time.

@ZakSingh
Copy link
Author

I had forgotten to copy the input vector to the device. Silly silly... Thanks again for the help! Closing this for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants