In [1]:
import warnings
# 忽视警告
warnings.filterwarnings('ignore')

import cv2
from PIL import Image
import numpy as np
import copy
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.utils.data import DataLoader

In [2]:
def processing_data(data_path, height=224, width=224, batch_size=64,
                    test_split=0.1):
    """
    数据处理部分
    :param data_path: 数据路径
    :param height:高度
    :param width: 宽度
    :param batch_size: 每次读取图片的数量
    :param test_split: 测试集划分比例
    :return: 
    """
    transforms = T.Compose([
        T.Resize((height, width)),
        T.RandomHorizontalFlip(0.1),  # 进行随机水平翻转
        T.RandomVerticalFlip(0.1),  # 进行随机竖直翻转
        T.ToTensor(),  # 转化为张量
        T.Normalize([0], [1]),  # 归一化
    ])

    dataset = ImageFolder(data_path, transform=transforms)
    # 划分数据集
    train_size = int((1-test_split)*len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    # 创建一个 DataLoader 对象
    train_data_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
    valid_data_loader = DataLoader(test_dataset, batch_size=batch_size,shuffle=True)

    return train_data_loader, valid_data_loader

In [3]:
torch.cuda.is_available()

True

In [4]:
data_path = './dataset'
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
train_data_loader, valid_data_loader = processing_data(data_path=data_path, height=112, width=112, batch_size=512)
epochs = 200
print('加载完成...')

加载完成...


In [5]:
from model import Arcface
from model import Backbone
class ArcfaceNet(nn.Module):
    def __init__(self, classes=10):
        super(ArcfaceNet, self).__init__()
        self.Resnet = Backbone(num_layers=50,drop_ratio=0.6,mode='ir_se').to(device)
        self.Resnet.load_state_dict(torch.load('model_ir_se50.pth'))
        self.head = Arcface(embedding_size=512, classnum=classes).to(device)

    def forward(self, x):
        out = self.Resnet(x)
        return out

model = ArcfaceNet(classes=10)
optimizer = optim.Adam(model.parameters(), lr=1e-3,weight_decay=5e-4)  # 优化器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                 'max', 
                                                 factor=0.5,
                                                 patience=2)
criterion = nn.CrossEntropyLoss()  

In [6]:
def train(model,epochs):
    model.train()
    loss_list = []  # 存储损失函数值
    running_loss = 0.    
    for epoch in range(epochs):
        print('epoch {} started'.format(epoch))                            
        for batch_idx,(imgs, labels) in tqdm(enumerate(train_data_loader, 1)):
            imgs = imgs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            embeddings = model(imgs)
            thetas = model.head(embeddings, labels)
            loss = nn.CrossEntropyLoss()(thetas, labels)
            loss.backward()
            running_loss += loss.item()
            optimizer.step  
            loss_list.append(loss)
        print('step:' + str(epoch + 1) + '/' + str(epochs) + ' || Total Loss: %.4f' % (loss))
    torch.save(model.state_dict(), './results/temp.pth')
train(model,epochs)

epoch 0 started


0it [00:00, ?it/s]

step:1/200 || Total Loss: 35.4354
epoch 1 started


0it [00:00, ?it/s]

step:2/200 || Total Loss: 35.2008
epoch 2 started


0it [00:00, ?it/s]

step:3/200 || Total Loss: 36.1949
epoch 3 started


0it [00:00, ?it/s]

step:4/200 || Total Loss: 35.2480
epoch 4 started


0it [00:00, ?it/s]

step:5/200 || Total Loss: 35.8609
epoch 5 started


0it [00:00, ?it/s]

step:6/200 || Total Loss: 35.6770
epoch 6 started


0it [00:00, ?it/s]

step:7/200 || Total Loss: 35.5949
epoch 7 started


0it [00:00, ?it/s]

step:8/200 || Total Loss: 35.4605
epoch 8 started


0it [00:00, ?it/s]

step:9/200 || Total Loss: 36.1350
epoch 9 started


0it [00:00, ?it/s]

