In [28]:
import torch
import torch.nn as nn

In [29]:
class ParamListTest(nn.Module):
    def __init__(self):
        super().__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

model = nn.Sequential(
    ParamListTest(),
    ParamListTest()
)

for name, param in model.named_parameters():
    print(name, param.shape)
    break


class ParamDictTest(nn.Module):
    def __init__(self):
        super().__init__()
        self.params = nn.ParameterDict({
                'R1': nn.Parameter(torch.randn(5, 10)),
                'R2': nn.Parameter(torch.randn(5, 10))
        })

    def forward(self, x, choice):
        x = self.params[choice].mm(x)
        return x

model = nn.Sequential(
    ParamDictTest(),
    ParamDictTest()
)

for name, param in model.named_parameters():
    print(name, param.shape)
    break

0.params.0 torch.Size([10, 10])
0.params.R1 torch.Size([5, 10])


In [43]:
from tools.rotation_utils import RotateModule

hidden_sizes = [512, 1024, 2048, 4096]
num_attention_headses = [8, 16, 32]
for hidden_size in hidden_sizes:
    for num_attention_heads in num_attention_headses:
        rotate_module = RotateModule(hidden_size, num_attention_heads=num_attention_heads)
        print(rotate_module.params_dict)
        print(f"==> (hidden_size, num_attention_heads): ({hidden_size}, {num_attention_heads}) rotate matrix: {rotate_module().shape}")

ParameterDict(
    (R0): Parameter containing: [torch.FloatTensor of size 8x8]
    (R1): Parameter containing: [torch.FloatTensor of size 8x8]
)
==> (hidden_size, num_attention_heads): (512, 8) rotate matrix: torch.Size([64, 64])
ParameterDict(
    (R0): Parameter containing: [torch.FloatTensor of size 4x4]
    (R1): Parameter containing: [torch.FloatTensor of size 8x8]
)
==> (hidden_size, num_attention_heads): (512, 16) rotate matrix: torch.Size([32, 32])
ParameterDict(
    (R0): Parameter containing: [torch.FloatTensor of size 4x4]
    (R1): Parameter containing: [torch.FloatTensor of size 4x4]
)
==> (hidden_size, num_attention_heads): (512, 32) rotate matrix: torch.Size([16, 16])
ParameterDict(
    (R0): Parameter containing: [torch.FloatTensor of size 8x8]
    (R1): Parameter containing: [torch.FloatTensor of size 16x16]
)
==> (hidden_size, num_attention_heads): (1024, 8) rotate matrix: torch.Size([128, 128])
ParameterDict(
    (R0): Parameter containing: [torch.FloatTensor of size