In [1]:
import torch
from tqdm import trange

import utils as ut
import evaluate as eva

In [2]:
class SVM():
    def __init__(self, kwargs):
        self.model_name = 'SVM'
        self.W = torch.randn(kwargs['feature_dim']+1, kwargs['class_num'])
        self.N = kwargs['class_num']
        self.init_weight()
        
    def init_weight(self):
        torch.nn.init.xavier_normal_(self.W)
        
    def train(self, X, Y, alpha=0.01, reg=2.5e4, vec=True):
        """
        功能: 完成训练过程，包括(1)求解损失, 计算梯度. (2) 正则化，计算梯度，(3)更新参数 
        输入:
            X(Tensor):(N, K:3*32*32+1)
            Y(Tensor):(N)
            alpha(float):                   # 学习率
            reg(float):                     # 正则化系数
        输出:
            L(int):(1)                      # 损失，范围给主程序以打印显示        
        """
        
        # 计算梯度与正则化
        if vec:
            L, dW = self.cal_dw_with_vec(X, Y, reg)
        else:
            L, dW = self.cal_dw_with_loop(X, Y, reg)
            
        # 更新参数
        self.W -= alpha * dW
        return L
    
    def cal_dw_with_loop(self, X, Y, reg):
        """
        功能： 计算损失和梯度
        输入:
            X(Tensor):(N, K:3*32*32+1)
            Y(Tensor):(N)
            reg(float):                    # 正则化系数
        输出:
            L(int): 1                      # 损失               
            dW(Tensor):(K+1,C)             # 参数梯度       
        """
        L = 0.0
        N = X.size(0)
        F, C = self.W.size()
        dW = torch.zeros(F, C)
        
        # (1) 求解损失
        for idx, Xi in enumerate(X):
            yi = Y[idx]
            scores = Xi.matmul(self.W)
            syi = scores[yi]
            for j in range(self.N):
                if j == yi:
                    continue
                sj = scores[j]
                if syi - sj - 1 < 0:
                    L += (sj - syi + 1).item()
                    dW[:,j] += Xi.t()
                    dW[:,yi] -= Xi.t()
        
        # (2) 正则化
        L = L / N +  0.5*reg*torch.sum(torch.pow(self.W, 2)).item()
        dW = dW / N +  reg*self.W
        
        return L, dW
    
    def cal_dw_with_vec(self, X, Y, reg):
        """
        功能： 计算损失和梯度
        输入:
            X(Tensor):(N, K:3*32*32+1)
            Y(Tensor):(N)
            reg(float):                    # 正则化系数
        输出:
            L(int): 1                      # 损失               
            dW(Tensor):(K+1,C)             # 参数梯度       
        """
        
        N = X.size(0)
        F, C = self.W.size()
        
        score = X.matmul(self.W)                                       # (N, C)
        correct = score[range(N), Y.tolist()].unsqueeze(1)             # (N, 1)
        score = torch.relu(score-correct+1)                            # (N, C)
        score[range(N), Y.tolist()] = 0
        
        L = torch.sum(score).item()
        L = L / N +  0.5*reg*torch.sum(torch.pow(self.W, 2)).item()
        
        
        dW = torch.zeros(F, C)
        mask = torch.zeros(N, C)
        mask[score>0] = 1                                              # (N,C)
        mask[range(N), Y.tolist()] = -torch.sum(mask, 1)               # (N,C)
        dW = X.t().matmul(mask)                                        # (F,C)

        dW = dW / N + reg*self.W
        return L, dW
        
    def predict(self, X):
        """
        功能: 预测输入数据标签
        输入:
            X(Tensor): (N, 3*32*32)
        输出:
            labels(Tensor): (N)
        """
        S = X.matmul(self.W)  # (N, C)
        return torch.max(S, 1)[1]
        
        

In [3]:
train_set, test_set = ut.data_load('./data')

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [4]:
batch_num = 200
opt = {"feature_dim":3*32*32, "class_num":10}
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_num, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_num, shuffle=False, num_workers=4)

In [None]:
def train(alpha, reg, epoches):
    """
    功能：完成训练过程
    输入:
        alpha(int):(1)     # 学习率
        reg(int):(1)       # 正则化系数
        epoches(int):(1)   # 迭代次数
    输出:
        svmEr(class) 训练好的模型
        alpha
        reg
    """
    svmEr = SVM(opt)
    for epoch in range(epoches):
        train_data_interator = enumerate(train_loader)
        train_steps = test_steps = (len(train_set) + batch_num - 1) // batch_num

        t = trange(train_steps)
        loss_avg = ut.RunningAverage()
        print("epoch:{}".format(epoch))
        for i in t:
            idx, data = next(train_data_interator)
            X_batch, Y_batch = data
            X_batch = X_batch.view(X_batch.size(0), -1)
            X_batch = torch.cat((torch.ones(X_batch.size(0),1), X_batch), 1)   

            loss = svmEr.train(X_batch, Y_batch, alpha=1e-4, reg=1, vec=True)
            loss_avg.update(loss)
            t.set_postfix(loss='{:05.3f}/{:05.3f}'.format(loss_avg(), loss))
        print(loss_avg())
    return svmEr

def evaluate(svmEr):
    """
    功能：使用训练好的模型进行预测，并评测结果
    输入: 
        svmEr(class) 训练好的模型
    输出: 
        acc(int):(1) 模型准确率
    """
    test_data_interator = enumerate(test_loader)
    test_steps = test_steps = (len(test_set) + batch_num - 1) // batch_num

    t = trange(test_steps)
    Y_predict = []
    Y_true = []
    for i in t:
        idx, data = next(test_data_interator)
        X_batch, Y_batch = data
        Y_true.extend(Y_batch.tolist())
        X_batch = X_batch.view(X_batch.size(0), -1)
        X_batch = torch.cat((torch.ones(X_batch.size(0),1), X_batch), 1)   

        y = svmEr.predict(X_batch)
        Y_predict.extend(y.tolist())
        
    Y_predict = torch.LongTensor(Y_predict)
    Y_true = torch.LongTensor(Y_true)
    acc = torch.sum(Y_predict == Y_true).item() /len(Y_predict)
    
    return acc
    

