In [1]:
import utils
from torchvision import transforms
import my_dataset
import os
import torch
from tqdm import tqdm
from models.vit import ViT
from torch.utils.tensorboard import SummaryWriter

In [2]:
checkpoint = 6

batch_size = 32

In [3]:
lrs = [1.98e-4,1.925e-4,1.8e-4,1.67e-4,1.5e-4,1.3e-4]
lrs2 = [3e-4,2.5e-4,2.3e-4,2e-4,1.8e-4,1.5e-4]
writer = SummaryWriter(log_dir = 'logs')

In [4]:
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

Using 8 dataloader workers every process


In [5]:
train_data = utils.read_file("./tracin_file/checkpoint" + str(checkpoint) + "0_0.3.txt")
val_data = utils.read_file("../cifar10/val_data.txt")
data_transform = {
        "train": transforms.Compose([
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.Resize(32),
                                    transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
train_dataset = my_dataset.MyDataSet_CIFAR(images_path=train_data,
                        transform=data_transform["train"])

val_dataset = my_dataset.MyDataSet_CIFAR(images_path=val_data,
                        transform=data_transform["val"])

train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=nw,
                                            collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=nw,
                                            collate_fn=val_dataset.collate_fn)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
model = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)

In [8]:
model.load_state_dict(torch.load("./weights/model-"+ str(checkpoint) + "0.pth", map_location=device))
model.train()

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=4, p2=4)
    (1): Linear(in_features=48, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): GELU(approximate=none)
    

In [9]:
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lrs2[checkpoint-1])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100 - checkpoint * 10)

In [10]:
tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
best_acc = 0

In [11]:
i = 0
for epoch in range(100 - 10 * checkpoint):
    
    model.train()
    accu_loss = torch.zeros(1).to(device)  # 累计损失
    accu_num = torch.zeros(1).to(device)  # 累计预测正确的样本数
    optimizer.zero_grad()

    sample_num = 0
    data_loader = tqdm(train_loader)
    for step, data in enumerate(data_loader):
        images, labels = data

        sample_num += images.shape[0]

        pred = model(images.to(device))
        
        pred_classes = torch.max(pred, dim=1)[1]  # 预测的类别，[1]是标签索引
       
        
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()
        loss = loss_function(pred, labels.to(device))
        loss.backward()
        
        accu_loss += loss.detach()
        
        data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
                                                                               accu_loss.item() / (step + 1),
                                                                               accu_num.item() / sample_num)
        optimizer.step()  # 更新
        optimizer.zero_grad()
    
    train_loss =  accu_loss.item() / (step + 1)
    train_acc = accu_num.item() / sample_num
    val_loss, val_acc = utils.evaluate(model=model,
                                data_loader=val_loader,
                                device=device,
                                epoch=epoch)
    writer.add_scalar(tags[0] + "_" + str(checkpoint*10), train_loss, epoch + 10 * checkpoint)
    writer.add_scalar(tags[1] + "_" + str(checkpoint*10), train_acc, epoch + 10 * checkpoint)
    writer.add_scalar(tags[2] + "_" + str(checkpoint*10), val_loss, epoch + 10 * checkpoint)
    writer.add_scalar(tags[3] + "_" + str(checkpoint*10), val_acc, epoch + 10 * checkpoint)
    writer.add_scalar(tags[4] + "_" + str(checkpoint*10), optimizer.param_groups[0]["lr"], epoch + 10 * checkpoint)
    scheduler.step()
    i = i + 1
    if i == 20:
        break

[train epoch 0] loss: 0.474, acc: 0.830: 100%|██████████| 329/329 [00:16<00:00, 20.07it/s]
[valid epoch 0] loss: 1.454, acc: 0.606: 100%|██████████| 313/313 [00:04<00:00, 63.48it/s]
[train epoch 1] loss: 0.438, acc: 0.848: 100%|██████████| 329/329 [00:15<00:00, 20.60it/s]
[valid epoch 1] loss: 1.709, acc: 0.573: 100%|██████████| 313/313 [00:04<00:00, 62.84it/s]
[train epoch 2] loss: 0.395, acc: 0.860: 100%|██████████| 329/329 [00:16<00:00, 20.56it/s]
[valid epoch 2] loss: 1.663, acc: 0.583: 100%|██████████| 313/313 [00:04<00:00, 62.78it/s]
[train epoch 3] loss: 0.369, acc: 0.871: 100%|██████████| 329/329 [00:15<00:00, 20.58it/s]
[valid epoch 3] loss: 1.752, acc: 0.572: 100%|██████████| 313/313 [00:04<00:00, 62.62it/s]
[train epoch 4] loss: 0.334, acc: 0.881: 100%|██████████| 329/329 [00:16<00:00, 20.54it/s]
[valid epoch 4] loss: 1.757, acc: 0.577: 100%|██████████| 313/313 [00:05<00:00, 62.46it/s]
[train epoch 5] loss: 0.329, acc: 0.884: 100%|██████████| 329/329 [00:16<00:00, 20.46it/s]