In [1]:
%matplotlib inline
import torch
import torchvision
from IPython import display
from matplotlib import pyplot as plt
import numpy as np
import random

In [2]:
def use_svg_display():
    # 用矢量图显示
    display.set_matplotlib_formats('svg')

def set_figsize(figsize=(3.5, 2.5)):
    use_svg_display()
    # 设置图的尺寸
    plt.rcParams['figure.figsize'] = figsize

In [3]:
!ls ./data/iris

iris_test.csv     iris_training.csv


In [4]:
class IrisDataSet(torch.utils.data.Dataset):
    def __init__(self, csv_path, sep=','):
        with open(csv_path, 'r') as f:
            lines = f.readlines()[1:]
            datas = [line.strip().split(sep) for line in lines]
            datas_np = np.array(datas, dtype=np.float32)
            self.featurs = torch.tensor(datas_np[:,:4], dtype=torch.float32)
            self.labels = torch.tensor(datas_np[:,-1], dtype=torch.float32)
    
    def __len__(self):
        return len(self.featurs)
    
    def __getitem__(self, index):
        return self.featurs[index], self.labels[index]
    

In [5]:
batch_size = 32
train_csv_path = './data/iris/iris_training.csv'
train_data_set = IrisDataSet(train_csv_path)
train_data_iter = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True)

test_csv_path = './data/iris/iris_test.csv'
test_data_set = IrisDataSet(test_csv_path)
test_data_iter = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size)

print(len(train_data_set), len(test_data_set))

120 30


### From scratch

In [6]:
w = torch.randn((4,3), dtype=torch.float32)
b = torch.zeros((1,3), dtype=torch.float32)

w.requires_grad_(True)
b.requires_grad_(True)

w, b

(tensor([[ 0.0808, -0.5288, -0.1669],
         [-1.3157, -0.0643, -1.2287],
         [-0.5900,  2.5818, -0.8695],
         [ 1.8603,  0.4940,  0.3751]], requires_grad=True),
 tensor([[0., 0., 0.]], requires_grad=True))

In [7]:
def softmax(x):
    x = x - x.max(dim=1, keepdim=True).values
    return x.exp() / x.exp().sum(dim=1, keepdim=True)

In [8]:
def cross_entropy_loss(y, y_hat):
    """
    y=[0,1,2]
    y_hat = [
        [0.7,0.2,0.1],
        [0.2,0.5,0.3],
        [0.1,0.1,0.8],
    ]
    """
    y_hat = softmax(y_hat)
    return -torch.log(
        y_hat.gather(dim=1, index=y.type(torch.long).view(-1,1))+1e-5
    ).sum()

def sgd(params, lr, bs):
    for param in params:
        param.data -= param.grad / bs * lr
        
def net(x):
    return torch.mm(x, w) + b
        
def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()

def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n

In [11]:
lr = 0.1
num_epochs = 20
params = [w, b]
loss = cross_entropy_loss
for epoch in range(1, num_epochs+1):
    train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
    for x, y in train_data_iter:
        for param in params:
            if param.grad is not None:
                param.grad.data.zero_()
        y_hat = net(x)
        l = cross_entropy_loss(y, y_hat)
        l.backward()
        sgd(params, lr, batch_size)
        
        train_l_sum += l.item()
        train_acc_sum += (y_hat.argmax(dim=1)==y).float().sum().item()
        n += y.shape[0]
    print('epoch %d loss %.4f train_acc %.4f test_acc %.4f' %
          (epoch, train_l_sum/n, train_acc_sum/n, evaluate_accuracy(test_data_iter, net)))
        

epoch 1 loss 0.5828 train_acc 0.7250 test_acc 0.8000
epoch 2 loss 0.4597 train_acc 0.7667 test_acc 0.7333
epoch 3 loss 0.4526 train_acc 0.8000 test_acc 0.7667
epoch 4 loss 0.8481 train_acc 0.6250 test_acc 0.5333
epoch 5 loss 0.6300 train_acc 0.6417 test_acc 0.5333
epoch 6 loss 0.6672 train_acc 0.6667 test_acc 0.6333
epoch 7 loss 0.4050 train_acc 0.7583 test_acc 0.6000
epoch 8 loss 0.4357 train_acc 0.7667 test_acc 0.8333
epoch 9 loss 0.4468 train_acc 0.7583 test_acc 0.5333
epoch 10 loss 0.4980 train_acc 0.7500 test_acc 0.8667
epoch 11 loss 0.4511 train_acc 0.7667 test_acc 0.9333
epoch 12 loss 0.3844 train_acc 0.8167 test_acc 0.7000
epoch 13 loss 0.3843 train_acc 0.8000 test_acc 0.5667
epoch 14 loss 0.3886 train_acc 0.8000 test_acc 0.8000
epoch 15 loss 0.3577 train_acc 0.8667 test_acc 0.8000
epoch 16 loss 0.4792 train_acc 0.8083 test_acc 0.7333
epoch 17 loss 0.5466 train_acc 0.6750 test_acc 0.6333
epoch 18 loss 0.4900 train_acc 0.7333 test_acc 0.9333
epoch 19 loss 0.3434 train_acc 0.9167

