In [1]:
%matplotlib inline
import torch
import torchvision
from IPython import display
from matplotlib import pyplot as plt
import numpy as np
import random

In [2]:
def use_svg_display():
    # 用矢量图显示
    display.set_matplotlib_formats('svg')

def set_figsize(figsize=(3.5, 2.5)):
    use_svg_display()
    # 设置图的尺寸
    plt.rcParams['figure.figsize'] = figsize

In [3]:
!ls ./data/iris

iris_test.csv     iris_training.csv


In [4]:
class IrisDataSet(torch.utils.data.Dataset):
    def __init__(self, csv_path, sep=','):
        with open(csv_path, 'r') as f:
            lines = f.readlines()[1:]
            datas = [line.strip().split(sep) for line in lines]
            datas_np = np.array(datas, dtype=np.float32)
            self.featurs = torch.tensor(datas_np[:,:4], dtype=torch.float32)
            self.labels = torch.tensor(datas_np[:,-1], dtype=torch.float32)
    
    def __len__(self):
        return len(self.featurs)
    
    def __getitem__(self, index):
        return self.featurs[index], self.labels[index]
    

In [5]:
batch_size = 32
train_csv_path = './data/iris/iris_training.csv'
train_data_set = IrisDataSet(train_csv_path)
train_data_iter = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True)

test_csv_path = './data/iris/iris_test.csv'
test_data_set = IrisDataSet(test_csv_path)
test_data_iter = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size)

print(len(train_data_set), len(test_data_set))

120 30


### From scratch

In [6]:
w = torch.randn((4,3), dtype=torch.float32)
b = torch.zeros((1,3), dtype=torch.float32)

w.requires_grad_(True)
b.requires_grad_(True)

w, b

(tensor([[ 0.0808, -0.5288, -0.1669],
         [-1.3157, -0.0643, -1.2287],
         [-0.5900,  2.5818, -0.8695],
         [ 1.8603,  0.4940,  0.3751]], requires_grad=True),
 tensor([[0., 0., 0.]], requires_grad=True))

In [7]:
def softmax(x):
    x = x - x.max(dim=1, keepdim=True).values
    return x.exp() / x.exp().sum(dim=1, keepdim=True)

In [8]:
def cross_entropy_loss(y, y_hat):
    """
    y=[0,1,2]
    y_hat = [
        [0.7,0.2,0.1],
        [0.2,0.5,0.3],
        [0.1,0.1,0.8],
    ]
    """
    y_hat = softmax(y_hat)
    return -torch.log(
        y_hat.gather(dim=1, index=y.type(torch.long).view(-1,1))+1e-5
    ).sum()

def sgd(params, lr, bs):
    for param in params:
        param.data -= param.grad / bs * lr
        
def net(x):
    return torch.mm(x, w) + b
        
def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()

def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n

In [11]:
lr = 0.1
num_epochs = 20
params = [w, b]
loss = cross_entropy_loss
for epoch in range(1, num_epochs+1):
    train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
    for x, y in train_data_iter:
        for param in params:
            if param.grad is not None:
                param.grad.data.zero_()
        y_hat = net(x)
        l = cross_entropy_loss(y, y_hat)
        l.backward()
        sgd(params, lr, batch_size)
        
        train_l_sum += l.item()
        train_acc_sum += (y_hat.argmax(dim=1)==y).float().sum().item()
        n += y.shape[0]
    print('epoch %d loss %.4f train_acc %.4f test_acc %.4f' %
          (epoch, train_l_sum/n, train_acc_sum/n, evaluate_accuracy(test_data_iter, net)))
        

