In [1]:
import torch
import torch.nn as nn
from actorcritic import Network  # Replace with the actual import for your Network class



In [None]:
# Function to load a model
def load_model(model_path, input_size, output_size, hidden_sizes, is_policy):
    model = Network(input_size, output_size, hidden_sizes, is_policy)
    model.load_state_dict(torch.load(model_path))
    return model

# Function to freeze a network
def freeze_network(network):
    for param in network.parameters():
        param.requires_grad = False






In [None]:
# Modified Network class to accept additional inputs from source networks
class ProgressiveNetwork(nn.Module):
    def __init__(self, input_size, output_size, hidden_sizes, source_networks=[], is_policy=True):
        super(ProgressiveNetwork, self).__init__()
        self.is_policy = is_policy
        self.source_networks = source_networks
        # Adjust the input size to include source networks' top hidden layer outputs
        adjusted_input_size = input_size + sum([hidden_sizes[-1] for _ in source_networks])  # Assumes the last hidden layer size is indicative

        # Building the layers
        layers = [nn.Linear(adjusted_input_size, hidden_sizes[0]), nn.ReLU()]
        for i in range(1, len(hidden_sizes)):
            layers.append(nn.Linear(hidden_sizes[i-1], hidden_sizes[i]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_sizes[-1], output_size))
        
        self.network = nn.Sequential(*layers)
        if is_policy:
            self.network.add_module("softmax", nn.Softmax(dim=-1))

    def forward(self, x):
        # Collect outputs from source network(s)
        source_outputs = [source.network[:-1](x).detach() for source in self.source_networks]  # Assumes source output is from the second last layer
        if source_outputs:
            x = torch.cat([x] + source_outputs, dim=1)
        
        return self.network(x)

In [None]:
# Example of setting up and using the Progressive Network
# Assuming the input_size, output_size, and hidden_sizes are defined elsewhere
source_networks = [load_model('Acrobot-v1_policy_network.pth', acrobot_input_size, acrobot_output_size, [64, 128, 64], True),
                   load_model('MountainCarContinuous-v0_policy_network.pth', mcc_input_size, mcc_output_size, [64, 128, 64], True)]

# Freeze the source networks
for net in source_networks:
    freeze_network(net)

# Setup Progressive Network for CartPole (or another target task)
# You need to define cartpole_input_size, cartpole_output_size, and whether it's a policy or value network accordingly
target_network = ProgressiveNetwork(cartpole_input_size, cartpole_output_size, [64, 128, 64], source_networks, is_policy=True)