This code generates **Figure 3 and 4**

In [9]:
## Required Liberaries: 
#! pip3 install torch 
#! pip3 install torchvision
# and GPUs

In [10]:
import numpy as np 
import torch 
import torchvision
from matplotlib import pyplot as plt
import math

Loading cifar dataset

In [11]:
batch_size_train = 500
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.CIFAR10('./data/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,0.5,0.5), (0.5,0.5,0.5))
                             ])),
  batch_size=batch_size_train, shuffle=True)

Files already downloaded and verified


In [12]:
# checking the hardware
# !nvidia-smi --query-gpu=gpu_name,driver_version,memory.total --format=csv

In [13]:
### tensor types, to run the code on CPU please update the types 
dtype = torch.cuda.FloatTensor # CPU -> torch.FloatTensor
dtype_labels = torch.cuda.LongTensor # CPU -> torch.cuda.LongTensor

Implementation of a multiLayer prectron

In [14]:
class MLP(torch.nn.Module): 
    def __init__(self,layer_num,input_size,width,out_size,bias_on=False,act = None): # act is the activation function
        super().__init__()
        self.layer_num = layer_num
        self.width = width
        self.out_size = out_size
        self.bias_on = bias_on
        self.act = act
        self.layers = torch.nn.ModuleList() # contains linear layers 
        self.layers.append(torch.nn.Linear(input_size,width,bias=bias_on))
        for i in range(layer_num-2): 
              self.layers.append(torch.nn.Linear(width,width,bias=bias_on))
        self.layers.append(torch.nn.Linear(width,out_size,bias=bias_on))
        self.act = act
    def forward(self,x): 
        out = x
        index = 0
        for lay in self.layers: # passing input through the layers 
            index = index +1 
            if self.act is not None and index<self.layer_num: 
                out = self.act(lay(out))
            else: 
                out = lay(out)
        return out
    def forward_minus(self,x): # skip the output layer
        out = x
        index = 0
        for lay in self.layers:
            index = index +1 
            if self.act is not None and index<self.layer_num-1: 
                out = self.act(lay(out))
        return out


MLP with batch normalization

In [15]:
class BNMLP(torch.nn.Module): 
    def __init__(self,layer_num,input_size,width,out_size,bias_on=False,act = None): # act is the activation function
        super().__init__()
        self.layer_num = layer_num
        self.width = width
        self.out_size = out_size
        self.bias_on = bias_on
        self.act = act
        self.layers = torch.nn.ModuleList() # linear layers
        self.layers.append(torch.nn.Linear(input_size,width,bias=bias_on))
        for i in range(layer_num-2): 
              self.layers.append(torch.nn.Linear(width,width,bias=bias_on))
        self.layers.append(torch.nn.Linear(width,out_size,bias=bias_on))
        self.act = act
        self.bns = torch.nn.ModuleList() # batch normalization layers
        for i in range(layer_num-1): 
              self.bns.append(torch.nn.BatchNorm1d(num_features=width)) # for MLPs we use 1d batch normalization
    def forward(self,x): 
        out = x
        for i in range(self.layer_num-1):
            if self.act is not None: 
                out = self.act(self.bns[i](self.layers[i](out)))
            else: 
                out = self.bns[i](self.layers[i](out))
       
        out = self.layers[self.layer_num-1](out)
        return out
    

Here, we implement xavier's initialization of weights

In [16]:
layer_num = 50
out_size =10 
input_size = 3*32*32
width = batch_size_train
net = MLP(layer_num, input_size,width,out_size)
net = net.cuda()
def xavier_init(m): # xavier initialization 
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu'))
        if m.bias is not None:
            m.bias.data.fill_(0)
def kaiming_init(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            m.bias.data.fill_(0)


AssertionError: Torch not compiled with CUDA enabled

The following functions implements the iterative initialization. Let $H_\ell \in R^{d \times n}$ denotes hidden representations in layer $\ell$. we use SVD decompostion $H_\ell = U_\ell \Sigma_\ell V_\ell^\top$ to initialize weights in layer $\ell$ as 
$$W_\ell = \frac{1}{\| \Sigma^{1/2} \|_F } V'_\ell \Sigma_\ell^{1/2} U_\ell^\top$$ 
where $V'_\ell$ is a slice of $V_\ell$. 

In [None]:
def improved_init(net, width):
   # picking a large mini batch of inputs
    batch_size_train = 3*width
    loader = torch.utils.data.DataLoader(
      torchvision.datasets.CIFAR10('./data/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,0.5,0.5), (0.5,0.5,0.5))
                             ])),
     batch_size=batch_size_train, shuffle=True)
    examples = enumerate(loader)
    batch_idx, (images, example_targets) = next(examples)
    H0 = images.view(-1,3*32*32).type(dtype)
   
    H = H0
    gamma = 0.1
    layer_num = net.layer_num
    # passing the input and compute the hidden representation and weights iteratively  
    for i in range(layer_num-1):
        if i>0: 
            Hdata = H.data
            u,s, v = torch.svd(Hdata) # svd computation 
            wd = net.layers[i].weight.data.size(0)
            w = u[0:wd,0:wd].mm(torch.diag(1/torch.pow(s[0:wd],1))).mm(v.t()[0:wd,:]) # weight initialization
            net.layers[i].weight.data = w
            net.layers[i].weight.data = w 
            net.layers[i].weight.data = net.layers[i].weight.data/torch.norm(net.layers[i](H)) # normalization factor
        if net.act is not None:
            H = net.act(net.layers[i](H))
        else:
            H = net.layers[i](H)
   
    return net


In [None]:
def compute_orthogonality_gap(net):
  loader = torch.utils.data.DataLoader(
      torchvision.datasets.CIFAR10('./data/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,0.5,0.5), (0.5,0.5,0.5))
                             ])),
     batch_size=batch_size_train, shuffle=True)
  avg_gap = 0
  index = 0 
  for x,y in loader: 
    index += 1
    x = x.type(dtype)
    y = y.type(dtype_labels)
    out = net.forward(torch.flatten(x,1))
    u1,s1,v1 = torch.svd(out.data)
    s1 = s1/torch.norm(s1)
    s2 = torch.tensor(np.ones(s1.size(0))/math.sqrt(s1.size(0))).type(dtype)
    avg_gap += torch.norm(s1-s2)
  avg_gap = avg_gap/index
  print('average of the orthogonality gap is ',avg_gap)
  return avg_gap

In [None]:
def train(net,train_loader,epochs_num = 10,stepsize=0.01):
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(),lr=stepsize)
    conv = []
    gaps = []
    itrs = []
    for i in range(epochs_num):
        current_loss = 0 
        r_size = 0
        for x,y in train_loader: 
            x = x.type(dtype)
            y = y.type(dtype_labels)
            r_size += x.size(0)
            optimizer.zero_grad()
            predy = net(torch.flatten(x,1))
            loss = loss_function(predy,y)
            loss.backward()
            optimizer.step()
            current_loss += loss.item()*batch_size_train
        if i % 5 == 0:
          gaps.append(compute_orthogonality_gap(net))   
          conv.append(current_loss/r_size)
          itrs.append(i+1)
          print(current_loss/r_size) 
    return conv,gaps,itrs

In [None]:
# computing the decay in the orthogonality gap for the iterative initialization ## in figure 3.b 
ACTIVE = torch.nn.functional.relu
convs =[] # convergence rate
gaps = [] # orthogonality gap during training
repeat = 5
for i in range(repeat):
  mynet = MLP(20, input_size,1000,out_size,act=torch.nn.functional.relu,bias_on = False)
  mynet = mynet.cuda()
  mynet.apply(xavier_init)
  conv, gap,itrs = train(mynet,train_loader,epochs_num = 50)
  convs.append(conv)
  gaps.append(gap)

The convergence for different initialization methods

In [None]:
# Figures 3.a, 4.a, and 4.b
layers = [15,30,45,60,75]
bias_on = False
width = 800
epochs = 30
repeat = 4
ACTIVE = torch.nn.functional.relu
ouresults = np.zeros((len(layers),epochs,repeat))
xavir_result = np.zeros((len(layers),epochs,repeat))
for j in range(repeat):
  for i in range(len(layers)): 
      layer_num = layers[i]
      mynet = MLP(layer_num, input_size,width,out_size,act=ACTIVE,bias_on = bias_on)
      mynet = mynet.cuda()
      mynet.apply(xavier_init)
      mynet = improved_init(mynet,width)
      print(i,'our initialization -----')
      res = train(mynet,train_loader,epochs_num = epochs)
      ouresults[i,:,j] = res[0]
      print(i,'xavier initialization----')
      netx = MLP(layer_num, input_size,width,out_size,act=ACTIVE,bias_on = bias_on)
      netx = netx.cuda()
      netx = netx.apply(xavier_init)
      result_xavier =  train(netx,train_loader,epochs_num = epochs)
      xavir_result[i,:,j] =result_xavier[0]
     