In [2]:
%matplotlib inline
import torch
import torchvision
from IPython import display
from matplotlib import pyplot as plt
import numpy as np
import random

In [3]:
def use_svg_display():
    # 用矢量图显示
    display.set_matplotlib_formats('svg')

def set_figsize(figsize=(3.5, 2.5)):
    use_svg_display()
    # 设置图的尺寸
    plt.rcParams['figure.figsize'] = figsize

In [4]:
!ls ./data/iris

iris_test.csv     iris_training.csv


In [5]:
class IrisDataSet(torch.utils.data.Dataset):
    def __init__(self, csv_path, sep=','):
        with open(csv_path, 'r') as f:
            lines = f.readlines()[1:]
            datas = [line.strip().split(sep) for line in lines]
            datas_np = np.array(datas, dtype=np.float32)
            self.featurs = torch.tensor(datas_np[:,:4], dtype=torch.float32)
            self.labels = torch.tensor(datas_np[:,-1], dtype=torch.float32)
    
    def __len__(self):
        return len(self.featurs)
    
    def __getitem__(self, index):
        return self.featurs[index], self.labels[index]
    

In [39]:
batch_size = 120
train_csv_path = './data/iris/iris_training.csv'
train_data_set = IrisDataSet(train_csv_path)
train_data_iter = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True)

test_csv_path = './data/iris/iris_test.csv'
test_data_set = IrisDataSet(test_csv_path)
test_data_iter = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size)

print(len(train_data_set), len(test_data_set))

120 30


### From scratch

In [11]:
w = torch.randn((4,3), dtype=torch.float32)
b = torch.zeros((1,3), dtype=torch.float32)

w.requires_grad_(True)
b.requires_grad_(True)

w, b

(tensor([[-0.9521, -0.8950, -0.9914],
         [-1.2383, -1.0109,  1.3255],
         [ 0.4372,  1.4624,  0.9612],
         [-0.5675, -1.9333, -0.2380]], requires_grad=True),
 tensor([[0., 0., 0.]], requires_grad=True))

In [12]:
def softmax(x):
    x = x - x.max(dim=1, keepdim=True).values
    return x.exp() / x.exp().sum(dim=1, keepdim=True)

In [13]:
def cross_entropy_loss(y, y_hat):
    """
    y=[0,1,2]
    y_hat = [
        [0.7,0.2,0.1],
        [0.2,0.5,0.3],
        [0.1,0.1,0.8],
    ]
    """
    y_hat = softmax(y_hat)
    return -torch.log(
        y_hat.gather(dim=1, index=y.type(torch.long).view(-1,1))+1e-5
    ).sum()

def sgd(params, lr, bs):
    for param in params:
        param.data -= param.grad / bs * lr
        
def net(x):
    return torch.mm(x, w) + b
        
def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()

def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n

In [14]:
lr = 0.1
num_epochs = 20
params = [w, b]
loss = cross_entropy_loss
for epoch in range(1, num_epochs+1):
    train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
    for x, y in train_data_iter:
        for param in params:
            if param.grad is not None:
                param.grad.data.zero_()
        y_hat = net(x)
        l = cross_entropy_loss(y, y_hat)
        l.backward()
        sgd(params, lr, batch_size)
        
        train_l_sum += l.item()
        train_acc_sum += (y_hat.argmax(dim=1)==y).float().sum().item()
        n += y.shape[0]
    print('epoch %d loss %.4f train_acc %.4f test_acc %.4f' %
          (epoch, train_l_sum/n, train_acc_sum/n, evaluate_accuracy(test_data_iter, net)))
        

epoch 1 loss 2.7619 train_acc 0.3500 test_acc 0.3000
epoch 2 loss 1.2208 train_acc 0.3083 test_acc 0.4000
epoch 3 loss 1.0305 train_acc 0.4500 test_acc 0.6667
epoch 4 loss 0.9041 train_acc 0.6250 test_acc 0.6667
epoch 5 loss 0.7688 train_acc 0.6750 test_acc 0.7333
epoch 6 loss 0.6900 train_acc 0.7500 test_acc 0.7333
epoch 7 loss 0.6026 train_acc 0.7250 test_acc 0.6667
epoch 8 loss 0.5608 train_acc 0.7750 test_acc 0.8000
epoch 9 loss 0.4720 train_acc 0.8417 test_acc 0.7667
epoch 10 loss 0.5459 train_acc 0.7750 test_acc 0.5667
epoch 11 loss 0.5393 train_acc 0.7500 test_acc 0.8000
epoch 12 loss 0.6139 train_acc 0.7083 test_acc 0.7000
epoch 13 loss 0.4216 train_acc 0.8500 test_acc 0.7667
epoch 14 loss 0.4237 train_acc 0.8500 test_acc 0.6000
epoch 15 loss 0.6393 train_acc 0.7333 test_acc 0.7667
epoch 16 loss 0.5694 train_acc 0.7417 test_acc 0.6000
epoch 17 loss 0.4653 train_acc 0.7333 test_acc 0.8000
epoch 18 loss 0.4732 train_acc 0.7417 test_acc 0.5333
epoch 19 loss 0.5131 train_acc 0.7250

