# StatQuest: Adding Reinforcement Learning to a Neural Network

Copyright 2025, Joshua Starmer

----

In this tutorial we illustrate how to add **Reinforcement Learning** (**RL**) to a simple **Neural Network**. Specifically, we illustrate how to add the **Policy Gradients** **RL** method to a **Neural Network** that is designed to help us decide where to go eat fries, either **'Squatch's Fry Shack** or **Norm's Fry Hut**, based on how hungry we are. If we are very hungry, than we will want to go to the place that has a high probability of giving us a large order of fries. However, if we are not very hungry, we want to go to the place that has a low probability of giving us a large order of fries.

![The neural network we will implement](images/nn_plus_fries.png)

**NOTE:** To make this example a little bit interesting, there is a **20%** chance that **'Squatch** will give us a large order of fries and an **80%** chance that **Norm** will give us a large order of fries.

----

# Import the modules that will do all the work

The very first thing we need to do is load a bunch of Python modules. Python itself is just a basic programming language. These modules give us extra functionality to generate random numbers and draw fancy graphs.

In [30]:
import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.Module(), nn.Embedding() and nn.Linear()
import torch.nn.functional as F # This gives us sigmoid()???
from torch.optim import Adam # optim contains many optimizers. This time we're using Adam

import lightning as L ## Lightning makes it easier to write, optimize and scale our code
from torch.utils.data import TensorDataset, DataLoader ## We'll store our data in DataLoaders

----

# Create Training Data

Training data? I thought this was **Reinforcement Learning**. Why do we need training data?

From the perspective of doing **Reinforcement Learning**, we don't need training data, since the model itself does the exploring. Our model will create its own inputs and discover the corresponding "labels". Then it will use a reward to modify the "label" depending on whether or not it was correct.

That said, the **Lightning Trainer** needs something it can iterate in order to train the model. So, here we will create some data that we will, ultimately,  ignore in our model's `training_step()` method.

In [31]:
## In order ot use the Lightning Trainer, 
## we need to give it something it can iterate on.
## So here is some data that we will completely ignore in the
## training_step() method.
training_inputs = torch.tensor([0.8, 0.1, 0.2, 0.9])
training_labels = torch.tensor([1.0, 0.0, 0.0, 1.0])

## Now let's package everything up into a DataLoader...
training_dataset = TensorDataset(training_inputs, training_labels) 
dataloader = DataLoader(training_dataset)

----

# Create a simple Neural Network with a trainable Bias

![the neural network that we will implement](images/just_nn.png)

Because the goal is to illustrate how to add **Reinforcement Learning** to a neural network, and not how to create a really fancy neural network, we're going to use an incredibly simple neural network that has a single trainable parameter, a **Bias** term. And we'll use **Reinforcement Learning** to train that one **Bias**. That said, the method easily extends as many parameters we need to train if we, ultimately, create a fancy neural network.

In [32]:
class simpleNN_with_RL(L.LightningModule):
    
    def __init__(self):
        
        super().__init__()

        L.seed_everything(seed=42)
        
        self.weight = torch.tensor(20) # We won't train this weight.

        ## However, we will train this bias.
        ## The ideal value sis -10.
        self.bias = nn.Parameter(torch.tensor(0.0))

        ## gamma is like a learning rate applied directly to the reward.
        self.gamma = torch.tensor(0.99)
        
        ## We need this to keep track of the reward.
        self.reward = torch.tensor(0) 

    
    def forward(self, inputs):
        ## A forward pass through a super simple neural network
        ## NOTE: In reinforcement learning lingo, a neural network
        ##       that we want to train is called a "policy network".

        p_norm = torch.sigmoid(inputs * self.weight + self.bias)

        return p_norm

    
    def configure_optimizers(self): 
        ## Configure the optimizer we want to use for backpropagation.
        return Adam(self.parameters(), lr=0.1)
    
    
    def training_step(self, batch, batch_idx): 
        ## Take a step during gradient descent with policy gradients.

        ## First, decide how hungry we are...
        how_hungry = torch.rand(1).to("mps")
        
        ## Now pass how_hungry through the neural network to get a probability
        ## for going to Norm's.
        ## NOTE: In reinforcement learning lingo, the neural network
        ##       we want to train is called a "policy network".
        outputs = self.forward(how_hungry)   
        # print("outputs:", outputs)
        
        ## now figure out if we go to Norm's or Squatch's
        ## by picking a random number between 0 and 1...
        rand_num = torch.rand(1).to("mps")
        # print("rand_num:", rand_num)

        ## ...and then comparing that number to the output
        ## from the neural network (ahem, "policy network").
        ## If the random number is < the output, go to Norm's.
        if(rand_num < outputs):
            ## go to norm's
            # print("\tgoing to norms!")
            
            ## Now determine if Norm is giving us a large order or not
            if (torch.rand(1) < 0.8): # Norm gave us a large order...

                ## if how_hungry is > 0.5, then we are hungry and want a larger order...
                if(how_hungry > 0.5): # 
                    self.reward = 1 # We are hungry and happy we got a large order
                else:
                    self.reward = -1 # We are not hungry and not happy we got a large
            else: # Norm gave us a small order
                if(how_hungry > 0.5):
                    self.reward = -1  # We are hungry and sad we got a small order
                else:
                    self.reward = 1 # We are not hungry and happy we got a small order
        else:
            ## go to squatch's
            # print("\tgoing to squatch!")

            # Now convert the probability of visiting Norm to the probability of visiting Squatch
            outputs = 1 - outputs
            
            ## Now determine if Squatch is giving us a large order or not
            if (torch.rand(1) < 0.2): # Squatch gave us a large order...
                if(how_hungry > 0.5): 
                    self.reward = 1 # We are hungry and happy squatch gave us a large order
                else:
                    self.reward = -1 # We are not hungry and sad we got a large order
            else: # Squatch gave us a small order
                if(how_hungry > 0.5): 
                    self.reward = -1 # We are hungry and sad squatch gave us a small order
                else:
                    self.reward = 1 # We are not hungry and happy we got a small order

        ## Now that we have visited either Norm or 'Squatch and ordered fries,
        ## we can calculate the loss.
        ## NOTE: We're using cross entropy which is...
        ## CE_squatch = -1 * log(1-output) * gamma * reward
        ## CE_norm = -1 * log(output) * gamma * reward
        ## When pytorch calculates the derivative of the cross entropy
        ## the scaled reward still have the desired effect.
        loss = -1 * torch.log(outputs) * self.gamma * self.reward
        
        return loss

