In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
from einops.layers.torch import Rearrange

# Define transformations and load datasets
transform = transforms.Compose([
    transforms.Grayscale(),  # ViT generally uses RGB, might consider using color images
    transforms.Resize((224, 224)),  # Adjust size for ViT, typically 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean and std if using RGB
])

train_dataset = datasets.ImageFolder(root='archive/train', transform=transform)
test_dataset = datasets.ImageFolder(root='archive/test', transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

# Check for device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Define a simple Vision Transformer Model
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size=16, in_channels=1, embed_size=1024):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # B, E, H/P, W/P
        x = x.flatten(2)  # B, E, N (N is number of patches)
        x = x.transpose(1, 2)  # B, N, E
        return x

class VisionTransformer(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=7, dim=1024, depth=6, heads=8, mlp_dim=2048):
        super(VisionTransformer, self).__init__()
        num_patches = (image_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_embedding = PatchEmbedding(patch_size, 1, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout=0.1),
            num_layers=depth
        )
        self.to_cls_token = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        p = self.patch_embedding(img)
        b, n, _ = p.shape
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_tokens, p), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.transformer(x)
        x = self.to_cls_token(x[:, 0])
        return self.mlp_head(x)

# Initialize model, optimizer, loss function, and scheduler
model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = ExponentialLR(optimizer, gamma=0.95)

# Training and validation loop, plotting omitted for brevity
# Refer to the previously provided code for the complete training loop and plotting
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

# Assuming the model, optimizer, criterion, and scheduler are already defined and initialized

# Lists to store metrics
train_losses = []
val_losses = []
learning_rates = []

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item() * images.size(0)
    
    # Calculate average losses
    avg_train_loss = total_train_loss / len(train_loader.dataset)
    train_losses.append(avg_train_loss)
    learning_rates.append(scheduler.get_last_lr()[0])

    # Validation phase
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_val_loss += loss.item() * images.size(0)
    avg_val_loss = total_val_loss / len(test_loader.dataset)
    val_losses.append(avg_val_loss)

    # Update learning rate
    scheduler.step()

    print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss}, Val Loss: {avg_val_loss}, Learning Rate: {scheduler.get_last_lr()[0]}')

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=0a047d2a-49ae-49e7-9442-011956428446' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>