In [1]:
import torch
from vit_pytorch import ViT
import os
from torch.utils.tensorboard import SummaryWriter
import utils
from torchvision import transforms
import my_dataset
from tqdm import tqdm
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import math

In [2]:
image_path_train = utils.read_file("./data/train-path.txt")
image_path_val = utils.read_file("./data/val-path.txt")
image_path_test = utils.read_file("./data/test-path.txt")
labels_train = utils.read_file("./data/train-anno.txt","int")
labels_val = utils.read_file("./data/val-anno.txt","int")
labels_test = utils.read_file("./data/test-anno.txt","int")

In [3]:
data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(256),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.RandomResizedCrop(256),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}

In [4]:
train_dataset = my_dataset.MyDataSet(images_path=image_path_train,
                        images_class=labels_train,
                        transform=data_transform["train"])

val_dataset = my_dataset.MyDataSet(images_path=image_path_val,
                        images_class=labels_val,
                        transform=data_transform["val"])

In [5]:
BATCH_SIZE = 16
nw = min([os.cpu_count(), BATCH_SIZE if BATCH_SIZE > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

Using 8 dataloader workers every process


In [6]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=BATCH_SIZE,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=nw,
                                            collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=BATCH_SIZE,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=nw,
                                            collate_fn=val_dataset.collate_fn)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 20,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)

In [9]:
lr = 0.001
epochs = 80
lrf = 0.01

In [10]:
loss_function = torch.nn.CrossEntropyLoss()
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=lr, momentum=0.9, weight_decay=5E-5)
lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf  # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
best_acc = 0

In [11]:
for epoch in range(epochs):
    accu_loss = torch.zeros(1).to(device)  # 累计损失
    accu_num = torch.zeros(1).to(device)  # 累计预测正确的样本数
    optimizer.zero_grad()

    sample_num = 0
    data_loader = tqdm(train_loader)
    for step, data in enumerate(data_loader):
        images, labels = data

        sample_num += images.shape[0]

        pred = model(images.to(device))
        
        pred_classes = torch.max(pred, dim=1)[1]  # 预测的类别，[1]是标签索引
       
        
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()
        loss = loss_function(pred, labels.to(device))
        loss.backward()
        
        accu_loss += loss.detach()
        
        data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
                                                                               accu_loss.item() / (step + 1),
                                                                               accu_num.item() / sample_num)
        optimizer.step()  # 更新
        optimizer.zero_grad()
    
    train_loss =  accu_loss.item() / (step + 1)
    train_acc = accu_num.item() / sample_num
    scheduler.step()
    val_loss, val_acc = utils.evaluate(model=model,
                                data_loader=val_loader,
                                device=device,
                                epoch=epoch)

[train epoch 0] loss: 2.494, acc: 0.385: 100%|██████████| 749/749 [01:41<00:00,  7.40it/s]
[valid epoch 0] loss: 2.482, acc: 0.395: 100%|██████████| 215/215 [00:20<00:00, 10.61it/s]
[train epoch 1] loss: 2.428, acc: 0.387: 100%|██████████| 749/749 [01:37<00:00,  7.65it/s]
[valid epoch 1] loss: 2.406, acc: 0.394: 100%|██████████| 215/215 [00:20<00:00, 10.71it/s]
[train epoch 2] loss: 2.397, acc: 0.393: 100%|██████████| 749/749 [01:38<00:00,  7.60it/s]
[valid epoch 2] loss: 2.385, acc: 0.367: 100%|██████████| 215/215 [00:19<00:00, 10.97it/s]
[train epoch 3] loss: 2.368, acc: 0.392: 100%|██████████| 749/749 [01:37<00:00,  7.66it/s]
[valid epoch 3] loss: 2.356, acc: 0.399: 100%|██████████| 215/215 [00:19<00:00, 10.85it/s]
[train epoch 4] loss: 2.360, acc: 0.390: 100%|██████████| 749/749 [01:37<00:00,  7.66it/s]
[valid epoch 4] loss: 2.356, acc: 0.379: 100%|██████████| 215/215 [00:19<00:00, 10.88it/s]
[train epoch 5] loss: 2.343, acc: 0.396: 100%|██████████| 749/749 [01:37<00:00,  7.67it/s]