In [2]:
import utils
from torchvision import transforms
import my_dataset
import os
import torch
from tqdm import tqdm
from vit_pytorch import ViT
from vit_pytorch.t2t import T2TViT
from vit_pytorch.cct import CCT
from linformer import Linformer
from torchsummary import summary

In [3]:
epochs = 60
lr = 2e-5
gamma = 0.7
batch_size = 64


In [4]:
train_data = utils.read_file("../cifar10/train_data.txt")
val_data = utils.read_file("../cifar10/val_data.txt")
test_data = utils.read_file("../cifar10/test_data.txt")

In [5]:
data_transform = {
        "train": transforms.Compose([transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}

In [6]:
train_dataset = my_dataset.MyDataSet_CIFAR(images_path=train_data,
                        transform=data_transform["train"])

val_dataset = my_dataset.MyDataSet_CIFAR(images_path=val_data,
                        transform=data_transform["val"])

In [7]:
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

Using 8 dataloader workers every process


In [8]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size = batch_size,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=nw,
                                            collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size = batch_size,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=nw,
                                            collate_fn=val_dataset.collate_fn)

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [12]:
# model = ViT(
#     image_size = 256,
#     patch_size = 32,
#     num_classes = 20,
#     dim = 1024,
#     depth = 6,
#     heads = 16,
#     mlp_dim = 2048,
#     dropout = 0.1,
#     emb_dropout = 0.1
# ).to(device)

# model = T2TViT(
#     dim = 512,
#     image_size = 32,
#     depth = 10,
#     heads = 12,
#     mlp_dim = 512,
#     num_classes = 10,
#     t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module
# ).to(device)

model = CCT(
    img_size = 32,
    embedding_dim = 384,
    n_conv_layers = 2,
    kernel_size = 7,
    stride = 2,
    padding = 3,
    pooling_kernel_size = 3,
    pooling_stride = 2,
    pooling_padding = 1,
    num_layers = 14,
    num_heads = 6,
    mlp_radio = 3.,
    num_classes = 10,
    positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
).to(device)

In [13]:
# loss function
loss_function = torch.nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

In [14]:
for epoch in range(epochs):
    
    model.train()
    accu_loss = torch.zeros(1).to(device)  # 累计损失
    accu_num = torch.zeros(1).to(device)  # 累计预测正确的样本数
    optimizer.zero_grad()

    sample_num = 0
    data_loader = tqdm(train_loader)
    for step, data in enumerate(data_loader):
        images, labels = data

        sample_num += images.shape[0]

        pred = model(images.to(device))
        
        pred_classes = torch.max(pred, dim=1)[1]  # 预测的类别，[1]是标签索引
       
        
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()
        loss = loss_function(pred, labels.to(device))
        loss.backward()
        
        accu_loss += loss.detach()
        
        data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
                                                                               accu_loss.item() / (step + 1),
                                                                               accu_num.item() / sample_num)
        optimizer.step()  # 更新
        optimizer.zero_grad()
    
    train_loss =  accu_loss.item() / (step + 1)
    train_acc = accu_num.item() / sample_num
    scheduler.step()
    val_loss, val_acc = utils.evaluate(model=model,
                                data_loader=val_loader,
                                device=device,
                                epoch=epoch)

[train epoch 0] loss: 1.583, acc: 0.428: 100%|██████████| 547/547 [00:27<00:00, 19.78it/s]
[valid epoch 0] loss: 1.377, acc: 0.505: 100%|██████████| 157/157 [00:01<00:00, 89.70it/s] 
[train epoch 1] loss: 1.298, acc: 0.538: 100%|██████████| 547/547 [00:24<00:00, 22.39it/s]
[valid epoch 1] loss: 1.263, acc: 0.547: 100%|██████████| 157/157 [00:01<00:00, 92.11it/s] 
[train epoch 2] loss: 1.201, acc: 0.574: 100%|██████████| 547/547 [00:24<00:00, 22.43it/s]
[valid epoch 2] loss: 1.205, acc: 0.566: 100%|██████████| 157/157 [00:01<00:00, 91.29it/s] 
[train epoch 3] loss: 1.133, acc: 0.599: 100%|██████████| 547/547 [00:24<00:00, 22.55it/s]
[valid epoch 3] loss: 1.184, acc: 0.579: 100%|██████████| 157/157 [00:01<00:00, 91.65it/s] 
[train epoch 4] loss: 1.092, acc: 0.616: 100%|██████████| 547/547 [00:24<00:00, 22.43it/s]
[valid epoch 4] loss: 1.164, acc: 0.587: 100%|██████████| 157/157 [00:01<00:00, 90.64it/s] 
[train epoch 5] loss: 1.061, acc: 0.627: 100%|██████████| 547/547 [00:24<00:00, 22.35