In [1]:
import utils
from torchvision import transforms
import my_dataset
import os
import torch
from tqdm import tqdm
from models.vit import ViT
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from models_DLA.dla_simple import SimpleDLA
from models import ResNet18

In [2]:
batch_size = 32
size = 64
lr = 0.1
epochs = 230
save = False

In [3]:
# writer = SummaryWriter(log_dir = 'logs_LDA')

In [4]:
train_data = utils.read_file("../cifar10/train_data.txt")
val_data = utils.read_file("../cifar10/val_data.txt")
test_data = utils.read_file("../cifar10/test_data.txt")

import random
train_data = random.sample(train_data, 10500)


data_transform = {
        "train": transforms.Compose([
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.2, 0.2, 0.2])]),
        "val": transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.2, 0.2, 0.2])])}
train_dataset = my_dataset.MyDataSet_CIFAR_Tracin(images_path=train_data,
                        transform=data_transform["train"])

val_dataset = my_dataset.MyDataSet_CIFAR_Tracin(images_path=val_data,
                        transform=data_transform["val"])

In [5]:
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

Using 8 dataloader workers every process


In [6]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size = batch_size,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=nw,
                                            collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size = batch_size,
                                            shuffle=False,
                                            pin_memory=True,
                                            num_workers=nw,
                                            collate_fn=val_dataset.collate_fn)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
# model = ViT(
#     image_size = size,
#     patch_size = 4,
#     num_classes = 10,
#     dim = 512,
#     depth = 6,
#     heads = 8,
#     mlp_dim = 512,
#     dropout = 0.1,
#     emb_dropout = 0.1
# ).to(device)

# model = SimpleDLA().to(device)
model = ResNet18().to(device)

In [9]:
# input = torch.randn(1, 3, 32, 32).to(device)
# writer.add_graph(model, input)

In [10]:
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [11]:
tags = ["train_loss2", "train_acc2", "val_loss2", "val_acc2", "learning_rate2"]
best_acc = 0

In [12]:
val_l = []
val_pre = []
val_ac = []
val_path = []
for epoch in range(epochs):
    
    model.train()
    accu_loss = torch.zeros(1).to(device)  # 累计损失
    accu_num = torch.zeros(1).to(device)  # 累计预测正确的样本数
    optimizer.zero_grad()

    sample_num = 0
    data_loader = tqdm(train_loader)
    for step, data in enumerate(data_loader):
        images, labels, p = data

        optimizer.zero_grad()
        sample_num += images.shape[0]

        pred = model(images.to(device))
        
        pred_classes = torch.max(pred, dim=1)[1]  # 预测的类别，[1]是标签索引
       
        
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()
        loss = loss_function(pred, labels.to(device))
        loss.backward()
        
        accu_loss += loss.detach()
        
        data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
                                                                               accu_loss.item() / (step + 1),
                                                                               accu_num.item() / sample_num)
        optimizer.step()  # 更新

    train_loss =  accu_loss.item() / (step + 1)
    train_acc = accu_num.item() / sample_num
    val_loss, val_acc, paths, ls, pres, acs = utils.evaluate_save(model=model,
                                data_loader=val_loader,
                                device=device,
                                epoch=epoch,
                                save=True)
    val_path = paths
    val_pre.append(pres)
    val_ac.append(acs)
    val_l = ls
    # writer.add_scalar(tags[0], train_loss, epoch)
    # writer.add_scalar(tags[1], train_acc, epoch)
    # writer.add_scalar(tags[2], val_loss, epoch)
    # writer.add_scalar(tags[3], val_acc, epoch)
    # writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
    
    
    scheduler.step()
    if save:
        if epoch < 40:
             
            state = {
                    'epoch' : epoch + 1,  #保存当前的迭代次数
                    'state_dict' : model.state_dict(), #保存模型参数
                    'optimizer' : optimizer.state_dict()
                }

            torch.save(state, './weights_LDA/checkpoint-' + str(epoch+1) + '.pth.tar')  
        
        if (epoch + 1)%10 == 0 and epoch > 40:
            state = {
                    'epoch' : epoch + 1,  #保存当前的迭代次数
                    'state_dict' : model.state_dict(), #保存模型参数
                    'optimizer' : optimizer.state_dict()
                }

            torch.save(state, './weights_LDA/checkpoint-' + str(epoch+1) + '.pth.tar')  

[train epoch 0] loss: 2.406, acc: 0.182: 100%|██████████| 329/329 [00:11<00:00, 29.03it/s]
[valid epoch 0] loss: 1.989, acc: 0.255: 100%|██████████| 313/313 [00:03<00:00, 87.91it/s] 
[train epoch 1] loss: 1.956, acc: 0.268: 100%|██████████| 329/329 [00:10<00:00, 32.49it/s]
[valid epoch 1] loss: 1.809, acc: 0.342: 100%|██████████| 313/313 [00:03<00:00, 94.01it/s] 
[train epoch 2] loss: 1.851, acc: 0.314: 100%|██████████| 329/329 [00:10<00:00, 32.65it/s]
[valid epoch 2] loss: 1.743, acc: 0.357: 100%|██████████| 313/313 [00:03<00:00, 96.00it/s] 
[train epoch 3] loss: 1.784, acc: 0.344: 100%|██████████| 329/329 [00:10<00:00, 32.64it/s]
[valid epoch 3] loss: 1.749, acc: 0.357: 100%|██████████| 313/313 [00:03<00:00, 96.17it/s] 
[train epoch 4] loss: 1.699, acc: 0.374: 100%|██████████| 329/329 [00:10<00:00, 32.31it/s]
[valid epoch 4] loss: 1.770, acc: 0.366: 100%|██████████| 313/313 [00:03<00:00, 89.85it/s] 
[train epoch 5] loss: 1.646, acc: 0.393: 100%|██████████| 329/329 [00:10<00:00, 31.87

In [None]:
val_pre = np.array(val_pre).T
val_ac = np.array(val_ac).T

In [None]:
vals = []
for index, p in enumerate(val_path):
    data = {}
    data["path"] = p
    data["label"] = val_l[index].item()
    data["pre"] = val_pre[index].tolist()
    data["acc"] = val_ac[index].tolist()
    vals.append(data)

In [None]:
import yaml

In [None]:
with open("./train_detail/val_data_resnet18.yaml", "w", encoding="utf-8") as f:
    yaml.dump(vals, f, allow_unicode=True)