In [None]:
lrs = [1e-2, 1e-3, 1e-4, 1e-5]
reg_strs = [0, 1, 10, 100, 1000]

result = {}

best_lr = None
best_reg = None
best_svm = None
best_acc = -1

for lr in lrs:
    for reg in reg_strs:
        svmEr = train(lr, reg, 25)
        acc = evaluate(svmEr)
        print("lr:{}; reg:{}; acc:{}".format(lr, reg, acc))
        if acc > best_acc:
            best_lr = lr
            best_reg = reg
            best_svm = svmEr
        result[(lr, reg)] = acc
print("the best: lr:{}; reg:{}; acc:{}".format(best_lr, best_reg, best_acc))

  2%|▏         | 6/250 [00:00<00:04, 59.65it/s, loss=19.626/19.076]

epoch:0


100%|██████████| 250/250 [00:01<00:00, 128.34it/s, loss=17.916/16.739]

17.916246122131355



  4%|▎         | 9/250 [00:00<00:02, 87.78it/s, loss=16.962/16.527]

epoch:1


100%|██████████| 250/250 [00:02<00:00, 121.80it/s, loss=16.359/15.812]

16.35899254180909



  4%|▍         | 11/250 [00:00<00:02, 107.86it/s, loss=15.887/15.926]

epoch:2


100%|██████████| 250/250 [00:02<00:00, 122.62it/s, loss=15.422/14.556]

15.422245274658213



  4%|▍         | 11/250 [00:00<00:02, 107.69it/s, loss=15.068/14.713]

epoch:3


100%|██████████| 250/250 [00:01<00:00, 126.77it/s, loss=14.724/14.656]


14.72368852920533


  4%|▎         | 9/250 [00:00<00:02, 85.51it/s, loss=14.483/14.247]

epoch:4


100%|██████████| 250/250 [00:02<00:00, 124.18it/s, loss=14.140/13.464]

14.140059371109006



  4%|▍         | 11/250 [00:00<00:02, 104.50it/s, loss=13.871/14.231]

epoch:5


100%|██████████| 250/250 [00:01<00:00, 126.94it/s, loss=13.625/12.891]

13.62452166725158



  4%|▍         | 11/250 [00:00<00:02, 108.54it/s, loss=13.250/13.248]

epoch:6


100%|██████████| 250/250 [00:02<00:00, 121.78it/s, loss=13.159/13.136]


13.159336671218869


  4%|▎         | 9/250 [00:00<00:02, 87.76it/s, loss=12.820/12.731]

epoch:7


100%|██████████| 250/250 [00:02<00:00, 118.04it/s, loss=12.731/12.259]


12.731060939178466


  4%|▍         | 10/250 [00:00<00:02, 99.92it/s, loss=12.546/12.278]

epoch:8


100%|██████████| 250/250 [00:01<00:00, 126.49it/s, loss=12.335/12.217]


12.335066028442393


  4%|▎         | 9/250 [00:00<00:02, 84.67it/s, loss=12.044/12.404]

epoch:9


100%|██████████| 250/250 [00:02<00:00, 124.61it/s, loss=11.966/11.266]

11.965536507186881



  4%|▍         | 10/250 [00:00<00:02, 98.35it/s, loss=11.786/12.106]

epoch:10


100%|██████████| 250/250 [00:02<00:00, 119.00it/s, loss=11.621/11.060]

11.62132484992981



  4%|▍         | 10/250 [00:00<00:02, 98.08it/s, loss=11.517/11.710]

epoch:11


100%|██████████| 250/250 [00:02<00:00, 124.29it/s, loss=11.296/11.574]

11.296423723754872



  4%|▍         | 10/250 [00:00<00:02, 96.99it/s, loss=11.003/11.242]

epoch:12


100%|██████████| 250/250 [00:02<00:00, 122.79it/s, loss=10.993/11.024]


10.992526661071778


  4%|▎         | 9/250 [00:00<00:02, 83.33it/s, loss=10.799/10.465]

epoch:13


100%|██████████| 250/250 [00:01<00:00, 126.51it/s, loss=10.705/11.083]

10.705010594329833



  4%|▎         | 9/250 [00:00<00:02, 83.94it/s, loss=10.579/10.664]

epoch:14


100%|██████████| 250/250 [00:01<00:00, 125.87it/s, loss=10.436/9.917]

10.435513696746824



  4%|▎         | 9/250 [00:00<00:02, 89.82it/s, loss=10.275/10.178]

epoch:15


100%|██████████| 250/250 [00:02<00:00, 125.10it/s, loss=10.178/10.164]

10.17831240562439



  4%|▍         | 11/250 [00:00<00:02, 106.64it/s, loss=10.062/10.335]

epoch:16


100%|██████████| 250/250 [00:02<00:00, 123.30it/s, loss=9.938/9.928]

9.938086416778564



  4%|▎         | 9/250 [00:00<00:02, 88.91it/s, loss=9.653/9.865]

epoch:17


100%|██████████| 250/250 [00:02<00:00, 120.55it/s, loss=9.710/9.548]

9.709956866836546



  4%|▎         | 9/250 [00:00<00:02, 89.13it/s, loss=9.561/10.212]

epoch:18


100%|██████████| 250/250 [00:02<00:00, 123.01it/s, loss=9.493/9.535]


9.493302478179936


  4%|▍         | 10/250 [00:00<00:02, 98.41it/s, loss=9.395/9.489] 

epoch:19


100%|██████████| 250/250 [00:02<00:00, 123.30it/s, loss=9.288/9.249]

9.288450702667237



  4%|▍         | 10/250 [00:00<00:02, 86.46it/s, loss=9.192/8.905]

epoch:20


100%|██████████| 250/250 [00:02<00:00, 123.49it/s, loss=9.095/9.419]


