## Data and problem statement

The MNIST dataset contains handwritten digits as gray-scale images with pixel sizes of 28-by-28. The pixel values are converted to float numbers and normalized with minimum-maximum scaling. The dataset is labeled with ten categories, represents digits of 0-9.

A supervised image classification problem is proposed to demonstrate the application of the Swin Transformer. By taking preprocessed grayscale images as inputs, the Swin Transformer is trained to classify the ten image labels.


In [1]:
!pip install timm




Defaulting to user installation because normal site-packages is not writeable


In [2]:
!pip install --upgrade jupyter ipywidgets


Defaulting to user installation because normal site-packages is not writeable


In [3]:
import timm
print(timm.__version__)  # To verify installation


  from .autonotebook import tqdm as notebook_tqdm


1.0.15


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import timm  
import os

In [5]:
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

In [6]:

# Enable cuDNN optimization for performance boost
torch.backends.cudnn.benchmark = True

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [7]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Reduce size for faster training
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])


In [8]:
# Load dataset 
train_dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root="./data", train=False, transform=transform, download=True)

In [9]:
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)



In [10]:

# Define Swin Transformer Model
class SwinModel(nn.Module):
    def __init__(self, num_classes=10):
        super(SwinModel, self).__init__()
        self.swin = timm.create_model("swin_s3_tiny_224", pretrained=True, num_classes=num_classes)


    def forward(self, x):
        return self.swin(x)

In [11]:
# Initialize model
model = SwinModel(num_classes=10).to(device)

In [12]:
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)  # AdamW is better for transformers
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)  # Reduce LR every 3 epochs

In [13]:
# Enable mixed precision training
scaler = torch.amp.GradScaler("cuda")

In [14]:
from torch.amp import GradScaler, autocast

scaler = GradScaler(device="cuda")  # Use correct syntax

def train_model(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            with autocast(device_type="cuda"):  # Mixed precision
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

            # Show progress every 10 batches
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")

        print(f"Epoch {epoch+1}/{epochs} completed. Average Loss: {total_loss / len(train_loader):.4f}")





In [15]:
train_model(model, train_loader, criterion, optimizer, epochs=1)


Epoch [1/1], Batch [0/30000], Loss: 2.1016
Epoch [1/1], Batch [10/30000], Loss: 1.6470
Epoch [1/1], Batch [20/30000], Loss: 2.0225
Epoch [1/1], Batch [30/30000], Loss: 1.3716
Epoch [1/1], Batch [40/30000], Loss: 3.6309
Epoch [1/1], Batch [50/30000], Loss: 2.3115
Epoch [1/1], Batch [60/30000], Loss: 3.3623
Epoch [1/1], Batch [70/30000], Loss: 2.6348
Epoch [1/1], Batch [80/30000], Loss: 1.6543
Epoch [1/1], Batch [90/30000], Loss: 2.4692
Epoch [1/1], Batch [100/30000], Loss: 2.1953
Epoch [1/1], Batch [110/30000], Loss: 2.9629
Epoch [1/1], Batch [120/30000], Loss: 1.8965
Epoch [1/1], Batch [130/30000], Loss: 1.6411
Epoch [1/1], Batch [140/30000], Loss: 3.6162
Epoch [1/1], Batch [150/30000], Loss: 2.0730
Epoch [1/1], Batch [160/30000], Loss: 2.0894
Epoch [1/1], Batch [170/30000], Loss: 2.0029
Epoch [1/1], Batch [180/30000], Loss: 2.9443
Epoch [1/1], Batch [190/30000], Loss: 2.3604
Epoch [1/1], Batch [200/30000], Loss: 2.4756
Epoch [1/1], Batch [210/30000], Loss: 2.6797
Epoch [1/1], Batch [2

In [16]:
def evaluate_model(model, test_loader, criterion):
    model.eval()  # Set to evaluation mode
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / len(test_loader)
    accuracy = 100 * correct / total
    print(f"Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

    return avg_loss, accuracy

# Run evaluation
evaluate_model(model, test_loader, criterion)





Test Loss: 2.3016, Accuracy: 11.35%


(2.301613800764084, 11.35)