In [1]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
import sys
sys.path.append('/content/gdrive/My Drive/Colab Notebooks/Lib_files')

Mounted at /content/gdrive


In [2]:
import torch
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
cuda0 = torch.device('cuda:0')

1
Tesla P100-PCIE-16GB


In [3]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 54.8 gigabytes of available RAM

You are using a high-RAM runtime!


In [None]:
"### MNIST FRAMEWORK K-FAC solver"
import torch
import torchvision
import torchvision.transforms as transforms
import os
import time
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from QE_KLD_WRM_correct_classif_implementation import QE_KLD_WRM_Optimizer

from torch.utils.data.dataloader import default_collate

#torch.set_default_tensor_type('torch.cuda.FloatTensor') 
#torch.set_default_tensor_type(torch.DoubleTensor)
torch.multiprocessing.set_start_method('spawn')

torch.backends.cudnn.enabled = True # False
torch.backends.cudnn.benchmark = True

# ------------------------------
# ---- Training parameters -----
n_epochs = 50
opti_type = 'QE_KLD_WRM'
# l_rate = 0.01;
def l_rate_function(epoch_n, iter_n):
    if epoch_n == 1:
        if iter_n < 3:
            return 0.01
        else:
            return 0.01
    elif epoch_n == 2:
        return 0.01
    elif epoch_n >= 3 and epoch_n < 30:
        return 0.01
    elif epoch_n >= 30:
        return 0.01

kfac_clip = 1e-1; KFAC_damping = 1e-02; stat_decay = 0.5

WD = 0.001#1
lambdaa = 0.0#1#1#7 #0.007
batch_size_train = batch_size_test = 512
# ONLY FOR SAVED FILE NAME: beta1 and beta2 are just 2 channels for filename ==
beta1 = WD
beta2 = KFAC_damping
# ====================================================
KFAC_matrix_update_frequency = 30

momentum = 0.0
my_clip_threshold = 2.0
number_inner_SGD_steps = 7
inner_lr_factor = 0.1
force_lr_on_final_step_flag = False
inner_momentum = 0.5
capacity_number_of_prev_nets_stored = 3

log_interval = 200 #int(200 *batch_size_train/8192)
basic_path = '/content/gdrive/My Drive/P_data/results{}_MNIST'.format(opti_type)
error_write_path = '/content/gdrive/My Drive/P_data/Errors/err_{}_MNIST'.format(opti_type)

random_seed = 0
torch.manual_seed(random_seed)

#------------------------------------------------------------------------------
#--------------------------- DATA LOADERS -------------------------------------
#------------------------------------------------------------------------------
def collation_fct(x):
  return  tuple(x_.to(cuda0) for x_ in default_collate(x))

# Data Normalisation parameters
global_data_mean = 0.1307
global_data_std = 0.3081
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('../data_lecun', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (global_data_mean,), (global_data_std,))
                               ])),
    batch_size=batch_size_train, shuffle=True, num_workers = 0, collate_fn = collation_fct) # pin_memory=True,
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('../data_lecun', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
    ])),
    batch_size=batch_size_train, shuffle=True, num_workers = 0, collate_fn = collation_fct) # pin_memory=True,

# -----------------------------------------------------------------------------
#-------------------------------- building the NET ----------------------------
scale = 1
class Net(nn.Module):
    def __init__(self, nodes_dropout=False):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 5, kernel_size=5)
        self.conv2 = nn.Conv2d(5, 7, kernel_size=5)
        if nodes_dropout == True:
            self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(7 * 16, 30) # nn.Linear(7 * 16, 30)
        self.fc2 = nn.Linear(30, 10) # nn.Linear(30, 10)
        #initialize parameters structure to map gradients from parameter format to math format and viceversa
        self.parameter_structure = [] #[i.shape for i in self.parameters()]
        self.coarser_param_structure = []
        for item in self.parameters():
            #print('Parameter on device {}'.format(item.device))
            item = item.shape
            number_of_elements = 1
            current_list = []
            for j in item:
                number_of_elements = number_of_elements * j
                current_list.append(j)
            if len(current_list) == 1:
                current_list.append(1)
            self.parameter_structure.append(current_list)
            self.coarser_param_structure.append(number_of_elements)
        self.even_coarser_param_structure = list(np.array([self.coarser_param_structure[i] for i in range(len(self.coarser_param_structure)) if i % 2 == 1]) + np.array([self.coarser_param_structure[i] for i in range(len(self.coarser_param_structure)) if i % 2 == 0]))
        self.number_of_parameters = np.sum(self.coarser_param_structure)

    def forward(self, x, nodes_dropout=False):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        if nodes_dropout == True:
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        else:
            x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 112)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

