In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from model import Model, CustomDataset
from utils import rgb_to_gray
import matplotlib.pyplot as plt
import gymnasium as gym
import numpy as np
import imageio
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

In [2]:
class PongAgent:
    def __init__(self, hidden_conv_layer_dims, hidden_lin_layer_dims, game = "ALE/Pong-v5", epsilon = 0.5, gamma=0.99, epochs=25, num_actions=6) -> None:
        self.target_network = Model(hidden_conv_layer_dims, hidden_lin_layer_dims)
        self.target_network._save__model__('0')
        self.online_network = torch.load("epoch_0_model.pth")
        self.gamma = gamma
        self.epsilon = epsilon
        self.epochs = epochs
        self.decay = pow(100*self.epsilon, 1/self.epochs)
        self.num_skip_frame = 4
        self.num_actions = num_actions
        self.game = game
        self.epochs = epochs
    def evaluate(self, iter) -> None:
        env = gym.make("ALE/Pong-v5", render_mode="rgb_array")
        for ____ in range(1):
            obs, _ = env.reset()
            images = []
            reward = 0
            action = 0
            state = [rgb_to_gray(obs)]
            for __ in range(self.num_skip_frame-1):
                obs, rew, done, info, _ = env.step(action=action)
                images.append(env.render())
                if done:
                    break
                reward += rew
                obs = rgb_to_gray(obs)
                state.append(obs)
            if len(state) != 4:
                print("F")
                continue
            while not done:
                if state is None:
                    break
                Q_s = self.online_network(state).detach()
                action = torch.argmax(Q_s).detach().item()
                obs, rew, done, info, _ = env.step(action=action)
                images.append(env.render())
                reward += rew * (1-done)
                state = state[:-1].append(rgb_to_gray(obs))
            print(len(images))
            imageio.mimsave(f"./gifs/iter_{iter}_{____+1}.gif", images, duration=1)
        print(f"Average reward after {iter} iterations : {reward/5}")
        return reward/5
    
    def choose_action(self, Q_s) -> int:
        if np.random.random()<self.epsilon:
            return np.random.randint(0,self.num_actions)
        else:
            if not isinstance(Q_s, torch.Tensor):
                Q_s = torch.tensor(Q_s)
            return torch.argmax(Q_s).item()
        
    def _generate_training_data(self):
        replay_buffer_data = []
        replay_buffer_labels = []
        env = gym.make(self.game)
        num_episodes = 1
        for episode_no in tqdm(range(num_episodes), desc=f"Simulating episodes "):
            rewards = []
            obs, _ = env.reset()
            done = False
            action = 0
            while not done:
                reward = 0
                state = []
                for __ in range(self.num_skip_frame):
                    obs, rew, done, info, _ = env.step(action=action)
                    if done:
                        break
                    reward += rew
                    obs = rgb_to_gray(obs)
                    state.append(obs)
                if len(state) != 4:
                    break
                state = np.array(state)
                rewards.append(reward)
                Q_s = self.online_network(state[np.newaxis, :]).detach()
                action = self.choose_action(Q_s)
                replay_buffer_data.append((state, action))
            for i in range(len(rewards)-2, -1, -1):
                rewards[i] += self.gamma * rewards[i+1]
            replay_buffer_labels.extend(rewards)
        return replay_buffer_data, replay_buffer_labels
    
    def _train_agent(self):
        rewards = []
        for epoch in range(self.epochs):
            for __ in range(3):
                replay_buffer_data, replay_buffer_labels = self._generate_training_data()
                train_dataset = CustomDataset(replay_buffer_data, replay_buffer_labels)
                self.online_network._train__instance__(train_dataset=train_dataset)
            self.online_network._save__model__(str(epoch+1))
            del self.target_network
            del replay_buffer_data
            del replay_buffer_labels
            self.target_network = torch.load(f"epoch_{epoch+1}_model.pth")
            print("Model Updated and saved!")
            print(self.evaluate(epoch+1))
            # rewards.append(self.evaluate(epoch+1))
            self.epsilon = self.epsilon / self.decay
        # plt.plot(rewards)
        # plt.show()

