In [2]:
from algorithm.ppo import Learner
from configs.args_parser import Parameters
from configs import Factor_dictionary
from algorithm.analysis import MO_Analysis
from algorithm.utils import MO_Stats
# import mo_gymnasium as mo_gym 
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from io import BytesIO
import os
import io
env_id = 'minecart-v0'


  from .autonotebook import tqdm as notebook_tqdm


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def constant_init(module, val, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)

class Block(nn.Module):
    """
    trainable block and locked block
    """
    def __init__(self,input_dim,output_dim,hidden_size = [512, 512]):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_size = hidden_size
        fc_linear = []
        fc_linear.append(nn.Linear(input_dim,hidden_size[0]))
        fc_linear.append(nn.Tanh())
        for i in range(len(hidden_size) - 1):
            fc_linear.append(nn.Linear(hidden_size[i],hidden_size[i+1]))
            fc_linear.append(nn.Tanh())
        self.fc_linear = nn.Sequential(*fc_linear)
        self.last_dim = hidden_size[-1]
        self.last_layer = nn.Linear(hidden_size[-1],output_dim)
    
    def forward(self,x):
        x = self.fc_linear(x)
        x = self.last_layer(x)
        return x

class ControlBlock(nn.Module):
    """
    contain two block, one is locked, one is trainable
    contain two zero layer
    """
    def __init__(self,input_dim,feature_dim,output_dim,hidden_size = [512,512],allow_retrain = False):
        super().__init__()
        #! trainable and locked block must has the same size 
        self.trainable_block = Block(input_dim,output_dim,hidden_size)
        self.locked_block = Block(input_dim,output_dim,hidden_size)

        ##TODO 改动1： zero layer 需要 state or style 的信息
        self.zero_layer1 = nn.Sequential(
            nn.Linear(feature_dim + input_dim,hidden_size[0]),
            nn.LeakyReLU(),
            nn.Linear(hidden_size[0],input_dim)
        )
        self.zero_layer2 = nn.Sequential(
            nn.Linear(output_dim + feature_dim,hidden_size[0]),
            nn.LeakyReLU(),
            nn.Linear(hidden_size[0],output_dim)
        )
        self.allow_retrain = allow_retrain
        self.init()
        self._set_parameter()
    def _set_parameter(self):
        for p in self.trainable_block.parameters():
            p.requires_grad = True
        for p in self.locked_block.parameters():
            p.requires_grad = self.allow_retrain

    def init(self):
        for m in self.modules():
            if m in (self.zero_layer1,self.zero_layer2):
                constant_init(m,val=0.0)

    def forward(self,x,extra_input):
        input1 = torch.cat([x,extra_input],dim = 1)
        delta_x = self.zero_layer1(input1)
        x_ = x + delta_x
        y_ = self.trainable_block.forward(x_)
        input2 = torch.cat([y_,extra_input],dim = 1)
        delta_y = self.zero_layer2(input2)
        return self.trainable_block(x) + delta_y
    def load_expert_state_dict(self,state_dict):
        self.locked_block.load_state_dict(state_dict)
        self.trainable_block.load_state_dict(state_dict)
        with torch.no_grad():
            for p in self.trainable_block.parameters():
                p.requires_grad = True
        with torch.no_grad():
            for p in self.locked_block.parameters():
                p.requires_grad = self.allow_retrain

class ExpertNet(nn.Module):
    def __init__(self, state_dim, style_dim, num_acts):
        super(ExpertNet,self).__init__()
        self.state_encoder = nn.Linear(state_dim, 512)
        self.block1 = Block(512,512,hidden_size=[512,512])

        self.policy_block = Block(512,256,hidden_size=[512,512])
        self.value_block = Block(512,256,hidden_size=[512,512])

        self.policy_head = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(256, num_acts)
        )

        self.value_head = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(256, 1)
        )

        self.softmax = nn.Softmax(dim=1)
    def init(self):
        for m in self.modules():
            if m in (self.policy_head, self.value_head):
                for sub_m in m:
                    if isinstance(sub_m, (nn.Conv2d, nn.Linear)):
                        nn.init.orthogonal(sub_m.weight)
            elif m in (self.state_encoder,self.block1,self.policy_block,self.value_block):
                torch.nn.init.kaiming_normal_(m.weight)
    def forward(self, state, ma):
        # ma: action mask: 1 means invalid!!
        state_ = state[:, :-6]
        style = state[:, -6:]
        state_emb = self.state_encoder(state_)

        mid = self.block1(state_emb)
        policy_mid = self.policy_block(mid)
        value_mid = self.value_block(mid)

        policy_out = self.policy_head(policy_mid)
        y = policy_out.masked_fill(ma, -np.inf)
        probs = self.softmax(y - y.max(1)[0].unsqueeze(1))
        log_probs = F.log_softmax(y, dim=1)
        value = self.value_head(value_mid)
        return probs, log_probs, value
    
    def save_expert_state_dict(self):
        expert_state = {
            "state_encoder_state_dict": self.state_encoder.state_dict(),
            "block1_state_dict": self.block1.state_dict(),
            "policy_block_state_dict": self.policy_block.state_dict(),
            "value_block_state_dict": self.value_block.state_dict(),
            "policy_head_state_dict": self.policy_head.state_dict(),
            "value_head_state_dict": self.value_head.state_dict(),
        }
        return expert_state 