### By torch

In [53]:
class LinearNet(torch.nn.Module):
    def __init__(self, input, output):
        super().__init__()
        self.linear = torch.nn.Linear(input, output)
        
    def forward(self, x):
        return self.linear(x)

In [54]:
lr = 3e-4
loss = torch.nn.CrossEntropyLoss()

net = LinearNet(4, 3)

optimizer = torch.optim.Adam(net.parameters())
torch.nn.init.normal_(net.linear.weight, 0., 0.1)
torch.nn.init.constant_(net.linear.bias, 0.)

Parameter containing:
tensor([0., 0., 0.], requires_grad=True)

In [55]:
def train(num_epochs, batch_size, net, loss, train_data_iter, test_data_iter, optimizer):
    for epoch in range(1, num_epochs+1):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for x, y in train_data_iter:
            for param in params:
                if param.grad is not None:
                    param.grad.data.zero_()
            y_hat = net(x)
            l = cross_entropy_loss(y, y_hat)
            l.backward()
            optimizer.step()

            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1)==y).float().sum().item()
            n += y.shape[0]
        print('epoch %d loss %.4f train_acc %.4f test_acc %.4f' %
              (epoch, train_l_sum/n, train_acc_sum/n, evaluate_accuracy(test_data_iter, net)))

In [56]:
train(
    num_epochs=1000,
    batch_size=batch_size,
    net = net,
    loss = loss,
    train_data_iter=train_data_iter,
    test_data_iter=test_data_iter,
    optimizer=optimizer
)

epoch 1 loss 1.0280 train_acc 0.3500 test_acc 0.2667
epoch 2 loss 1.0247 train_acc 0.3500 test_acc 0.2667
epoch 3 loss 1.0216 train_acc 0.3583 test_acc 0.2667
epoch 4 loss 1.0182 train_acc 0.3750 test_acc 0.2667
epoch 5 loss 1.0145 train_acc 0.3833 test_acc 0.3333
epoch 6 loss 1.0105 train_acc 0.4250 test_acc 0.4000
epoch 7 loss 1.0071 train_acc 0.5333 test_acc 0.4333
epoch 8 loss 1.0039 train_acc 0.6083 test_acc 0.4667
epoch 9 loss 1.0005 train_acc 0.6750 test_acc 0.5000
epoch 10 loss 0.9973 train_acc 0.7000 test_acc 0.5000
epoch 11 loss 0.9935 train_acc 0.7000 test_acc 0.5333
epoch 12 loss 0.9899 train_acc 0.7000 test_acc 0.5333
epoch 13 loss 0.9865 train_acc 0.7000 test_acc 0.5333
epoch 14 loss 0.9829 train_acc 0.7000 test_acc 0.5333
epoch 15 loss 0.9796 train_acc 0.7000 test_acc 0.5333
epoch 16 loss 0.9759 train_acc 0.7000 test_acc 0.5333
epoch 17 loss 0.9727 train_acc 0.7000 test_acc 0.5333
epoch 18 loss 0.9692 train_acc 0.7000 test_acc 0.5333
epoch 19 loss 0.9660 train_acc 0.7000

