<img src="../rsag_convex.png" alt="algoconvex" />
<img src="../x_update.png" alt="x_update" />
<img src="../mean.png" alt="mean" />
<img src="../rsag_composite.png" alt="algo" />

__Parameters :__
- $\alpha$: (1-$\alpha$) weight of aggregated x on current state, i.e. momentum
- $\lambda$: learning rate
- $\beta$: change for aggregated x
- $p_k$ termination probability



In [1]:
from  torch.optim import Adam, SGD, RMSprop
import torch
from torch.nn import functional as F
from torch import nn
import numpy as np
import warnings
import copy
import torch.utils.data as data_utils
# import EarlyStopping
# from pytorchtools import EarlyStopping

In [2]:
print('Using PyTorch version:', torch.__version__)
if torch.cuda.is_available():
    print('Using GPU, device name:', torch.cuda.get_device_name(0))
    device = torch.device('cuda')
else:
    print('No GPU found, using CPU instead.') 
    device = torch.device('cpu')

Using PyTorch version: 2.1.2+cu121
Using GPU, device name: NVIDIA GeForce GTX 1660 Ti


In [3]:
import path
import sys
sys.path.append('../')

from models import MLP
from optimizers import RSAG, AccSGD
from util import DataLoader
from util import calc_accuracy, train_model, HPScheduler


ModuleNotFoundError: No module named 'models'

### Run MLP:
__TUNE DIFFERENT OPTIMIZERS__:
- Nesterov w/ weight decay w/ Scheduled LR (SGD)
- Momentum w/ weight decay w/ Scheduled LR (SGD)
- Basic SGD
- Adagrad?
- Adam?



In [None]:
data_loader = DataLoader()
loaders = data_loader.get_loaders()
# loss_function = torch.nn.CrossEntropyLoss()
# model = MLP().to(device)
# print(model)

# optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, nesterov=True, momentum=0.9)
# optimizer = RSAG(model.parameters(), lr=1e-4, alpha=.9, beta=9e-5)


In [None]:
class MLP(nn.Module): 
    """
    Very simple
    2 hidden layer MLP with 512 and 512 hidden units respectively
    """
    def __init__(self, input_dim=28*28, output_dim=10, h=512):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, h),
            nn.Linear(h, output_dim)
        )

    def forward(self, x):
        # x = x.view(-1,28*28)
        # x = F.relu(self.layers[0](x))
        return self.layers(x)
    


### RSAG Pytorch

