## Over parameterized network Non decreasing lr

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.autograd.functional import hessian
import torch.utils.data as data_utils
import matplotlib.pyplot as plt
import random

from tqdm.notebook import tqdm 
import pickle
import matplotlib.pyplot as plt
import matplotlib.colors
import seaborn as sns
import os
import copy
from torch.nn.utils import _stateless

batch_size = 1
num_workers = 1
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# device

details = {}
details['use_db'] = 'mnist'
details['result_root_dir']='results/t0/'
details['result_path']='try1_t8_w25'
details['g_weight'] = [25]
# details['ratio'] = 15
details['book_keep_freq'] = 20
details['g_times'] = 8
details['g_epochs'] = 10000
details['alpha_0']= 0.003
details['freq_reduce_by'] = 20
details['freq_reduce_after'] = 100

details['training_step_limit'] = 2000000 ## this is to train for max updates per epochs
details['stop_hess_computation'] = 1000000 ## Stop computing hessian after calculated these many times


print(f'selected weight:{details["g_weight"]}')

with open(details['result_root_dir']+'details_'+details['result_path']+'.txt', 'w+') as f:
    for key, val in details.items():
        content = key + ' : '+str(val) + '\n'
        f.write(content)
        
torch.manual_seed(3407)
np.random.seed(3407)
torch.cuda.manual_seed_all(3407)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# torch.use_deterministic_algorithms(True)
# os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"

train_data_all = torchvision.datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

test_data_all = torchvision.datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transforms.ToTensor()
)
# print(f'train data:{train_data}')
# print(f'test data:{test_data}')

def get_random_subset(train_data_all, test_data_all):    
    # train_indices = torch.arange(20000)
    test_indices = torch.arange(256)
    train_indices = torch.randint(60000-1, (2000,))
    # print(f'train indices:{train_indices[:10]}')
    train_data = data_utils.Subset(train_data_all, train_indices)
    test_data = data_utils.Subset(test_data_all, test_indices)
    # print(f'train data:{train_data}')
    # print(f'test data:{test_data}')
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    g = torch.Generator()
    g.manual_seed(0)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        num_workers=num_workers,
        worker_init_fn=seed_worker,
        generator=g,
    )
    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=256,
        num_workers=num_workers,
        worker_init_fn=seed_worker,
        generator=g,
    )
    print(f'train data size:{len(train_loader.dataset)}')
    print(f'test data size:{len(test_loader.dataset)}')
    # X_mat, y_mat = torch.Tensor(len(train_loader.dataset),784), torch.Tensor(len(train_loader.dataset)).long()
    # for i, (data, label) in enumerate(train_loader):
    #     X_mat[i] = data.flatten()
    #     y_mat[i] = label.flatten()
    # print(f'X_mat shape:{X_mat.shape}, y_mat shape:{y_mat.shape}')
    return train_loader, test_loader #, X_mat, y_mat


class Net(nn.Module):

    def __init__(self, input_features, hidden_layers, output_size):
        super(Net, self).__init__()
        self.layers = len(hidden_layers) + 1
        self.total_params_len = 0
        self.fc_layers = nn.ModuleList()
        prev_weight = input_features
        
        for i, weight in enumerate(hidden_layers):
            self.fc_layers.append(nn.Linear(prev_weight, weight))
            self.total_params_len += prev_weight*weight + weight
            prev_weight = weight
        
        self.fc_last = nn.Linear(hidden_layers[-1], output_size)
        self.total_params_len += hidden_layers[-1]*output_size + output_size
        
        ### Others required params
        self.param_list = []

    def forward(self, x):
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        # print('x shape in forward',x.shape)
        for fc_layer in self.fc_layers:
            x = F.relu(fc_layer(x))
        x = self.fc_last(x)
        return x
    
    def fit(self, X, Y, X_val, Y_val, epochs, batch_size=1, **kwargs):
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(self.parameters())
        

