In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt

from torchvision import datasets, transforms
# from tensorboardX import SummaryWriter

use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 64
np.random.seed(42)
torch.manual_seed(42)


## Dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

## Simple NN. You can change this if you want. If you change it, mention the architectural details in your report.
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.f1 = nn.Linear(28*28, 50)
        self.f2 = nn.Linear(50, 50)
        self.f3 = nn.Linear(50, 50)
        self.out = nn.Linear(50, 10)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.f1(x))
        x = F.relu(self.f2(x))
        x = F.relu(self.f3(x))
        x = self.out(x)
        return x

class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307)/0.3081

# Add the data normalization as a first "layer" to the network
# this allows us to search for adverserial examples to the real image, rather than
# to the normalized image
model = nn.Sequential(Normalize(), Net())

model = model.to(device)
model.train()

Sequential(
  (0): Normalize()
  (1): Net(
    (f1): Linear(in_features=784, out_features=50, bias=True)
    (f2): Linear(in_features=50, out_features=50, bias=True)
    (f3): Linear(in_features=50, out_features=50, bias=True)
    (out): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [2]:
def train_model(model, num_epochs, learning_rate=0.1, momentum=0.95):
    entropy_loss = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
    
    for epoch in range(num_epochs):
        model.train() 
        for (data, target) in train_loader:
            data = data.to(device)
            target = target.to(device)
            
            optimizer.zero_grad()           
            outputs = model(data)           
            loss = entropy_loss(outputs, target)
            loss.backward()                 
            optimizer.step()                

In [4]:
train_model(model, 20)
torch.save(model.state_dict(), 'weights.pt')

In [66]:
def evaluate_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad(): 
        for data, target in test_loader:
            data = data.to(device)
            target = target.to(device)
            outputs = model(data) 
            _, predicted = torch.max(outputs, 1) 
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100.0 * correct / total
    print(f'Model Accuracy on MNIST Test Set: {accuracy:.2f}%')
    return accuracy

accuracy = evaluate_model(model, test_loader, device)

Model Accuracy on MNIST Test Set: 8.52%


In [63]:
def interval_analysis(model, x, eps):
    f1_weight = model[1].f1.weight.data 
    f1_bias = model[1].f1.bias.data   
    
    f2_weight = model[1].f2.weight.data
    f2_bias = model[1].f2.bias.data  
    
    f3_weight = model[1].f3.weight.data 
    f3_bias = model[1].f3.bias.data 
    
    out_weight = model[1].out.weight.data 
    out_bias = model[1].out.bias.data 
    
    x_normalized = (x - 0.1307) / 0.3081
    
    x_normalized = x_normalized.view(-1)
    x_lower = x_normalized - eps
    x_upper = x_normalized + eps

    f1_lower = torch.zeros(50)
    f1_upper = torch.zeros(50)
    
    for i in range(f1_weight.size(0)):
        sum_lower = 0.0
        sum_upper = 0.0
        for j in range(f1_weight.size(1)): 
            w = f1_weight[i][j].item()
            if w >=0:
                sum_lower += w * x_lower[j]
                sum_upper += w * x_upper[j]
            else:
                sum_lower += w * x_upper[j]
                sum_upper += w * x_lower[j]
        f1_lower[i] = sum_lower + f1_bias[i]
        f1_upper[i] = sum_upper + f1_bias[i]

    f1_lower = torch.clamp(f1_lower, min=0.0)
    f1_upper = torch.clamp(f1_upper, min=0.0)

    f2_lower = torch.zeros(50)
    f2_upper = torch.zeros(50)
    
    for i in range(f2_weight.size(0)):
        sum_lower = 0.0
        sum_upper = 0.0
        for j in range(f2_weight.size(1)):
            w = f2_weight[i][j].item()
            if w >=0:
                sum_lower += w * f1_lower[j]
                sum_upper += w * f1_upper[j]
            else:
                sum_lower += w * f1_upper[j]
                sum_upper += w * f1_lower[j]
        f2_lower[i] = sum_lower + f2_bias[i]
        f2_upper[i] = sum_upper + f2_bias[i]

    f2_lower = torch.clamp(f2_lower, min=0.0)
    f2_upper = torch.clamp(f2_upper, min=0.0)

    f3_lower = torch.zeros(50)
    f3_upper = torch.zeros(50)
    
    for i in range(f3_weight.size(0)):
        sum_lower = 0.0
        sum_upper = 0.0
        for j in range(f3_weight.size(1)):
            w = f3_weight[i][j].item()
            if w >=0:
                sum_lower += w * f2_lower[j]
                sum_upper += w * f2_upper[j]
            else:
                sum_lower += w * f2_upper[j]
                sum_upper += w * f2_lower[j]
        f3_lower[i] = sum_lower + f3_bias[i]
        f3_upper[i] = sum_upper + f3_bias[i]
    
    f3_lower = torch.clamp(f3_lower, min=0.0)
    f3_upper = torch.clamp(f3_upper, min=0.0)

    out_lower = torch.zeros(10)
    out_upper = torch.zeros(10)
    
    for i in range(out_weight.size(0)):
        sum_lower = 0.0
        sum_upper = 0.0
        for j in range(out_weight.size(1)):
            w = out_weight[i][j].item()
            if w >=0:
                sum_lower += w * f3_lower[j]
                sum_upper += w * f3_upper[j]
            else:
                sum_lower += w * f3_upper[j]
                sum_upper += w * f3_lower[j]
        out_lower[i] = sum_lower + out_bias[i]
        out_upper[i] = sum_upper + out_bias[i]
    print(out_lower)
    print(out_upper)
    return out_lower, out_upper


In [64]:

!pip install tqdm
from tqdm import tqdm
eps = [0.01 * i for i in range(1, 11)]
robust_rates = []

for e in eps:
    robust_count = 0
    process_count = 0
    test_num = 10

    for idx, (data, target) in enumerate(tqdm(test_loader)):
        process_count += 1
        data = data.to(device)
        target = target.item()
        

        f_l, f_u = interval_analysis(model, data, e)
        c_l = f_l[target]

        incorrect_u = f_u
        incorrect_u[target] = -1000

        max_incorrect_u = torch.max(incorrect_u)
        
        if c_l > max_incorrect_u:
            print("correct_lower: ", c_l)
            print("max_incorrect_upper: ", max_incorrect_u)
            robust_count +=1
        if process_count >= test_num:
            break
    robust_rate = robust_count / test_num
    print("robust_rate: ", robust_rate)


Defaulting to user installation because normal site-packages is not writeable


  0%|          | 1/10000 [00:01<3:04:19,  1.11s/it]

tensor([-1.2872, -0.8924, -1.0657, -0.8974, -1.1955, -1.0123, -0.5290, -0.8308,
        -0.9499, -0.7195])
tensor([0.5968, 1.0441, 1.0219, 0.9192, 0.8673, 0.9053, 1.0133, 0.7626, 0.9839,
        1.1293])


  0%|          | 2/10000 [00:02<3:06:24,  1.12s/it]

tensor([-1.2875, -0.8726, -1.0299, -0.9247, -1.1996, -0.9723, -0.5380, -0.7559,
        -0.9268, -0.7370])
tensor([0.6320, 1.0234, 1.0519, 0.8790, 0.8748, 0.8993, 1.0087, 0.7828, 0.9987,
        1.0804])


  0%|          | 3/10000 [00:03<3:05:54,  1.12s/it]

tensor([-0.9034, -0.5835, -0.7567, -0.6644, -0.8544, -0.6672, -0.3129, -0.5485,
        -0.6226, -0.4706])
tensor([0.3691, 0.7179, 0.6364, 0.5691, 0.5430, 0.6310, 0.7399, 0.5185, 0.6873,
        0.7679])


  0%|          | 4/10000 [00:04<3:05:43,  1.11s/it]

tensor([-1.2160, -0.8610, -1.0051, -0.9060, -1.1578, -0.9263, -0.4833, -0.7332,
        -0.8927, -0.6557])
tensor([0.6089, 0.9627, 0.9558, 0.7665, 0.7964, 0.8554, 0.9507, 0.7581, 0.9830,
        1.0732])


  0%|          | 5/10000 [00:05<3:05:15,  1.11s/it]

tensor([-1.3729, -0.9728, -1.1262, -0.9649, -1.2757, -1.0593, -0.5636, -0.8684,
        -0.9992, -0.7336])
tensor([0.7057, 1.0843, 1.0955, 0.9680, 0.9420, 0.9924, 1.0558, 0.8411, 1.0806,
        1.2221])


  0%|          | 6/10000 [00:06<3:05:27,  1.11s/it]

tensor([-1.0254, -0.6751, -0.8582, -0.7497, -0.9663, -0.7684, -0.3884, -0.6415,
        -0.7279, -0.5460])
tensor([0.4554, 0.8332, 0.7644, 0.6760, 0.6479, 0.7331, 0.8295, 0.5961, 0.7891,
        0.8982])


  0%|          | 7/10000 [00:07<3:05:18,  1.11s/it]

tensor([-1.0603, -0.6892, -0.8587, -0.7363, -0.9669, -0.7590, -0.3934, -0.6578,
        -0.7932, -0.5358])
tensor([0.4479, 0.8079, 0.7993, 0.6795, 0.6639, 0.7557, 0.8291, 0.6085, 0.7530,
        0.9486])


  0%|          | 8/10000 [00:08<3:06:53,  1.12s/it]

tensor([-1.1826, -0.8107, -0.9873, -0.8916, -1.1303, -0.8908, -0.4731, -0.7296,
        -0.8599, -0.6383])
tensor([0.5872, 0.9754, 0.9278, 0.7794, 0.7787, 0.8560, 0.9492, 0.7408, 0.9369,
        1.0756])


  0%|          | 9/10000 [00:10<3:08:51,  1.13s/it]

tensor([-1.4319, -1.0212, -1.1922, -1.0185, -1.3645, -1.1009, -0.6316, -0.9329,
        -1.0641, -0.8062])
tensor([0.7181, 1.1383, 1.1603, 1.0147, 0.9686, 1.0605, 1.0736, 0.8528, 1.1322,
        1.2681])


  0%|          | 9/10000 [00:11<3:28:04,  1.25s/it]


tensor([-1.2530, -0.8382, -1.0481, -0.9338, -1.1863, -0.9970, -0.4956, -0.8345,
        -0.9211, -0.6779])
tensor([0.6249, 1.0384, 0.9999, 0.8736, 0.8421, 0.8917, 0.9937, 0.7349, 0.9994,
        1.1540])
robust_rate:  0.0


  0%|          | 1/10000 [00:01<3:10:57,  1.15s/it]

tensor([-2.2538, -1.7231, -1.9010, -1.6444, -2.1558, -1.9227, -1.0715, -1.5782,
        -1.8038, -1.3805])
tensor([1.3172, 1.9415, 2.0474, 1.7664, 1.6930, 1.6642, 1.8288, 1.4371, 1.8501,
        2.0931])


  0%|          | 2/10000 [00:02<3:09:32,  1.14s/it]

tensor([-2.3130, -1.7550, -1.9192, -1.7013, -2.2184, -1.9378, -1.1263, -1.5571,
        -1.8380, -1.4354])
tensor([1.3863, 1.9745, 2.1324, 1.7933, 1.7550, 1.7083, 1.8539, 1.4865, 1.9128,
        2.1001])


  0%|          | 3/10000 [00:03<3:09:29,  1.14s/it]

tensor([-1.7899, -1.3319, -1.5212, -1.3303, -1.7251, -1.4848, -0.8206, -1.2248,
        -1.4027, -1.0695])
tensor([1.0090, 1.5345, 1.5627, 1.3511, 1.3003, 1.3324, 1.4621, 1.1257, 1.4676,
        1.6373])


  0%|          | 4/10000 [00:04<3:09:11,  1.14s/it]

tensor([-2.2676, -1.7499, -1.9057, -1.6800, -2.1753, -1.9129, -1.0751, -1.5472,
        -1.8017, -1.3664])
tensor([1.3546, 1.9293, 2.0652, 1.7136, 1.6874, 1.6774, 1.8201, 1.4688, 1.9023,
        2.1046])


  0%|          | 5/10000 [00:05<3:07:38,  1.13s/it]

tensor([-2.4003, -1.8677, -1.9977, -1.7487, -2.2881, -2.0095, -1.1421, -1.6532,
        -1.9022, -1.4362])
tensor([1.4505, 2.0071, 2.1750, 1.8624, 1.8041, 1.7834, 1.9006, 1.5500, 1.9835,
        2.2279])


  0%|          | 6/10000 [00:06<3:07:54,  1.13s/it]

tensor([-1.8217, -1.3506, -1.5440, -1.3533, -1.7499, -1.5014, -0.8453, -1.2518,
        -1.4305, -1.0839])
tensor([1.0368, 1.5669, 1.5974, 1.3841, 1.3309, 1.3637, 1.4800, 1.1425, 1.4902,
        1.6817])


  0%|          | 7/10000 [00:07<3:10:18,  1.14s/it]

tensor([-1.9899, -1.4967, -1.6582, -1.4536, -1.8959, -1.6243, -0.9243, -1.3719,
        -1.6068, -1.1790])
tensor([1.1437, 1.6654, 1.7896, 1.5008, 1.4565, 1.4872, 1.6018, 1.2644, 1.5884,
        1.8677])


  0%|          | 8/10000 [00:09<3:08:45,  1.13s/it]

tensor([-2.1250, -1.6170, -1.7843, -1.5916, -2.0452, -1.7591, -1.0040, -1.4490,
        -1.6916, -1.2800])
tensor([1.2663, 1.8316, 1.9157, 1.6089, 1.5879, 1.5945, 1.7244, 1.3827, 1.7616,
        1.9880])


  0%|          | 9/10000 [00:10<3:11:03,  1.15s/it]

tensor([-2.5262, -1.9594, -2.1355, -1.8534, -2.4393, -2.1250, -1.2454, -1.7818,
        -2.0336, -1.5416])
tensor([1.5286, 2.1547, 2.3186, 1.9794, 1.9048, 1.9259, 1.9942, 1.6073, 2.0994,
        2.3594])


  0%|          | 9/10000 [00:11<3:31:51,  1.27s/it]


tensor([-2.2769, -1.6997, -1.9068, -1.6875, -2.1823, -1.9414, -1.0728, -1.6017,
        -1.8062, -1.3729])
tensor([1.3533, 1.9688, 2.0753, 1.7781, 1.7070, 1.6846, 1.8214, 1.4349, 1.8935,
        2.1398])
robust_rate:  0.0


  0%|          | 1/10000 [00:01<3:28:05,  1.25s/it]

tensor([-3.1308, -2.4726, -2.6394, -2.2948, -3.0049, -2.7231, -1.5655, -2.2441,
        -2.5772, -1.9812])
tensor([1.9414, 2.7289, 2.9622, 2.5323, 2.4392, 2.3486, 2.5436, 2.0346, 2.6138,
        2.9383])


  0%|          | 2/10000 [00:02<3:24:10,  1.23s/it]

tensor([-3.3213, -2.6173, -2.7744, -2.4536, -3.2015, -2.8623, -1.7003, -2.3312,
        -2.7300, -2.1181])
tensor([2.1160, 2.8860, 3.1912, 2.6791, 2.6170, 2.5016, 2.6737, 2.1736, 2.7993,
        3.0856])


  0%|          | 3/10000 [00:03<3:24:34,  1.23s/it]

tensor([-2.7239, -2.1259, -2.3112, -2.0293, -2.6367, -2.3336, -1.3533, -1.9345,
        -2.2294, -1.7016])
tensor([1.6841, 2.3792, 2.5396, 2.1648, 2.0945, 2.0646, 2.2171, 1.7657, 2.2823,
        2.5470])


  0%|          | 4/10000 [00:04<3:25:20,  1.23s/it]

tensor([-3.2309, -2.5651, -2.7196, -2.3862, -3.1104, -2.7917, -1.6250, -2.2732,
        -2.6419, -2.0281])
tensor([2.0448, 2.7974, 3.0758, 2.5650, 2.5074, 2.4269, 2.5999, 2.1280, 2.7408,
        3.0271])


  0%|          | 5/10000 [00:06<3:40:04,  1.32s/it]

tensor([-3.4163, -2.7351, -2.8550, -2.5018, -3.2776, -2.9381, -1.7197, -2.4240,
        -2.7985, -2.1331])
tensor([2.1803, 2.9240, 3.2433, 2.7536, 2.6686, 2.5739, 2.7272, 2.2439, 2.8674,
        3.2097])


  0%|          | 6/10000 [00:07<3:37:18,  1.30s/it]

tensor([-2.7174, -2.1109, -2.2984, -2.0226, -2.6234, -2.3154, -1.3545, -1.9305,
        -2.2223, -1.6920])
tensor([1.6836, 2.3748, 2.5349, 2.1640, 2.0935, 2.0647, 2.2047, 1.7562, 2.2706,
        2.5504])


  0%|          | 7/10000 [00:08<3:33:47,  1.28s/it]

tensor([-2.9413, -2.3079, -2.4668, -2.1623, -2.8294, -2.4937, -1.4692, -2.0926,
        -2.4430, -1.8292])
tensor([1.8372, 2.5257, 2.7912, 2.3383, 2.2643, 2.2314, 2.3728, 1.9205, 2.4215,
        2.7908])


  0%|          | 8/10000 [00:10<3:31:19,  1.27s/it]

tensor([-3.0014, -2.3664, -2.5226, -2.2397, -2.8975, -2.5608, -1.5006, -2.1109,
        -2.4598, -1.8793])
tensor([1.8921, 2.6133, 2.8293, 2.3742, 2.3315, 2.2766, 2.4325, 1.9815, 2.5287,
        2.8266])


  0%|          | 9/10000 [00:11<3:28:15,  1.25s/it]

tensor([-3.5428, -2.8340, -3.0015, -2.6280, -3.4434, -3.0664, -1.8217, -2.5610,
        -2.9290, -2.2379])
tensor([2.2823, 3.0909, 3.3983, 2.8762, 2.7743, 2.7211, 2.8365, 2.3113, 2.9990,
        3.3566])


  0%|          | 9/10000 [00:12<3:53:24,  1.40s/it]


tensor([-3.2125, -2.4964, -2.6955, -2.3784, -3.0952, -2.7903, -1.6090, -2.3036,
        -2.6274, -2.0104])
tensor([2.0293, 2.8105, 3.0586, 2.5979, 2.4998, 2.4131, 2.5748, 2.0793, 2.7053,
        3.0418])
robust_rate:  0.0


  0%|          | 1/10000 [00:01<3:19:05,  1.19s/it]

tensor([-3.9896, -3.2046, -3.3618, -2.9293, -3.8374, -3.5056, -2.0513, -2.8938,
        -3.3351, -2.5709])
tensor([2.5526, 3.4954, 3.8583, 3.2828, 3.1671, 3.0161, 3.2388, 2.6195, 3.3597,
        3.7617])


  0%|          | 2/10000 [00:02<3:18:46,  1.19s/it]

tensor([-4.2576, -3.4174, -3.5625, -3.1526, -4.1154, -3.7191, -2.2289, -3.0432,
        -3.5560, -2.7562])
tensor([2.7880, 3.7249, 4.1717, 3.4961, 3.4113, 3.2301, 3.4318, 2.8115, 3.6173,
        3.9903])


  0%|          | 3/10000 [00:03<3:19:22,  1.20s/it]

tensor([-3.7097, -2.9640, -3.1443, -2.7633, -3.5973, -3.2319, -1.9148, -2.6841,
        -3.0992, -2.3706])
tensor([2.3942, 3.2698, 3.5687, 3.0250, 2.9321, 2.8362, 3.0131, 2.4393, 3.1437,
        3.5040])


  0%|          | 4/10000 [00:04<3:19:37,  1.20s/it]

tensor([-4.1172, -3.3175, -3.4678, -3.0386, -3.9701, -3.6013, -2.1287, -2.9452,
        -3.4178, -2.6373])
tensor([2.6751, 3.5925, 4.0010, 3.3454, 3.2609, 3.1180, 3.3168, 2.7306, 3.5124,
        3.8739])


  0%|          | 5/10000 [00:05<3:20:20,  1.20s/it]

tensor([-4.4062, -3.5790, -3.6912, -3.2363, -4.2403, -3.8439, -2.2814, -3.1787,
        -3.6705, -2.8120])
tensor([2.8871, 3.8159, 4.2774, 3.6195, 3.5089, 3.3478, 3.5291, 2.9163, 3.7329,
        4.1613])


  0%|          | 6/10000 [00:07<3:20:39,  1.20s/it]

tensor([-3.6773, -2.9285, -3.1089, -2.7371, -3.5598, -3.1914, -1.8992, -2.6592,
        -3.0694, -2.3445])
tensor([2.3758, 3.2394, 3.5392, 3.0016, 2.9104, 2.8153, 2.9825, 2.4143, 3.1098,
        3.4813])


  0%|          | 7/10000 [00:08<3:21:21,  1.21s/it]

tensor([-3.9004, -3.1260, -3.2790, -2.8775, -3.7663, -3.3703, -2.0166, -2.8214,
        -3.2855, -2.4833])
tensor([2.5284, 3.3897, 3.7949, 3.1794, 3.0763, 2.9818, 3.1485, 2.5784, 3.2634,
        3.7183])


  0%|          | 8/10000 [00:09<3:25:45,  1.24s/it]

tensor([-3.8567, -3.0961, -3.2436, -2.8720, -3.7282, -3.3405, -1.9864, -2.7557,
        -3.2097, -2.4620])
tensor([2.5014, 3.3755, 3.7187, 3.1209, 3.0563, 2.9413, 3.1197, 2.5635, 3.2742,
        3.6461])


  0%|          | 9/10000 [00:10<3:25:23,  1.23s/it]

tensor([-4.5859, -3.7239, -3.8835, -3.4076, -4.4651, -4.0209, -2.4154, -3.3540,
        -3.8493, -2.9542])
tensor([3.0402, 4.0329, 4.4973, 3.7933, 3.6627, 3.5371, 3.6840, 3.0275, 3.9132,
        4.3637])


  0%|          | 9/10000 [00:12<3:44:52,  1.35s/it]


tensor([-4.1880, -3.3315, -3.5155, -3.0993, -4.0444, -3.6759, -2.1668, -3.0382,
        -3.4857, -2.6739])
tensor([2.7281, 3.6838, 4.0786, 3.4506, 3.3239, 3.1724, 3.3614, 2.7483, 3.5523,
        3.9803])
robust_rate:  0.0


  0%|          | 1/10000 [00:01<3:23:55,  1.22s/it]

tensor([-4.8307, -3.9227, -4.0692, -3.5499, -4.6541, -4.2718, -2.5286, -3.5279,
        -4.0764, -3.1487])
tensor([3.1514, 4.2440, 4.7361, 4.0183, 3.8801, 3.6689, 3.9174, 3.1929, 4.0900,
        4.5651])


  0%|          | 2/10000 [00:02<3:22:46,  1.22s/it]

tensor([-5.1450, -4.1777, -4.3096, -3.8162, -4.9839, -4.5320, -2.7301, -3.7169,
        -4.3402, -3.3635])
tensor([3.4252, 4.5180, 5.1012, 4.2701, 4.1629, 3.9201, 4.1509, 3.4181, 4.3929,
        4.8458])


  0%|          | 3/10000 [00:03<3:23:38,  1.22s/it]

tensor([-4.6445, -3.7622, -3.9337, -3.4588, -4.5096, -4.0857, -2.4475, -3.3946,
        -3.9234, -3.0074])
tensor([3.0677, 4.1123, 4.5473, 3.8439, 3.7260, 3.5661, 3.7684, 3.0797, 3.9612,
        4.4100])


  0%|          | 4/10000 [00:04<3:26:46,  1.24s/it]

tensor([-5.0363, -4.1000, -4.2437, -3.7189, -4.8629, -4.4426, -2.6503, -3.6446,
        -4.2251, -3.2675])
tensor([3.3289, 4.4175, 4.9600, 4.1510, 4.0412, 3.8355, 4.0602, 3.3550, 4.3138,
        4.7552])


  0%|          | 5/10000 [00:06<3:35:54,  1.30s/it]

tensor([-5.3660, -4.3972, -4.5032, -3.9484, -5.1743, -4.7233, -2.8265, -3.9105,
        -4.5152, -3.4715])
tensor([3.5732, 4.6804, 5.2804, 4.4607, 4.3254, 4.0985, 4.3067, 3.5699, 4.5734,
        5.0833])


  0%|          | 6/10000 [00:07<3:45:11,  1.35s/it]

tensor([-4.6258, -3.7384, -3.9088, -3.4423, -4.4845, -4.0577, -2.4381, -3.3791,
        -3.9058, -2.9902])
tensor([3.0597, 4.0932, 4.5324, 3.8321, 3.7173, 3.5554, 3.7503, 3.0645, 3.9390,
        4.3997])


  0%|          | 7/10000 [00:09<3:46:56,  1.36s/it]

tensor([-4.8573, -3.9425, -4.0908, -3.5898, -4.7012, -4.2470, -2.5638, -3.5508,
        -4.1253, -3.1360])
tensor([3.2189, 4.2564, 4.7965, 4.0211, 3.8894, 3.7320, 3.9240, 3.2352, 4.1058,
        4.6462])


  0%|          | 8/10000 [00:10<3:44:48,  1.35s/it]

tensor([-4.7081, -3.8241, -3.9607, -3.5021, -4.5547, -4.1158, -2.4700, -3.3973,
        -3.9579, -3.0422])
tensor([3.1074, 4.1336, 4.6043, 3.8635, 3.7767, 3.6011, 3.8034, 3.1429, 4.0139,
        4.4626])


  0%|          | 9/10000 [00:11<3:34:44,  1.29s/it]

tensor([-5.5951, -4.5848, -4.7367, -4.1629, -5.4513, -4.9446, -2.9886, -4.1247,
        -4.7412, -3.6459])
tensor([3.7681, 4.9434, 5.5554, 4.6793, 4.5207, 4.3276, 4.5038, 3.7168, 4.7962,
        5.3370])


  0%|          | 9/10000 [00:12<3:57:02,  1.42s/it]


tensor([-5.1398, -4.1472, -4.3156, -3.8023, -4.9696, -4.5405, -2.7109, -3.7556,
        -4.3243, -3.3231])
tensor([3.4076, 4.5367, 5.0734, 4.2824, 4.1280, 3.9131, 4.1307, 3.3998, 4.3787,
        4.8945])
robust_rate:  0.0


  0%|          | 1/10000 [00:01<3:11:55,  1.15s/it]

tensor([-5.7561, -4.7162, -4.8470, -4.2337, -5.5544, -5.1136, -3.0551, -4.2258,
        -4.8920, -3.7830])
tensor([3.8136, 5.0689, 5.7041, 4.8276, 4.6655, 4.3892, 4.6648, 3.8262, 4.8958,
        5.4522])


  0%|          | 2/10000 [00:02<3:13:44,  1.16s/it]

tensor([-6.0722, -4.9726, -5.0899, -4.5084, -5.8899, -5.3802, -3.2554, -4.4219,
        -5.1595, -3.9984])
tensor([4.0923, 5.3481, 6.0736, 5.0806, 4.9488, 4.6423, 4.9024, 4.0524, 5.2045,
        5.7404])


  0%|          | 3/10000 [00:03<3:11:00,  1.15s/it]

tensor([-5.5573, -4.5442, -4.7052, -4.1367, -5.4002, -4.9203, -2.9685, -4.0874,
        -4.7279, -3.6313])
tensor([3.7251, 4.9336, 5.5028, 4.6449, 4.5025, 4.2778, 4.5059, 3.7052, 4.7598,
        5.2918])


  0%|          | 4/10000 [00:04<3:11:20,  1.15s/it]

tensor([-5.9541, -4.8816, -5.0184, -4.3994, -5.7564, -5.2837, -3.1712, -4.3435,
        -5.0321, -3.8963])
tensor([3.9835, 5.2431, 5.9192, 4.9552, 4.8202, 4.5517, 4.8022, 3.9793, 5.1147,
        5.6374])


  0%|          | 5/10000 [00:05<3:11:12,  1.15s/it]

tensor([-6.3185, -5.2091, -5.3088, -4.6547, -6.1013, -5.5961, -3.3680, -4.6369,
        -5.3541, -4.1264])
tensor([4.2542, 5.5388, 6.2767, 5.2966, 5.1366, 4.8436, 5.0790, 4.2187, 5.4071,
        5.9985])


  0%|          | 6/10000 [00:06<3:11:05,  1.15s/it]

tensor([-5.5271, -4.5086, -4.6694, -4.1107, -5.3627, -4.8814, -2.9514, -4.0637,
        -4.6996, -3.6047])
tensor([3.7078, 4.9045, 5.4746, 4.6231, 4.4841, 4.2578, 4.4782, 3.6802, 4.7265,
        5.2707])


  0%|          | 7/10000 [00:08<3:14:15,  1.17s/it]

tensor([-5.7545, -4.7121, -4.8541, -4.2602, -5.5820, -5.0712, -3.0776, -4.2375,
        -4.9175, -3.7506])
tensor([3.8692, 5.0711, 5.7390, 4.8103, 4.6537, 4.4370, 4.6539, 3.8528, 4.8960,
        5.5184])


  0%|          | 8/10000 [00:09<3:12:57,  1.16s/it]

tensor([-5.6165, -4.6007, -4.7254, -4.1731, -5.4372, -4.9444, -2.9867, -4.0838,
        -4.7560, -3.6622])
tensor([3.7553, 4.9460, 5.5512, 4.6579, 4.5468, 4.3057, 4.5350, 3.7617, 4.8050,
        5.3350])


  0%|          | 9/10000 [00:10<3:13:07,  1.16s/it]

tensor([-6.6130, -5.4517, -5.5966, -4.9220, -6.4441, -5.8766, -3.5665, -4.9013,
        -5.6393, -4.3433])
tensor([4.5003, 5.8617, 6.6222, 5.5719, 5.3869, 5.1248, 5.3297, 4.4113, 5.6876,
        6.3186])


  0%|          | 9/10000 [00:11<3:35:31,  1.29s/it]


tensor([-6.0707, -4.9477, -5.0997, -4.4919, -5.8772, -5.3893, -3.2427, -4.4597,
        -5.1459, -3.9611])
tensor([4.0740, 5.3736, 6.0501, 5.0971, 4.9161, 4.6389, 4.8861, 4.0386, 5.1911,
        5.7894])
robust_rate:  0.0


  0%|          | 1/10000 [00:01<3:16:08,  1.18s/it]

tensor([-6.6843, -5.5131, -5.6291, -4.9210, -6.4589, -5.9602, -3.5838, -4.9298,
        -5.7116, -4.4202])
tensor([4.4793, 5.9011, 6.6764, 5.6412, 5.4556, 5.1144, 5.4171, 4.4611, 5.7063,
        6.3438])


  0%|          | 2/10000 [00:02<3:15:15,  1.17s/it]

tensor([-6.9950, -5.7660, -5.8662, -5.1961, -6.7910, -6.2239, -3.7786, -5.1227,
        -5.9750, -4.6315])
tensor([4.7559, 6.1738, 7.0415, 5.8886, 5.7320, 5.3605, 5.6513, 4.6848, 6.0122,
        6.6290])


  0%|          | 3/10000 [00:03<3:19:17,  1.20s/it]

tensor([-6.4565, -5.3152, -5.4650, -4.8050, -6.2778, -5.7432, -3.4813, -4.7710,
        -5.5209, -4.2470])
tensor([4.3725, 5.7433, 6.4441, 5.4345, 5.2678, 4.9788, 5.2338, 4.3208, 5.5467,
        6.1598])


  0%|          | 4/10000 [00:04<3:19:34,  1.20s/it]

tensor([-6.8481, -5.6442, -5.7731, -5.0624, -6.6272, -6.1017, -3.6796, -5.0232,
        -5.8188, -4.5093])
tensor([4.6228, 6.0462, 6.8538, 5.7386, 5.5791, 5.2497, 5.5244, 4.5881, 5.8956,
        6.4965])


  0%|          | 5/10000 [00:06<3:23:37,  1.22s/it]

tensor([-7.2883, -6.0396, -6.1281, -5.3761, -7.0475, -6.4841, -3.9206, -5.3762,
        -6.2110, -4.7934])
tensor([4.9513, 6.4136, 7.2939, 6.1482, 5.9627, 5.6014, 5.8670, 4.8817, 6.2561,
        6.9326])


  0%|          | 6/10000 [00:07<3:23:25,  1.22s/it]

tensor([-6.4157, -5.2699, -5.4200, -4.7707, -6.2294, -5.6948, -3.4577, -4.7397,
        -5.4832, -4.2127])
tensor([4.3471, 5.7050, 6.4042, 5.4037, 5.2406, 4.9505, 5.1976, 4.2875, 5.5038,
        6.1282])


  0%|          | 7/10000 [00:08<3:21:32,  1.21s/it]

tensor([-6.6682, -5.4966, -5.6275, -4.9414, -6.4757, -5.9086, -3.5992, -4.9331,
        -5.7239, -4.3768])
tensor([4.5284, 5.8962, 6.6965, 5.6119, 5.4316, 5.1503, 5.3948, 4.4796, 5.6972,
        6.4016])


  0%|          | 8/10000 [00:09<3:23:43,  1.22s/it]

tensor([-6.4962, -5.3550, -5.4676, -4.8246, -6.2939, -5.7478, -3.4882, -4.7512,
        -5.5309, -4.2642])
tensor([4.3861, 5.7362, 6.4715, 5.4288, 5.2946, 4.9918, 5.2461, 4.3633, 5.5742,
        6.1827])


  0%|          | 9/10000 [00:10<3:21:41,  1.21s/it]

tensor([-7.5906, -6.2851, -6.4207, -5.6489, -7.3958, -6.7701, -4.1223, -5.6455,
        -6.5006, -5.0138])
tensor([5.2026, 6.7419, 7.6468, 6.4295, 6.2183, 5.8887, 6.1210, 5.0775, 6.5427,
        7.2599])


  0%|          | 9/10000 [00:12<3:42:45,  1.34s/it]


tensor([-7.0223, -5.7658, -5.9014, -5.1979, -6.8062, -6.2588, -3.7850, -5.1809,
        -5.9862, -4.6146])
tensor([4.7576, 6.2298, 7.0502, 5.9307, 5.7244, 5.3821, 5.6596, 4.6919, 6.0237,
        6.7050])
robust_rate:  0.0


  0%|          | 1/10000 [00:01<3:26:47,  1.24s/it]

tensor([-7.6091, -6.3068, -6.4087, -5.6071, -7.3611, -6.8047, -4.1098, -5.6318,
        -6.5279, -5.0548])
tensor([5.1428, 6.7308, 7.6454, 6.4517, 6.2415, 5.8365, 6.1666, 5.0936, 6.5145,
        7.2330])


  0%|          | 2/10000 [00:02<3:22:59,  1.22s/it]

tensor([-7.9015, -6.5456, -6.6302, -5.8699, -7.6758, -7.0525, -4.2934, -5.8114,
        -6.7757, -5.2540])
tensor([5.4071, 6.9849, 7.9909, 6.6837, 6.5019, 6.0668, 6.3862, 5.3058, 6.8056,
        7.5007])


  0%|          | 3/10000 [00:03<3:19:42,  1.20s/it]

tensor([-7.3521, -6.0832, -6.2218, -5.4707, -7.1519, -6.5630, -3.9919, -5.4521,
        -6.3108, -4.8604])
tensor([5.0172, 6.5500, 7.3817, 6.2211, 6.0301, 5.6770, 5.9591, 4.9338, 6.3306,
        7.0241])


  0%|          | 4/10000 [00:04<3:21:19,  1.21s/it]

tensor([-7.7599, -6.4233, -6.5426, -5.7388, -7.5151, -6.9349, -4.1989, -5.7166,
        -6.6213, -5.1335])
tensor([5.2761, 6.8660, 7.8069, 6.5374, 6.3536, 5.9623, 6.2609, 5.2095, 6.6926,
        7.3743])


  0%|          | 5/10000 [00:06<3:21:02,  1.21s/it]

tensor([-8.2351, -6.8503, -6.9263, -6.0800, -7.9705, -7.3487, -4.4602, -6.0960,
        -7.0477, -5.4433])
tensor([5.6322, 7.2664, 8.2868, 6.9784, 6.7682, 6.3402, 6.6348, 5.5291, 7.0838,
        7.8453])


  0%|          | 6/10000 [00:07<3:18:59,  1.19s/it]

tensor([-7.3112, -6.0378, -6.1768, -5.4364, -7.1034, -6.5146, -3.9683, -5.4208,
        -6.2731, -4.8261])
tensor([4.9918, 6.5116, 7.3418, 6.1903, 6.0029, 5.6486, 5.9229, 4.9005, 6.2876,
        6.9925])


  0%|          | 7/10000 [00:08<3:17:39,  1.19s/it]

tensor([-7.5747, -6.2742, -6.3929, -5.6145, -7.3602, -6.7382, -4.1156, -5.6222,
        -6.5234, -4.9983])
tensor([5.1801, 6.7122, 7.6459, 6.4077, 6.2026, 5.8572, 6.1294, 5.0998, 6.4907,
        7.2753])


  0%|          | 8/10000 [00:09<3:16:07,  1.18s/it]

tensor([-7.3806, -6.1137, -6.2140, -5.4795, -7.1554, -6.5560, -3.9920, -5.4228,
        -6.3103, -4.8702])
tensor([5.0201, 6.5313, 7.3976, 6.2037, 6.0463, 5.6824, 5.9621, 4.9682, 6.3484,
        7.0348])


  0%|          | 9/10000 [00:10<3:15:34,  1.17s/it]

tensor([-8.5533, -7.1080, -7.2313, -6.3635, -8.3328, -7.6493, -4.6697, -6.3778,
        -7.3487, -5.6740])
tensor([5.8941, 7.6087, 8.6560, 7.2742, 7.0375, 6.6407, 6.9010, 5.7344, 7.3841,
        8.1863])


  0%|          | 9/10000 [00:11<3:39:44,  1.32s/it]


tensor([-7.9757, -6.5834, -6.7040, -5.9043, -7.7359, -7.1299, -4.3272, -5.9032,
        -6.8274, -5.2691])
tensor([5.4423, 7.0870, 8.0515, 6.7654, 6.5350, 6.1268, 6.4342, 5.3456, 6.8583,
        7.6221])
robust_rate:  0.0


  0%|          | 1/10000 [00:01<3:08:53,  1.13s/it]

tensor([-8.5414, -7.1071, -7.1944, -6.2993, -8.2708, -7.6562, -4.6399, -6.3392,
        -7.3507, -5.6943])
tensor([5.8119, 7.5673, 8.6224, 7.2686, 7.0332, 6.5640, 6.9222, 5.7313, 7.3295,
        8.1293])


  0%|          | 2/10000 [00:02<3:08:22,  1.13s/it]

tensor([-8.8129, -7.3287, -7.3985, -6.5469, -8.5653, -7.8853, -4.8112, -6.5034,
        -7.5806, -5.8795])
tensor([6.0614, 7.8010, 8.9453, 7.4826, 7.2758, 6.7771, 7.1248, 5.9300, 7.6029,
        8.3776])


  0%|          | 3/10000 [00:03<3:11:38,  1.15s/it]

tensor([-8.2477, -6.8512, -6.9785, -6.1364, -8.0259, -7.3829, -4.5025, -6.1333,
        -7.1007, -5.4738])
tensor([5.6619, 7.3566, 8.3193, 7.0077, 6.7924, 6.3752, 6.6845, 5.5468, 7.1144,
        7.8883])


  0%|          | 4/10000 [00:04<3:12:26,  1.16s/it]

tensor([-8.6676, -7.1990, -7.3089, -6.4120, -8.3990, -7.7648, -4.7159, -6.4071,
        -7.4202, -5.7552])
tensor([5.9264, 7.6822, 8.7557, 7.3328, 7.1248, 6.6716, 6.9943, 5.8281, 7.4863,
        8.2482])


  0%|          | 5/10000 [00:05<3:12:50,  1.16s/it]

tensor([-9.1496, -7.6354, -7.6978, -6.7588, -8.8627, -8.1851, -4.9816, -6.7915,
        -7.8553, -6.0712])
tensor([6.2900, 8.0896, 9.2457, 7.7814, 7.5459, 7.0536, 7.3766, 6.1549, 7.8840,
        8.7264])


  0%|          | 6/10000 [00:07<3:17:38,  1.19s/it]

tensor([-8.2068, -6.8058, -6.9335, -6.1021, -7.9775, -7.3344, -4.4789, -6.1019,
        -7.0630, -5.4396])
tensor([5.6365, 7.3183, 8.2794, 6.9769, 6.7652, 6.3468, 6.6483, 5.5134, 7.0715,
        7.8568])


  0%|          | 7/10000 [00:08<3:16:36,  1.18s/it]

tensor([-8.4895, -7.0590, -7.1649, -6.2932, -8.2525, -7.5752, -4.6364, -6.3174,
        -7.3302, -5.6260])
tensor([5.8370, 7.5352, 8.6044, 7.2105, 6.9804, 6.5707, 6.8712, 5.7256, 7.2916,
        8.1562])


  0%|          | 8/10000 [00:09<3:14:36,  1.17s/it]

tensor([-8.2954, -6.8985, -6.9859, -6.1582, -8.0477, -7.3931, -4.5128, -6.1180,
        -7.1170, -5.4980])
tensor([5.6770, 7.3543, 8.3560, 7.0065, 6.8241, 6.3960, 6.7038, 5.5939, 7.1493,
        7.9157])


  0%|          | 9/10000 [00:10<3:14:11,  1.17s/it]

tensor([-9.4902, -7.9094, -8.0198, -7.0591, -9.2449, -8.5035, -5.2025, -7.0898,
        -8.1746, -6.3159])
tensor([6.5672, 8.4511, 9.6379, 8.0957, 7.8338, 7.3721, 7.6588, 6.3734, 8.2021,
        9.0881])


  0%|          | 9/10000 [00:11<3:35:28,  1.29s/it]


tensor([-8.9080, -7.3837, -7.4897, -6.5964, -8.6456, -7.9815, -4.8573, -6.6106,
        -7.6501, -5.9087])
tensor([6.1114, 7.9235, 9.0285, 7.5823, 7.3267, 6.8543, 7.1898, 5.9833, 7.6734,
        8.5184])
robust_rate:  0.0


  0%|          | 1/10000 [00:01<3:13:30,  1.16s/it]

tensor([-9.4737, -7.9073, -7.9801, -6.9914, -9.1805, -8.5078, -5.1700, -7.0466,
        -8.1735, -6.3338])
tensor([6.4810, 8.4037, 9.5994, 8.0854, 7.8249, 7.2915, 7.6778, 6.3691, 8.1446,
        9.0255])


  0%|          | 2/10000 [00:02<3:18:02,  1.19s/it]

tensor([-9.7260, -8.1121, -8.1691, -7.2244, -9.4560, -8.7207, -5.3310, -7.1973,
        -8.3859, -6.5061])
tensor([6.7172, 8.6223, 9.9019, 8.2839, 8.0520, 7.4893, 7.8651, 6.5547, 8.4023,
        9.2568])


  0%|          | 3/10000 [00:03<3:19:09,  1.20s/it]

tensor([-9.1561, -7.6304, -7.7454, -6.8108, -8.9122, -8.2142, -5.0199, -6.8238,
        -7.9018, -6.0968])
tensor([6.3148, 8.1742, 9.2709, 7.8051, 7.5650, 7.0836, 7.4208, 6.1682, 7.9096,
        8.7637])


  0%|          | 4/10000 [00:04<3:21:06,  1.21s/it]

tensor([-9.5624, -7.9643, -8.0650, -7.0757, -9.2706, -8.5837, -5.2255, -7.0883,
        -8.2077, -6.3685])
tensor([6.5675, 8.4871, 9.6913, 8.1180, 7.8857, 7.3706, 7.7179, 6.4380, 8.2697,
        9.1099])


  0%|          | 5/10000 [00:06<3:23:19,  1.22s/it]

tensor([-10.0633,  -8.4193,  -8.4688,  -7.4367,  -9.7540,  -9.0211,  -5.5018,
         -7.4860,  -8.6611,  -6.6982])
tensor([ 6.9462,  8.9116, 10.2030,  8.5833,  8.3227,  7.7662,  8.1174,  6.7799,
         8.6839,  9.6062])


  0%|          | 6/10000 [00:07<3:27:29,  1.25s/it]

tensor([-9.1024, -7.5738, -7.6903, -6.7678, -8.8516, -8.1543, -4.9895, -6.7830,
        -7.8529, -6.0530])
tensor([6.2812, 8.1249, 9.2170, 7.7635, 7.5275, 7.0450, 7.3736, 6.1264, 7.8553,
        8.7210])


  0%|          | 7/10000 [00:08<3:23:11,  1.22s/it]

tensor([-9.4043, -7.8437, -7.9369, -6.9720, -9.1448, -8.4123, -5.1572, -7.0126,
        -8.1369, -6.2538])
tensor([6.4939, 8.3582, 9.5629, 8.0133, 7.7581, 7.2842, 7.6129, 6.3513, 8.0925,
        9.0371])


  0%|          | 8/10000 [00:09<3:18:54,  1.19s/it]

tensor([-9.2102, -7.6832, -7.7579, -6.8370, -8.9400, -8.2301, -5.0335, -6.8132,
        -7.9238, -6.1257])
tensor([6.3339, 8.1773, 9.3145, 7.8093, 7.6018, 7.1095, 7.4455, 6.2196, 7.9502,
        8.7966])


  0%|          | 9/10000 [00:10<3:17:05,  1.18s/it]

tensor([-10.4103,  -8.6966,  -8.7940,  -7.7426, -10.1408,  -9.3419,  -5.7258,
         -7.7890,  -8.9862,  -6.9461])
tensor([ 7.2280,  9.2774, 10.6016,  8.9023,  8.6151,  8.0897,  8.4019,  7.0003,
         9.0049,  9.9736])


  0%|          | 9/10000 [00:11<3:41:07,  1.33s/it]

tensor([-9.8383, -8.1822, -8.2739, -7.2871, -9.5534, -8.8315, -5.3864, -7.3167,
        -8.4712, -6.5469])
tensor([ 6.7792,  8.7585, 10.0035,  8.3976,  8.1169,  7.5803,  7.9439,  6.6198,
         8.4869,  9.4130])
robust_rate:  0.0



