<a href="https://colab.research.google.com/github/Olveir/Digit-generationApp/blob/main/training_script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
# Train a model from scratch on MNIST and save it
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os

# Set up data
transform = transforms.ToTensor()
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Define a simple CNN
class DigitGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 64)
        )
        self.decoder = nn.Sequential(
            nn.Linear(64 + 10, 256),
            nn.ReLU(),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = self.encoder(x.view(-1, 28 * 28))
        x = torch.cat([x, labels], dim=1)
        x = self.decoder(x)
        return x.view(-1, 1, 28, 28)

model = DigitGenerator()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# One-hot encoding for labels
def one_hot(labels, num_classes=10):
    return torch.eye(num_classes)[labels]

# Train
for epoch in range(5):
    for images, labels in trainloader:
        labels_onehot = one_hot(labels)
        outputs = model(images, labels_onehot)
        loss = criterion(outputs, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# Save model
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/mnist_generator.pth")


100%|██████████| 9.91M/9.91M [00:00<00:00, 34.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.19MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.89MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.38MB/s]


Epoch 1, Loss: 0.0132
Epoch 2, Loss: 0.0075
Epoch 3, Loss: 0.0060
Epoch 4, Loss: 0.0055
Epoch 5, Loss: 0.0045
