<a href="https://colab.research.google.com/github/Maokami/he-dnn/blob/main/cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pbd
import pbd

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pbd
  Downloading pbd-0.9.zip (3.9 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pbd
  Building wheel for pbd (setup.py) ... [?25l[?25hdone
  Created wheel for pbd: filename=pbd-0.9-py3-none-any.whl size=3901 sha256=b32af25959da448d789a41394604a6633323e6bb8e954907b485c794d26da170
  Stored in directory: /root/.cache/pip/wheels/11/e6/3c/d392e61cd24131b41765167d17227decc96693f8e0625c809f
Successfully built pbd
Installing collected packages: pbd
Successfully installed pbd-0.9


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from tqdm import tqdm

class SquareActivation(nn.Module):
  def forward(self, x):
    return x**2

class QuadraticActivation(nn.Module):
  def __init__(self):
    super(QuadraticActivation, self).__init__()
    self.a = nn.Parameter(torch.randn(1) * 0.01)
    self.b = nn.Parameter(torch.randn(1) * 0.01)

  def forward(self, x):
    return self.a * x**2 + self.b * x

# TODO : Refactoring
class CIFAR10Model(nn.Module):
  def __init__(self, use_poly=False, use_quadratic=False, poly_activation = None):
    super(CIFAR10Model, self).__init__()
    if poly_activation is not None:
      self.activation = poly_activation
      pooling_layer1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
      pooling_layer2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
    elif use_poly:
      self.activation = SquareActivation()
      pooling_layer1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
      pooling_layer2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
    elif use_quadratic:
      self.activation = QuadraticActivation()
      pooling_layer1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
      pooling_layer2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
    else:
      self.activation = nn.ReLU()
      pooling_layer1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
      pooling_layer2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

    self.layers = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
      self.activation,
      nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=0),
      self.activation,
      pooling_layer1,
      nn.Dropout(0.25),
      nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
      self.activation,
      nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
      self.activation,
      pooling_layer2,
      nn.Dropout(0.25),
      nn.Flatten(),
      nn.Linear(6 * 6 * 64, 512),
      self.activation,
      nn.Dropout(0.5),
      nn.Linear(512, 10),
    )


  def forward(self, x):
    return self.layers(x)


In [None]:
autograd.set_detect_anomaly(False)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f099c030550>

In [None]:
batch_size = 32
epochs = 30

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


Files already downloaded and verified
Files already downloaded and verified
cuda:0


In [None]:
def train_dnn(model, trainloader, epochs, device):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    for epoch in range(epochs):
        #if epoch == 4:
        #    autograd.set_detect_anomaly(True)

        running_loss = 0.0
        progress_bar = tqdm(trainloader, desc=f"Epoch {epoch + 1}")
        for i, data in enumerate(progress_bar, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device) 

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            progress_bar.set_postfix({"Loss": running_loss / (i + 1)})

    return model

def dnn_inference(model, testloader, device):
    correct = 0
    total = 0
    model.to(device)
    model.eval()

    with torch.no_grad():
        progress_bar = tqdm(testloader, desc="Inference")
        for data in progress_bar:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [None]:
original_model = CIFAR10Model()
original_model = train_dnn(original_model, trainloader, epochs, device)
accuracy = dnn_inference(original_model, testloader, device)
print("Original Model")
print(original_model)
print(f'Accuracy: {accuracy:.2f}%\n')

Epoch 1: 100%|██████████| 1563/1563 [00:34<00:00, 45.93it/s, Loss=1.54]
Epoch 2: 100%|██████████| 1563/1563 [00:26<00:00, 58.88it/s, Loss=1.17]
Epoch 3: 100%|██████████| 1563/1563 [00:26<00:00, 58.27it/s, Loss=1.01]
Epoch 4: 100%|██████████| 1563/1563 [00:26<00:00, 58.14it/s, Loss=0.919]
Epoch 5: 100%|██████████| 1563/1563 [00:26<00:00, 58.66it/s, Loss=0.859]
Epoch 6: 100%|██████████| 1563/1563 [00:26<00:00, 58.68it/s, Loss=0.815]
Epoch 7: 100%|██████████| 1563/1563 [00:26<00:00, 59.28it/s, Loss=0.78]
Epoch 8: 100%|██████████| 1563/1563 [00:26<00:00, 58.72it/s, Loss=0.747]
Epoch 9: 100%|██████████| 1563/1563 [00:27<00:00, 55.93it/s, Loss=0.727]
Epoch 10: 100%|██████████| 1563/1563 [00:28<00:00, 55.71it/s, Loss=0.71]
Epoch 11: 100%|██████████| 1563/1563 [00:28<00:00, 55.57it/s, Loss=0.688]
Epoch 12: 100%|██████████| 1563/1563 [00:27<00:00, 56.77it/s, Loss=0.672]
Epoch 13: 100%|██████████| 1563/1563 [00:26<00:00, 60.10it/s, Loss=0.659]
Epoch 14: 100%|██████████| 1563/1563 [00:26<00:00, 5

Original Model
CIFAR10Model(
  (activation): ReLU()
  (layers): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Dropout(p=0.25, inplace=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (9): ReLU()
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Dropout(p=0.25, inplace=False)
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Linear(in_features=2304, out_features=512, bias=True)
    (14): ReLU()
    (15): Dropout(p=0.5, inplace=False)
    (16): Linear(in_features=512, out_features=10, bias=True)
  )
)
Accuracy: 79.34%






In [None]:
poly_model = CIFAR10Model(use_poly=True)
poly_model.load_state_dict(original_model.state_dict())
accuracy = dnn_inference(poly_model, testloader, device)
print("Poly Model")
print(poly_model)
print(f'Accuracy: {accuracy:.2f}%\n')

Inference: 100%|██████████| 313/313 [00:03<00:00, 92.01it/s] 

Poly Model
CIFAR10Model(
  (activation): SquareActivation()
  (layers): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): SquareActivation()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): SquareActivation()
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (5): Dropout(p=0.25, inplace=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): SquareActivation()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (9): SquareActivation()
    (10): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (11): Dropout(p=0.25, inplace=False)
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Linear(in_features=2304, out_features=512, bias=True)
    (14): SquareActivation()
    (15): Dropout(p=0.5, inplace=False)
    (16): Linear(in_features=512, out_features=10, bias=True)
  )
)
Accuracy: 9.99%






