In [1]:
from test_utils import *
tag = f"c0"
filter = LayerFilter(unselect_keys=['bn', 'num_batches_tracked','downsample'], all_select_keys=['layer'])
weight_dict = load_weight_dict_by_tag('../out/FedAMC-amc_alexnet_cosine_0.97_0.2_4_0.65/2024-08-30-15:13:54/grad_lists', tag)
# cal_similay_by_tag('weight_lists/202405281805-avg-ss/',"server", unselect_keys=['bn', 'num_batches_tracked','downsample'], describe='202405281805-avg-ss')


In [5]:
for i in weight_dict[0][0].keys():
    print(f"'{i}' : {weight_dict[0][0][i].shape},")

'base.conv1.weight' : torch.Size([64, 3, 5, 5]),
'base.conv1.bias' : torch.Size([64]),
'base.bn1.weight' : torch.Size([64]),
'base.bn1.bias' : torch.Size([64]),
'base.bn1.running_mean' : torch.Size([64]),
'base.bn1.running_var' : torch.Size([64]),
'base.bn1.num_batches_tracked' : torch.Size([]),
'base.conv2.weight' : torch.Size([192, 64, 5, 5]),
'base.conv2.bias' : torch.Size([192]),
'base.bn2.weight' : torch.Size([192]),
'base.bn2.bias' : torch.Size([192]),
'base.bn2.running_mean' : torch.Size([192]),
'base.bn2.running_var' : torch.Size([192]),
'base.bn2.num_batches_tracked' : torch.Size([]),
'base.conv3.weight' : torch.Size([384, 192, 3, 3]),
'base.conv3.bias' : torch.Size([384]),
'base.bn3.weight' : torch.Size([384]),
'base.bn3.bias' : torch.Size([384]),
'base.bn3.running_mean' : torch.Size([384]),
'base.bn3.running_var' : torch.Size([384]),
'base.bn3.num_batches_tracked' : torch.Size([]),
'base.conv4.weight' : torch.Size([256, 384, 3, 3]),
'base.conv4.bias' : torch.Size([256]),
'ba

In [236]:
from typing import Dict