9.094563897895812


  4%|▍         | 11/250 [00:00<00:02, 106.33it/s, loss=9.048/8.729]

epoch:21


100%|██████████| 250/250 [00:02<00:00, 121.40it/s, loss=8.911/9.272]

8.910736973114023



  4%|▍         | 10/250 [00:00<00:02, 94.11it/s, loss=8.760/8.737]

epoch:22


100%|██████████| 250/250 [00:02<00:00, 119.10it/s, loss=8.737/8.125]

8.73744777889252



  4%|▎         | 9/250 [00:00<00:02, 88.94it/s, loss=8.669/9.182]

epoch:23


100%|██████████| 250/250 [00:01<00:00, 127.56it/s, loss=8.570/8.196]

8.570184391937254



  4%|▍         | 10/250 [00:00<00:02, 97.92it/s, loss=8.413/8.295]

epoch:24


100%|██████████| 250/250 [00:01<00:00, 125.77it/s, loss=8.415/8.484]


8.414628829498287


100%|██████████| 50/50 [00:00<00:00, 124.10it/s]


lr:0.01; reg:0; acc:0.35


  2%|▏         | 6/250 [00:00<00:04, 59.96it/s, loss=19.167/18.455]

epoch:0


100%|██████████| 250/250 [00:02<00:00, 117.74it/s, loss=17.816/17.050]


17.815554900207513


  5%|▌         | 13/250 [00:00<00:02, 116.09it/s, loss=16.875/16.814]

epoch:1


100%|██████████| 250/250 [00:02<00:00, 124.62it/s, loss=16.267/15.873]

16.2671507395935



  4%|▍         | 10/250 [00:00<00:02, 99.22it/s, loss=15.679/15.545]

epoch:2


100%|██████████| 250/250 [00:02<00:00, 123.49it/s, loss=15.352/14.546]

15.351803334198003



  4%|▍         | 10/250 [00:00<00:02, 98.55it/s, loss=15.034/14.419]

epoch:3


100%|██████████| 250/250 [00:02<00:00, 120.65it/s, loss=14.670/14.062]

14.669615088043209



  5%|▍         | 12/250 [00:00<00:02, 114.68it/s, loss=14.444/14.898]

epoch:4


100%|██████████| 250/250 [00:02<00:00, 123.14it/s, loss=14.096/13.548]


14.095687815322872


  4%|▍         | 10/250 [00:00<00:02, 97.94it/s, loss=13.866/14.251]

epoch:5


100%|██████████| 250/250 [00:02<00:00, 121.20it/s, loss=13.585/13.280]

13.584698061141962



  4%|▎         | 9/250 [00:00<00:02, 88.82it/s, loss=13.429/13.204]

epoch:6


100%|██████████| 250/250 [00:02<00:00, 122.64it/s, loss=13.121/12.958]

13.120814876632693



  3%|▎         | 8/250 [00:00<00:03, 78.73it/s, loss=12.975/13.468]

epoch:7


100%|██████████| 250/250 [00:02<00:00, 124.36it/s, loss=12.695/12.388]


12.695199081573483


  4%|▍         | 11/250 [00:00<00:02, 109.21it/s, loss=12.410/12.669]

epoch:8


100%|██████████| 250/250 [00:02<00:00, 122.15it/s, loss=12.299/12.102]


12.299339632110602


  4%|▍         | 10/250 [00:00<00:02, 97.06it/s, loss=12.123/12.005]

epoch:9


100%|██████████| 250/250 [00:02<00:00, 123.93it/s, loss=11.932/11.142]


11.932037286682132


  4%|▍         | 10/250 [00:00<00:02, 99.29it/s, loss=11.815/12.076]

epoch:10


100%|██████████| 250/250 [00:02<00:00, 122.17it/s, loss=11.588/11.073]

11.588364110717773



  4%|▍         | 11/250 [00:00<00:02, 108.25it/s, loss=11.378/10.928]

epoch:11


100%|██████████| 250/250 [00:02<00:00, 115.60it/s, loss=11.264/11.060]


11.264279318084705


  4%|▍         | 10/250 [00:00<00:02, 96.48it/s, loss=11.083/10.841]

epoch:12


100%|██████████| 250/250 [00:01<00:00, 125.75it/s, loss=10.962/10.264]

10.962209168014528



  4%|▍         | 11/250 [00:00<00:02, 107.82it/s, loss=10.718/10.597]

epoch:13


100%|██████████| 250/250 [00:01<00:00, 125.18it/s, loss=10.677/9.965]

10.677119106140138



  4%|▎         | 9/250 [00:00<00:02, 87.90it/s, loss=10.585/10.645]

epoch:14


100%|██████████| 250/250 [00:02<00:00, 122.04it/s, loss=10.407/10.308]

10.40736568916321



  4%|▎         | 9/250 [00:00<00:02, 89.07it/s, loss=10.320/10.254]

epoch:15


100%|██████████| 250/250 [00:02<00:00, 124.46it/s, loss=10.152/9.818]

10.152180160598757



  4%|▎         | 9/250 [00:00<00:02, 87.85it/s, loss=10.006/9.442] 

epoch:16


100%|██████████| 250/250 [00:02<00:00, 119.21it/s, loss=9.912/9.854]

9.911775660705565



  4%|▍         | 11/250 [00:00<00:02, 105.42it/s, loss=9.819/9.695]

epoch:17


100%|██████████| 250/250 [00:02<00:00, 124.09it/s, loss=9.686/9.429]


9.686147310333261


  4%|▍         | 11/250 [00:00<00:02, 108.43it/s, loss=9.490/9.036]

epoch:18


100%|██████████| 250/250 [00:01<00:00, 125.74it/s, loss=9.471/9.548]

9.471020856246945



  4%|▍         | 11/250 [00:00<00:02, 107.80it/s, loss=9.376/9.722]

epoch:19


100%|██████████| 250/250 [00:02<00:00, 117.56it/s, loss=9.267/9.332]

