In [1]:
# Imports
import torch
import torch.nn.functional as F  # Parameterless functions, like (some) activation functions
import torchvision.datasets as datasets  # Standard datasets
import torchvision.transforms as transforms  # Transformations we can perform on our dataset for augmentation
from torch import optim  # For optimizers like SGD, Adam, etc.
from torch import nn  # All neural network modules
from torch.utils.data import (
    DataLoader,
)  # Gives easier dataset managment by creating mini batches etc.
from tqdm import tqdm  # For nice progress bar!

In [2]:
import numpy as np

**Standard Relu**

In [3]:
# Simple CNN
class CNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=8,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=8,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.fc1 = nn.Linear(16 * 7 * 7, 250) # replace this with siren.
        self.fc2 = nn.Linear(250,250)
        self.fc3 = nn.Linear(250, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x)) # replace this with siren.
        x = F.relu(self.fc2(x))
        x = self.fc2(x)
        return x


In [4]:
device = torch.device('cuda')

In [5]:
# Hyperparameters
in_channels = 1
num_classes = 10
learning_rate = 3e-4 # karpathy's constant
batch_size = 64
num_epochs = 3

In [6]:
# Load Data
train_dataset = datasets.MNIST(
    root="dataset/", train=True, transform=transforms.ToTensor(), download=True
)
test_dataset = datasets.MNIST(
    root="dataset/", train=False, transform=transforms.ToTensor(), download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 233515000.56it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 101453679.92it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 75361135.41it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 16858875.02it/s]


Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [7]:
# Initialize network
model_relu = CNN(in_channels=in_channels, num_classes=num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_relu.parameters(), lr=learning_rate)

In [8]:
# Train Network
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
        # Get data to cuda if possible
        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model_relu(data)
        loss = criterion(scores, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()

100%|██████████| 938/938 [00:14<00:00, 63.29it/s] 
100%|██████████| 938/938 [00:09<00:00, 98.19it/s] 
100%|██████████| 938/938 [00:10<00:00, 89.19it/s] 


In [9]:
# Check accuracy on training & test to see how good our model
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)

            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

    model.train()
    return num_correct / num_samples


print(f"Accuracy on training set: {check_accuracy(train_loader, model_relu)*100:.2f}")
print(f"Accuracy on test set: {check_accuracy(test_loader, model_relu)*100:.2f}")

Accuracy on training set: 96.97
Accuracy on test set: 96.89


In [10]:
def compute_gradient (func,inp, **kwargs):
    inp.requires_grad = True
    loss = func(inp, **kwargs)
    loss.backward()
    inp.requires_grad = False
    return inp.grad.data

In [11]:
def func(inp, model = None, target = None):
    out = model(inp)
    loss = torch.nn.functional.nll_loss(out,torch.LongTensor([target]).to(device))
    #print(f"Loss:  {loss.item()}")
    return loss

In [12]:
def attack (tensor, model, eps = 1e-3, n_iter = 5000):
    number = 5000
    new_tensor = tensor.unsqueeze(0).detach().clone()
    new_tensor = new_tensor.to(device = device)
    orig_prediction = model(new_tensor).argmax()
    #print(f"Original Prediction: {orig_prediction.item()}")

    for i in tqdm(range(n_iter)):
        model.zero_grad()

        grad = compute_gradient(func, new_tensor, model = model, target = orig_prediction.item())
        #new_tensor = torch.clamp(new_tensor + eps * grad.sign(), -2, 2)
        new_tensor = new_tensor + eps * grad.sign()
        new_prediction = model(new_tensor).argmax()

        if new_prediction != orig_prediction:
            #print(f"New Prediction: {new_prediction.item()}")
            number = i
            break
    return number


In [13]:
num1 = attack(train_dataset[0][0], model_relu)
num1

  1%|          | 48/5000 [00:00<00:27, 183.10it/s]


48

In [14]:
# now its time for siren.
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
    # hyperparameter.

    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30,is_linear = False):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.is_linear = is_linear
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        if self.is_linear:
          return (self.omega_0 * self.linear(input))
        return torch.sin(self.omega_0 * self.linear(input))

    def forward_with_intermediate(self, input):
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate


class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()

        self.net = []
        self.net.append(SineLayer(in_features, hidden_features,
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features,
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)

            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features,
                                      is_first=False, omega_0=hidden_omega_0))

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)

                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()

                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else:
                x = layer(x)

                if retain_grad:
                    x.retain_grad()

            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations

In [15]:
# Simple CNN
class siren_CNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(siren_CNN, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=8,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=8,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        #self.fc1 = nn.Linear(16 * 7 * 7, 250) # replace this with siren.
        self.fc1 = SineLayer(16 * 7 * 7, 250, is_first=True, omega_0=30)
        self.fc2 = SineLayer(250, 250, is_first=False, omega_0=30)
        self.fc3 = nn.Linear(250,num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        #x = F.relu(self.fc1(x)) # replace this with siren.
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


In [16]:
# Initialize network
model_siren = siren_CNN(in_channels=in_channels, num_classes=num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_siren.parameters(), lr=learning_rate)

In [17]:
# Train Network
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
        # Get data to cuda if possible
        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model_siren(data)
        loss = criterion(scores, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()

100%|██████████| 938/938 [00:09<00:00, 95.43it/s] 
100%|██████████| 938/938 [00:09<00:00, 94.88it/s] 
100%|██████████| 938/938 [00:09<00:00, 96.21it/s] 


In [18]:
print(f"Accuracy on training set: {check_accuracy(train_loader, model_siren)*100:.2f}")
print(f"Accuracy on test set: {check_accuracy(test_loader, model_siren)*100:.2f}")

Accuracy on training set: 98.55
Accuracy on test set: 97.85


In [19]:
num1 = attack(train_dataset[0][0], model_siren)
num1

  1%|          | 45/5000 [00:00<00:11, 446.84it/s]


45

In [22]:
# Simple CNN
class siren_updated_CNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(siren_updated_CNN, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=8,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=8,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        #self.fc1 = nn.Linear(16 * 7 * 7, 250) # replace this with siren.
        self.fc1 = SineLayer(16 * 7 * 7, 250, is_first=True, omega_0=30)
        self.fc2 = SineLayer(250, 250, is_first=False, omega_0=30)
        self.fc3 = SineLayer(250,num_classes, is_first = False)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        #x = F.relu(self.fc1(x)) # replace this with siren.
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

In [23]:
# Initialize network
model_siren_updated = siren_updated_CNN(in_channels=in_channels, num_classes=num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_siren_updated.parameters(), lr=learning_rate)

In [24]:
# Train Network
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
        # Get data to cuda if possible
        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model_siren_updated(data)
        loss = criterion(scores, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()

100%|██████████| 938/938 [00:08<00:00, 108.39it/s]
100%|██████████| 938/938 [00:09<00:00, 101.65it/s]
100%|██████████| 938/938 [00:09<00:00, 100.65it/s]


In [25]:
print(f"Accuracy on training set: {check_accuracy(train_loader, model_siren_updated)*100:.2f}")
print(f"Accuracy on test set: {check_accuracy(test_loader, model_siren_updated)*100:.2f}")

Accuracy on training set: 98.61
Accuracy on test set: 98.09


In [26]:
num1 = attack(train_dataset[0][0], model_siren_updated)
num1

  0%|          | 14/5000 [00:00<00:14, 335.78it/s]


14

In [27]:
# Simple CNN
class relu_updated_CNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(relu_updated_CNN, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=8,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=8,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.fc1 = nn.Linear(16 * 7 * 7, 250) # replace this with siren.
        self.fc2 = nn.Linear(250,250)
        self.fc3 = nn.Linear(250, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x)) # replace this with siren.
        x = F.relu(self.fc2(x))
        x = self.fc2(x)
        x = torch.sin(x)
        return x


In [28]:
# Initialize network
model_relu_updated = relu_updated_CNN(in_channels=in_channels, num_classes=num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_relu_updated.parameters(), lr=learning_rate)

In [29]:
# Train Network
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
        # Get data to cuda if possible
        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model_relu_updated(data)
        loss = criterion(scores, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()

100%|██████████| 938/938 [00:10<00:00, 93.76it/s]
100%|██████████| 938/938 [00:09<00:00, 94.32it/s]
100%|██████████| 938/938 [00:09<00:00, 101.46it/s]


In [30]:
print(f"Accuracy on training set: {check_accuracy(train_loader, model_relu_updated)*100:.2f}")
print(f"Accuracy on test set: {check_accuracy(test_loader, model_relu_updated)*100:.2f}")

Accuracy on training set: 91.81
Accuracy on test set: 91.75


In [31]:
num1 = attack(train_dataset[0][0], model_relu_updated)
num1

  0%|          | 25/5000 [00:00<00:12, 399.38it/s]


25

In [32]:
def compute_avg_n_iter (model):
    sum = 0
    for i in tqdm(range(1000)):
        sum += attack(train_dataset[i][0], model)
    return sum/1000

In [33]:
def new_compute_avg_n_iter (model):
    sum = 0
    counter = 0
    i = 0
    while(counter<1000):
        siren_prediction = model_siren(train_dataset[i][0].unsqueeze(0).to(device)).argmax()
        relu_prediction = model_relu(train_dataset[i][0].unsqueeze(0).to(device)).argmax()
        actual = train_dataset[i][1]
        if (siren_prediction == actual and relu_prediction == actual):
            sum += attack(train_dataset[i][0], model)
            counter += 1
        i += 1
    return sum/1000

In [34]:
def new_compute_avg_n_iter_test (model_1, model_2):
    sum_1 = 0
    sum_2 = 0
    counter = 0
    i = 0
    while(counter<100):
        model_1_prediction = model_1(test_dataset[i][0].unsqueeze(0).to(device)).argmax()
        model_2_prediction = model_2(test_dataset[i][0].unsqueeze(0).to(device)).argmax()
        actual = test_dataset[i][1]
        print(counter)
        if (model_1_prediction == actual and model_2_prediction == actual):
            sum_1 += attack(test_dataset[i][0], model_1)
            sum_2 += attack(test_dataset[i][0], model_2)
            counter += 1
        i += 1
    return sum_1/100, sum_2/100

In [35]:
print(f"avg n_iter for model_relu and model_siren = {new_compute_avg_n_iter_test(model_relu, model_siren)}")

0


  2%|▏         | 123/5000 [00:00<00:10, 454.53it/s]
  2%|▏         | 85/5000 [00:00<00:11, 445.56it/s]


1


  3%|▎         | 162/5000 [00:00<00:10, 444.80it/s]
  1%|▏         | 71/5000 [00:00<00:11, 429.48it/s]


2


  1%|▏         | 72/5000 [00:00<00:10, 478.34it/s]
  1%|▏         | 64/5000 [00:00<00:12, 384.21it/s]


3


  3%|▎         | 160/5000 [00:00<00:10, 454.89it/s]
  2%|▏         | 106/5000 [00:00<00:11, 417.70it/s]


4


  2%|▏         | 87/5000 [00:00<00:11, 442.17it/s]
  2%|▏         | 97/5000 [00:00<00:12, 403.94it/s]


5


  2%|▏         | 108/5000 [00:00<00:11, 442.16it/s]
  2%|▏         | 99/5000 [00:00<00:11, 428.40it/s]


6


  1%|          | 56/5000 [00:00<00:11, 421.77it/s]
  1%|          | 53/5000 [00:00<00:11, 438.32it/s]


7


  1%|          | 33/5000 [00:00<00:10, 469.61it/s]
  1%|          | 28/5000 [00:00<00:11, 421.23it/s]


8


  1%|          | 54/5000 [00:00<00:10, 491.25it/s]
  0%|          | 3/5000 [00:00<00:16, 300.68it/s]


9


  1%|          | 51/5000 [00:00<00:10, 458.08it/s]
  0%|          | 22/5000 [00:00<00:17, 284.07it/s]


10


  4%|▎         | 184/5000 [00:00<00:12, 378.82it/s]
  2%|▏         | 124/5000 [00:00<00:13, 359.05it/s]


11


  2%|▏         | 102/5000 [00:00<00:13, 362.96it/s]
  2%|▏         | 80/5000 [00:00<00:14, 347.52it/s]


12


  2%|▏         | 76/5000 [00:00<00:13, 371.55it/s]
  1%|          | 32/5000 [00:00<00:14, 348.99it/s]


13


  3%|▎         | 165/5000 [00:00<00:14, 341.52it/s]
  2%|▏         | 86/5000 [00:00<00:12, 391.05it/s]


14


  2%|▏         | 107/5000 [00:00<00:15, 306.71it/s]
  2%|▏         | 111/5000 [00:00<00:13, 364.09it/s]


15


  2%|▏         | 80/5000 [00:00<00:10, 470.68it/s]
  1%|▏         | 71/5000 [00:00<00:11, 435.38it/s]


16


  2%|▏         | 77/5000 [00:00<00:17, 283.70it/s]
  1%|          | 59/5000 [00:00<00:10, 454.41it/s]


17


  2%|▏         | 110/5000 [00:00<00:10, 450.12it/s]
  2%|▏         | 95/5000 [00:00<00:17, 272.82it/s]


18


  0%|          | 5/5000 [00:00<00:13, 370.70it/s]
  0%|          | 2/5000 [00:00<00:18, 263.31it/s]


19


  3%|▎         | 161/5000 [00:00<00:09, 496.15it/s]
  2%|▏         | 91/5000 [00:00<00:11, 423.59it/s]


20
20


  1%|          | 58/5000 [00:00<00:10, 458.68it/s]
  2%|▏         | 96/5000 [00:00<00:13, 353.30it/s]


21


  2%|▏         | 87/5000 [00:00<00:13, 360.56it/s]
  1%|          | 40/5000 [00:00<00:15, 328.61it/s]


22


  3%|▎         | 172/5000 [00:00<00:12, 372.48it/s]
  1%|          | 54/5000 [00:00<00:14, 351.33it/s]


23


  1%|▏         | 73/5000 [00:00<00:12, 398.61it/s]
  3%|▎         | 126/5000 [00:00<00:14, 346.51it/s]


24


  4%|▍         | 212/5000 [00:00<00:12, 391.52it/s]
  2%|▏         | 81/5000 [00:00<00:11, 439.49it/s]


25


  2%|▏         | 120/5000 [00:00<00:10, 467.54it/s]
  3%|▎         | 144/5000 [00:00<00:11, 435.46it/s]


26


  2%|▏         | 119/5000 [00:00<00:10, 480.96it/s]
  3%|▎         | 128/5000 [00:00<00:10, 452.43it/s]


27


  4%|▍         | 204/5000 [00:00<00:10, 468.81it/s]
  2%|▏         | 117/5000 [00:00<00:10, 453.69it/s]


28


  1%|▏         | 71/5000 [00:00<00:10, 476.77it/s]
  2%|▏         | 95/5000 [00:00<00:10, 456.82it/s]


29


  3%|▎         | 151/5000 [00:00<00:10, 477.95it/s]
  1%|          | 57/5000 [00:00<00:11, 415.08it/s]


30


  2%|▏         | 90/5000 [00:00<00:10, 462.66it/s]
  2%|▏         | 106/5000 [00:00<00:11, 417.73it/s]


31


  2%|▏         | 123/5000 [00:00<00:10, 468.12it/s]
  2%|▏         | 99/5000 [00:00<00:11, 442.48it/s]


32
32


  3%|▎         | 129/5000 [00:00<00:09, 491.45it/s]
  2%|▏         | 92/5000 [00:00<00:10, 456.30it/s]


33


  3%|▎         | 142/5000 [00:00<00:12, 388.69it/s]
  1%|          | 57/5000 [00:00<00:15, 323.49it/s]


34


  3%|▎         | 128/5000 [00:00<00:12, 387.47it/s]
  1%|          | 48/5000 [00:00<00:13, 355.85it/s]


35


  2%|▏         | 116/5000 [00:00<00:13, 355.10it/s]
  2%|▏         | 122/5000 [00:00<00:13, 362.74it/s]


36


  2%|▏         | 103/5000 [00:00<00:12, 379.62it/s]
  1%|▏         | 68/5000 [00:00<00:14, 345.79it/s]


37


  2%|▏         | 80/5000 [00:00<00:14, 351.03it/s]
  3%|▎         | 133/5000 [00:00<00:10, 451.78it/s]


38


  1%|▏         | 71/5000 [00:00<00:10, 486.21it/s]
  2%|▏         | 113/5000 [00:00<00:11, 439.68it/s]


39


  2%|▏         | 98/5000 [00:00<00:10, 489.91it/s]
  2%|▏         | 79/5000 [00:00<00:11, 429.37it/s]


40


  2%|▏         | 114/5000 [00:00<00:09, 489.27it/s]
  2%|▏         | 120/5000 [00:00<00:10, 455.38it/s]


41


  2%|▏         | 111/5000 [00:00<00:10, 472.38it/s]
  0%|          | 1/5000 [00:00<00:29, 166.85it/s]


42


  2%|▏         | 97/5000 [00:00<00:11, 443.25it/s]
  0%|          | 25/5000 [00:00<00:11, 437.62it/s]


43


  2%|▏         | 113/5000 [00:00<00:10, 462.15it/s]
  1%|          | 34/5000 [00:00<00:11, 443.06it/s]


44


  2%|▏         | 85/5000 [00:00<00:11, 443.11it/s]
  2%|▏         | 100/5000 [00:00<00:11, 413.68it/s]


45


  3%|▎         | 139/5000 [00:00<00:11, 415.20it/s]
  1%|          | 52/5000 [00:00<00:12, 411.85it/s]


46


  2%|▏         | 75/5000 [00:00<00:11, 410.83it/s]
  2%|▏         | 118/5000 [00:00<00:11, 415.05it/s]


47


  3%|▎         | 153/5000 [00:00<00:10, 464.61it/s]
  3%|▎         | 131/5000 [00:00<00:11, 434.46it/s]


48


  2%|▏         | 95/5000 [00:00<00:10, 475.24it/s]
  2%|▏         | 76/5000 [00:00<00:11, 427.40it/s]


49


  3%|▎         | 126/5000 [00:00<00:10, 484.25it/s]
  2%|▏         | 119/5000 [00:00<00:10, 465.03it/s]


50


  3%|▎         | 134/5000 [00:00<00:09, 497.68it/s]
  2%|▏         | 76/5000 [00:00<00:11, 438.29it/s]


51


  2%|▏         | 110/5000 [00:00<00:09, 491.10it/s]
  2%|▏         | 78/5000 [00:00<00:10, 465.32it/s]


52


  3%|▎         | 146/5000 [00:00<00:10, 475.40it/s]
  1%|          | 56/5000 [00:00<00:12, 395.59it/s]


53


  3%|▎         | 145/5000 [00:00<00:11, 437.38it/s]
  3%|▎         | 126/5000 [00:00<00:11, 426.22it/s]


54


  3%|▎         | 155/5000 [00:00<00:10, 465.83it/s]
  3%|▎         | 143/5000 [00:00<00:11, 440.87it/s]


55


  2%|▏         | 102/5000 [00:00<00:10, 480.17it/s]
  2%|▏         | 99/5000 [00:00<00:11, 430.05it/s]


56


  2%|▏         | 98/5000 [00:00<00:11, 432.14it/s]
  2%|▏         | 81/5000 [00:00<00:11, 415.41it/s]


57


  2%|▏         | 86/5000 [00:00<00:10, 450.97it/s]
  0%|          | 21/5000 [00:00<00:13, 372.79it/s]


58


  2%|▏         | 124/5000 [00:00<00:11, 421.30it/s]
  1%|▏         | 71/5000 [00:00<00:11, 427.01it/s]


59


  2%|▏         | 98/5000 [00:00<00:15, 309.83it/s]
  0%|          | 3/5000 [00:00<00:24, 203.06it/s]


60


  0%|          | 20/5000 [00:00<00:17, 280.16it/s]
  1%|          | 41/5000 [00:00<00:16, 305.37it/s]


61


  1%|          | 61/5000 [00:00<00:13, 363.37it/s]
  1%|          | 42/5000 [00:00<00:14, 330.74it/s]


62


  3%|▎         | 129/5000 [00:00<00:13, 370.03it/s]
  2%|▏         | 92/5000 [00:00<00:13, 353.12it/s]


63


  1%|          | 43/5000 [00:00<00:13, 378.67it/s]
  1%|          | 39/5000 [00:00<00:14, 337.01it/s]


64


  1%|          | 36/5000 [00:00<00:13, 379.64it/s]
  1%|          | 26/5000 [00:00<00:14, 348.50it/s]


65


  2%|▏         | 121/5000 [00:00<00:13, 349.96it/s]
  2%|▏         | 112/5000 [00:00<00:14, 337.67it/s]


66


  3%|▎         | 130/5000 [00:00<00:10, 478.92it/s]
  1%|          | 50/5000 [00:00<00:11, 432.06it/s]


67


  3%|▎         | 169/5000 [00:00<00:09, 485.52it/s]
  2%|▏         | 88/5000 [00:00<00:10, 452.39it/s]


68


  2%|▏         | 114/5000 [00:00<00:10, 475.83it/s]
  2%|▏         | 119/5000 [00:00<00:11, 435.77it/s]


69


  5%|▌         | 265/5000 [00:00<00:09, 478.03it/s]
  2%|▏         | 94/5000 [00:00<00:11, 424.45it/s]


70


  3%|▎         | 145/5000 [00:00<00:10, 443.98it/s]
  1%|          | 56/5000 [00:00<00:11, 446.60it/s]


71
71


  2%|▏         | 111/5000 [00:00<00:10, 470.99it/s]
  2%|▏         | 86/5000 [00:00<00:10, 473.76it/s]


72


  2%|▏         | 118/5000 [00:00<00:10, 488.16it/s]
  1%|▏         | 64/5000 [00:00<00:11, 429.26it/s]


73


  2%|▏         | 117/5000 [00:00<00:09, 499.32it/s]
  2%|▏         | 84/5000 [00:00<00:11, 445.12it/s]


74


  2%|▏         | 120/5000 [00:00<00:09, 489.27it/s]
  2%|▏         | 83/5000 [00:00<00:10, 448.17it/s]


75


  1%|          | 30/5000 [00:00<00:11, 424.41it/s]
  0%|          | 19/5000 [00:00<00:11, 417.03it/s]


76


  3%|▎         | 142/5000 [00:00<00:10, 479.16it/s]
  1%|          | 44/5000 [00:00<00:10, 454.55it/s]


77


  1%|          | 54/5000 [00:00<00:10, 474.66it/s]
  1%|▏         | 63/5000 [00:00<00:12, 402.03it/s]


78


  1%|          | 57/5000 [00:00<00:11, 417.37it/s]
  2%|▏         | 81/5000 [00:00<00:13, 373.35it/s]


79


  5%|▍         | 234/5000 [00:00<00:09, 483.94it/s]
  1%|▏         | 65/5000 [00:00<00:11, 423.98it/s]


80


  2%|▏         | 114/5000 [00:00<00:10, 457.21it/s]
  2%|▏         | 92/5000 [00:00<00:10, 464.97it/s]


81


  3%|▎         | 128/5000 [00:00<00:09, 495.75it/s]
  1%|          | 38/5000 [00:00<00:11, 442.33it/s]


82


  3%|▎         | 136/5000 [00:00<00:10, 474.57it/s]
  2%|▏         | 109/5000 [00:00<00:11, 441.26it/s]


83


  3%|▎         | 138/5000 [00:00<00:09, 491.60it/s]
  2%|▏         | 80/5000 [00:00<00:10, 456.09it/s]


84


  1%|▏         | 73/5000 [00:00<00:10, 468.26it/s]
  0%|          | 10/5000 [00:00<00:13, 381.47it/s]


85


  3%|▎         | 153/5000 [00:00<00:10, 478.16it/s]
  2%|▏         | 84/5000 [00:00<00:11, 429.90it/s]


86


  2%|▏         | 79/5000 [00:00<00:10, 477.27it/s]
  2%|▏         | 107/5000 [00:00<00:10, 453.37it/s]


87


  3%|▎         | 126/5000 [00:00<00:10, 464.58it/s]
  2%|▏         | 75/5000 [00:00<00:11, 425.55it/s]


88


  2%|▏         | 115/5000 [00:00<00:13, 371.52it/s]
  2%|▏         | 91/5000 [00:00<00:13, 361.72it/s]


89


  0%|          | 21/5000 [00:00<00:13, 369.76it/s]
  0%|          | 2/5000 [00:00<00:22, 218.98it/s]


90


  2%|▎         | 125/5000 [00:00<00:13, 373.70it/s]
  1%|▏         | 66/5000 [00:00<00:14, 345.64it/s]


91


  1%|          | 49/5000 [00:00<00:12, 392.99it/s]
  1%|▏         | 68/5000 [00:00<00:13, 366.45it/s]


92


  1%|          | 52/5000 [00:00<00:12, 389.99it/s]
  1%|          | 59/5000 [00:00<00:14, 335.57it/s]


93


  1%|          | 26/5000 [00:00<00:15, 328.43it/s]
  1%|          | 32/5000 [00:00<00:17, 279.43it/s]


94


  2%|▏         | 97/5000 [00:00<00:13, 351.49it/s]
  2%|▏         | 80/5000 [00:00<00:10, 450.84it/s]


95


  2%|▏         | 89/5000 [00:00<00:10, 479.82it/s]
  2%|▏         | 84/5000 [00:00<00:11, 439.44it/s]


96


  2%|▏         | 105/5000 [00:00<00:10, 475.61it/s]
  2%|▏         | 76/5000 [00:00<00:10, 459.12it/s]


97


  2%|▏         | 96/5000 [00:00<00:10, 490.00it/s]
  1%|          | 53/5000 [00:00<00:11, 447.61it/s]


98


  3%|▎         | 172/5000 [00:00<00:09, 486.35it/s]
  2%|▏         | 79/5000 [00:00<00:11, 443.63it/s]


99


  3%|▎         | 144/5000 [00:00<00:09, 502.12it/s]
  1%|          | 39/5000 [00:00<00:11, 449.42it/s]

avg n_iter for model_relu and model_siren = (108.8, 74.36)





In [36]:
print(f"avg n_iter for model_relu and model_siren_updated = {new_compute_avg_n_iter_test(model_relu, model_siren_updated)}")

0


  2%|▏         | 123/5000 [00:00<00:10, 469.53it/s]
  5%|▌         | 266/5000 [00:00<00:11, 426.73it/s]


1


  3%|▎         | 162/5000 [00:00<00:10, 476.05it/s]
  3%|▎         | 131/5000 [00:00<00:11, 430.89it/s]


2


  1%|▏         | 72/5000 [00:00<00:10, 485.76it/s]
  3%|▎         | 152/5000 [00:00<00:11, 433.14it/s]


3


  3%|▎         | 160/5000 [00:00<00:10, 466.94it/s]
  3%|▎         | 137/5000 [00:00<00:11, 434.76it/s]


4


  2%|▏         | 87/5000 [00:00<00:10, 470.42it/s]
100%|██████████| 5000/5000 [00:12<00:00, 404.48it/s]


5


  2%|▏         | 108/5000 [00:00<00:10, 476.83it/s]
 13%|█▎        | 673/5000 [00:01<00:11, 389.70it/s]


6


  1%|          | 56/5000 [00:00<00:11, 447.06it/s]
  0%|          | 20/5000 [00:00<00:12, 396.54it/s]


7


  1%|          | 33/5000 [00:00<00:10, 467.25it/s]
  0%|          | 22/5000 [00:00<00:12, 395.62it/s]


8


  1%|          | 54/5000 [00:00<00:10, 452.47it/s]
  1%|          | 31/5000 [00:00<00:12, 411.93it/s]


9


  1%|          | 51/5000 [00:00<00:10, 474.34it/s]
  1%|          | 36/5000 [00:00<00:12, 411.20it/s]


10


  4%|▎         | 184/5000 [00:00<00:10, 474.61it/s]
100%|██████████| 5000/5000 [00:12<00:00, 410.30it/s]


11


  2%|▏         | 102/5000 [00:00<00:10, 460.02it/s]
 74%|███████▎  | 3681/5000 [00:09<00:03, 406.60it/s]


12


  2%|▏         | 76/5000 [00:00<00:10, 450.72it/s]
  1%|          | 43/5000 [00:00<00:11, 440.64it/s]


13


  3%|▎         | 165/5000 [00:00<00:09, 492.90it/s]
100%|██████████| 5000/5000 [00:12<00:00, 411.15it/s]


14


  2%|▏         | 107/5000 [00:00<00:10, 476.06it/s]
  1%|▏         | 66/5000 [00:00<00:11, 436.72it/s]


15


  2%|▏         | 80/5000 [00:00<00:10, 483.01it/s]
  5%|▍         | 230/5000 [00:00<00:10, 436.33it/s]


16


  2%|▏         | 77/5000 [00:00<00:09, 494.90it/s]
  1%|          | 37/5000 [00:00<00:11, 443.33it/s]


17


  2%|▏         | 110/5000 [00:00<00:09, 489.69it/s]
  4%|▍         | 218/5000 [00:00<00:10, 436.83it/s]


18
18


  3%|▎         | 161/5000 [00:00<00:10, 479.99it/s]
100%|██████████| 5000/5000 [00:12<00:00, 415.28it/s]


19


  1%|          | 33/5000 [00:00<00:10, 479.76it/s]
  1%|          | 27/5000 [00:00<00:11, 449.03it/s]


20


  1%|          | 58/5000 [00:00<00:09, 496.60it/s]
  1%|▏         | 66/5000 [00:00<00:11, 423.90it/s]


21


  2%|▏         | 87/5000 [00:00<00:11, 432.60it/s]
100%|██████████| 5000/5000 [00:12<00:00, 409.47it/s]


22


  3%|▎         | 172/5000 [00:00<00:09, 503.43it/s]
  7%|▋         | 361/5000 [00:00<00:12, 366.79it/s]


23


  1%|▏         | 73/5000 [00:00<00:12, 383.29it/s]
  2%|▏         | 84/5000 [00:00<00:16, 295.35it/s]


24


  4%|▍         | 212/5000 [00:00<00:12, 368.60it/s]
  5%|▍         | 240/5000 [00:00<00:13, 343.53it/s]


25


  2%|▏         | 120/5000 [00:00<00:09, 502.87it/s]
 94%|█████████▍| 4701/5000 [00:11<00:00, 426.25it/s]


26


  2%|▏         | 119/5000 [00:00<00:12, 382.23it/s]
100%|██████████| 5000/5000 [00:11<00:00, 418.51it/s]


27


  4%|▍         | 204/5000 [00:00<00:12, 369.12it/s]
100%|██████████| 5000/5000 [00:11<00:00, 420.99it/s]


28


  1%|▏         | 71/5000 [00:00<00:13, 365.08it/s]
  2%|▏         | 94/5000 [00:00<00:14, 334.25it/s]


29


  3%|▎         | 151/5000 [00:00<00:10, 444.89it/s]
  3%|▎         | 161/5000 [00:00<00:11, 428.33it/s]


30


  2%|▏         | 90/5000 [00:00<00:09, 513.32it/s]
  1%|          | 54/5000 [00:00<00:11, 444.89it/s]


31


  2%|▏         | 123/5000 [00:00<00:10, 481.64it/s]
100%|██████████| 5000/5000 [00:12<00:00, 415.27it/s]


32
32


  3%|▎         | 129/5000 [00:00<00:10, 480.89it/s]
  5%|▍         | 244/5000 [00:00<00:10, 435.80it/s]


33


  3%|▎         | 142/5000 [00:00<00:09, 502.46it/s]
  2%|▏         | 107/5000 [00:00<00:11, 429.11it/s]


34


  3%|▎         | 128/5000 [00:00<00:10, 468.04it/s]
 14%|█▍        | 705/5000 [00:01<00:10, 410.02it/s]


35


  2%|▏         | 116/5000 [00:00<00:09, 488.64it/s]
  1%|          | 55/5000 [00:00<00:11, 422.19it/s]


36


  2%|▏         | 103/5000 [00:00<00:10, 480.16it/s]
  3%|▎         | 136/5000 [00:00<00:11, 427.08it/s]


37


  2%|▏         | 80/5000 [00:00<00:10, 486.12it/s]
 15%|█▍        | 737/5000 [00:01<00:09, 427.42it/s]


38


  1%|▏         | 71/5000 [00:00<00:10, 471.23it/s]
  1%|          | 46/5000 [00:00<00:12, 395.51it/s]


39


  2%|▏         | 98/5000 [00:00<00:10, 486.47it/s]
  3%|▎         | 142/5000 [00:00<00:11, 429.54it/s]


40


  2%|▏         | 114/5000 [00:00<00:09, 491.22it/s]
100%|██████████| 5000/5000 [00:11<00:00, 417.89it/s]


41


  2%|▏         | 111/5000 [00:00<00:09, 508.86it/s]
100%|██████████| 5000/5000 [00:11<00:00, 417.44it/s]


42


  2%|▏         | 97/5000 [00:00<00:10, 467.78it/s]
100%|██████████| 5000/5000 [00:12<00:00, 392.76it/s]


43


  2%|▏         | 113/5000 [00:00<00:09, 493.57it/s]
100%|██████████| 5000/5000 [00:12<00:00, 416.49it/s]


44


  2%|▏         | 85/5000 [00:00<00:10, 489.91it/s]
  1%|          | 51/5000 [00:00<00:12, 410.02it/s]


45


  3%|▎         | 139/5000 [00:00<00:09, 496.41it/s]
100%|██████████| 5000/5000 [00:11<00:00, 418.29it/s]


46


  2%|▏         | 75/5000 [00:00<00:09, 501.10it/s]
100%|██████████| 5000/5000 [00:11<00:00, 420.31it/s]


47


  3%|▎         | 153/5000 [00:00<00:09, 517.50it/s]
100%|██████████| 5000/5000 [00:11<00:00, 421.87it/s]


48


  2%|▏         | 95/5000 [00:00<00:10, 465.63it/s]
  2%|▏         | 109/5000 [00:00<00:10, 448.46it/s]


49


  3%|▎         | 126/5000 [00:00<00:10, 484.06it/s]
100%|██████████| 5000/5000 [00:11<00:00, 422.22it/s]


50


  3%|▎         | 134/5000 [00:00<00:09, 505.85it/s]
  2%|▏         | 100/5000 [00:00<00:11, 433.93it/s]


51


  2%|▏         | 110/5000 [00:00<00:10, 486.34it/s]
100%|██████████| 5000/5000 [00:11<00:00, 421.69it/s]


52


  3%|▎         | 146/5000 [00:00<00:09, 487.18it/s]
  2%|▏         | 78/5000 [00:00<00:11, 432.73it/s]


53


  3%|▎         | 145/5000 [00:00<00:09, 497.46it/s]
  1%|▏         | 66/5000 [00:00<00:11, 436.36it/s]


54


  3%|▎         | 155/5000 [00:00<00:10, 484.22it/s]
100%|██████████| 5000/5000 [00:11<00:00, 421.18it/s]


55


  2%|▏         | 102/5000 [00:00<00:09, 511.88it/s]
 18%|█▊        | 924/5000 [00:02<00:11, 354.83it/s]


56


  2%|▏         | 98/5000 [00:00<00:10, 486.85it/s]
  1%|          | 61/5000 [00:00<00:11, 437.65it/s]


57


  2%|▏         | 86/5000 [00:00<00:10, 490.64it/s]
  1%|          | 42/5000 [00:00<00:12, 412.89it/s]


58


  2%|▏         | 124/5000 [00:00<00:10, 467.67it/s]
  2%|▏         | 95/5000 [00:00<00:11, 435.29it/s]


59


  2%|▏         | 98/5000 [00:00<00:10, 489.22it/s]
  5%|▍         | 242/5000 [00:00<00:10, 437.15it/s]


60
60


  1%|          | 61/5000 [00:00<00:09, 494.34it/s]
100%|██████████| 5000/5000 [00:12<00:00, 413.63it/s]


61


  3%|▎         | 129/5000 [00:00<00:10, 449.74it/s]
  4%|▍         | 213/5000 [00:00<00:11, 432.20it/s]


62


  1%|          | 43/5000 [00:00<00:10, 491.78it/s]
  1%|          | 35/5000 [00:00<00:12, 400.95it/s]


63


  1%|          | 36/5000 [00:00<00:10, 456.87it/s]
  2%|▏         | 115/5000 [00:00<00:11, 424.10it/s]


64


  2%|▏         | 121/5000 [00:00<00:09, 498.03it/s]
100%|██████████| 5000/5000 [00:13<00:00, 375.22it/s]


65


  3%|▎         | 130/5000 [00:00<00:16, 303.06it/s]
100%|██████████| 5000/5000 [00:12<00:00, 396.49it/s]


66


  3%|▎         | 169/5000 [00:00<00:09, 489.42it/s]
100%|██████████| 5000/5000 [00:12<00:00, 413.92it/s]


67


  2%|▏         | 114/5000 [00:00<00:10, 475.36it/s]
  4%|▍         | 221/5000 [00:00<00:11, 432.55it/s]


68


  5%|▌         | 265/5000 [00:00<00:09, 492.41it/s]
100%|██████████| 5000/5000 [00:12<00:00, 405.65it/s]


69


  3%|▎         | 145/5000 [00:00<00:09, 507.37it/s]
100%|██████████| 5000/5000 [00:13<00:00, 367.70it/s]


70
70


  2%|▏         | 111/5000 [00:00<00:10, 488.02it/s]
  1%|          | 41/5000 [00:00<00:11, 437.11it/s]


71


  2%|▏         | 118/5000 [00:00<00:10, 476.28it/s]
 49%|████▊     | 2435/5000 [00:06<00:06, 400.15it/s]


72


  2%|▏         | 117/5000 [00:00<00:10, 462.07it/s]
100%|██████████| 5000/5000 [00:12<00:00, 399.72it/s]


73


  2%|▏         | 120/5000 [00:00<00:09, 502.68it/s]
100%|██████████| 5000/5000 [00:12<00:00, 386.28it/s]


74


  1%|          | 30/5000 [00:00<00:10, 460.66it/s]
  1%|          | 29/5000 [00:00<00:11, 416.23it/s]


75


  3%|▎         | 142/5000 [00:00<00:09, 488.08it/s]
100%|██████████| 5000/5000 [00:12<00:00, 412.40it/s]


76


  1%|          | 54/5000 [00:00<00:10, 473.93it/s]
 59%|█████▊    | 2935/5000 [00:07<00:05, 401.63it/s]


77


  1%|          | 57/5000 [00:00<00:10, 469.92it/s]
100%|██████████| 5000/5000 [00:12<00:00, 412.01it/s]


78


  5%|▍         | 234/5000 [00:00<00:10, 469.36it/s]
100%|██████████| 5000/5000 [00:11<00:00, 420.51it/s]


79


  2%|▏         | 114/5000 [00:00<00:10, 482.30it/s]
  5%|▌         | 268/5000 [00:00<00:13, 357.94it/s]


80


  3%|▎         | 128/5000 [00:00<00:23, 208.11it/s]
  2%|▎         | 125/5000 [00:00<00:14, 332.90it/s]


81


  3%|▎         | 136/5000 [00:00<00:13, 352.85it/s]
  2%|▏         | 80/5000 [00:00<00:14, 339.13it/s]


82


  3%|▎         | 138/5000 [00:00<00:09, 504.44it/s]
 39%|███▉      | 1970/5000 [00:05<00:09, 332.43it/s]


83


  1%|▏         | 73/5000 [00:00<00:10, 465.64it/s]
100%|██████████| 5000/5000 [00:13<00:00, 362.04it/s]


84


  3%|▎         | 153/5000 [00:00<00:13, 350.33it/s]
  3%|▎         | 158/5000 [00:00<00:21, 222.25it/s]


85


  2%|▏         | 79/5000 [00:00<00:18, 262.54it/s]
 19%|█▊        | 932/5000 [00:02<00:10, 389.17it/s]


86


  3%|▎         | 126/5000 [00:00<00:12, 383.21it/s]
100%|██████████| 5000/5000 [00:12<00:00, 396.80it/s]


87


  2%|▏         | 115/5000 [00:00<00:10, 476.46it/s]
 52%|█████▏    | 2607/5000 [00:06<00:05, 430.25it/s]


88


  0%|          | 21/5000 [00:00<00:12, 395.63it/s]
  0%|          | 3/5000 [00:00<00:18, 265.40it/s]


89


  2%|▎         | 125/5000 [00:00<00:11, 430.93it/s]
100%|██████████| 5000/5000 [00:12<00:00, 395.66it/s]


90


  1%|          | 49/5000 [00:00<00:13, 367.95it/s]
  1%|          | 29/5000 [00:00<00:14, 334.09it/s]


91


  1%|          | 52/5000 [00:00<00:12, 382.85it/s]
100%|██████████| 5000/5000 [00:17<00:00, 287.44it/s]


92


  1%|          | 26/5000 [00:00<00:11, 451.49it/s]
  1%|          | 40/5000 [00:00<00:11, 415.09it/s]


93


  2%|▏         | 97/5000 [00:00<00:13, 353.44it/s]
  1%|          | 47/5000 [00:00<00:11, 431.51it/s]


94


  2%|▏         | 89/5000 [00:00<00:12, 379.30it/s]
100%|██████████| 5000/5000 [00:14<00:00, 344.19it/s]


95


  2%|▏         | 105/5000 [00:00<00:16, 294.09it/s]
  1%|          | 56/5000 [00:00<00:12, 399.38it/s]


96


  2%|▏         | 96/5000 [00:00<00:12, 389.59it/s]
  3%|▎         | 157/5000 [00:00<00:13, 365.66it/s]


97


  3%|▎         | 172/5000 [00:00<00:10, 468.75it/s]
100%|██████████| 5000/5000 [00:14<00:00, 351.46it/s]


98


  3%|▎         | 144/5000 [00:00<00:15, 303.92it/s]
  5%|▌         | 273/5000 [00:01<00:19, 244.74it/s]


99


  4%|▍         | 225/5000 [00:00<00:15, 299.93it/s]
100%|██████████| 5000/5000 [00:12<00:00, 385.38it/s]

avg n_iter for model_relu and model_siren_updated = (111.13, 2082.4)





In [37]:
print(f"avg n_iter for model_relu and model_relu_updated = {new_compute_avg_n_iter_test(model_relu, model_relu_updated)}")

0


  2%|▏         | 123/5000 [00:00<00:10, 478.19it/s]
  3%|▎         | 155/5000 [00:00<00:09, 493.62it/s]


1


  3%|▎         | 162/5000 [00:00<00:09, 489.59it/s]
  3%|▎         | 159/5000 [00:00<00:10, 477.75it/s]


2


  1%|▏         | 72/5000 [00:00<00:10, 481.20it/s]
  3%|▎         | 147/5000 [00:00<00:10, 484.74it/s]


3


  3%|▎         | 160/5000 [00:00<00:10, 476.99it/s]
  5%|▍         | 244/5000 [00:00<00:10, 452.16it/s]


4


  2%|▏         | 87/5000 [00:00<00:10, 451.13it/s]
  2%|▏         | 78/5000 [00:00<00:11, 429.89it/s]


5


  2%|▏         | 108/5000 [00:00<00:11, 439.87it/s]
  6%|▌         | 290/5000 [00:00<00:10, 457.07it/s]


6


  1%|          | 56/5000 [00:00<00:10, 479.02it/s]
  2%|▏         | 83/5000 [00:00<00:10, 464.38it/s]


7
7


  1%|          | 54/5000 [00:00<00:10, 469.29it/s]
  0%|          | 14/5000 [00:00<00:11, 441.67it/s]


8


  1%|          | 51/5000 [00:00<00:09, 504.47it/s]
  0%|          | 0/5000 [00:00<?, ?it/s]


9


  4%|▎         | 184/5000 [00:00<00:09, 498.08it/s]
  2%|▏         | 101/5000 [00:00<00:10, 472.15it/s]


10


  2%|▏         | 102/5000 [00:00<00:10, 472.23it/s]
  1%|▏         | 71/5000 [00:00<00:10, 470.51it/s]


11


  2%|▏         | 76/5000 [00:00<00:10, 458.56it/s]
  1%|          | 37/5000 [00:00<00:10, 472.68it/s]


12


  3%|▎         | 165/5000 [00:00<00:09, 484.17it/s]
  4%|▍         | 211/5000 [00:00<00:10, 437.87it/s]


13


  2%|▏         | 107/5000 [00:00<00:12, 397.44it/s]
  3%|▎         | 128/5000 [00:00<00:10, 451.89it/s]


14


  2%|▏         | 80/5000 [00:00<00:10, 463.72it/s]
  4%|▎         | 176/5000 [00:00<00:09, 489.46it/s]


15


  2%|▏         | 77/5000 [00:00<00:10, 472.16it/s]
  1%|          | 43/5000 [00:00<00:10, 493.71it/s]


16


  2%|▏         | 110/5000 [00:00<00:10, 487.55it/s]
  6%|▌         | 286/5000 [00:00<00:09, 495.82it/s]


17
17


  3%|▎         | 161/5000 [00:00<00:10, 480.64it/s]
  2%|▏         | 100/5000 [00:00<00:11, 413.35it/s]


18


  1%|          | 33/5000 [00:00<00:13, 365.77it/s]
  0%|          | 15/5000 [00:00<00:14, 340.26it/s]


19


  1%|          | 58/5000 [00:00<00:12, 385.27it/s]
  2%|▏         | 85/5000 [00:00<00:12, 393.19it/s]


20


  2%|▏         | 87/5000 [00:00<00:12, 389.87it/s]
  1%|          | 62/5000 [00:00<00:12, 390.08it/s]


21


  3%|▎         | 172/5000 [00:00<00:12, 401.92it/s]
  2%|▏         | 99/5000 [00:00<00:12, 384.36it/s]


22


  1%|▏         | 73/5000 [00:00<00:14, 350.26it/s]
  1%|▏         | 66/5000 [00:00<00:13, 370.40it/s]


23


  4%|▍         | 212/5000 [00:00<00:10, 469.82it/s]
  3%|▎         | 168/5000 [00:00<00:09, 487.28it/s]


24


  2%|▏         | 120/5000 [00:00<00:09, 492.77it/s]
  3%|▎         | 131/5000 [00:00<00:10, 473.58it/s]


25


  2%|▏         | 119/5000 [00:00<00:10, 486.11it/s]
  2%|▏         | 91/5000 [00:00<00:10, 468.53it/s]


26


  4%|▍         | 204/5000 [00:00<00:09, 492.80it/s]
  6%|▌         | 294/5000 [00:00<00:10, 469.89it/s]


27


  1%|▏         | 71/5000 [00:00<00:10, 481.26it/s]
  5%|▍         | 245/5000 [00:00<00:09, 481.80it/s]


28


  3%|▎         | 151/5000 [00:00<00:10, 472.88it/s]
  1%|          | 31/5000 [00:00<00:10, 472.84it/s]


29


  2%|▏         | 90/5000 [00:00<00:10, 488.36it/s]
  1%|          | 53/5000 [00:00<00:10, 474.18it/s]


30
30
30


  3%|▎         | 129/5000 [00:00<00:10, 474.06it/s]
  2%|▏         | 88/5000 [00:00<00:09, 508.57it/s]


31


  3%|▎         | 142/5000 [00:00<00:09, 489.68it/s]
  4%|▍         | 194/5000 [00:00<00:09, 490.53it/s]


32


  3%|▎         | 128/5000 [00:00<00:10, 478.55it/s]
  2%|▏         | 113/5000 [00:00<00:10, 479.05it/s]


33


  2%|▏         | 116/5000 [00:00<00:10, 482.48it/s]
  5%|▌         | 252/5000 [00:00<00:10, 468.80it/s]


34


  2%|▏         | 103/5000 [00:00<00:13, 353.54it/s]
  3%|▎         | 142/5000 [00:00<00:13, 365.58it/s]


35


  2%|▏         | 80/5000 [00:00<00:13, 352.04it/s]
  5%|▌         | 250/5000 [00:00<00:12, 378.52it/s]


36


  1%|▏         | 71/5000 [00:00<00:13, 374.11it/s]
  3%|▎         | 131/5000 [00:00<00:12, 378.28it/s]


37


  2%|▏         | 98/5000 [00:00<00:10, 469.35it/s]
  3%|▎         | 128/5000 [00:00<00:10, 467.30it/s]


38


  2%|▏         | 114/5000 [00:00<00:10, 487.10it/s]
  2%|▏         | 104/5000 [00:00<00:10, 477.77it/s]


39


  2%|▏         | 111/5000 [00:00<00:12, 392.45it/s]
  0%|          | 5/5000 [00:00<00:16, 306.29it/s]


40
40


  2%|▏         | 113/5000 [00:00<00:13, 372.83it/s]
  3%|▎         | 126/5000 [00:00<00:13, 362.41it/s]


41


  2%|▏         | 85/5000 [00:00<00:13, 366.47it/s]
  1%|          | 47/5000 [00:00<00:12, 389.11it/s]


42


  3%|▎         | 139/5000 [00:00<00:12, 380.35it/s]
  3%|▎         | 144/5000 [00:00<00:14, 339.44it/s]


43


  2%|▏         | 75/5000 [00:00<00:15, 327.82it/s]
  0%|          | 11/5000 [00:00<00:12, 408.75it/s]


44


  3%|▎         | 153/5000 [00:00<00:10, 476.90it/s]
  2%|▏         | 112/5000 [00:00<00:10, 463.88it/s]


45


  2%|▏         | 95/5000 [00:00<00:09, 490.82it/s]
  2%|▏         | 107/5000 [00:00<00:10, 469.12it/s]


46


  3%|▎         | 126/5000 [00:00<00:10, 480.76it/s]
  1%|          | 32/5000 [00:00<00:10, 484.27it/s]


47


  3%|▎         | 134/5000 [00:00<00:09, 505.32it/s]
  6%|▋         | 323/5000 [00:00<00:09, 475.96it/s]


48


  2%|▏         | 110/5000 [00:00<00:09, 490.70it/s]
  5%|▌         | 263/5000 [00:00<00:09, 485.73it/s]


49


  3%|▎         | 146/5000 [00:00<00:09, 502.65it/s]
  1%|          | 61/5000 [00:00<00:10, 478.77it/s]


50


  3%|▎         | 145/5000 [00:00<00:10, 474.75it/s]
  6%|▌         | 311/5000 [00:00<00:09, 488.00it/s]


51


  3%|▎         | 155/5000 [00:00<00:10, 457.09it/s]
  2%|▏         | 99/5000 [00:00<00:10, 489.82it/s]


52


  2%|▏         | 102/5000 [00:00<00:09, 490.87it/s]
  3%|▎         | 144/5000 [00:00<00:09, 489.58it/s]


53


  2%|▏         | 98/5000 [00:00<00:09, 495.97it/s]
  0%|          | 24/5000 [00:00<00:23, 214.97it/s]


54


  2%|▏         | 86/5000 [00:00<00:15, 322.66it/s]
  2%|▏         | 122/5000 [00:00<00:10, 486.89it/s]


55


  2%|▏         | 124/5000 [00:00<00:12, 390.60it/s]
  2%|▏         | 104/5000 [00:00<00:15, 309.44it/s]


56


  2%|▏         | 98/5000 [00:00<00:09, 495.97it/s]
  1%|▏         | 70/5000 [00:00<00:10, 471.75it/s]


57


  0%|          | 20/5000 [00:00<00:22, 224.94it/s]
  0%|          | 20/5000 [00:00<00:32, 154.39it/s]


58
58


  3%|▎         | 129/5000 [00:00<00:10, 468.02it/s]
  3%|▎         | 149/5000 [00:00<00:10, 473.54it/s]


59


  1%|          | 43/5000 [00:00<00:10, 481.19it/s]
  0%|          | 3/5000 [00:00<00:17, 279.88it/s]


60


  1%|          | 36/5000 [00:00<00:10, 469.37it/s]
  1%|          | 56/5000 [00:00<00:10, 452.37it/s]


61


  2%|▏         | 121/5000 [00:00<00:10, 472.21it/s]
  2%|▏         | 92/5000 [00:00<00:10, 469.03it/s]


62


  3%|▎         | 130/5000 [00:00<00:09, 504.13it/s]
  1%|          | 33/5000 [00:00<00:10, 452.27it/s]


63


  3%|▎         | 169/5000 [00:00<00:10, 481.82it/s]
  2%|▏         | 114/5000 [00:00<00:13, 363.78it/s]


64


  2%|▏         | 114/5000 [00:00<00:12, 395.99it/s]
  6%|▌         | 300/5000 [00:00<00:12, 375.72it/s]


65


  5%|▌         | 265/5000 [00:00<00:16, 285.81it/s]
  3%|▎         | 128/5000 [00:00<00:09, 487.48it/s]


66


  3%|▎         | 145/5000 [00:00<00:09, 490.28it/s]
  6%|▌         | 287/5000 [00:00<00:10, 444.13it/s]


67
67


  2%|▏         | 111/5000 [00:00<00:12, 399.84it/s]
  3%|▎         | 134/5000 [00:00<00:12, 402.03it/s]


68


  2%|▏         | 118/5000 [00:00<00:10, 460.03it/s]
  2%|▏         | 98/5000 [00:00<00:10, 459.04it/s]


69
69


  2%|▏         | 120/5000 [00:00<00:10, 473.00it/s]
  0%|          | 5/5000 [00:00<00:13, 372.15it/s]


70


  1%|          | 30/5000 [00:00<00:10, 482.86it/s]
  1%|          | 47/5000 [00:00<00:12, 412.63it/s]


71


  3%|▎         | 142/5000 [00:00<00:10, 473.68it/s]
  2%|▏         | 110/5000 [00:00<00:10, 464.29it/s]


72
72


  1%|          | 57/5000 [00:00<00:10, 453.98it/s]
  2%|▏         | 87/5000 [00:00<00:10, 484.77it/s]


73


  5%|▍         | 234/5000 [00:00<00:10, 472.70it/s]
  3%|▎         | 130/5000 [00:00<00:10, 476.60it/s]


74


  2%|▏         | 114/5000 [00:00<00:10, 478.41it/s]
  2%|▏         | 80/5000 [00:00<00:11, 442.45it/s]


75


  3%|▎         | 128/5000 [00:00<00:10, 452.58it/s]
  1%|          | 33/5000 [00:00<00:11, 441.02it/s]


76


  3%|▎         | 136/5000 [00:00<00:10, 449.45it/s]
  2%|▏         | 98/5000 [00:00<00:10, 474.37it/s]


77


  3%|▎         | 138/5000 [00:00<00:09, 499.49it/s]
  4%|▍         | 215/5000 [00:00<00:09, 492.66it/s]


78
78


  3%|▎         | 153/5000 [00:00<00:11, 412.76it/s]
  2%|▏         | 115/5000 [00:00<00:10, 474.41it/s]


79


  2%|▏         | 79/5000 [00:00<00:10, 481.91it/s]
  3%|▎         | 153/5000 [00:00<00:10, 483.45it/s]


80


  3%|▎         | 126/5000 [00:00<00:16, 289.35it/s]
  1%|          | 55/5000 [00:00<00:10, 462.25it/s]


81


  2%|▏         | 115/5000 [00:00<00:09, 488.64it/s]
  2%|▏         | 105/5000 [00:00<00:10, 470.17it/s]


82
82


  2%|▎         | 125/5000 [00:00<00:10, 479.10it/s]
  1%|          | 53/5000 [00:00<00:10, 449.87it/s]


83


  1%|          | 49/5000 [00:00<00:10, 484.43it/s]
  5%|▌         | 262/5000 [00:00<00:09, 478.70it/s]


84


  1%|          | 52/5000 [00:00<00:15, 326.60it/s]
  1%|▏         | 72/5000 [00:00<00:13, 354.33it/s]


85


  1%|          | 26/5000 [00:00<00:12, 384.57it/s]
  0%|          | 8/5000 [00:00<00:19, 255.75it/s]


86


  2%|▏         | 97/5000 [00:00<00:12, 401.84it/s]
  2%|▏         | 85/5000 [00:00<00:12, 381.57it/s]


87


  2%|▏         | 89/5000 [00:00<00:13, 357.67it/s]
  2%|▏         | 95/5000 [00:00<00:12, 391.85it/s]


88


  2%|▏         | 105/5000 [00:00<00:12, 394.08it/s]
  1%|          | 59/5000 [00:00<00:13, 360.46it/s]


89


  2%|▏         | 96/5000 [00:00<00:13, 375.38it/s]
  2%|▏         | 112/5000 [00:00<00:11, 420.57it/s]


90


  3%|▎         | 172/5000 [00:00<00:09, 483.29it/s]
  4%|▍         | 222/5000 [00:00<00:09, 479.99it/s]


91


  3%|▎         | 144/5000 [00:00<00:09, 492.35it/s]
  6%|▌         | 305/5000 [00:00<00:09, 482.06it/s]


92


  4%|▍         | 225/5000 [00:00<00:10, 476.65it/s]
  2%|▏         | 117/5000 [00:00<00:10, 483.84it/s]


93


  2%|▏         | 76/5000 [00:00<00:10, 473.21it/s]
  2%|▏         | 76/5000 [00:00<00:09, 494.64it/s]


94


  2%|▏         | 80/5000 [00:00<00:09, 501.16it/s]
  1%|          | 32/5000 [00:00<00:11, 442.31it/s]


95


  3%|▎         | 151/5000 [00:00<00:10, 484.51it/s]
  3%|▎         | 168/5000 [00:00<00:10, 458.28it/s]


96


  1%|          | 38/5000 [00:00<00:11, 450.18it/s]
  2%|▏         | 79/5000 [00:00<00:11, 432.84it/s]


97


  1%|          | 52/5000 [00:00<00:10, 451.45it/s]
  1%|          | 61/5000 [00:00<00:11, 447.57it/s]


98


  3%|▎         | 130/5000 [00:00<00:10, 479.35it/s]
  1%|▏         | 68/5000 [00:00<00:10, 454.66it/s]


99


  2%|▏         | 81/5000 [00:00<00:10, 483.12it/s]
  1%|          | 53/5000 [00:00<00:10, 452.91it/s]

avg n_iter for model_relu and model_relu_updated = (111.62, 117.44)





In [38]:
def attack_cross (tensor, model, eps = 1e-3, n_iter = 5000):
    number = 5000
    new_tensor = tensor.unsqueeze(0).detach().clone()
    new_tensor = new_tensor.to(device = device)
    orig_prediction = model(new_tensor).argmax()
    #print(f"Original Prediction: {orig_prediction.item()}")

    for i in tqdm(range(n_iter)):
        model.zero_grad()

        grad = compute_gradient(func, new_tensor, model = model, target = orig_prediction.item())
        #new_tensor = torch.clamp(new_tensor + eps * grad.sign(), -2, 2)
        new_tensor = new_tensor + eps * grad.sign()
        new_prediction = model(new_tensor).argmax()

        if new_prediction != orig_prediction:
            #print(f"New Prediction: {new_prediction.item()}")
            number = i
            break
    return number, new_tensor

In [39]:
def new_compute_avg_n_iter_test_cross (model_1, model_2):
    sum_1 = 0
    sum_2 = 0
    correct_predictions_1 = 0
    correct_predictions_2 = 0
    counter = 0
    i = 0
    while(counter<100):
        model_1_prediction = model_1(test_dataset[i][0].unsqueeze(0).to(device)).argmax()
        model_2_prediction = model_2(test_dataset[i][0].unsqueeze(0).to(device)).argmax()
        actual = test_dataset[i][1]
        print(counter)
        if (model_1_prediction == actual and model_2_prediction == actual):
            val_1, val_2 = attack_cross(test_dataset[i][0], model_1)
            sum_1 += val_1
            adv_example_from_model_1 = val_2
            adv_example_from_model_1_on_model_2_prediction = model_2(adv_example_from_model_1).argmax()
            if (adv_example_from_model_1_on_model_2_prediction == actual):
              correct_predictions_2 += 1
            val_1, val_2 = attack_cross(test_dataset[i][0], model_2)
            sum_2 += val_1
            adv_example_from_model_2 = val_2
            adv_example_from_model_2_on_model_1_prediction = model_1(adv_example_from_model_2).argmax()
            if (adv_example_from_model_2_on_model_1_prediction == actual):
              correct_predictions_1 += 1

            counter += 1
        i += 1
    return sum_1/100, sum_2/100, correct_predictions_1, correct_predictions_2

In [40]:
new_compute_avg_n_iter_test_cross(model_relu, model_siren_updated)

0


  2%|▏         | 123/5000 [00:00<00:09, 502.12it/s]
  5%|▌         | 266/5000 [00:00<00:10, 444.36it/s]


1


  3%|▎         | 162/5000 [00:00<00:09, 508.47it/s]
  3%|▎         | 131/5000 [00:00<00:11, 419.65it/s]


2


  1%|▏         | 72/5000 [00:00<00:10, 482.23it/s]
  3%|▎         | 152/5000 [00:00<00:11, 435.87it/s]


3


  3%|▎         | 160/5000 [00:00<00:10, 482.76it/s]
  3%|▎         | 137/5000 [00:00<00:11, 412.85it/s]


4


  2%|▏         | 87/5000 [00:00<00:09, 510.14it/s]
100%|██████████| 5000/5000 [00:11<00:00, 418.59it/s]


5


  2%|▏         | 108/5000 [00:00<00:09, 498.38it/s]
 13%|█▎        | 673/5000 [00:01<00:09, 441.52it/s]


6


  1%|          | 56/5000 [00:00<00:11, 441.29it/s]
  0%|          | 20/5000 [00:00<00:11, 435.22it/s]


7


  1%|          | 33/5000 [00:00<00:10, 490.86it/s]
  0%|          | 22/5000 [00:00<00:11, 420.61it/s]


8


  1%|          | 54/5000 [00:00<00:09, 498.12it/s]
  1%|          | 31/5000 [00:00<00:10, 452.75it/s]


9


  1%|          | 51/5000 [00:00<00:10, 479.53it/s]
  1%|          | 36/5000 [00:00<00:10, 451.67it/s]


10


  4%|▎         | 184/5000 [00:00<00:09, 482.42it/s]
100%|██████████| 5000/5000 [00:11<00:00, 417.78it/s]


11


  2%|▏         | 102/5000 [00:00<00:10, 465.67it/s]
 74%|███████▎  | 3681/5000 [00:08<00:03, 413.54it/s]


12


  2%|▏         | 76/5000 [00:00<00:10, 477.29it/s]
  1%|          | 43/5000 [00:00<00:10, 453.95it/s]


13


  3%|▎         | 165/5000 [00:00<00:09, 484.35it/s]
100%|██████████| 5000/5000 [00:12<00:00, 399.92it/s]


14


  2%|▏         | 107/5000 [00:00<00:10, 486.81it/s]
  1%|▏         | 66/5000 [00:00<00:11, 442.57it/s]


15


  2%|▏         | 80/5000 [00:00<00:09, 503.72it/s]
  5%|▍         | 230/5000 [00:00<00:10, 434.06it/s]


16


  2%|▏         | 77/5000 [00:00<00:10, 483.78it/s]
  1%|          | 37/5000 [00:00<00:10, 459.54it/s]


17


  2%|▏         | 110/5000 [00:00<00:10, 477.34it/s]
  4%|▍         | 218/5000 [00:00<00:10, 444.37it/s]


18
18


  3%|▎         | 161/5000 [00:00<00:09, 506.80it/s]
100%|██████████| 5000/5000 [00:11<00:00, 419.04it/s]


19


  1%|          | 33/5000 [00:00<00:10, 451.78it/s]
  1%|          | 27/5000 [00:00<00:11, 417.86it/s]


20


  1%|          | 58/5000 [00:00<00:10, 459.80it/s]
  1%|▏         | 66/5000 [00:00<00:12, 403.84it/s]


21


  2%|▏         | 87/5000 [00:00<00:10, 463.29it/s]
100%|██████████| 5000/5000 [00:11<00:00, 426.74it/s]


22


  3%|▎         | 172/5000 [00:00<00:09, 483.60it/s]
  7%|▋         | 361/5000 [00:00<00:11, 416.59it/s]


23


  1%|▏         | 73/5000 [00:00<00:10, 465.06it/s]
  2%|▏         | 84/5000 [00:00<00:13, 371.96it/s]


24


  4%|▍         | 212/5000 [00:00<00:09, 496.84it/s]
  5%|▍         | 240/5000 [00:00<00:11, 421.61it/s]


25


  2%|▏         | 120/5000 [00:00<00:12, 379.34it/s]
 94%|█████████▍| 4701/5000 [00:11<00:00, 420.77it/s]


26


  2%|▏         | 119/5000 [00:00<00:09, 497.11it/s]
100%|██████████| 5000/5000 [00:11<00:00, 419.55it/s]


27


  4%|▍         | 204/5000 [00:00<00:09, 483.30it/s]
100%|██████████| 5000/5000 [00:12<00:00, 414.93it/s]


28


  1%|▏         | 71/5000 [00:00<00:10, 455.97it/s]
  2%|▏         | 94/5000 [00:00<00:12, 399.43it/s]


29


  3%|▎         | 151/5000 [00:00<00:13, 369.57it/s]
  3%|▎         | 161/5000 [00:00<00:14, 335.22it/s]


30


  2%|▏         | 90/5000 [00:00<00:13, 358.85it/s]
  1%|          | 54/5000 [00:00<00:16, 301.57it/s]


31


  2%|▏         | 123/5000 [00:00<00:12, 393.42it/s]
100%|██████████| 5000/5000 [00:12<00:00, 404.51it/s]


32
32


  3%|▎         | 129/5000 [00:00<00:12, 390.30it/s]
  5%|▍         | 244/5000 [00:00<00:11, 430.60it/s]


33


  3%|▎         | 142/5000 [00:00<00:09, 487.31it/s]
  2%|▏         | 107/5000 [00:00<00:11, 425.84it/s]


34


  3%|▎         | 128/5000 [00:00<00:09, 490.66it/s]
 14%|█▍        | 705/5000 [00:01<00:09, 441.78it/s]


35


  2%|▏         | 116/5000 [00:00<00:10, 466.57it/s]
  1%|          | 55/5000 [00:00<00:11, 437.69it/s]


36


  2%|▏         | 103/5000 [00:00<00:09, 507.05it/s]
  3%|▎         | 136/5000 [00:00<00:11, 436.10it/s]


37


  2%|▏         | 80/5000 [00:00<00:11, 438.90it/s]
 15%|█▍        | 737/5000 [00:01<00:10, 414.15it/s]


38


  1%|▏         | 71/5000 [00:00<00:10, 488.80it/s]
  1%|          | 46/5000 [00:00<00:13, 379.67it/s]


39


  2%|▏         | 98/5000 [00:00<00:09, 500.49it/s]
  3%|▎         | 142/5000 [00:00<00:11, 440.27it/s]


40


  2%|▏         | 114/5000 [00:00<00:10, 464.51it/s]
100%|██████████| 5000/5000 [00:12<00:00, 412.61it/s]


41


  2%|▏         | 111/5000 [00:00<00:10, 472.46it/s]
100%|██████████| 5000/5000 [00:12<00:00, 415.43it/s]


42


  2%|▏         | 97/5000 [00:00<00:10, 480.04it/s]
100%|██████████| 5000/5000 [00:11<00:00, 417.05it/s]


43


  2%|▏         | 113/5000 [00:00<00:09, 496.70it/s]
100%|██████████| 5000/5000 [00:11<00:00, 418.56it/s]


44


  2%|▏         | 85/5000 [00:00<00:10, 463.93it/s]
  1%|          | 51/5000 [00:00<00:11, 419.33it/s]


45


  3%|▎         | 139/5000 [00:00<00:10, 480.11it/s]
100%|██████████| 5000/5000 [00:11<00:00, 420.48it/s]


46


  2%|▏         | 75/5000 [00:00<00:09, 492.65it/s]
100%|██████████| 5000/5000 [00:11<00:00, 417.26it/s]


47


  3%|▎         | 153/5000 [00:00<00:09, 489.59it/s]
100%|██████████| 5000/5000 [00:12<00:00, 413.54it/s]


48


  2%|▏         | 95/5000 [00:00<00:09, 496.71it/s]
  2%|▏         | 109/5000 [00:00<00:11, 437.43it/s]


49


  3%|▎         | 126/5000 [00:00<00:09, 500.23it/s]
100%|██████████| 5000/5000 [00:11<00:00, 419.51it/s]


50


  3%|▎         | 134/5000 [00:00<00:09, 492.79it/s]
  2%|▏         | 100/5000 [00:00<00:11, 413.02it/s]


51


  2%|▏         | 110/5000 [00:00<00:09, 506.96it/s]
100%|██████████| 5000/5000 [00:12<00:00, 391.30it/s]


52


  3%|▎         | 146/5000 [00:00<00:10, 452.57it/s]
  2%|▏         | 78/5000 [00:00<00:11, 421.69it/s]


53


  3%|▎         | 145/5000 [00:00<00:10, 475.61it/s]
  1%|▏         | 66/5000 [00:00<00:12, 392.27it/s]


54


  3%|▎         | 155/5000 [00:00<00:10, 459.09it/s]
100%|██████████| 5000/5000 [00:11<00:00, 423.70it/s]


55


  2%|▏         | 102/5000 [00:00<00:10, 455.92it/s]
 18%|█▊        | 924/5000 [00:02<00:11, 353.05it/s]


56


  2%|▏         | 98/5000 [00:00<00:12, 398.22it/s]
  1%|          | 61/5000 [00:00<00:11, 427.52it/s]


57


  2%|▏         | 86/5000 [00:00<00:10, 471.65it/s]
  1%|          | 42/5000 [00:00<00:11, 432.71it/s]


58


  2%|▏         | 124/5000 [00:00<00:09, 490.95it/s]
  2%|▏         | 95/5000 [00:00<00:11, 440.46it/s]


59


  2%|▏         | 98/5000 [00:00<00:10, 488.12it/s]
  5%|▍         | 242/5000 [00:00<00:10, 439.30it/s]


60
60


  1%|          | 61/5000 [00:00<00:10, 456.79it/s]
100%|██████████| 5000/5000 [00:12<00:00, 414.43it/s]


61


  3%|▎         | 129/5000 [00:00<00:09, 502.86it/s]
  4%|▍         | 213/5000 [00:00<00:10, 436.15it/s]


62


  1%|          | 43/5000 [00:00<00:10, 475.31it/s]
  1%|          | 35/5000 [00:00<00:11, 445.47it/s]


63


  1%|          | 36/5000 [00:00<00:09, 508.56it/s]
  2%|▏         | 115/5000 [00:00<00:10, 459.75it/s]


64


  2%|▏         | 121/5000 [00:00<00:10, 483.43it/s]
100%|██████████| 5000/5000 [00:11<00:00, 418.44it/s]


65


  3%|▎         | 130/5000 [00:00<00:10, 472.30it/s]
100%|██████████| 5000/5000 [00:11<00:00, 419.63it/s]


66


  3%|▎         | 169/5000 [00:00<00:10, 482.12it/s]
100%|██████████| 5000/5000 [00:11<00:00, 417.99it/s]


67


  2%|▏         | 114/5000 [00:00<00:10, 467.72it/s]
  4%|▍         | 221/5000 [00:00<00:10, 438.99it/s]


68


  5%|▌         | 265/5000 [00:00<00:09, 478.67it/s]
100%|██████████| 5000/5000 [00:11<00:00, 419.70it/s]


69


  3%|▎         | 145/5000 [00:00<00:10, 479.89it/s]
100%|██████████| 5000/5000 [00:12<00:00, 412.90it/s]


70
70


  2%|▏         | 111/5000 [00:00<00:09, 500.03it/s]
  1%|          | 41/5000 [00:00<00:11, 437.84it/s]


71


  2%|▏         | 118/5000 [00:00<00:10, 475.43it/s]
 49%|████▊     | 2435/5000 [00:06<00:06, 401.25it/s]


72


  2%|▏         | 117/5000 [00:00<00:13, 356.74it/s]
100%|██████████| 5000/5000 [00:12<00:00, 407.88it/s]


73


  2%|▏         | 120/5000 [00:00<00:14, 337.89it/s]
100%|██████████| 5000/5000 [00:12<00:00, 414.40it/s]


74


  1%|          | 30/5000 [00:00<00:14, 337.58it/s]
  1%|          | 29/5000 [00:00<00:17, 291.38it/s]


75


  3%|▎         | 142/5000 [00:00<00:09, 487.23it/s]
100%|██████████| 5000/5000 [00:12<00:00, 410.92it/s]


76


  1%|          | 54/5000 [00:00<00:10, 469.76it/s]
 59%|█████▊    | 2935/5000 [00:06<00:04, 424.98it/s]


77


  1%|          | 57/5000 [00:00<00:10, 458.00it/s]
100%|██████████| 5000/5000 [00:12<00:00, 408.44it/s]


78


  5%|▍         | 234/5000 [00:00<00:09, 495.71it/s]
100%|██████████| 5000/5000 [00:12<00:00, 408.67it/s]


79


  2%|▏         | 114/5000 [00:00<00:11, 431.17it/s]
  5%|▌         | 268/5000 [00:00<00:11, 400.98it/s]


80


  3%|▎         | 128/5000 [00:00<00:10, 461.06it/s]
  2%|▎         | 125/5000 [00:00<00:11, 418.83it/s]


81


  3%|▎         | 136/5000 [00:00<00:10, 471.11it/s]
  2%|▏         | 80/5000 [00:00<00:11, 420.96it/s]


82


  3%|▎         | 138/5000 [00:00<00:14, 336.63it/s]
 39%|███▉      | 1970/5000 [00:05<00:07, 383.24it/s]


83


  1%|▏         | 73/5000 [00:00<00:10, 460.22it/s]
100%|██████████| 5000/5000 [00:12<00:00, 386.66it/s]


84


  3%|▎         | 153/5000 [00:00<00:10, 467.69it/s]
  3%|▎         | 158/5000 [00:00<00:11, 416.42it/s]


85


  2%|▏         | 79/5000 [00:00<00:10, 486.86it/s]
 19%|█▊        | 932/5000 [00:02<00:09, 432.04it/s]


86


  3%|▎         | 126/5000 [00:00<00:10, 467.51it/s]
100%|██████████| 5000/5000 [00:12<00:00, 412.59it/s]


87


  2%|▏         | 115/5000 [00:00<00:10, 477.74it/s]
 52%|█████▏    | 2607/5000 [00:06<00:06, 393.09it/s]


88


  0%|          | 21/5000 [00:00<00:11, 446.93it/s]
  0%|          | 3/5000 [00:00<00:17, 289.47it/s]


89


  2%|▎         | 125/5000 [00:00<00:10, 483.67it/s]
100%|██████████| 5000/5000 [00:12<00:00, 410.20it/s]


90


  1%|          | 49/5000 [00:00<00:11, 431.49it/s]
  1%|          | 29/5000 [00:00<00:12, 412.32it/s]


91


  1%|          | 52/5000 [00:00<00:10, 450.00it/s]
100%|██████████| 5000/5000 [00:12<00:00, 414.98it/s]


92


  1%|          | 26/5000 [00:00<00:10, 475.09it/s]
  1%|          | 40/5000 [00:00<00:10, 453.47it/s]


93


  2%|▏         | 97/5000 [00:00<00:09, 515.04it/s]
  1%|          | 47/5000 [00:00<00:11, 419.14it/s]


94


  2%|▏         | 89/5000 [00:00<00:09, 499.41it/s]
100%|██████████| 5000/5000 [00:12<00:00, 412.55it/s]


95


  2%|▏         | 105/5000 [00:00<00:10, 470.10it/s]
  1%|          | 56/5000 [00:00<00:11, 422.68it/s]


96


  2%|▏         | 96/5000 [00:00<00:10, 461.91it/s]
  3%|▎         | 157/5000 [00:00<00:11, 437.03it/s]


97


  3%|▎         | 172/5000 [00:00<00:09, 483.57it/s]
100%|██████████| 5000/5000 [00:12<00:00, 416.21it/s]


98


  3%|▎         | 144/5000 [00:00<00:10, 444.02it/s]
  5%|▌         | 273/5000 [00:00<00:12, 391.89it/s]


99


  4%|▍         | 225/5000 [00:00<00:10, 472.52it/s]
100%|██████████| 5000/5000 [00:11<00:00, 420.40it/s]


(111.13, 2082.4, 97, 99)

In [None]:
print(f"avg n_iter for model_relu and model_siren = {new_compute_avg_n_iter_test(model_relu, model_siren)}")

In [None]:
print(f"avg n_iter for model_siren = {new_compute_avg_n_iter_test(model_siren)}")

In [None]:
print(f"avg n_iter for model_siren_updated = {new_compute_avg_n_iter_test(model_siren_updated)}")

In [None]:
print(f"avg n_iter for model_relu_updated = {new_compute_avg_n_iter_test(model_relu_updated)}")