In [None]:
retrained_model = CIFAR10Model(use_poly=True)
retrained_model = train_dnn(retrained_model, trainloader, epochs, device)
accuracy = dnn_inference(retrained_model, testloader, device)
print("Retrained Model")
print(retrained_model)
print(f'Accuracy: {accuracy:.2f}%\n')

Epoch 1: 100%|██████████| 1563/1563 [00:26<00:00, 59.06it/s, Loss=462]
Epoch 2: 100%|██████████| 1563/1563 [00:26<00:00, 59.15it/s, Loss=1.37e+5]
Epoch 3: 100%|██████████| 1563/1563 [00:26<00:00, 58.75it/s, Loss=16.9]
Epoch 4: 100%|██████████| 1563/1563 [00:26<00:00, 59.07it/s, Loss=7.23]
Epoch 5: 100%|██████████| 1563/1563 [00:27<00:00, 57.58it/s, Loss=3.85]
Epoch 6: 100%|██████████| 1563/1563 [00:26<00:00, 58.37it/s, Loss=3.02]
Epoch 7: 100%|██████████| 1563/1563 [00:26<00:00, 59.06it/s, Loss=2.73]
Epoch 8: 100%|██████████| 1563/1563 [00:26<00:00, 59.23it/s, Loss=2.48]
Epoch 9: 100%|██████████| 1563/1563 [00:26<00:00, 58.95it/s, Loss=2.38]
Epoch 10: 100%|██████████| 1563/1563 [00:26<00:00, 58.93it/s, Loss=2.32]
Epoch 11: 100%|██████████| 1563/1563 [00:26<00:00, 59.48it/s, Loss=2.29]
Epoch 12: 100%|██████████| 1563/1563 [00:26<00:00, 59.15it/s, Loss=2.29]
Epoch 13: 100%|██████████| 1563/1563 [00:26<00:00, 59.53it/s, Loss=2.29]
Epoch 14: 100%|██████████| 1563/1563 [00:26<00:00, 59.68it

Retrained Model
CIFAR10Model(
  (activation): SquareActivation()
  (layers): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): SquareActivation()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): SquareActivation()
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (5): Dropout(p=0.25, inplace=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): SquareActivation()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (9): SquareActivation()
    (10): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (11): Dropout(p=0.25, inplace=False)
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Linear(in_features=2304, out_features=512, bias=True)
    (14): SquareActivation()
    (15): Dropout(p=0.5, inplace=False)
    (16): Linear(in_features=512, out_features=10, bias=True)
  )
)
Accuracy: 31.14%