network = Net()
network.to(cuda0)
for p in network.parameters():  print('After moving, parameters on device {}'.format(p.device))
#network.load_state_dict(torch.load('./model_0.pth'))
print('the number of parameters is {}'.format(sum([p.numel() for p in network.parameters()])))
# -------------------------------------------

# --------------------------------- DEFINE THE OPTIMISER ----------------------
# -----------------------------------------------------------------------------
def regularized_loss_fct(output,target,network, lambdaa):
    KFAC_matrix_loss = F.cross_entropy(output,target)
    l2_reg = torch.tensor(0., device = cuda0)
    for param in network.parameters():
        l2_reg += torch.norm(param)
    KFAC_matrix_loss += lambdaa * l2_reg
    
    loss_for_gradient = F.cross_entropy(output,target)
    l2_reg = torch.tensor(0., device = cuda0)
    for param in network.parameters():
        l2_reg += torch.norm(param)
    loss_for_gradient += lambdaa * l2_reg
    
    #print('the parameter vector norm is {}'.format(l2_reg))
    return KFAC_matrix_loss, loss_for_gradient, l2_reg


# note that LBFGS computes the same
optimizer = QE_KLD_WRM_Optimizer(network, network_generating_function = Net, lr_function = l_rate_function, momentum = 0.0, 
                                stat_decay = stat_decay, kl_clip = kfac_clip, damping = KFAC_damping, weight_decay = WD, 
                                Ts = KFAC_matrix_update_frequency,
                                Tf = KFAC_matrix_update_frequency, my_clip_threshold = my_clip_threshold,
                                number_inner_SGD_steps = number_inner_SGD_steps,
                                inner_lr_factor = inner_lr_factor,
                                force_lr_on_final_step_flag = force_lr_on_final_step_flag, 
                                inner_momentum = inner_momentum, capacity_number_of_prev_nets_stored = capacity_number_of_prev_nets_stored)

scheduler = None
#scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda = [lambda1, lambda2])

# ------------------------------------------------------------------------------

# track performance of parameters and progress
train_losses = []; train_losses_per_epoch = []; train_accuracy_per_epoch = []

train_accuracy = []; test_accuracy = []; time_per_epoch_ = []; time_per_iter = []
train_counter = []; test_losses = []
test_counter = [i * len(train_loader.dataset) for i in range(n_epochs + 1)]

##initial network

