Profiling decorrelation and node perturbation

In [13]:
import numpy as np
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from node_perturbation.node_perturbation import NPLinear
from node_perturbation.utils import np_train
from decorrelation.utils import decor_train
from decorrelation.decorrelation import Decorrelation, DecorLinear
import matplotlib.pyplot as plt
import argparse

import cProfile
import pstats
from torch.profiler import profile, record_function, ProfilerActivity

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

Testing the effect of decprrelating after larger intervals

Profiling node perturbation

In [5]:
args = argparse.Namespace(lr=1e-4, decor_lr=1e-1, kappa=1e-2, epochs=3, full=True)
lossfun = torch.nn.CrossEntropyLoss().to(device)
sampler = torch.distributions.Distribution = torch.distributions.Normal(0.0, 1e-3)

# model = NPLinear(784, 10, sampler=sampler, device=device)
# model, L1, D, T = np_train(args, model, lossfun, train_loader, device)

model = nn.Sequential(Decorrelation(784, decor_lr=args.decor_lr, kappa=args.kappa, full=args.full), NPLinear(784, 10, sampler=sampler, device=device))

cProfile.run('np_train(args, model, lossfun, train_loader, device)', 'restats')
p = pstats.Stats('restats')
p.sort_stats('cumulative').print_stats(10)
p.sort_stats('time').print_stats(10)


epoch 0  	time:0.000 s	bp loss: 5.198018	decorrelation loss: 6.596981
epoch 1  	time:0.658 s	bp loss: 5.074780	decorrelation loss: 2.382838
epoch 2  	time:0.668 s	bp loss: 5.007250	decorrelation loss: 0.513329
epoch 3  	time:0.662 s	bp loss: 4.983179	decorrelation loss: 0.261576
Wed Mar 13 14:21:15 2024    restats

         564699 function calls (559649 primitive calls) in 2.296 seconds

   Ordered by: cumulative time
   List reduced from 306 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    2.296    2.296 {built-in method builtins.exec}
        1    0.000    0.000    2.296    2.296 <string>:1(<module>)
        1    0.008    0.008    2.296    2.296 /Users/marcel.vangerven/Code/github/decorrelation/node_perturbation/utils.py:7(np_train)
       45    0.001    0.000    1.298    0.029 /Users/marcel.vangerven/Code/github/decorrelation/decorrelation/decorrelation.py:14(decor_update)
       45    1.275    0.02

<pstats.Stats at 0x10f973410>

Pytroch profiling. See https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html for more examples

In [6]:
args = argparse.Namespace(lr=1e-4, decor_lr=1e-1, kappa=1e-2, epochs=3, full=True)
lossfun = torch.nn.CrossEntropyLoss().to(device)
sampler = torch.distributions.Distribution = torch.distributions.Normal(0.0, 1e-3)

model = nn.Sequential(Decorrelation(784, decor_lr=args.decor_lr, kappa=args.kappa, full=args.full), NPLinear(784, 10, sampler=sampler, device=device))

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_training"):
        np_train(args, model, lossfun, train_loader, device)

STAGE:2024-03-13 14:21:15 68950:9119183 ActivityProfilerController.cpp:314] Completed Stage: Warm Up


epoch 0  	time:0.000 s	bp loss: 5.074006	decorrelation loss: 6.577436
epoch 1  	time:0.650 s	bp loss: 4.999786	decorrelation loss: 2.381298
epoch 2  	time:0.660 s	bp loss: 4.955505	decorrelation loss: 0.510715
epoch 3  	time:0.652 s	bp loss: 4.917325	decorrelation loss: 0.261559


STAGE:2024-03-13 14:21:18 68950:9119183 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-03-13 14:21:18 68950:9119183 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [7]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         model_training         2.67%      60.702ms       100.00%        2.274s        2.274s             1  
                                           aten::matmul         0.04%       1.013ms        66.10%        1.503s       6.679ms           225  
                                               aten::mm        66.06%        1.502s        66.06%        1.502s       6.675ms           225  
enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        16.32%     370.989ms        25.00%     568.347ms       8.880ms            64  
      

In [8]:
print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  -------------------------------------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                 Input Shapes  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  -------------------------------------------  
                                         model_training         2.67%      60.702ms       100.00%        2.274s        2.274s             1                                           []  
                                           aten::matmul         0.02%     343.000us        45.31%        1.030s      22.892ms            45                     [[784, 784], [784, 784]]  
                                               aten::mm        45