### By torch

In [61]:
class LinearNet(torch.nn.Module):
    def __init__(self, input, hidden, output):
        super().__init__()
        self.linear1 = torch.nn.Linear(input, hidden)
        self.linear2 = torch.nn.Linear(hidden, output)
        
    def forward(self, x):
        x = torch.nn.ReLU()(self.linear1(x))
        x = self.linear2(x)
        return x

In [104]:
import wandb
wandb.init(project="test-project", name="test5", reinit=True)

W&B Run: https://app.wandb.ai/ringares/test-project/runs/srntbnxu

In [105]:
lr = 1e-3
loss = torch.nn.CrossEntropyLoss()

net = LinearNet(4, 4, 3)

optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=5e-2)
torch.nn.init.normal_(net.linear1.weight, 0., 0.1)
torch.nn.init.constant_(net.linear1.bias, 0.)
torch.nn.init.normal_(net.linear2.weight, 0., 0.1)
torch.nn.init.constant_(net.linear2.bias, 0.)

Parameter containing:
tensor([0., 0., 0.], requires_grad=True)

In [106]:
wandb.watch(net)
net.train()

LinearNet(
  (linear1): Linear(in_features=4, out_features=4, bias=True)
  (linear2): Linear(in_features=4, out_features=3, bias=True)
)

In [107]:
def train(num_epochs, batch_size, net, loss, train_data_iter, test_data_iter, optimizer):
    for epoch in range(1, num_epochs+1):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for x, y in train_data_iter:
            for param in params:
                if param.grad is not None:
                    param.grad.data.zero_()
            y_hat = net(x)
            l = cross_entropy_loss(y, y_hat)
            l.backward()
            optimizer.step()

            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1)==y).float().sum().item()
            n += y.shape[0]
        wandb.log({"loss": train_l_sum/n, 
                   "train_acc":train_acc_sum/n, 
                   "evaluate_accuracy":evaluate_accuracy(test_data_iter, net)})
        print('epoch %d loss %.4f train_acc %.4f test_acc %.4f' %
              (epoch, train_l_sum/n, train_acc_sum/n, evaluate_accuracy(test_data_iter, net)))

In [108]:
train(
    num_epochs=1100,
    batch_size=batch_size,
    net = net,
    loss = loss,
    train_data_iter=train_data_iter,
    test_data_iter=test_data_iter,
    optimizer=optimizer
)

epoch 1 loss 1.1179 train_acc 0.6500 test_acc 0.7333
epoch 2 loss 1.1166 train_acc 0.6500 test_acc 0.7333
epoch 3 loss 1.1154 train_acc 0.6500 test_acc 0.7333
epoch 4 loss 1.1143 train_acc 0.6500 test_acc 0.7333
epoch 5 loss 1.1132 train_acc 0.6417 test_acc 0.7333
epoch 6 loss 1.1121 train_acc 0.6417 test_acc 0.7333
epoch 7 loss 1.1109 train_acc 0.6417 test_acc 0.7333
epoch 8 loss 1.1097 train_acc 0.6333 test_acc 0.7333
epoch 9 loss 1.1085 train_acc 0.6333 test_acc 0.7333
epoch 10 loss 1.1073 train_acc 0.6333 test_acc 0.7333
epoch 11 loss 1.1061 train_acc 0.6333 test_acc 0.7333
epoch 12 loss 1.1049 train_acc 0.6333 test_acc 0.7000
epoch 13 loss 1.1036 train_acc 0.6167 test_acc 0.7000
epoch 14 loss 1.1024 train_acc 0.6167 test_acc 0.7000
epoch 15 loss 1.1012 train_acc 0.6083 test_acc 0.6667
epoch 16 loss 1.1000 train_acc 0.6000 test_acc 0.5667
epoch 17 loss 1.0988 train_acc 0.6000 test_acc 0.5667
epoch 18 loss 1.0977 train_acc 0.5917 test_acc 0.5667
epoch 19 loss 1.0966 train_acc 0.5333

