### **Training a DNN for HC18 segmentation task**

**Description:** In this notebook, we provide the codes for:

1.   Opening data and preparing data for training (Functions)
2.   Model architecture (Class)
3.   Training function
4.   Second-order optimizers (Class)
5.   Helpers to train with each optimizer and save
6.   Utilities to ensure reproductibility
7.   Training and save loop over seeds and optimizers

**STEP 1 - Opening data and preparing data for training (Functions)**

In [1]:
import pandas as pd
import numpy as np
import PIL
import os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader

def extract_files(path):
  ''' Function to help load data. Performs basic pre-processing on images (normalization, masks to float)
  path: folder where the training imags are stored

  return: list, list (two lists containing the images and masks respectively)
  '''

  input = []
  output = []

  def format_mask(mask):
    return (mask>0).astype(float)

  def normalize(im):
    im = im.astype(float)
    im = (im - np.min(im))/(np.max(im)-np.min(im))
    return im    

  for file in os.listdir(path):
    if not '_Annotation' in file:

      imp = normalize(np.array(PIL.Image.open(path+file)))
      imo = format_mask(np.array(PIL.Image.open(path+file.replace('.png','_Annotation.png'))))

      input.append(imp)
      output.append(imo)

  return input, output


def load_data(trpath = './Data/format_train/', 
              vpath = './Data/format_val/'):
  
  '''This function loads the validation and training data from their respective folders and fuse them to allow the 
  automatic change of train / val split depending on random shuffling seed. 

  return: numpy array, numpy array (two numpy arrays with the fused images data and masks respectively)
  '''

  train_input, train_output = extract_files(trpath)
  val_input, val_output = extract_files(vpath)

  train_input, train_output = extract_files(trpath)
  plus_input, plus_output = extract_files(vpath)

  train_input_ = list(np.concatenate([train_input, plus_input]))
  train_output_ = list(np.concatenate([train_output, plus_output]))

  return train_input_, train_output_


def split_train_val(train_input_, SEED=0):

  '''This function simply split train_input_ into 300 validation samples and 699 
  training samples after random shuffling with seed SEED. 

  Return: numpy array, numpy array, numpy array 
  (train images, train masks, val images, val masks resp.)
  '''

  indexes = np.arange(len(train_input_))
  np.random.seed(SEED)
  np.random.shuffle(indexes)

  v_indexes = indexes[:300]
  tr_indexes = indexes[300:]

  train_input = np.array(train_input_)[tr_indexes,:]
  train_output = np.array(train_output_)[tr_indexes,:]

  val_input = np.array(train_input_)[v_indexes,:]
  val_output = np.array(train_output_)[v_indexes,:]

  print(len(val_input), len(train_input))

  return train_input, train_output, val_input, val_output


class MyDataset(TensorDataset):
  #This class formats the data for training in pytorch.
    def __init__(self, data, targets):    
      # data: input images
      # target: output masks 
        self.data = data
        self.targets = targets

    def __getitem__(self, index):
            x_arr = (self.data[index]).astype(float)
            y_arr = (self.targets[index]).astype(float)           
            x = torch.Tensor(x_arr[np.newaxis,:,:])
            y = torch.Tensor(y_arr).long()
            return x, y
    def __len__(self):
        return len(self.data)


def get_dataloaders(train_input, train_output, val_input, val_output, BATCH_SIZE = 5):

  '''This function creates data loaders for training in pytorch. 
  It also extracts
  the weights for cross-entropy loss.
  train_input (samples x W x H): list of input train images
  train_output (samples x W x H): list of output train masks
  val_input (samples x W x H): list of input val images
  val_output (samples x W x H): list of output val masks

  return : torch dataloader, torch dataloader, numpy array
  (training loader, val loader, weights)
  '''

  # We compute the weights for the Cross Entropy loss
  all_pixs = np.concatenate(np.concatenate(train_output))
  props = np.sum(all_pixs > 0)/len(all_pixs)
  w = 1/np.array([1-props, props])

  # create loaders
  train_data = MyDataset(train_input, train_output)
  train_dl = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False) 
  valid_data = MyDataset(val_input, val_output)
  valid_dl = DataLoader(valid_data, shuffle=False) 

  return train_dl, valid_dl, w

**STEP 2 - Model architecture (Class)**

In [2]:
import torch
from torch import nn

