In [None]:
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset
from data_reader import LoadData  # 数据读取


In [None]:
# 数据目录
route = r"D:\data\MNIST"  # 数据目录
result_save_path = r"C:\Users\haokw\Documents\GitHub\VScode\Gan_dataset_expansion\model\CNN_model1"  # 模型和loss图的保存目录
drop_last = False  # 不够一个批次的数据是否舍弃掉，数据量多可以选择True
if not os.path.exists(result_save_path):
    print("dsga")
    os.mkdir(result_save_path)  # 如果没有保存路径的目录文件夹则进行创建

In [None]:
# 训练相关的参数
lr = 0.002  # 判别器学习率
batch_size = 128  # 一个批次的大小
num_epoch = 100  # 训练迭代次数
output_loss_Interval_ratio = 1  # 间隔多少个epoch打印一次损失
test_interval = 1  # 间隔多少个epoch测试一次准确率
# 网络结构相关的参数
input_number_of_channels = 1  # 输入通道数，RGB为3，GRAY为1

In [None]:
#分类网络CNN
class classification_model(nn.Module):
    def __init__(self,output_classes,input_number_of_channels):
        """
        n_classes:类别数
        """
        super(classification_model,self).__init__()
        self.structure=nn.Sequential(

            nn.Conv2d(input_number_of_channels, 6, kernel_size=5, stride=1, padding=2),  # (m,6,28,28)
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (m,6,14,14)

            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),  # (6,16,10,10)
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),  # (6,16,5,5)

            nn.Conv2d(16, output_classes, kernel_size=5, stride=1, padding=0),  # (16,10,1,1)
            # nn.Softmax(dim=1)
        )   
    
    def forward(self,x):
        out=self.structure(x)
        out=out.reshape(out.shape[0],-1)
       
        return out

In [None]:
criterion = nn.CrossEntropyLoss()
model = classification_model(output_classes=10, input_number_of_channels=input_number_of_channels).cuda()
optimer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
# 初始化训练数据读取器
train_dataset = ConcatDataset([LoadData(os.path.join(route, 'train', str(number)), 
                                        input_number_of_channels=input_number_of_channels) for number in range(0, 10)])  # dataset
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                          shuffle=True, drop_last=drop_last)  # dataloader

val_dataset = ConcatDataset([LoadData(os.path.join(route, 'val', str(number)), 
                                        input_number_of_channels=input_number_of_channels) for number in range(0, 10)])  # dataset
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size,
                        shuffle=True, drop_last=drop_last)  # dataloader

In [None]:
#初始化网络参数函数，用于下一个数字开始训练之前
def init_weights(m):
    if hasattr(m,'weight'):
        nn.init.uniform_(m.weight,-0.1,0.1)

In [None]:
# 初始化模型
model.apply(init_weights)

In [30]:

loss_list = []  # 保存每一个epoch的损失值
acc_list = []  # 保存每一个epoch的准确率


for epoch in range(0, num_epoch+1):  # 迭代num_epoch个epoch
    # 训练
    model.train()
    batch_loss = 0  # 累加每个epoch中全部batch的损失值，最后平均得到每个epoch的损失值
    # 每个batch_size的图片
    for img, label in tqdm(train_loader, desc=f'Epoch[{epoch}] train'):
        label = torch.as_tensor(label, dtype=torch.long).cuda()
        output = model(img.cuda())  # 前向传播
        loss = criterion(output, label)  # 计算loss
        optimer.zero_grad()  # 梯度清零
        loss.backward()  # 反向传播
        optimer.step()  # 参数更新
        batch_loss += loss  # 累加loss
   
    # 保存损失值为列表,将所有batch累加的损失值除以batch数即该轮epoch的损失值
    loss_list.append(batch_loss.item()/len(train_loader))

    # 测试
    if epoch % test_interval == 0:  # 间隔test_interval个epoch测试一次准确率
        model.eval()
        batch_acc = 0
        # 每个batch_size的图片
        for img, label in tqdm(val_loader, desc=f'Epoch[{epoch}] test'):
            label = torch.as_tensor(label, dtype=torch.long).cuda()
            prediction_output = model(img.cuda())
            batch_acc += sum(torch.argmax(prediction_output,dim=1) == label)/len(img)

        # 将该轮的测试准确率保存到列表当中
        acc_list.append(batch_acc.item()/len(val_loader))

    # 打印训练的损失和测试的准确率  #间隔output_loss_Interval_ratio个epoch打印一次损失
    if epoch % output_loss_Interval_ratio == 0:
        print('Epoch[{}/{}],loss:{:.6f}'.format(
            epoch, num_epoch,
            batch_loss.item()/len(train_loader)
            ))  # 打印每个epoch的损失值

    # 如果做了测试，则打印准确率
    if epoch % test_interval == 0:
        print('Epoch[{}/{}],acc:{:.6f}'.format(
            epoch, num_epoch,
            acc_list[-1]
        ))  # 打印每个epoch的损失值

    # 保存loss图像
    plt.plot(range(len(loss_list)), loss_list, label="loss")
    plt.plot([i*test_interval for i in range(len(acc_list))],
             acc_list, label="acc")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.legend()
    plt.savefig(os.path.join(result_save_path, 'loss.jpg'))
    plt.clf()

    # 创建保存模型和loss图的目录
    if not os.path.exists(os.path.join(result_save_path)):
        os.mkdir(os.path.join(result_save_path))

    # 保存模型
    torch.save(model, os.path.join(result_save_path, 'last.pth'))

Epoch[0] train: 100%|██████████| 40/40 [00:01<00:00, 20.24it/s]
Epoch[0] test: 100%|██████████| 40/40 [00:02<00:00, 19.06it/s]


Epoch[0/100],loss:0.138184
Epoch[0/100],acc:0.938672


Epoch[1] train: 100%|██████████| 40/40 [00:01<00:00, 20.01it/s]
Epoch[1] test:   0%|          | 0/40 [00:00<?, ?it/s]


KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>