In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from actorcritic import Network, ActorCriticAgent, EnvironmentWrapper  
import gymnasium as gym

In [75]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [76]:
# Function to load a model
def load_model(model_path, input_size, output_size, hidden_sizes, is_policy, device=device):
    model = Network(input_size, output_size, hidden_sizes, is_policy)
    model.load_state_dict(torch.load(model_path))
    return model.to(device)

# Function to freeze a network
def freeze_network(network):
    for param in network.parameters():
        param.requires_grad = False


In [77]:
class Adapter(nn.Module):
    def __init__(self, input_size, output_size):
        super(Adapter, self).__init__()
        # Define a simple MLP with one hidden layer for dimensionality adaptation and non-linearity
        self.adapter_layers = nn.Sequential(
            nn.Linear(input_size, output_size),  # You might want to change the size or add more layers
            nn.ReLU(),
        )

    def forward(self, x):
        return self.adapter_layers(x)


In [99]:
class ProgressiveNetwork(nn.Module):
    def __init__(self, input_size=6, output_size=3, hidden_sizes=[64, 64], adapter_size=2, source_models=[], is_policy=False, device=torch.device('cpu')):
        super(ProgressiveNetwork, self).__init__()
        self.source_models = source_models
        self.hidden_sizes = hidden_sizes
        self.is_policy = is_policy

        # Initialize adapters for each layer of the source models
        self.adapters = [nn.ModuleList() for _ in range(len(source_models))]
        for i, source_model in enumerate(source_models):
            adapter1 = Adapter(source_model.hidden_sizes[0], adapter_size).to(device)
            adapter2 = Adapter(source_model.hidden_sizes[1], adapter_size).to(device)
            self.adapters[i].append(adapter1)
            self.adapters[i].append(adapter2)

        # For 2 hidden layers
        self.input_layer = nn.Linear(input_size, hidden_sizes[0])
        self.hidden_layer = nn.Linear(hidden_sizes[0] + len(source_models) * adapter_size, 
                                      hidden_sizes[1])  

        self.output_layer = nn.Linear(hidden_sizes[1] + len(source_models) * adapter_size, output_size)

    def forward(self, x):
        
        # get ouput from first layer
        output = F.relu(self.input_layer(x))

        # for model in self.source_models, get output from first layer:
        for i, model in enumerate(self.source_models):
            model_output = model.network[:2](x)
            # Apply adapter to the output
            model_output = self.adapters[i][0](model_output)
            output = torch.cat((output, model_output), dim=-1)

        # repeat for second layer
        # get ouput from first layer
        output = F.relu(self.hidden_layer(output))

        # for model in self.source_models, get output from first layer:
        for i, model in enumerate(self.source_models):
            model_output = model.network[:4](x)
            # Apply adapter to the output
            model_output = self.adapters[i][1](model_output)
            output = torch.cat((output, model_output), dim=-1)


        # Output layer
        output = self.output_layer(output)

        if self.is_policy:
            return nn.Softmax(dim=-1)(output)
        return output


In [100]:
# Example of setting up and using the Progressive Network
# Assuming the input_size, output_size, and hidden_sizes are defined elsewhere
source_networks = [load_model('models/Acrobot-v1_policy_network.pth', 6, 3, [64, 64], True),
                   load_model('models/MountainCarContinuous-v0_policy_network.pth', 6, 3, [64, 64], True)]

# Freeze the source networks
for net in source_networks:
    freeze_network(net)

# Setup Progressive Network for CartPole (or another target task)
# You need to define cartpole_input_size, cartpole_output_size, and whether it's a policy or value network accordingly
policy_network = ProgressiveNetwork(input_size=6, output_size=3, hidden_sizes=[64, 64], adapter_size=2, source_models=source_networks, is_policy=True, device=device)

In [101]:
# Example of setting up and using the Progressive Network
# Assuming the input_size, output_size, and hidden_sizes are defined elsewhere
source_networks = [load_model('models/Acrobot-v1_value_network.pth', 6, 1, [64, 64], False),
                   load_model('models/MountainCarContinuous-v0_value_network.pth', 6, 1, [64, 64], False)]

# Freeze the source networks
for net in source_networks:
    freeze_network(net)

# Setup Progressive Network for CartPole (or another target task)
# You need to define cartpole_input_size, cartpole_output_size, and whether it's a policy or value network accordingly
value_network = ProgressiveNetwork(input_size=6, output_size=1, hidden_sizes=[64, 64], adapter_size=2, source_models=source_networks, is_policy=False, device=device)

In [102]:
# hyperparameters
config = {
    'experiment': 'ProgressiveCartPole',
    'device': 'cuda',
    'state_size': 6, 
    'action_size': 3,
    'hidden_sizes': [64, 64], 
    'lr_actor': 0.001,
    'lr_critic': 0.005,
    'verbosity': 10,
    'env_name': 'CartPole-v1',
    'gamma': 0.99, 
    'reward_threshold': 475.0,
    'max_episodes': 2000,
    'max_steps': 500,
    'update_frequency': 500

}

In [103]:
# Initialize the environment
env = gym.make(config['env_name'])
env_wrapper = EnvironmentWrapper(env)

# Initialize the ActorCriticAgent
agent = ActorCriticAgent(config)

In [104]:
device = torch.device(config['device'])
agent.policy_network = policy_network.to(device)
agent.value_network = value_network.to(device)

In [105]:
results = agent.train(env_wrapper, max_episodes=config['max_episodes'], max_steps=config['max_steps'], reward_threshold=config['reward_threshold'], update_frequency=config['update_frequency'])

Episode 0, Avg Reward: 24.0, PLoss: 25.374746322631836, VLoss: 24.224069595336914
Episode 10, Avg Reward: 22.727272727272727, PLoss: 17.677387237548828, VLoss: 16.180065155029297
Episode 20, Avg Reward: 20.0, PLoss: 22.273283004760742, VLoss: 21.151987075805664
Episode 30, Avg Reward: 18.548387096774192, PLoss: 11.310555458068848, VLoss: 10.14748764038086
Episode 40, Avg Reward: 18.682926829268293, PLoss: 16.089935302734375, VLoss: 15.144342422485352
Episode 50, Avg Reward: 19.45098039215686, PLoss: 14.265267372131348, VLoss: 13.189603805541992
Episode 60, Avg Reward: 18.540983606557376, PLoss: 9.83056354522705, VLoss: 10.174091339111328
Episode 70, Avg Reward: 17.830985915492956, PLoss: 17.141817092895508, VLoss: 16.17665672302246
Episode 80, Avg Reward: 17.91358024691358, PLoss: 13.511711120605469, VLoss: 13.202524185180664
Episode 90, Avg Reward: 17.791208791208792, PLoss: 45.274593353271484, VLoss: 40.23586654663086
Episode 100, Avg Reward: 17.69, PLoss: 22.799264907836914, VLoss: 

In [106]:
# save the model
torch.save(agent.policy_network.state_dict(), 'models/ProgressiveCartPole_policy_network.pth')
torch.save(agent.value_network.state_dict(), 'models/ProgressiveCartPole_value_network.pth')