In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np

from tqdm import tqdm
from IPython.display import clear_output
import matplotlib.pyplot as plt

In [93]:
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(2, 1),
        )
        
    def forward(self, x):
        return self.actor(x)

In [121]:
model1 = ActorCritic()
model2 = ActorCritic()

optimizer1 = torch.optim.SGD(model1.parameters(), lr=0.1)

x = torch.rand((128, 2))


model3 = ActorCritic()


# Add the weights of model2 to model3
for param1, param2, param3 in zip(model1.parameters(), model2.parameters(), model3.parameters()):
    param3.data = param1 + param2
    print(param1.data)
        
model3(x).mean().backward()

for param1, param2, param3 in zip(model1.parameters(), model2.parameters(), model3.parameters()):
    # print(param1.grad)
    param1.grad = param3.grad
    
# for param1, param2, param3 in zip(model1.parameters(), model2.parameters(), model3.parameters()):
    # print(param1.grad)
    
optimizer1.step()

for param1, param2, param3 in zip(model1.parameters(), model2.parameters(), model3.parameters()):
    print(param1.data)

tensor([[-0.3053, -0.5017]])
tensor([-0.0585])
tensor([[-0.3566, -0.5534]])
tensor([-0.1585])


In [78]:
list(model1.actor.parameters())[0].grad, list(model2.actor.parameters())[0].grad, list(model3.actor.parameters())[0].grad

(None, None, tensor([[0.5532]]))

---

In [2]:
import copy

class Linear(nn.Module):
    def __init__(self, n_anchors, in_channels, out_channels, bias = True, same_init = False):
        super().__init__()
        self.n_anchors = n_anchors
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.is_bias = bias

        if same_init:
            anchor = nn.Linear(in_channels,out_channels,bias=self.is_bias)
            anchors = [copy.deepcopy(anchor) for _ in range(n_anchors)]
        else:
            anchors = [nn.Linear(in_channels,out_channels,bias=self.is_bias) for _ in range(n_anchors)]
        self.anchors = nn.ModuleList(anchors)

    def forward(self, x, alpha):
        xs = [anchor(x) for anchor in self.anchors]
        xs = torch.stack(xs,dim=-1)

        alpha = torch.stack([alpha] * self.out_channels, dim=-2)
        xs = (xs * alpha).sum(-1)
        return xs

class Sequential(nn.Sequential):
    def __init__(self,*args):
        super().__init__(*args)

    def forward(self, input, t):
        for module in self:
            input = module(input,t) if isinstance(module,Linear) else module(input)
        return input

tensor(-0.3001, grad_fn=<MeanBackward0>)

In [49]:
linear = Linear(2, 2, 3, bias=False)

x = torch.rand((128, 2))
alpha = torch.tensor([1., 1])
linear(x, alpha).mean().backward()

In [50]:
list(linear.anchors[0].parameters())[0].grad

tensor([[0.1646, 0.1663],
        [0.1646, 0.1663],
        [0.1646, 0.1663]])

In [51]:
list(linear.anchors[1].parameters())[0].grad

tensor([[0.1646, 0.1663],
        [0.1646, 0.1663],
        [0.1646, 0.1663]])