class UNET(nn.Module):
    
    # Sandra Marcadent pytorch implementation of 2D Unet [1]

    # [1] @article{Ronneberger2015, 
    # author = {Olaf Ronneberger and Philipp Fischer and Thomas Brox}, 
    # month = {5}, 
    # title = {U-Net: Convolutional Networks for Biomedical Image # Segmentation},
    # url = {http://arxiv.org/abs/1505.04597},
    # year = {2015},
    #}
    
    def __init__(self, in_channels, out_channels):
        
        super().__init__()
        
        self.out_channels = out_channels

        self.conv1 = self.contract_block(in_channels, 64, 3, 1)
        self.conv2 = self.contract_block(64, 128, 3, 1)
        self.conv3 = self.contract_block(128, 256, 3, 1)
        self.conv4 = self.contract_block(256, 512, 3, 1)
        self.conv5 = self.contract_block(512, 1024, 3, 1)
        
        self.upconv5 = self.expand_block(1024, 512, 3, 1)
        self.upconv4 = self.expand_block(512*2, 256, 3, 1)
        self.upconv3 = self.expand_block(256*2, 128, 3, 1)
        self.upconv2 = self.expand_block(128*2, 64, 3, 1)
        self.upconv1 = self.expand_block(64*2, self.out_channels, 3, 1)

    def __call__(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)

        upconv5 = self.upconv5(conv5)
        
        upconv4 = self.upconv4(torch.cat([upconv5, conv4], 1))
        upconv3 = self.upconv3(torch.cat([upconv4, conv3], 1))
        upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))


        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0))
        return expand
    
   

**STEP 3 - Training function** 

In [3]:
import torchvision
import numpy as np
import time


def train(model, device, train_dl, valid_dl, loss_fn, optimizer, epochs=100,
          lr_base=0.01, save=False, savepath='', mode=''):
  
    '''This function trains a DNN and saves the results for frequency analysis.
    model: torch.nn network
    device: gpu or cpu torch device
    train_dl: training dataloader
    valid_dl: validation dataloader
    loss_fn: loss function to optimize
    mode: Hessian or empty (default) mode

    return: torch.nn 
    (The trained neural network)
    '''
    
    start = time.time()
    model.to(device, non_blocking=False)
    loss_fn = loss_fn.to(device)

    for epoch in range(epochs):
        
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        train_outputs = []
        val_outputs = []

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  # Set training mode = true
                dataloader = train_dl
            else:
                model.train(False)  # Set model to evaluate mode
                dataloader = valid_dl


            for step, data in enumerate(dataloader):
                
                x,y = data
                x = x.to(device)
                y = y.to(device)
 
                if phase == 'train':
                    if mode == 'Hessian':
                          def closure():

                              outputs = model(x)
                              optimizer.zero_grad()
                              loss = loss_fn(outputs, y)
                              loss.backward(create_graph=True) #needed to compute the Hessian 
                              return loss
          
                          optimizer.step(closure)

                          with torch.no_grad():  
                              outputs = model(x)
                              loss = loss_fn(outputs, y)

                    else:
                              optimizer.zero_grad()
                              outputs = model(x)
                              loss = loss_fn(outputs, y)
                              loss.backward()
                              optimizer.step()

                        
                    if save:
                            y_hat_ = get_outputs(outputs) # format the outputs to be numpy arrays
                            train_outputs.append(y_hat_)

                    if step % 10 == 0:
                        print(phase, ' --- Current step: {}  Loss: {}  AllocMem (Mb): {}'.format(step, loss,
                                                                    torch.cuda.memory_allocated()/1024/1024))


                else:    
                        with torch.no_grad():

                          outputs = model(x)
                          loss = loss_fn(outputs, y)

                          # loss visu to check that the training goes well
                          if epoch%2 == 0:   
                                if step%10 == 0:
                                    print(phase, ' --- Current step: {}  Loss: {}  AllocMem (Mb): {}'.format(step, loss,
                                                                        torch.cuda.memory_allocated()/1024/1024))

                          if save:
                            y_hat_ = get_outputs(outputs, val=True) # format the outputs to be numpy arrays
                            val_outputs.append(y_hat_) # already formatted for frequency analysis



        if save:
          # Save the results for further analysis
            val_outputs = np.concatenate(val_outputs, axis = 0)
            train_outputs = np.concatenate(train_outputs, axis = 0)
            np.save(savepath+'val_epoch='+str(epoch)+'.npy',val_outputs)
            np.save(savepath+'train_epoch='+str(epoch)+'.npy',train_outputs)

               
    x = x.cpu()
    y = y.cpu()
    loss_fn = loss_fn.cpu()
    model = model.cpu()
    
    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    
    return model


