In [1]:
from utils import *
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from pickle import dump
import os 
#os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:21"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def Coordinate_descend_genep(dim, comp_dim):
    assert dim >= comp_dim, "compression dimension must be smaller than dimension"
    ide = torch.eye(dim, requires_grad=False)
    select_col = torch.randperm(dim)[:comp_dim]
    sign = torch.randint(0, 2, (comp_dim, ))
    sign = sign * 2 - 1
    # XXX make clear whether PTP is I or PPT is I
    P = ide[:, select_col] * sign
    return P

def mem():
    #torch.cuda.empty_cache()
    print('memory allocated: ' + str((torch.cuda.memory_allocated() / 1024)) )
    print('memory reserved: ' + str(torch.cuda.memory_reserved() / 1024) )
    print('max memory allocated: ' + str(torch.cuda.max_memory_allocated() / 1024) )
    print('max memory reserved: ' + str(torch.cuda.max_memory_reserved() / 1024) )

class SubScafLinearTest(nn.Module):
    """
    Linear network with compressed dimension
    """
    def __init__(self, comp_dim: int, comp_mat: torch.Tensor, wraped_model: nn.Linear):
        self.comp_mat = comp_mat
        self.comp_dim = comp_dim
        device = wraped_model.weight.device
        dtype = wraped_model.weight.dtype
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.x = wraped_model.weight.detach().clone()
        super().__init__()
        self.b = nn.Parameter(torch.zeros((comp_dim, wraped_model.in_features), **factory_kwargs))
    
    def comp_mul(self, b):
        return self.comp_mat @ b + self.x

    def forward(self, input):
        #def compute_linear(input):
            #weight = self.comp_mat @ self.b + self.x
            #return F.linear(input, weight)
        #return checkpoint(compute_linear, input, use_reentrant=False)
        mem()
        x = F.linear(self.comp_mat, self.b.T)
        print(x.grad_fn)
        mem()
        x = self.x + x
        print(x.grad_fn)
        mem()
        output = F.linear(input, x)
        print(output.grad_fn)
        mem()
        return output 

class SubScafLayer(nn.Module):
    def __init__(self, comp_dim, comp_mat, wraped_model):
        super().__init__()
        self.layer1 = SubScafLinearTest(comp_dim, comp_mat, wraped_model)
        self.layer2 = SubScafLinearTest(comp_dim, comp_mat, wraped_model)
    
    def forward(self, input):
        x = checkpoint(self.layer1, input)
        x = self.layer2(x)
        return x

In [3]:
comp_dim = 64
wraped_module = nn.Linear(512, 512, bias=False)
comp_mat = Coordinate_descend_genep(wraped_module.out_features, comp_dim)
model = SubScafLinearTest(comp_dim, comp_mat, wraped_module)

activation_values = []
def hook(module, input, output):
    activation_values.append(output.detach().clone())
handle = model.register_forward_hook(hook)

In [3]:
#torch.cuda.memory._record_memory_history(enabled='all')
device = "cuda:0"
comp_dim = 64
wraped_module = nn.Linear(512, 512, bias=False).to(device)
comp_mat = Coordinate_descend_genep(wraped_module.out_features, comp_dim).to(device)
model = SubScafLinearTest(comp_dim, comp_mat, wraped_module).to(device)
input_data = torch.eye(512).to(device)
output = model(input_data)
mem()
output.sum().backward()
opt = torch.optim.SGD(model.parameters())
opt.zero_grad()
#s = torch.cuda.memory._snapshot()
#with open(f"snapshot.pickle", "wb") as f:
    #dump(s, f)

memory allocated: 3328.0
<MmBackward0 object at 0x78015159fe80>
memory allocated: 12672.0
<AddBackward0 object at 0x78015159fe50>
memory allocated: 12672.0
<MmBackward0 object at 0x78023ebacd30>
memory allocated: 13696.0
memory allocated: 12672.0


In [27]:

output = model(input_data)
mem()
output.sum().backward()
del output
mem()

memory allocated: 19968.0
memory allocated: 19072.0


In [3]:
#torch.cuda.memory._record_memory_history(enabled='all')
device = "cuda:0"
comp_dim = 64
wraped_module = nn.Linear(512, 512, bias=False).to(device)
comp_mat = Coordinate_descend_genep(wraped_module.out_features, comp_dim).to(device)
model = SubScafLayer(comp_dim, comp_mat, wraped_module).to(device)
del wraped_module
mem()
input_data = torch.eye(512).to(device)
mem()
output = model(input_data)
mem()
#del output
output.sum().backward()
mem()
#s = torch.cuda.memory._snapshot()
#with open(f"snapshot.pickle", "wb") as f:
    #dump(s, f)

memory allocated: 2432.0
memory reserved: 4096.0
max memory allocated: 3456.0
max memory reserved: 4096.0
memory allocated: 3456.0
memory reserved: 4096.0
max memory allocated: 3456.0
max memory reserved: 4096.0
memory allocated: 3456.0
memory reserved: 4096.0
max memory allocated: 3456.0
max memory reserved: 4096.0
None
memory allocated: 12800.0
memory reserved: 26624.0
max memory allocated: 12800.0
max memory reserved: 26624.0
None
memory allocated: 12800.0
memory reserved: 26624.0
max memory allocated: 13824.0
max memory reserved: 26624.0
None
memory allocated: 13824.0
memory reserved: 26624.0
max memory allocated: 13824.0
max memory reserved: 26624.0
memory allocated: 12800.0
memory reserved: 26624.0
max memory allocated: 13824.0
max memory reserved: 26624.0
<MmBackward0 object at 0x7ec61479b6d0>
memory allocated: 13824.0
memory reserved: 26624.0
max memory allocated: 13824.0
max memory reserved: 26624.0
<AddBackward0 object at 0x7ec61479a170>
memory allocated: 13824.0
memory reser



