In [1]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models

  from .autonotebook import tqdm as notebook_tqdm


In [6]:


# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


# Load dataset
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=2)

# Define ViT model
model = models.vit_b_32(pretrained=False, num_classes=10)
model = model.to(device)  # Move model to GPU if available
# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

# Training loop
for epoch in range(5):  # Adjust number of epochs as needed
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU if available
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:  # Print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

# Save model
torch.save(model.state_dict(), 'vit_model.pth')

Files already downloaded and verified


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


[1,   100] loss: 2.185
[1,   200] loss: 2.116
[1,   300] loss: 2.135
[1,   400] loss: 2.101
[1,   500] loss: 2.115
[1,   600] loss: 2.106
[1,   700] loss: 2.092
[1,   800] loss: 2.073
[1,   900] loss: 2.079
[1,  1000] loss: 2.083
[1,  1100] loss: 2.052
[1,  1200] loss: 2.133
[1,  1300] loss: 2.081
[1,  1400] loss: 2.059
[1,  1500] loss: 2.066
[2,   100] loss: 2.056
[2,   200] loss: 2.076
[2,   300] loss: 2.134
[2,   400] loss: 2.079
[2,   500] loss: 2.059
[2,   600] loss: 2.068
[2,   700] loss: 2.066
[2,   800] loss: 2.075
[2,   900] loss: 2.068
[2,  1000] loss: 2.030
[2,  1100] loss: 2.017
[2,  1200] loss: 2.038
[2,  1300] loss: 2.031
[2,  1400] loss: 2.004
[2,  1500] loss: 2.049
[3,   100] loss: 2.028
[3,   200] loss: 2.033
[3,   300] loss: 2.078
[3,   400] loss: 2.096
[3,   500] loss: 2.066
[3,   600] loss: 2.100
[3,   700] loss: 2.075
[3,   800] loss: 2.085
[3,   900] loss: 2.097
[3,  1000] loss: 2.071
[3,  1100] loss: 2.080
[3,  1200] loss: 2.076
[3,  1300] loss: 2.050
[3,  1400] 

In [8]:
# Assuming you have trained your ViT model and saved it as 'vit_model.pth'

import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the saved model
model = models.vit_b_32(pretrained=False, num_classes=10)
model.load_state_dict(torch.load('vit_model.pth'))
model = model.to(device)

# Define transformations for validation data
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load validation dataset
val_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=2)

# Set the model to evaluation mode
model.eval()

correct = 0
total = 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the validation images: %d %%' % (100 * correct / total))

Files already downloaded and verified
Accuracy of the network on the validation images: 22 %