In [3]:
hidden_conv_layer_dims = [(4, 16, 8, 3, 0), (16, 32, 5, 3, 2), (32, 64, 3, 2, 0)]
hidden_lin_layer_dims = [(1024, 256), (256, 32), (32, 6)]
agent = PongAgent(hidden_conv_layer_dims, hidden_lin_layer_dims, game="ALE/Pong-v5", gamma=0.99, epsilon=0.5, epochs=25, num_actions=6)
agent._train_agent()

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.39it/s]
 20%|██        | 1/5 [00:00<00:01,  3.11it/s]

Epoch 1/5, Loss: 2015.3626299417851


 40%|████      | 2/5 [00:00<00:01,  2.63it/s]

Epoch 2/5, Loss: 1794.6113699496127


 60%|██████    | 3/5 [00:01<00:00,  2.47it/s]

Epoch 3/5, Loss: 2166.755949401349


 80%|████████  | 4/5 [00:01<00:00,  2.55it/s]

Epoch 4/5, Loss: 2430.8568893028555


100%|██████████| 5/5 [00:01<00:00,  2.52it/s]


Epoch 5/5, Loss: 2885.903962423922


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.51it/s]
 20%|██        | 1/5 [00:00<00:01,  2.52it/s]

Epoch 1/5, Loss: 135.3876697869582


 40%|████      | 2/5 [00:00<00:01,  2.51it/s]

Epoch 2/5, Loss: 173.68698336977403


 60%|██████    | 3/5 [00:01<00:00,  2.62it/s]

Epoch 3/5, Loss: 165.98407338089078


 80%|████████  | 4/5 [00:01<00:00,  2.75it/s]

Epoch 4/5, Loss: 224.48913784476866


100%|██████████| 5/5 [00:01<00:00,  2.72it/s]


Epoch 5/5, Loss: 254.6912870917314


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.34it/s]
 20%|██        | 1/5 [00:00<00:01,  2.64it/s]

Epoch 1/5, Loss: 4422.524444109222


 40%|████      | 2/5 [00:00<00:01,  2.71it/s]

Epoch 2/5, Loss: 3867.188726817224


 60%|██████    | 3/5 [00:01<00:00,  2.68it/s]

Epoch 3/5, Loss: 4314.18010147157


 80%|████████  | 4/5 [00:01<00:00,  2.70it/s]

Epoch 4/5, Loss: 3338.1718790562763


100%|██████████| 5/5 [00:01<00:00,  2.67it/s]

Epoch 5/5, Loss: 4328.017787305697
Model Updated and saved!
4
Average reward after 1 iterations : 0.0
0.0



Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.39it/s]
 20%|██        | 1/5 [00:00<00:01,  2.69it/s]

Epoch 1/5, Loss: 69.0498244900171


 40%|████      | 2/5 [00:00<00:01,  2.64it/s]

Epoch 2/5, Loss: 61.02939651999585


 60%|██████    | 3/5 [00:01<00:00,  2.66it/s]

Epoch 3/5, Loss: 81.1839470135679


 80%|████████  | 4/5 [00:01<00:00,  2.60it/s]

Epoch 4/5, Loss: 77.2955452364891


100%|██████████| 5/5 [00:01<00:00,  2.57it/s]


Epoch 5/5, Loss: 72.23258309706108


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.33it/s]
 20%|██        | 1/5 [00:00<00:01,  2.40it/s]

Epoch 1/5, Loss: 1329.6862642048113


 40%|████      | 2/5 [00:00<00:01,  2.48it/s]

Epoch 2/5, Loss: 3166.1030010017125


 60%|██████    | 3/5 [00:01<00:00,  2.58it/s]

Epoch 3/5, Loss: 3284.561545381028


 80%|████████  | 4/5 [00:01<00:00,  2.60it/s]

Epoch 4/5, Loss: 2424.806686149934


100%|██████████| 5/5 [00:01<00:00,  2.61it/s]


Epoch 5/5, Loss: 1644.1288917437982


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
 20%|██        | 1/5 [00:00<00:01,  2.90it/s]

Epoch 1/5, Loss: 706.1140447281407


 40%|████      | 2/5 [00:00<00:01,  2.78it/s]

Epoch 2/5, Loss: 972.904068884026


 60%|██████    | 3/5 [00:01<00:00,  2.73it/s]

