In [1]:
from utils import *
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from pickle import dump
import os 
#os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:21"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def Coordinate_descend_genep(dim, comp_dim):
    assert dim >= comp_dim, "compression dimension must be smaller than dimension"
    ide = torch.eye(dim, requires_grad=False)
    select_col = torch.randperm(dim)[:comp_dim]
    sign = torch.randint(0, 2, (comp_dim, ))
    sign = sign * 2 - 1
    # XXX make clear whether PTP is I or PPT is I
    P = ide[:, select_col] * sign
    return P

def mem():
    #torch.cuda.empty_cache()
    print('memory allocated: ' + str((torch.cuda.memory_allocated() / 1024)) )
    print('memory reserved: ' + str(torch.cuda.memory_reserved() / 1024) )
    print('max memory allocated: ' + str(torch.cuda.max_memory_allocated() / 1024) )
    print('max memory reserved: ' + str(torch.cuda.max_memory_reserved() / 1024) )

class SubScafLinearTest(nn.Module):
    """
    Linear network with compressed dimension
    """
    def __init__(self, comp_dim: int, comp_mat: torch.Tensor, wraped_model: nn.Linear):
        self.comp_mat = comp_mat
        self.comp_dim = comp_dim
        device = wraped_model.weight.device
        dtype = wraped_model.weight.dtype
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.x = wraped_model.weight.detach().clone()
        super().__init__()
        self.b = nn.Parameter(torch.zeros((comp_dim, wraped_model.in_features), **factory_kwargs))
    
    def comp_mul(self, b):
        return self.comp_mat @ b + self.x

    def forward(self, input):
        #def compute_linear(input):
            #weight = self.comp_mat @ self.b + self.x
            #return F.linear(input, weight)
        #return checkpoint(compute_linear, input, use_reentrant=False)
        mem()
        x = F.linear(self.comp_mat, self.b.T)
        print(x.grad_fn)
        mem()
        x = self.x + x
        print(x.grad_fn)
        mem()
        output = F.linear(input, x)
        print(output.grad_fn)
        mem()
        return output 

class SubScafLayer(nn.Module):
    def __init__(self, comp_dim, comp_mat, wraped_model):
        super().__init__()
        self.layer1 = SubScafLinearTest(comp_dim, comp_mat, wraped_model)
        self.layer2 = SubScafLinearTest(comp_dim, comp_mat, wraped_model)
    
    def forward(self, input):
        x = checkpoint(self.layer1, input)
        x = self.layer2(x)
        return x

In [5]:
dir(model)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_compiled_call_impl',
 '_forward_hooks',
 '_forward_hooks_always_called',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_se

In [3]:
comp_dim = 64
wraped_module = nn.Linear(512, 512, bias=False)
comp_mat = Coordinate_descend_genep(wraped_module.out_features, comp_dim)
model = SubScafLinearTest(comp_dim, comp_mat, wraped_module)

activation_values = []
def hook(module, input, output):
    activation_values.append(output.detach().clone())
handle = model.register_forward_hook(hook)

In [3]:
#torch.cuda.memory._record_memory_history(enabled='all')
device = "cuda:0"
comp_dim = 64
wraped_module = nn.Linear(512, 512, bias=False).to(device)
comp_mat = Coordinate_descend_genep(wraped_module.out_features, comp_dim).to(device)
model = SubScafLinearTest(comp_dim, comp_mat, wraped_module).to(device)
input_data = torch.eye(512).to(device)
output = model(input_data)
mem()
output.sum().backward()
opt = torch.optim.SGD(model.parameters())
opt.zero_grad()
#s = torch.cuda.memory._snapshot()
#with open(f"snapshot.pickle", "wb") as f:
    #dump(s, f)

memory allocated: 3328.0
<MmBackward0 object at 0x78015159fe80>
memory allocated: 12672.0
<AddBackward0 object at 0x78015159fe50>
memory allocated: 12672.0
<MmBackward0 object at 0x78023ebacd30>
memory allocated: 13696.0
memory allocated: 12672.0


In [27]:

output = model(input_data)
mem()
output.sum().backward()
del output
mem()

memory allocated: 19968.0
memory allocated: 19072.0


In [3]:
#torch.cuda.memory._record_memory_history(enabled='all')
device = "cuda:0"
comp_dim = 64
wraped_module = nn.Linear(512, 512, bias=False).to(device)
comp_mat = Coordinate_descend_genep(wraped_module.out_features, comp_dim).to(device)
model = SubScafLayer(comp_dim, comp_mat, wraped_module).to(device)
del wraped_module
mem()
input_data = torch.eye(512).to(device)
mem()
output = model(input_data)
mem()
#del output
output.sum().backward()
mem()
#s = torch.cuda.memory._snapshot()
#with open(f"snapshot.pickle", "wb") as f:
    #dump(s, f)

memory allocated: 2432.0
memory reserved: 4096.0
max memory allocated: 3456.0
max memory reserved: 4096.0
memory allocated: 3456.0
memory reserved: 4096.0
max memory allocated: 3456.0
max memory reserved: 4096.0
memory allocated: 3456.0
memory reserved: 4096.0
max memory allocated: 3456.0
max memory reserved: 4096.0
None
memory allocated: 12800.0
memory reserved: 26624.0
max memory allocated: 12800.0
max memory reserved: 26624.0
None
memory allocated: 12800.0
memory reserved: 26624.0
max memory allocated: 13824.0
max memory reserved: 26624.0
None
memory allocated: 13824.0
memory reserved: 26624.0
max memory allocated: 13824.0
max memory reserved: 26624.0
memory allocated: 12800.0
memory reserved: 26624.0
max memory allocated: 13824.0
max memory reserved: 26624.0
<MmBackward0 object at 0x7ec61479b6d0>
memory allocated: 13824.0
memory reserved: 26624.0
max memory allocated: 13824.0
max memory reserved: 26624.0
<AddBackward0 object at 0x7ec61479a170>
memory allocated: 13824.0
memory reser



In [3]:
device = "cuda:0"
mem()
model = nn.Linear(32, 32, bias=False).to(device)
mem()
input_data = torch.randn((1, 32)).to(device)
mem()
output = model(input_data)
#with torch.no_grad():
    #output = model(input_data)
mem()

output.sum().backward()

mem()

memory allocated: 0.0
memory reserved: 0.0
max memory allocated: 0.0
max memory reserved: 0.0
memory allocated: 4.0
memory reserved: 2048.0
max memory allocated: 4.0
max memory reserved: 2048.0
memory allocated: 4.5
memory reserved: 2048.0
max memory allocated: 4.5
max memory reserved: 2048.0
memory allocated: 8325.0
memory reserved: 22528.0
max memory allocated: 8325.0
max memory reserved: 22528.0
memory allocated: 16649.0
memory reserved: 22528.0
max memory allocated: 16650.5
max memory reserved: 22528.0


In [9]:
model.weight.device

device(type='cuda', index=0)

In [None]:
del model