In [14]:
x = torch.randint(10, (20,))
print(x)
x = np.random.permutation(10)[x]
torch.tensor(x)

tensor([4, 0, 9, 8, 1, 6, 2, 5, 0, 4, 9, 9, 2, 8, 5, 7, 3, 1, 0, 3])


tensor([6, 5, 3, 0, 8, 4, 7, 2, 5, 6, 3, 3, 7, 0, 2, 1, 9, 8, 5, 9])

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.autograd.functional import hessian
import torch.utils.data as data_utils
import matplotlib.pyplot as plt
import random

from tqdm.notebook import tqdm 
import pickle
import matplotlib.pyplot as plt
import matplotlib.colors
import seaborn as sns
import os
import copy
from torch.nn.utils import _stateless

batch_size = 1
num_workers = 1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device

details = {}
details['use_db'] = 'mnist'
details['result_root_dir']='results/t0/'
details['result_path']='try1_t10_r1'
details['ratio'] = 1
details['book_keep_freq'] = 20
details['g_times'] = 8
details['g_epochs'] = 1
details['alpha_0']= 0.001
details['freq_reduce_by'] = 20
details['freq_reduce_after'] = 100

details['training_step_limit'] = 100000 ## this is to train for max updates per epochs
details['stop_hess_computation'] = 20000 ## Stop computing hessian after calculated these many times

details['g_weight'] = int(details['ratio']*20000/(784+10))
print(f'selected weight:{details["g_weight"]}')

with open(details['result_root_dir']+'details_'+details['result_path']+'.txt', 'w+') as f:
    for key, val in details.items():
        content = key + ' : '+str(val) + '\n'
        f.write(content)
        
torch.manual_seed(3407)
np.random.seed(3407)
torch.cuda.manual_seed_all(3407)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# torch.use_deterministic_algorithms(True)
# os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"

train_data_all = torchvision.datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

test_data_all = torchvision.datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transforms.ToTensor()
)
# print(f'train data:{train_data}')
# print(f'test data:{test_data}')
def get_random_subset(train_data_all, test_data_all, corrupt_label=False):    
    # train_indices = torch.arange(20000)
    test_indices = torch.arange(500)
    train_indices = torch.randint(60000-1, (2000,))
    # print(f'train indices:{train_indices[:10]}')
    train_data = data_utils.Subset(train_data_all, train_indices)
    test_data = data_utils.Subset(test_data_all, test_indices)
    # print(f'train data:{train_data}')
    # print(f'test data:{test_data}')
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    g = torch.Generator()
    g.manual_seed(0)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        num_workers=num_workers,
        worker_init_fn=seed_worker,
        generator=g,
    )
    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=256,
        num_workers=num_workers,
        worker_init_fn=seed_worker,
        generator=g,
    )
    print(f'train data size:{len(train_loader.dataset)}')
    print(f'test data size:{len(test_loader.dataset)}')
    X_mat, y_mat = torch.Tensor(len(train_loader.dataset),784), torch.Tensor(len(train_loader.dataset)).long()
    for i, (data, label) in enumerate(train_loader):
        X_mat[i] = data.flatten()
        y_mat[i] = label.flatten()
    if corrupt_label:
        y_mat = torch.tensor(np.random.permutation(10)[y_mat])
    print(f'X_mat shape:{X_mat.shape}, y_mat shape:{y_mat.shape}')
    return train_loader, test_loader, X_mat, y_mat


class Net(nn.Module):

    def __init__(self, input_features, hidden_layers, output_size):
        super(Net, self).__init__()
        self.layers = len(hidden_layers) + 1
        self.total_params_len = 0
        self.fc1 = nn.Linear(input_features, hidden_layers[0])
        self.total_params_len += input_features*hidden_layers[0] + hidden_layers[0]
        self.fc2 = nn.Linear(hidden_layers[0], output_size)
        self.total_params_len += hidden_layers[0]*output_size + output_size
        
        ### Others required params
        self.param_list = []

    def forward(self, x):
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        # print('x shape in forward',x.shape)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        