In [None]:
class RSAG(torch.optim.Optimizer):
    r"""
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate (lambda) (required)
        kappa (float): lambda  (default: 1000)
        xi (float, optional): statistical advantage parameter (default: 10)
        smallConst (float, optional): any value <=1 (default: 0.7)
    Example:
        >>> from RSAG import *
        >>> optimizer = RSAG(model.parameters(), lr=0.1, kappa = 1000.0, xi = 10.0)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    """

    def __init__(self, 
                 params, 
                 lr=0.01, 
                 alpha = 0.1, 
                 beta = 0.1): #, smallConst = 0.7, weight_decay=0):
        #defaults = dict(lr=lr, kappa=kappa, xi, smallConst=smallConst,
                        # weight_decay=weight_decay)
        
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if alpha < 0.0 or alpha > 1.0:
            raise ValueError("Invalid alpha: {}".format(alpha))
        if beta < 0.0:
            raise ValueError("Invalid beta: {}".format(beta))
        
        defaults = dict(lr=[lr], alpha=[alpha], beta=[beta], step=0)

        
        # ASSUME PARAMS ARE ALREADY PROJECTED
        super(RSAG, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RSAG, self).__setstate__(state)

    def step(self, closure=None):
        """ Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()
            print('hereee')
        # print('param_group',self.param_groups)
        for group in self.param_groups:
            # weight_decay = group['weight_decay']
            lr = group['lr']
            alpha, beta = group['alpha'], group['beta']
            alpha_bar = 1.0-alpha
            momentum_buffer_list = []
            # step = group['step']
            # print('group', group)

            # INITIALIZE GROUPS
            # params_with_grad, d_p_list, momentum_buffer_list = [], [], []
            # for p in group['params']:
            #     if p.grad is not None:
            #         params_with_grad.append(p)
            #         d_p_list.append(p.grad)
            #         # if p.grad.is_sparse:
            #         #     has_sparse_grad = True

            #         state = self.state[p]
            #         if 'momentum_aggr' not in state:
            #             momentum_buffer_list.append(None)
            #         else:
            #             momentum_buffer_list.append(state['momentum_buffer'])
            
            # UPDATE GROUPS
            for p in group['params']:
                if p.grad is None:
                    continue

                d_w = p.grad.data
                param_state = self.state[p]
                w = p.data

                
                momentum_aggr = torch.tensor(w - beta[-1]*d_w)
                
                if group['step'] == 0:
                    alpha.append(alpha[-1])
                    beta.append(beta[-1])
                    lr.append(lr[-1])

                    if param_state['momentum_aggr'] is None:
                        param_state['momentum_aggr'] = {}

                    param_state['momentum_aggr'][p] = momentum_aggr

                    alpha_val = (1-alpha[-1])*beta[-2]+alpha[-1]*lr[-1]
                    p.data.add_(d_w, alpha=-alpha_val)
                    
                else:
                    alpha.append(alpha[-1])
                    beta.append(beta[-1])
                    lr.append(lr[-1])

                    mu = 1-alpha[-1] + alpha[-1]/alpha[-2]
                    xi = (1-alpha[-1])*beta[-2] + alpha[-1]*lr[-1]
                    eta = (1-alpha[-2])*alpha[-1]/alpha[-2]

                    p.data.mul_(mu)
                    p.data.add_(w, alpha=xi)
                    p.data.add_(param_state['momentum_aggr'][p], alpha=-eta)
                    param_state['momentum_aggr'][p] = momentum_aggr

                    # param_state['momentum_aggr'] = momentum_aggr

                # if weight_decay != 0:
                #     grad_d.add_(weight_decay, p.data)
                
                # if 'momentum_aggr' not in param_state:
                #     param_state['momentum_aggr'] = copy.deepcopy(p.data)
                #     param_state['prev_momentum_aggr'] = copy.deepcopy(p.data)
                # buf = param_state['momentum_aggr']
                # aggr_grad = (buf-param_state['prev_momentum_aggr'])
                # aggr_grad.mul_(alpha_bar)
                # aggr_grad.add_(d_w, alpha=alpha)
                
                # param_state['prev_momentum_aggr'] = copy.deepcopy(buf)
                
                # # Update momentum buffer:'
                # buf.mul_(alpha_bar)
                # buf.add_(p.data, alpha=alpha)
                # buf.add_(aggr_grad, alpha=-beta)
                
                # p.data.add_(aggr_grad, alpha=-lr)
                

        return loss

In [None]:
def train_model(
                model,
                loss_function,
                optimizer,
                loaders,
                device='cpu',
                verbose=True,
                save_path=None,
                log_path=None,
                n_epochs=200,
                print_every=1
                ):
    log = {}
    log['loss'], log['accuracy'] = [], []
    log['v_loss'], log['v_accuracy'] = [], []
    
    log['v_loss_std'], log['v_accuracy_std'] = [], []
    log['loss_std'], log['accuracy_std'] = [], []

    best_acc = 0.0
    

    # initialize the early_stopping object
    # early_stopping = EarlyStopping(patience=20, verbose=True)

    model.to(device)
    model.train()
    for epoch in range(0,n_epochs):
        print(f'Starting Epoch {epoch+1}')

        current_loss, total_acc = [], []
        v_loss, v_acc = [], []

        for data, targets in loaders['train']:
            # inputs, targets = data
            # inputs, targets = inputs.float(), targets.float()
            # targets = targets.reshape((targets.shape[0], 1))
            
            # Copy data and targets to GPU
            data = data.to(device)
            targets = targets.to(device)
            
            optimizer.zero_grad()

            outputs = model(data)

            # Calculate the loss
            loss = loss_function(outputs, targets)
            # current_loss += loss

            # Backpropagation
            loss.backward()
            optimizer.step()

            current_loss.append(loss.item())
            total_acc.append(calc_accuracy(outputs, targets))
            
        # Validation
        model.eval()

        for data, targets in loaders['valid']:
            data = data.to(device)
            targets = targets.to(device)
            
            outputs = model(data)

            loss = loss_function(outputs, targets)
            v_loss.append(loss.item())
            v_acc.append(calc_accuracy(outputs, targets))        

        print(f'Epoch {epoch+1} finished')
        # current_loss /= len(loaders['train'])
        # total_acc /= len(loaders['train'])
        # print('loss {:.4f}'.format(current_loss))
        # print('Accuracy:  {:.4f}'.format(total_acc))

        update_log(log, current_loss, total_acc, v_loss, v_acc, len(loaders['train']), len(loaders['valid']))

        if verbose:
            if epoch%print_every == 0:
                print('Epoch {}/{}'.format(epoch+1, n_epochs))
                print('-' * 10)
                print('Loss {:.4f}'.format(log['loss'][-1]))
                print('Accuracy:  {:.4f}'.format(log['accuracy'][-1]))
                print('Validation Loss {:.4f}'.format(log['v_loss'][-1]))
                print('Validation Accuracy:  {:.4f}'.format(log['v_accuracy'][-1]))
        
        if len(log['v_accuracy']) > 1 and (np.abs(log['v_accuracy'][-1]-log['v_accuracy'][-2])<0.1):
            print('Early stopping at epoch %d'%epoch)
            break
        # early_stopping(log['v_loss'][-1], model)
        
        # if early_stopping.early_stop:
        #     print("Early stopping")
        #     break

        if log['v_accuracy'][-1] > best_acc:
            best_acc = log['v_accuracy'][-1]
    
            # load the last checkpoint with the best model
            model.load_state_dict(torch.load('checkpoint.pt'))
    
    if save_path is not None:
        torch.save(model.state_dict(), save_path)
        print('Model saved to %s'%save_path)

    if log_path is not None:
        df = pd.DataFrame.from_dict(log)
        df.to_csv(log_path)
        print('Log saved to %s'%log_path)

    print("Training has completed")
    return log, best_acc

def update_log(log, current_loss, total_acc, v_loss, v_acc, train_len, valid_len):
        log['loss_std'].append(np.std(current_loss))
        log['accuracy_std'].append(np.std(total_acc))
        current_loss = sum(current_loss)/train_len
        total_acc = sum(total_acc)/train_len
        log['loss'].append(current_loss)
        log['accuracy'].append(total_acc)

        
        log['v_loss_std'].append(np.std(v_loss))
        log['v_accuracy_std'].append(np.std(v_acc))
        v_loss = sum(v_loss)/valid_len
        v_acc = sum(v_acc)/valid_len
        log['v_loss'].append(v_loss)
        log['v_accuracy'].append(v_acc)

        return

In [31]:
model = MLP().to(device)
print(model)

loss_function = torch.nn.CrossEntropyLoss()

optimizer = RSAG(model.parameters(),  lr=1e-4, alpha=.9, beta=9e-5)
log, best_acc = train_model(model, loss_function, optimizer, loaders, device)


MLP(
  (layers): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=10, bias=True)
  )
)
Starting Epoch 1


TypeError: unsupported operand type(s) for -: 'float' and 'list'

In [81]:
def train_with_hyperparameters(alpha_values, lr_values, save_log=False):
    loss_function = torch.nn.CrossEntropyLoss()
    best_alpha, best_lr = 0.0, 0.0
    best_accuracy = 0.0
    v_accs, acc_std, v_loss, loss_std = [], [], [], []
    acc, loss = [], []
    
    for alpha in alpha_values:
        for lr in lr_values:
            beta = lr * alpha
            
            print(f"----------- Training with alpha={alpha}, lr={lr} -----------------")
            
            model = MLP().to(device)
            optimizer = RSAG(model.parameters(), lr=lr, alpha=alpha, beta=beta)
            log = train_model(model, loaders, optimizer, loss_function, device, epochs=20)
            
            if log['v_accuracy'][-1] > best_accuracy:
                print(f"Found a new best accuracy: {log['v_accuracy'][-1]}")
                print(f"best alpha: {alpha}, best lr: {lr}")
                best_accuracy = log['v_accuracy'][-1]
                best_alpha = alpha
                best_lr = lr
            
            v_accs.append(log['v_accuracy'])
            acc_std.append(log['v_accuracy_std'])
            v_loss.append(log['v_loss'])
            loss_std.append(log['v_loss_std'])
            acc.append(log['accuracy'])
            loss.append(log['loss'])
            

    
    return best_alpha, best_lr, v_accs, acc_std, v_loss, loss_std, acc, loss


IndentationError: unindent does not match any outer indentation level (<tokenize>, line 58)

In [46]:

optimizer = RSAG(model.parameters(), lr=1e-4, alpha=.9, beta=9e-5)
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, nesterov=True, momentum=0.9)
log, best_acc = train_model(model, loaders, optimizer, loss_function, device, epochs=5)
valid_loss = log['v_loss']
train_loss = log['loss']

Starting Epoch 1
Epoch 1 finished
loss 4.6072
Accuracy:  7.4900
Starting Epoch 2
Epoch 2 finished
loss 4.6027
Accuracy:  8.8633
Starting Epoch 3
Epoch 3 finished
loss 4.5982
Accuracy:  10.3017
Starting Epoch 4
Epoch 4 finished
loss 4.5937
Accuracy:  11.6917
Starting Epoch 5
Epoch 5 finished
loss 4.5893
Accuracy:  13.1383
Training has completed


## Visualize Loss and Early Stopping

In [None]:
# visualize the loss as the network trained
fig = plt.figure(figsize=(10,8))
plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss')

# find position of lowest validation loss
minposs = valid_loss.index(min(valid_loss))+1 
plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint')

plt.xlabel('epochs')
plt.ylabel('loss')
plt.ylim(0, 0.5) # consistent scale
plt.xlim(0, len(train_loss)+1) # consistent scale
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
fig.savefig('loss_plot.png', bbox_inches='tight')

## Test Model

In [None]:
# initialize lists to monitor test loss and accuracy
test_loss = 0.0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

model.eval() # prep model for evaluation

for data, target in test_loader:
    if len(target.data) != batch_size:
        break
    # forward pass: compute predicted outputs by passing inputs to the model
    output = model(data)
    # calculate the loss
    loss = loss_function(output, target)
    # update test loss 
    test_loss += loss.item()*data.size(0)
    # convert output probabilities to predicted class
    _, pred = torch.max(output, 1)
    # compare predictions to true label
    correct = np.squeeze(pred.eq(target.data.view_as(pred)))
    # calculate test accuracy for each object class
    for i in range(batch_size):
        label = target.data[i]
        class_correct[label] += correct[i].item()
        class_total[label] += 1

# calculate and print avg test loss
test_loss = test_loss/len(test_loader.dataset)
print('Test Loss: {:.6f}\n'.format(test_loss))

for i in range(10):
    if class_total[i] > 0:
        print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (
            str(i), 100 * class_correct[i] / class_total[i],
            np.sum(class_correct[i]), np.sum(class_total[i])))
    else:
        print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))

print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (
    100. * np.sum(class_correct) / np.sum(class_total),
    np.sum(class_correct), np.sum(class_total)))