Epoch 3/5, Loss: 734.1968633548588


 80%|████████  | 4/5 [00:01<00:00,  2.67it/s]

Epoch 4/5, Loss: 556.611030978049


100%|██████████| 5/5 [00:01<00:00,  2.75it/s]

Epoch 5/5, Loss: 658.4878891897868
Model Updated and saved!
4
Average reward after 2 iterations : 0.0
0.0



Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.39it/s]
 20%|██        | 1/5 [00:00<00:01,  2.19it/s]

Epoch 1/5, Loss: 2543.752980782602


 40%|████      | 2/5 [00:00<00:01,  2.47it/s]

Epoch 2/5, Loss: 1217.6571961908467


 60%|██████    | 3/5 [00:01<00:00,  2.53it/s]

Epoch 3/5, Loss: 757.0702598899986


 80%|████████  | 4/5 [00:01<00:00,  2.54it/s]

Epoch 4/5, Loss: 1217.1183562572346


100%|██████████| 5/5 [00:01<00:00,  2.55it/s]


Epoch 5/5, Loss: 828.3591374394623


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]
 20%|██        | 1/5 [00:00<00:01,  2.88it/s]

Epoch 1/5, Loss: 87.48060031073052


 40%|████      | 2/5 [00:00<00:01,  2.86it/s]

Epoch 2/5, Loss: 87.47551454191367


 60%|██████    | 3/5 [00:01<00:00,  2.88it/s]

Epoch 3/5, Loss: 65.6719926581019


 80%|████████  | 4/5 [00:01<00:00,  2.90it/s]

Epoch 4/5, Loss: 75.77019032804026


100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 5/5, Loss: 21.26876137365348


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
 20%|██        | 1/5 [00:00<00:01,  2.52it/s]

Epoch 1/5, Loss: 2429.9111377818654


 40%|████      | 2/5 [00:00<00:01,  2.40it/s]

Epoch 2/5, Loss: 2766.1702856303623


 60%|██████    | 3/5 [00:01<00:00,  2.44it/s]

Epoch 3/5, Loss: 3351.8422273912993


 80%|████████  | 4/5 [00:01<00:00,  2.46it/s]

Epoch 4/5, Loss: 2127.080867606898


100%|██████████| 5/5 [00:02<00:00,  2.49it/s]

Epoch 5/5, Loss: 3276.2703585816807
Model Updated and saved!
4
Average reward after 3 iterations : 0.0
0.0



Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.32it/s]
 20%|██        | 1/5 [00:00<00:01,  2.43it/s]

Epoch 1/5, Loss: 247.80057636822661


 40%|████      | 2/5 [00:00<00:01,  2.38it/s]

Epoch 2/5, Loss: 320.1910284221909


 60%|██████    | 3/5 [00:01<00:00,  2.46it/s]

Epoch 3/5, Loss: 84.66874607698786


 80%|████████  | 4/5 [00:01<00:00,  2.58it/s]

Epoch 4/5, Loss: 42.941309977560614


100%|██████████| 5/5 [00:01<00:00,  2.57it/s]


Epoch 5/5, Loss: 93.85301141461159


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.51it/s]
 20%|██        | 1/5 [00:00<00:01,  2.39it/s]

Epoch 1/5, Loss: 278.8064523891034


 40%|████      | 2/5 [00:00<00:01,  2.44it/s]

Epoch 2/5, Loss: 41.130782578663414


 60%|██████    | 3/5 [00:01<00:00,  2.61it/s]

Epoch 3/5, Loss: 270.08224951726186


 80%|████████  | 4/5 [00:01<00:00,  2.68it/s]

Epoch 4/5, Loss: 128.44346580642062


100%|██████████| 5/5 [00:01<00:00,  2.62it/s]


Epoch 5/5, Loss: 234.98021089638831


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.23it/s]
 20%|██        | 1/5 [00:00<00:01,  2.33it/s]

Epoch 1/5, Loss: 1409.3727695672794


 40%|████      | 2/5 [00:00<00:01,  2.51it/s]

Epoch 2/5, Loss: 2007.706868483131


 60%|██████    | 3/5 [00:01<00:00,  2.71it/s]

Epoch 3/5, Loss: 1679.4863853147292


 80%|████████  | 4/5 [00:01<00:00,  2.72it/s]