class Train_nn:
    
    def __init__(self, input_features, hidden_layers, output_size, lr, dont_decay=False, l2_reg=0.1):
        self.model = Net(input_features, hidden_layers=hidden_layers, output_size=output_size)
        self.model.to(device)
        self.loss_fn = nn.CrossEntropyLoss()
        
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr, weight_decay=l2_reg)
        if dont_decay:
            lambda_lr = lambda it: 1
        else:
            lambda_lr = lambda it: 1/(it+1)
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda_lr)
        
    def get_loss(self, X, y, params=None):
        # if params is not None:
        assert False, "Model not initialized with given params"
        op = self.model(X)
        loss = self.loss_fn(pred, y)
        return loss
    
    def get_gradient(self):
        params = (self.model.parameters())
        grad_norm_sq = torch.tensor(0, dtype=float).to(device)
        # print('grad Norm init:', grad_norm_sq)
        for param in self.model.parameters():
            temp = param.grad.data.pow(2).sum()
            # print(f'param grad norm \n\tsum:{temp.data}')#,\n\tshape:{param.shape}')
            grad_norm_sq += temp
        grad_norm = grad_norm_sq.sqrt()
        return grad_norm.cpu()
    
    def try_operator_norm(self, hess_mat):
        for i in len(hess_mat):
            for j in len(hess_mat[0]):
                torch.unsqueeze(hess_mat[i][i],0)
        hess_tensor_dim = list(hess_mat[0][0].shape)
        hess_tensor_dim += [n*2,n*2]
        hess_mat_np = np.zeros(shape=hess_tensor_dim)
        hess_tensor = torch.tensor(hess_mat_np)
        torch.cat(hess_mat, out=hess_tensor)
        
        hess_mat.reshpe(n*2,n*2)
        hess_norm = torch.linalg.norm(hess_mat, 2)
        assert False, "Not working"
    
    def get_hessian(self, X, y):
        prev_params = copy.deepcopy(list(self.model.parameters()))
        n = self.model.layers
        def local_model(*params):
            # print(f'len of params:{len(params)}')
            # print(f'shape of params[0]:{params[0].shape}')
            # with torch.no_grad():
            #initialize model with given params
            i = 0
            for i, param in enumerate(self.model.parameters()):
                param.data = params[i]
            pred = self.model(X)
            loss = self.loss_fn(pred, y)
            # print(f'loss type:{type(loss)}')
            return loss
        p =list(self.model.parameters())
        hess_mat = hessian(local_model, tuple(p))
        hess_norm = torch.tensor(0.).to(device)
        for i in range(len(hess_mat)):
            for j in range(len(hess_mat[0])):
                hess_norm+= hess_mat[i][j].pow(2).sum()
        
        # print(f'Hess mat len:{len(hess_mat)}')
        # print(f'Hess mat[0] len:{len(hess_mat[0])}')
        # print(f'Hess mat[0][0] shape:{hess_mat[0][0].shape}')
        
        hess_norm = hess_norm.sqrt()
        # print(f'hess norm:{hess_norm}')
        
        # Reinitialize the original params to model
        for i, param in enumerate(self.model.parameters()):
                param.data = prev_params[i]
        
        return hess_norm
    
    def get_hessianv2(self, X,y):
        names = list(n for n, _ in self.model.named_parameters())
        def loss_fun_hess(*params):
            out: torch.Tensor = _stateless.functional_call(self.model, {n: p for n, p in zip(names, params)}, X)
            local_loss = self.loss_fn(out, y)
            return local_loss
        hess_mat = hessian(loss_fun_hess, tuple(self.model.parameters()))
        
        hess_norm = torch.tensor(0.).to(device)
        for i in range(len(hess_mat)):
            for j in range(len(hess_mat[0])):
                hess_norm+= hess_mat[i][j].pow(2).sum()
        hess_norm = hess_norm.sqrt()
        # print(f'v2 hess norm{hess_norm}')
        return hess_norm.cpu()
        
    def fit(self, train_loader, test_loader, epochs, store_grads=False, store_hessian=False, store_gen_err=False, store_weights=False, store_pt_loss=True, store_freq = 20, freq_reduce_by=None, freq_reduce_after=None, fast_X_train=None, fast_y_train=None):
        
        ## For Book keeping results ##
        self.grads_norms = []
        self.param_list = []
        self.hess_norms = []
        self.gen_err = []
        self.train_loss = []
        self.val_loss = []
        self.point_loss = []
        ## Initializing values ##
        terminate_training = False
        store_count = 0
        ## Moving to gpu
        fast_X_train = fast_X_train.to(device)
        fast_y_train = fast_y_train.to(device)
        
        
        for epoch in tqdm(range(epochs), total=epochs, unit="epoch", disable=False):
            if terminate_training == True:
                break
            for batch, (X, y) in tqdm(enumerate(train_loader), total=len(train_loader), unit='batch',disable=True):
                # if batch>300:
                #     terminate_training = True
                #     break
                # print('y shape:',y.shape)
                batch_size = len(y)
                X, y =X.to(device), y.to(device)
                ## assigning the corrupted label
                y = fast_y_train[batch*batch_size: (batch+1)*batch_size]
                ####
                pred = self.model(X)
                loss = self.loss_fn(pred, y)
                # Backpropagation
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                ## Saving point loss
                if store_pt_loss and (batch%store_freq==0):
                    self.point_loss.append(loss.item())
                    
                ## Saving the weights
                if store_weights and (batch%store_freq==0):
                    current_params = list(self.model.parameters())
                    self.param_list.append(current_params)
                
                ## computing and saving the gradient
                if store_grads and (batch% store_freq == 0):
                    grad_norm_per_update = self.get_gradient()
                    self.grads_norms.append(grad_norm_per_update)
                    
                ## computing and saving hessian
                if store_hessian and (batch% store_freq==0):
                    self.optimizer.zero_grad()
                    hess_norms_per_update = self.get_hessianv2(X,y)
                    # print(f'\thess norm:{hess_norms_per_update}')
                    self.hess_norms.append(hess_norms_per_update)
                    
                ## computing and storing the generalization error
                if store_gen_err and (batch% store_freq == 0):
                    train_loss, test_loss=0, 0
                    if (fast_y_train is None) or (fast_X_train is None):
                        assert False, "not given fast_X_train and fast_y_train"
                        with torch.no_grad():
                            for sub_batch, (X_local,y_local) in enumerate(train_loader):
                                if epoch==0 and sub_batch> batch: # only taking the visited points to calculate train loss
                                    break
                                X_local, y_local = X_local.to(device), y_local.to(device)
                                pred_local = self.model(X_local)
                                train_loss += self.loss_fn(pred_local, y_local).item()
                        train_loss = train_loss/(batch+1)
                    else:
                        # print('using fast train loss, epoch', epoch)
                        with torch.no_grad():
                            if epoch==0:
                                pred_local = self.model(fast_X_train[:batch+1])
                                train_loss = self.loss_fn(pred_local, fast_y_train[:batch+1]).item()
                                # print(f'train_loss:{train_loss}')
                            else:
                                pred_local = self.model(fast_X_train)
                                train_loss = self.loss_fn(pred_local, fast_y_train).item()
                    with torch.no_grad():
                        for sub_batch, (X_local,y_local) in enumerate(test_loader):
                            X_local, y_local = X_local.to(device), y_local.to(device)
                            pred_local = self.model(X_local)
                            test_loss += self.loss_fn(pred_local, y_local).item()
                    test_batch_size = len(test_loader)
                    # print(f"Number of batches in test:{len(test_loader)}")
                    test_loss = test_loss/ len(test_loader)
                    self.train_loss.append(train_loss)
                    self.val_loss.append(test_loss)
                    self.gen_err.append(train_loss - test_loss)
                    print(f'train loss:{train_loss}',end=' ')
                    print(f'test loss :{test_loss}',end=' ')
                    print(f'gen loss  :{self.gen_err[-1]}')
                    store_count += 1
                    if store_count%freq_reduce_after==0:
                        store_freq += freq_reduce_by
                        
                if batch % 1000 == 0:
                    loss, current = loss.item(), batch * len(X)
                    correct = 0
                    test_loss = 0
                    with torch.no_grad():
                        pred_local = self.model(fast_X_train)
                        correct += (pred_local.argmax(1) == fast_y_train).type(torch.float).sum().item()
                    # print('data points', fast_X_train.shape[0])
                    train_acc = 100* correct/fast_X_train.shape[0]
                    correct=0
                    print(f'\ttrain acc:{train_acc}')
                    with torch.no_grad():
                        for X, y in test_loader:
                            X, y = X.to(device), y.to(device)
                            pred = self.model(X)
                            test_loss += self.loss_fn(pred, y).item()
                            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
                    acc = 100*correct/len(test_loader.dataset)
                    print(f"\ttest acc   :{acc}")#, at batch:{batch}")
                    print(f"\tloss       : {loss:>7f}")
                    print(f'\tlr rate:{self.scheduler.get_last_lr()}')
                self.scheduler.step()
            