epoch 1 loss 0.5828 train_acc 0.7250 test_acc 0.8000
epoch 2 loss 0.4597 train_acc 0.7667 test_acc 0.7333
epoch 3 loss 0.4526 train_acc 0.8000 test_acc 0.7667
epoch 4 loss 0.8481 train_acc 0.6250 test_acc 0.5333
epoch 5 loss 0.6300 train_acc 0.6417 test_acc 0.5333
epoch 6 loss 0.6672 train_acc 0.6667 test_acc 0.6333
epoch 7 loss 0.4050 train_acc 0.7583 test_acc 0.6000
epoch 8 loss 0.4357 train_acc 0.7667 test_acc 0.8333
epoch 9 loss 0.4468 train_acc 0.7583 test_acc 0.5333
epoch 10 loss 0.4980 train_acc 0.7500 test_acc 0.8667
epoch 11 loss 0.4511 train_acc 0.7667 test_acc 0.9333
epoch 12 loss 0.3844 train_acc 0.8167 test_acc 0.7000
epoch 13 loss 0.3843 train_acc 0.8000 test_acc 0.5667
epoch 14 loss 0.3886 train_acc 0.8000 test_acc 0.8000
epoch 15 loss 0.3577 train_acc 0.8667 test_acc 0.8000
epoch 16 loss 0.4792 train_acc 0.8083 test_acc 0.7333
epoch 17 loss 0.5466 train_acc 0.6750 test_acc 0.6333
epoch 18 loss 0.4900 train_acc 0.7333 test_acc 0.9333
epoch 19 loss 0.3434 train_acc 0.9167

### By torch

In [39]:
class LinearNet(torch.nn.Module):
    def __init__(self, input, output):
        super().__init__()
        self.linear = torch.nn.Linear(input, output)
        
    def forward(self, x):
        return self.linear(x)

In [49]:
lr = 3e-4
loss = torch.nn.CrossEntropyLoss()

net = LinearNet(4, 3)

optimizer = torch.optim.Adam(net.parameters())
torch.nn.init.normal_(net.linear.weight, 0., 0.1)
torch.nn.init.constant_(net.linear.bias, 0.)

Parameter containing:
tensor([0., 0., 0.], requires_grad=True)

In [50]:
def train(num_epochs, batch_size, net, loss, train_data_iter, test_data_iter, optimizer):
    for epoch in range(1, num_epochs+1):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for x, y in train_data_iter:
            for param in params:
                if param.grad is not None:
                    param.grad.data.zero_()
            y_hat = net(x)
            l = cross_entropy_loss(y, y_hat)
            l.backward()
            optimizer.step()

            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1)==y).float().sum().item()
            n += y.shape[0]
        print('epoch %d loss %.4f train_acc %.4f test_acc %.4f' %
              (epoch, train_l_sum/n, train_acc_sum/n, evaluate_accuracy(test_data_iter, net)))

In [52]:
train(
    num_epochs=1000,
    batch_size=batch_size,
    net = net,
    loss = loss,
    train_data_iter=train_data_iter,
    test_data_iter=test_data_iter,
    optimizer=optimizer
)

epoch 1 loss 0.6342 train_acc 0.6500 test_acc 0.7333
epoch 2 loss 0.6423 train_acc 0.6500 test_acc 0.7333
epoch 3 loss 0.6468 train_acc 0.6500 test_acc 0.7333
epoch 4 loss 0.6473 train_acc 0.6500 test_acc 0.7333
epoch 5 loss 0.6437 train_acc 0.6500 test_acc 0.7333
epoch 6 loss 0.6351 train_acc 0.6500 test_acc 0.7333
epoch 7 loss 0.6235 train_acc 0.6500 test_acc 0.7333
epoch 8 loss 0.6056 train_acc 0.6500 test_acc 0.7333
epoch 9 loss 0.5863 train_acc 0.6500 test_acc 0.7333
epoch 10 loss 0.5628 train_acc 0.6500 test_acc 0.7333
epoch 11 loss 0.5394 train_acc 0.6500 test_acc 0.7333
epoch 12 loss 0.5154 train_acc 0.6500 test_acc 0.7333
epoch 13 loss 0.4926 train_acc 0.6583 test_acc 0.7333
epoch 14 loss 0.4778 train_acc 0.7833 test_acc 0.9333
epoch 15 loss 0.4600 train_acc 0.9500 test_acc 0.9333
epoch 16 loss 0.4529 train_acc 0.9167 test_acc 0.6333
epoch 17 loss 0.4494 train_acc 0.7417 test_acc 0.5667
epoch 18 loss 0.4517 train_acc 0.7000 test_acc 0.5333
epoch 19 loss 0.4589 train_acc 0.7000