Epoch 4/5, Loss: 1379.0895483857869


100%|██████████| 5/5 [00:01<00:00,  2.63it/s]

Epoch 5/5, Loss: 2033.5371708369844
Model Updated and saved!
4
Average reward after 4 iterations : 0.0
0.0



Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.39it/s]
 20%|██        | 1/5 [00:00<00:01,  2.69it/s]

Epoch 1/5, Loss: 754.0858466512291


 40%|████      | 2/5 [00:00<00:01,  2.81it/s]

Epoch 2/5, Loss: 1367.553886630259


 60%|██████    | 3/5 [00:01<00:00,  2.65it/s]

Epoch 3/5, Loss: 436.914886858054


 80%|████████  | 4/5 [00:01<00:00,  2.60it/s]

Epoch 4/5, Loss: 902.5602160281571


100%|██████████| 5/5 [00:01<00:00,  2.61it/s]


Epoch 5/5, Loss: 193.54012998965345


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.44it/s]
 20%|██        | 1/5 [00:00<00:01,  2.57it/s]

Epoch 1/5, Loss: 2285.865378318392


 40%|████      | 2/5 [00:00<00:01,  2.62it/s]

Epoch 2/5, Loss: 3206.82499453593


 60%|██████    | 3/5 [00:01<00:00,  2.60it/s]

Epoch 3/5, Loss: 3088.784696349238


 80%|████████  | 4/5 [00:01<00:00,  2.52it/s]

Epoch 4/5, Loss: 2340.9303254494876


100%|██████████| 5/5 [00:01<00:00,  2.56it/s]


Epoch 5/5, Loss: 2058.0831080759112


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.25it/s]
 20%|██        | 1/5 [00:00<00:01,  2.28it/s]

Epoch 1/5, Loss: 1308.5350407895462


 40%|████      | 2/5 [00:00<00:01,  2.37it/s]

Epoch 2/5, Loss: 2066.4500322212957


 60%|██████    | 3/5 [00:01<00:00,  2.38it/s]

Epoch 3/5, Loss: 1585.342835177496


 80%|████████  | 4/5 [00:01<00:00,  2.49it/s]

Epoch 4/5, Loss: 1625.5857462841998


100%|██████████| 5/5 [00:02<00:00,  2.44it/s]

Epoch 5/5, Loss: 2399.3318634159623
Model Updated and saved!
4
Average reward after 5 iterations : 0.0
0.0



Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.51it/s]
 20%|██        | 1/5 [00:00<00:01,  2.65it/s]

Epoch 1/5, Loss: 21493.070964642047


 40%|████      | 2/5 [00:00<00:01,  2.84it/s]

Epoch 2/5, Loss: 23652.954759696666


 60%|██████    | 3/5 [00:01<00:00,  2.59it/s]

Epoch 3/5, Loss: 16083.620363994414


 80%|████████  | 4/5 [00:01<00:00,  2.52it/s]

Epoch 4/5, Loss: 18549.053253669765


100%|██████████| 5/5 [00:01<00:00,  2.60it/s]


Epoch 5/5, Loss: 18508.661768955462


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.43it/s]
 20%|██        | 1/5 [00:00<00:01,  2.67it/s]

Epoch 1/5, Loss: 169.68862654314228


 40%|████      | 2/5 [00:00<00:01,  2.68it/s]

Epoch 2/5, Loss: 109.98465268179066


 60%|██████    | 3/5 [00:01<00:00,  2.66it/s]

Epoch 3/5, Loss: 127.68745782823187


 80%|████████  | 4/5 [00:01<00:00,  2.61it/s]

Epoch 4/5, Loss: 116.54666602685182


100%|██████████| 5/5 [00:01<00:00,  2.65it/s]


Epoch 5/5, Loss: 103.68852769382222


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.32it/s]
 20%|██        | 1/5 [00:00<00:01,  2.74it/s]

Epoch 1/5, Loss: 17919.77955653653


 40%|████      | 2/5 [00:00<00:01,  2.81it/s]

Epoch 2/5, Loss: 16085.166401887145


 60%|██████    | 3/5 [00:01<00:00,  2.59it/s]