def exp_get_lp_sm(train_data_all, test_data_all, op_features, weight = 10, times = 8, epochs = 1, root_dir='', path=None, clear_file = True, freq_reduce_by=10, freq_reduce_after=100):
    grad_list        = []
    hess_norm_list   = []
    if path is not None:
        grad_file_path = root_dir+'grad_'+path
        hess_file_path = root_dir+'hess_'+path
        gen_file_path = root_dir+'gen_'+path
        if clear_file:
            with open(grad_file_path, 'w+') as f:
                f.write('')
            with open(hess_file_path, 'w+') as f:
                f.write('')
            with open(gen_file_path, 'w+') as f:
                f.write('')
            
    for t in range(times):
        train_loader, test_loader, fast_X_train, fast_y_train = get_random_subset(train_data_all, test_data_all)
        print(f'Time:{t}')
        train_model = Train_nn(784, [weight], op_features, lr= details['alpha_0'])
        train_model.fit(train_loader, test_loader, epochs=epochs, store_grads=True, store_hessian=True, store_freq=details['book_keep_freq'],  store_gen_err=True, store_pt_loss=False, store_weights=False, freq_reduce_by = freq_reduce_by, freq_reduce_after=freq_reduce_after, fast_X_train=fast_X_train, fast_y_train=fast_y_train)
        
        with open(grad_file_path,'a+') as f:
            f.write(' '.join([str(grad) for grad in train_model.grads_norms]) + '\n')
        with open(hess_file_path,'a+') as f:
            f.write(' '.join([str(hess) for hess in train_model.hess_norms]) + '\n')
        with open(gen_file_path,'a+') as f:
            f.write(' '.join([str(gen_e) for gen_e in train_model.gen_err]) + '\n')
        
        hess_norm_list.append(train_model.hess_norms)
        grad_list.append(train_model.grads_norms)
         
    return grad_list, hess_norm_list
