In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

import pandas as pd
import numpy as np

In [2]:
# 使用gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.is_available()

True

In [3]:
iris = datasets.load_iris()
iris_x = iris.data
iris_y = iris.target
print(type(iris_x))
print(iris_x.shape)
print(type(iris_y))
print(iris_y.shape)

<class 'numpy.ndarray'>
(150, 4)
<class 'numpy.ndarray'>
(150,)


In [4]:
minmax = MinMaxScaler()
iris_x = minmax.fit_transform(iris_x)

In [5]:
train_x,valid_x,train_y,valid_y = train_test_split(iris_x,iris_y,test_size=0.15)
print(train_x.shape)
print(train_y.shape)
print(valid_x.shape)
print(valid_y.shape)

(127, 4)
(127,)
(23, 4)
(23,)


In [6]:
train_x = torch.tensor(train_x,dtype=torch.float32)
train_y = torch.tensor(train_y,dtype=torch.long)
valid_x = torch.tensor(valid_x,dtype=torch.float32)
valid_y = torch.tensor(valid_y,dtype=torch.long)

In [7]:
class dataset(Dataset):
  def __init__(self,x,y):
    self.x = x
    self.y = y
    self.n_sample = len(x)
  def __getitem__(self,index):
    return self.x[index],self.y[index]
  def __len__(self):
    return self.n_sample
  
train_set = dataset(train_x,train_y)

In [8]:
train_loader = DataLoader(dataset=train_set ,batch_size=10, shuffle=True)

In [9]:
class Model(nn.Module):
  def __init__(self):
    super(Model,self).__init__()
    self.fc1 = nn.Linear(in_features=4, out_features=24)
    self.fc2 = nn.Linear(in_features=24, out_features=24)
    self.fc3 = nn.Linear(in_features=24, out_features=3)
  def forward(self,x):
    x = self.fc1(x)
    x = F.relu(x)
    x = self.fc2(x)
    x = F.relu(x)
    x = self.fc3(x)
    return x

model = Model()
model.to(device)
train_x = train_x.to(device)
train_y = train_y.to(device)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epoch = 30
n_batch=len(train_loader)

In [11]:
for i in range(epoch):
  for j,(samples,labels) in enumerate(train_loader):
    samples = samples.to(device)
    labels = labels.to(device)

    pre = model(samples)
    labels=labels.view(-1)
    loss = criterion(pre,labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"epoch = {i+1}/{epoch} batch = {j+1}/{n_batch} loss = {loss:.4f}",end=' ')
    with torch.no_grad():
      pre = model(train_x)
      _,pre = torch.max(pre,1)
      n_sample = len(train_x)
      n_correct = ( train_y.view(-1)==pre ).sum()
      print(f"train_acc = {n_correct/n_sample:.4f}")

epoch = 1/30 batch = 1/13 loss = 1.1015 train_acc = 0.3307
epoch = 1/30 batch = 2/13 loss = 1.0642 train_acc = 0.3307
epoch = 1/30 batch = 3/13 loss = 1.0369 train_acc = 0.3307
epoch = 1/30 batch = 4/13 loss = 1.1066 train_acc = 0.3307
epoch = 1/30 batch = 5/13 loss = 1.1070 train_acc = 0.3307
epoch = 1/30 batch = 6/13 loss = 1.1416 train_acc = 0.3307
epoch = 1/30 batch = 7/13 loss = 1.1131 train_acc = 0.3307
epoch = 1/30 batch = 8/13 loss = 1.1469 train_acc = 0.3307
epoch = 1/30 batch = 9/13 loss = 1.0561 train_acc = 0.3307
epoch = 1/30 batch = 10/13 loss = 1.1528 train_acc = 0.3307
epoch = 1/30 batch = 11/13 loss = 1.1441 train_acc = 0.3307
epoch = 1/30 batch = 12/13 loss = 1.0682 train_acc = 0.3307
epoch = 1/30 batch = 13/13 loss = 1.1298 train_acc = 0.3307
epoch = 2/30 batch = 1/13 loss = 1.1504 train_acc = 0.3307
epoch = 2/30 batch = 2/13 loss = 1.1433 train_acc = 0.3307
epoch = 2/30 batch = 3/13 loss = 1.0788 train_acc = 0.3307
epoch = 2/30 batch = 4/13 loss = 1.1358 train_acc = 

In [12]:
# 儲存model
FILE = 'model.pt'
torch.save(model.state_dict(), FILE)

In [13]:
# 讀取儲存後的model
model = Model()
model.load_state_dict(torch.load(FILE))

<All keys matched successfully>

In [14]:
with torch.no_grad():
  pre = model(valid_x)
  _,pre = torch.max(pre,1)
  n_sample = len(valid_x)
  n_correct = ( valid_y.view(-1)==pre ).sum()
  print(f"valid_acc = {n_correct/n_sample}")

valid_acc = 0.95652174949646
