In [2]:
!pip install -q dagshub mlflow

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.0/251.0 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.7/26.7 MB[0m [31m45.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.7/5.7 MB[0m [31m63.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.2/233.2 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m147.8/147.8 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.6/114.6 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.0/85.0 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import mlflow
import dagshub

# Initialize DagsHub
dagshub.init(repo_owner='s.carlosj.28', repo_name='moe_image_class', mlflow=True)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_epochs = 50
batch_size = 64
learning_rate = 0.001
num_layers = 3
num_heads = 4
d_model = 128
d_ff = 256

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# Define the Transformer model
class TransformerModel(nn.Module):
    def __init__(self, num_layers, num_heads, d_model, d_ff, num_classes):
        super().__init__()
        self.patch_embed = nn.Conv2d(1, d_model, kernel_size=8, stride=8)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=d_ff, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.fc(x)

# Initialize the model
model = TransformerModel(num_layers, num_heads, d_model, d_ff, num_classes=10).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training function
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(dataloader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

# Evaluation function
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = 100. * correct / total
    return accuracy

# MLflow experiment
mlflow.set_experiment("MNIST_Transformer_Baseline")

with mlflow.start_run(run_name="transformer_baseline"):
    # Log parameters
    mlflow.log_param("model", "Transformer")
    mlflow.log_param("num_layers", num_layers)
    mlflow.log_param("num_heads", num_heads)
    mlflow.log_param("d_model", d_model)
    mlflow.log_param("d_ff", d_ff)
    mlflow.log_param("num_epochs", num_epochs)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("learning_rate", learning_rate)

    # Training loop
    for epoch in range(num_epochs):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%")

        mlflow.log_metric("train_loss", train_loss, step=epoch)
        mlflow.log_metric("train_accuracy", train_acc, step=epoch)

        test_accuracy = evaluate(model, testloader, device)
        mlflow.log_metric("test_accuracy", test_accuracy, step=epoch)

    # Final evaluation
    test_accuracy = evaluate(model, testloader, device)
    print(f"Final Test Accuracy: {test_accuracy:.2f}%")
    mlflow.log_metric("final_test_accuracy", test_accuracy)

    # Save the model
    mlflow.pytorch.log_model(model, "model")

print("Training completed.")




Open the following link in your browser to authorize the client:
https://dagshub.com/login/oauth/authorize?state=6a2a33ff-37f0-4909-a0bd-1db487d386e2&client_id=32b60ba385aa7cecf24046d8195a71c07dd345d9657977863b52e7748e0f0f28&middleman_request_id=734065ff5cb2a75a3c3624e71f36592071cad06a5d87e1a3eb70fcf1fec0da71




Output()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 9373705.43it/s] 


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1199577.09it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 1867367.16it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1939577.35it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



2024/10/21 15:09:00 INFO mlflow.tracking.fluent: Experiment with name 'MNIST_Transformer_Baseline' does not exist. Creating a new experiment.
100%|██████████| 938/938 [01:52<00:00,  8.32it/s]


Epoch 1/50, Loss: 0.5852, Accuracy: 81.03%


100%|██████████| 157/157 [00:05<00:00, 26.51it/s]
100%|██████████| 938/938 [01:53<00:00,  8.25it/s]


Epoch 2/50, Loss: 0.2970, Accuracy: 90.56%


100%|██████████| 157/157 [00:06<00:00, 24.15it/s]
100%|██████████| 938/938 [01:50<00:00,  8.49it/s]


Epoch 3/50, Loss: 0.2469, Accuracy: 92.27%


100%|██████████| 157/157 [00:07<00:00, 21.94it/s]
100%|██████████| 938/938 [01:53<00:00,  8.27it/s]


Epoch 4/50, Loss: 0.2175, Accuracy: 93.05%


100%|██████████| 157/157 [00:08<00:00, 19.51it/s]
100%|██████████| 938/938 [01:56<00:00,  8.07it/s]


Epoch 5/50, Loss: 0.1949, Accuracy: 93.73%


100%|██████████| 157/157 [00:06<00:00, 26.12it/s]
100%|██████████| 938/938 [01:54<00:00,  8.21it/s]


Epoch 6/50, Loss: 0.1829, Accuracy: 94.20%


100%|██████████| 157/157 [00:06<00:00, 23.30it/s]
100%|██████████| 938/938 [01:52<00:00,  8.35it/s]


Epoch 7/50, Loss: 0.1714, Accuracy: 94.49%


100%|██████████| 157/157 [00:10<00:00, 15.12it/s]
100%|██████████| 938/938 [01:53<00:00,  8.24it/s]


Epoch 8/50, Loss: 0.1575, Accuracy: 94.96%


100%|██████████| 157/157 [00:06<00:00, 25.57it/s]
100%|██████████| 938/938 [02:00<00:00,  7.77it/s]


Epoch 9/50, Loss: 0.1497, Accuracy: 95.11%


100%|██████████| 157/157 [00:08<00:00, 18.69it/s]
100%|██████████| 938/938 [01:56<00:00,  8.04it/s]


Epoch 10/50, Loss: 0.1471, Accuracy: 95.23%


100%|██████████| 157/157 [00:06<00:00, 25.93it/s]
100%|██████████| 938/938 [01:54<00:00,  8.22it/s]


Epoch 11/50, Loss: 0.1356, Accuracy: 95.58%


100%|██████████| 157/157 [00:06<00:00, 25.90it/s]
100%|██████████| 938/938 [01:54<00:00,  8.17it/s]


Epoch 12/50, Loss: 0.1291, Accuracy: 95.82%


100%|██████████| 157/157 [00:08<00:00, 18.77it/s]
100%|██████████| 938/938 [01:56<00:00,  8.07it/s]


Epoch 13/50, Loss: 0.1263, Accuracy: 95.94%


100%|██████████| 157/157 [00:06<00:00, 22.43it/s]
100%|██████████| 938/938 [01:54<00:00,  8.17it/s]


Epoch 14/50, Loss: 0.1228, Accuracy: 95.96%


100%|██████████| 157/157 [00:06<00:00, 25.74it/s]
100%|██████████| 938/938 [01:55<00:00,  8.12it/s]


Epoch 15/50, Loss: 0.1167, Accuracy: 96.20%


100%|██████████| 157/157 [00:08<00:00, 19.02it/s]
100%|██████████| 938/938 [01:55<00:00,  8.10it/s]


Epoch 16/50, Loss: 0.1135, Accuracy: 96.19%


100%|██████████| 157/157 [00:07<00:00, 20.41it/s]
100%|██████████| 938/938 [01:54<00:00,  8.22it/s]


Epoch 17/50, Loss: 0.1129, Accuracy: 96.27%


100%|██████████| 157/157 [00:06<00:00, 25.76it/s]
100%|██████████| 938/938 [01:54<00:00,  8.16it/s]


Epoch 18/50, Loss: 0.1047, Accuracy: 96.51%


100%|██████████| 157/157 [00:06<00:00, 23.26it/s]
100%|██████████| 938/938 [01:55<00:00,  8.10it/s]


Epoch 19/50, Loss: 0.1033, Accuracy: 96.57%


100%|██████████| 157/157 [00:08<00:00, 19.09it/s]
100%|██████████| 938/938 [01:56<00:00,  8.03it/s]


Epoch 20/50, Loss: 0.1014, Accuracy: 96.67%


100%|██████████| 157/157 [00:06<00:00, 25.50it/s]
100%|██████████| 938/938 [01:57<00:00,  7.98it/s]


Epoch 21/50, Loss: 0.0955, Accuracy: 96.92%


100%|██████████| 157/157 [00:08<00:00, 19.41it/s]
100%|██████████| 938/938 [01:57<00:00,  7.95it/s]


Epoch 22/50, Loss: 0.0951, Accuracy: 96.90%


100%|██████████| 157/157 [00:06<00:00, 25.44it/s]
100%|██████████| 938/938 [01:56<00:00,  8.08it/s]


Epoch 23/50, Loss: 0.0948, Accuracy: 96.88%


100%|██████████| 157/157 [00:07<00:00, 22.28it/s]
100%|██████████| 938/938 [01:57<00:00,  7.96it/s]


Epoch 24/50, Loss: 0.0899, Accuracy: 97.13%


100%|██████████| 157/157 [00:07<00:00, 21.95it/s]
100%|██████████| 938/938 [01:55<00:00,  8.14it/s]


Epoch 25/50, Loss: 0.0863, Accuracy: 97.12%


100%|██████████| 157/157 [00:06<00:00, 25.01it/s]
100%|██████████| 938/938 [01:57<00:00,  8.00it/s]


Epoch 26/50, Loss: 0.0863, Accuracy: 97.16%


100%|██████████| 157/157 [00:08<00:00, 17.89it/s]
100%|██████████| 938/938 [01:59<00:00,  7.82it/s]


Epoch 27/50, Loss: 0.0845, Accuracy: 97.20%


100%|██████████| 157/157 [00:07<00:00, 19.83it/s]
100%|██████████| 938/938 [01:57<00:00,  7.99it/s]


Epoch 28/50, Loss: 0.0843, Accuracy: 97.20%


100%|██████████| 157/157 [00:06<00:00, 24.33it/s]
100%|██████████| 938/938 [01:56<00:00,  8.08it/s]


Epoch 29/50, Loss: 0.0806, Accuracy: 97.41%


100%|██████████| 157/157 [00:08<00:00, 19.39it/s]
100%|██████████| 938/938 [01:57<00:00,  7.97it/s]


Epoch 30/50, Loss: 0.0805, Accuracy: 97.35%


100%|██████████| 157/157 [00:06<00:00, 25.53it/s]
100%|██████████| 938/938 [01:55<00:00,  8.12it/s]


Epoch 31/50, Loss: 0.0805, Accuracy: 97.28%


100%|██████████| 157/157 [00:08<00:00, 18.89it/s]
100%|██████████| 938/938 [01:57<00:00,  7.97it/s]


Epoch 32/50, Loss: 0.0760, Accuracy: 97.36%


100%|██████████| 157/157 [00:07<00:00, 21.72it/s]
100%|██████████| 938/938 [01:57<00:00,  8.00it/s]


Epoch 33/50, Loss: 0.0755, Accuracy: 97.48%


100%|██████████| 157/157 [00:08<00:00, 19.50it/s]
100%|██████████| 938/938 [01:56<00:00,  8.04it/s]


Epoch 34/50, Loss: 0.0739, Accuracy: 97.49%


100%|██████████| 157/157 [00:06<00:00, 25.15it/s]
100%|██████████| 938/938 [01:57<00:00,  7.97it/s]


Epoch 35/50, Loss: 0.0761, Accuracy: 97.50%


100%|██████████| 157/157 [00:08<00:00, 19.31it/s]
100%|██████████| 938/938 [01:54<00:00,  8.22it/s]


Epoch 36/50, Loss: 0.0682, Accuracy: 97.72%


100%|██████████| 157/157 [00:07<00:00, 20.08it/s]
100%|██████████| 938/938 [01:55<00:00,  8.12it/s]


Epoch 37/50, Loss: 0.0692, Accuracy: 97.70%


100%|██████████| 157/157 [00:06<00:00, 25.45it/s]
100%|██████████| 938/938 [01:58<00:00,  7.91it/s]


Epoch 38/50, Loss: 0.0688, Accuracy: 97.71%


100%|██████████| 157/157 [00:08<00:00, 19.46it/s]
100%|██████████| 938/938 [01:55<00:00,  8.09it/s]


Epoch 39/50, Loss: 0.0648, Accuracy: 97.89%


100%|██████████| 157/157 [00:07<00:00, 21.04it/s]
100%|██████████| 938/938 [01:58<00:00,  7.94it/s]


Epoch 40/50, Loss: 0.0679, Accuracy: 97.71%


100%|██████████| 157/157 [00:07<00:00, 21.32it/s]
100%|██████████| 938/938 [01:57<00:00,  8.02it/s]


Epoch 41/50, Loss: 0.0639, Accuracy: 97.85%


100%|██████████| 157/157 [00:06<00:00, 25.01it/s]
100%|██████████| 938/938 [01:57<00:00,  8.01it/s]


Epoch 42/50, Loss: 0.0659, Accuracy: 97.85%


100%|██████████| 157/157 [00:07<00:00, 19.71it/s]
100%|██████████| 938/938 [01:59<00:00,  7.85it/s]


Epoch 43/50, Loss: 0.0635, Accuracy: 97.88%


100%|██████████| 157/157 [00:07<00:00, 21.30it/s]
100%|██████████| 938/938 [01:57<00:00,  7.96it/s]


Epoch 44/50, Loss: 0.0625, Accuracy: 97.89%


100%|██████████| 157/157 [00:06<00:00, 24.45it/s]
100%|██████████| 938/938 [01:57<00:00,  7.98it/s]


Epoch 45/50, Loss: 0.0618, Accuracy: 97.96%


100%|██████████| 157/157 [00:07<00:00, 20.48it/s]
100%|██████████| 938/938 [01:58<00:00,  7.92it/s]


Epoch 46/50, Loss: 0.0592, Accuracy: 98.05%


100%|██████████| 157/157 [00:06<00:00, 25.29it/s]
100%|██████████| 938/938 [01:56<00:00,  8.06it/s]


Epoch 47/50, Loss: 0.0621, Accuracy: 97.88%


100%|██████████| 157/157 [00:06<00:00, 22.87it/s]
100%|██████████| 938/938 [01:59<00:00,  7.86it/s]


Epoch 48/50, Loss: 0.0591, Accuracy: 97.97%


100%|██████████| 157/157 [00:06<00:00, 24.94it/s]
100%|██████████| 938/938 [01:58<00:00,  7.91it/s]


Epoch 49/50, Loss: 0.0579, Accuracy: 98.06%


100%|██████████| 157/157 [00:08<00:00, 18.97it/s]
100%|██████████| 938/938 [01:58<00:00,  7.91it/s]


Epoch 50/50, Loss: 0.0587, Accuracy: 98.03%


100%|██████████| 157/157 [00:06<00:00, 25.30it/s]
100%|██████████| 157/157 [00:08<00:00, 18.80it/s]


Final Test Accuracy: 97.28%


2024/10/21 16:52:58 INFO mlflow.tracking._tracking_service.client: 🏃 View run transformer_baseline at: https://dagshub.com/s.carlosj.28/moe_image_class.mlflow/#/experiments/5/runs/3a081420cbc640f996d3d491845d345f.
2024/10/21 16:52:58 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://dagshub.com/s.carlosj.28/moe_image_class.mlflow/#/experiments/5.


Training completed.