class Train_nn:
    
    def __init__(self, input_features, hidden_layers, output_size, lr, decay=True):
        self.model = Net(input_features, hidden_layers=hidden_layers, output_size=output_size)
        self.model.to(device)
        self.loss_fn = nn.CrossEntropyLoss()
        if decay:
            lr_lambda = lambda it: 1/(it+1)
        else:
            lr_lambda = lambda it: 1
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda= lr_lambda)
        
    def get_loss(self, X, y, params=None):
        # if params is not None:
        assert False, "Model not initialized with given params"
        op = self.model(X)
        loss = self.loss_fn(pred, y)
        return loss
    
    def get_gradient(self):
        params = (self.model.parameters())
        grad_norm_sq = torch.tensor(0, dtype=float).to(device)
        # print('grad Norm init:', grad_norm_sq)
        for param in self.model.parameters():
            temp = param.grad.data.pow(2).sum()
            # print(f'param grad norm \n\tsum:{temp.data}')#,\n\tshape:{param.shape}')
            grad_norm_sq += temp
            
        return grad_norm_sq.sqrt().cpu()
    
    def get_gradientv2(self, X, y):
        names = list(n for n, _ in self.model.named_parameters())
        def loss_fun_grad(*params):
            out: torch.Tensor = _stateless.functional_call(self.model, {n: p for n, p in zip(names, params)}, X)
            local_loss = self.loss_fn(out, y)
            return local_loss
        grad_mat = torch.autograd.grad(loss_fun_grad, tuple(self.model.parameters()))
        # print(f'len of hess mat:{len(hess_mat)}')
        # print(f'hess_mat[0] shape:{len(hess_mat[0])}')
        # print(f'hess_mat[0][0] shape:{hess_mat[0][0].shape}')
        grad_norm = torch.tensor(0.).to(device)
        for i in range(len(grad_mat)):
            for j in range(len(grad_mat[0])):
                grad_norm+= grad_mat[i][j].pow(2).sum()
        grad_norm = grad_norm.sqrt()
        # print(f'v2 hess norm{hess_norm}')
        return grad_norm.cpu()
    
    def try_operator_norm(self, hess_mat):
        for i in len(hess_mat):
            for j in len(hess_mat[0]):
                torch.unsqueeze(hess_mat[i][i],0)
        hess_tensor_dim = list(hess_mat[0][0].shape)
        hess_tensor_dim += [n*2,n*2]
        hess_mat_np = np.zeros(shape=hess_tensor_dim)
        hess_tensor = torch.tensor(hess_mat_np)
        torch.cat(hess_mat, out=hess_tensor)
        
        hess_mat.reshpe(n*2,n*2)
        hess_norm = torch.linalg.norm(hess_mat, 2)
        assert False, "Not working"
    
    def get_hessian(self, X, y):
        prev_params = copy.deepcopy(list(self.model.parameters()))
        n = self.model.layers
        def local_model(*params):
            # print(f'len of params:{len(params)}')
            # print(f'shape of params[0]:{params[0].shape}')
            # with torch.no_grad():
            #initialize model with given params
            i = 0
            for i, param in enumerate(self.model.parameters()):
                param.data = params[i]
            pred = self.model(X)
            loss = self.loss_fn(pred, y)
            # print(f'loss type:{type(loss)}')
            return loss
        p =list(self.model.parameters())
        hess_mat = hessian(local_model, tuple(p))
        hess_norm = torch.tensor(0.).to(device)
        for i in range(len(hess_mat)):
            for j in range(len(hess_mat[0])):
                hess_norm+= hess_mat[i][j].pow(2).sum()
        
        # print(f'Hess mat len:{len(hess_mat)}')
        # print(f'Hess mat[0] len:{len(hess_mat[0])}')
        # print(f'Hess mat[0][0] shape:{hess_mat[0][0].shape}')
        
        hess_norm = hess_norm.sqrt()
        # print(f'hess norm:{hess_norm}')
        
        # Reinitialize the original params to model
        for i, param in enumerate(self.model.parameters()):
                param.data = prev_params[i]
        
        return hess_norm
    
    def get_hessianv2(self, X,y):
        names = list(n for n, _ in self.model.named_parameters())
        def loss_fun_hess(*params):
            out: torch.Tensor = _stateless.functional_call(self.model, {n: p for n, p in zip(names, params)}, X)
            local_loss = self.loss_fn(out, y)
            return local_loss
        hess_mat = hessian(loss_fun_hess, tuple(self.model.parameters()))
        # print(f'len of hess mat:{len(hess_mat)}')
        # print(f'hess_mat[0] shape:{len(hess_mat[0])}')
        # print(f'hess_mat[0][0] shape:{hess_mat[0][0].shape}')
        hess_norm = torch.tensor(0.).to(device)
        for i in range(len(hess_mat)):
            for j in range(len(hess_mat[0])):
                hess_norm+= hess_mat[i][j].pow(2).sum()
        hess_norm = hess_norm.sqrt()
        # print(f'v2 hess norm{hess_norm}')
        return hess_norm.cpu()
        
    def fit(self, train_loader, test_loader, epochs, store_grads=False, store_hessian=False, store_gen_err=False, store_weights=False, store_pt_loss=True, store_freq = 20, freq_reduce_by=None, freq_reduce_after=None):
        
        ## For Book keeping results ##
        self.grads_norms = []
        self.grads_normsv2 = []
        self.param_list = []
        self.hess_norms = []
        self.gen_err = []
        self.train_loss = []
        self.val_loss = []
        self.point_loss = []
        ## Initializing values ##
        terminate_training = False
        store_count = 0
        
        for epoch in tqdm(range(epochs), total=epochs, unit="epoch", disable=True):
            if terminate_training == True:
                break
            for batch, (X, y) in tqdm(enumerate(train_loader), total=len(train_loader), unit='batch'):
                if batch>details['training_step_limit']:
                    terminate_training = True
                    break
                
                X, y =X.to(device), y.to(device)
                pred = self.model(X)
                loss = self.loss_fn(pred, y)

                # Backpropagation
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                ## Saving point loss
                if store_pt_loss and (batch%store_freq==0):
                    self.point_loss.append(loss.item())
                    
                ## Saving the weights
                if store_weights and (batch%store_freq==0):
                    current_params = tuple(self.model.parameters())
                    self.param_list.append(current_params)
                
                ## computing and saving the gradient
                if store_grads and (batch% store_freq == 0):
                    # store_count += 1
                    # # print(f'\tstore_freq:{store_freq}, batch:{batch}')
                    # if store_count%freq_reduce_after==0:
                    #     store_freq += freq_reduce_by
                    #     # print(f'store freq:{store_freq}, batch:{batch}')
                    grad_norm_per_update = self.get_gradient()
                    print('grad:', grad_norm_per_update)
                    # print('\tgrad norm:', grad_norm_per_update)
                    self.grads_norms.append(grad_norm_per_update)
                    # self.grads_normsv2.append(self.get_gradientv2(X,y))
                ## computing and saving hessian
                if store_hessian and (batch% store_freq==0):
                    #assert False, "Not implemented"
                    self.optimizer.zero_grad()
                    hess_val = self.get_hessianv2(X,y)
                    print('hess:',hess_val)
                    self.hess_norms.append(hess_val)
                    store_count += 1
                    if store_count%freq_reduce_after==0:
                        store_freq += freq_reduce_by
                
                ## computing and storing the generalization error
                if store_gen_err and (batch% store_freq == 0):
                    assert False, "fix reducing freq to get it working and fastX, fasty"
                    train_loss, test_loss, point_loss=0, 0, 0
                    with torch.no_grad():
                        for sub_batch, (X_local,y_local) in enumerate(train_loader):
                            # if sub_batch> batch: # only taking the encountered points to calculate train loss
                            #     break
                            X_local, y_local = X_local.to(device), y_local.to(device)
                            pred_local = self.model(X_local)
                            train_loss += self.loss_fn(pred_local, y_local).item()
                    train_loss = train_loss/(batch+1)
                    with torch.no_grad():
                        for sub_batch, (X_local,y_local) in enumerate(test_loader):
                            X_local, y_local = X_local.to(device), y_local.to(device)
                            pred_local = self.model(X_local)
                            test_loss += self.loss_fn(pred_local, y_local).item()
                    test_batch_size = len(test_loader)
                    # print(f"Number of batches in test:{len(test_loader)}")
                    test_loss = test_loss/ len(test_loader)
                    self.train_loss.append(train_loss)
                    self.val_loss.append(test_loss)
                
                if batch % 1000 == 0:
                    loss, current = loss.item(), batch * len(X)
                    correct = 0
                    test_loss = 0
                    with torch.no_grad():
                        for X, y in test_loader:
                            X, y = X.to(device), y.to(device)
                            pred = self.model(X)
                            test_loss += self.loss_fn(pred, y).item()
                            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
                    acc = 100*correct/len(test_loader.dataset)
                    print(f"\taccuracy:{acc}")#, at batch:{batch}")
                    print(f"\tloss: {loss:>7f}")
                
                    # print(f'Learning rate:{self.scheduler.get_last_lr()}')
                self.scheduler.step()
            
def exp_get_lp_sm(train_data_all, test_data_all, op_features, weight = 10, times = 8, epochs = 1, root_dir='', path=None, clear_file = True, freq_reduce_by=10, freq_reduce_after=100):
    grad_list        = []
    hess_norm_list   = []
    if path is not None:
        grad_file_path = root_dir+'grad_'+path
        hess_file_path = root_dir+'hess_'+path
        # gen_file_path = root_dir+'gen_'+path
        if clear_file:
            with open(grad_file_path, 'w+') as f:
                f.write('')
            with open(hess_file_path, 'w+') as f:
                f.write('')
    
    train_loader, test_loader = get_random_subset(train_data_all, test_data_all)
    for t in range(times):
        print(f'Time:{t}')
        train_model = Train_nn(784, weight, op_features, lr= details['alpha_0'], decay=False)
        train_model.fit(train_loader, test_loader, epochs=epochs, store_grads=True, store_hessian=True, store_freq=details['book_keep_freq'],  store_gen_err=False, store_pt_loss=False, store_weights=False, freq_reduce_by = freq_reduce_by, freq_reduce_after=freq_reduce_after, )
        
        with open(grad_file_path,'a+') as f:
            f.write(' '.join([str(grad) for grad in train_model.grads_norms]) + '\n')
        with open(hess_file_path,'a+') as f:
            f.write(' '.join([str(hess) for hess in train_model.hess_norms]) + '\n')
        
        hess_norm_list.append(train_model.hess_norms)
        grad_list.append(train_model.grads_norms)
         
    return grad_list, hess_norm_list

