In [None]:
%pip install torch torchvision tensorboard

In [None]:
from torch.utils.tensorboard import SummaryWriter
import datetime

# Create a unique directory name for this run
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(log_dir)

In [None]:
def train(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    
    train_loss = running_loss / len(train_loader)
    train_accuracy = 100. * correct / total
    
    # Log the training loss and accuracy
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Accuracy/train', train_accuracy, epoch)
    
    return train_loss, train_accuracy

In [None]:
def evaluate(model, test_loader, criterion, device, epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / total
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)\n')
    
    # Log the test loss and accuracy
    writer.add_scalar('Loss/test', test_loss, epoch)
    writer.add_scalar('Accuracy/test', accuracy, epoch)
    
    return accuracy

In [None]:
best_accuracy = 0
for epoch in range(1, num_epochs + 1):
    train_loss, train_accuracy = train(model, train_loader, criterion, optimizer, device, epoch)
    test_accuracy = evaluate(model, test_loader, criterion, device, epoch)
    
    # Log the learning rate
    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)
    
    if test_accuracy > best_accuracy:
        best_accuracy = test_accuracy
        torch.save(model.state_dict(), 'best_vit_mnist.pth')

print(f'Best accuracy: {best_accuracy:.2f}%')

# Close the TensorBoard writer
writer.close()

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/fit