9.267187007369992



  4%|▍         | 10/250 [00:00<00:02, 97.52it/s, loss=9.113/8.707]

epoch:20


100%|██████████| 250/250 [00:02<00:00, 124.58it/s, loss=9.074/9.159]

9.07420670124054



  4%|▍         | 10/250 [00:00<00:02, 98.77it/s, loss=8.999/9.593]

epoch:21


100%|██████████| 250/250 [00:01<00:00, 125.85it/s, loss=8.891/8.568]


8.890999304885868


  4%|▎         | 9/250 [00:00<00:02, 88.02it/s, loss=8.833/8.464]

epoch:22


100%|██████████| 250/250 [00:02<00:00, 122.33it/s, loss=8.719/8.996]


8.718575936126708


  4%|▍         | 10/250 [00:00<00:02, 97.46it/s, loss=8.664/8.993]

epoch:23


100%|██████████| 250/250 [00:02<00:00, 119.91it/s, loss=8.554/8.189]


8.55384531528473


  4%|▍         | 11/250 [00:00<00:02, 107.46it/s, loss=8.670/8.790]

epoch:24


100%|██████████| 250/250 [00:02<00:00, 122.15it/s, loss=8.397/8.337]


8.396745621376043


100%|██████████| 50/50 [00:00<00:00, 119.16it/s]

lr:0.01; reg:1; acc:0.35



  4%|▍         | 10/250 [00:00<00:02, 85.62it/s, loss=19.644/19.481]

epoch:0


100%|██████████| 250/250 [00:02<00:00, 121.84it/s, loss=17.969/16.801]

17.969020517578123



  5%|▍         | 12/250 [00:00<00:02, 118.74it/s, loss=16.873/16.700]

epoch:1


100%|██████████| 250/250 [00:01<00:00, 129.32it/s, loss=16.377/15.570]

16.377091684265135



  4%|▎         | 9/250 [00:00<00:02, 85.24it/s, loss=15.896/15.870]

epoch:2


100%|██████████| 250/250 [00:02<00:00, 121.99it/s, loss=15.439/14.289]


15.439497570953368


  4%|▍         | 10/250 [00:00<00:02, 98.39it/s, loss=15.102/14.895]

epoch:3


100%|██████████| 250/250 [00:01<00:00, 125.89it/s, loss=14.744/14.114]

14.744209224548337



  4%|▍         | 10/250 [00:00<00:02, 94.22it/s, loss=14.399/14.042]

epoch:4


100%|██████████| 250/250 [00:02<00:00, 127.99it/s, loss=14.160/14.170]


14.160371120529177


  4%|▍         | 10/250 [00:00<00:02, 96.87it/s, loss=13.837/13.768]

epoch:5


100%|██████████| 250/250 [00:02<00:00, 121.74it/s, loss=13.643/13.413]

13.643439978256223



  4%|▍         | 10/250 [00:00<00:02, 95.82it/s, loss=13.408/13.001]

epoch:6


100%|██████████| 250/250 [00:01<00:00, 126.54it/s, loss=13.176/12.886]

13.176308521728513



  4%|▍         | 10/250 [00:00<00:02, 96.69it/s, loss=13.035/12.776]

epoch:7


100%|██████████| 250/250 [00:02<00:00, 120.91it/s, loss=12.747/12.281]

12.74721038528442



  4%|▍         | 11/250 [00:00<00:02, 108.55it/s, loss=12.527/12.777]

epoch:8


100%|██████████| 250/250 [00:02<00:00, 124.65it/s, loss=12.349/12.204]

12.348974196624752



  4%|▍         | 11/250 [00:00<00:02, 105.74it/s, loss=12.121/12.077]

epoch:9


100%|██████████| 250/250 [00:01<00:00, 127.33it/s, loss=11.979/12.211]

11.979202127990714



  4%|▎         | 9/250 [00:00<00:02, 86.73it/s, loss=11.855/11.696]

epoch:10


100%|██████████| 250/250 [00:01<00:00, 125.35it/s, loss=11.632/11.460]


11.63214548751831


  4%|▍         | 10/250 [00:00<00:02, 96.61it/s, loss=11.533/11.678]

epoch:11


100%|██████████| 250/250 [00:02<00:00, 120.53it/s, loss=11.307/11.119]


11.307030237045288


  4%|▎         | 9/250 [00:00<00:02, 89.05it/s, loss=11.115/11.101]

epoch:12


100%|██████████| 250/250 [00:01<00:00, 132.81it/s, loss=11.002/11.101]

11.001533670196528



  4%|▎         | 9/250 [00:00<00:02, 89.03it/s, loss=10.900/10.879]

epoch:13


100%|██████████| 250/250 [00:02<00:00, 121.09it/s, loss=10.715/10.523]

10.71457227821351



  4%|▍         | 10/250 [00:00<00:02, 96.61it/s, loss=10.630/10.252]

epoch:14


100%|██████████| 250/250 [00:02<00:00, 124.07it/s, loss=10.442/10.351]


10.442303704071035


  4%|▎         | 9/250 [00:00<00:02, 89.36it/s, loss=10.242/10.240]

epoch:15


100%|██████████| 250/250 [00:02<00:00, 121.41it/s, loss=10.187/10.199]

10.18744029663086



  4%|▍         | 10/250 [00:00<00:02, 96.44it/s, loss=10.111/10.343]

epoch:16


100%|██████████| 250/250 [00:01<00:00, 125.81it/s, loss=9.944/9.997]

9.94431532180786



  4%|▍         | 10/250 [00:00<00:02, 96.37it/s, loss=9.745/9.650]

epoch:17


100%|██████████| 250/250 [00:02<00:00, 123.91it/s, loss=9.716/9.661]


9.715758679504392


  4%|▎         | 9/250 [00:00<00:02, 85.95it/s, loss=9.567/9.770]

epoch:18


100%|██████████| 250/250 [00:02<00:00, 120.07it/s, loss=9.499/9.556]