epoch 167 loss 0.2978 train_acc 0.8833 test_acc 0.9667
epoch 168 loss 0.2762 train_acc 0.9250 test_acc 0.9667
epoch 169 loss 0.2605 train_acc 0.9583 test_acc 0.9667
epoch 170 loss 0.2564 train_acc 0.9833 test_acc 0.9667
epoch 171 loss 0.2534 train_acc 0.9667 test_acc 0.9333
epoch 172 loss 0.2576 train_acc 0.9500 test_acc 0.9000
epoch 173 loss 0.2645 train_acc 0.9167 test_acc 0.7667
epoch 174 loss 0.2816 train_acc 0.8500 test_acc 0.7000
epoch 175 loss 0.2994 train_acc 0.8167 test_acc 0.6333
epoch 176 loss 0.3202 train_acc 0.7917 test_acc 0.6000
epoch 177 loss 0.3488 train_acc 0.7750 test_acc 0.6000
epoch 178 loss 0.3744 train_acc 0.7333 test_acc 0.5667
epoch 179 loss 0.4058 train_acc 0.7167 test_acc 0.5667
epoch 180 loss 0.4371 train_acc 0.7083 test_acc 0.5667
epoch 181 loss 0.4708 train_acc 0.7000 test_acc 0.5333
epoch 182 loss 0.5049 train_acc 0.7000 test_acc 0.5333
epoch 183 loss 0.5340 train_acc 0.7000 test_acc 0.5333
epoch 184 loss 0.5662 train_acc 0.7000 test_acc 0.5333
epoch 185 

epoch 317 loss 0.7994 train_acc 0.7000 test_acc 0.5333
epoch 318 loss 0.7996 train_acc 0.7000 test_acc 0.5333
epoch 319 loss 0.7978 train_acc 0.7000 test_acc 0.5333
epoch 320 loss 0.7946 train_acc 0.7000 test_acc 0.5333
epoch 321 loss 0.7894 train_acc 0.7000 test_acc 0.5333
epoch 322 loss 0.7821 train_acc 0.7000 test_acc 0.5333
epoch 323 loss 0.7730 train_acc 0.7000 test_acc 0.5333
epoch 324 loss 0.7631 train_acc 0.7000 test_acc 0.5333
epoch 325 loss 0.7509 train_acc 0.7000 test_acc 0.5333
epoch 326 loss 0.7367 train_acc 0.7000 test_acc 0.5333
epoch 327 loss 0.7205 train_acc 0.7000 test_acc 0.5667
epoch 328 loss 0.7049 train_acc 0.7000 test_acc 0.5667
epoch 329 loss 0.6845 train_acc 0.7000 test_acc 0.5667
epoch 330 loss 0.6651 train_acc 0.7000 test_acc 0.5667
epoch 331 loss 0.6438 train_acc 0.7000 test_acc 0.5667
epoch 332 loss 0.6202 train_acc 0.7000 test_acc 0.5667
epoch 333 loss 0.5993 train_acc 0.7000 test_acc 0.5667
epoch 334 loss 0.5738 train_acc 0.7167 test_acc 0.5667
epoch 335 

epoch 470 loss 0.6976 train_acc 0.7250 test_acc 0.5667
epoch 471 loss 0.6932 train_acc 0.7250 test_acc 0.5667
epoch 472 loss 0.6874 train_acc 0.7333 test_acc 0.5667
epoch 473 loss 0.6804 train_acc 0.7333 test_acc 0.5667
epoch 474 loss 0.6723 train_acc 0.7333 test_acc 0.5667
epoch 475 loss 0.6627 train_acc 0.7333 test_acc 0.5667
epoch 476 loss 0.6532 train_acc 0.7333 test_acc 0.5667
epoch 477 loss 0.6417 train_acc 0.7333 test_acc 0.5667
epoch 478 loss 0.6276 train_acc 0.7333 test_acc 0.6000
epoch 479 loss 0.6140 train_acc 0.7417 test_acc 0.6000
epoch 480 loss 0.5990 train_acc 0.7417 test_acc 0.6000
epoch 481 loss 0.5833 train_acc 0.7500 test_acc 0.6000
epoch 482 loss 0.5661 train_acc 0.7500 test_acc 0.6000
epoch 483 loss 0.5502 train_acc 0.7500 test_acc 0.6000
epoch 484 loss 0.5308 train_acc 0.7583 test_acc 0.6333
epoch 485 loss 0.5112 train_acc 0.7667 test_acc 0.6333
epoch 486 loss 0.4936 train_acc 0.7750 test_acc 0.6333
epoch 487 loss 0.4706 train_acc 0.7833 test_acc 0.6333
epoch 488 