def get_outputs(y_hat, val=False):

  '''This function formats the DNN output for frequency analysis.
  y_hat: a mini batch of size (B x C x H x W)

  return: numpy array
  (formated y_hat)
  '''

  yh = torch.argmax(y_hat.float(),dim=1).float()
  img = yh.cpu().numpy()
  img = np.transpose(img,(1, 2, 0))

  if val: # if val we then have a batchsize of one
    img_ = img[:,:,0].flatten('F')
    return img_[np.newaxis,:]

  outputs = []

  for k in range(img.shape[2]): # run over the batch to accumulate the outputs
    img_ = img[:,:,k].flatten('F')
    outputs.append(img_[np.newaxis,:])

  return np.concatenate(outputs,axis=0)


**STEP 4 - Second-order optimizers (Class)**

AdaHessian

In [4]:
import torch
from torch.optim import Optimizer
from pdb import set_trace as bp


class AdaHessian(torch.optim.Optimizer):
    """
    Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
    Arguments:
        params (iterable) -- iterable of parameters to optimize or dicts defining parameter groups
        lr (float, optional) -- learning rate (default: 0.1)
        betas ((float, float), optional) -- coefficients used for computing running averages of gradient and the squared hessian trace (default: (0.9, 0.999))
        eps (float, optional) -- term added to the denominator to improve numerical stability (default: 1e-8)
        weight_decay (float, optional) -- weight decay (L2 penalty) (default: 0.0)
        hessian_power (float, optional) -- exponent of the hessian trace (default: 1.0)
        update_each (int, optional) -- compute the hessian trace approximation only after *this* number of steps (to save time) (default: 1)
        n_samples (int, optional) -- how many times to sample `z` for the approximation of the hessian trace (default: 1)

    Reference: https://github.com/amirgholami/adahessian

    """

    def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, 
                 hessian_power=1.0, update_each=1, n_samples=1, average_conv_kernel=False):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if not 0.0 <= hessian_power <= 1.0:
            raise ValueError(f"Invalid Hessian power value: {hessian_power}")

        self.n_samples = n_samples
        self.update_each = update_each
        self.average_conv_kernel = average_conv_kernel

        # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
        self.generator = torch.Generator().manual_seed(2147483647)

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
        super(AdaHessian, self).__init__(params, defaults)

        for p in self.get_params():
            p.hess = 0.0
            self.state[p]["hessian step"] = 0

    def get_params(self):
        """
        Gets all parameters in all param_groups with gradients
        """

        return (p for group in self.param_groups for p in group['params'] if p.requires_grad)

    def zero_hessian(self):
        """
        Zeros out the accumalated hessian traces.
        """

        for p in self.get_params():
            if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
                p.hess.zero_()

    @torch.no_grad()
    def set_hessian(self):
        """
        Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
        """

        params = []
        for p in filter(lambda p: p.grad is not None, self.get_params()):
            if self.state[p]["hessian step"] % self.update_each == 0:  # compute the trace only each `update_each` step
                params.append(p)
            self.state[p]["hessian step"] += 1

        if len(params) == 0:
            return

        if self.generator.device != params[0].device:  # hackish way of casting the generator to the right device
            self.generator = torch.Generator(params[0].device).manual_seed(2147483647)

        grads = [p.grad for p in params]

        for i in range(self.n_samples):
            zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]  # Rademacher distribution {-1.0, 1.0}
            h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
            for h_z, z, p in zip(h_zs, zs, params):
                p.hess += h_z * z / self.n_samples  # approximate the expected values of z*(H@z)

    @torch.no_grad()
    def step(self, closure=None):
        """
        Performs a single optimization step.
        Arguments:
            closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
        """

        loss = None
        if closure is not None:
            loss = closure()

        self.zero_hessian()
        self.set_hessian()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None or p.hess is None:
                    continue

                if self.average_conv_kernel and p.dim() == 4:
                    p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()

                # Perform correct stepweight decay as in AdamW
                p.mul_(1 - group['lr'] * group['weight_decay'])

                state = self.state[p]

                # State initialization
                if len(state) == 1:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)  # Exponential moving average of gradient values
                    state['exp_hessian_diag_sq'] = torch.zeros_like(p.data)  # Exponential moving average of Hessian diagonal square values

                exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
                beta1, beta2 = group['betas']
                state['step'] += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
                exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                k = group['hessian_power']
                denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])

                # make update
                step_size = group['lr'] / bias_correction1
                p.addcdiv_(exp_avg, denom, value=-step_size)

        return loss