grad_list, hess_norm_list = exp_get_lp_sm(train_data_all, test_data_all, op_features=10, 
              weight=details['g_weight'], times=details['g_times'], 
              epochs=details['g_epochs'], root_dir=details['result_root_dir'], 
              path=details['result_path'], freq_reduce_by=details['freq_reduce_by'], 
              freq_reduce_after=details['freq_reduce_after'])


selected weight:[25]
train data size:2000
test data size:256
Time:0


  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(4.5351, dtype=torch.float64)
hess: tensor(58.4720)
	accuracy:5.859375
	loss: 2.299834
grad: tensor(3.5895, dtype=torch.float64)
hess: tensor(39.1836)
grad: tensor(2.2147, dtype=torch.float64)
hess: tensor(26.5113)
grad: tensor(3.9376, dtype=torch.float64)
hess: tensor(50.1822)
grad: tensor(2.2974, dtype=torch.float64)
hess: tensor(26.8003)
grad: tensor(2.9156, dtype=torch.float64)
hess: tensor(33.6059)
grad: tensor(5.0878, dtype=torch.float64)
hess: tensor(49.7428)
grad: tensor(3.5337, dtype=torch.float64)
hess: tensor(44.7176)
grad: tensor(3.6666, dtype=torch.float64)
hess: tensor(45.3936)
grad: tensor(2.7725, dtype=torch.float64)
hess: tensor(30.5925)
grad: tensor(4.3674, dtype=torch.float64)
hess: tensor(45.4752)
grad: tensor(2.2625, dtype=torch.float64)
hess: tensor(44.0066)
grad: tensor(3.6647, dtype=torch.float64)
hess: tensor(35.5010)
grad: tensor(3.1759, dtype=torch.float64)
hess: tensor(45.3310)
grad: tensor(3.1716, dtype=torch.float64)
hess: tensor(33.1775)
grad:

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(0.6732, dtype=torch.float64)
hess: tensor(12.1683)
	accuracy:80.859375
	loss: 0.038415
grad: tensor(2.7152, dtype=torch.float64)
hess: tensor(16.9440)
grad: tensor(3.7142, dtype=torch.float64)
hess: tensor(25.6924)
grad: tensor(4.2961, dtype=torch.float64)
hess: tensor(53.0776)
grad: tensor(10.2782, dtype=torch.float64)
hess: tensor(77.2385)
grad: tensor(9.0738, dtype=torch.float64)
hess: tensor(63.0244)
grad: tensor(5.6897, dtype=torch.float64)
hess: tensor(43.7116)
grad: tensor(6.2583, dtype=torch.float64)
hess: tensor(37.6650)
grad: tensor(5.0936, dtype=torch.float64)
hess: tensor(52.7293)
grad: tensor(6.3582, dtype=torch.float64)
hess: tensor(48.2746)
grad: tensor(0.9966, dtype=torch.float64)
hess: tensor(14.3808)
grad: tensor(11.2264, dtype=torch.float64)
hess: tensor(73.9123)
grad: tensor(0.9698, dtype=torch.float64)
hess: tensor(15.8960)
grad: tensor(5.6511, dtype=torch.float64)
hess: tensor(40.1757)
grad: tensor(2.9539, dtype=torch.float64)
hess: tensor(32.3606)
gr

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(0.0593, dtype=torch.float64)
hess: tensor(1.4905)
	accuracy:87.109375
	loss: 0.002547
grad: tensor(1.1570, dtype=torch.float64)
hess: tensor(9.7434)
grad: tensor(1.8118, dtype=torch.float64)
hess: tensor(17.3874)
grad: tensor(3.2804, dtype=torch.float64)
hess: tensor(56.6832)
grad: tensor(11.3960, dtype=torch.float64)
hess: tensor(101.9069)
grad: tensor(13.9606, dtype=torch.float64)
hess: tensor(94.2925)
grad: tensor(5.4576, dtype=torch.float64)
hess: tensor(49.4632)
grad: tensor(8.0192, dtype=torch.float64)
hess: tensor(48.9318)
grad: tensor(1.9218, dtype=torch.float64)
hess: tensor(30.1544)
grad: tensor(7.4399, dtype=torch.float64)
hess: tensor(63.2250)
grad: tensor(0.1346, dtype=torch.float64)
hess: tensor(2.6641)
grad: tensor(15.8370, dtype=torch.float64)
hess: tensor(103.1075)
grad: tensor(0.1872, dtype=torch.float64)
hess: tensor(4.1127)
grad: tensor(3.8975, dtype=torch.float64)
hess: tensor(36.9925)
grad: tensor(0.9333, dtype=torch.float64)
hess: tensor(14.3537)
gra

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(0.0133, dtype=torch.float64)
hess: tensor(0.3750)
	accuracy:90.234375
	loss: 0.000504
grad: tensor(1.5034, dtype=torch.float64)
hess: tensor(25.9389)
grad: tensor(3.2425, dtype=torch.float64)
hess: tensor(61.9475)
grad: tensor(11.5457, dtype=torch.float64)
hess: tensor(59.0236)
grad: tensor(4.9679, dtype=torch.float64)
hess: tensor(51.4653)
grad: tensor(0.7729, dtype=torch.float64)
hess: tensor(17.6038)
grad: tensor(8.7204, dtype=torch.float64)
hess: tensor(76.5834)
grad: tensor(1.0467, dtype=torch.float64)
hess: tensor(11.8190)
grad: tensor(0.0759, dtype=torch.float64)
hess: tensor(1.8382)
grad: tensor(0.0308, dtype=torch.float64)
hess: tensor(0.7116)
grad: tensor(0.1770, dtype=torch.float64)
hess: tensor(3.4407)
grad: tensor(2.5265, dtype=torch.float64)
hess: tensor(34.4412)
grad: tensor(0.2842, dtype=torch.float64)
hess: tensor(4.3476)
grad: tensor(3.6161, dtype=torch.float64)
hess: tensor(44.7855)
grad: tensor(0.1913, dtype=torch.float64)
hess: tensor(3.5110)
grad: ten

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(0.0045, dtype=torch.float64)
hess: tensor(0.1325)
	accuracy:91.015625
	loss: 0.000160