step:10/200 || Total Loss: 35.1898
epoch 10 started


0it [00:00, ?it/s]

step:11/200 || Total Loss: 35.6202
epoch 11 started


0it [00:00, ?it/s]

step:12/200 || Total Loss: 34.9543
epoch 12 started


0it [00:00, ?it/s]

step:13/200 || Total Loss: 35.1941
epoch 13 started


0it [00:00, ?it/s]

step:14/200 || Total Loss: 35.1961
epoch 14 started


0it [00:00, ?it/s]

step:15/200 || Total Loss: 35.3187
epoch 15 started


0it [00:00, ?it/s]

step:16/200 || Total Loss: 35.5941
epoch 16 started


0it [00:00, ?it/s]

step:17/200 || Total Loss: 35.5175
epoch 17 started


0it [00:00, ?it/s]

step:18/200 || Total Loss: 35.7539
epoch 18 started


0it [00:00, ?it/s]

step:19/200 || Total Loss: 35.7622
epoch 19 started


0it [00:00, ?it/s]

step:20/200 || Total Loss: 35.7343
epoch 20 started


0it [00:00, ?it/s]

step:21/200 || Total Loss: 35.4086
epoch 21 started


0it [00:00, ?it/s]

step:22/200 || Total Loss: 35.4598
epoch 22 started


0it [00:00, ?it/s]

step:23/200 || Total Loss: 35.4120
epoch 23 started


0it [00:00, ?it/s]

step:24/200 || Total Loss: 35.4821
epoch 24 started


0it [00:00, ?it/s]

step:25/200 || Total Loss: 35.8600
epoch 25 started


0it [00:00, ?it/s]

step:26/200 || Total Loss: 35.8224
epoch 26 started


0it [00:00, ?it/s]

step:27/200 || Total Loss: 35.4431
epoch 27 started


0it [00:00, ?it/s]

step:28/200 || Total Loss: 36.1704
epoch 28 started


0it [00:00, ?it/s]

step:29/200 || Total Loss: 35.0511
epoch 29 started


0it [00:00, ?it/s]

step:30/200 || Total Loss: 35.9170
epoch 30 started


0it [00:00, ?it/s]

step:31/200 || Total Loss: 36.0697
epoch 31 started


0it [00:00, ?it/s]

step:32/200 || Total Loss: 35.5642
epoch 32 started


0it [00:00, ?it/s]

step:33/200 || Total Loss: 35.4301
epoch 33 started


0it [00:00, ?it/s]

step:34/200 || Total Loss: 35.1970
epoch 34 started


0it [00:00, ?it/s]

step:35/200 || Total Loss: 35.0312
epoch 35 started


0it [00:00, ?it/s]

step:36/200 || Total Loss: 35.1647
epoch 36 started


0it [00:00, ?it/s]

step:37/200 || Total Loss: 35.6655
epoch 37 started


0it [00:00, ?it/s]

step:38/200 || Total Loss: 35.7085
epoch 38 started


0it [00:00, ?it/s]

step:39/200 || Total Loss: 35.1649
epoch 39 started


0it [00:00, ?it/s]

step:40/200 || Total Loss: 35.2633
epoch 40 started


0it [00:00, ?it/s]

step:41/200 || Total Loss: 35.3267
epoch 41 started


0it [00:00, ?it/s]

step:42/200 || Total Loss: 35.0248
epoch 42 started


0it [00:00, ?it/s]

step:43/200 || Total Loss: 35.5528
epoch 43 started


0it [00:00, ?it/s]

step:44/200 || Total Loss: 35.4315
epoch 44 started


0it [00:00, ?it/s]

step:45/200 || Total Loss: 35.4286
epoch 45 started


0it [00:00, ?it/s]

step:46/200 || Total Loss: 35.5612
epoch 46 started


0it [00:00, ?it/s]

step:47/200 || Total Loss: 35.7621
epoch 47 started


0it [00:00, ?it/s]

step:48/200 || Total Loss: 35.6619
epoch 48 started