class SlideSVDCompress:
    def __init__(self, K, D, L):
        self.K = K # U的列数
        self.D = D # 主动更新的维度
        self.L = L # 参数切片的长度
        self.U = None # 基础的U

    def update_basis(self, update_dict:Dict[int, torch.Tensor]):
        if self.U is None:
            assert len(update_dict) == self.K, f"First update_dict length must be {self.K}"
            self.U = torch.cat(list(update_dict.values()), dim=1)
        
        # key是更新位置，value是更新的值
        for k, v in update_dict.items():
            self.U[:,k] = v.clone().detach()

    def update_basis_by_vector(self, vector, update_threshold=0):
        '''
        Return update_dict
        '''
        # 通过向量更新U
        flatten_L = vector.numel() if len(vector.shape) == 1 else (vector.numel() // vector.shape[0])
        if flatten_L % self.L != 0:
            return {}
        vector = vector.reshape(-1, self.L)
        if self.K > vector.shape[0]:
            raise ValueError(f"K {self.K} must less than vector.shape[0] {vector.shape[0]}")
        update_dict = {}
        vector_t = vector.T
        if self.U is None:
            # 通过SVD分解得到U
            U, S, V = torch.linalg.svd(vector_t, full_matrices=False)
            self.U = U[:,:self.K]
            # update_dict 为全部U向量
            for i in range(self.K):
                update_dict[i] = self.U[:,i]
        
        elif self.D > 0:
            # 通过U重构vector
            a = self.U.T @ vector_t
            g = self.U @ a
            e = vector_t - g
            U_e, S_e, V_e = torch.linalg.svd(e, full_matrices=False)
            U_K_e = torch.cat([self.U, U_e[:,:self.D]], dim=1)

            a = U_K_e.T @ vector_t
            
            contribution = torch.sum(a ** 2, dim=1)  # 计算每个正交向量的贡献度（平方和）
            _, min_indices = torch.topk(contribution, k=self.D, largest=False)

            min_indices_set = set(min_indices.tolist())
            wait_D_update_set = set([i for i in range(self.K, self.K+self.D)])
            sub_index = min_indices_set - wait_D_update_set
            add_index = wait_D_update_set - min_indices_set

            # 交换列
            U_K_e[:,list(sub_index)] = U_K_e[:,list(add_index)]
            U_K = U_K_e[:,:self.K]
            a_2 = U_K.T @ vector_t
            g_2 = U_K @ a_2
            e_2 = vector_t - g_2

            # 若更新后的误差变化小于阈值，则不更新
            # print(f"de {(e.norm() - e_2.norm())/e.norm()}, update_threshold {update_threshold}")
            if (e.norm() - e_2.norm())/e.norm() < update_threshold:
                return {}
            
            self.U = U_K_e[:,:self.K]
            # 返回更新列字典
            for i in sub_index:
                update_dict[i] = U_K_e[:,i].clone().detach()
        
        return update_dict

    def compress(self, vector):
        flatten_L = vector.numel() if len(vector.shape) == 1 else (vector.numel() // vector.shape[0])
        if flatten_L % self.L != 0:
            print(f"vector.numel() // vector.shape[0] {flatten_L} can't divide L {self.L}. Return itself")
            return vector, 0
        else:
            vector = vector.reshape(-1, self.L)

        vector_t = vector.T
        # 通过U重构vector
        a = self.U.T @ vector_t
        g = self.U @ a
        e = vector_t - g
        return a, e.T.reshape(vector.shape)

    def uncompress(self, a:Tensor, shape = None):
        # 如果a的维度刚好等于shape，直接返回
        if a.shape == shape:
            return a
        elif shape is None:
            return (self.U @ a).T
        else:
            return (self.U @ a).T.reshape(shape)

In [None]:
class CompressorCombinModel:
    def __init__(self, setting_dict:Dict[str, tuple]):
        self.setting_dict = setting_dict
        self.compressor_dict:Dict[str, SlideSVDCompress] = {}
        for key, value in setting_dict.items():
            self.compressor_dict[key] = SlideSVDCompress(*value)

    def compress(self, model_params:Dict[str, Tensor], can_update_basis_func=None, **kwargs):
        compress_dict = {}
        combin_update_dict = {}
        for key, value in model_params.items():
            compressor = self.compressor_dict[key]
            if can_update_basis_func is not None:
                if can_update_basis_func(**kwargs):
                    combin_update_dict[key] = compressor.update_basis_by_vector(value)
                else:
                    combin_update_dict[key] = {}
            compress_dict[key], _ = compressor.compress(value)
        return compress_dict, combin_update_dict

    def uncompress(self, compress_model_params:Dict[str, Tensor], target_model_params:Dict[str, Tensor]):
        for key, value in compress_model_params.items():
            target_model_params[key] = self.compressor_dict[key].uncompress(value, target_model_params[key].shape)
    
    def update(self, combin_update_dict:Dict[str, Dict[int, Tensor]]):
        for key, value in combin_update_dict.items():
            compressor = self.compressor_dict[key]
            compressor.update_basis(value)


In [1]:
def cal_K_D_L(M, K=None, D=None, L=None):
    # vector = vector.flatten()
    # M = vector.shape[0]
    if K is None and D is not None and L is not None:
        max = L // 2
        K = D * L * L // M
        if K > max:
            raise ValueError(f"Not good. K is too large, max is {max}")
        else:
            r = 2 * L * D / M
    elif K is not None and D is None and L is not None:
        if L // K <= 2:
            raise ValueError(f"Not good. L // K is too small, min is 2, now is {L // K}")
        max = M // (4 * K)
        D = K * M // (L * L)
        if D > max:
            raise ValueError(f"Not good. D is too large, max is {max}")
        else:
            r = 2 * L * D / M
    elif K is not None and D is not None and L is None:
        if K * D * 4 >= M:
            raise ValueError(f"Not good. K * D * 4 is too large, max is {M // 4}, now is {K * D * 4}")
        min = 2 * K
        L = int(math.sqrt(M * K // D))
        if L < min:
            raise ValueError(f"Not good. L is too small, min is {min}")
        else:
            r = 2 * L * D / M
    elif K is not None and D is not None and L is not None:
        if K * D * 4 >= M:
            raise ValueError(f"Not good. K * D * 4 is too large, max is {M // 4}, now is {K * D * 4}")
        if L // K < 2:
            raise ValueError(f"Not good. L // K is too small, min is 2, now is {L // K}")
        r = K / L + D * L / M
    else:
        raise ValueError(f"K, D, L must have one None")
    
    if K < D:
        raise ValueError(f"Not good. K < D, now is {K} < {D}")
    return K, D, L, 1 - r



In [12]:
K, D, L, r = cal_K_D_L(4096,K=2, D=2, L=64)
print(f"K={K}, D={D}, L={L}, r={r}")

K=2, D=2, L=64, r=0.9375


In [302]:
# 遍历K和D，找到最好的K和D，使得r最大
max_r = 0
max_K = 0
max_D = 0
max_L = 0
for K in [2,4,6,8,10,12,14,16,18,20]:
    for D in range(1, K):
        for L in range(2*K, 1024):
            try:
                K, D, L, r = cal_K_D_L(4096, K=K, D=D, L=L)
                if L < 1024 and r > max_r:
                    max_r = r
                    max_K = K
                    max_D = D
                    max_L = L

            except:
                continue
    print(f"{K}-th max_K={max_K}, max_D={max_D}, max_L={max_L}, max_r={max_r}")
print(f"max_K={max_K}, max_D={max_D}, max_L={max_L}, max_r={max_r}")


2-th max_K=2, max_D=1, max_L=91, max_r=0.955805181146978
4-th max_K=2, max_D=1, max_L=91, max_r=0.955805181146978
6-th max_K=2, max_D=1, max_L=91, max_r=0.955805181146978
8-th max_K=2, max_D=1, max_L=91, max_r=0.955805181146978
10-th max_K=2, max_D=1, max_L=91, max_r=0.955805181146978
12-th max_K=2, max_D=1, max_L=91, max_r=0.955805181146978
14-th max_K=2, max_D=1, max_L=91, max_r=0.955805181146978
16-th max_K=2, max_D=1, max_L=91, max_r=0.955805181146978
18-th max_K=2, max_D=1, max_L=91, max_r=0.955805181146978
20-th max_K=2, max_D=1, max_L=91, max_r=0.955805181146978
max_K=2, max_D=1, max_L=91, max_r=0.955805181146978


In [331]:
compresser = SlideSVDCompress(2, 2, 64)
v_5 = weight_dict[0][0]['classifier.bn7.bias']
update_dict = compresser.update_basis_by_vector(v_5)

In [332]:
for i in range(1, 50):
    v_test = weight_dict[i][0]['classifier.bn7.bias']
    a_test, e_test = compresser.compress(v_test)
    g_test = compresser.uncompress(a_test, v_test.shape)
    before = cos_similar(g_test, v_test)

    if i % 1 == 0:
        update_dict = compresser.update_basis_by_vector(v_test, update_threshold=0)
        a_test, e_test_after = compresser.compress(v_test)
        g_test = compresser.uncompress(a_test, v_test.shape)
        after = cos_similar(g_test, v_test)
        if after < before:
            if e_test_after.norm() > e_test.norm():
                print(f"Th-{i}: before:{before:10.5f}, after:{after:10.5f} ==> Update_dict:{update_dict.keys()} e:{e_test.norm()} after e:{e_test_after.norm()}")
            else:
                print(f"Th-{i}: before:{before:10.5f}, after:{after:10.5f} ==> Update_dict:{update_dict.keys()} e_test_after < e_test but after_similar < before_similar")
        else:
            print(f"Th-{i}: before:{before:10.5f}, after:{after:10.5f} ==> Update_dict:{update_dict.keys()}")
    else:
        print(f"Th-{i}: before:{before:10.5f} e:{e_test.norm()} after e:{e_test_after.norm()}")

Th-1: before:   0.45079, after:   0.50904 ==> Update_dict:dict_keys([1])
Th-2: before:   0.17858, after:   0.74591 ==> Update_dict:dict_keys([0, 1])
Th-3: before:   0.03031, after:   0.84314 ==> Update_dict:dict_keys([0, 1])
Th-4: before:   0.09780, after:   0.73217 ==> Update_dict:dict_keys([0, 1])
Th-5: before:   0.14634, after:   0.90785 ==> Update_dict:dict_keys([0, 1])
Th-6: before:   0.05894, after:   0.79023 ==> Update_dict:dict_keys([0, 1])
Th-7: before:   0.10013, after:   0.82852 ==> Update_dict:dict_keys([0, 1])
Th-8: before:   0.00043, after:   0.93388 ==> Update_dict:dict_keys([0, 1])
Th-9: before:   0.70187, after:   0.94095 ==> Update_dict:dict_keys([1])
Th-10: before:   0.88185, after:   0.96219 ==> Update_dict:dict_keys([1])
Th-11: before:   0.49033, after:   0.81116 ==> Update_dict:dict_keys([0, 1])
Th-12: before:   0.00123, after:   0.90453 ==> Update_dict:dict_keys([0, 1])
Th-13: before:   0.27537, after:   0.83382 ==> Update_dict:dict_keys([0, 1])
Th-14: before:   

In [95]:
for i in range(1, 50):
    v_test = weight_dict[i][0]['classifier.fc1.weight']
    a_test, e_test = compresser.compress(v_test)
    g_test = compresser.uncompress(a_test, v_test.shape)
    before = cos_similar(g_test, v_test)
    print(f"{i}-th {before}")

1-th 0.3954788148403168
2-th 0.38685572147369385
3-th 0.38566067814826965
4-th 0.4078787863254547
5-th 0.39017173647880554
6-th 0.38016679883003235
7-th 0.3959686756134033
8-th 0.398908793926239
9-th 0.3915354311466217
10-th 0.4029323160648346
11-th 0.40529799461364746
12-th 0.392452597618103
13-th 0.3967442512512207
14-th 0.3904167711734772
15-th 0.3983614444732666
16-th 0.3839189410209656
17-th 0.39532017707824707
18-th 0.3997047245502472
19-th 0.3955429792404175
20-th 0.3920809030532837
21-th 0.38374629616737366
22-th 0.391897976398468
23-th 0.38946467638015747
24-th 0.3910553753376007
25-th 0.3862779438495636
26-th 0.385198712348938
27-th 0.38837939500808716
28-th 0.3965727388858795
29-th 0.3846668004989624
30-th 0.38984793424606323
31-th 0.3796626925468445
32-th 0.38867342472076416
33-th 0.3884607255458832
34-th 0.3864929974079132
35-th 0.38782989978790283
36-th 0.39541861414909363
37-th 0.3933473825454712
38-th 0.38715362548828125
39-th 0.3878777325153351
40-th 0.3889986276626587