In [1]:
import torch.nn.functional as F
import numpy as np
import torch
#import gym
import argparse
import os
import yaml
import torch.nn as nn
#from sh_tools4maniskill import TD3, New_Trans_RB, env_constructor
import copy
#########################################################
from collections import defaultdict

import mani_skill.envs
import gymnasium as gym
from mani_skill.utils.wrappers.flatten import FlattenRGBDObservationWrapper
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
########################################################
from torchvision.models import efficientnet_b0, mobilenet_v2
from torch.utils.tensorboard import SummaryWriter

device = "cuda:1" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [1]:

import gymnasium
gymnasium.__version__

'0.29.1'

In [3]:

import mujoco
mujoco.__version__

'2.3.3'

In [3]:

class StateDictWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super(StateDictWrapper, self).__init__(env)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)  # Получаем obs и info от reset()
        # Преобразуем только obs в словарь с ключом 'state'
        return {'state': obs}, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)  # Получаем obs, reward, done, info и т.д.
        # Преобразуем только obs в словарь с ключом 'state'
        return {'state': obs}, reward, terminated, truncated, info 
    

def env_constructor(env_name, num_envs, obs_mode, reconf_freq=None):
    if obs_mode == 'state':
        env_kwargs = dict(obs_mode=obs_mode, sim_backend="gpu", control_mode="pd_joint_delta_pos")
        env = gym.make(env_name, num_envs=num_envs, reconfiguration_freq=reconf_freq, **env_kwargs)
        env = ManiSkillVectorEnv(env, num_envs, ignore_terminations=True, record_metrics=True)
        env = StateDictWrapper(env)
        s_d = env.observation_space.shape[-1]
        a_d = env.action_space.shape[-1]
        return env, s_d, a_d
    elif obs_mode == 'rgb':
        env_kwargs = dict(obs_mode=obs_mode, sim_backend="gpu", control_mode="pd_joint_delta_pos")
        env = gym.make(env_name, num_envs=num_envs, reconfiguration_freq=reconf_freq, **env_kwargs)
        env = FlattenRGBDObservationWrapper(env, rgb=True, depth=False, state=True)
        env = ManiSkillVectorEnv(env, num_envs=num_envs, ignore_terminations=True, record_metrics=True)
        s_d = env.observation_space['state'].shape[-1]
        a_d = env.action_space.shape[-1]
        return env, s_d, a_d
    elif obs_mode == 'rgbd':
        env_kwargs = dict(obs_mode=obs_mode, sim_backend="gpu", control_mode="pd_joint_delta_pos")
        env = gym.make(env_name, num_envs=num_envs, reconfiguration_freq=reconf_freq, **env_kwargs)
        env = FlattenRGBDObservationWrapper(env, rgb=True, depth=True, state=True)
        env = ManiSkillVectorEnv(env, num_envs=num_envs, ignore_terminations=True, record_metrics=True)
        s_d = env.observation_space['state'].shape[-1]
        a_d = env.action_space.shape[-1]
        return env, s_d, a_d


In [3]:

