Skip to content

Crazy-Jack/PredNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 

Repository files navigation

Pytorch implementation of PredNet.

Original Repository

This repo was first created by leido and was a helpful starting point. My implementation fixes the known issues of the leido's code such as blurry, black-and-white predictions.

Here's an example plot generating using prednet_relu_bug with default hyperparameters:

This implementation includes features not present in the original code such as the ability to toggle on peephole connections, between tied and untied bias weights, and between multiplicative and subtractive gating as developed by Costa et al.

Details

"Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning"(https://arxiv.org/abs/1605.08104)

The PredNet is a deep recurrent convolutional neural network that is inspired by the neuroscience concept of predictive coding (Rao and Ballard, 1999; Friston, 2005)

Original paper's code is writen in Keras. Examples and project website can be found here.

In leido's code, ConvLSTMCell is borrowed from here.

However, we significantly revamped this module with a proper implementation of peephole connections, gating options, and a more readable style.

Training, Validation, and Testing

Training a prednet model is done via kitti_train.py. Feel free to adjust the following training and model hyperparamters within the script:

Training parameters

  • num_epochs: default- 150
  • batch_size: default- 4
  • lr: learning rate, default- 0.001
  • nt: length of video sequences, default- 10
  • n_train_seq: number of video sequences per training epoch, default- 500
  • n_val_seq: number of video sequenced used for validation, default- 100

Model hyperparameters

  • loss_mode: 'L_0' or 'L_all', default- 'L_0'
  • peephole: toggles incluse of peephole connection w/n the ConvLSTM, default- False
  • lstm_tied_bias: toggles the tieing of biases w/n ConvLST, default- False
  • gating_mode: toggles between multiplicative 'mul' or subtractive 'sub' gating w/n ConvLSTM, default- 'mul'
  • A_channels & R_channels: number of channels within each layer of PredNet, default- (3, 48, 96, 192)

After training is complete, the script saves two versions of the model: prendet-*-best.pt (version with the lowest loss on validation set) and prednet-*.pt (version saved after the last epoch).

To test your models using kitti_test.py transfer them into your 'models' folder, set the testing and model hyperparamters accordingly, then run the script. It should output the MSE between the GT and predicted sequences as well as the MSE if the model simply predicted the previous frame at each time step.

A brief word on hyperparameters

The default parameters listed above reproduce the results in the paper when using prednet_relu_bug. However, prednet underperforms under these parameters and overfits the data. After a coarse hyperparameter search, we found that shrinking the model helped to alleviate overfitting.

Data

Acquiring the dataset requires multiple steps: 1) downloading the zip files 2) extracting and processsing the images. Step 1 is done via running the download_raw_data_.sh scripts found in kitti_raw_data\raw<category>. Step 2 is handled by running the process_kitti.py.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages