In [1]:
import os
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from vit_pytorch import ViT
from torch.nn import CrossEntropyLoss

import utils
import criterion
from datasets.cpn_vit import CPNvit
from utils import ext_transforms as et

torch.cuda.set_device(0)
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
devices = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'devices: {devices}')

  from .autonotebook import tqdm as notebook_tqdm


devices: cuda


In [2]:
# num class 1024
# 32 x 32 / block size 2^4 = 16 > randn crop

v = ViT(
    image_size = 512,
    patch_size = 16,
    num_classes = 1024,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
).to(devices)

# img = torch.randn(1, 3, 256, 256).to(devices)
# preds = v(img) # (1, 1000)

In [3]:
transform = et.ExtCompose([
            et.ExtRandomCrop(size=(512, 512), is_crop=True, pad_if_needed=True),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            
dst = CPNvit(root='/data1/sdi/datasets', datatype='CPN', image_set='train',
            transform=transform, is_rgb=True, dver='splits/v5/3')
loader = DataLoader(dst, batch_size=8,
                        shuffle=True, num_workers=2, drop_last=True)
print(f'len [train]: {len(dst)}')

len [train]: 374


In [4]:
optimizer = optim.SGD(v.parameters(), 
                        lr=0.1, 
                        weight_decay=5e-4,
                        momentum=0.9)
scheduler = utils.PolyLR(optimizer, 2000, power=0.9)

costfunction = CrossEntropyLoss()

In [5]:
for epoch in range(0, 2000):

    v.train()
    running_loss = 0.0
    running_correct = 0.0

    for i, (images, labels) in enumerate(loader):
        images = images.to(devices)
        labels = labels.to(devices)

        outputs = v(images)
        probs = nn.Softmax(dim=1)(outputs)
        preds = torch.max(probs, 1)[1]

        optimizer.zero_grad()
        loss = costfunction(outputs, labels)
        loss.backward()

        optimizer.step()
        running_loss += loss.item() * images.size(0)
        running_correct += torch.sum(preds == labels)

    scheduler.step()
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = running_correct.float() / len(loader.dataset)
    
    print(f'running correct: {epoch_acc:.4f}')

running correct: 0.0053
running correct: 0.0080
running correct: 0.0107


In [3]:
from torchsummary import summary

summary(v, (3, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1            [-1, 1024, 768]               0
            Linear-2           [-1, 1024, 1024]         787,456
           Dropout-3           [-1, 1025, 1024]               0
         LayerNorm-4           [-1, 1025, 1024]           2,048
            Linear-5           [-1, 1025, 3072]       3,145,728
           Softmax-6       [-1, 16, 1025, 1025]               0
           Dropout-7       [-1, 16, 1025, 1025]               0
            Linear-8           [-1, 1025, 1024]       1,049,600
           Dropout-9           [-1, 1025, 1024]               0
        Attention-10           [-1, 1025, 1024]               0
          PreNorm-11           [-1, 1025, 1024]               0
        LayerNorm-12           [-1, 1025, 1024]           2,048
           Linear-13           [-1, 1025, 2048]       2,099,200
             GELU-14           [-1, 102