# torch.save(network.state_dict(), './resultsSGD/model.pth')
def save_data_():
    # basic_path = os.path.join(basic_path, '/')
    train_losses_cpu = []; train_losses_per_epoch_cpu = []; train_accuracy_per_epoch_cpu = []
    train_accuracy_cpu = []; test_accuracy_cpu = []; time_per_epoch__cpu = []; time_per_iter_cpu = []
    test_losses_cpu = []

    for trl,tim,trlpe,tracc,traccpe,timpe,teacc,tel in zip(train_losses, time_per_iter, train_losses_per_epoch, train_accuracy,
                                                           train_accuracy_per_epoch, time_per_epoch_, 
                                                           test_accuracy_cpu, test_losses_cpu):
      train_losses_cpu.append(trl.cpu());time_per_iter_cpu.append(tim.cpu());
      train_losses_per_epoch_cpu.append(trlpe.cpu()); train_accuracy_cpu.append(tracc.cpu());
      train_accuracy_per_epoch_cpu.append(traccpe.cpu());  time_per_epoch__cpu.append(timpe.cpu());  test_accuracy_cpu.append(teacc.cpu());
      test_losses_cpu.append(tel.cpu());
    
    np.save(os.path.join(basic_path,
                         '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_run{}_{}'.format(opti_type, batch_size_train,
                                                    beta1,
                                                    beta2, l_rate_function(40,40), stat_decay, 
                                                    my_clip_threshold,
                                                    number_inner_SGD_steps, inner_lr_factor, force_lr_on_final_step_flag, inner_momentum, capacity_number_of_prev_nets_stored,
                                                    random_seed, 'train_losses')),
            train_losses_cpu)
    
    np.save(os.path.join(basic_path,
                         '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_run{}_{}'.format(opti_type, batch_size_train,
                                                    beta1,
                                                    beta2, l_rate_function(40,40), stat_decay, 
                                                    my_clip_threshold,
                                                    number_inner_SGD_steps, inner_lr_factor, force_lr_on_final_step_flag, inner_momentum, capacity_number_of_prev_nets_stored,
                                                    random_seed, 'time_per_iter')),
            time_per_iter_cpu)

    np.save(os.path.join(basic_path,
                         '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_run{}_{}'.format(opti_type, batch_size_train,
                                                    beta1,
                                                    beta2, l_rate_function(40,40), stat_decay, 
                                                    my_clip_threshold, 
                                                    number_inner_SGD_steps, inner_lr_factor, force_lr_on_final_step_flag, inner_momentum, capacity_number_of_prev_nets_stored,
                                                    random_seed, 'train_losses_per_epoch')),
            train_losses_per_epoch_cpu)
    np.save(os.path.join(basic_path,
                         '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_run{}_{}'.format(opti_type, batch_size_train,
                                                    beta1,
                                                    beta2, l_rate_function(40,40), stat_decay, 
                                                    my_clip_threshold, 
                                                    number_inner_SGD_steps, inner_lr_factor, force_lr_on_final_step_flag, inner_momentum, capacity_number_of_prev_nets_stored,
                                                    random_seed,  'train_accuracy')),
            train_accuracy_cpu)
    np.save(os.path.join(basic_path,
                         '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_run{}_{}'.format(opti_type, batch_size_train,
                                                    beta1,
                                                    beta2, l_rate_function(40,40), stat_decay, 
                                                    my_clip_threshold, 
                                                    number_inner_SGD_steps, inner_lr_factor, force_lr_on_final_step_flag, inner_momentum, capacity_number_of_prev_nets_stored,
                                                    random_seed,  'train_accuracy_per_epoch')),
            train_accuracy_per_epoch_cpu)
    np.save(os.path.join(basic_path,
                         '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_run{}_{}'.format(opti_type, batch_size_train,
                                                    beta1,
                                                    beta2, l_rate_function(40,40), stat_decay, 
                                                    my_clip_threshold, 
                                                    number_inner_SGD_steps, inner_lr_factor, force_lr_on_final_step_flag, inner_momentum, capacity_number_of_prev_nets_stored,
                                                    random_seed,  'time_per_epoch_')),
            time_per_epoch__cpu)
    np.save(os.path.join(basic_path,
                         '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_run{}_{}'.format(opti_type, batch_size_train,
                                                    beta1,
                                                    beta2, l_rate_function(40,40), stat_decay, 
                                                    my_clip_threshold, 
                                                    number_inner_SGD_steps, inner_lr_factor, force_lr_on_final_step_flag, inner_momentum, capacity_number_of_prev_nets_stored,
                                                    random_seed,  'test_accuracy')),
            test_accuracy_cpu)
    np.save(os.path.join(basic_path,
                         '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_run{}_{}'.format(opti_type, batch_size_train,
                                                    beta1,
                                                    beta2, l_rate_function(40,40), stat_decay, 
                                                    my_clip_threshold, 
                                                    number_inner_SGD_steps, inner_lr_factor, force_lr_on_final_step_flag, inner_momentum, capacity_number_of_prev_nets_stored,
                                                    random_seed,  'test_losses')),
            test_losses_cpu)