9.49941695266724



  4%|▍         | 10/250 [00:00<00:02, 99.72it/s, loss=9.302/9.030]

epoch:19


100%|██████████| 250/250 [00:02<00:00, 124.29it/s, loss=9.295/9.145]

9.294908153419486



  4%|▎         | 9/250 [00:00<00:02, 87.74it/s, loss=9.241/9.125]

epoch:20


100%|██████████| 250/250 [00:02<00:00, 118.68it/s, loss=9.100/9.518]

9.09957903896332



  4%|▎         | 9/250 [00:00<00:02, 89.87it/s, loss=8.921/8.912]

epoch:21


100%|██████████| 250/250 [00:02<00:00, 123.19it/s, loss=8.916/9.140]


8.916270267791747


  4%|▍         | 10/250 [00:00<00:02, 91.27it/s, loss=8.849/9.194]

epoch:22


100%|██████████| 250/250 [00:02<00:00, 122.42it/s, loss=8.742/9.105]


8.741738113861084


  4%|▍         | 11/250 [00:00<00:02, 106.59it/s, loss=8.637/8.619]

epoch:23


100%|██████████| 250/250 [00:02<00:00, 123.04it/s, loss=8.576/8.549]


8.576164241523744


  4%|▍         | 10/250 [00:00<00:02, 96.82it/s, loss=8.518/8.272]

epoch:24


100%|██████████| 250/250 [00:02<00:00, 124.48it/s, loss=8.418/8.312]


8.418285245666505


100%|██████████| 50/50 [00:00<00:00, 112.47it/s]

lr:0.01; reg:10; acc:0.3487



  4%|▎         | 9/250 [00:00<00:03, 77.26it/s, loss=19.302/19.268]

epoch:0


100%|██████████| 250/250 [00:02<00:00, 119.83it/s, loss=17.923/16.761]


17.9230865510559


  4%|▍         | 11/250 [00:00<00:02, 103.54it/s, loss=16.955/16.633]

epoch:1


100%|██████████| 250/250 [00:02<00:00, 120.77it/s, loss=16.266/15.695]

16.26553041168214



  4%|▎         | 9/250 [00:00<00:02, 84.98it/s, loss=15.698/15.638]

epoch:2


100%|██████████| 250/250 [00:02<00:00, 124.84it/s, loss=15.270/14.888]

15.26991767105103



  4%|▍         | 10/250 [00:00<00:02, 97.16it/s, loss=14.875/14.897]

epoch:3


100%|██████████| 250/250 [00:02<00:00, 121.12it/s, loss=14.561/14.336]

14.560876787719728



  5%|▍         | 12/250 [00:00<00:02, 114.49it/s, loss=14.302/13.822]

epoch:4


100%|██████████| 250/250 [00:02<00:00, 123.37it/s, loss=13.983/14.085]

13.983393429260252



  4%|▎         | 9/250 [00:00<00:02, 88.95it/s, loss=13.793/14.011]

epoch:5


100%|██████████| 250/250 [00:02<00:00, 119.59it/s, loss=13.476/12.644]

13.475838845596313



  4%|▍         | 10/250 [00:00<00:02, 97.71it/s, loss=13.312/12.845]

epoch:6


100%|██████████| 250/250 [00:02<00:00, 121.88it/s, loss=13.018/12.764]

13.01848428672791



  4%|▍         | 10/250 [00:00<00:02, 98.85it/s, loss=12.899/13.073]

epoch:7


100%|██████████| 250/250 [00:02<00:00, 123.17it/s, loss=12.600/12.540]

12.600032804183966



  4%|▍         | 11/250 [00:00<00:02, 108.53it/s, loss=12.229/11.619]

epoch:8


100%|██████████| 250/250 [00:01<00:00, 129.02it/s, loss=12.211/12.060]

12.210707184753415



  4%|▎         | 9/250 [00:00<00:02, 89.15it/s, loss=12.062/11.653]

epoch:9


100%|██████████| 250/250 [00:02<00:00, 118.54it/s, loss=11.849/11.391]


11.849416738967893


  4%|▎         | 9/250 [00:00<00:02, 89.84it/s, loss=11.661/11.352]

epoch:10


100%|██████████| 250/250 [00:02<00:00, 121.90it/s, loss=11.510/11.834]


11.510071171951296


  4%|▎         | 9/250 [00:00<00:02, 87.56it/s, loss=11.242/11.149]

epoch:11


100%|██████████| 250/250 [00:02<00:00, 123.05it/s, loss=11.192/10.933]

11.191961520156848



  4%|▍         | 11/250 [00:00<00:02, 103.76it/s, loss=11.112/10.635]

epoch:12


100%|██████████| 250/250 [00:01<00:00, 127.51it/s, loss=10.893/10.760]

10.89263427009582



  4%|▍         | 11/250 [00:00<00:02, 105.09it/s, loss=10.730/10.923]

epoch:13


100%|██████████| 250/250 [00:02<00:00, 120.36it/s, loss=10.611/10.633]

10.611241938858035



  4%|▍         | 10/250 [00:00<00:02, 99.36it/s, loss=10.394/10.013]

epoch:14


100%|██████████| 250/250 [00:02<00:00, 122.96it/s, loss=10.345/9.885]


10.345049800415048


  4%|▍         | 10/250 [00:00<00:02, 96.45it/s, loss=10.291/10.523]

epoch:15


100%|██████████| 250/250 [00:01<00:00, 126.68it/s, loss=10.093/10.497]


10.09336397888183


  4%|▍         | 11/250 [00:00<00:02, 104.98it/s, loss=10.012/9.880] 

epoch:16


100%|██████████| 250/250 [00:01<00:00, 126.04it/s, loss=9.857/10.084]

9.856569747695923



  2%|▏         | 5/250 [00:00<00:04, 49.49it/s, loss=9.785/9.889]

epoch:17


100%|██████████| 250/250 [00:02<00:00, 122.06it/s, loss=9.631/9.345]