grad: tensor(1.2360, dtype=torch.float64)
hess: tensor(23.3221)
grad: tensor(3.0189, dtype=torch.float64)
hess: tensor(62.1040)
grad: tensor(12.3685, dtype=torch.float64)
hess: tensor(62.3671)
grad: tensor(4.3480, dtype=torch.float64)
hess: tensor(50.0266)
grad: tensor(0.3639, dtype=torch.float64)
hess: tensor(9.2279)
grad: tensor(9.7236, dtype=torch.float64)
hess: tensor(86.6660)
grad: tensor(0.7911, dtype=torch.float64)
hess: tensor(9.5734)
grad: tensor(0.0434, dtype=torch.float64)
hess: tensor(1.1162)
grad: tensor(0.0132, dtype=torch.float64)
hess: tensor(0.3217)
grad: tensor(0.0991, dtype=torch.float64)
hess: tensor(2.0703)
grad: tensor(1.8367, dtype=torch.float64)
hess: tensor(28.3656)
grad: tensor(0.1925, dtype=torch.float64)
hess: tensor(3.1722)
grad: tensor(3.7434, dtype=torch.float64)
hess: tensor(49.2356)
grad: tensor(0.1275, dtype=torch.float64)
hess: tensor(2.4867)
grad: tenso

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(0.0018, dtype=torch.float64)
hess: tensor(0.0542)
	accuracy:91.40625
	loss: 0.000060
grad: tensor(1.0255, dtype=torch.float64)
hess: tensor(20.6394)
grad: tensor(2.7201, dtype=torch.float64)
hess: tensor(59.5393)
grad: tensor(12.8666, dtype=torch.float64)
hess: tensor(60.9015)
grad: tensor(3.7952, dtype=torch.float64)
hess: tensor(47.6078)
grad: tensor(0.2062, dtype=torch.float64)
hess: tensor(5.5783)
grad: tensor(10.4202, dtype=torch.float64)
hess: tensor(93.8042)
grad: tensor(0.6474, dtype=torch.float64)
hess: tensor(8.2042)
grad: tensor(0.0286, dtype=torch.float64)
hess: tensor(0.7663)
grad: tensor(0.0074, dtype=torch.float64)
hess: tensor(0.1874)
grad: tensor(0.0651, dtype=torch.float64)
hess: tensor(1.4286)
grad: tensor(1.3419, dtype=torch.float64)
hess: tensor(22.2155)
grad: tensor(0.1406, dtype=torch.float64)
hess: tensor(2.4419)
grad: tensor(3.6830, dtype=torch.float64)
hess: tensor(50.2570)
grad: tensor(0.1014, dtype=torch.float64)
hess: tensor(2.1399)
grad: tenso

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(0.0008, dtype=torch.float64)
hess: tensor(0.0247)
	accuracy:91.40625
	loss: 0.000026
grad: tensor(0.4831, dtype=torch.float64)
hess: tensor(6.6571)
grad: tensor(13.6404, dtype=torch.float64)
hess: tensor(142.7451)
grad: tensor(3.2497, dtype=torch.float64)
hess: tensor(44.0412)
grad: tensor(0.6178, dtype=torch.float64)
hess: tensor(14.1735)
grad: tensor(0.0111, dtype=torch.float64)
hess: tensor(0.2798)
grad: tensor(0.0200, dtype=torch.float64)
hess: tensor(0.5536)
grad: tensor(0.1219, dtype=torch.float64)
hess: tensor(2.6078)
grad: tensor(0.0449, dtype=torch.float64)
hess: tensor(1.1378)
grad: tensor(0.1054, dtype=torch.float64)
hess: tensor(1.9113)
grad: tensor(0.3723, dtype=torch.float64)
hess: tensor(4.2067)
grad: tensor(0.0292, dtype=torch.float64)
hess: tensor(0.7262)
grad: tensor(1.9577, dtype=torch.float64)
hess: tensor(32.8240)
	accuracy:88.28125
	loss: 0.009677
grad: tensor(0.0703, dtype=torch.float64)
hess: tensor(1.2370)
grad: tensor(0.0319, dtype=torch.float64)


  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(0.0004, dtype=torch.float64)
hess: tensor(0.0118)
	accuracy:91.40625
	loss: 0.000011
grad: tensor(0.4200, dtype=torch.float64)
hess: tensor(6.0446)
grad: tensor(13.9863, dtype=torch.float64)
hess: tensor(146.5914)
grad: tensor(2.7982, dtype=torch.float64)
hess: tensor(40.5312)
grad: tensor(0.5318, dtype=torch.float64)
hess: tensor(12.7566)
grad: tensor(0.0079, dtype=torch.float64)
hess: tensor(0.2057)
grad: tensor(0.0151, dtype=torch.float64)
hess: tensor(0.4284)
grad: tensor(0.0830, dtype=torch.float64)
hess: tensor(1.8437)
grad: tensor(0.0363, dtype=torch.float64)
hess: tensor(0.9505)
grad: tensor(0.0816, dtype=torch.float64)
hess: tensor(1.5351)
grad: tensor(0.3337, dtype=torch.float64)
hess: tensor(3.8528)
grad: tensor(0.0226, dtype=torch.float64)
hess: tensor(0.5808)
grad: tensor(1.5938, dtype=torch.float64)
hess: tensor(28.2442)
	accuracy:88.671875
	loss: 0.007423
grad: tensor(0.0529, dtype=torch.float64)
hess: tensor(0.9614)
grad: tensor(0.0260, dtype=torch.float64)

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(0.0002, dtype=torch.float64)
hess: tensor(0.0060)
	accuracy:91.40625
	loss: 0.000005
grad: tensor(0.3816, dtype=torch.float64)
hess: tensor(5.6503)
grad: tensor(14.4766, dtype=torch.float64)
hess: tensor(150.0987)
grad: tensor(2.3487, dtype=torch.float64)
hess: tensor(33.1273)
grad: tensor(0.4653, dtype=torch.float64)
hess: tensor(11.5811)
grad: tensor(0.0056, dtype=torch.float64)
hess: tensor(0.1508)
grad: tensor(0.0120, dtype=torch.float64)
hess: tensor(0.3529)
grad: tensor(0.0580, dtype=torch.float64)
hess: tensor(1.3285)
grad: tensor(0.0297, dtype=torch.float64)
hess: tensor(0.8041)
grad: tensor(0.0633, dtype=torch.float64)
hess: tensor(1.2279)
grad: tensor(0.2950, dtype=torch.float64)
hess: tensor(3.4906)
grad: tensor(0.0187, dtype=torch.float64)
hess: tensor(0.4921)
grad: tensor(1.3132, dtype=torch.float64)
hess: tensor(24.4141)
	accuracy:88.671875
	loss: 0.005933
grad: tensor(0.0424, dtype=torch.float64)
hess: tensor(0.7906)
grad: tensor(0.0230, dtype=torch.float64)

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(9.3450e-05, dtype=torch.float64)
hess: tensor(0.0032)
	accuracy:91.015625
	loss: 0.000003