def train(epoch, step_counter, log_interval = log_interval):
    network.train()
    correct = 0
    time_epoch = 0
    optimizer.epoch_number = epoch
    # previous_step = np.array([0])
    for batch_idx, (data, target) in enumerate(train_loader):
        step_counter = step_counter + 1
        start = time.time()
        optimizer.zero_grad()
        #data = data.double()
        output = network(data)
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()

        KFAC_matrix_loss, loss_for_gradient, l2_reg = regularized_loss_fct(output, target, network, lambdaa)
        #t1 = time.time()
        
        # update network weights
        optimizer.zero_grad()
        
        ''' assemble <KFAC matrix> loss to compute KFAC matrix'''
        if optimizer.steps % KFAC_matrix_update_frequency == 0:
            optimizer.acc_stats = True
            KFAC_matrix_loss.backward(retain_graph=True)
        optimizer.acc_stats = False
    
        ''' compute gradient of <Policy loss> (precond by KFAC^{-1}) and then take step'''
        ''' also need to compute and return the gradient for TRUEish F^{-1}g computation'''
        optimizer.zero_grad()
        loss_for_gradient.backward()
        QE_direction = optimizer.step(epoch_number = epoch, data = data, error_write_path = error_write_path)
        loss_value = loss_for_gradient.detach()
        end = time.time()

        time_per_iter.append(start - end)
        time_epoch = time_epoch + (end - start)
        train_losses.append(loss_value)
        if scheduler == None:
            pass
        else:
            scheduler.step()

        if batch_idx % log_interval == 0:
            train_counter.append(
                (batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))
            # change the saving path
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss_value))
            # print('\n negative evals steps:{} \n positive eval steps:{}\n'.format(negative_eigenvalue_steps,nonnegative_eigenvalue_steps))
            torch.save(network.state_dict(), '/content/gdrive/My Drive/P_data/results{}_MNIST/model.pth'.format(opti_type))
            torch.save(optimizer.state_dict(), '/content/gdrive/My Drive/P_data/results{}_MNIST/optimizer.pth'.format(opti_type))
            print('param norm: {}'.format(l2_reg))

            save_data_()

            #print('the number {} shoudl be (the number above)*lambdaa '
            #      'if reg works'.format(loss.detach()-F.cross_entropy(output,target)))

    accc = 100. * correct / len(train_loader.dataset)
    train_accuracy.append(accc)
    time_per_epoch_.append(time_epoch)
    train_losses_per_epoch.append(loss_value)
    train_accuracy_per_epoch.append(accc)
    save_data_()
    return step_counter

def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        data_iter = iter(test_loader)
        for dummy_index in range(len(test_loader)):
            data, target = next(data_iter)
            #print('Data on device {}'.format(data.device))
            #data = data.double()
            output = network(data)
            test_loss += F.cross_entropy(output, target, size_average=False).detach()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    accc = 100. * correct / len(test_loader.dataset)
    test_accuracy.append(accc)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

test()
step_counter = 0
print('\n', optimizer.m_aa, '\n')
for epoch in range(1, n_epochs + 1):
    t1 = time.time()
    step_counter = train(epoch, step_counter)
    test()
    t2 = time.time()
    print('Took {}s'.format(t2-t1))
print('\nDone!')
save_data = True


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data_lecun/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data_lecun/MNIST/raw/train-images-idx3-ubyte.gz to ../data_lecun/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data_lecun/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data_lecun/MNIST/raw/train-labels-idx1-ubyte.gz to ../data_lecun/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data_lecun/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data_lecun/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data_lecun/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data_lecun/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data_lecun/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data_lecun/MNIST/raw

After moving, parameters on device cuda:0
After moving, parameters on device cuda:0
After moving, parameters on device cuda:0
After moving, parameters on device cuda:0
After moving, parameters on device cuda:0
After moving, parameters on device cuda:0
After moving, parameters on device cuda:0
After moving, parameters on device cuda:0
the number of parameters is 4712
The device is cuda:0
The device is cuda:0
The device is cuda:0
The device is cuda:0
The device is cuda:0
The device is cuda:0
The device is cuda:0
The device is cuda:0





Test set: Avg. loss: 2.3062, Accuracy: 1200/10000 (12%)


 {} 



	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1050.)
  p.grad.data.add_(self.weight_decay, p.data)
The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2499.)
  self.m_aa_augmented[m], eigenvectors=True)


param norm: 8.815337181091309


  p = F.softmax(no_softmax_p); q = F.softmax(no_softmax_q)


CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!
CLIPPING ACTIVATED in inner solver!

Test set: Avg. loss: 2.0384, Accuracy: 4902/10000 (49%)

Took 27.078907251358032s
param norm: 19.171422958374023