Epoch 3/5, Loss: 16288.126062075342


 80%|████████  | 4/5 [00:01<00:00,  2.50it/s]

Epoch 4/5, Loss: 12628.541127851908


100%|██████████| 5/5 [00:01<00:00,  2.58it/s]

Epoch 5/5, Loss: 17008.233749237068
Model Updated and saved!
4
Average reward after 6 iterations : 0.0
0.0



Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.40it/s]
 20%|██        | 1/5 [00:00<00:01,  2.60it/s]

Epoch 1/5, Loss: 561.9238794325706


 40%|████      | 2/5 [00:00<00:01,  2.69it/s]

Epoch 2/5, Loss: 858.589204832757


 60%|██████    | 3/5 [00:01<00:00,  2.79it/s]

Epoch 3/5, Loss: 469.6599697544776


 80%|████████  | 4/5 [00:01<00:00,  2.80it/s]

Epoch 4/5, Loss: 609.3594398993431


100%|██████████| 5/5 [00:01<00:00,  2.76it/s]


Epoch 5/5, Loss: 556.3552872127152


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
 20%|██        | 1/5 [00:00<00:01,  3.06it/s]

Epoch 1/5, Loss: 17495.503434696115


 40%|████      | 2/5 [00:00<00:00,  3.05it/s]

Epoch 2/5, Loss: 19050.238294672556


 60%|██████    | 3/5 [00:01<00:00,  2.90it/s]

Epoch 3/5, Loss: 15077.77709777763


 80%|████████  | 4/5 [00:01<00:00,  2.78it/s]

Epoch 4/5, Loss: 16013.96327716214


100%|██████████| 5/5 [00:01<00:00,  2.78it/s]


Epoch 5/5, Loss: 17118.97359332376


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.30it/s]
 20%|██        | 1/5 [00:00<00:01,  2.35it/s]

Epoch 1/5, Loss: 10.062501206288431


 40%|████      | 2/5 [00:00<00:01,  2.55it/s]

Epoch 2/5, Loss: 54.150266489090114


 60%|██████    | 3/5 [00:01<00:00,  2.69it/s]

Epoch 3/5, Loss: 9.153297336182979


 80%|████████  | 4/5 [00:01<00:00,  2.71it/s]

Epoch 4/5, Loss: 4.006253375281093


100%|██████████| 5/5 [00:01<00:00,  2.68it/s]

Epoch 5/5, Loss: 87.01628492179381
Model Updated and saved!
4
Average reward after 7 iterations : 0.0
0.0



Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
 20%|██        | 1/5 [00:00<00:01,  2.71it/s]

Epoch 1/5, Loss: 1884.6014383339452


 40%|████      | 2/5 [00:00<00:01,  2.65it/s]

Epoch 2/5, Loss: 772.4909288444047


 60%|██████    | 3/5 [00:01<00:00,  2.68it/s]

Epoch 3/5, Loss: 1816.8432024067813


 80%|████████  | 4/5 [00:01<00:00,  2.65it/s]

Epoch 4/5, Loss: 534.0643079794232


100%|██████████| 5/5 [00:01<00:00,  2.68it/s]


Epoch 5/5, Loss: 358.2479325553297


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.43it/s]
 20%|██        | 1/5 [00:00<00:01,  2.74it/s]

Epoch 1/5, Loss: 15146.324222830559


 40%|████      | 2/5 [00:00<00:01,  2.91it/s]

Epoch 2/5, Loss: 14478.51459465409


 60%|██████    | 3/5 [00:01<00:00,  2.95it/s]

Epoch 3/5, Loss: 17112.380659398787


 80%|████████  | 4/5 [00:01<00:00,  2.99it/s]

Epoch 4/5, Loss: 12745.892187804426


100%|██████████| 5/5 [00:01<00:00,  2.80it/s]


Epoch 5/5, Loss: 12389.089654557374


Simulating episodes : 100%|██████████| 1/1 [00:01<00:00,  1.25s/it]
 20%|██        | 1/5 [00:00<00:02,  1.98it/s]

Epoch 1/5, Loss: 242.05281615944386


 40%|████      | 2/5 [00:00<00:01,  2.04it/s]

Epoch 2/5, Loss: 207.25463887298918


 60%|██████    | 3/5 [00:01<00:00,  2.06it/s]