grad: tensor(0.3601, dtype=torch.float64)
hess: tensor(5.5104)
grad: tensor(14.5520, dtype=torch.float64)
hess: tensor(151.3354)
grad: tensor(1.9719, dtype=torch.float64)
hess: tensor(30.9321)
grad: tensor(0.3937, dtype=torch.float64)
hess: tensor(10.1300)
grad: tensor(0.0042, dtype=torch.float64)
hess: tensor(0.1165)
grad: tensor(0.0097, dtype=torch.float64)
hess: tensor(0.2934)
grad: tensor(0.0405, dtype=torch.float64)
hess: tensor(0.9538)
grad: tensor(0.0248, dtype=torch.float64)
hess: tensor(0.6884)
grad: tensor(0.0505, dtype=torch.float64)
hess: tensor(1.0069)
grad: tensor(0.2689, dtype=torch.float64)
hess: tensor(3.2494)
grad: tensor(0.0157, dtype=torch.float64)
hess: tensor(0.4235)
grad: tensor(1.1412, dtype=torch.float64)
hess: tensor(21.9109)
	accuracy:89.453125
	loss: 0.004795
grad: tensor(0.0338, dtype=torch.float64)
hess: tensor(0.6368)
grad: tensor(0.0207, dtype=torch.flo

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(5.1343e-05, dtype=torch.float64)
hess: tensor(0.0018)
	accuracy:91.015625
	loss: 0.000002
grad: tensor(0.6678, dtype=torch.float64)
hess: tensor(9.2663)
grad: tensor(13.9366, dtype=torch.float64)
hess: tensor(138.0523)
grad: tensor(0.0281, dtype=torch.float64)
hess: tensor(0.8808)
grad: tensor(0.0033, dtype=torch.float64)
hess: tensor(0.0923)
grad: tensor(9.6263, dtype=torch.float64)
hess: tensor(124.9718)
grad: tensor(0.0170, dtype=torch.float64)
hess: tensor(0.4313)
grad: tensor(0.3928, dtype=torch.float64)
hess: tensor(9.8151)
grad: tensor(0.2424, dtype=torch.float64)
hess: tensor(2.9879)
grad: tensor(1.3539, dtype=torch.float64)
hess: tensor(19.3193)
grad: tensor(0.0629, dtype=torch.float64)
hess: tensor(1.0726)
	accuracy:89.0625
	loss: 0.003833
grad: tensor(6.7324, dtype=torch.float64)
hess: tensor(65.4687)
grad: tensor(0.9122, dtype=torch.float64)
hess: tensor(22.6086)
grad: tensor(11.1125, dtype=torch.float64)
hess: tensor(80.0655)
grad: tensor(1.2082, dtype=torch.f

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(3.1098e-05, dtype=torch.float64)
hess: tensor(0.0011)
	accuracy:91.40625
	loss: 0.000001
grad: tensor(0.6163, dtype=torch.float64)
hess: tensor(8.8203)
grad: tensor(12.9525, dtype=torch.float64)
hess: tensor(135.8791)
grad: tensor(0.0206, dtype=torch.float64)
hess: tensor(0.6577)
grad: tensor(0.0028, dtype=torch.float64)
hess: tensor(0.0808)
grad: tensor(8.6227, dtype=torch.float64)
hess: tensor(120.5818)
grad: tensor(0.0137, dtype=torch.float64)
hess: tensor(0.3553)
grad: tensor(0.3619, dtype=torch.float64)
hess: tensor(9.2467)
grad: tensor(0.2230, dtype=torch.float64)
hess: tensor(2.8169)
grad: tensor(1.2879, dtype=torch.float64)
hess: tensor(18.8604)
grad: tensor(0.0560, dtype=torch.float64)
hess: tensor(0.9954)
	accuracy:89.0625
	loss: 0.003297
grad: tensor(6.7860, dtype=torch.float64)
hess: tensor(67.7225)
grad: tensor(0.9038, dtype=torch.float64)
hess: tensor(23.0043)
grad: tensor(10.8327, dtype=torch.float64)
hess: tensor(82.2995)
grad: tensor(1.1045, dtype=torch.fl

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(2.0783e-05, dtype=torch.float64)
hess: tensor(0.0007)
	accuracy:91.40625
	loss: 0.000001
grad: tensor(0.5768, dtype=torch.float64)
hess: tensor(8.4897)
grad: tensor(12.2112, dtype=torch.float64)
hess: tensor(134.4156)
grad: tensor(0.0149, dtype=torch.float64)
hess: tensor(0.4835)
grad: tensor(0.0023, dtype=torch.float64)
hess: tensor(0.0687)
grad: tensor(8.0286, dtype=torch.float64)
hess: tensor(108.7518)
grad: tensor(0.0103, dtype=torch.float64)
hess: tensor(0.2723)
grad: tensor(0.3258, dtype=torch.float64)
hess: tensor(8.5074)
grad: tensor(0.2042, dtype=torch.float64)
hess: tensor(2.6030)
grad: tensor(1.3462, dtype=torch.float64)
hess: tensor(21.1110)
grad: tensor(0.0482, dtype=torch.float64)
hess: tensor(0.8773)
	accuracy:89.453125
	loss: 0.002776
grad: tensor(6.8510, dtype=torch.float64)
hess: tensor(70.0010)
grad: tensor(0.8788, dtype=torch.float64)
hess: tensor(23.0609)
grad: tensor(11.3279, dtype=torch.float64)
hess: tensor(84.8255)
grad: tensor(1.0248, dtype=torch.

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(8.5815e-06, dtype=torch.float64)
hess: tensor(0.0003)
	accuracy:91.796875
	loss: 0.000000
grad: tensor(0.5393, dtype=torch.float64)
hess: tensor(8.1592)
grad: tensor(11.7546, dtype=torch.float64)
hess: tensor(134.2470)
grad: tensor(0.0116, dtype=torch.float64)
hess: tensor(0.3821)
grad: tensor(0.0020, dtype=torch.float64)
hess: tensor(0.0592)
grad: tensor(7.5289, dtype=torch.float64)
hess: tensor(108.7723)
grad: tensor(0.0079, dtype=torch.float64)
hess: tensor(0.2150)
grad: tensor(0.3054, dtype=torch.float64)
hess: tensor(8.1306)
grad: tensor(0.1915, dtype=torch.float64)
hess: tensor(2.4973)
grad: tensor(1.3151, dtype=torch.float64)
hess: tensor(21.1770)
grad: tensor(0.0415, dtype=torch.float64)
hess: tensor(0.7715)
	accuracy:89.453125
	loss: 0.002340
grad: tensor(6.7879, dtype=torch.float64)
hess: tensor(67.9265)
grad: tensor(0.8550, dtype=torch.float64)
hess: tensor(22.9148)
grad: tensor(11.2584, dtype=torch.float64)
hess: tensor(87.0411)
grad: tensor(0.9431, dtype=torch

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(7.6855e-06, dtype=torch.float64)
hess: tensor(0.0003)
	accuracy:91.796875
	loss: 0.000000
grad: tensor(0.4965, dtype=torch.float64)
hess: tensor(7.4739)
grad: tensor(11.3509, dtype=torch.float64)
hess: tensor(132.5065)
grad: tensor(0.0084, dtype=torch.float64)
hess: tensor(0.2831)
grad: tensor(0.0017, dtype=torch.float64)
hess: tensor(0.0513)
grad: tensor(6.6930, dtype=torch.float64)
hess: tensor(107.9810)
grad: tensor(0.0060, dtype=torch.float64)
hess: tensor(0.1670)
grad: tensor(0.2884, dtype=torch.float64)
hess: tensor(7.8232)
grad: tensor(0.1749, dtype=torch.float64)
hess: tensor(2.2721)
grad: tensor(1.2384, dtype=torch.float64)
hess: tensor(20.5200)
grad: tensor(0.0360, dtype=torch.float64)
hess: tensor(0.6825)
	accuracy:89.453125
	loss: 0.001986
grad: tensor(6.7627, dtype=torch.float64)
hess: tensor(69.5947)
grad: tensor(0.8208, dtype=torch.float64)
hess: tensor(22.4855)
grad: tensor(10.8996, dtype=torch.float64)
hess: tensor(89.2255)
grad: tensor(0.8356, dtype=torch

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(1.9132e-06, dtype=torch.float64)
hess: tensor(0.0001)
	accuracy:91.796875
	loss: 0.000000
grad: tensor(0.4447, dtype=torch.float64)
hess: tensor(15.5862)
grad: tensor(0.9078, dtype=torch.float64)
hess: tensor(17.4332)
grad: tensor(9.2988, dtype=torch.float64)
hess: tensor(114.4373)
grad: tensor(0.0046, dtype=torch.float64)
hess: tensor(0.1539)
grad: tensor(0.0046, dtype=torch.float64)
hess: tensor(0.1290)
grad: tensor(0.0135, dtype=torch.float64)
hess: tensor(0.3096)
grad: tensor(0.0471, dtype=torch.float64)
hess: tensor(1.2938)
grad: tensor(0.4094, dtype=torch.float64)
hess: tensor(9.2054)
	accuracy:89.453125
	loss: 0.001617
grad: tensor(0.0021, dtype=torch.float64)
hess: tensor(0.0728)
grad: tensor(0.7774, dtype=torch.float64)
hess: tensor(21.8480)
grad: tensor(0.1803, dtype=torch.float64)
hess: tensor(6.9330)
grad: tensor(1.3064, dtype=torch.float64)
hess: tensor(31.2436)
grad: tensor(0.0236, dtype=torch.float64)
hess: tensor(0.8529)
grad: tensor(0.6996, dtype=torch.flo

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(1.2634e-06, dtype=torch.float64)
hess: tensor(7.1804e-05)
	accuracy:92.1875
	loss: 0.000000
grad: tensor(0.3638, dtype=torch.float64)
hess: tensor(13.1110)
grad: tensor(0.7871, dtype=torch.float64)
hess: tensor(15.5545)
grad: tensor(9.0772, dtype=torch.float64)
hess: tensor(114.6154)
grad: tensor(0.0042, dtype=torch.float64)
hess: tensor(0.1448)
grad: tensor(0.0036, dtype=torch.float64)
hess: tensor(0.1028)
grad: tensor(0.0111, dtype=torch.float64)
hess: tensor(0.2594)
grad: tensor(0.0442, dtype=torch.float64)
hess: tensor(1.2379)
grad: tensor(0.3568, dtype=torch.float64)
hess: tensor(8.1640)
	accuracy:90.625
	loss: 0.001420
grad: tensor(0.0017, dtype=torch.float64)
hess: tensor(0.0615)
grad: tensor(0.7572, dtype=torch.float64)
hess: tensor(21.6653)
grad: tensor(0.1719, dtype=torch.float64)
hess: tensor(6.7210)
grad: tensor(1.2539, dtype=torch.float64)
hess: tensor(30.5877)
grad: tensor(0.0163, dtype=torch.float64)
hess: tensor(0.6013)
grad: tensor(0.6620, dtype=torch.floa

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(8.8151e-07, dtype=torch.float64)
hess: tensor(4.6367e-05)
	accuracy:91.796875
	loss: 0.000000
grad: tensor(0.2996, dtype=torch.float64)
hess: tensor(11.0666)
grad: tensor(0.6958, dtype=torch.float64)
hess: tensor(14.0876)
grad: tensor(9.1151, dtype=torch.float64)
hess: tensor(116.2242)
grad: tensor(0.0038, dtype=torch.float64)
hess: tensor(0.1311)
grad: tensor(0.0028, dtype=torch.float64)
hess: tensor(0.0826)
grad: tensor(0.0090, dtype=torch.float64)
hess: tensor(0.2150)
grad: tensor(0.0404, dtype=torch.float64)
hess: tensor(1.1515)
grad: tensor(0.3027, dtype=torch.float64)
hess: tensor(7.0692)
	accuracy:90.234375
	loss: 0.001174
grad: tensor(0.0015, dtype=torch.float64)
hess: tensor(0.0532)
grad: tensor(0.7281, dtype=torch.float64)
hess: tensor(21.2986)
grad: tensor(0.1747, dtype=torch.float64)
hess: tensor(6.9358)
grad: tensor(1.2467, dtype=torch.float64)
hess: tensor(30.9175)
grad: tensor(0.0111, dtype=torch.float64)
hess: tensor(0.4165)
grad: tensor(0.6092, dtype=torch

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(6.7558e-07, dtype=torch.float64)
hess: tensor(3.5862e-05)
	accuracy:91.796875
	loss: 0.000000
grad: tensor(0.2556, dtype=torch.float64)
hess: tensor(9.6597)
grad: tensor(0.6043, dtype=torch.float64)
hess: tensor(12.5468)
grad: tensor(8.0874, dtype=torch.float64)
hess: tensor(111.0664)
grad: tensor(0.0034, dtype=torch.float64)
hess: tensor(0.1190)
grad: tensor(0.0022, dtype=torch.float64)
hess: tensor(0.0644)
grad: tensor(0.0072, dtype=torch.float64)
hess: tensor(0.1748)
grad: tensor(0.0394, dtype=torch.float64)
hess: tensor(1.1435)
grad: tensor(0.2555, dtype=torch.float64)
hess: tensor(6.0478)
	accuracy:90.234375
	loss: 0.001045
grad: tensor(0.0013, dtype=torch.float64)
hess: tensor(0.0458)
grad: tensor(0.7111, dtype=torch.float64)
hess: tensor(21.1711)
grad: tensor(0.1683, dtype=torch.float64)
hess: tensor(6.7851)
grad: tensor(1.2062, dtype=torch.float64)
hess: tensor(30.4094)
grad: tensor(0.0082, dtype=torch.float64)
hess: tensor(0.3116)
grad: tensor(0.5777, dtype=torch.

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(4.7744e-07, dtype=torch.float64)
hess: tensor(2.6116e-05)
	accuracy:91.796875
	loss: 0.000000
grad: tensor(0.2472, dtype=torch.float64)
hess: tensor(9.4988)
grad: tensor(0.5164, dtype=torch.float64)
hess: tensor(10.9845)
grad: tensor(8.0734, dtype=torch.float64)
hess: tensor(112.1919)
grad: tensor(0.0032, dtype=torch.float64)
hess: tensor(0.1133)
grad: tensor(0.0016, dtype=torch.float64)
hess: tensor(0.0482)
grad: tensor(0.0061, dtype=torch.float64)
hess: tensor(0.1502)
grad: tensor(0.0368, dtype=torch.float64)
hess: tensor(1.0856)
grad: tensor(0.2255, dtype=torch.float64)
hess: tensor(5.4345)
	accuracy:90.234375
	loss: 0.000865
grad: tensor(0.0011, dtype=torch.float64)
hess: tensor(0.0419)
grad: tensor(0.6813, dtype=torch.float64)
hess: tensor(20.6979)
grad: tensor(0.1784, dtype=torch.float64)
hess: tensor(7.2838)
grad: tensor(1.1186, dtype=torch.float64)
hess: tensor(28.9008)
grad: tensor(0.0056, dtype=torch.float64)
hess: tensor(0.2165)
grad: tensor(0.5561, dtype=torch.

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(3.4903e-07, dtype=torch.float64)
hess: tensor(1.9080e-05)
	accuracy:92.1875
	loss: 0.000000
grad: tensor(0.1911, dtype=torch.float64)
hess: tensor(7.5057)
grad: tensor(0.4617, dtype=torch.float64)
hess: tensor(10.0191)
grad: tensor(7.5235, dtype=torch.float64)
hess: tensor(109.4864)
grad: tensor(0.0029, dtype=torch.float64)
hess: tensor(0.1030)
grad: tensor(0.0012, dtype=torch.float64)
hess: tensor(0.0382)
grad: tensor(0.0051, dtype=torch.float64)
hess: tensor(0.1262)
grad: tensor(0.0356, dtype=torch.float64)
hess: tensor(1.0678)
grad: tensor(0.1887, dtype=torch.float64)
hess: tensor(4.6226)
	accuracy:90.234375
	loss: 0.000753
grad: tensor(0.0009, dtype=torch.float64)
hess: tensor(0.0359)
grad: tensor(0.6167, dtype=torch.float64)
hess: tensor(19.1527)
grad: tensor(0.1579, dtype=torch.float64)
hess: tensor(6.5558)
grad: tensor(1.0948, dtype=torch.float64)
hess: tensor(28.7319)
grad: tensor(0.5935, dtype=torch.float64)
hess: tensor(21.2442)
grad: tensor(0.5436, dtype=torch.f

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(3.0002e-07, dtype=torch.float64)
hess: tensor(1.6662e-05)
	accuracy:92.1875
	loss: 0.000000
grad: tensor(0.0746, dtype=torch.float64)
hess: tensor(1.8584)
grad: tensor(5.9774, dtype=torch.float64)
hess: tensor(70.8735)
grad: tensor(0.4380, dtype=torch.float64)
hess: tensor(7.2381)
grad: tensor(0.0024, dtype=torch.float64)
hess: tensor(0.0691)
grad: tensor(0.2084, dtype=torch.float64)
hess: tensor(6.2951)
grad: tensor(0.0343, dtype=torch.float64)
hess: tensor(1.0461)
grad: tensor(0.0070, dtype=torch.float64)
hess: tensor(0.2170)
	accuracy:90.234375
	loss: 0.000646
grad: tensor(0.0077, dtype=torch.float64)
hess: tensor(0.1944)
grad: tensor(2.9026, dtype=torch.float64)
hess: tensor(73.4673)
grad: tensor(0.3881, dtype=torch.float64)
hess: tensor(11.0089)
grad: tensor(0.5685, dtype=torch.float64)
hess: tensor(20.7445)
grad: tensor(0.5137, dtype=torch.float64)
hess: tensor(13.7284)
grad: tensor(0.0073, dtype=torch.float64)
hess: tensor(0.2658)
grad: tensor(0.4202, dtype=torch.fl

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(2.2303e-07, dtype=torch.float64)
hess: tensor(1.1958e-05)
	accuracy:92.1875
	loss: 0.000000
grad: tensor(0.0672, dtype=torch.float64)
hess: tensor(1.6964)
grad: tensor(5.7417, dtype=torch.float64)
hess: tensor(69.9043)
grad: tensor(0.4245, dtype=torch.float64)
hess: tensor(7.1069)
grad: tensor(0.0019, dtype=torch.float64)
hess: tensor(0.0576)
grad: tensor(0.2029, dtype=torch.float64)
hess: tensor(6.2087)
grad: tensor(0.0330, dtype=torch.float64)
hess: tensor(1.0197)
grad: tensor(0.0057, dtype=torch.float64)
hess: tensor(0.1789)
	accuracy:90.234375
	loss: 0.000538
grad: tensor(0.0070, dtype=torch.float64)
hess: tensor(0.1791)
grad: tensor(2.6474, dtype=torch.float64)
hess: tensor(69.7256)
grad: tensor(0.3531, dtype=torch.float64)
hess: tensor(10.2007)
grad: tensor(0.5442, dtype=torch.float64)
hess: tensor(20.1694)
grad: tensor(0.4856, dtype=torch.float64)
hess: tensor(13.1930)
grad: tensor(0.0070, dtype=torch.float64)
hess: tensor(0.2565)
grad: tensor(0.3964, dtype=torch.fl

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(1.7944e-07, dtype=torch.float64)
hess: tensor(9.6984e-06)
	accuracy:92.578125
	loss: 0.000000
grad: tensor(0.0592, dtype=torch.float64)
hess: tensor(1.5147)
grad: tensor(5.4753, dtype=torch.float64)
hess: tensor(68.5719)
grad: tensor(0.4298, dtype=torch.float64)
hess: tensor(7.2767)
grad: tensor(0.0017, dtype=torch.float64)
hess: tensor(0.0520)
grad: tensor(0.1952, dtype=torch.float64)
hess: tensor(6.0541)
grad: tensor(0.0301, dtype=torch.float64)
hess: tensor(0.9424)
grad: tensor(0.0049, dtype=torch.float64)
hess: tensor(0.1578)
	accuracy:90.625
	loss: 0.000483
grad: tensor(0.0063, dtype=torch.float64)
hess: tensor(0.1649)
grad: tensor(2.6253, dtype=torch.float64)
hess: tensor(70.1268)
grad: tensor(0.3123, dtype=torch.float64)
hess: tensor(9.1292)
grad: tensor(0.5310, dtype=torch.float64)
hess: tensor(20.0129)
grad: tensor(0.4595, dtype=torch.float64)
hess: tensor(12.6706)
grad: tensor(0.0067, dtype=torch.float64)
hess: tensor(0.2506)
grad: tensor(0.3595, dtype=torch.floa

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(1.4719e-07, dtype=torch.float64)
hess: tensor(8.1444e-06)
	accuracy:92.578125
	loss: 0.000000
grad: tensor(0.0573, dtype=torch.float64)
hess: tensor(1.4838)
grad: tensor(5.3355, dtype=torch.float64)
hess: tensor(67.9918)
grad: tensor(0.4160, dtype=torch.float64)
hess: tensor(7.1360)
grad: tensor(0.0014, dtype=torch.float64)
hess: tensor(0.0432)
grad: tensor(0.1875, dtype=torch.float64)
hess: tensor(5.8963)
grad: tensor(0.0306, dtype=torch.float64)
hess: tensor(0.9715)
grad: tensor(0.0039, dtype=torch.float64)
hess: tensor(0.1268)
	accuracy:90.625
	loss: 0.000418
grad: tensor(0.0057, dtype=torch.float64)
hess: tensor(0.1500)
grad: tensor(2.4853, dtype=torch.float64)
hess: tensor(68.1781)
grad: tensor(0.2697, dtype=torch.float64)
hess: tensor(8.0227)
grad: tensor(0.5203, dtype=torch.float64)
hess: tensor(19.9057)
grad: tensor(0.4297, dtype=torch.float64)
hess: tensor(12.0314)
grad: tensor(0.0063, dtype=torch.float64)
hess: tensor(0.2348)
grad: tensor(0.3357, dtype=torch.floa

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(1.1275e-07, dtype=torch.float64)
hess: tensor(5.8590e-06)
	accuracy:92.1875
	loss: 0.000000
grad: tensor(0.0523, dtype=torch.float64)
hess: tensor(1.3696)
grad: tensor(5.1632, dtype=torch.float64)
hess: tensor(67.3815)
grad: tensor(0.4196, dtype=torch.float64)
hess: tensor(7.2689)
grad: tensor(0.0012, dtype=torch.float64)
hess: tensor(0.0370)
grad: tensor(0.1814, dtype=torch.float64)
hess: tensor(5.7617)
grad: tensor(0.0283, dtype=torch.float64)
hess: tensor(0.9108)
grad: tensor(0.0034, dtype=torch.float64)
hess: tensor(0.1116)
	accuracy:90.625
	loss: 0.000358
grad: tensor(0.0052, dtype=torch.float64)
hess: tensor(0.1383)
grad: tensor(2.3603, dtype=torch.float64)
hess: tensor(66.3696)
grad: tensor(0.2577, dtype=torch.float64)
hess: tensor(7.7683)
grad: tensor(0.5093, dtype=torch.float64)
hess: tensor(19.7856)
grad: tensor(0.4108, dtype=torch.float64)
hess: tensor(11.6494)
grad: tensor(0.0063, dtype=torch.float64)
hess: tensor(0.2374)
grad: tensor(0.3199, dtype=torch.float6

  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(1.0061e-07, dtype=torch.float64)
hess: tensor(5.2404e-06)
	accuracy:92.1875
	loss: 0.000000
grad: tensor(0.0454, dtype=torch.float64)
hess: tensor(1.2004)
grad: tensor(4.8942, dtype=torch.float64)
hess: tensor(65.6725)
grad: tensor(0.4051, dtype=torch.float64)
hess: tensor(7.1020)
grad: tensor(0.0010, dtype=torch.float64)
hess: tensor(11.3047)
grad: tensor(0.0059, dtype=torch.float64)
hess: tensor(0.2260)
grad: tensor(0.2971, dtype=torch.float64)
hess: tensor(9.1746)


  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(8.2714e-08, dtype=torch.float64)
hess: tensor(4.2832e-06)
	accuracy:92.1875
	loss: 0.000000
grad: tensor(0.0447, dtype=torch.float64)
hess: tensor(1.1983)
grad: tensor(4.6963, dtype=torch.float64)
hess: tensor(64.3478)
grad: tensor(0.4031, dtype=torch.float64)
hess: tensor(7.1403)
grad: tensor(0.0009, dtype=torch.float64)
hess: tensor(0.0279)
grad: tensor(0.1641, dtype=torch.float64)
hess: tensor(5.3287)
grad: tensor(0.0720, dtype=torch.float64)
hess: tensor(1.0781)
grad: tensor(0.0746, dtype=torch.float64)
hess: tensor(2.0099)
	accuracy:90.234375
	loss: 0.000284
grad: tensor(0.0048, dtype=torch.float64)
hess: tensor(0.1441)
grad: tensor(0.0364, dtype=torch.float64)
hess: tensor(1.2977)
grad: tensor(0.8498, dtype=torch.float64)
hess: tensor(24.6118)
grad: tensor(1.1930, dtype=torch.float64)
hess: tensor(29.5503)
grad: tensor(3.1529, dtype=torch.float64)
hess: tensor(67.8077)
grad: tensor(2.4049, dtype=torch.float64)
hess: tensor(47.6072)


  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(6.7558e-08, dtype=torch.float64)
hess: tensor(3.5184e-06)
	accuracy:91.796875
	loss: 0.000000
grad: tensor(5.4086, dtype=torch.float64)
hess: tensor(111.3229)
grad: tensor(0.0410, dtype=torch.float64)
hess: tensor(1.4278)
grad: tensor(0.0021, dtype=torch.float64)
hess: tensor(0.0852)
grad: tensor(0.0026, dtype=torch.float64)
hess: tensor(0.1104)
grad: tensor(0.0685, dtype=torch.float64)
hess: tensor(1.0347)
grad: tensor(0.0684, dtype=torch.float64)
hess: tensor(1.8607)
	accuracy:90.234375
	loss: 0.000251
grad: tensor(0.0042, dtype=torch.float64)
hess: tensor(0.1245)
grad: tensor(0.0352, dtype=torch.float64)
hess: tensor(1.2685)
grad: tensor(0.7994, dtype=torch.float64)
hess: tensor(23.4442)
grad: tensor(1.1545, dtype=torch.float64)
hess: tensor(29.0213)
grad: tensor(3.2148, dtype=torch.float64)
hess: tensor(69.3204)
grad: tensor(2.3322, dtype=torch.float64)
hess: tensor(47.7333)


  0%|          | 0/2000 [00:00<?, ?batch/s]

grad: tensor(5.6885e-08, dtype=torch.float64)
hess: tensor(2.9817e-06)
	accuracy:91.796875
	loss: 0.000000
grad: tensor(5.0280, dtype=torch.float64)
hess: tensor(108.1678)
grad: tensor(0.0362, dtype=torch.float64)
hess: tensor(1.2769)
grad: tensor(0.0021, dtype=torch.float64)
hess: tensor(0.0842)
grad: tensor(0.0025, dtype=torch.float64)
hess: tensor(0.1063)


KeyboardInterrupt: 