# grad_list, hess_norm_list=[],[]
# grad_list, hess_norm_list = exp_get_lp_sm(train_data_all, test_data_all, op_features=10, 
#               weight=details['g_weight'], times=details['g_times'], 
#               epochs=details['g_epochs'], root_dir=details['result_root_dir'], 
#               path=details['result_path'], freq_reduce_by=details['freq_reduce_by'], 
#               freq_reduce_after=details['freq_reduce_after'])


selected weight:25


In [19]:
train_loader, test_loader, fast_X_train, fast_y_train = get_random_subset(train_data_all, test_data_all, corrupt_label=True)

train data size:2000
test data size:500
X_mat shape:torch.Size([2000, 784]), y_mat shape:torch.Size([2000])


In [22]:
train_model = Train_nn(784, [10], 10, lr= 1., dont_decay = False, l2_reg=0.000)
train_model.fit(train_loader, test_loader, epochs=100, store_grads=False, store_hessian=False, store_freq=10000,  store_gen_err=False, store_pt_loss=False, store_weights=False, freq_reduce_by = 100, freq_reduce_after=10000, fast_X_train=fast_X_train, fast_y_train=fast_y_train)

  0%|          | 0/100 [00:00<?, ?epoch/s]

	train acc:11.25
	test acc   :8.0
	loss       : 2.402829
	lr rate:[1.0]
	train acc:14.75
	test acc   :6.8
	loss       : 2.376818
	lr rate:[0.000999000999000999]
	train acc:13.55
	test acc   :10.8
	loss       : 2.002281
	lr rate:[0.0004997501249375312]
	train acc:12.5
	test acc   :11.6
	loss       : 2.328550
	lr rate:[0.0003332222592469177]
	train acc:12.45
	test acc   :12.0
	loss       : 2.017007
	lr rate:[0.00024993751562109475]
	train acc:12.05
	test acc   :12.8
	loss       : 2.306098
	lr rate:[0.0001999600079984003]
	train acc:12.25
	test acc   :13.0
	loss       : 2.023251
	lr rate:[0.00016663889351774705]
	train acc:12.1
	test acc   :13.0
	loss       : 2.291227
	lr rate:[0.00014283673760891302]
	train acc:12.15
	test acc   :13.0
	loss       : 2.026544
	lr rate:[0.00012498437695288088]
	train acc:12.15
	test acc   :13.0
	loss       : 2.280087
	lr rate:[0.00011109876680368848]


KeyboardInterrupt: 

In [32]:
train_loader, test_loader, fast_X_train, fast_y_train = get_random_subset(train_data_all, test_data_all)

train data size:2000
test data size:500
X_mat shape:torch.Size([2000, 784]), y_mat shape:torch.Size([2000])