env, s_d, a_d = env_constructor('PushCube', 3, 'rgb', reconf_freq=None)
obs = env.reset()


  logger.warn(


In [2]:


class Actor(nn.Module):
    def __init__(self, d_model, state_dim, action_dim, max_action, obs_mode):
        super(Actor, self).__init__()
        
        in_channels = 3 if obs_mode=='rgb' else 4
        
        self.cnn = mobilenet_v2(pretrained=True)
                                                                  
        self.cnn_fc = nn.Linear(1000, d_model - d_model//4)

        self.state_fc = nn.Linear(state_dim, d_model//4)

        self.out_fc = nn.Linear(d_model, action_dim)
        
        self.max_action = max_action
        

    def forward(self, img, state):
        if len( img.shape ) == 5: # ne, bs, h, w, c
            ne, bs, h, w, c = img.shape[0], img.shape[1], img.shape[2], img.shape[3], img.shape[4]
            img = torch.permute(img, (0, 1, 4, 2, 3))
            img = img.reshape(ne*bs, c, h, w)
        img = self.cnn(img/255.0)
        img = img.reshape(ne, bs, 1000)
        img = self.cnn_fc( img ) # ne, bs, d_model-d_model//4
        
        state = self.state_fc(state) # ne, bs, d_model//4
        x = torch.cat([img, state],2) # ne, bs, d_model
        
        return self.max_action * torch.tanh(self.out_fc(x))

class Critic(nn.Module):
    def __init__(self, d_model, state_dim, action_dim, obs_mode):
        super(Critic, self).__init__()
        
        in_channels = 3 if obs_mode=='rgb' else 4
        self.cnn = mobilenet_v2(pretrained=True)                 
        self.cnn_fc = nn.Linear(1000, d_model - d_model//4)
        
        self.state_fc = nn.Linear(state_dim, d_model//4)
        
        self.action_fc = nn.Linear(action_dim, d_model)
        
        # Q1 architecture
        self.hidden_1 = nn.Linear(d_model*2, d_model)
        self.hidden_2 = nn.Linear(d_model*2, d_model)
        # Q2 architecture
        self.out_fc_1 = nn.Linear(d_model, 1)
        self.out_fc_2 = nn.Linear(d_model, 1)
        
    def forward(self, img, state, action):
        if len( img.shape ) == 5: # ne, bs, h, w, c
            ne, bs, h, w, c = img.shape[0], img.shape[1], img.shape[2], img.shape[3], img.shape[4]
            img = torch.permute(img, (0, 1, 4, 2, 3))
            img = img.reshape(ne*bs, c, h, w)
        img = self.cnn(img/255.0)
        img = img.reshape(ne, bs, 1000)
        img = self.cnn_fc( img ) # ne, bs, d_model-d_model//4
        
        state = self.state_fc(state) # ne, bs, d_model//4
        action = self.action_fc(action)
        
        sa = torch.cat([img, state, action],2) # ne, bs, d_model
        
        q1 = F.relu(self.hidden_1(sa))
        q1 = self.out_fc_1(q1)
        
        q2 = F.relu(self.hidden_2(sa))
        q2 = self.out_fc_2(q2)

        return q1, q2


    def Q1(self, img, state, action):
        
        if len( img.shape ) == 5: # ne, bs, h, w, c
            ne, bs, h, w, c = img.shape[0], img.shape[1], img.shape[2], img.shape[3], img.shape[4]
            img = torch.permute(img, (0, 1, 4, 2, 3))
            img = img.reshape(ne*bs, c, h, w)
        img = self.cnn(img/255.0)
        img = img.reshape(ne, bs, 1000)
        img = self.cnn_fc( img ) # ne, bs, d_model-d_model//4
        
        state = self.state_fc(state) # ne, bs, d_model//4
        action = self.action_fc(action)
        
        sa = torch.cat([img, state, action],2) # ne, bs, d_model
        
        q1 = F.relu(self.hidden_1(sa))
        q1 = self.out_fc_1(q1)
        
        return q1
class CustomTransformerEncoder(nn.Module):
    def __init__(self, d_model, n_heads, dim_feedforward, dropout, wo_ffn, norm_first, use_gate, gate_mode, mode, layer_num=None):
        super(CustomTransformerEncoder, self).__init__()
        
        self.norm_first = norm_first
        self.use_gate = use_gate
        self.wo_ffn = wo_ffn
        self.mode = mode
        
        if mode == 'Trans':
            self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True)

        
        self.layer_norm1 = nn.LayerNorm(d_model)
        if not self.wo_ffn:
            self.layer_norm2 = nn.LayerNorm(d_model)
        
        
        self.dropout1 = nn.Dropout(dropout)
        if not self.wo_ffn:
            self.dropout2 = nn.Dropout(dropout)
            self.dropout_ffn = nn.Dropout(dropout)
            self.linear1 = nn.Linear(d_model, dim_feedforward)
            self.linear2 = nn.Linear(dim_feedforward, d_model)
            self.relu = torch.nn.ReLU()

    def forward(self, src):                             # src = bs, seq_len, d_model
        
        skip_connection = src
        
        if self.norm_first:
            src = self.layer_norm1(src)                 #bs, seq_len, d_model
        
        
        if self.mode == 'Trans':
            src2, _ = self.self_attn(src, src, src)

        
        connection = skip_connection + self.dropout1(src2)
        
        if not self.norm_first:
            connection = self.layer_norm1(connection)

        if self.wo_ffn:
            return connection
        ###########FFN PART##############
        skip_connection2 = connection
        if self.norm_first:
            connection = self.layer_norm2(connection)
        
        src3 = self.linear2(self.dropout_ffn(self.relu(self.linear1(connection))))  #bs, seq_len, d_model
        
        if self.use_gate:
            connection2, percentage2 = self.gate(skip_connection2, self.relu(src3))  # ВОЗМОЖНО ПОСЛЕ RELU НАДО ТОЖЕ ДОБАВИТЬ ДРОПАУТ
        else: 
            connection2 = skip_connection2 + self.dropout2(src3)
        
        if not self.norm_first:
            connection2 = self.layer_norm2(connection2)
        
        return connection2#, (percentage1, percentage2)
    
class Trans_Critic(nn.Module):
    def __init__(self, state_dim, action_dim, d_model=256, num_heads=2, num_layers=1, obs_mode='rgb'):
        super(Trans_Critic, self).__init__()
        
        self.d_model = d_model
        in_channels = 3 if obs_mode=='rgb' else 4
        self.cnn = mobilenet_v2(pretrained=True)                 
        self.cnn_fc = nn.Linear(1000, d_model - d_model//4)
        self.state_fc = nn.Linear(state_dim, d_model//4)
        self.action_fc = nn.Linear(action_dim, d_model//2)
        self.transformer_encoder = CustomTransformerEncoder(d_model, num_heads, 512, 0.05, False, False, False, 'GRU', 'Trans')
        # Q1 and Q2
        self.out_fc_1 = nn.Linear(d_model+d_model//2, 1)
        self.out_fc_2 = nn.Linear(d_model+d_model//2, 1)


    def forward(self, img, state, action):
        n_e, bs, cont, s_d = state.shape
        state = state.view(-1, cont, s_d)
        img = img.view(-1, 3, 128, 128)
        state = self.state_fc(state)  # n_e*bs, cont, d_model//4
        img = self.cnn(img)            # n_e*bs*cont, 1000
        img = self.cnn_fc(img)          # n_e*bs*cont,  d_model-d_model//4
        img = img.view(-1, cont, self.d_model-self.d_model//4)  # n_e*bs, cont,  d_model-d_model//4
        
        x = torch.cat([img, state],-1)  # n_e*bs, cont, d_model
        transformer_out = self.transformer_encoder(x)  # n_e*bs, cont, d_model
        transformer_out = transformer_out[:, -1, :].view(n_e, bs, self.d_model)    # n_e, bs, d_model

        action = self.action_fc(action)                     # n_e, bs, d_model//2
        sa = torch.cat([transformer_out, action], dim=-1)   # n_e, bs, d_model+d_model//2
        
        q1 = self.out_fc_1(sa)
        q2 = self.out_fc_2(sa)
        return q1, q2

    def Q1(self, img, state, action):
        n_e, bs, cont, s_d = state.shape
        state = state.view(-1, cont, s_d)
        img = img.view(-1, 3, 128, 128)
        state = self.state_fc(state)  # n_e*bs, cont, d_model//4
        img = self.cnn(img)            # n_e*bs*cont, 1000
        img = self.cnn_fc(img)          # n_e*bs*cont,  d_model-d_model//4
        img = img.view(-1, cont, self.d_model-self.d_model//4)  # n_e*bs, cont,  d_model-d_model//4
        
        x = torch.cat([img, state],-1)  # n_e*bs, cont, d_model
        transformer_out = self.transformer_encoder(x)  # n_e*bs, cont, d_model
        transformer_out = transformer_out[:, -1, :].view(n_e, bs, self.d_model)    # n_e, bs, d_model

        action = self.action_fc(action)                     # n_e, bs, d_model//2
        sa = torch.cat([transformer_out, action], dim=-1)   # n_e, bs, d_model+d_model//2
        
        q1 = self.out_fc_1(sa)
        return q1

In [3]:
img = torch.randn(5, 16, 5, 128, 128, 3).to(device)
state = torch.randn(5, 16, 5, 25).to(device)
actions = torch.randn(5, 16, 8).to(device)

In [4]:
tr = Trans_Critic(25, 8, d_model=512, num_heads=2, num_layers=1, obs_mode='rgb').to(device)



In [6]:
q1 = tr.Q1(img, state, actions)

In [11]:
q1.shape, q2.shape

(torch.Size([5, 16, 1]), torch.Size([5, 16, 1]))

In [13]:
q1.shape

torch.Size([5, 16, 1])

In [8]:
actions = actor(img, state)

In [8]:
q1, q2 = critic(img, state, actions)

In [9]:
q1.shape

torch.Size([5, 32, 1])