# Training on MNIST data

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

# Training settings
batch_size = 64
test_batch_size = 1000
epochs = 10
lr = 1.0
gamma = 0.7
no_accel = False
dry_run = False
seed = 1
log_interval = 10
save_model = False
tensorboard_log_dir = 'runs/mnist_experiment_1'

# --- Model Definition ---
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(current_epoch, model, device, train_loader, optimizer, writer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                current_epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            global_step = (current_epoch - 1) * len(train_loader) + batch_idx
            writer.add_scalar('Loss/train', loss.item(), global_step)
            if dry_run:
                break


def test(model, device, test_loader, writer, current_epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), accuracy))
    writer.add_scalar('Loss/test', test_loss, current_epoch)
    writer.add_scalar('Accuracy/test', accuracy, current_epoch)


writer = SummaryWriter(tensorboard_log_dir)
print(f"TensorBoard logs will be saved to: {tensorboard_log_dir}")

if hasattr(torch, 'accelerator') and torch.accelerator.is_available() and not no_accel:
    use_accel = True
    device = torch.accelerator.current_accelerator()
    print(f"Using accelerator: {device}")
else:
    use_accel = False
    if not no_accel and torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using CUDA")
    elif not no_accel and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using MPS")
    else:
        device = torch.device("cpu")
        print("Using CPU")

torch.manual_seed(seed)

train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': test_batch_size}

if use_accel or (device.type != 'cpu'):
    common_kwargs = {'num_workers': 1, 'pin_memory': True}
    train_kwargs.update(common_kwargs)
    train_kwargs['shuffle'] = True # Keep shuffle True for train_loader
    test_kwargs.update(common_kwargs)
    test_kwargs['shuffle'] = False # Shuffle False for test_loader is standard
    print(f"Train Dataloader kwargs: {train_kwargs}")
    print(f"Test Dataloader kwargs: {test_kwargs}")


mnist_mean = (0.1307,)
mnist_std = (0.3081,)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mnist_mean, mnist_std)
])

try:
    dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
    dataset2 = datasets.MNIST('../data', train=False, transform=transform)
except Exception as e:
    print(f"Error downloading or loading dataset: {e}")
    writer.close()
    raise

vis_train_loader = torch.utils.data.DataLoader(dataset1, batch_size=batch_size, shuffle=True) # Can use existing train_loader too
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) # Original train_loader
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)


print("\nPlotting a few initial training examples...")
initial_examples = enumerate(vis_train_loader)
_, (initial_example_data, initial_example_targets) = next(initial_examples)

fig = plt.figure(figsize=(10, 4))
for i in range(6):
    if i >= len(initial_example_data): break
    plt.subplot(2, 3, i + 1)
    plt.tight_layout()
    img = initial_example_data[i][0] # Still normalized
    plt.imshow(img, cmap='gray', interpolation='none')
    plt.title("Ground Truth: {}".format(initial_example_targets[i]))
    plt.xticks([])
    plt.yticks([])

plt.show()

In [None]:
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

sample_input_for_graph = initial_example_data.to(device)
try:
    writer.add_graph(model, sample_input_for_graph)
    print("Model graph logged to TensorBoard.")
except Exception as e:
    print(f"Could not log model graph: {e}")


print(f"Starting training for {epochs} epochs...")
for epoch in range(1, epochs + 1):
    train(epoch, model, device, train_loader, optimizer, writer)
    test(model, device, test_loader, writer, epoch)
    scheduler.step()
    if dry_run:
        print("Dry run complete after one epoch.")
        break
print("Training finished.")

if save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")
    print("Model saved to mnist_cnn.pt")

print("\nShowing a few training examples with model predictions...")
model.eval() # Set model to evaluation mode

try:
    vis_batch_data, vis_batch_targets = next(iter(vis_train_loader))
except StopIteration:
    vis_train_loader = torch.utils.data.DataLoader(dataset1, batch_size=batch_size, shuffle=True)
    vis_batch_data, vis_batch_targets = next(iter(vis_train_loader))


vis_batch_data, vis_batch_targets = vis_batch_data.to(device), vis_batch_targets.to(device)

with torch.no_grad(): # No need to track gradients for inference
    output = model(vis_batch_data)
    predictions = output.argmax(dim=1, keepdim=False) # Get the index of the max log-probability

# Move data back to CPU for plotting if it was on GPU
vis_batch_data_cpu = vis_batch_data.cpu()
vis_batch_targets_cpu = vis_batch_targets.cpu()
predictions_cpu = predictions.cpu()

num_images_to_show = 9
fig = plt.figure(figsize=(9, 9))
plt.suptitle("Model Predictions on Training Examples", fontsize=16)
for i in range(num_images_to_show):
    if i >= len(vis_batch_data_cpu): break

    plt.subplot(3, 3, i + 1)
    plt.tight_layout(rect=[0, 0, 1, 0.96])

    image = vis_batch_data_cpu[i][0]

    plt.imshow(image, cmap='gray', interpolation='none')
    true_label = vis_batch_targets_cpu[i].item()
    predicted_label = predictions_cpu[i].item()

    title_color = 'green' if predicted_label == true_label else 'red'
    plt.title(f"True: {true_label}\nPred: {predicted_label}", color=title_color)
    plt.xticks([])
    plt.yticks([])

plt.show()
print("Finished showing predictions.")

# Close the TensorBoard SummaryWriter
writer.close()
print("TensorBoard writer closed.")