In [2]:
!pip install -q dagshub mlflow

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.0/251.0 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.7/26.7 MB[0m [31m61.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.7/5.7 MB[0m [31m90.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.2/233.2 kB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m147.8/147.8 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.6/114.6 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.0/85.0 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import mlflow
import dagshub

# Initialize DagsHub
dagshub.init(repo_owner='s.carlosj.28', repo_name='moe_image_class', mlflow=True)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_epochs = 50
batch_size = 64
learning_rate = 0.001

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# Define the CNN model
class CNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize the model
model = CNNModel().to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training function
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(dataloader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

# Evaluation function
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = 100. * correct / total
    return accuracy

# MLflow experiment
mlflow.set_experiment("MNIST_CNN_Baseline")

with mlflow.start_run(run_name="cnn_baseline"):
    # Log parameters
    mlflow.log_param("model", "CNN")
    mlflow.log_param("num_epochs", num_epochs)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("learning_rate", learning_rate)

    # Training loop
    for epoch in range(num_epochs):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%")

        mlflow.log_metric("train_loss", train_loss, step=epoch)
        mlflow.log_metric("train_accuracy", train_acc, step=epoch)

        test_accuracy = evaluate(model, testloader, device)
        mlflow.log_metric("test_accuracy", test_accuracy, step=epoch)

    # Final evaluation
    test_accuracy = evaluate(model, testloader, device)
    print(f"Final Test Accuracy: {test_accuracy:.2f}%")
    mlflow.log_metric("final_test_accuracy", test_accuracy)

    # Save the model
    mlflow.pytorch.log_model(model, "model")

print("Training completed.")




Open the following link in your browser to authorize the client:
https://dagshub.com/login/oauth/authorize?state=65c9d552-a566-4b30-8258-d7a979d46e59&client_id=32b60ba385aa7cecf24046d8195a71c07dd345d9657977863b52e7748e0f0f28&middleman_request_id=bbcfb9e18996a770f2e84e05b3e0e3bbdc05cf3f5cc2da8ef75388e21bf00188




Output()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 47981755.41it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1866555.11it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2293912.44it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2934010.28it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw




2024/10/21 15:09:27 INFO mlflow.tracking.fluent: Experiment with name 'MNIST_CNN_Baseline' does not exist. Creating a new experiment.
100%|██████████| 938/938 [02:02<00:00,  7.68it/s]


Epoch 1/50, Loss: 0.1677, Accuracy: 95.04%


100%|██████████| 157/157 [00:12<00:00, 12.89it/s]
100%|██████████| 938/938 [02:08<00:00,  7.28it/s]


Epoch 2/50, Loss: 0.0446, Accuracy: 98.60%


100%|██████████| 157/157 [00:11<00:00, 13.25it/s]
100%|██████████| 938/938 [02:11<00:00,  7.15it/s]


Epoch 3/50, Loss: 0.0307, Accuracy: 99.04%


100%|██████████| 157/157 [00:11<00:00, 14.00it/s]
100%|██████████| 938/938 [02:12<00:00,  7.08it/s]


Epoch 4/50, Loss: 0.0237, Accuracy: 99.26%


100%|██████████| 157/157 [00:11<00:00, 13.93it/s]
100%|██████████| 938/938 [02:10<00:00,  7.21it/s]


Epoch 5/50, Loss: 0.0178, Accuracy: 99.40%


100%|██████████| 157/157 [00:11<00:00, 13.52it/s]
100%|██████████| 938/938 [02:10<00:00,  7.18it/s]


Epoch 6/50, Loss: 0.0138, Accuracy: 99.58%


100%|██████████| 157/157 [00:12<00:00, 12.83it/s]
100%|██████████| 938/938 [02:15<00:00,  6.93it/s]


Epoch 7/50, Loss: 0.0116, Accuracy: 99.63%


100%|██████████| 157/157 [00:12<00:00, 12.85it/s]
100%|██████████| 938/938 [02:09<00:00,  7.23it/s]


Epoch 8/50, Loss: 0.0093, Accuracy: 99.70%


100%|██████████| 157/157 [00:12<00:00, 12.72it/s]
100%|██████████| 938/938 [02:11<00:00,  7.15it/s]


Epoch 9/50, Loss: 0.0072, Accuracy: 99.76%


100%|██████████| 157/157 [00:12<00:00, 12.59it/s]
100%|██████████| 938/938 [02:10<00:00,  7.20it/s]


Epoch 10/50, Loss: 0.0071, Accuracy: 99.75%


100%|██████████| 157/157 [00:12<00:00, 12.62it/s]
100%|██████████| 938/938 [02:11<00:00,  7.15it/s]


Epoch 11/50, Loss: 0.0040, Accuracy: 99.86%


100%|██████████| 157/157 [00:12<00:00, 12.95it/s]
100%|██████████| 938/938 [02:14<00:00,  6.98it/s]


Epoch 12/50, Loss: 0.0047, Accuracy: 99.84%


100%|██████████| 157/157 [00:12<00:00, 12.63it/s]
100%|██████████| 938/938 [02:09<00:00,  7.24it/s]


Epoch 13/50, Loss: 0.0055, Accuracy: 99.83%


100%|██████████| 157/157 [00:12<00:00, 12.61it/s]
100%|██████████| 938/938 [02:11<00:00,  7.11it/s]


Epoch 14/50, Loss: 0.0063, Accuracy: 99.80%


100%|██████████| 157/157 [00:12<00:00, 12.82it/s]
100%|██████████| 938/938 [02:10<00:00,  7.19it/s]


Epoch 15/50, Loss: 0.0034, Accuracy: 99.90%


100%|██████████| 157/157 [00:12<00:00, 12.70it/s]
100%|██████████| 938/938 [02:11<00:00,  7.14it/s]


Epoch 16/50, Loss: 0.0042, Accuracy: 99.87%


100%|██████████| 157/157 [00:12<00:00, 12.71it/s]
100%|██████████| 938/938 [02:11<00:00,  7.16it/s]


Epoch 17/50, Loss: 0.0049, Accuracy: 99.83%


100%|██████████| 157/157 [00:12<00:00, 12.72it/s]
100%|██████████| 938/938 [02:10<00:00,  7.19it/s]


Epoch 18/50, Loss: 0.0029, Accuracy: 99.90%


100%|██████████| 157/157 [00:12<00:00, 12.97it/s]
100%|██████████| 938/938 [02:11<00:00,  7.13it/s]


Epoch 19/50, Loss: 0.0026, Accuracy: 99.92%


100%|██████████| 157/157 [00:12<00:00, 12.89it/s]
100%|██████████| 938/938 [02:10<00:00,  7.17it/s]


Epoch 20/50, Loss: 0.0045, Accuracy: 99.85%


100%|██████████| 157/157 [00:12<00:00, 12.78it/s]
100%|██████████| 938/938 [02:12<00:00,  7.08it/s]


Epoch 21/50, Loss: 0.0028, Accuracy: 99.92%


100%|██████████| 157/157 [00:12<00:00, 12.69it/s]
100%|██████████| 938/938 [02:10<00:00,  7.19it/s]


Epoch 22/50, Loss: 0.0013, Accuracy: 99.96%


100%|██████████| 157/157 [00:12<00:00, 12.81it/s]
100%|██████████| 938/938 [02:10<00:00,  7.16it/s]


Epoch 23/50, Loss: 0.0037, Accuracy: 99.88%


100%|██████████| 157/157 [00:11<00:00, 13.66it/s]
100%|██████████| 938/938 [02:12<00:00,  7.09it/s]


Epoch 24/50, Loss: 0.0028, Accuracy: 99.92%


100%|██████████| 157/157 [00:10<00:00, 14.53it/s]
100%|██████████| 938/938 [02:10<00:00,  7.20it/s]


Epoch 25/50, Loss: 0.0027, Accuracy: 99.91%


100%|██████████| 157/157 [00:12<00:00, 12.95it/s]
100%|██████████| 938/938 [02:12<00:00,  7.10it/s]


Epoch 26/50, Loss: 0.0037, Accuracy: 99.89%


100%|██████████| 157/157 [00:12<00:00, 12.97it/s]
100%|██████████| 938/938 [02:10<00:00,  7.18it/s]


Epoch 27/50, Loss: 0.0019, Accuracy: 99.94%


100%|██████████| 157/157 [00:10<00:00, 15.23it/s]
100%|██████████| 938/938 [02:11<00:00,  7.11it/s]


Epoch 28/50, Loss: 0.0017, Accuracy: 99.95%


100%|██████████| 157/157 [00:12<00:00, 12.93it/s]
100%|██████████| 938/938 [02:10<00:00,  7.18it/s]


Epoch 29/50, Loss: 0.0031, Accuracy: 99.91%


100%|██████████| 157/157 [00:12<00:00, 12.96it/s]
100%|██████████| 938/938 [02:11<00:00,  7.13it/s]


Epoch 30/50, Loss: 0.0024, Accuracy: 99.92%


100%|██████████| 157/157 [00:12<00:00, 12.72it/s]
100%|██████████| 938/938 [02:09<00:00,  7.22it/s]


Epoch 31/50, Loss: 0.0022, Accuracy: 99.93%


100%|██████████| 157/157 [00:11<00:00, 13.65it/s]
100%|██████████| 938/938 [02:10<00:00,  7.19it/s]


Epoch 32/50, Loss: 0.0005, Accuracy: 99.99%


100%|██████████| 157/157 [00:10<00:00, 14.63it/s]
100%|██████████| 938/938 [02:12<00:00,  7.10it/s]


Epoch 33/50, Loss: 0.0048, Accuracy: 99.86%


100%|██████████| 157/157 [00:11<00:00, 13.14it/s]
100%|██████████| 938/938 [02:10<00:00,  7.21it/s]


Epoch 34/50, Loss: 0.0019, Accuracy: 99.94%


100%|██████████| 157/157 [00:11<00:00, 13.34it/s]
100%|██████████| 938/938 [02:10<00:00,  7.18it/s]


Epoch 35/50, Loss: 0.0007, Accuracy: 99.97%


100%|██████████| 157/157 [00:12<00:00, 12.88it/s]
100%|██████████| 938/938 [02:10<00:00,  7.20it/s]


Epoch 36/50, Loss: 0.0022, Accuracy: 99.95%


100%|██████████| 157/157 [00:12<00:00, 12.64it/s]
100%|██████████| 938/938 [02:11<00:00,  7.11it/s]


Epoch 37/50, Loss: 0.0028, Accuracy: 99.91%


100%|██████████| 157/157 [00:12<00:00, 12.84it/s]
100%|██████████| 938/938 [02:11<00:00,  7.13it/s]


Epoch 38/50, Loss: 0.0029, Accuracy: 99.92%


100%|██████████| 157/157 [00:12<00:00, 12.85it/s]
100%|██████████| 938/938 [02:11<00:00,  7.12it/s]


Epoch 39/50, Loss: 0.0023, Accuracy: 99.94%


100%|██████████| 157/157 [00:14<00:00, 11.12it/s]
100%|██████████| 938/938 [02:11<00:00,  7.12it/s]


Epoch 40/50, Loss: 0.0006, Accuracy: 99.98%


100%|██████████| 157/157 [00:12<00:00, 12.78it/s]
100%|██████████| 938/938 [02:10<00:00,  7.17it/s]


Epoch 41/50, Loss: 0.0026, Accuracy: 99.92%


100%|██████████| 157/157 [00:12<00:00, 13.06it/s]
100%|██████████| 938/938 [02:12<00:00,  7.11it/s]


Epoch 42/50, Loss: 0.0020, Accuracy: 99.94%


100%|██████████| 157/157 [00:12<00:00, 12.92it/s]
100%|██████████| 938/938 [02:13<00:00,  7.04it/s]


Epoch 43/50, Loss: 0.0021, Accuracy: 99.93%


100%|██████████| 157/157 [00:12<00:00, 12.99it/s]
100%|██████████| 938/938 [02:13<00:00,  7.03it/s]


Epoch 44/50, Loss: 0.0014, Accuracy: 99.97%


100%|██████████| 157/157 [00:12<00:00, 12.65it/s]
100%|██████████| 938/938 [02:19<00:00,  6.74it/s]


Epoch 45/50, Loss: 0.0016, Accuracy: 99.95%


100%|██████████| 157/157 [00:11<00:00, 14.04it/s]
100%|██████████| 938/938 [02:16<00:00,  6.87it/s]


Epoch 46/50, Loss: 0.0028, Accuracy: 99.91%


100%|██████████| 157/157 [00:12<00:00, 12.86it/s]
100%|██████████| 938/938 [02:14<00:00,  6.97it/s]


Epoch 47/50, Loss: 0.0014, Accuracy: 99.96%


100%|██████████| 157/157 [00:14<00:00, 11.04it/s]
100%|██████████| 938/938 [02:19<00:00,  6.70it/s]


Epoch 48/50, Loss: 0.0023, Accuracy: 99.94%


100%|██████████| 157/157 [00:10<00:00, 15.57it/s]
100%|██████████| 938/938 [02:15<00:00,  6.94it/s]


Epoch 49/50, Loss: 0.0022, Accuracy: 99.94%


100%|██████████| 157/157 [00:12<00:00, 12.94it/s]
100%|██████████| 938/938 [02:19<00:00,  6.75it/s]


Epoch 50/50, Loss: 0.0009, Accuracy: 99.98%


100%|██████████| 157/157 [00:11<00:00, 13.73it/s]
100%|██████████| 157/157 [00:12<00:00, 12.88it/s]


Final Test Accuracy: 99.27%


2024/10/21 17:11:17 INFO mlflow.tracking._tracking_service.client: 🏃 View run cnn_baseline at: https://dagshub.com/s.carlosj.28/moe_image_class.mlflow/#/experiments/6/runs/a2ed479e5d444ab78f98022f890b4150.
2024/10/21 17:11:17 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://dagshub.com/s.carlosj.28/moe_image_class.mlflow/#/experiments/6.


Training completed.
