Implementation of a VQC for reinforcement learning on OpenAI Gym's Frozen Lake environment

Inspired by https://github.com/ycchen1989/Var-QuantumCircuits-DeepRL/blob/master/Code/QML_DQN_FROZEN_LAKE.py

Implemented using Qiskit and PyTorch

In [None]:
%pip install qiskit
%pip install qiskit-aer-gpu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting qiskit
  Downloading qiskit-0.42.1.tar.gz (14 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting qiskit-terra==0.23.3
  Downloading qiskit_terra-0.23.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m54.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting qiskit-aer==0.12.0
  Downloading qiskit_aer-0.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m82.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting qiskit-ibmq-provider==0.20.2
  Downloading qiskit_ibmq_provider-0.20.2-py3-none-any.whl (241 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m241.5/241.5 KB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
Collecting websocket-client>=1.5.1
  Download

In [None]:
import torch
from torch.nn import MSELoss
from torch.autograd import Variable

from qiskit import QuantumCircuit
from qiskit.circuit import Parameter
from qiskit.compiler import transpile
from qiskit import Aer

import gym

import numpy as np
from tqdm import tqdm
import time
import random

In [None]:
class ReplayMemory():

    # Initialize our replay memory
    def __init__(self, capacity):
        self.transitions = []
        self.capacity = capacity

    # Add a transition value to our memory
    def store_transition(self, transition):
        if len(self.transitions) < self.capacity:
            self.transitions.append(transition)
        else:
            self.transitions.pop(random.randint(0, self.capacity - 1))
            self.transitions.append(transition)

    def size(self):
        return len(self.transitions)

    # Sample a random batch from our memory
    def sample(self, batch_size):
        return random.sample(self.transitions, batch_size)

In [None]:
class Agent():
    
    def __init__(self, numQubits=4, depth=1):
        # Number of qubits used (# of wires)
        self.numQubits = numQubits

        # Number of times to apply the CNOT / rotation module
        self.depth = depth

        # Action-value function approximator
        self.qc = QuantumCircuit(numQubits, numQubits)

        # State encoding parameters
        # thetas: parameters used to store angles to rotate in the x direction by
        # phis: parameters used to store angles to rotate in the z direction by
        self.thetas = [Parameter(f'theta_{i}') for i in range(self.numQubits)]
        self.phis = [Parameter(f'phi_{i}') for i in range(self.numQubits)]

        # Creates rotations to be used in getting outputs
        self.alpha_rotations = [Parameter(f'alpha_{i}') for i in range(self.numQubits)]
        self.beta_rotations = [Parameter(f'beta_{i}') for i in range(self.numQubits)]
        self.gamma_rotations = [Parameter(f'gamma_{i}') for i in range(self.numQubits)]

        # Creates a backend to run the circuit on
        self.backend = Aer.get_backend('qasm_simulator')
        self.backend.set_options(device='GPU')

        # Initialize state preparation gates
        self.state_preparation()

        # Create a layer
        self.init_layer()

        # Initialize measurement
        self.init_measurement()

    # State is a decimal value from 0 to 16
    # Converts this decimal value to a binary list
    def get_binary_state_encoding(self, state):
        encoding = [int(i) for i in bin(state)[2:]]
        while len(encoding) < self.numQubits:
            encoding = [0] + encoding
        return encoding

    # Creates a parameterized state encoding circuit
    def state_preparation(self):
        
        # Initialize circuit with params
        for wire in range(self.numQubits):
            self.qc.rx(np.pi * self.thetas[wire], wire)
        
        for wire in range(self.numQubits):
            self.qc.rz(np.pi * self.phis[wire], wire)
    
    # Binds theta values and phi values to the quantum circuit
    # state: index of the state
    def bind_state_preparation_parameters(self, state, circuit):
        angles = self.get_binary_state_encoding(state)

        # Make sure the number of theta values and phi values are equal to the number of qubits in the circuit
        assert len(angles) == self.numQubits

        circuit = circuit.bind_parameters(dict(zip(self.thetas, angles)))
        circuit = circuit.bind_parameters(dict(zip(self.phis, angles)))

        return circuit
    
    # Creates a rotation layer
    def init_layer(self):
        # Create CNOT gates at each layer of the circuit
        for wire in range(self.numQubits - 1):
            self.qc.cx(wire, wire + 1)
        
        # Create rotations at each level of the circuit
        for wire in range(self.numQubits):
            self.qc.rx(self.alpha_rotations[wire], wire)
            self.qc.ry(self.beta_rotations[wire], wire)
            self.qc.rz(self.gamma_rotations[wire], wire)

    # Binds one layer of a parameterized circuit to parameters
    # alpha: array storing rotation values in x direction
    # beta: array storing rotation values in y direction
    # gamma: array storing rotation values in z direction
    def bind_layer(self, alphas, betas, gammas, circuit):
        # Length of alpha, beta, and gamma must be the same as the number of wires in the circuit
        assert len(alphas) == self.numQubits                     
        assert len(betas) == self.numQubits
        assert len(gammas) == self.numQubits

        circuit = circuit.bind_parameters(dict(zip(self.alpha_rotations, [alpha.item() for alpha in alphas])))
        circuit = circuit.bind_parameters(dict(zip(self.beta_rotations, [beta.item() for beta in betas])))
        circuit = circuit.bind_parameters(dict(zip(self.gamma_rotations, [gamma.item() for gamma in gammas])))

        return circuit
    
    # Adds a measurement layer to the end of the circuit
    def init_measurement(self):
        for wire in range(self.numQubits):
            self.qc.measure(wire, wire)

    def run_job(self, circuit, num_iterations=10):
        # Create a job to run on our circuit
        job = self.backend.run(transpile(circuit, self.backend), shots=num_iterations)

        # Get the result of our job
        results = job.result()

        # Get the number of times each result appeared
        result_counts = results.get_counts(circuit)

        # For each result we got, get the number of times each bit appeared
        # The bit that was 1 most often is the index of our selected action
        counts = [0] * self.numQubits
        for output in result_counts.keys():
            for wire in range(self.numQubits):
                counts[wire] += int(output[wire]) * result_counts[output]

        return counts

    # Outputs a score for each action
    # If expectation value from qubit 0 is highest, then action selected is LEFT
    # If expectation value from qubit 1 is highest, then action selected is DOWN
    # If expectation value from qubit 2 is highest, then action selected is RIGHT
    # If expectation value from qubit 3 is highest, then action selected is UP
    # state: the state to select an action from
    # params: tuple containing alphas, betas, and gammas
    # num_iterations: the number of times to use to calculate our expectation values
    # epsilon: percentage of the time to choose a random action
    def get_q_values(self, state, params, num_iterations=10, epsilon=0.75):
        if np.random.rand() < epsilon:
            random_q = np.random.rand(4)
            normalized_random_q = [i / np.sum(random_q) for i in random_q]
            return normalized_random_q

        # Binds the state preparation parameters to the state we want to select the best action from
        bound_copy = self.qc.copy()
        bound_copy = self.bind_state_preparation_parameters(state, bound_copy)
        bound_copy = self.bind_layer(*params, bound_copy)

        counts = self.run_job(bound_copy, num_iterations)
        
        # Normalization
        counts = [i / (sum(counts) if sum(counts) != 0 else 1 / len(counts)) for i in counts]

        return counts

In [None]:
class Runner():

    # Initialize the runner class
    # num_episodes: the number of episodes to run
    # epsilon: the percentage of the time to select a random action
    # capacity: the capacity of the replay memory
    def __init__(self, num_episodes=100, epsilon=0.1, capacity=1000, batch_size=16):
        self.num_episodes = num_episodes
        self.epsilon = epsilon
        self.memory = ReplayMemory(capacity)
        self.agent = Agent(numQubits=4, depth=1)
        self.target_agent = Agent(numQubits=4, depth=1)
        self.parameters = Variable(torch.tensor(np.random.rand(3, 4) * 2 - 1).type(torch.DoubleTensor), requires_grad=True)
        self.target_parameters = self.parameters.clone().detach()
        self.env = gym.make('FrozenLake-v1', is_slippery=False)
        self.batch_size = batch_size
        self.terminal_state = 15
        self.opt = torch.optim.RMSprop([self.parameters], lr=0.01, alpha=0.99, eps=1e-08)
        self.gamma = 0.1
        self.steps = 0
        self.update_target_every_n_steps = 10
        self.labels = torch.tensor([])
        self.cumulative_rewards = []
        
    def criterion(self, y_pred, y):
        return torch.mean((y_pred - y) ** 2)
    
    def backward_step(self):
        self.opt.zero_grad()
        loss = self.criterion(self.predictions, self.labels)
        loss.backward()
        print(self.parameters.grad)
        print(self.labels.grad)
        print(self.predictions.grad)
        return loss

    # Run the algorithm for num_episodes
    # num_episodes: the number of episodes to run (M)
    def run(self):
        # print(self.parameters)
        # Initialize random parameters
        for ep in tqdm(range(self.num_episodes)):
            initial_state = self.env.reset()
            prev_state = initial_state
            done = False
            cumulative_ep_reward = 0
            
            while not done:
                # init_time = time.time()
                # ts_0, ts_1, ts_2, ts_3, ts_4, ts_5 = 0, 0, 0, 0, 0, 0

                self.steps += 1
                
                # Use our quantum function approximator to get an action
                # alphas, betas, gammas, = self.parameters
                # action = np.argmax(self.agent.get_q_values(initial_state, (alphas, betas, gammas)))
                action = np.argmax(self.agent.get_q_values(initial_state, self.parameters))

                # ts_0 = time.time() - init_time

                # Take selected action
                observation, reward, done, info = self.env.step(action)

                # ts_1 = time.time() - (init_time + ts_0)
                
                # Reward shaping to penalize for falling in a hole
                if done and observation != self.terminal_state:
                    reward -= 1

                # Update cumulative reward
                cumulative_ep_reward += reward

                # Store experience
                self.memory.store_transition((prev_state, action, reward, observation))
                prev_state = observation

                # If we have very little experience built up, continue
                if self.memory.size() < self.batch_size:
                    continue
                
                # Get minibatch of transitions
                transitions = self.memory.sample(self.batch_size)

                # ts_2 = time.time() - (init_time + ts_1)

                # Get labels for the transitions
                labels = []
                for transition in transitions:
                    # Reward
                    label = transition[2]
                    if transition[3] != self.terminal_state:
                        # target_alphas, target_betas, target_gammas = self.target_parameters
                        # label += self.gamma * np.max(self.target_agent.get_q_values(initial_state, (target_alphas, target_betas, target_gammas)))
                        label += self.gamma * np.max(self.target_agent.get_q_values(transition[3], self.target_parameters))
                    labels.append(label)

                self.labels = torch.tensor(labels)

                self.predictions = torch.tensor([self.agent.get_q_values(transition[0], self.parameters)[transition[1]] for transition in transitions], requires_grad=True)
                # ts_3 = time.time() - (init_time + ts_2)
                
                # Backwards step
                self.opt.step(self.backward_step)

                # ts_4 = time.time() - (init_time + ts_3)

                # Update target parameters
                if self.steps % self.update_target_every_n_steps == 0:
                    self.target_parameters = self.parameters.clone().detach()
                
                # ts_5 = time.time() - (init_time + ts_4)

                # print(ts_0, ts_1, ts_2, ts_3, ts_4, ts_5)
                # print(time.time() - init_time)

            self.cumulative_rewards.append(cumulative_ep_reward)
            

In [None]:
r = Runner(num_episodes=10)
r.run()

  deprecation(
  deprecation(
 10%|█         | 1/10 [00:00<00:02,  3.15it/s]

None
None
tensor([ 0.0372,  0.0410,  0.0021,  0.0327,  0.0372,  0.0964,  0.0069,  0.0051,
         0.0247,  0.0549, -0.0015, -0.0056,  0.1410, -0.0009,  0.0213,  0.0604],
       dtype=torch.float64)
None
None
tensor([ 0.0250,  0.0231,  0.0077,  0.0185,  0.0365,  0.0372,  0.0561,  0.1204,
         0.0174,  0.0784, -0.0050,  0.0283,  0.0497,  0.0304,  0.0086,  0.0298],
       dtype=torch.float64)
None
None
tensor([0.0260, 0.0033, 0.0156, 0.0153, 0.0541, 0.0246, 0.0228, 0.0368, 0.1591,
        0.0357, 0.0087, 0.0740, 0.0438, 0.0303, 0.0518, 0.0153],
       dtype=torch.float64)


 20%|██        | 2/10 [00:01<00:08,  1.02s/it]

None
None
tensor([-0.0062,  0.0710,  0.0556,  0.0044,  0.0727,  0.0559,  0.1726,  0.0156,
         0.1648,  0.0631,  0.0309,  0.0513,  0.0108, -0.0030,  0.0119,  0.0441],
       dtype=torch.float64)
None
None
tensor([-0.0066,  0.0260,  0.0146,  0.0454,  0.0043,  0.0282,  0.0020,  0.0315,
         0.0294,  0.0237,  0.0160,  0.1516,  0.0200,  0.1533,  0.0781,  0.0056],
       dtype=torch.float64)
None
None
tensor([-0.0009,  0.0049,  0.0230,  0.0058,  0.0175,  0.0299,  0.0443,  0.0026,
         0.0326,  0.0185,  0.0347,  0.0041, -0.0031,  0.0214,  0.1439,  0.0225],
       dtype=torch.float64)
None
None
tensor([ 0.0511,  0.0236,  0.0358,  0.0289,  0.1630, -0.0062,  0.0270,  0.0567,
         0.0342,  0.0244,  0.0976,  0.0494,  0.0085, -0.0056,  0.0341,  0.0308],
       dtype=torch.float64)
None
None
tensor([ 0.0074,  0.0345,  0.0545,  0.0496,  0.0163,  0.0449,  0.0584,  0.1514,
         0.1399,  0.0607,  0.0389,  0.0574, -0.0023,  0.0009,  0.0455,  0.0454],
       dtype=torch.float64)
None


 30%|███       | 3/10 [00:03<00:09,  1.34s/it]

None
None
tensor([ 3.3820e-02,  1.5141e-02,  6.7641e-02,  2.7572e-02,  6.5841e-02,
         3.2475e-02,  3.1903e-02, -4.8462e-06,  1.4724e-01,  3.8538e-02,
         2.6660e-02,  3.5954e-02,  3.1025e-02,  3.2114e-02,  1.4417e-01,
         3.2856e-02], dtype=torch.float64)
None
None
tensor([ 0.0120, -0.0004,  0.0499,  0.0156,  0.0192,  0.0082,  0.0022,  0.0365,
         0.0459,  0.0025, -0.0028,  0.0560,  0.1606,  0.0141,  0.0211,  0.1191],
       dtype=torch.float64)
None
None
tensor([0.0204, 0.0009, 0.0789, 0.0493, 0.0153, 0.0376, 0.1252, 0.0285, 0.0401,
        0.0311, 0.0354, 0.0227, 0.1700, 0.0223, 0.0047, 0.0232],
       dtype=torch.float64)
None
None
tensor([ 0.0212,  0.0074,  0.0239,  0.0378,  0.0443,  0.0114,  0.0101, -0.0010,
        -0.0016,  0.0127,  0.0447,  0.0188,  0.0612,  0.0323,  0.1279,  0.0488],
       dtype=torch.float64)
None
None
tensor([ 0.0166,  0.0275,  0.0602,  0.0438, -0.0027,  0.0489,  0.0455,  0.0193,
         0.0141,  0.1355,  0.0232,  0.0409,  0.0114, -0.0

 40%|████      | 4/10 [00:05<00:08,  1.48s/it]

None
None
tensor([-1.0430e-04,  2.4430e-02,  2.5568e-02, -6.0291e-03,  4.6171e-02,
         2.3302e-02,  1.2154e-01, -4.0002e-03,  3.2773e-02,  1.4708e-02,
         3.7695e-02,  4.8357e-02,  1.2808e-01,  5.0780e-02, -6.2500e-03,
         4.4219e-02], dtype=torch.float64)
None
None
tensor([ 0.0206,  0.1620,  0.0249,  0.0202,  0.0149,  0.0366,  0.0535,  0.0083,
         0.0379,  0.0475,  0.0382,  0.0245,  0.1415,  0.0370,  0.0180, -0.0032],
       dtype=torch.float64)
None
None
tensor([ 0.0337,  0.0368,  0.1769,  0.0563,  0.0266, -0.0004,  0.1026,  0.0462,
         0.0166,  0.0156,  0.0270,  0.0409,  0.0402,  0.0317,  0.0037,  0.0129],
       dtype=torch.float64)
None
None
tensor([3.3854e-02, 2.4027e-02, 3.4507e-02, 2.1337e-02, 1.7940e-02, 1.5266e-02,
        3.2675e-02, 1.4611e-02, 1.6438e-01, 2.4964e-02, 2.8362e-02, 6.2850e-02,
        4.5736e-03, 3.1651e-02, 2.3587e-02, 1.6929e-05], dtype=torch.float64)
None
None
tensor([0.0544, 0.0580, 0.0189, 0.0373, 0.0483, 0.0172, 0.0134, 0.0335, 

 50%|█████     | 5/10 [00:06<00:06,  1.39s/it]

None
None
tensor([ 0.0018, -0.0061,  0.0671,  0.0114,  0.0407,  0.0206,  0.0562,  0.0101,
         0.0192,  0.0319,  0.0237,  0.1329,  0.0915, -0.0025,  0.0728,  0.0222],
       dtype=torch.float64)
None
None
tensor([0.0224, 0.1192, 0.0069, 0.0165, 0.0285, 0.0431, 0.0165, 0.0315, 0.0318,
        0.0269, 0.0295, 0.0717, 0.1589, 0.0345, 0.0559, 0.0504],
       dtype=torch.float64)
None
None
tensor([ 0.0420,  0.1465,  0.0134,  0.0277,  0.0367,  0.0280,  0.0171,  0.0593,
         0.1466,  0.1760,  0.0425, -0.0057, -0.0025,  0.0297,  0.0588,  0.0283],
       dtype=torch.float64)
None
None
tensor([ 0.0272,  0.0298,  0.0003,  0.0229,  0.0467,  0.0303,  0.0418,  0.0169,
        -0.0057,  0.0222,  0.1193,  0.0182,  0.0086,  0.0712,  0.0240,  0.0333],
       dtype=torch.float64)
None
None
tensor([0.0252, 0.0248, 0.1508, 0.0028, 0.0211, 0.0468, 0.0106, 0.0059, 0.0245,
        0.0140, 0.1495, 0.0378, 0.0044, 0.0429, 0.0189, 0.0011],
       dtype=torch.float64)
None
None
tensor([ 0.0092,  0.0265,  

 60%|██████    | 6/10 [00:07<00:05,  1.40s/it]

None
None
tensor([ 0.0307,  0.0691,  0.0199,  0.1569,  0.0325,  0.0338,  0.0138,  0.0471,
         0.0316,  0.0411,  0.0351,  0.0368, -0.0046,  0.0270,  0.1307,  0.1197],
       dtype=torch.float64)
None
None
tensor([-0.0028,  0.0017,  0.0283,  0.0205,  0.0500, -0.0041,  0.0306, -0.0004,
         0.0005,  0.0147,  0.0262,  0.0501,  0.0256,  0.0292,  0.1544,  0.0316],
       dtype=torch.float64)
None
None
tensor([-0.0036,  0.0027,  0.0117,  0.0090,  0.0541,  0.0202,  0.1475,  0.0371,
         0.0215,  0.0105,  0.0310,  0.0223,  0.1339, -0.0055,  0.0004,  0.0111],
       dtype=torch.float64)
None
None
tensor([ 5.3600e-04,  3.3899e-02,  6.0266e-03,  7.2745e-03,  3.1389e-02,
        -4.2551e-05,  2.3709e-02,  2.0833e-02,  2.7729e-02,  6.3420e-03,
         1.1346e-02,  2.8371e-02, -3.6765e-03,  7.6337e-03,  4.3561e-02,
         3.1304e-02], dtype=torch.float64)
None
None
tensor([ 0.0238,  0.0456,  0.0724, -0.0023,  0.0163,  0.1543,  0.0196,  0.0151,
         0.0263,  0.0168,  0.0587,  0.052

 70%|███████   | 7/10 [00:08<00:03,  1.13s/it]

None
None
tensor([ 0.0449,  0.1455,  0.0332,  0.1850,  0.0035,  0.1634,  0.0017,  0.0170,
         0.0319,  0.0187,  0.0432,  0.0438,  0.0230, -0.0018,  0.0046, -0.0045],
       dtype=torch.float64)
None
None
tensor([ 0.0149,  0.0403,  0.0188, -0.0021,  0.0442,  0.0122,  0.0182,  0.0310,
         0.0428,  0.0531,  0.0258,  0.0335,  0.1555,  0.0058,  0.0288, -0.0037],
       dtype=torch.float64)


 80%|████████  | 8/10 [00:09<00:01,  1.06it/s]

None
None
tensor([-0.0052,  0.0112,  0.0352,  0.0212,  0.0045,  0.0368,  0.0383,  0.0067,
         0.1627,  0.0401,  0.0242,  0.1626,  0.0226,  0.0115,  0.2012,  0.0527],
       dtype=torch.float64)
None
None
tensor([-0.0059,  0.0426,  0.0318,  0.0642,  0.1543,  0.1282,  0.0360,  0.0194,
         0.0263,  0.0222,  0.0256,  0.0266,  0.1543,  0.0067,  0.1283,  0.0186],
       dtype=torch.float64)
None
None
tensor([-0.0030,  0.0253,  0.1443,  0.0110,  0.0926,  0.0303,  0.0305,  0.0332,
         0.0544,  0.0033,  0.1554,  0.1266,  0.0425,  0.0278,  0.0438,  0.0370],
       dtype=torch.float64)
None
None
tensor([0.1546, 0.0381, 0.0249, 0.1637, 0.1711, 0.0295, 0.0061, 0.0480, 0.0264,
        0.0399, 0.0096, 0.0179, 0.1503, 0.0015, 0.0157, 0.0336],
       dtype=torch.float64)
None
None
tensor([ 0.0014,  0.0132,  0.0200,  0.0345,  0.0289, -0.0011,  0.0301,  0.0643,
         0.0143,  0.0679,  0.0289,  0.0033,  0.1684,  0.0104,  0.0095,  0.0120],
       dtype=torch.float64)
None
None
tensor([ 0.

 90%|█████████ | 9/10 [00:11<00:01,  1.55s/it]

None
None
tensor([0.0228, 0.0382, 0.0389, 0.0387, 0.0315, 0.1495, 0.0400, 0.0196, 0.0307,
        0.0267, 0.0183, 0.0118, 0.0131, 0.0295, 0.0605, 0.0204],
       dtype=torch.float64)
None
None
tensor([0.0391, 0.1225, 0.0099, 0.0346, 0.0024, 0.0674, 0.0333, 0.0439, 0.1498,
        0.0723, 0.0339, 0.0073, 0.0304, 0.0220, 0.0433, 0.0347],
       dtype=torch.float64)


100%|██████████| 10/10 [00:12<00:00,  1.23s/it]

None
None
tensor([ 0.0328,  0.0244,  0.0048,  0.0423,  0.0226,  0.0402,  0.0299, -0.0032,
        -0.0005,  0.0433,  0.0240,  0.0202,  0.0290,  0.0129,  0.1408,  0.0386],
       dtype=torch.float64)
None
None
tensor([ 0.0054,  0.1618, -0.0038,  0.0229, -0.0043,  0.0418,  0.0019,  0.0103,
         0.0534,  0.0801, -0.0044,  0.0746,  0.0365,  0.0343,  0.0119,  0.0488],
       dtype=torch.float64)





In [None]:
r.parameters

tensor([[-0.9693,  0.8310,  0.8167,  0.1281],
        [-0.9694, -0.5553, -0.0303, -0.3287],
        [-0.5342, -0.3335,  0.9402,  0.0235]], dtype=torch.float64,
       requires_grad=True)

In [None]:
r.agent.get_q_values(0, r.parameters, num_iterations=100, epsilon=0)

[0.03896103896103896,
 0.2077922077922078,
 0.2987012987012987,
 0.45454545454545453]

In [None]:
r.target_parameters

tensor([[-0.9693,  0.8310,  0.8167,  0.1281],
        [-0.9694, -0.5553, -0.0303, -0.3287],
        [-0.5342, -0.3335,  0.9402,  0.0235]], dtype=torch.float64)

In [None]:
env = gym.make('FrozenLake-v1', is_slippery=False)
env.reset()
env.step(0)

(0, 0.0, False, {'prob': 1.0})