epoch 644 loss 0.5607 train_acc 0.7833 test_acc 0.6333
epoch 645 loss 0.5663 train_acc 0.7833 test_acc 0.6333
epoch 646 loss 0.5707 train_acc 0.7833 test_acc 0.6333
epoch 647 loss 0.5745 train_acc 0.7833 test_acc 0.6333
epoch 648 loss 0.5774 train_acc 0.7833 test_acc 0.6333
epoch 649 loss 0.5798 train_acc 0.7833 test_acc 0.6333
epoch 650 loss 0.5809 train_acc 0.7833 test_acc 0.6333
epoch 651 loss 0.5815 train_acc 0.7833 test_acc 0.6333
epoch 652 loss 0.5811 train_acc 0.7833 test_acc 0.6333
epoch 653 loss 0.5799 train_acc 0.7833 test_acc 0.6333
epoch 654 loss 0.5779 train_acc 0.7833 test_acc 0.6333
epoch 655 loss 0.5750 train_acc 0.7833 test_acc 0.6333
epoch 656 loss 0.5715 train_acc 0.7833 test_acc 0.6333
epoch 657 loss 0.5671 train_acc 0.7833 test_acc 0.6333
epoch 658 loss 0.5614 train_acc 0.7833 test_acc 0.6333
epoch 659 loss 0.5560 train_acc 0.7833 test_acc 0.6667
epoch 660 loss 0.5486 train_acc 0.8000 test_acc 0.6667
epoch 661 loss 0.5405 train_acc 0.8000 test_acc 0.6667
epoch 662 

epoch 823 loss 0.0953 train_acc 0.9583 test_acc 0.9667
epoch 824 loss 0.0997 train_acc 0.9583 test_acc 0.9667
epoch 825 loss 0.1059 train_acc 0.9583 test_acc 0.9667
epoch 826 loss 0.1102 train_acc 0.9583 test_acc 0.9333
epoch 827 loss 0.1155 train_acc 0.9583 test_acc 0.9333
epoch 828 loss 0.1222 train_acc 0.9583 test_acc 0.9333
epoch 829 loss 0.1280 train_acc 0.9583 test_acc 0.9333
epoch 830 loss 0.1340 train_acc 0.9583 test_acc 0.9333
epoch 831 loss 0.1409 train_acc 0.9583 test_acc 0.9333
epoch 832 loss 0.1478 train_acc 0.9583 test_acc 0.9333
epoch 833 loss 0.1552 train_acc 0.9583 test_acc 0.9333
epoch 834 loss 0.1639 train_acc 0.9583 test_acc 0.9333
epoch 835 loss 0.1711 train_acc 0.9583 test_acc 0.9333
epoch 836 loss 0.1794 train_acc 0.9417 test_acc 0.9333
epoch 837 loss 0.1870 train_acc 0.9417 test_acc 0.9333
epoch 838 loss 0.1954 train_acc 0.9333 test_acc 0.9333
epoch 839 loss 0.2025 train_acc 0.9250 test_acc 0.9333
epoch 840 loss 0.2123 train_acc 0.9167 test_acc 0.9333
epoch 841 

epoch 987 loss 0.2763 train_acc 0.8667 test_acc 0.9333
epoch 988 loss 0.2836 train_acc 0.8667 test_acc 0.9333
epoch 989 loss 0.2910 train_acc 0.8667 test_acc 0.9333
epoch 990 loss 0.2981 train_acc 0.8667 test_acc 0.9333
epoch 991 loss 0.3048 train_acc 0.8667 test_acc 0.9333
epoch 992 loss 0.3110 train_acc 0.8667 test_acc 0.9333
epoch 993 loss 0.3174 train_acc 0.8667 test_acc 0.9333
epoch 994 loss 0.3229 train_acc 0.8667 test_acc 0.9333
epoch 995 loss 0.3282 train_acc 0.8667 test_acc 0.9333
epoch 996 loss 0.3334 train_acc 0.8667 test_acc 0.9333
epoch 997 loss 0.3384 train_acc 0.8667 test_acc 0.9333
epoch 998 loss 0.3428 train_acc 0.8583 test_acc 0.9333
epoch 999 loss 0.3462 train_acc 0.8583 test_acc 0.9333
epoch 1000 loss 0.3496 train_acc 0.8583 test_acc 0.9333
