In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim

In [2]:
from torch.utils.data import DataLoader, TensorDataset

def get_data():
    x_data = [[1, 2], [2, 3], [3, 1], [4, 3], [5, 3], [6, 2]]
    y_data = [[0], [0], [0], [1], [1], [1]]
    x_train = torch.FloatTensor(x_data)
    y_train = torch.FloatTensor(y_data)
    
    return DataLoader(dataset=TensorDataset(x_train, y_train), batch_size=2, shuffle=True)

In [3]:
model = torch.nn.Sequential(
    torch.nn.Linear(2, 1), 
    torch.nn.Sigmoid()
)

# torch 로 model 을 정의-> 가중치 자동 생성
optimizer = optim.SGD(model.parameters(), lr=1)


In [4]:
epochs = 100000
loader = get_data()

for epoch in range(1 + epochs):
    loss = None
    for batch_idx, data in enumerate(loader):
        X, y = data
        forward = model(X)
        loss = F.binary_cross_entropy(forward, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    assert loss is not None
    if epoch % 10000 == 0:
        print(f'Epoch: {epoch:4d}/{epochs:4d}, loss: {loss.item():.4f} ')

Epoch:    0/100000, loss: 0.2564 
Epoch: 10000/100000, loss: 0.0013 
Epoch: 20000/100000, loss: 0.0004 
Epoch: 30000/100000, loss: 0.0003 
Epoch: 40000/100000, loss: 0.0000 
Epoch: 50000/100000, loss: 0.0000 
Epoch: 60000/100000, loss: 0.0000 
Epoch: 70000/100000, loss: 0.0001 
Epoch: 80000/100000, loss: 0.0002 
Epoch: 90000/100000, loss: 0.0000 
Epoch: 100000/100000, loss: 0.0001 


In [5]:
X = torch.FloatTensor([[1, 2], [2, 3], [3, 1], [4, 3], [5, 3], [6, 2]])
y = torch.FloatTensor([[0], [0], [0], [1], [1], [1]])

pred = model(X)
print(pred)
pred = pred >= torch.FloatTensor([0.5])
print(pred)
correct_classification = pred.float() == y

accuracy = correct_classification.sum().item() / len(correct_classification)
print(f'Accuracy: {accuracy * 100}')

tensor([[2.1881e-10],
        [1.1948e-04],
        [1.3939e-04],
        [9.9983e-01],
        [1.0000e+00],
        [1.0000e+00]], grad_fn=<SigmoidBackward0>)
tensor([[False],
        [False],
        [False],
        [ True],
        [ True],
        [ True]])
Accuracy: 100.0
