In [1]:
pip install timm


Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting timm
  Obtaining dependency information for timm from https://files.pythonhosted.org/packages/e7/0e/ef97f6d8c399bf5842af0dd5a4f5ac55b2f169d62e29ecbf7663e1cb1438/timm-1.0.9-py3-none-any.whl.metadata
  Downloading timm-1.0.9-py3-none-any.whl.metadata (42 kB)
     ---------------------------------------- 0.0/42.4 kB ? eta -:--:--
     --------- ------------------------------ 10.2/42.4 kB ? eta -:--:--
     -------------------------------------- 42.4/42.4 kB 412.3 kB/s eta 0:00:00
Downloading timm-1.0.9-py3-none-any.whl (2.3 MB)
   ---------------------------------------- 0.0/2.3 MB ? eta -:--:--
   ---------------------------------------- 0.0/2.3 MB ? eta -:--:--
   --- ------------------------------------ 0.2/2.3 MB 2.2 MB/s eta 0:00:01
   ------- -------------------------------- 0.5/2.3 MB 3.6 MB/s eta 0:00:01
   --------- ------------------------------ 0.6/2.3 MB 3.5 MB/s eta 0:00:01
   --------- ------

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])



trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')



from torch.utils.data import random_split

# Split the dataset into training and validation sets (e.g., 80% training, 20% validation)
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
trainset, valset = random_split(trainset, [train_size, val_size])

# Create dataloaders
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
valloader = DataLoader(valset, batch_size=64, shuffle=False, num_workers=2)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import timm  # For Vision Transformer (ViT) models

# Load the pretrained ViT-Tiny model using timm
vit_tiny = timm.create_model('vit_tiny_patch16_224', pretrained=True)

# Modify the final classification head to match 10 classes (Fashion-MNIST)
vit_tiny.head = nn.Linear(vit_tiny.head.in_features, 10)

# Move the model to the appropriate device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vit_tiny = vit_tiny.to(device)

# Define the optimizer and loss function
optimizer = optim.Adam(vit_tiny.head.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

model.safetensors:   0%|          | 0.00/22.9M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [7]:

# Since Fashion-MNIST has grayscale images (1 channel), 
# we need to repeat the single channel to 3 channels to fit into ViT.
def convert_grayscale_to_rgb(images):
    return images.repeat(1, 3, 1, 1)  # Repeat the 1 channel to 3 channels

# Training loop with validation
epochs = 5
for epoch in range(epochs):
    vit_tiny.train()
    running_loss = 0.0
    train_progress = tqdm(trainloader, desc=f'Epoch {epoch + 1}/{epochs}', leave=False)
    
    for i, data in enumerate(train_progress):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # Convert grayscale images to RGB (1 channel to 3 channels)
        inputs = convert_grayscale_to_rgb(inputs)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = vit_tiny(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Update running loss
        running_loss += loss.item()
        train_progress.set_postfix(loss=running_loss / (i + 1))
    
    # Evaluate the performance on the validation set
    vit_tiny.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in valloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Convert grayscale images to RGB
            inputs = convert_grayscale_to_rgb(inputs)
            
            outputs = vit_tiny(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_loss /= len(valloader)
    val_accuracy = 100 * correct / total
    print(f'Epoch {epoch + 1}/{epochs}, Validation Loss: {val_loss:.3f}, Validation Accuracy: {val_accuracy:.2f}%')

print('Finished Training')

  x = F.scaled_dot_product_attention(
                                                                        

Epoch 1/5, Validation Loss: 0.346, Validation Accuracy: 87.72%


                                                                        

Epoch 2/5, Validation Loss: 0.337, Validation Accuracy: 88.03%


                                                                        

Epoch 3/5, Validation Loss: 0.325, Validation Accuracy: 88.08%


                                                                        

Epoch 4/5, Validation Loss: 0.309, Validation Accuracy: 88.83%


                                                                        

Epoch 5/5, Validation Loss: 0.306, Validation Accuracy: 89.17%
Finished Training


In [10]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)  # Move data to the same device as the model
        images = convert_grayscale_to_rgb(images)

        outputs = vit_tiny(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')


Accuracy of the network on the 10000 test images: 88.42%