In [35]:
train_model = Train_nn(784, [10], 10, lr= 0.3, dont_decay = True, l2_reg=0.000)
train_model.fit(train_loader, test_loader, epochs=100, store_grads=False, store_hessian=False, store_freq=10000,  store_gen_err=False, store_pt_loss=False, store_weights=False, freq_reduce_by = 100, freq_reduce_after=10000, fast_X_train=fast_X_train, fast_y_train=fast_y_train)

  0%|          | 0/100 [00:00<?, ?epoch/s]

	train acc:1.85
	test acc   :1.2
	loss       : 2.498376
	lr rate:[0.3]
	train acc:23.35
	test acc   :20.6
	loss       : 2.318381
	lr rate:[0.0002997002997002997]
	train acc:25.1
	test acc   :21.6
	loss       : 0.418449
	lr rate:[0.00014992503748125936]
	train acc:27.05
	test acc   :24.0
	loss       : 2.297851
	lr rate:[9.99666777740753e-05]
	train acc:27.35
	test acc   :25.0
	loss       : 0.348010
	lr rate:[7.498125468632842e-05]
	train acc:28.7
	test acc   :25.2
	loss       : 2.291203
	lr rate:[5.998800239952009e-05]
	train acc:29.2
	test acc   :25.8
	loss       : 0.313152
	lr rate:[4.9991668055324116e-05]
	train acc:29.85
	test acc   :26.0
	loss       : 2.284672
	lr rate:[4.2851021282673906e-05]
	train acc:30.1
	test acc   :26.8
	loss       : 0.290971
	lr rate:[3.749531308586426e-05]
	train acc:30.5
	test acc   :26.8
	loss       : 2.279595
	lr rate:[3.3329630041106545e-05]
	train acc:31.2
	test acc   :27.4
	loss       : 0.275081
	lr rate:[2.9997000299970004e-05]
	train acc:31.95
	tes

## for corrupted labels

In [37]:
train_loaderr, test_loaderr, fast_X_trainr, fast_y_trainr = get_random_subset(train_data_all, test_data_all, corrupt_label=True)

train data size:2000
test data size:500
X_mat shape:torch.Size([2000, 784]), y_mat shape:torch.Size([2000])


In [38]:
train_model = Train_nn(784, [10], 10, lr= 0.3, dont_decay = False, l2_reg=0.000)
train_model.fit(train_loaderr, test_loaderr, epochs=100, store_grads=False, store_hessian=False, store_freq=10000,  store_gen_err=False, store_pt_loss=False, store_weights=False, freq_reduce_by = 100, freq_reduce_after=10000, fast_X_train=fast_X_trainr, fast_y_train=fast_y_trainr)

  0%|          | 0/100 [00:00<?, ?epoch/s]

	train acc:8.8
	test acc   :10.0
	loss       : 2.057150
	lr rate:[0.3]
	train acc:10.65
	test acc   :13.2
	loss       : 2.526260
	lr rate:[0.0002997002997002997]
	train acc:10.2
	test acc   :11.6
	loss       : 5.509175
	lr rate:[0.00014992503748125936]
	train acc:10.25
	test acc   :11.6
	loss       : 2.521321
	lr rate:[9.99666777740753e-05]
	train acc:10.25
	test acc   :11.6
	loss       : 5.236142
	lr rate:[7.498125468632842e-05]
	train acc:10.3
	test acc   :11.6
	loss       : 2.518581
	lr rate:[5.998800239952009e-05]
	train acc:10.3
	test acc   :11.6
	loss       : 5.094241
	lr rate:[4.9991668055324116e-05]
	train acc:10.3
	test acc   :11.4
	loss       : 2.516590
	lr rate:[4.2851021282673906e-05]
	train acc:10.35
	test acc   :11.6
	loss       : 5.005261
	lr rate:[3.749531308586426e-05]
	train acc:10.25
	test acc   :11.4
	loss       : 2.515024
	lr rate:[3.3329630041106545e-05]
	train acc:10.3
	test acc   :11.2
	loss       : 4.938525
	lr rate:[2.9997000299970004e-05]
	train acc:10.4
	tes

KeyboardInterrupt: 

In [17]:
torch.randint(10, (9,))

tensor([7, 9, 6, 1, 8, 7, 8, 8, 6])

In [43]:
fast_y_train[:10]

tensor([0, 0, 2, 5, 5, 5, 1, 6, 7, 4])

In [49]:
y_corrupted = fast_y_train[np.random.permutation(len(fast_y_train))]
y_corrupted

tensor([8, 9, 7,  ..., 6, 9, 4])