### Bam.

----

# Run some input values through neural network to see what it does before training

In [47]:
model = simpleNN_with_RL()

Seed set to 42


First, let's run a relatively large input value, **0.9**, through the neural network. A relatively large input value (a value close to **1**) indicates that we are hungry and would like to go somewhere that has a good chance of giving us a large order of fries.

In [48]:
model(torch.tensor(0.9))

tensor(1., grad_fn=<SigmoidBackward0>)

The output, **0.9995**, means there is a high probability that we will go to **Norm's** to get fries. This makes sense because we are hungry (the input was **0.9**) and there a good chance that **Norm** will give us a large order of fries.

Now let's run a relatively small input value, **0.1**, thorugh the neural network. A relatively small input value (a value close to **0**) indicates that we are not hungry and would like to go somewhere that has a good chance of serving us a small order of fries.

In [49]:
model(torch.tensor(0.1))

tensor(0.8808, grad_fn=<SigmoidBackward0>)

The output, **0.8808**, means there is a high probability that we will go to **Norm's** to get fries. This does not make sense because we are not hungry (the input was **0.1**) and there is a good chance that **Norm** will give us a large order of fries. This means we'll waste fries, and that is no good at all. This means we need to train our model!

----

# Now train the NN

In [52]:
model = simpleNN_with_RL()
trainer = L.Trainer(max_epochs=70)
trainer.fit(model, train_dataloaders=dataloader)

Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name         | Type | Params | Mode
---------------------------------------------
  | other params | n/a  | 1      | n/a 
---------------------------------------------
1         Trainable params
0         Non-trainable params
1         Total params
0.000     Total estimated model params size (MB)


Training: |                                               | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=70` reached.


Now that we've trained the model, we can re-run the input values we used before to see if the outputs have changed. We'll start with a relatively high input value, **0.9**, that indicates that we are hungry.

In [53]:
model(torch.tensor(0.9))

tensor(0.9995, grad_fn=<SigmoidBackward0>)

The result, **1**, means there is a **100%** chance that we will go to **Norm's**. And this makes sense because we are hungry, the input values was 0.9, and there's a good chance that Norm will give us a lot of fries.

Now let's see what happens when we use a relatively small input value, **0.1**, that indicates that we are not hungry.

In [54]:
model(torch.tensor(0.1))

tensor(0.0002, grad_fn=<SigmoidBackward0>)

Unlike before we trained the model, now the output is close to 0, meaning there is low probability that we will go to Norm's. This means there is a high probabilty that we will go to 'Squatch's. And that makes sense, because there is a good chance that 'Squatch will eat most of the fries and only give us a small order. This means we will not waste any fries, and that is a good thing!

## DOUBLE BAM!!

Now let's print out all of the trainable parameters, meaning, the one **Bias**, to see what the value is, after training.

In [44]:
for name, param in model.named_parameters():
    print(name, torch.round(param.data, decimals=2))

bias tensor(-10.3400)


That value, **-10.3**, is very close to the ideal value, **-10**, and that means that **Reinforcement Learning** successfully trained the model.

# TRIPLE BAM!!!