In [66]:
from prometheus_client import values
from torch.distributions import Normal

from models.actor import ValueNetwork, PolicyNetwork, Actor, Critic
from models.vae import ResVAE
import torch
import torch.nn as nn
import numpy as np
import gymnasium as gym
import pygame
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

In [67]:
def load_checkpoint(model, path, device):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

In [71]:
class PPO(nn.Module):
    def __init__(self, policy_net, value_net, state_dim, action_dim, gamma=0.99, batch_size=100, epsilone=0.2):
        super(PPO, self).__init__()
        self.policy_net = policy_net
        self.value_net = value_net
        self.transform = transforms.Compose([
                                             transforms.ToPILImage(),
                                             transforms.Resize(96),
                                             transforms.ToTensor(),
                                             ])

        self.gamma = gamma
        self.batch_size = batch_size
        self.epsilone = epsilone

        self.policy_opt = torch.optim.Adam(self.policy_net.parameters(), lr=3e-4)
        self.value_opt = torch.optim.Adam(self.value_net.parameters(), lr=3e-4)

    def get_action(self, state):
        state = self.transform(state)
        print(state)
        with torch.no_grad():
            mean, log_std = self.policy_net(state.unsqueeze(0))
        dist = Normal(mean, torch.exp(log_std))
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.numpy(), log_prob.numpy().squeeze()

    def compute_returns(self, rewards, dones):
        returns = np.zeros_like(rewards)
        running_return = 0

        for t in reversed(range(len(rewards))):
            running_return = rewards[t] + self.gamma * running_return * (1 - dones[t])
            returns[t] = running_return

        return returns

    def fit(self, states, actions, rewards, dones, old_log_probs):
        states = np.array(states)
        actions = np.array(actions)
        rewards = np.array(rewards)
        dones = np.array(dones)
        old_log_probs = np.array(old_log_probs)

        returns = self.compute_returns(rewards, dones)

        states = torch.stack([self.transform(img) for img in states])

        actions_tensor = torch.FloatTensor(actions)
        returns_tensor = torch.FloatTensor(returns).unsqueeze(1)
        old_log_probs_tensor = torch.FloatTensor(old_log_probs).unsqueeze(1)

        n_samples = len(states)
        indices = np.arange(n_samples)
        np.random.shuffle(indices)

        all_values = self.value_net(states)

        advantages = returns_tensor - all_values.detach()

        for start in range(0, n_samples, self.batch_size):
            end = start + self.batch_size
            batch_indices = indices[start:end]

            # Батч данных
            b_states = states[batch_indices]
            b_actions = actions_tensor[batch_indices]
            b_returns = returns_tensor[batch_indices]
            b_old_log_probs = old_log_probs_tensor[batch_indices]
            b_advantage = advantages[batch_indices]

            b_mean, b_log_std = self.policy_net(b_states)
            b_dist = Normal(b_mean, torch.exp(b_log_std))
            b_new_log_probs = b_dist.log_prob(b_actions)

            if len(b_new_log_probs.shape) > 1:
                b_new_log_probs = b_new_log_probs.sum(dim=-1, keepdim=True)

            b_ratio = torch.exp(b_new_log_probs - b_old_log_probs)

            policy_loss_1 = b_ratio * b_advantage
            policy_loss_2 = torch.clamp(b_ratio, 1. - self.epsilone, 1. + self.epsilone) * b_advantage
            policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()

            self.policy_opt.zero_grad()
            policy_loss.backward()
            self.policy_opt.step()


            b_values = self.value_net(b_states)
            value_loss = torch.nn.functional.mse_loss(b_values, b_returns)

            self.value_opt.zero_grad()
            value_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5)  # Clip gradients
            self.value_opt.step()



In [72]:
env = gym.make("CarRacing-v3", render_mode="rgb_array")

In [73]:
vae = ResVAE(3, hidden_dim=4096, z_dim=32)
load_checkpoint(vae, './checkpoints/checkpoint_epoch_20.pt', "cpu")
policy = PolicyNetwork(vae, Actor(32, 3))
critic = ValueNetwork(vae, Critic(32))

In [74]:
ppo = PPO(policy, critic, (96, 96, 3), 3)

In [75]:
state = env.reset()[0]

In [76]:
EPISODE_N = 5
SIM_NUM = 10

In [77]:
total_rewards = []
for episode in range(EPISODE_N):
    states, actions, rewards, dones, episode_log_probs = [], [], [], [], []

    for _ in tqdm(range(SIM_NUM)):
        total_reward = 0
        state = env.reset()[0]
        for t in range(2000):
            states.append(state)

            action, log_prob = ppo.get_action(state)
            actions.append(action)

            state, reward, done, _, _ = env.step(action[0])
            rewards.append(reward)
            dones.append(done)
            episode_log_probs.append(log_prob)

            total_reward += reward

        total_rewards.append(total_reward)

    ppo.fit(states, actions, rewards, dones, episode_log_probs)

plt.plot(total_rewards)

  0%|          | 0/10 [00:00<?, ?it/s]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 

  0%|          | 0/10 [00:01<?, ?it/s]

tensor([[[0.3922, 0.3922, 0.3922,  ..., 0.3922, 0.3922, 0.3922],
         [0.3922, 0.3922, 0.3922,  ..., 0.3922, 0.3922, 0.3922],
         [0.3922, 0.3922, 0.3922,  ..., 0.3922, 0.3922, 0.3922],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.7922, 0.7922, 0.7922,  ..., 0.7922, 0.7922, 0.7922],
         [0.7922, 0.7922, 0.7922,  ..., 0.7922, 0.7922, 0.7922],
         [0.7922, 0.7922, 0.7922,  ..., 0.7922, 0.7922, 0.7922],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.3922, 0.3922, 0.3922,  ..., 0.3922, 0.3922, 0.3922],
         [0.3922, 0.3922, 0.3922,  ..., 0.3922, 0.3922, 0.3922],
         [0.3922, 0.3922, 0.3922,  ..., 0.3922, 0.3922, 0.




KeyboardInterrupt: 

In [64]:
env = gym.make("CarRacing-v3", render_mode="human")

In [65]:
total_rewards = []
for episode in range(EPISODE_N):
    #states, actions, rewards, dones, episode_log_probs = [], [], [], [], []

    for _ in tqdm(range(SIM_NUM)):
        total_reward = 0
        state = env.reset()[0]
        for t in range(2000):
            #states.append(state)

            action, log_prob = ppo.get_action(state)
            #actions.append(action)

            state, reward, done, _, _ = env.step(action[0])
            #rewards.append(reward)
            #dones.append(done)
            #episode_log_probs.append(log_prob)

            total_reward += reward

        total_rewards.append(total_reward)

    #ppo.fit(states, actions, rewards, dones, episode_log_probs)

plt.plot(total_rewards)

  0%|          | 0/10 [00:38<?, ?it/s]


KeyboardInterrupt: 