epoch 162 loss 0.7154 train_acc 0.7000 test_acc 0.5333
epoch 163 loss 0.7122 train_acc 0.7000 test_acc 0.5333
epoch 164 loss 0.7090 train_acc 0.7000 test_acc 0.5333
epoch 165 loss 0.7058 train_acc 0.7000 test_acc 0.5333
epoch 166 loss 0.7027 train_acc 0.7000 test_acc 0.5333
epoch 167 loss 0.6996 train_acc 0.7000 test_acc 0.5333
epoch 168 loss 0.6965 train_acc 0.7000 test_acc 0.5333
epoch 169 loss 0.6934 train_acc 0.7000 test_acc 0.5333
epoch 170 loss 0.6903 train_acc 0.7000 test_acc 0.5333
epoch 171 loss 0.6872 train_acc 0.7000 test_acc 0.5333
epoch 172 loss 0.6841 train_acc 0.7000 test_acc 0.5333
epoch 173 loss 0.6810 train_acc 0.7000 test_acc 0.5333
epoch 174 loss 0.6780 train_acc 0.7000 test_acc 0.5333
epoch 175 loss 0.6749 train_acc 0.7000 test_acc 0.5333
epoch 176 loss 0.6718 train_acc 0.7000 test_acc 0.5333
epoch 177 loss 0.6687 train_acc 0.7000 test_acc 0.5333
epoch 178 loss 0.6656 train_acc 0.7000 test_acc 0.5333
epoch 179 loss 0.6625 train_acc 0.7000 test_acc 0.5333
epoch 180 

epoch 330 loss 0.4193 train_acc 0.7583 test_acc 0.5667
epoch 331 loss 0.4180 train_acc 0.7583 test_acc 0.6000
epoch 332 loss 0.4168 train_acc 0.7583 test_acc 0.6000
epoch 333 loss 0.4155 train_acc 0.7583 test_acc 0.6000
epoch 334 loss 0.4142 train_acc 0.7667 test_acc 0.6333
epoch 335 loss 0.4129 train_acc 0.7667 test_acc 0.6333
epoch 336 loss 0.4116 train_acc 0.7667 test_acc 0.6333
epoch 337 loss 0.4103 train_acc 0.7667 test_acc 0.6333
epoch 338 loss 0.4090 train_acc 0.7667 test_acc 0.6333
epoch 339 loss 0.4077 train_acc 0.7750 test_acc 0.6333
epoch 340 loss 0.4064 train_acc 0.7917 test_acc 0.6333
epoch 341 loss 0.4051 train_acc 0.7917 test_acc 0.6667
epoch 342 loss 0.4038 train_acc 0.7917 test_acc 0.6667
epoch 343 loss 0.4026 train_acc 0.7917 test_acc 0.6667
epoch 344 loss 0.4013 train_acc 0.7917 test_acc 0.6667
epoch 345 loss 0.4000 train_acc 0.7917 test_acc 0.6667
epoch 346 loss 0.3987 train_acc 0.7917 test_acc 0.6667
epoch 347 loss 0.3975 train_acc 0.8000 test_acc 0.6667
epoch 348 

epoch 488 loss 0.2331 train_acc 0.9667 test_acc 0.9667
epoch 489 loss 0.2318 train_acc 0.9667 test_acc 0.9667
epoch 490 loss 0.2305 train_acc 0.9667 test_acc 0.9667
epoch 491 loss 0.2293 train_acc 0.9667 test_acc 0.9667
epoch 492 loss 0.2281 train_acc 0.9667 test_acc 0.9667
epoch 493 loss 0.2269 train_acc 0.9750 test_acc 0.9667
epoch 494 loss 0.2258 train_acc 0.9833 test_acc 0.9667
epoch 495 loss 0.2248 train_acc 0.9833 test_acc 0.9667
epoch 496 loss 0.2239 train_acc 0.9833 test_acc 0.9667
epoch 497 loss 0.2230 train_acc 0.9833 test_acc 0.9667
epoch 498 loss 0.2223 train_acc 0.9833 test_acc 0.9667
epoch 499 loss 0.2216 train_acc 0.9833 test_acc 0.9667
epoch 500 loss 0.2210 train_acc 0.9833 test_acc 0.9667
epoch 501 loss 0.2205 train_acc 0.9833 test_acc 0.9667
epoch 502 loss 0.2200 train_acc 0.9750 test_acc 0.9667
epoch 503 loss 0.2197 train_acc 0.9750 test_acc 0.9667
epoch 504 loss 0.2194 train_acc 0.9750 test_acc 0.9667
epoch 505 loss 0.2191 train_acc 0.9750 test_acc 0.9667
epoch 506 