0it [00:00, ?it/s]

step:49/200 || Total Loss: 35.6906
epoch 49 started


0it [00:00, ?it/s]

step:50/200 || Total Loss: 35.8379
epoch 50 started


0it [00:00, ?it/s]

step:51/200 || Total Loss: 35.3993
epoch 51 started


0it [00:00, ?it/s]

step:52/200 || Total Loss: 35.1924
epoch 52 started


0it [00:00, ?it/s]

step:53/200 || Total Loss: 35.0504
epoch 53 started


0it [00:00, ?it/s]

step:54/200 || Total Loss: 35.4077
epoch 54 started


0it [00:00, ?it/s]

step:55/200 || Total Loss: 35.3527
epoch 55 started


0it [00:00, ?it/s]

step:56/200 || Total Loss: 35.5075
epoch 56 started


0it [00:00, ?it/s]

step:57/200 || Total Loss: 35.3398
epoch 57 started


0it [00:00, ?it/s]

step:58/200 || Total Loss: 35.5605
epoch 58 started


0it [00:00, ?it/s]

step:59/200 || Total Loss: 35.2098
epoch 59 started


0it [00:00, ?it/s]

step:60/200 || Total Loss: 35.4546
epoch 60 started


0it [00:00, ?it/s]

step:61/200 || Total Loss: 35.9160
epoch 61 started


0it [00:00, ?it/s]

step:62/200 || Total Loss: 35.1337
epoch 62 started


0it [00:00, ?it/s]

step:63/200 || Total Loss: 35.2761
epoch 63 started


0it [00:00, ?it/s]

step:64/200 || Total Loss: 36.1416
epoch 64 started


0it [00:00, ?it/s]

step:65/200 || Total Loss: 34.9724
epoch 65 started


0it [00:00, ?it/s]

step:66/200 || Total Loss: 36.0079
epoch 66 started


0it [00:00, ?it/s]

step:67/200 || Total Loss: 35.4796
epoch 67 started


0it [00:00, ?it/s]

step:68/200 || Total Loss: 35.6744
epoch 68 started


0it [00:00, ?it/s]

step:69/200 || Total Loss: 35.4160
epoch 69 started


0it [00:00, ?it/s]

step:70/200 || Total Loss: 35.6032
epoch 70 started


0it [00:00, ?it/s]

step:71/200 || Total Loss: 34.9901
epoch 71 started


0it [00:00, ?it/s]

step:72/200 || Total Loss: 35.1065
epoch 72 started


0it [00:00, ?it/s]

step:73/200 || Total Loss: 35.6675
epoch 73 started


0it [00:00, ?it/s]

step:74/200 || Total Loss: 35.2646
epoch 74 started


0it [00:00, ?it/s]

step:75/200 || Total Loss: 35.9482
epoch 75 started


0it [00:00, ?it/s]

step:76/200 || Total Loss: 35.7362
epoch 76 started


0it [00:00, ?it/s]

step:77/200 || Total Loss: 35.4524
epoch 77 started


0it [00:00, ?it/s]

step:78/200 || Total Loss: 35.4709
epoch 78 started


0it [00:00, ?it/s]

step:79/200 || Total Loss: 34.8921
epoch 79 started


0it [00:00, ?it/s]

step:80/200 || Total Loss: 35.2036
epoch 80 started


0it [00:00, ?it/s]

step:81/200 || Total Loss: 34.7844
epoch 81 started


0it [00:00, ?it/s]

step:82/200 || Total Loss: 36.1515
epoch 82 started


0it [00:00, ?it/s]

step:83/200 || Total Loss: 35.5726
epoch 83 started


0it [00:00, ?it/s]

step:84/200 || Total Loss: 35.8926
epoch 84 started


0it [00:00, ?it/s]

step:85/200 || Total Loss: 35.2120
epoch 85 started


0it [00:00, ?it/s]

step:86/200 || Total Loss: 35.2193
epoch 86 started


0it [00:00, ?it/s]

