In [1]:
import torch 
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math

torch.manual_seed(1) #reproducible
EPOCH = 4
BATCH_SIZE = 128

train_data = torchvision.datasets.MNIST(
    root='./mnist', #保存位置
    train=True, #training set
    transform=torchvision.transforms.ToTensor(), #converts a PIL.Image to torch.FloatTensor(C*H*W) in range(0.0,1.0)
    download=True
)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = torchvision.datasets.MNIST(
    root='./MNIST',
    train=False,
    transform=torchvision.transforms.ToTensor()
)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE,shuffle=False)

In [10]:
torch.manual_seed(27)
# network structure
class CNN(nn.Module):
    def __init__(self,D_in,H,D_out):
        super(CNN, self).__init__()
        self.fc1 = nn.Linear(D_in,H)
        #torch.nn.init.normal(self.fc1.weight, mean=0, std=0.01)
        nn.init.xavier_normal(self.fc1.weight,gain = 1)
        nn.init.constant(self.fc1.bias, 0.1)
        
        self.fc2 = nn.Linear(H,D_out)
        #torch.nn.init.normal(self.fc2.weight, mean=0, std=0.01)
        nn.init.xavier_normal(self.fc2.weight, gain = 1)
        nn.init.constant(self.fc2.bias, 0.1)
       # self.out = nn.Linear(10,10)
        
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x))
        output = x
        #output = self.out(x)
        return output

D_in,H,D_out = 784,10,10
cnn = CNN(D_in,H,D_out)

# initial hyperparameter
learning_rate = 0.3
u0 = 1
#loss function:cross-entropy with l2 regularizaiton
loss_func = nn.CrossEntropyLoss()
# inital all using viariables
EMAg = []
EMAg_2 = []
EMAx = []
EMAx_2 = []
EMAxg = []
EMAu = []
beta = []

prev_grad=[] #store params grad of last iter
prev_x=[]    #store params of last iter

# training iteration
for epoch in range(EPOCH):
    running_loss = 0.0                        # loss to show
    #training each mini-batch in dataloader
    for i, data in enumerate(train_loader,0):
       # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)

        # zero the parameter gradients
        cnn.zero_grad()

        # forward + backward
        outputs = cnn(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
       
        # implementation of cSGD algo
        for index,params in enumerate(cnn.parameters(),0):
            
            #update parameters process
            if(i==0 and epoch==0):
                #lists store EMA of weight_grad, weight_grad_square, weight, and weight_square
                
                EMAg.append(torch.Tensor(params.size()))
                EMAx.append(torch.Tensor(params.size()))
                EMAx_2.append(torch.Tensor(params.size()))
                EMAg_2.append(torch.Tensor(params.size()))
                EMAxg.append(torch.Tensor(params.size()))
                prev_grad.append(torch.Tensor(params.size()))
                prev_x.append(torch.Tensor(params.size()))
                n = params.dim()
                if n ==1:
                    for item in range(0,params.size()[0]):
                        EMAg[index][item] = params.grad.data[item]
                        prev_grad[index][item] = params.grad.data[item]
                        EMAx[index][item] = params.data[item]
                        prev_x[index][item] = params.data[item]
                        EMAg_2[index][item] = params.grad.data[item]**2
                        EMAx_2[index][item] = params.data[item]**2
                        EMAxg[index][item] = params.data[item]*params.grad.data[item]**2
                    
                if n ==2:
                    for item1 in range(0,params.size()[0]):
                        for item2 in range(0,params.size()[1]):
                            EMAg[index][item1,item2] = params.grad.data[item1,item2]
                            EMAx[index][item1,item2] = params.data[item1,item2]
                            prev_grad[index][item1,item2] = params.grad.data[item1,item2]
                            prev_x[index][item1,item2] = params.data[item1,item2]
                            EMAg_2[index][item1,item2] = params.grad.data[item1,item2]**2
                            EMAx_2[index][item1,item2] = params.data[item1,item2]**2
                            EMAxg[index][item1,item2] = params.data[item1,item2]*params.grad.data[item1,item2]
                
                #lists store EMA of u and beta
                EMAu.append(torch.Tensor(params.size()))
                beta.append(torch.ones(params.size())*0.9)
                EMAu[index] = torch.ones(EMAu[index].size())*u0
                #print(EMAg[index][0,64])
                #print(EMAg_2[index][0,1])
            else:
               
                one = torch.ones(params.size())
                #print(EMAg_2[index][0,1])
                #update EMA
                EMAg[index] = (beta[index])*EMAg[index]+(one-beta[index])*params.grad.data
                EMAg_2[index] = (beta[index])*EMAg_2[index] + (one-beta[index])*(prev_grad[index]**2)
                EMAx[index] = (beta[index])*EMAx[index] + (one-beta[index])*prev_x[index]
                EMAx_2[index] = (beta[index])*EMAx_2[index]+(one-beta[index])*(prev_x[index]**2)
                EMAxg[index] = (beta[index])*EMAxg[index] + (one-beta[index])*(prev_grad[index]*prev_x[index])
                
                
                #cal a,b,sigma,u*
                n = params.dim()
                a = torch.Tensor(params.size())
                b = torch.Tensor(params.size())
                sigma = torch.Tensor(params.size())
                u = torch.Tensor(params.size())
            
            
                if n == 1:
                    for item in range(0,params.size()[0]):
                        prev_x[index][item] = params.data[item]
                        prev_grad[index][item] = params.grad.data[item]
                        #cal a
                        if EMAxg[index][item]==EMAg[index][item]*EMAx[index][item]:
                             a[item] = 0
                        elif (EMAx_2[index][item]-EMAx[index][item]**2)==0:
                             a[item] = 10000
                        else:
                             a[item] = (EMAxg[index][item]-EMAg[index][item]*EMAx[index][item])/(EMAx_2[index][item]-EMAx[index][item]**2)
                        
                        
                        #cal sigma
                        sigma[item] = EMAg_2[index][item] - EMAg[index][item]**2
                        
                        #cal u*
                        if(a[item]<= 0):
                            u[item] = 1
                        else:
                            if(EMAg[index][item]==0):
                                u[item] = 0.0
                            elif(sigma[item]==0 or a[item]==0):
                                u[item] = 1.0
                            else:
                                u[item] = min(1,(EMAg[index][item]**2)/(learning_rate*sigma[item]*a[item])) 
                        
                    
                        #cal beta
                        if (EMAg_2[index][item]==EMAg[index][item]**2):
                            beta[index][item] = 0.9
                        elif (EMAg_2[index][item]==0):
                            beta[index][item] = 1000
                        else:
                            beta[index][item] = 0.9+(0.999-0.9)*(EMAg_2[index][item]-EMAg[index][item]**2)/(EMAg_2[index][item])
                        #update EMA u
                        EMAu[index][item] = (1-beta[index][item])*EMAu[index][item] + beta[index][item]*u[item]
                if n == 2:
                    for item1 in range(0,params.size()[0]):
                        for item2 in range(0,params.size()[1]):
                            prev_grad[index][item1,item2] = params.grad.data[item1,item2]
                            prev_x[index][item1,item2] = params.data[item1,item2]
                            #cal a
                            if ((EMAxg[index][item1,item2]-EMAg[index][item1,item2]*EMAx[index][item1,item2])==0):
                                a[item1,item2] = 0
                            elif (EMAx_2[index][item1,item2]-EMAx[index][item1,item2]**2)==0:
                                a[item1,item2] = 10000
                            else:
                                a[item1,item2] =(EMAxg[index][item1,item2]-EMAg[index][item1,item2]*EMAx[index][item1,item2])/(EMAx_2[index][item1,item2]-EMAx[index][item1,item2]**2)            
                                #print(a[item1,item2])
                        
                            #cal sigma
                            sigma[item1,item2] = EMAg_2[index][item1,item2] - math.pow(EMAg[index][item1,item2],2)
                            #print(sigma[item1,item2])
                            #cal u*
                            if(a[item1,item2]<= 0):
                                u[item1,item2] = 1.0
                            else:
                                if(EMAg[index][item1,item2]==0):
                                    u[item1,item2] = 0.0
                                elif(sigma[item1,item2]==0 or a[item1,item2]==0):
                                    u[item1,item2] = 1.0
                                else:
                                    u[item1,item2] = min(1,(EMAg[index][item1,item2]**2)/(learning_rate*sigma[item1,item2]*a[item1,item2])) 
                        
                                #print(a[item1,item2],b[item1,item2],sigma[item1,item2])
                                #print(u[item1,item2])
                                #print(a[item1,item2]*((EMAx[index][item1,item2]-b[item1,item2])**2)/(learning_rate*sigma[item1,item2]))
                            
                            #cal beta
                            if EMAg_2[index][item1,item2]==EMAg[index][item1,item2]**2:
                                beta[index][item1,item2] = 0.9
                            elif (EMAg_2[index][item1,item2]==0):
                                beta[index][item1,item2] = 1000
                            else:
                                beta[index][item1,item2] = 0.9+(0.999-0.9)*(EMAg_2[index][item1,item2]-EMAg[index][item1,item2]**2)/(EMAg_2[index][item1,item2])
                            #update EMAu
                            EMAu[index][item1,item2] = (1-beta[index][item1,item2])*EMAu[index][item1,item2] + (beta[index][item1,item2])*u[item1,item2]
                            #print(EMAu[index][item1,item2])
            #update weight and bias
            #prev_grad[index] = params.grad.data
            #prev_x[index] = params.data
            params.data -= learning_rate *EMAu[index]* params.grad.data
        
        #print("%d %d"%(epoch, i))
        running_loss += loss.data[0]
        if i % 100 == 99:    # print every 2000 mini-batches
            #print(EMAu[0])
            print('[%d, %5d] loss: %.8f' %(epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0
print('Finished Training')

[1,   100] loss: 2.38456159
[1,   200] loss: 2.38904151


KeyboardInterrupt: 