In [None]:
epochs = 50
quadratic_model = CIFAR10Model(use_quadratic=True)
print(quadratic_model.activation.a)
print(quadratic_model.activation.b)
quadratic_model = train_dnn(quadratic_model, trainloader, epochs, device)
accuracy = dnn_inference(quadratic_model, testloader, device)
print(quadratic_model.activation.a)
print(quadratic_model.activation.b)
print("Quadratic Model")
print(quadratic_model)
print(f'Accuracy: {accuracy:.2f}%\n')

Parameter containing:
tensor([0.0081], requires_grad=True)
Parameter containing:
tensor([0.0070], requires_grad=True)


Epoch 1: 100%|██████████| 1563/1563 [00:27<00:00, 56.48it/s, Loss=2.3]
Epoch 2: 100%|██████████| 1563/1563 [00:27<00:00, 56.24it/s, Loss=2.3]
Epoch 3: 100%|██████████| 1563/1563 [00:27<00:00, 56.24it/s, Loss=2.3]
Epoch 4: 100%|██████████| 1563/1563 [00:28<00:00, 55.67it/s, Loss=2.3]
Epoch 5: 100%|██████████| 1563/1563 [00:28<00:00, 55.65it/s, Loss=2.3]
Epoch 6: 100%|██████████| 1563/1563 [00:28<00:00, 55.79it/s, Loss=2.29]
Epoch 7: 100%|██████████| 1563/1563 [00:28<00:00, 55.16it/s, Loss=1.84]
Epoch 8: 100%|██████████| 1563/1563 [00:28<00:00, 54.98it/s, Loss=1.61]
Epoch 9: 100%|██████████| 1563/1563 [00:28<00:00, 55.45it/s, Loss=1.46]
Epoch 10: 100%|██████████| 1563/1563 [00:28<00:00, 55.80it/s, Loss=3.75]
Epoch 11: 100%|██████████| 1563/1563 [00:27<00:00, 56.51it/s, Loss=1.9]
Epoch 12: 100%|██████████| 1563/1563 [00:28<00:00, 55.70it/s, Loss=1.84]
Epoch 13: 100%|██████████| 1563/1563 [00:28<00:00, 55.80it/s, Loss=1.8]
Epoch 14: 100%|██████████| 1563/1563 [00:27<00:00, 56.15it/s, Loss=

Parameter containing:
tensor([0.0195], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([-0.0850], device='cuda:0', requires_grad=True)
Quadratic Model
CIFAR10Model(
  (activation): QuadraticActivation()
  (layers): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): QuadraticActivation()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): QuadraticActivation()
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (5): Dropout(p=0.25, inplace=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): QuadraticActivation()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (9): QuadraticActivation()
    (10): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (11): Dropout(p=0.25, inplace=False)
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Linear(in_features=2304, out_features=512, bias=True)
    (14): QuadraticActivation()
    (15): Dropout(p=0.5, inplace=

In [None]:
# alpha = 13, B = 50
# Source : Precise Approximation of Convolutional Neural Networks for Homomorphically Encrypted Data
epochs = 30
class SnuActivation(nn.Module):
  def forward(self, x):
    x = x / 1000
    y = (24.5589415425004*x + 
         (-669.660449716894)*x**3 + 
         6672.99848301339*x**5 + 
         (-30603.6656163898)*x**7 +
         73188.4032987787*x**9 +
         (-94443.3217050084)*x**11 +
         62325.4094212546*x**13 +
         (-16494.6744117805)*x**15 
         )
    y = (9.35625636035439*y + 
         (-59.16389639336264)*y**3 + 
         148.860930626448*y**5 + 
         (-175.812874878582)*y**7 +
         109.111299685955*y**9 +
         (-36.6768839978755)*y**11 +
         6.31846290311294*y**13 +
         (-0.437113415082177)*y**15 
         )
    y = (5.07813569758861*y + 
         (-30.7329918137186)*y**3+
         144.109746812809*y**5+
         (-459.661688826142)*y**7+
         1021.52064470459*y**9+
         (-1620.56256708877)*y**11+
         1864.67646416570*y**13+
         (-1567.49300877143)*y**15+
         960.970309093422*y**17+
         (-424.326161871646)*y**19+
         131.278509256003*y**21+
         (-26.9812576626115)*y**23+
         3.30651387315565*y**25+
         (-0.182742944627533)*y**27
         )
    x = (x + x * y)/2 * 1000
    return x

snu_model = CIFAR10Model(poly_activation=SnuActivation())
snu_model = train_dnn(snu_model, trainloader, epochs, device)
accuracy = dnn_inference(snu_model, testloader, device)
print("Snu Model")
print(snu_model)
print(f'Accuracy: {accuracy:.2f}%\n')

Epoch 1: 100%|██████████| 1563/1563 [01:04<00:00, 24.25it/s, Loss=1.44]
Epoch 2: 100%|██████████| 1563/1563 [01:02<00:00, 25.12it/s, Loss=0.98]
Epoch 3: 100%|██████████| 1563/1563 [01:02<00:00, 25.11it/s, Loss=0.797]
Epoch 4: 100%|██████████| 1563/1563 [01:03<00:00, 24.62it/s, Loss=0.694]
Epoch 5: 100%|██████████| 1563/1563 [01:02<00:00, 24.85it/s, Loss=0.615]
Epoch 6: 100%|██████████| 1563/1563 [01:02<00:00, 24.96it/s, Loss=0.561]
Epoch 7: 100%|██████████| 1563/1563 [01:02<00:00, 24.86it/s, Loss=0.513]
Epoch 8: 100%|██████████| 1563/1563 [01:02<00:00, 24.81it/s, Loss=0.472]
Epoch 9: 100%|██████████| 1563/1563 [01:02<00:00, 24.91it/s, Loss=0.441]
Epoch 10: 100%|██████████| 1563/1563 [01:03<00:00, 24.64it/s, Loss=0.422]
Epoch 11: 100%|██████████| 1563/1563 [01:03<00:00, 24.63it/s, Loss=0.396]
Epoch 12: 100%|██████████| 1563/1563 [01:03<00:00, 24.62it/s, Loss=0.376]
Epoch 13: 100%|██████████| 1563/1563 [01:03<00:00, 24.47it/s, Loss=0.359]
Epoch 14: 100%|██████████| 1563/1563 [01:03<00:00

Snu Model
CIFAR10Model(
  (activation): SnuActivation()
  (layers): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): SnuActivation()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): SnuActivation()
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (5): Dropout(p=0.25, inplace=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): SnuActivation()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (9): SnuActivation()
    (10): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (11): Dropout(p=0.25, inplace=False)
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Linear(in_features=2304, out_features=512, bias=True)
    (14): SnuActivation()
    (15): Dropout(p=0.5, inplace=False)
    (16): Linear(in_features=512, out_features=10, bias=True)
  )
)
Accuracy: 80.99%