epoch 169 loss 0.6071 train_acc 0.7000 test_acc 0.5333
epoch 170 loss 0.6076 train_acc 0.7000 test_acc 0.5333
epoch 171 loss 0.6077 train_acc 0.7000 test_acc 0.5333
epoch 172 loss 0.6086 train_acc 0.7000 test_acc 0.5333
epoch 173 loss 0.6089 train_acc 0.7000 test_acc 0.5333
epoch 174 loss 0.6093 train_acc 0.7000 test_acc 0.5333
epoch 175 loss 0.6097 train_acc 0.7000 test_acc 0.5333
epoch 176 loss 0.6097 train_acc 0.7000 test_acc 0.5333
epoch 177 loss 0.6100 train_acc 0.7000 test_acc 0.5333
epoch 178 loss 0.6098 train_acc 0.7000 test_acc 0.5333
epoch 179 loss 0.6094 train_acc 0.7000 test_acc 0.5333
epoch 180 loss 0.6087 train_acc 0.7000 test_acc 0.5333
epoch 181 loss 0.6079 train_acc 0.7000 test_acc 0.5333
epoch 182 loss 0.6066 train_acc 0.7000 test_acc 0.5333
epoch 183 loss 0.6050 train_acc 0.7000 test_acc 0.5333
epoch 184 loss 0.6031 train_acc 0.7000 test_acc 0.5333
epoch 185 loss 0.6010 train_acc 0.7000 test_acc 0.5333
epoch 186 loss 0.5988 train_acc 0.7000 test_acc 0.5333
epoch 187 

epoch 324 loss 0.4773 train_acc 0.7750 test_acc 0.8000
epoch 325 loss 0.4757 train_acc 0.7833 test_acc 0.8000
epoch 326 loss 0.4737 train_acc 0.7833 test_acc 0.8000
epoch 327 loss 0.4715 train_acc 0.7833 test_acc 0.8000
epoch 328 loss 0.4691 train_acc 0.7833 test_acc 0.8333
epoch 329 loss 0.4665 train_acc 0.7917 test_acc 0.9000
epoch 330 loss 0.4634 train_acc 0.8000 test_acc 0.9000
epoch 331 loss 0.4608 train_acc 0.8333 test_acc 0.9000
epoch 332 loss 0.4577 train_acc 0.8667 test_acc 0.9000
epoch 333 loss 0.4547 train_acc 0.8750 test_acc 0.9333
epoch 334 loss 0.4514 train_acc 0.8833 test_acc 0.9333
epoch 335 loss 0.4489 train_acc 0.9083 test_acc 0.9667
epoch 336 loss 0.4465 train_acc 0.9167 test_acc 0.9667
epoch 337 loss 0.4435 train_acc 0.9667 test_acc 0.9667
epoch 338 loss 0.4412 train_acc 0.9667 test_acc 0.9667
epoch 339 loss 0.4383 train_acc 0.9667 test_acc 0.9667
epoch 340 loss 0.4361 train_acc 0.9500 test_acc 0.9333
epoch 341 loss 0.4344 train_acc 0.9583 test_acc 0.9333
epoch 342 

epoch 501 loss 0.3983 train_acc 0.7167 test_acc 0.5667
epoch 502 loss 0.3972 train_acc 0.7167 test_acc 0.5667
epoch 503 loss 0.3958 train_acc 0.7167 test_acc 0.5667
epoch 504 loss 0.3944 train_acc 0.7250 test_acc 0.5667
epoch 505 loss 0.3928 train_acc 0.7333 test_acc 0.5667
epoch 506 loss 0.3909 train_acc 0.7417 test_acc 0.5667
epoch 507 loss 0.3887 train_acc 0.7417 test_acc 0.5667
epoch 508 loss 0.3863 train_acc 0.7417 test_acc 0.6000
epoch 509 loss 0.3836 train_acc 0.7417 test_acc 0.6000
epoch 510 loss 0.3812 train_acc 0.7417 test_acc 0.6000
epoch 511 loss 0.3786 train_acc 0.7500 test_acc 0.6000
epoch 512 loss 0.3757 train_acc 0.7500 test_acc 0.6000
epoch 513 loss 0.3728 train_acc 0.7583 test_acc 0.6000
epoch 514 loss 0.3698 train_acc 0.7667 test_acc 0.6000
epoch 515 loss 0.3671 train_acc 0.7667 test_acc 0.6000
epoch 516 loss 0.3638 train_acc 0.7667 test_acc 0.6333
epoch 517 loss 0.3611 train_acc 0.7667 test_acc 0.6333
epoch 518 loss 0.3583 train_acc 0.7833 test_acc 0.6333
epoch 519 

