In [1]:
import torch
import torch.utils.data as Data
import torch.nn.functional as F

BATCH_SIZE = 20

# make fake data
n_data = torch.ones(100, 2)
x0 = torch.normal(2*n_data, 1)      # class0 x data (tensor), shape=(100, 2)
y0 = torch.zeros(100)               # class0 y data (tensor), shape=(100, 1)
x1 = torch.normal(-2*n_data, 1)     # class1 x data (tensor), shape=(100, 2)
y1 = torch.ones(100)                # class1 y data (tensor), shape=(100, 1)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # shape (200, 2) FloatTensor = 32-bit floating
y = torch.cat((y0, y1), ).type(torch.LongTensor)    # shape (200,) LongTensor = 64-bit integer

torch_dataset = Data.TensorDataset(x, y)

loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=True,               # random shuffle for training
    num_workers=2,              # subprocesses for loading data
)

In [2]:
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.out = torch.nn.Linear(n_hidden, n_output)
    
    def forward(self, x):
        x = self.hidden(x)
        F.relu(x)
        x = self.out(x)
        
        return x

In [3]:
net = Net(n_feature=2, n_hidden=10, n_output=2)
opt = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss()

In [4]:
for epoch in range(5):   # train entire dataset 3 times
    for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
        # train your data...
        out = net(batch_x)
        loss = loss_func(out,batch_y)
        prediction = torch.argmax(out,1)
#         print(out)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        pred_y = prediction.data.numpy()
        target_y = batch_y.data.numpy()
        accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
        print('Epoch: ', epoch, '| Step: ', step, '| accuracy: ',
                  accuracy)

Epoch:  0 | Step:  0 | accuracy:  0.35
Epoch:  0 | Step:  1 | accuracy:  0.95
Epoch:  0 | Step:  2 | accuracy:  1.0
Epoch:  0 | Step:  3 | accuracy:  1.0
Epoch:  0 | Step:  4 | accuracy:  1.0
Epoch:  0 | Step:  5 | accuracy:  1.0
Epoch:  0 | Step:  6 | accuracy:  1.0
Epoch:  0 | Step:  7 | accuracy:  1.0
Epoch:  0 | Step:  8 | accuracy:  1.0
Epoch:  0 | Step:  9 | accuracy:  1.0
Epoch:  1 | Step:  0 | accuracy:  1.0
Epoch:  1 | Step:  1 | accuracy:  1.0
Epoch:  1 | Step:  2 | accuracy:  1.0
Epoch:  1 | Step:  3 | accuracy:  1.0
Epoch:  1 | Step:  4 | accuracy:  1.0
Epoch:  1 | Step:  5 | accuracy:  1.0
Epoch:  1 | Step:  6 | accuracy:  1.0
Epoch:  1 | Step:  7 | accuracy:  1.0
Epoch:  1 | Step:  8 | accuracy:  1.0
Epoch:  1 | Step:  9 | accuracy:  1.0
Epoch:  2 | Step:  0 | accuracy:  1.0
Epoch:  2 | Step:  1 | accuracy:  1.0
Epoch:  2 | Step:  2 | accuracy:  1.0
Epoch:  2 | Step:  3 | accuracy:  1.0
Epoch:  2 | Step:  4 | accuracy:  1.0
Epoch:  2 | Step:  5 | accuracy:  1.0
Epoch:  2 