epoch 648 loss 0.1531 train_acc 0.9333 test_acc 0.9667
epoch 649 loss 0.1476 train_acc 0.9417 test_acc 0.9667
epoch 650 loss 0.1422 train_acc 0.9500 test_acc 0.9667
epoch 651 loss 0.1371 train_acc 0.9500 test_acc 0.9333
epoch 652 loss 0.1324 train_acc 0.9500 test_acc 0.9333
epoch 653 loss 0.1280 train_acc 0.9583 test_acc 0.9333
epoch 654 loss 0.1240 train_acc 0.9583 test_acc 0.9667
epoch 655 loss 0.1205 train_acc 0.9583 test_acc 0.9667
epoch 656 loss 0.1175 train_acc 0.9583 test_acc 0.9667
epoch 657 loss 0.1151 train_acc 0.9750 test_acc 0.9667
epoch 658 loss 0.1132 train_acc 0.9750 test_acc 0.9667
epoch 659 loss 0.1120 train_acc 0.9833 test_acc 0.9667
epoch 660 loss 0.1113 train_acc 0.9833 test_acc 0.9667
epoch 661 loss 0.1113 train_acc 0.9833 test_acc 0.9667
epoch 662 loss 0.1120 train_acc 0.9833 test_acc 0.9667
epoch 663 loss 0.1132 train_acc 0.9833 test_acc 0.9667
epoch 664 loss 0.1152 train_acc 0.9750 test_acc 0.9667
epoch 665 loss 0.1177 train_acc 0.9667 test_acc 0.9667
epoch 666 

epoch 813 loss 0.0942 train_acc 0.9667 test_acc 0.9667
epoch 814 loss 0.1019 train_acc 0.9667 test_acc 0.9667
epoch 815 loss 0.1108 train_acc 0.9667 test_acc 0.9667
epoch 816 loss 0.1212 train_acc 0.9667 test_acc 0.9667
epoch 817 loss 0.1329 train_acc 0.9583 test_acc 0.9667
epoch 818 loss 0.1459 train_acc 0.9500 test_acc 0.9667
epoch 819 loss 0.1603 train_acc 0.9333 test_acc 0.9000
epoch 820 loss 0.1760 train_acc 0.9333 test_acc 0.9000
epoch 821 loss 0.1931 train_acc 0.9083 test_acc 0.9000
epoch 822 loss 0.2113 train_acc 0.8917 test_acc 0.9000
epoch 823 loss 0.2308 train_acc 0.8917 test_acc 0.9000
epoch 824 loss 0.2515 train_acc 0.8917 test_acc 0.8667
epoch 825 loss 0.2733 train_acc 0.8833 test_acc 0.8333
epoch 826 loss 0.2962 train_acc 0.8750 test_acc 0.8333
epoch 827 loss 0.3200 train_acc 0.8667 test_acc 0.8000
epoch 828 loss 0.3448 train_acc 0.8667 test_acc 0.8000
epoch 829 loss 0.3703 train_acc 0.8583 test_acc 0.8000
epoch 830 loss 0.3966 train_acc 0.8583 test_acc 0.8000
epoch 831 

epoch 983 loss 0.7358 train_acc 0.7583 test_acc 0.8667
epoch 984 loss 0.6879 train_acc 0.7667 test_acc 0.8667
epoch 985 loss 0.6404 train_acc 0.7833 test_acc 0.8667
epoch 986 loss 0.5936 train_acc 0.8000 test_acc 0.9000
epoch 987 loss 0.5476 train_acc 0.8083 test_acc 0.9333
epoch 988 loss 0.5029 train_acc 0.8250 test_acc 0.9333
epoch 989 loss 0.4597 train_acc 0.8333 test_acc 0.9333
epoch 990 loss 0.4181 train_acc 0.8500 test_acc 0.9333
epoch 991 loss 0.3785 train_acc 0.8583 test_acc 0.9333
epoch 992 loss 0.3409 train_acc 0.8750 test_acc 0.9333
epoch 993 loss 0.3056 train_acc 0.8917 test_acc 0.9333
epoch 994 loss 0.2725 train_acc 0.8917 test_acc 0.9333
epoch 995 loss 0.2419 train_acc 0.9000 test_acc 0.9333
epoch 996 loss 0.2138 train_acc 0.9083 test_acc 0.9333
epoch 997 loss 0.1881 train_acc 0.9083 test_acc 0.9667
epoch 998 loss 0.1649 train_acc 0.9250 test_acc 0.9667
epoch 999 loss 0.1442 train_acc 0.9250 test_acc 0.9667
epoch 1000 loss 0.1261 train_acc 0.9250 test_acc 0.9667
epoch 100

In [72]:
import os
torch.save(net.state_dict(), os.path.join(wandb.run.dir, "model.h5"))
wandb.save('model.h5')

[]