In [3]:
device = "cuda:0"
mem()
model = nn.Linear(32, 32, bias=False).to(device)
mem()
input_data = torch.randn((1, 32)).to(device)
mem()
output = model(input_data)
#with torch.no_grad():
    #output = model(input_data)
mem()

output.sum().backward()

mem()

memory allocated: 0.0
memory reserved: 0.0
max memory allocated: 0.0
max memory reserved: 0.0
memory allocated: 4.0
memory reserved: 2048.0
max memory allocated: 4.0
max memory reserved: 2048.0
memory allocated: 4.5
memory reserved: 2048.0
max memory allocated: 4.5
max memory reserved: 2048.0
memory allocated: 8325.0
memory reserved: 22528.0
max memory allocated: 8325.0
max memory reserved: 22528.0
memory allocated: 16649.0
memory reserved: 22528.0
max memory allocated: 16650.5
max memory reserved: 22528.0


In [1]:
from transformers import (
    AutoTokenizer,
    AutoConfig,
    LlamaConfig,
    LlamaForCausalLM,
    DataCollatorForLanguageModeling,
    get_cosine_schedule_with_warmup,
)
model_config = AutoConfig.from_pretrained("configs/llama_1b.json")
model = LlamaForCausalLM(model_config)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048, padding_idx=31999)
    (layers): ModuleList(
      (0-23): 24 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5461, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5461, bias=False)
          (down_proj): Linear(in_features=5461, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-0

In [23]:
import numpy as np
def Coordinate_descend_genep(d, r, n=1):
    sum_p = np.zeros((d, r)) 
    for _ in range(n):
        ide = np.eye(d)
        col_num = np.arange(d)
        select_col = np.random.choice(col_num, r, replace=False)
        sign = np.random.choice([-1, 1], r)
        P = np.sqrt(d / r) * ide[:, select_col] * sign
        #P = ide[:, select_col] * sign
        sum_p += P
    P = sum_p / n
    return P


def Spherical_smoothing_genep(d, r, n=1):
    sum_p = np.zeros((d, r)) 
    for _ in range(n):
        z = np.random.randn(d, d)
        Q, R = np.linalg.qr(z)
        D = np.diag(np.sign(np.diag(R)))
        Q = Q @ D
        R = D @ R
        assert np.allclose(Q @ R, z, atol=1e-7), "the QR decomposion is not accuracy"
        #P = np.sqrt(d / r) * Q[:, :r]
        P = Q[:, :r]
        sum_p += P
    P = sum_p / n
    return P

def gene_random_matrix(in_dim, out_dim):
    return np.random.randn(in_dim, out_dim) / np.sqrt(out_dim)

In [10]:
import numpy as np
worker_num = 20
total_sample_size = 1000
d = 10
def genHeLS(total_sample_size, d, worker_num):
    assert total_sample_size % worker_num == 0, "please make sure total_sample can be divided for each worker equally"
    mean = np.random.uniform(low=-1, high=1, size=worker_num)
    scale = np.sqrt(np.random.uniform(low=0.5, high=1.5, size=worker_num))
    a_list = []
    for i in range(worker_num):
        a_list.append(np.random.normal(loc=mean[i], scale=scale[i], size=(total_sample_size // worker_num, d)))
    a = np.concatenate(a_list, axis=0)
    ans = np.random.randn(d, 1)
    b = a @ ans + 0.1 * np.random.randn(total_sample_size, 1)
    return a, b

def solLS(A, b):
    x_sol = np.linalg.inv(A.T@A)@(A.T@b)
    print(A.T@(A@x_sol-b) / len(A))
    return x_sol

In [25]:
pcd = Coordinate_descend_genep(10, 5)
pss = Spherical_smoothing_genep(10, 5)
prd = gene_random_matrix(10, 5)

pcd.T @ pcd

array([[2., 0., 0., 0., 0.],
       [0., 2., 0., 0., 0.],
       [0., 0., 2., 0., 0.],
       [0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 2.]])

In [6]:
np.sum(prd, axis=0)

array([-0.07551295, -0.43366106,  0.36777972,  0.63740934,  0.45284414])

In [7]:
prd

array([[-0.10841447, -0.65429265,  0.5515432 ,  0.38592721, -0.40341137],
       [ 0.35419363,  0.3494931 ,  0.40770049,  0.2784732 ,  0.55071407],
       [-0.06773897,  0.57525541, -0.29882595, -0.2461836 , -0.54525363],
       [ 0.51437674,  0.01865143,  0.06029494, -0.70781416,  0.02839217],
       [-0.02919308,  0.38172098,  0.05224194,  0.62015402, -0.08248262],
       [ 0.50908027,  0.37083368,  0.15635518,  0.49316109,  0.36576279],
       [ 0.08453869, -0.14064056, -0.15259858, -0.52102426, -0.10319151],
       [-0.38667629, -0.79663312, -0.52137695,  0.11998081,  0.28849406],
       [-0.46859249,  0.04807989, -0.37163389,  0.27534065, -0.05068043],
       [-0.47708699, -0.58612922,  0.48407934, -0.06060563,  0.40450061]])