In [1]:

import torch
from torch import nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import logging

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# log 
# record the process of training
logging.basicConfig(
    filename='vit.log',  
    level=logging.INFO,         
    format='%(asctime)s %(message)s',  
)
logger_all = logging.getLogger("logger_all")



In [3]:

# make single item into a tuple of 2 same items
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# FeedForward network with two fully connected layers and GELU activation
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# Multi-head self-attention
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        
        # Dimension of concatenated attention heads
        inner_dim = dim_head *  heads
        
        # whether the dimension of the attention's output is same to the expected dimension
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.dim_head = dim_head
        
        # qkv scaling
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        # get attention weight
        self.attend = nn.Softmax(dim = -1)
        
        self.dropout = nn.Dropout(dropout)

        # compute qkv at one time
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        
        # transfer the dimension of attention's output into expected output dimension
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        
        # get q, k, v
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        
        # get attention weight
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)
        
        # compute attention out put
        out = torch.matmul(attn, v)
        
        # reshape 
        out = rearrange(out, 'b h n d -> b n (h d)')
        
        # change dimension
        return self.to_out(out)

# Original Transformer
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        
        self.norm = nn.LayerNorm(dim)
        
        # several layers of rotational attention
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

# original Vision Transformer
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        
        # layers for transferring patches into token
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # positional embedding require gradient
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

        # initial class token require gradient
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)
        
        # original transformer
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        
        # Embed image patches into a sequence of tokens
        x = self.to_patch_embedding(img)
        
        b, n, _ = x.shape
        
        # add class token
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # add positional embedding
        x += self.pos_embedding[:, :(n + 1)]
        
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        
        # get classification results
        return self.mlp_head(x)


In [4]:
# my_rovit = RoViT(
#     image_size = 256,
#     patch_size = 16,
#     num_classes = 100,
#     dim = 512,
#     depth = 6,
#     heads = 16,
#     mlp_dim = 1024,
#     dropout = 0.1,
#     emb_dropout = 0.1
# )


In [5]:
# a = torch.randn(8,3,256,256)
# b = my_rovit(a)

In [6]:

# reshape image size and transform image into tensor
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  
])

# train dataset and dataloader
# use CIFAR-100
train_batch = 128
trainset = torchvision.datasets.CIFAR100(root="./cifar-100",
                                         train=True,
                                         download=True,
                                         transform=transform)
# num_worker should be same to the number of CPU core
trainloader = torch.utils.data.DataLoader(
    trainset, 
    batch_size=train_batch, 
    shuffle=True, 
    num_workers=16)


# test dataset and dataloader
# use CIFAR-100
test_batch = 256
testset = torchvision.datasets.CIFAR100(root="./cifar-100",
                                        train=False,
                                        download=True,
                                        transform=transform)
# num_worker should be same to the number of CPU core
testloader = torch.utils.data.DataLoader(
    testset, 
    batch_size=test_batch, 
    shuffle=False, 
    num_workers=16)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
# original model
my_vit = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 100,
    dim = 512,
    depth = 6,
    heads = 16,
    mlp_dim = 1024,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)

In [None]:
# import os

# folder_path = './ckpt/vit/'

# file_names = os.listdir(folder_path)
# file_names = ["rovit_last.pth"]+file_names
# files = [x for x in file_names if x[-4:]==".pth" and "last" not in x]
# cur_epoch = len(files)

In [None]:

# Loss function for classification
criterion = nn.CrossEntropyLoss()