step:87/200 || Total Loss: 35.0705
epoch 87 started


0it [00:00, ?it/s]

step:88/200 || Total Loss: 35.6018
epoch 88 started


0it [00:00, ?it/s]

step:89/200 || Total Loss: 35.2281
epoch 89 started


0it [00:00, ?it/s]

step:90/200 || Total Loss: 35.8652
epoch 90 started


0it [00:00, ?it/s]

step:91/200 || Total Loss: 35.7913
epoch 91 started


0it [00:00, ?it/s]

step:92/200 || Total Loss: 35.6592
epoch 92 started


0it [00:00, ?it/s]

step:93/200 || Total Loss: 35.1586
epoch 93 started


0it [00:00, ?it/s]

step:94/200 || Total Loss: 35.7470
epoch 94 started


0it [00:00, ?it/s]

step:95/200 || Total Loss: 35.5728
epoch 95 started


0it [00:00, ?it/s]

step:96/200 || Total Loss: 35.3773
epoch 96 started


0it [00:00, ?it/s]

step:97/200 || Total Loss: 35.7419
epoch 97 started


0it [00:00, ?it/s]

step:98/200 || Total Loss: 36.1861
epoch 98 started


0it [00:00, ?it/s]

step:99/200 || Total Loss: 35.3238
epoch 99 started


0it [00:00, ?it/s]

step:100/200 || Total Loss: 34.9947
epoch 100 started


0it [00:00, ?it/s]

step:101/200 || Total Loss: 35.3743
epoch 101 started


0it [00:00, ?it/s]

step:102/200 || Total Loss: 35.3637
epoch 102 started


0it [00:00, ?it/s]

step:103/200 || Total Loss: 35.6734
epoch 103 started


0it [00:00, ?it/s]

step:104/200 || Total Loss: 35.8105
epoch 104 started


0it [00:00, ?it/s]

step:105/200 || Total Loss: 35.8428
epoch 105 started


0it [00:00, ?it/s]

step:106/200 || Total Loss: 35.4949
epoch 106 started


0it [00:00, ?it/s]

step:107/200 || Total Loss: 35.7924
epoch 107 started


0it [00:00, ?it/s]

step:108/200 || Total Loss: 35.7064
epoch 108 started


0it [00:00, ?it/s]

step:109/200 || Total Loss: 36.3906
epoch 109 started


0it [00:00, ?it/s]

step:110/200 || Total Loss: 35.4201
epoch 110 started


0it [00:00, ?it/s]

step:111/200 || Total Loss: 35.9459
epoch 111 started


0it [00:00, ?it/s]

step:112/200 || Total Loss: 35.6840
epoch 112 started


0it [00:00, ?it/s]

step:113/200 || Total Loss: 35.4327
epoch 113 started


0it [00:00, ?it/s]

step:114/200 || Total Loss: 35.1327
epoch 114 started


0it [00:00, ?it/s]

step:115/200 || Total Loss: 35.5390
epoch 115 started


0it [00:00, ?it/s]

step:116/200 || Total Loss: 35.5166
epoch 116 started


0it [00:00, ?it/s]

step:117/200 || Total Loss: 35.4669
epoch 117 started


0it [00:00, ?it/s]

step:118/200 || Total Loss: 35.4836
epoch 118 started


0it [00:00, ?it/s]

step:119/200 || Total Loss: 35.6589
epoch 119 started


0it [00:00, ?it/s]

step:120/200 || Total Loss: 34.8233
epoch 120 started


0it [00:00, ?it/s]

step:121/200 || Total Loss: 35.1824
epoch 121 started


0it [00:00, ?it/s]

step:122/200 || Total Loss: 35.2520
epoch 122 started


0it [00:00, ?it/s]

step:123/200 || Total Loss: 35.7415
epoch 123 started


0it [00:00, ?it/s]

step:124/200 || Total Loss: 35.5204
epoch 124 started


0it [00:00, ?it/s]

step:125/200 || Total Loss: 35.7126
epoch 125 started