9.630577645111076



  5%|▍         | 12/250 [00:00<00:02, 115.90it/s, loss=9.492/9.859]

epoch:18


100%|██████████| 250/250 [00:02<00:00, 121.46it/s, loss=9.418/9.419]

9.418150564422604



  4%|▍         | 10/250 [00:00<00:02, 95.19it/s, loss=9.308/9.041]

epoch:19


100%|██████████| 250/250 [00:02<00:00, 124.91it/s, loss=9.218/8.928]


9.218057455062864


  5%|▍         | 12/250 [00:00<00:02, 116.57it/s, loss=9.178/9.533]

epoch:20


100%|██████████| 250/250 [00:01<00:00, 129.54it/s, loss=9.027/9.196]

9.02686303646087



  4%|▎         | 9/250 [00:00<00:02, 83.84it/s, loss=8.884/8.916]

epoch:21


100%|██████████| 250/250 [00:01<00:00, 126.73it/s, loss=8.844/8.819]


8.844129444732669


  4%|▎         | 9/250 [00:00<00:02, 85.91it/s, loss=8.888/8.930]

epoch:22


100%|██████████| 250/250 [00:02<00:00, 123.89it/s, loss=8.675/8.219]

8.674747849121092



  4%|▍         | 10/250 [00:00<00:02, 98.19it/s, loss=8.694/8.937]

epoch:23


100%|██████████| 250/250 [00:02<00:00, 126.29it/s, loss=8.512/8.383]

8.511701023139956



  4%|▎         | 9/250 [00:00<00:02, 87.11it/s, loss=8.439/8.794]

epoch:24


100%|██████████| 250/250 [00:02<00:00, 120.65it/s, loss=8.358/8.555]


8.357806935462953


100%|██████████| 50/50 [00:00<00:00, 115.40it/s]

lr:0.01; reg:100; acc:0.3522



  4%|▎         | 9/250 [00:00<00:02, 84.99it/s, loss=19.236/18.851]

epoch:0


100%|██████████| 250/250 [00:01<00:00, 125.20it/s, loss=17.468/16.591]

17.468252722473146



  4%|▍         | 10/250 [00:00<00:02, 96.71it/s, loss=16.500/16.337]

epoch:1


100%|██████████| 250/250 [00:01<00:00, 126.10it/s, loss=16.083/15.550]


16.083373804321287


  4%|▎         | 9/250 [00:00<00:02, 87.45it/s, loss=15.639/15.621]

epoch:2


100%|██████████| 250/250 [00:02<00:00, 123.63it/s, loss=15.239/15.074]

15.239159704742429



  4%|▍         | 10/250 [00:00<00:02, 96.66it/s, loss=14.844/14.681]

epoch:3


100%|██████████| 250/250 [00:02<00:00, 122.76it/s, loss=14.578/13.791]

14.577643055725096



  4%|▍         | 10/250 [00:00<00:02, 97.81it/s, loss=14.250/14.285]

epoch:4


100%|██████████| 250/250 [00:01<00:00, 127.44it/s, loss=14.015/13.621]


14.015341282730104


  4%|▍         | 10/250 [00:00<00:02, 94.54it/s, loss=13.735/13.505]

epoch:5


100%|██████████| 250/250 [00:01<00:00, 126.39it/s, loss=13.515/14.041]


13.514606290130615


  4%|▍         | 11/250 [00:00<00:02, 105.09it/s, loss=13.236/13.397]

epoch:6


100%|██████████| 250/250 [00:02<00:00, 123.64it/s, loss=13.058/12.444]

13.058060020141607



  4%|▍         | 11/250 [00:00<00:02, 109.25it/s, loss=12.855/13.048]

epoch:7


100%|██████████| 250/250 [00:02<00:00, 125.94it/s, loss=12.638/12.182]

12.638117534713752



  5%|▍         | 12/250 [00:00<00:02, 114.44it/s, loss=12.476/12.318]

epoch:8


100%|██████████| 250/250 [00:01<00:00, 127.22it/s, loss=12.248/11.980]

12.247604243240362



  4%|▍         | 11/250 [00:00<00:02, 102.91it/s, loss=12.055/11.969]

epoch:9


100%|██████████| 250/250 [00:02<00:00, 120.61it/s, loss=11.885/11.468]

11.885130422210695



  4%|▍         | 10/250 [00:00<00:02, 96.57it/s, loss=11.657/11.919]

epoch:10


100%|██████████| 250/250 [00:01<00:00, 125.42it/s, loss=11.544/11.411]

11.543612541275026



  4%|▎         | 9/250 [00:00<00:02, 88.69it/s, loss=11.282/11.029]

epoch:11


100%|██████████| 250/250 [00:02<00:00, 121.20it/s, loss=11.224/11.546]


11.22448576576233


  4%|▎         | 9/250 [00:00<00:02, 88.08it/s, loss=11.008/10.582]

epoch:12


100%|██████████| 250/250 [00:02<00:00, 121.41it/s, loss=10.925/10.292]


10.925404503478996


  4%|▍         | 10/250 [00:00<00:02, 97.54it/s, loss=10.873/11.459]

epoch:13


100%|██████████| 250/250 [00:01<00:00, 125.14it/s, loss=10.641/10.268]


10.641123804855342


  4%|▍         | 10/250 [00:00<00:02, 96.74it/s, loss=10.448/10.863]

epoch:14


100%|██████████| 250/250 [00:01<00:00, 126.37it/s, loss=10.373/10.602]

10.373445298538208



  4%|▎         | 9/250 [00:00<00:02, 85.44it/s, loss=10.282/10.354]

epoch:15


100%|██████████| 250/250 [00:02<00:00, 123.08it/s, loss=10.122/10.390]

10.12238635894775



  4%|▍         | 11/250 [00:00<00:02, 108.42it/s, loss=10.022/10.226]

epoch:16


100%|██████████| 250/250 [00:02<00:00, 118.69it/s, loss=9.884/10.016]