epoch 655 loss 0.3240 train_acc 0.8000 test_acc 0.7000
epoch 656 loss 0.3216 train_acc 0.8083 test_acc 0.7000
epoch 657 loss 0.3186 train_acc 0.8167 test_acc 0.7000
epoch 658 loss 0.3162 train_acc 0.8333 test_acc 0.7000
epoch 659 loss 0.3128 train_acc 0.8417 test_acc 0.7333
epoch 660 loss 0.3099 train_acc 0.8500 test_acc 0.7333
epoch 661 loss 0.3079 train_acc 0.8667 test_acc 0.7333
epoch 662 loss 0.3048 train_acc 0.8750 test_acc 0.7667
epoch 663 loss 0.3021 train_acc 0.8750 test_acc 0.7667
epoch 664 loss 0.3005 train_acc 0.8917 test_acc 0.8000
epoch 665 loss 0.2970 train_acc 0.8917 test_acc 0.8000
epoch 666 loss 0.2950 train_acc 0.9000 test_acc 0.8333
epoch 667 loss 0.2927 train_acc 0.9083 test_acc 0.8667
epoch 668 loss 0.2903 train_acc 0.9167 test_acc 0.8667
epoch 669 loss 0.2885 train_acc 0.9250 test_acc 0.9000
epoch 670 loss 0.2867 train_acc 0.9417 test_acc 0.9000
epoch 671 loss 0.2851 train_acc 0.9417 test_acc 0.9000
epoch 672 loss 0.2836 train_acc 0.9417 test_acc 0.9000
epoch 673 

epoch 804 loss 0.2945 train_acc 0.8333 test_acc 0.7000
epoch 805 loss 0.2939 train_acc 0.8333 test_acc 0.7000
epoch 806 loss 0.2932 train_acc 0.8333 test_acc 0.7000
epoch 807 loss 0.2921 train_acc 0.8333 test_acc 0.7000
epoch 808 loss 0.2912 train_acc 0.8333 test_acc 0.7000
epoch 809 loss 0.2901 train_acc 0.8333 test_acc 0.7000
epoch 810 loss 0.2888 train_acc 0.8500 test_acc 0.7333
epoch 811 loss 0.2873 train_acc 0.8583 test_acc 0.7333
epoch 812 loss 0.2858 train_acc 0.8583 test_acc 0.7333
epoch 813 loss 0.2844 train_acc 0.8583 test_acc 0.7333
epoch 814 loss 0.2824 train_acc 0.8667 test_acc 0.7333
epoch 815 loss 0.2808 train_acc 0.8667 test_acc 0.7333
epoch 816 loss 0.2790 train_acc 0.8667 test_acc 0.7333
epoch 817 loss 0.2772 train_acc 0.8667 test_acc 0.7333
epoch 818 loss 0.2752 train_acc 0.8667 test_acc 0.7667
epoch 819 loss 0.2729 train_acc 0.8750 test_acc 0.7667
epoch 820 loss 0.2713 train_acc 0.8750 test_acc 0.7667
epoch 821 loss 0.2689 train_acc 0.8917 test_acc 0.8000
epoch 822 

epoch 962 loss 0.2246 train_acc 0.9333 test_acc 0.9000
epoch 963 loss 0.2257 train_acc 0.9250 test_acc 0.9000
epoch 964 loss 0.2269 train_acc 0.9250 test_acc 0.9000
epoch 965 loss 0.2279 train_acc 0.9250 test_acc 0.9000
epoch 966 loss 0.2292 train_acc 0.9250 test_acc 0.9000
epoch 967 loss 0.2304 train_acc 0.9250 test_acc 0.9000
epoch 968 loss 0.2315 train_acc 0.9250 test_acc 0.9000
epoch 969 loss 0.2326 train_acc 0.9167 test_acc 0.9000
epoch 970 loss 0.2336 train_acc 0.9167 test_acc 0.8667
epoch 971 loss 0.2346 train_acc 0.9167 test_acc 0.8333
epoch 972 loss 0.2354 train_acc 0.9083 test_acc 0.8333
epoch 973 loss 0.2364 train_acc 0.9083 test_acc 0.8333
epoch 974 loss 0.2372 train_acc 0.9083 test_acc 0.8333
epoch 975 loss 0.2378 train_acc 0.9083 test_acc 0.8333
epoch 976 loss 0.2386 train_acc 0.9083 test_acc 0.8333
epoch 977 loss 0.2392 train_acc 0.9083 test_acc 0.8333
epoch 978 loss 0.2398 train_acc 0.9083 test_acc 0.8333
epoch 979 loss 0.2402 train_acc 0.9083 test_acc 0.8333
epoch 980 