SCRN

In [5]:
class SCRNOptimizer(Optimizer):
    def __init__(self, params, inner_itr=10, ro=0.1, l=0.5, epsilon=1e-3, c_prime=0.1, step_size=0.001):

        self.ro = ro
        self.l = l
        self.epsilon = epsilon
        self.c_prime = c_prime
        self.inner_itr = inner_itr
        self.step_size = 1 / (20 * l)
        self.iteration = -1
        defaults = dict()
        self.sqr_grads_norms = 0
        self.last_grad_norm = 0

        super(SCRNOptimizer, self).__init__(params, defaults)
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['displacement'] = torch.zeros_like(p)

    def compute_norm_of_list_var(self, array_):
        """
        Args:
        param array_: list of tensors
        return:
        norm of the flattened list
        """
        norm_square = 0
        for i in range(len(array_)):
            norm_square += array_[i].norm(2).item() ** 2
        return norm_square ** 0.5

    def inner_product_of_list_var(self, array1_, array2_):

        """
        Args:
        param array1_: list of tensors
        param array2_: list of tensors
        return:
        The inner product of the flattened list
        """

        sum_list = 0
        for i in range(len(array1_)):
            sum_list += torch.sum(array1_[i] * array2_[i])
        return sum_list

    def cubic_subsolver(self, grads, param, grad_norm: float, epsilon: float, ro: float, l: float):
        """
        solve the sub problem with gradient decent
        """
        deltas = [0] * len(grads)
        g_tildas = [0] * len(grads)

        # compute the hessian
        #hessian = torch.autograd.grad(outputs=grads, inputs=param, retain_graph=True)
        # turn of unwanted actions


        with torch.no_grad():
            if grad_norm >= l ** 2 / self.ro:

                # compute hessian vector with respect to grads
                hvp = torch.autograd.grad(outputs=grads, inputs=param,
                                          grad_outputs=grads, retain_graph=True)
                
                g_t_dot_bg_t = self.inner_product_of_list_var(grads, hvp) / (ro * (grad_norm ** 2))
                R_c = -g_t_dot_bg_t + (g_t_dot_bg_t ** 2 + 2 * grad_norm / ro) ** 0.5

                for i in range(len(grads)):
                    deltas[i] = -R_c * grads[i].clone() / grad_norm

            else:
                sigma = self.c_prime * (epsilon * ro) ** 0.5 / l
                for i in range(len(grads)):
                    deltas[i] = torch.zeros(grads[i].shape)
                    khi = torch.rand(grads[i].shape)
                    g_tildas[i] = grads[i].clone() + sigma * khi
                for t in range(self.inner_itr):
                    # compute hessian vector with respect to delta
                    hvp = torch.autograd.grad(outputs=grads, inputs=param,
                                              grad_outputs=deltas, retain_graph=True)
                    deltas_norm = self.compute_norm_of_list_var(deltas)
                    if self.compute_norm_of_list_var(hvp)>200:
                        break

                    for i in range(len(grads)):
                        deltas[i] = deltas[i] - self.step_size * (
                                g_tildas[i] + hvp[i] + ro / 2 * deltas_norm * deltas[i])
                    # print("*********************")
                    # print(deltas[0])
                    # print(self.step_size)
                    # print(g_tildas[0])
                    # print(hvp[0])
                    # print(deltas_norm)

        # compute hessian vector with respect to delta
        hvp = torch.autograd.grad(outputs=grads, inputs=param,
                                  grad_outputs=deltas, retain_graph=True)
        deltas_norm = self.compute_norm_of_list_var(deltas)
        delta_m = 0
        for i in range(len(grads)):
            delta_m += torch.sum(grads[i] * deltas[i]) + 0.5 * torch.sum(deltas[i] * hvp[i]) + ro / 6 * deltas_norm ** 3

        deltas_norm = 0
        # update the displacement
        for group in self.param_groups:
            i = 0
            for p in group['params']:
                state = self.state[p]
                deltas_norm += deltas[i].norm(2).item() ** 2
                state['displacement'] = deltas[i]
                i += 1

        return delta_m.item(), deltas_norm ** 0.5

    def cubic_finalsolver(self, grads, param, epsilon: float, ro: float, l: float):
        """
        solve the sub problem with gradient decent
        """
        grads_m = [0] * len(grads)
        with torch.no_grad():
            deltas = [0] * len(grads)
            for i in range(len(grads)):
                deltas[i] = torch.zeros_like(grads[i])
                grads_m[i] = grads[i].clone()
            while self.compute_norm_of_list_var(grads_m, ) > epsilon / 2:
                hvp = torch.autograd.grad(outputs=grads, inputs=param, grad_outputs=deltas, retain_graph=True)
                for i in range(len(grads)):
                    deltas[i] = deltas[i] - self.step_size * grads_m[i]
                deltas_norm = self.compute_norm_of_list_var(deltas)
                for i in range(len(grads)):
                    grads_m[i] = grads[i] + hvp[i] + ro / 2 * deltas_norm * deltas[i]

            # update the displacement
            for group in self.param_groups:
                with torch.no_grad():
                    i = 0
                    for p in group['params']:
                        state = self.state[p]
                        state['displacement'] = deltas[i]
                        i += 1

    def update_parameters(self, ):

        for group in self.param_groups:
            with torch.no_grad():
                for p in group['params']:
                    state = self.state[p]
                    displacement = state['displacement']
                    p.add_(displacement.clone())

    def step(self, closure=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        self.iteration += 1

        # compute the gradiant

        for group in self.param_groups:

            grads = []
            param = []
            grad_square_norm = 0

            for p in group['params']:

                    with torch.no_grad():

                        d_p = p.grad
                        grad_square_norm += d_p.norm(2).item() ** 2

                    grads.append(p.grad)
                    param.append(p)

            # hessian = self.eval_hessian(grads,param)
            # print(hessian.shape)
            # e, _ = torch.eig(torch.tensor(hessian))
            # lambda_min = torch.min(e[:, 0]).item()
            delta_m, deltas_norm = self.cubic_subsolver(grads, param, grad_square_norm ** 0.5, self.epsilon, self.ro,
                                                        self.l)
            # if delta_m >= -(self.epsilon ** 3 / self.ro) ** 0.5 / 100:
            #     self.cubic_finalsolver(grads, param, self.epsilon, self.ro, self.l)
            #     self.update_parameters()
            #     return loss
            # else:

            self.update_parameters()

            # with tabular.prefix("SCRN" + '/'):
            #     tabular.record('delta of m', delta_m)
            #     tabular.record('norm of gradient', grad_square_norm ** (1. / 2))
            #     tabular.record('norm of deltas', deltas_norm)
            #     # tabular.record('landa min', lambda_min)
            #     logger.log(tabular)
        return loss

    # eval Hessian matrix
    def eval_hessian(self, loss_grad, params):
        cnt = 0
        for g in loss_grad:
            g_vector = g.contiguous().view(-1) if cnt == 0 else torch.cat([g_vector, g.contiguous().view(-1)])
            cnt = 1
        l = g_vector.size(0)
        hessian = torch.zeros(l, l)
        for idx in range(l):
            grad2rd = torch.autograd.grad(g_vector[idx], params, create_graph=True)
            cnt = 0
            for g in grad2rd:
                g2 = g.contiguous().view(-1) if cnt == 0 else torch.cat([g2, g.contiguous().view(-1)])
                cnt = 1
            hessian[idx] = g2
        return hessian.cpu().data.numpy()


**STEP 5 - Helpers to train with each optimizer and save**

In [6]:
def train_Adam(device,
               train_dl, 
               valid_dl,
                w,
                LR_BASE = 0.01,
                savepath='./Results/Logistic_loss epochs=100 val seed = 0/Unet_Adam_batch=5/',
                SEED=0):
  
    '''Trains a unet model with Adam and weighted cross-entropy loss. Saves the results for frequency analysis
    in a newly created folder which corresponds to the optimization method. 
    device: gpu or cpu torch device
      train_dl: training dataloader
      valid_dl: validation dataloader
      w: weights for the cross-entropy loss
      LR_BASE: learning rate
      SEED: the seed used to initialize the model and split the dataset (only used for the savepath here)
    '''

    savepath_ = savepath.replace('seed = 0','seed = '+str(SEED)) # adapt the savepath with the seed
    print(savepath_)
    if not os.path.exists(savepath_): # create the folder
      os.makedirs(savepath_)

    unet = UNET(1,2)
    opt = torch.optim.Adam(unet.parameters(),lr=LR_BASE)
    loss_fn = torch.nn.CrossEntropyLoss(weight=torch.Tensor(w))
    model = train(unet, device, train_dl, valid_dl, loss_fn, opt, epochs=100,
                  lr_base = LR_BASE, save=True, savepath=savepath_)

    torch.cuda.empty_cache() #empty gpu


def train_SGD(train_dl, 
               valid_dl,
                w,
                LR_BASE = 0.01,
                savepath='./Results/Logistic_loss epochs=100 val seed = 0/Unet_SGD_batch=5/',
                SEED=0):

    '''Trains a unet model with SGD and weighted cross-entropy loss. Saves the results for frequency analysis
    in a newly created folder which corresponds to the optimization method. 
    device: gpu or cpu torch device
      train_dl: training dataloader
      valid_dl: validation dataloader
      w: weights for the cross-entropy loss
      LR_BASE: learning rate
      SEED: the seed used to initialize the model and split the dataset (only used for the savepath here)
    '''

    savepath_ = savepath.replace('seed = 0','seed = '+str(SEED))
    print(savepath_)

    if not os.path.exists(savepath_): # create the folder
      os.makedirs(savepath_)

    unet = UNET(1,2)
    opt = torch.optim.SGD(unet.parameters(),lr=LR_BASE)
    loss_fn = torch.nn.CrossEntropyLoss(weight=torch.Tensor(w))
    model = train(unet, device, train_dl, valid_dl, loss_fn, opt, epochs=100,
                  lr_base = LR_BASE, save=True, savepath=savepath_)

    torch.cuda.empty_cache() #empty gpu


def train_SCRN(train_dl, 
               valid_dl,
                w,
                LR_BASE = 0.01,
                savepath='./Results/Logistic_loss epochs=100 val seed = 0/Unet_SCRN(r=5)_batch=5/',
                SEED=0):


    '''Trains a unet model with SCRN and weighted cross-entropy loss. Saves the results for frequency analysis
    in a newly created folder which corresponds to the optimization method. 
    device: gpu or cpu torch device
      train_dl: training dataloader
      valid_dl: validation dataloader
      w: weights for the cross-entropy loss
      LR_BASE: learning rate
      SEED: the seed used to initialize the model and split the dataset (only used for the savepath here)
    '''

    savepath_ = savepath.replace('seed = 0','seed = '+str(SEED))
    print(savepath_)

    if not os.path.exists(savepath_): # create the folder
      os.makedirs(savepath_)


    unet = UNET(1,2)
    opt = SCRNOptimizer(unet.parameters(), l=LR_BASE, ro=5)
    loss_fn = torch.nn.CrossEntropyLoss(weight=torch.Tensor(w))
    model = train(unet, device, train_dl, valid_dl, loss_fn, opt, epochs=100,
                  lr_base = LR_BASE, save=True, savepath=savepath_, mode='Hessian')

    torch.cuda.empty_cache() #empty gpu


def train_AdaHessian(train_dl, 
               valid_dl,
                w,
                LR_BASE = 0.01,
                savepath='./Results/Logistic_loss epochs=100 val seed = 0/Unet_AdaHessian_batch=5/',
                SEED=0):

    '''Trains a unet model with AdaHessian and weighted cross-entropy loss. Saves the results for frequency analysis
    in a newly created folder which corresponds to the optimization method. 
    device: gpu or cpu torch device
      train_dl: training dataloader
      valid_dl: validation dataloader
      w: weights for the cross-entropy loss
      LR_BASE: learning rate
      SEED: the seed used to initialize the model and split the dataset (only used for the savepath here)
    '''

    savepath_ = savepath.replace('seed = 0','seed = '+str(SEED))
    print(savepath_)

    if not os.path.exists(savepath_): # create the folder
      os.makedirs(savepath_)

    unet = UNET(1,2)
    opt = AdaHessian(unet.parameters(), lr=LR_BASE)
    loss_fn = torch.nn.CrossEntropyLoss(weight=torch.Tensor(w))

    model = train(unet, device, train_dl, valid_dl, loss_fn, opt, epochs=100,
                  lr_base = LR_BASE, save=True, savepath=savepath_, mode='Hessian')
    torch.cuda.empty_cache() #empty gpu

**STEP 6 - Utilities to ensure reproductibility**

In [7]:
"""Utilities for ensuring that experiments are deterministic."""
import random
import sys
import warnings

import numpy as np

seed_ = None
seed_stream_ = None


def set_seed(seed):
    """Set the process-wide random seed.

    Args:
        seed (int): A positive integer

    """
    seed %= 4294967294
    # pylint: disable=global-statement
    global seed_
    global seed_stream_
    seed_ = seed
    random.seed(seed)
    np.random.seed(seed)
    if 'tensorflow' in sys.modules:
        import tensorflow as tf  # pylint: disable=import-outside-toplevel
        tf.compat.v1.set_random_seed(seed)
        try:
            # pylint: disable=import-outside-toplevel
            import tensorflow_probability as tfp
            seed_stream_ = tfp.util.SeedStream(seed_, salt='garage')
        except ImportError:
            pass
    if 'torch' in sys.modules:
        warnings.warn(
            'Enabeling deterministic mode in PyTorch can have a performance '
            'impact when using GPU.')
        import torch  # pylint: disable=import-outside-toplevel
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def get_seed():
    """Get the process-wide random seed.

    Returns:
        int: The process-wide random seed

    """
    return seed_


def get_tf_seed_stream():
    """Get the pseudo-random number generator (PRNG) for TensorFlow ops.

    Returns:
        int: A seed generated by a PRNG with fixed global seed.

    """
    if seed_stream_ is None:
        set_seed(0)
    return seed_stream_() % 4294967294


**STEP 7 - Training loop**

In [8]:
# We use GPU resources
CUDA_LAUNCH_BLOCKING=1

device = torch.device(['cpu','cuda'][torch.cuda.is_available()])
print(device)

cuda


In [9]:
train_input_, train_output_ = load_data()
SEEDS = [0,1,2,3,4]


for SEED in SEEDS:

  set_seed(SEED) # seed for model init

  train_input, train_output, val_input, val_output = split_train_val(train_input_,SEED=SEED) # also change the train / val split
  train_dl, valid_dl, w = get_dataloaders(train_input, train_output, val_input, val_output, BATCH_SIZE = 5)

  train_SGD(train_dl, valid_dl, w, SEED=SEED)
  train_Adam(train_dl, valid_dl, w, SEED=SEED)
  train_SCRN(train_dl, valid_dl, w, SEED=SEED)
  train_AdaHessian(train_dl, valid_dl, w, SEED=SEED)
  
            
  



300 699
/content/drive/MyDrive/OptML_Data/Results/Logistic_loss epochs=100 val seed = 0/Unet_SGD_batch=5/
Epoch 0/99
----------
train  --- Current step: 0  Loss: 0.7322268486022949  AllocMem (Mb): 250.8642578125
train  --- Current step: 10  Loss: 0.6976431012153625  AllocMem (Mb): 250.6142578125
train  --- Current step: 20  Loss: 0.7056930661201477  AllocMem (Mb): 250.6142578125
train  --- Current step: 30  Loss: 0.6738709211349487  AllocMem (Mb): 250.6142578125
train  --- Current step: 40  Loss: 0.6912842392921448  AllocMem (Mb): 250.6142578125
train  --- Current step: 50  Loss: 0.6731741428375244  AllocMem (Mb): 250.6142578125
train  --- Current step: 60  Loss: 0.6902633905410767  AllocMem (Mb): 250.6142578125
train  --- Current step: 70  Loss: 0.6783270239830017  AllocMem (Mb): 250.6142578125
train  --- Current step: 80  Loss: 0.6764160990715027  AllocMem (Mb): 250.6142578125
train  --- Current step: 90  Loss: 0.6800898313522339  AllocMem (Mb): 250.6142578125
train  --- Current step

KeyboardInterrupt: ignored