Epoch 3/5, Loss: 21.669477667490263


 80%|████████  | 4/5 [00:01<00:00,  2.03it/s]

Epoch 4/5, Loss: 115.8799694868161


100%|██████████| 5/5 [00:02<00:00,  2.01it/s]

Epoch 5/5, Loss: 11.931142315662111
Model Updated and saved!





4
Average reward after 8 iterations : 0.0
0.0


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.14it/s]
 20%|██        | 1/5 [00:00<00:01,  2.33it/s]

Epoch 1/5, Loss: 11195.825024006459


 40%|████      | 2/5 [00:00<00:01,  2.44it/s]

Epoch 2/5, Loss: 13547.97915837806


 60%|██████    | 3/5 [00:01<00:00,  2.26it/s]

Epoch 3/5, Loss: 8780.501074263508


 80%|████████  | 4/5 [00:01<00:00,  2.19it/s]

Epoch 4/5, Loss: 8551.23674414961


100%|██████████| 5/5 [00:02<00:00,  2.18it/s]

Epoch 5/5, Loss: 9392.926313352804



Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]
 20%|██        | 1/5 [00:00<00:01,  2.60it/s]

Epoch 1/5, Loss: 5.227209493903777


 40%|████      | 2/5 [00:00<00:01,  2.22it/s]

Epoch 2/5, Loss: 5.09191236664881


 60%|██████    | 3/5 [00:01<00:00,  2.12it/s]

Epoch 3/5, Loss: 2.7913726615027743


 80%|████████  | 4/5 [00:01<00:00,  2.01it/s]

Epoch 4/5, Loss: 7.255638096128928


100%|██████████| 5/5 [00:02<00:00,  2.04it/s]

Epoch 5/5, Loss: 3.4944777166454335



Simulating episodes : 100%|██████████| 1/1 [00:01<00:00,  1.44s/it]
 20%|██        | 1/5 [00:00<00:02,  1.74it/s]

Epoch 1/5, Loss: 1394.4909871436512


 40%|████      | 2/5 [00:01<00:01,  1.78it/s]

Epoch 2/5, Loss: 905.6746872941363


 60%|██████    | 3/5 [00:01<00:01,  1.72it/s]

Epoch 3/5, Loss: 1217.747195327642


 80%|████████  | 4/5 [00:02<00:00,  1.64it/s]

Epoch 4/5, Loss: 2810.3306396486614


100%|██████████| 5/5 [00:03<00:00,  1.63it/s]

Epoch 5/5, Loss: 432.5402417653638
Model Updated and saved!
4
Average reward after 9 iterations : 0.0
0.0



Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
 20%|██        | 1/5 [00:00<00:01,  3.02it/s]

Epoch 1/5, Loss: 13661.130759533393


 40%|████      | 2/5 [00:00<00:01,  2.94it/s]

Epoch 2/5, Loss: 15970.118427394726


 60%|██████    | 3/5 [00:01<00:00,  2.84it/s]

Epoch 3/5, Loss: 14897.411305402497


 80%|████████  | 4/5 [00:01<00:00,  2.86it/s]

Epoch 4/5, Loss: 22183.87821625021


100%|██████████| 5/5 [00:01<00:00,  2.89it/s]


Epoch 5/5, Loss: 16264.840461765263


Simulating episodes : 100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
 20%|██        | 1/5 [00:00<00:01,  2.77it/s]

Epoch 1/5, Loss: 8144.530910818903


 40%|████      | 2/5 [00:00<00:01,  2.70it/s]

Epoch 2/5, Loss: 7284.804992860859


 60%|██████    | 3/5 [00:01<00:00,  2.58it/s]

Epoch 3/5, Loss: 15918.143839012231


 80%|████████  | 4/5 [00:01<00:00,  2.68it/s]

Epoch 4/5, Loss: 12778.356423061185


 80%|████████  | 4/5 [00:01<00:00,  2.36it/s]


KeyboardInterrupt: 

In [None]:
x = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]])
ac = torch.tensor([0, 2])
ac.unsqueeze(1)
print(x)
y = x.gather(1, ac.unsqueeze(1))
print(len(y.shape))

tensor([[ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12],
        [13, 14, 15, 16, 17, 18]])
2
