In [None]:
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_wine

In [None]:
wine_dataset = load_wine()

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    wine_dataset.data, 
    wine_dataset.target, 
    test_size=0.3, 
    shuffle=True
)

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
X_test = torch.FloatTensor(X_test).to(device)
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
y_test = torch.LongTensor(y_test).to(device)

In [None]:
class Model(torch.nn.Module):
  def __init__(self,n_input, n_hidden_neurons):
    super(Model, self).__init__()
    self.fc1 = torch.nn.Linear(n_input, n_hidden_neurons)
    self.ac1 = torch.nn.Sigmoid()
    self.fc2 = torch.nn.Linear(n_hidden_neurons, 3)

  def forward(self, x):
    x = self.fc1(x)
    x = self.ac1(x)
    x = self.fc2(x)
    return x


In [None]:
model = Model(X_train.shape[1], 10).to(device)

In [None]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-3)

In [None]:
batch_size = 20
for epoch in range(1000):
  order = np.random.permutation(X_train.shape[0])
  for start_index in range(0, X_train.shape[0], batch_size):
    optimizer.zero_grad()
    batch_id = order[start_index : start_index+batch_size]

    X_batch = X_train[batch_id]
    y_batch = y_train[batch_id]

    predict = model.forward(X_batch)

    loss_prediction = loss(predict, y_batch)
    loss_prediction.backward()
    optimizer.step()

  if epoch % 10 == 0:
    predict = model.forward(X_test)
    predict = predict.argmax(axis=1)
    print((predict == y_test).cpu().numpy().mean())

0.24074074074074073
0.7407407407407407
0.7222222222222222
0.7222222222222222
0.7407407407407407
0.7407407407407407
0.7407407407407407
0.7407407407407407
0.7407407407407407
0.7222222222222222
0.7407407407407407
0.7962962962962963
0.7777777777777778
0.8518518518518519
0.8888888888888888
0.8888888888888888
0.8888888888888888
0.9074074074074074
0.9074074074074074
0.9074074074074074
0.9629629629629629
0.9444444444444444
0.9629629629629629
0.9259259259259259
0.9259259259259259
0.9629629629629629
0.9629629629629629
0.9259259259259259
0.9629629629629629
0.9444444444444444
0.9629629629629629
0.9814814814814815
0.9444444444444444
0.9259259259259259
0.9444444444444444
0.9444444444444444
0.9444444444444444
0.9444444444444444
0.9444444444444444
0.9629629629629629
0.9259259259259259
0.9444444444444444
0.9629629629629629
0.9629629629629629
0.9444444444444444
0.9814814814814815
0.9814814814814815
0.9629629629629629
0.9259259259259259
0.9444444444444444
0.9259259259259259
0.9629629629629629
0.962962962