# MNIST Example with Data Logging in DataFed


## Import Libraries


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from m3util.util.IO import make_folder
from datafed_torchflow.pytorch import TorchLogger


## Paramters to Update


## Builds the CNN


In [None]:
# Define the CNN architecture
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, padding=1
        )

        # Max pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Fully connected layers
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)  # Output layer for 10 classes (digits 0-9)

    def forward(self, x):
        # Apply convolutional layers with ReLU and max pooling
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        # Flatten the output
        x = x.view(-1, 64 * 7 * 7)

        # Apply fully connected layers with ReLU and final output
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x


## Define transformations for data preprocessing


In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),  # Convert images to PyTorch tensors
        transforms.Normalize(
            (0.1307,), (0.3081,)
        ),  # Normalize with mean and std of MNIST dataset
    ]
)


## Load the MNIST dataset


In [None]:
train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

# Create data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)


## Instantiate the model, loss function, and optimizer, and DataFed TorchLogger


In [None]:
suffix = "111024"
notebook_path = (
    "./4_pytorch_logger.ipynb"
)


criterion = nn.CrossEntropyLoss()  # Loss function for multi-class classification

learning_rate = 0.001
optimizer = optim.Adam(SimpleCNN().parameters(), lr=learning_rate)  # Adam optimizer

model_dict = {"model": SimpleCNN(), "optimizer": optimizer}

In [None]:
torchlogger = TorchLogger(
    model_dict=model_dict,
    DataFed_path=f"MEM679-Fall2024/Class/{suffix}",
    script_path=notebook_path,
    input_data_shape=train_dataset[0][0].shape,
    local_model_path=f"examples/model/{suffix}",
    logging=True,
)

## Training function

This function calls TorchLogger.save, which does the following:

1. Saves the model checkpoint
1. Identifies the approprate metadata for the model (including DataFed provenance dependencies)
1. Identifies and navigates to the approprate DataFed project and collection
1. Creates a DataFed data record with this metadata
1. Saves the model weights file or, gets the local zip file the user specified instead in order to upload multiple files to the same DataFed data record
1. Uploads the zip file to the DataFed data record generated in the previous steps


In [None]:
def train(
    model,
    device,
    train_loader,
    optimizer,
    criterion,
    epoch,
    base_local_file_name,
    local_vars,
):
    make_folder(base_local_file_name)  # ensure the path exists to save the weights

    model.train()  # Set the model to training mode

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()  # Zero the gradients

        # Forward pass
        output = model(data)
        loss = criterion(output, target)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(
                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
                f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
            )

    file_name = f"MNSIT_epoch_{epoch}_loss_{loss.item():.4e}"
    local_file_path = f"{base_local_file_name}/{file_name}.pkl"

    torchlogger.save(
        file_name,
        epoch=epoch,
        training_loss=loss.item(),
        local_file_path=local_file_path,
        local_vars=local_vars,
        model_hyperparameters={"learning_rate": learning_rate},
    )


## Testing function


In [None]:
def test(model, device, test_loader, criterion):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    correct = 0
    with torch.no_grad():  # Disable gradient calculation for evaluation
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            # Forward pass
            output = model(data)
            test_loss += criterion(output, target).item()  # Sum up the batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # Get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100.0 * correct / len(test_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
        f"({accuracy:.2f}%)\n"
    )


## Instantiate the DataFed Configuration


## Train Model


In [None]:
model = SimpleCNN()

In [None]:
# Train and test the CNN
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
n_epochs = 5
for epoch in range(1, n_epochs + 1):
    local_vars = locals()

    train(
        model=model,
        device=device,
        train_loader=train_loader,
        optimizer=optimizer,
        criterion=criterion,
        epoch=epoch,
        base_local_file_name="model/100124/weights",
        local_vars=list(local_vars.items()),
    )
    test(model=model, device=device, test_loader=test_loader, criterion=criterion)