<a href="https://colab.research.google.com/github/NoCodeProgram/deepLearning/blob/main/transformer/tinyVIT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader
print(torch.__version__)

2.5.1+cu121


In [2]:
if torch.backends.mps.is_available():
    my_device = torch.device('mps')
elif torch.cuda.is_available():
    my_device = torch.device('cuda')
else:
    my_device = torch.device('cpu')

print(my_device)

cuda


In [3]:
# Data loading
transform_train = transforms.Compose([
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:05<00:00, 29.3MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
class PatchEmbed(nn.Module):
    def __init__(self, patch_size=4, in_channels=3, embed_dim=48):
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, H/patch_size, W/patch_size)
        x = x.flatten(2)  # (B, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
        return x

In [7]:
class TinyViT(nn.Module):
    def __init__(self):
        super().__init__()

        img_size = 32
        patch_size = 4
        in_channels = 3
        embed_dim = 48
        num_heads = 4
        dropout = 0.1
        num_layers = 4
        num_classes = 10
        mlp_ratio = 4.0

        self.patch_embed = PatchEmbed(
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim
        )

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Calculate number of patches for position embedding
        n_patches = (img_size // patch_size) ** 2  # Assuming 32x32 input
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))

        # Define encoder layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            batch_first=True,
            norm_first=True
        )

        # Create transformer encoder
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),  # or nn.ReLU()
            nn.Dropout(dropout),
            nn.Linear(embed_dim, num_classes)
        )
    def forward(self, x):
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)

        x = x + self.pos_embed

        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]  # Take cls token
        x = self.head(x)
        return x

In [8]:
model = TinyViT()

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

model = model.to(my_device)




In [None]:
for epoch in range(100):
    model.train()
    running_loss = 0.0

    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(my_device), labels.to(my_device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()


    print(f'Epoch [{epoch+1}], Step [{i+1}/{len(train_loader)}], '
            f'Loss: {running_loss/100:.4f}')
    running_loss = 0.0

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(my_device), labels.to(my_device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

Epoch [1], Step [391/391], Loss: 7.1850
Test Accuracy: 39.71%
Epoch [2], Step [391/391], Loss: 6.1883
Test Accuracy: 45.55%
Epoch [3], Step [391/391], Loss: 5.7046
Test Accuracy: 48.06%
Epoch [4], Step [391/391], Loss: 5.4545
Test Accuracy: 50.41%
Epoch [5], Step [391/391], Loss: 5.2948
Test Accuracy: 51.14%
Epoch [6], Step [391/391], Loss: 5.1815
Test Accuracy: 53.66%
Epoch [7], Step [391/391], Loss: 5.0526
Test Accuracy: 54.08%
Epoch [8], Step [391/391], Loss: 4.9464
Test Accuracy: 54.93%
Epoch [9], Step [391/391], Loss: 4.8681
Test Accuracy: 56.37%
Epoch [10], Step [391/391], Loss: 4.7644
Test Accuracy: 56.55%
Epoch [11], Step [391/391], Loss: 4.7315
Test Accuracy: 56.17%
Epoch [12], Step [391/391], Loss: 4.6344
Test Accuracy: 57.39%
Epoch [13], Step [391/391], Loss: 4.5733
Test Accuracy: 57.11%
Epoch [14], Step [391/391], Loss: 4.5229
Test Accuracy: 57.76%
Epoch [15], Step [391/391], Loss: 4.4779
Test Accuracy: 59.87%
Epoch [16], Step [391/391], Loss: 4.4258
Test Accuracy: 58.63%
E