9.88352467338562



  4%|▎         | 9/250 [00:00<00:02, 87.30it/s, loss=9.768/10.001]

epoch:17


100%|██████████| 250/250 [00:02<00:00, 118.63it/s, loss=9.658/9.824]

9.65782411117553



  4%|▍         | 11/250 [00:00<00:02, 109.10it/s, loss=9.438/9.021]

epoch:18


100%|██████████| 250/250 [00:02<00:00, 123.91it/s, loss=9.444/9.556]

9.443991833953858



  4%|▍         | 11/250 [00:00<00:02, 109.14it/s, loss=9.333/9.178]

epoch:19


100%|██████████| 250/250 [00:02<00:00, 121.90it/s, loss=9.243/9.073]

9.242528081741336



  5%|▍         | 12/250 [00:00<00:02, 116.64it/s, loss=9.116/9.115]

epoch:20


100%|██████████| 250/250 [00:01<00:00, 127.31it/s, loss=9.050/8.620]

9.049820881271362



  4%|▍         | 11/250 [00:00<00:02, 104.55it/s, loss=8.982/8.931]

epoch:21


100%|██████████| 250/250 [00:02<00:00, 123.67it/s, loss=8.869/8.055]

8.869377263755792



  4%|▎         | 9/250 [00:00<00:02, 84.33it/s, loss=8.673/8.389]

epoch:22


100%|██████████| 250/250 [00:01<00:00, 126.07it/s, loss=8.696/8.521]

8.69579874507904



  4%|▎         | 9/250 [00:00<00:02, 83.70it/s, loss=8.700/8.785]

epoch:23


100%|██████████| 250/250 [00:01<00:00, 127.60it/s, loss=8.533/8.338]


8.532903923301694


  4%|▎         | 9/250 [00:00<00:02, 88.37it/s, loss=8.589/8.596]

epoch:24


100%|██████████| 250/250 [00:01<00:00, 127.47it/s, loss=8.378/8.445]


8.37849558074951


100%|██████████| 50/50 [00:00<00:00, 129.35it/s]

lr:0.01; reg:1000; acc:0.3475



  4%|▎         | 9/250 [00:00<00:02, 88.71it/s, loss=19.567/19.562]

epoch:0


100%|██████████| 250/250 [00:02<00:00, 124.07it/s, loss=18.065/17.181]

18.065440831756597



  5%|▍         | 12/250 [00:00<00:02, 117.21it/s, loss=17.088/17.040]

epoch:1


100%|██████████| 250/250 [00:02<00:00, 120.51it/s, loss=16.425/15.664]

16.42518378341675



  4%|▎         | 9/250 [00:00<00:03, 76.56it/s, loss=15.864/15.898]

epoch:2


100%|██████████| 250/250 [00:02<00:00, 121.87it/s, loss=15.456/15.496]


15.456286075592045


  4%|▍         | 10/250 [00:00<00:02, 98.29it/s, loss=15.029/14.582]

epoch:3


100%|██████████| 250/250 [00:02<00:00, 124.08it/s, loss=14.752/14.521]


14.752382411041259


  4%|▍         | 11/250 [00:00<00:02, 107.73it/s, loss=14.549/14.482]

epoch:4


100%|██████████| 250/250 [00:02<00:00, 123.82it/s, loss=14.165/13.835]

14.164590626068115



  4%|▎         | 9/250 [00:00<00:02, 88.69it/s, loss=13.920/13.515]

epoch:5


100%|██████████| 250/250 [00:01<00:00, 126.83it/s, loss=13.646/13.564]

13.646297334213262



  4%|▍         | 10/250 [00:00<00:02, 97.98it/s, loss=13.436/13.498]

epoch:6


100%|██████████| 250/250 [00:02<00:00, 124.20it/s, loss=13.179/13.036]


13.178773640670775


  4%|▍         | 10/250 [00:00<00:02, 100.00it/s, loss=12.999/12.613]

epoch:7


100%|██████████| 250/250 [00:01<00:00, 128.52it/s, loss=12.750/12.648]

12.749676422195439



  4%|▍         | 11/250 [00:00<00:02, 102.10it/s, loss=12.429/12.402]

epoch:8


100%|██████████| 250/250 [00:02<00:00, 122.07it/s, loss=12.351/12.496]

12.350718756408687



  4%|▍         | 10/250 [00:00<00:02, 99.80it/s, loss=12.119/11.887]

epoch:9


100%|██████████| 250/250 [00:02<00:00, 124.60it/s, loss=11.980/12.255]

11.980370386810305



  4%|▍         | 11/250 [00:00<00:02, 104.09it/s, loss=11.809/12.035]

epoch:10


100%|██████████| 250/250 [00:02<00:00, 124.74it/s, loss=11.634/11.531]

11.63373385154725



  4%|▍         | 10/250 [00:00<00:02, 95.83it/s, loss=11.359/11.429]

epoch:11


100%|██████████| 250/250 [00:02<00:00, 123.76it/s, loss=11.308/11.089]

11.307634508590697



  4%|▍         | 10/250 [00:00<00:02, 96.69it/s, loss=11.079/10.694]

epoch:12


100%|██████████| 250/250 [00:02<00:00, 122.01it/s, loss=11.003/10.600]

11.00273172706604



  4%|▎         | 9/250 [00:00<00:02, 88.38it/s, loss=10.958/10.756]

epoch:13


100%|██████████| 250/250 [00:02<00:00, 118.93it/s, loss=10.714/10.592]


10.713924961776744


  4%|▎         | 9/250 [00:00<00:02, 81.64it/s, loss=10.573/10.637]

epoch:14


100%|██████████| 250/250 [00:02<00:00, 124.93it/s, loss=10.442/10.124]

10.442297568588257



  4%|▍         | 10/250 [00:00<00:02, 99.20it/s, loss=10.446/10.534]

epoch:15


100%|██████████| 250/250 [00:01<00:00, 128.71it/s, loss=10.185/9.831]