lr = 5e-4
optimizer = optim.Adam(my_vit.parameters(), lr=lr)

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    # Set model to training mode
    my_vit.train() 
    
    # running loss for each epoch
    running_loss = 0.0
    
    # counter of processed items for each epoch 
    cnt = 0
    
    for i, (images, labels) in enumerate(trainloader):
        
        # Move data to GPU
        images = images.to(device)
        labels = labels.to(device)

        # classification result
        outputs = my_vit(images)
        
        # loss
        loss = criterion(outputs, labels)

        # optimize model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # count running loss and processed items
        running_loss += loss.item()
        b = images.shape[0]
        cnt += b
        
        # log message for each iteration
        msg = f'Epoch [{epoch + 1}/{num_epochs}], Item [{cnt}/{len(trainset)}], Loss: {running_loss / (i+1):.4f}'
        logger_all.info(msg)
        print(msg)
    

    # log message for each epoch
    msg = f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(trainloader):.4f}'
    logger_all.info(msg)
    print(msg)
    
    # save checkpoint
    if (epoch + 1)%5==0:
        model_save_path = f'./ckpt/vit/vit_{epoch+1}.pth'
        torch.save(my_vit.state_dict(), model_save_path)
        print(f'Model parameters saved to {model_save_path}')

model_save_path = f'./ckpt/vit/vit_last.pth'
torch.save(my_vit.state_dict(), model_save_path)
print(f'Model parameters saved to {model_save_path}')

In [None]:

# my_vit.eval()  
# correct = 0
# total = 0
# 
# with torch.no_grad():
#     for images, labels in testloader:
#         images = images.to(device)
#         labels = labels.to(device)
#         outputs = my_vit(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
# 
# accuracy = 100 * correct / total
# print(f'Accuracy on CIFAR-100 test images: {accuracy:.2f}%')

In [None]:
# my_vit = ViT(
#     image_size = 256,
#     patch_size = 16,
#     num_classes = 100,
#     dim = 512,
#     depth = 6,
#     heads = 16,
#     mlp_dim = 1024,
#     dropout = 0.1,
#     emb_dropout = 0.1
# ).to(device)
# model_save_path = './ckpt/vit/vit_50.pth'
# my_vit.load_state_dict(torch.load(model_save_path))
# my_vit = my_vit.to(device)
# my_vit.eval()  
# print("Model loaded and ready for inference.")

In [None]:
# a = torch.randn(8,3,256,256).to(device)
# b = my_vit(a)

In [None]:
# my_vit.eval()  
# correct = 0
# total = 0
# 
# with torch.no_grad():
#     for images, labels in testloader:
#         images = images.to(device)
#         labels = labels.to(device)
#         outputs = my_vit(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
# 
# accuracy = 100 * correct / total
# print(f'Accuracy on CIFAR-100 test images: {accuracy:.2f}%')

In [12]:
res = []

In [13]:
# Evaluation for models at different epoches
for e in range(5,55,5):
    my_vit = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 100,
    dim = 512,
    depth = 6,
    heads = 16,
    mlp_dim = 1024,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)
    
    # load checkpoint
    model_save_path = f'./ckpt/vit/vit_{e}.pth'
    my_vit.load_state_dict(torch.load(model_save_path))
    
    # mode model to GPU
    my_vit = my_vit.to(device)
    
    # Set model to evaluation mode
    my_vit.eval() 

    # Accuracy calculation
    # counter for test
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in testloader:
            
            # Move data to GPU
            images = images.to(device)
            labels = labels.to(device)
            
            # classification results
            outputs = my_vit(images)
            
            # predicted class label
            _, predicted = torch.max(outputs.data, 1)
            
            # update counters
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    # get test result for an epoch
    accuracy = 100 * correct / total
    res.append(accuracy)

In [14]:
print(res)

[28.53, 33.39, 34.52, 36.76, 37.55, 37.77, 38.59, 37.21, 38.02, 38.76, 38.14, 39.11, 39.58, 39.0, 38.45, 38.9, 38.91, 39.15, 38.67, 39.36, 39.03, 39.61, 40.27, 40.07, 39.9, 39.98, 39.8, 40.08, 40.68, 39.94]


In [None]:
# import os

# # 指定文件夹路径
# directory = 'ckpt/vit'

# # 列出目录中的所有文件
# for filename in os.listdir(directory):
#     # 检查文件是否以 .pth 结尾
#     if filename.endswith(".pth"):
#         if not filename.endswith('5.pth') and not filename.endswith('0.pth'):
#                 file_path = os.path.join(directory, filename)
#                 # 删除文件
#                 os.remove(file_path)
#                 print(f"Deleted: {file_path}")