In [None]:
import torch
import copy

In [None]:
class DQN:
    def __init__(self, state_size,action_size = 4):
        l1 = state_size
        l2 = 24
        l3 = 24
        l4 = action_size
        self.model = torch.nn.Sequential(
        torch.nn.Linear(l1, l2),
        torch.nn.ReLU(),
        torch.nn.Linear(l2, l3),
        torch.nn.ReLU(),
        torch.nn.Linear(l3,l4))

        self.model2 = copy.deepcopy(self.model)
        self.model2.load_state_dict(self.model.state_dict())
        self.loss_fn = torch.nn.MSELoss()
        self.learning_rate = 0.001
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

# The function "update_target" copies the state of the prediction network to the target network. You need to use this in regular intervals.
    def update_target(self):
        self.model2.load_state_dict(self.model.state_dict())

# The function "get_qvals" returns a numpy list of qvals for the state given by the argument based on the prediction network.
    def get_qvals(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        q_values = self.model(state)
        return q_values

# The function "get_maxQ" returns the maximum q-value for the state given by the argument based on the target network.
    def get_maxQ(self,state):
        q_values = self.model2(state)
        return torch.max(q_values).item()

# The function "train_one_step_new" performs a single training step.
# It returns the current loss (only needed for debugging purposes).
# Its parameters are three parallel lists: a minibatch of states, a minibatch of actions,
# a minibatch of the corresponding TD targets and the discount factor.
    def train_one_step(self, states, actions, targets):
        targets_reply = []
        state1_batch = torch.cat([torch.from_numpy(s).float() for s in states])
        action_batch = torch.Tensor(actions)
        Q1 = self.model(state1_batch)
        X = Q1.gather(dim=1,index=action_batch.long().unsqueeze(dim=1)).squeeze()
        Y = torch.tensor(targets).float()
        loss = self.loss_fn(X, Y)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()