In [1]:
import torch
import numpy as np
from ALOptimizer import ALOptimizer

In [2]:
cuda = False
torch.manual_seed(42)

<torch._C.Generator at 0x2bd57d36310>

In [3]:
import torchvision
import torchvision.transforms.v2 as transforms

dataset = torchvision.datasets.MNIST('/data', download=True, train=True)

batch_size = 16

transform = transforms.Compose(
    [transforms.ToImage(),
     transforms.ToDtype(torch.float32, scale=True)])
trainset = torchvision.datasets.MNIST('/data', train=True, transform=transform)
# trainset = torch.utils.data.
trainset = torch.utils.data.Subset(trainset, np.arange(0, 20000))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=0,
                                          generator=torch.Generator(device='cuda') if cuda else None)

testset = torchvision.datasets.MNIST('/data', train=False, transform=transform)

true_test, constr_test = torch.utils.data.random_split(testset, [0.95, 0.05])
constr_test_loader = torch.utils.data.DataLoader(constr_test, batch_size=batch_size,
                                          shuffle=False, num_workers=0,
                                          generator=torch.Generator(device='cuda') if cuda else None)
true_test_loader = torch.utils.data.DataLoader(true_test, batch_size=batch_size,
                                          shuffle=False, num_workers=0,
                                          generator=torch.Generator(device='cuda') if cuda else None)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, n_in, n_out):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        # self.double()
    
    def forward(self, input):
        c1 = F.relu(self.conv1(input))
        s2 = F.max_pool2d(c1, (2, 2))
        c3 = F.relu(self.conv2(s2))
        s4 = F.max_pool2d(c3, 2)
        s4 = torch.flatten(s4, 1)
        f5 = F.relu(self.fc1(s4))   
        f6 = F.relu(self.fc2(f5))
        output = self.fc3(f6)
        return output

***
***
A constraint on the loss on a separate subset of data. This is a stochastic constraint.

In [5]:
def test_loss_constr(net, loss_fn, testset, threshold):
    loss = 0
    
    for i, (inputs, labels) in enumerate(testset):
        out = net.forward(inputs)
        loss += loss_fn(out, labels)
    if i > 0:
        loss /= i
    return torch.max((loss-threshold), torch.zeros(1))

In [6]:
n_classes = 10
class_test_net = Net(1, n_classes)
loss = torch.nn.CrossEntropyLoss()

def constraint(net):
    return test_loss_constr(net, loss, constr_test_loader, 0.1)

alo = ALOptimizer(net=class_test_net, loss_fn=loss, m=1, constraint_fn=constraint)

In [None]:
alo.optimize(trainloader, maxiter=3, epochs=3)

0, 1, 2.299325942993164, 2.272899866104126

In [8]:
acc = 0
with torch.no_grad():
    for i, data in enumerate(true_test):
        inputs, labels = data
        out = class_test_net.forward(inputs.unsqueeze(0))
        if np.argmax(out.detach().numpy()) == labels:
            acc+=1

acc/len(true_test)

0.9806315789473684

In [9]:
with torch.no_grad():
    x = test_loss_constr(class_test_net, loss, constr_test_loader, 0)
    print(x)

tensor([0.0446])


In [10]:
with torch.no_grad():
    x = test_loss_constr(class_test_net, loss, true_test_loader, 0)
    print(x)

tensor([0.0688])


***
***

An L2 constraint on the weights. This is a deterministic constraint.

We get

$$
L := l(outputs, labels) + \lambda*h(W) + 0.5 \lambda r * h(W)^2 
$$

where

$$
h(W) := \begin{cases} 0 \qquad \textrm{if} \quad  {||W||_2}^2 - c \leq 0 \\ {||W||_2}^2 -c  \quad \textrm{otherwise} \end{cases}
$$

In [6]:
def total_w_l2_constr(params, c):
    l2 = 0
    for param in params:
        l2 += torch.sum(torch.square(param))
    cval = torch.max(l2 - c, torch.zeros(1, dtype=param.dtype))
    return cval

In [7]:
n_classes = 10
class_test_net = Net(1, n_classes)
loss = torch.nn.CrossEntropyLoss()

def constraint(net):
    return total_w_l2_constr(net.parameters(), 15)

alo = ALOptimizer(net=class_test_net, loss_fn=loss, m=1, constraint_fn=constraint, lr=0.005)

In [7]:
alo.optimize(trainloader, maxiter=5, epochs=5)

-------, 0.5727880597114563, 0.0796251296997070345


-------, 0.6938788294792175, 0.0863895416259765645


-------, 0.6064719557762146, 0.0888954162597656625


-------, 0.6860342025756836, 0.0273464202880859462


-------, 0.6928727030754089, 0.0074728012084961262




In [20]:
l2 = 0
for param in class_test_net.parameters():
    l2 += torch.sum(torch.square(param))
l2

tensor(14.6636, grad_fn=<AddBackward0>)

In [21]:
acc = 0
with torch.no_grad():
    for i, data in enumerate(true_test):
        inputs, labels = data
        out = class_test_net(inputs.unsqueeze(0))
        if np.argmax(out.detach().numpy()) == labels:
            acc+=1

acc/len(true_test)

0.7371578947368421

***


In [8]:
n_classes = 10
class_test_net = Net(1, n_classes)
loss = torch.nn.CrossEntropyLoss()

def constraint(net):
    return total_w_l2_constr(net.parameters(), 15)

alo = ALOptimizer(net=class_test_net, loss_fn=loss, m=1, constraint_fn=constraint, lr=0.005)

In [9]:
alo.optimize_cond(trainloader, maxiter=5, epochs=5)

4, 4, 1249, 0.42003297805786133, 0.05862693786621094555

***
***
Pytorch L2 regularization

In [27]:
n_classes = 10
reg_net = Net(1, n_classes)
loss = torch.nn.CrossEntropyLoss()
op = torch.optim.Adam(reg_net.parameters(), lr=0.005, weight_decay=1e-2)

for epoch in range(5):
    for i, data in enumerate(trainloader):
        op.zero_grad()
        inputs, labels = data
        outputs = reg_net(inputs)
        loss_eval = loss(outputs, labels)
        loss_eval.backward()
        op.step()

In [28]:
l2 = 0
for param in reg_net.parameters():
    l2 += torch.sum(torch.square(param))
l2

tensor(48.7833, grad_fn=<AddBackward0>)

In [29]:
acc = 0
with torch.no_grad():
    for i, data in enumerate(true_test):
        inputs, labels = data
        out = reg_net(inputs.unsqueeze(0))
        if np.argmax(out.detach().numpy()) == labels:
            acc+=1

acc/len(true_test)

0.9395789473684211

***
***
profiling

In [17]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        total_w_l2_constr(class_test_net.parameters(), 15)

In [18]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
        model_inference        50.87%       1.959ms       100.00%       3.851ms       3.851ms             1  
           aten::square         5.67%     218.400us        21.42%     824.800us      82.480us            10  
              aten::pow        15.58%     600.100us        15.75%     606.400us      60.640us            10  
              aten::sum        13.14%     505.900us        13.54%     521.600us      52.160us            10  
              aten::add         0.61%      23.600us        11.58%     446.000us     446.000us             1  
               aten::to         5.12%     197.000us        11.36%     437.300us      36.442us            12  
         a