-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modifying DQN model to accept 3D images #7
Comments
I am using 3D frames not 2D, so this should be fine with your data if it is 3D too. There are two models in the code 2D and 3D, both should be working so far. |
@crypdick Did you manage to make the code work on your data? |
@amiralansary I can't be sure. I finally got my network training yesterday evening, epoch 3 just started (training is slow-- only ~0.8 iterations per second). When I use |
I usually get around ~3-4 it/sec using the default big architecture on a GTX 1080. You can try using a tiny architecture first till you see some decent results, then try a bigger one. And if you do not need 3D convs then it will be much faster if you use 2D convs. You can also monitor the performance using tensorboard and passing the train_log directory. Signs like playing more games with time or increasing the number of successful episodes/games, are healthy indicators that the agent is learning. |
@amiralansary I wanted to double check something with you. In the network graph, you slice into the comb_state to get the current state and future state. So, the minimum FRAME_HISTORY has to be 2, right? I don't understand why not just return (pre-state, action, reward, terminal, post-state). |
Hi @amiralansary , I have a similar question. I want to apply your code to landmark detection on the common natural images (with 3 dims: H, W, C). How should I modify your code? Where should I pay more attention? (changing the conv operation from conv3D to conv2D; the image size from (45, 45, 45) to (X, X, 3)). |
Hi @amiralansary , At what point of time, in terms of number of epochs did you start getting reasonable results? Do you have some graph which shows number of success v/s epochs? Thanks. |
@sunalbert You may need to use conv2d instead of conv3d if the input is 2d (check Model2D). You will have to track other processes as well and make sure it process inputs and outputs correctly. Last thing, if your input is coloured images and not grayscale, then you may end up using conv3d as the channel dimension is already used for storing frame history. |
@gravity1989 time or number of iterations vary based on the complexity of the target landmark. You can visual the performance of the agent during training using the saved models (with the task play). Or Use tensorboard to visual the curves and monitor the performance. Here are some of the performance curves on training data for a cardiac landmark For results reported in the paper, training was stopped at 500.0k iter for all experiments. But in practice you can keep it training till the agent performs with a satisfactory accuracy. |
Thanks, @amiralansary . |
Thanks, @amiralansary |
@amiralansary @crypdick @gravity1989 @sunalbert too slow training speed my current env is and i used examples data for training for gpu memory limit, and i used parameters: and GPU and CPU setting: and exclude Data Load's effect.i used FakeData dataflow = FakeData([[BATCH_SIZE,45,45,45,5],[BATCH_SIZE],[BATCH_SIZE],[BATCH_SIZE]],size=1000,random=False, dtype= ['uint8','float32','int8','bool']) and minimal training setting: return TrainConfig( the training speed is 28 seconds per iter. even i reduce the model complexness (by commented Conv3D and Pool3D ): with argscope(Conv3D, nl=PReLU.symbolic_function, use_bias=True): the training speed is 22 seconds per iter. it is 100x slow by comparison with your training speed I want to know why and |
I've had some difficulties modifying your code to work directly on image stacks. Your RL model uses the past few 2D frames as channels and does 3D convolutions on that frame history. Instead, I want my agent to only see the current step (i.e.
FRAME_HISTORY = 1
) but for the inputs to be single-channel image stacks. I was hoping you could give me some insight.The text was updated successfully, but these errors were encountered: