In [1]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
torch.manual_seed(1)    # reproducible

# make fake data
n_data = torch.ones(100, 2)
x0 = torch.normal(means=2*n_data, std=1)      # class0 x data (tensor), shape=(100, 2)
y0 = torch.zeros(100)               # class0 y data (tensor), shape=(100, 1)
x1 = torch.normal(means=-2*n_data, std=1)     # class1 x data (tensor), shape=(100, 2)
y1 = torch.ones(100)                # class1 y data (tensor), shape=(100, 1)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # shape (200, 2) FloatTensor = 32-bit floating
y = torch.cat((y0, y1), ).type(torch.LongTensor)    # shape (200,) LongTensor = 64-bit integer, 
# label should be LongTensor type

# torch can only train on Variable, so convert them to Variable
x, y = Variable(x), Variable(y)

In [4]:
# method 1
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
        self.out = torch.nn.Linear(n_hidden, n_output)   # output layer

    def forward(self, x):
        x = F.relu(self.hidden(x))      # activation function for hidden layer
        x = self.out(x)
        return x

In [5]:
# method 2
net2 = torch.nn.Sequential(
    torch.nn.Linear(2,10),
    torch.nn.ReLU(),
    torch.nn.Linear(10,2)
)

In [7]:
net1 = Net(2,10,2)
print(net1)

Net (
  (hidden): Linear (2 -> 10)
  (out): Linear (10 -> 2)
)


In [6]:
print(net2)

Sequential (
  (0): Linear (2 -> 10)
  (1): ReLU ()
  (2): Linear (10 -> 2)
)


In [13]:
optimizer = torch.optim.SGD(net2.parameters(), lr=0.5)
loss_func = torch.nn.CrossEntropyLoss()

In [14]:
# perform the training
for t in range(100):
    prediction = net2(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [15]:
# save method 1
torch.save(net2,'net.pkl') # entire net

In [16]:
# save method 2
torch.save(net2.state_dict(), 'net_params.pkl')  # parameters

In [17]:
# restore net
net3 = torch.load('net.pkl')

In [20]:
# restore params, first, it needs to create a network structure first
net4 = torch.nn.Sequential(
    torch.nn.Linear(2,10),
    torch.nn.ReLU(),
    torch.nn.Linear(10,2)
)
net4.load_state_dict(torch.load('net_params.pkl'))