10.184775684585569


  4%|▎         | 9/250 [00:00<00:02, 89.47it/s, loss=10.111/9.990] 

epoch:16


100%|██████████| 250/250 [00:02<00:00, 122.53it/s, loss=9.944/9.944]

9.943885758209229



  4%|▍         | 10/250 [00:00<00:02, 96.50it/s, loss=9.654/9.540]

epoch:17


100%|██████████| 250/250 [00:02<00:00, 122.47it/s, loss=9.716/10.406]

9.715616761322023



  4%|▍         | 10/250 [00:00<00:02, 96.77it/s, loss=9.718/9.633]

epoch:18


100%|██████████| 250/250 [00:02<00:00, 129.28it/s, loss=9.499/9.093]


9.498930023040769


  4%|▍         | 10/250 [00:00<00:02, 96.51it/s, loss=9.386/9.264]

epoch:19


100%|██████████| 250/250 [00:02<00:00, 121.04it/s, loss=9.293/9.107]

9.293206131553653



  5%|▍         | 12/250 [00:00<00:02, 113.44it/s, loss=9.241/9.941]

epoch:20


100%|██████████| 250/250 [00:02<00:00, 124.80it/s, loss=9.098/8.978]


9.097881811218265


  4%|▎         | 9/250 [00:00<00:02, 87.77it/s, loss=8.942/8.792]

epoch:21


100%|██████████| 250/250 [00:02<00:00, 118.94it/s, loss=8.915/9.348]

8.914557628631597



  5%|▍         | 12/250 [00:00<00:01, 119.21it/s, loss=8.760/8.699]

epoch:22


100%|██████████| 250/250 [00:02<00:00, 124.82it/s, loss=8.740/9.115]


8.740255799903867


  5%|▍         | 12/250 [00:00<00:02, 117.99it/s, loss=8.721/8.665]

epoch:23


100%|██████████| 250/250 [00:01<00:00, 128.88it/s, loss=8.573/8.085]

8.57292573589325



  4%|▍         | 10/250 [00:00<00:02, 96.67it/s, loss=8.428/8.704]

epoch:24


100%|██████████| 250/250 [00:02<00:00, 120.76it/s, loss=8.417/8.392]


8.416777057037352


100%|██████████| 50/50 [00:00<00:00, 121.33it/s]


lr:0.001; reg:0; acc:0.3501


  4%|▎         | 9/250 [00:00<00:02, 84.94it/s, loss=18.988/18.893]

epoch:0


100%|██████████| 250/250 [00:02<00:00, 117.27it/s, loss=17.775/17.542]

17.775390431213378



  4%|▎         | 9/250 [00:00<00:02, 85.96it/s, loss=16.622/16.472]

epoch:1


100%|██████████| 250/250 [00:02<00:00, 118.30it/s, loss=16.190/15.924]


16.19020367370605


  4%|▍         | 11/250 [00:00<00:02, 106.08it/s, loss=15.618/15.936]

epoch:2


100%|██████████| 250/250 [00:02<00:00, 123.71it/s, loss=15.262/15.129]

15.26187125106812



  4%|▍         | 11/250 [00:00<00:02, 107.87it/s, loss=14.828/14.494]

epoch:3


100%|██████████| 250/250 [00:01<00:00, 127.16it/s, loss=14.572/14.235]

14.572403966064465



  4%|▍         | 10/250 [00:00<00:02, 96.81it/s, loss=14.243/14.082]

epoch:4


100%|██████████| 250/250 [00:02<00:00, 116.12it/s, loss=13.995/13.657]

13.995466187973019



  4%|▍         | 10/250 [00:00<00:02, 99.62it/s, loss=13.804/13.883]

epoch:5


100%|██████████| 250/250 [00:02<00:00, 127.29it/s, loss=13.488/13.384]


13.487838833770759


  4%|▎         | 9/250 [00:00<00:02, 88.85it/s, loss=13.183/13.052]

epoch:6


100%|██████████| 250/250 [00:02<00:00, 122.85it/s, loss=13.028/12.996]


13.028180516281113


  5%|▍         | 12/250 [00:00<00:02, 116.85it/s, loss=12.787/12.774]

epoch:7


100%|██████████| 250/250 [00:01<00:00, 125.20it/s, loss=12.606/12.134]

12.6064019708252



  4%|▍         | 10/250 [00:00<00:02, 96.90it/s, loss=12.404/11.927]

epoch:8


100%|██████████| 250/250 [00:01<00:00, 126.26it/s, loss=12.216/11.938]

12.215849483184806



  4%|▍         | 10/250 [00:00<00:02, 98.63it/s, loss=11.977/12.128]

epoch:9


100%|██████████| 250/250 [00:02<00:00, 121.49it/s, loss=11.853/11.238]

11.852835023651112



  4%|▍         | 10/250 [00:00<00:02, 94.47it/s, loss=11.711/11.513]

epoch:10


100%|██████████| 250/250 [00:02<00:00, 121.52it/s, loss=11.511/11.363]

11.51145655311584



  4%|▍         | 10/250 [00:00<00:02, 96.40it/s, loss=11.231/11.116]

epoch:11


100%|██████████| 250/250 [00:02<00:00, 123.14it/s, loss=11.195/11.005]

11.194531906051637



  4%|▍         | 10/250 [00:00<00:02, 95.89it/s, loss=11.069/10.912]

epoch:12


100%|██████████| 250/250 [00:02<00:00, 122.63it/s, loss=10.893/10.532]

10.893212919616692



  4%|▍         | 11/250 [00:00<00:02, 108.80it/s, loss=10.721/10.951]

epoch:13


100%|██████████| 250/250 [00:02<00:00, 123.11it/s, loss=10.611/10.821]

10.611330493316649



  4%|▎         | 9/250 [00:00<00:03, 78.83it/s, loss=10.360/10.229]

epoch:14


 14%|█▎        | 34/250 [00:00<00:02, 95.30it/s, loss=10.378/10.665]