class  PolicyNet(nn.Module):
    def __init__(self,state_dim, style_dim, num_acts,allow_retrain=False):
        super(PolicyNet,self).__init__()
        #! Encoder: state style
        self.state_encoder = nn.Linear(state_dim, 512)
        self.style_encoder = nn.Linear(style_dim,512)

        self.block1 = ControlBlock(512,512,512,[512,512],allow_retrain=allow_retrain)

        self.policy_block = ControlBlock(512,512,256,[512,512],allow_retrain=allow_retrain)
        self.value_block = ControlBlock(512,512,256,[512,512],allow_retrain=allow_retrain)

        self.policy_head = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(256, num_acts)
        )

        self.value_head = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(256, 1)
        )
        self.softmax = nn.Softmax(dim=1)
    def init(self):
        for m in self.modules():
            if m in (self.policy_head, self.value_head):
                for sub_m in m:
                    if isinstance(sub_m, (nn.Conv2d, nn.Linear)):
                        nn.init.orthogonal(sub_m.weight)
            elif m in (self.state_encoder,self.style_encoder):
                torch.nn.init.kaiming_normal_(m.weight)
    
    def forward(self,state, ma):
        # ma: action mask: 1 means invalid!!
        state_ = state[:, :-6]
        style = state[:, -6:]
        state_emb = self.state_encoder(state_)
        style_emb = self.style_encoder(style)

        mid = self.block1.forward(state_emb, style_emb)
        policy_mid = self.policy_block.forward(mid, style_emb)
        value_mid = self.value_block.forward(mid, style_emb)

        policy_out = self.policy_head(policy_mid)
        y = policy_out.masked_fill(ma, -np.inf)
        probs = self.softmax(y - y.max(1)[0].unsqueeze(1))
        log_probs = F.log_softmax(y, dim=1)
        value = self.value_head(value_mid)
        return probs, log_probs, value
    
    def load_expert_state_dict(self,state_dict:dict):
        #! allow retrain 
        if 'state_encoder_state_dict' in state_dict.keys():
            self.state_encoder.load_state_dict(state_dict['state_encoder_state_dict'])
        if 'policy_head_state_dict' in state_dict.keys():
            self.policy_head.load_state_dict(state_dict['policy_head_state_dict'])
        if 'value_head_state_dict' in state_dict.keys():
            self.value_head.load_state_dict(state_dict['value_head_state_dict'])
        if 'block1_state_dict' in state_dict.keys():
            self.block1.load_expert_state_dict(state_dict['block1_state_dict'])
        if 'policy_block_state_dict' in state_dict.keys():
            self.policy_block.load_expert_state_dict(state_dict['policy_block_state_dict'])
        if 'value_block_state_dict' in state_dict.keys():
            self.value_block.load_expert_state_dict(state_dict['value_block_state_dict'])
            
        

In [14]:
mas = PolicyNet(20,6,4)
expert = ExpertNet(20,6,4)

In [5]:
expert_jit = torch.jit.script(expert)
mas_jit = torch.jit.script(mas)
torch.jit.save(expert_jit, 'tmp/expert_jit')
torch.jit.save(mas_jit, 'tmp/mas_jit')


In [18]:
expert_state_dict = expert.save_expert_state_dict()
mas.load_expert_state_dict(expert_state_dict)

In [20]:
torch.save(expert_state_dict, 'tmp/expert_state_dict')

In [21]:
import glob 

model_files = glob.glob('tmp/*')

In [23]:
sd = torch.load(model_files[-1])
mas.load_expert_state_dict(sd)