In [1]:
from utils import *
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def Coordinate_descend_genep(dim, comp_dim):
    assert dim >= comp_dim, "compression dimension must be smaller than dimension"
    ide = torch.eye(dim)
    select_col = torch.randperm(dim)[:comp_dim]
    sign = torch.randint(0, 2, (comp_dim, ))
    sign = sign * 2 - 1
    # XXX make clear whether PTP is I or PPT is I
    P = ide[:, select_col] * sign
    return P

In [3]:
class SubScafLinearTest(nn.Module):
    """
    Linear network with compressed dimension
    """
    def __init__(self, comp_dim: int, comp_mat: torch.Tensor, wraped_model: nn.Linear):
        self.comp_mat = comp_mat
        self.comp_dim = comp_dim
        device = wraped_model.weight.device
        dtype = wraped_model.weight.dtype
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.x = self.weight.detach().clone()
        self.b = nn.Parameter(torch.zeros((comp_dim, wraped_model.in_features), **factory_kwargs))
        del self.weight

    def forward(self, input):
        return F.linear(input, self.comp_mat @ self.b + self.x, self.bias)

In [4]:
comp_dim = 64
wraped_module = nn.Linear(512, 512, bias=False)
comp_mat = Coordinate_descend_genep(wraped_module.out_features, comp_dim)
model = SubScafLinear(comp_dim, comp_mat, wraped_module)

In [5]:
model.b.data = torch.ones_like(model.b.data)

In [6]:
(comp_mat @ model.b + model.x).T

tensor([[ 5.9137e-03,  1.0026e+00,  3.3829e-02,  ...,  2.2875e-02,
         -6.2389e-04, -1.7577e-02],
        [ 1.4547e-02,  9.6451e-01, -3.7899e-02,  ...,  3.8589e-02,
          1.9935e-02,  3.4524e-02],
        [ 3.2876e-02,  9.8613e-01, -2.2340e-02,  ..., -2.6213e-02,
         -1.7955e-02,  2.3092e-02],
        ...,
        [-3.7569e-02,  1.0426e+00,  3.9935e-02,  ...,  4.3619e-03,
          1.6771e-02, -6.1447e-03],
        [-1.1713e-02,  9.9415e-01,  1.2940e-02,  ..., -1.1452e-02,
          1.6718e-02, -1.4886e-02],
        [-3.7856e-02,  1.0128e+00,  3.2255e-02,  ..., -3.5552e-02,
         -1.3113e-02, -3.8475e-02]], grad_fn=<PermuteBackward0>)

In [7]:
activation_values = []
def hook(module, input, output):
    activation_values.append(output.detach().clone())
handle = model.register_forward_hook(hook)

In [8]:
input_data = torch.eye(512)
output = model(input_data)
handle.remove()
for i, activations in enumerate(activation_values):
    print(f"第 {i+1} 个激活值的形状: {activations.shape}")
    print(f"第 {i+1} 个激活值: {activations}")

第 1 个激活值的形状: torch.Size([512, 512])
第 1 个激活值: tensor([[ 5.9137e-03,  1.0026e+00,  3.3829e-02,  ...,  2.2875e-02,
         -6.2389e-04, -1.7577e-02],
        [ 1.4547e-02,  9.6451e-01, -3.7899e-02,  ...,  3.8589e-02,
          1.9935e-02,  3.4524e-02],
        [ 3.2876e-02,  9.8613e-01, -2.2340e-02,  ..., -2.6213e-02,
         -1.7955e-02,  2.3092e-02],
        ...,
        [-3.7569e-02,  1.0426e+00,  3.9935e-02,  ...,  4.3619e-03,
          1.6771e-02, -6.1447e-03],
        [-1.1713e-02,  9.9415e-01,  1.2940e-02,  ..., -1.1452e-02,
          1.6718e-02, -1.4886e-02],
        [-3.7856e-02,  1.0128e+00,  3.2255e-02,  ..., -3.5552e-02,
         -1.3113e-02, -3.8475e-02]])


In [57]:
model.x

tensor([[ 0.0417, -0.0276, -0.0326,  ...,  0.0225,  0.0314,  0.0001],
        [ 0.0125,  0.0263,  0.0419,  ...,  0.0306,  0.0375,  0.0388],
        [-0.0300, -0.0118, -0.0378,  ..., -0.0366, -0.0139, -0.0109],
        ...,
        [-0.0262,  0.0070, -0.0359,  ..., -0.0078, -0.0306,  0.0115],
        [ 0.0354,  0.0183,  0.0306,  ..., -0.0296, -0.0359, -0.0303],
        [ 0.0071,  0.0410,  0.0226,  ..., -0.0129, -0.0354,  0.0304]])

In [20]:
model.x

tensor([[-0.0191,  0.0315, -0.0157,  ..., -0.0055, -0.0170, -0.0283],
        [-0.0164,  0.0325,  0.0270,  ...,  0.0212, -0.0297,  0.0428],
        [-0.0123,  0.0415,  0.0061,  ..., -0.0230, -0.0186, -0.0173],
        ...,
        [ 0.0039, -0.0417, -0.0124,  ..., -0.0081, -0.0392,  0.0304],
        [ 0.0023,  0.0092,  0.0231,  ..., -0.0239, -0.0293,  0.0124],
        [ 0.0159,  0.0249, -0.0316,  ..., -0.0335,  0.0371, -0.0409]])

In [21]:
model.b

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)