0it [00:00, ?it/s]

step:126/200 || Total Loss: 35.5406
epoch 126 started


0it [00:00, ?it/s]

step:127/200 || Total Loss: 35.7592
epoch 127 started


0it [00:00, ?it/s]

step:128/200 || Total Loss: 35.4891
epoch 128 started


0it [00:00, ?it/s]

step:129/200 || Total Loss: 36.0218
epoch 129 started


0it [00:00, ?it/s]

step:130/200 || Total Loss: 35.6334
epoch 130 started


0it [00:00, ?it/s]

step:131/200 || Total Loss: 35.7609
epoch 131 started


0it [00:00, ?it/s]

step:132/200 || Total Loss: 35.2352
epoch 132 started


0it [00:00, ?it/s]

step:133/200 || Total Loss: 35.6103
epoch 133 started


0it [00:00, ?it/s]

step:134/200 || Total Loss: 35.5967
epoch 134 started


0it [00:00, ?it/s]

step:135/200 || Total Loss: 36.0601
epoch 135 started


0it [00:00, ?it/s]

step:136/200 || Total Loss: 35.6226
epoch 136 started


0it [00:00, ?it/s]

step:137/200 || Total Loss: 35.7336
epoch 137 started


0it [00:00, ?it/s]

step:138/200 || Total Loss: 36.3225
epoch 138 started


0it [00:00, ?it/s]

step:139/200 || Total Loss: 35.6208
epoch 139 started


0it [00:00, ?it/s]

step:140/200 || Total Loss: 36.1231
epoch 140 started


0it [00:00, ?it/s]

step:141/200 || Total Loss: 35.8532
epoch 141 started


0it [00:00, ?it/s]

step:142/200 || Total Loss: 35.0178
epoch 142 started


0it [00:00, ?it/s]

step:143/200 || Total Loss: 35.3967
epoch 143 started


0it [00:00, ?it/s]

step:144/200 || Total Loss: 35.3995
epoch 144 started


0it [00:00, ?it/s]

step:145/200 || Total Loss: 35.6307
epoch 145 started


0it [00:00, ?it/s]

step:146/200 || Total Loss: 35.5817
epoch 146 started


0it [00:00, ?it/s]

step:147/200 || Total Loss: 35.8862
epoch 147 started


0it [00:00, ?it/s]

step:148/200 || Total Loss: 35.8339
epoch 148 started


0it [00:00, ?it/s]

step:149/200 || Total Loss: 35.6128
epoch 149 started


0it [00:00, ?it/s]

step:150/200 || Total Loss: 35.6895
epoch 150 started


0it [00:00, ?it/s]

step:151/200 || Total Loss: 35.7234
epoch 151 started


0it [00:00, ?it/s]

step:152/200 || Total Loss: 35.9415
epoch 152 started


0it [00:00, ?it/s]

step:153/200 || Total Loss: 35.4873
epoch 153 started


0it [00:00, ?it/s]

step:154/200 || Total Loss: 35.2460
epoch 154 started


0it [00:00, ?it/s]

step:155/200 || Total Loss: 35.6401
epoch 155 started


0it [00:00, ?it/s]

step:156/200 || Total Loss: 34.9585
epoch 156 started


0it [00:00, ?it/s]

step:157/200 || Total Loss: 35.0083
epoch 157 started


0it [00:00, ?it/s]

step:158/200 || Total Loss: 35.6422
epoch 158 started


0it [00:00, ?it/s]

step:159/200 || Total Loss: 35.1751
epoch 159 started


0it [00:00, ?it/s]

step:160/200 || Total Loss: 35.3257
epoch 160 started


0it [00:00, ?it/s]

step:161/200 || Total Loss: 35.4287
epoch 161 started


0it [00:00, ?it/s]

step:162/200 || Total Loss: 35.8224
epoch 162 started


0it [00:00, ?it/s]

step:163/200 || Total Loss: 35.7260
epoch 163 started


0it [00:00, ?it/s]

step:164/200 || Total Loss: 35.1646
epoch 164 started


0it [00:00, ?it/s]

step:165/200 || Total Loss: 35.7789
epoch 165 started


0it [00:00, ?it/s]

step:166/200 || Total Loss: 35.3721
epoch 166 started


0it [00:00, ?it/s]

step:167/200 || Total Loss: 35.8526
epoch 167 started


0it [00:00, ?it/s]

step:168/200 || Total Loss: 35.3787
epoch 168 started


0it [00:00, ?it/s]

step:169/200 || Total Loss: 34.7916
epoch 169 started


0it [00:00, ?it/s]

step:170/200 || Total Loss: 35.5393
epoch 170 started


0it [00:00, ?it/s]

step:171/200 || Total Loss: 35.8118
epoch 171 started


0it [00:00, ?it/s]

step:172/200 || Total Loss: 35.3366
epoch 172 started


0it [00:00, ?it/s]

step:173/200 || Total Loss: 35.5456
epoch 173 started


0it [00:00, ?it/s]

step:174/200 || Total Loss: 35.3797
epoch 174 started


0it [00:00, ?it/s]

step:175/200 || Total Loss: 35.5616
epoch 175 started


0it [00:00, ?it/s]

step:176/200 || Total Loss: 35.2308
epoch 176 started


0it [00:00, ?it/s]

step:177/200 || Total Loss: 35.6645
epoch 177 started


0it [00:00, ?it/s]

step:178/200 || Total Loss: 35.7620
epoch 178 started


0it [00:00, ?it/s]

step:179/200 || Total Loss: 35.1838
epoch 179 started


0it [00:00, ?it/s]

step:180/200 || Total Loss: 35.9053
epoch 180 started


0it [00:00, ?it/s]

step:181/200 || Total Loss: 34.9805
epoch 181 started


0it [00:00, ?it/s]

step:182/200 || Total Loss: 35.6043
epoch 182 started


0it [00:00, ?it/s]

step:183/200 || Total Loss: 35.7987
epoch 183 started


0it [00:00, ?it/s]

step:184/200 || Total Loss: 35.7091
epoch 184 started


0it [00:00, ?it/s]

step:185/200 || Total Loss: 35.3962
epoch 185 started


0it [00:00, ?it/s]

step:186/200 || Total Loss: 36.3949
epoch 186 started


0it [00:00, ?it/s]

step:187/200 || Total Loss: 35.8973
epoch 187 started


0it [00:00, ?it/s]

step:188/200 || Total Loss: 35.9324
epoch 188 started


0it [00:00, ?it/s]

step:189/200 || Total Loss: 35.1016
epoch 189 started


0it [00:00, ?it/s]

step:190/200 || Total Loss: 35.3569
epoch 190 started


0it [00:00, ?it/s]

step:191/200 || Total Loss: 36.0154
epoch 191 started


0it [00:00, ?it/s]

step:192/200 || Total Loss: 35.5672
epoch 192 started


0it [00:00, ?it/s]

step:193/200 || Total Loss: 35.4538
epoch 193 started


0it [00:00, ?it/s]

step:194/200 || Total Loss: 35.4746
epoch 194 started


0it [00:00, ?it/s]

step:195/200 || Total Loss: 36.2494
epoch 195 started


0it [00:00, ?it/s]

step:196/200 || Total Loss: 35.8582
epoch 196 started


0it [00:00, ?it/s]

step:197/200 || Total Loss: 35.7553
epoch 197 started


0it [00:00, ?it/s]

step:198/200 || Total Loss: 35.4577
epoch 198 started


0it [00:00, ?it/s]

step:199/200 || Total Loss: 35.5307
epoch 199 started


0it [00:00, ?it/s]

step:200/200 || Total Loss: 35.5122


In [7]:
train_result = []
for loss in loss_list :
    train_result.append(loss.detach().to('cpu').numpy())
train_result

NameError: name 'loss_list' is not defined

In [None]:
plt.plot(train_result,label = "loss")
plt.legend()
plt.show()