# Lab #4

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
from IPython.display import clear_output
import matplotlib.pyplot as plt
from picamera2 import Picamera2
from PIL import Image
import cv2


camera = Picamera2()
config = camera.create_still_configuration(main={"size": (300, 300)})
camera.configure(config)
clear_output()

### Train a Machine Learning model

- We load the MNIST dataset, which contains images of handwritten digits. The dataset is split into training and testing sets. We also apply transformations such as converting images to tensor format and normalizing them.

In [None]:
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
trainset = datasets.MNIST('data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data
testset = datasets.MNIST('data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)


- Let's print some examples from the dataset.

In [None]:
examples = enumerate(testloader)
batch_idx, (example_data, example_targets) = next(examples)


fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])

- Here, we define a simple neural network for digit classification. We use a typical structure with fully connected (dense) layers. The training process involves feeding the network with data, calculating loss, and updating model weights.

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

model = Net()
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003)

# Training loop
epochs = 5
losses = []
print('Training started. Please wait...')
for e in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    else:
        epoch_loss = running_loss / len(trainloader)
        print(f"Epoch: [{e+1}/{epochs}] | Training loss: {epoch_loss}")
        losses.append(epoch_loss)

# Plot the training loss
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.title('Training Loss Over Epochs')
plt.show()

# Save the model
torch.save(model.state_dict(), 'model.pth')

- After training the model, we evaluate its performance on the test dataset. Metrics like accuracy, precision, and recall are calculated to understand the model's effectiveness in classifying the digits.

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total}%')

- This section demonstrates a practical application of the trained model. We take a saved image of a handwritten digit, preprocess it to match the input requirements of our model, and then use our model to predict the digit.

In [None]:
def import_image(image_path, invert=True, contrast=1.5):
    image = cv2.imread(image_path)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    contrast = cv2.convertScaleAbs(gray, alpha=contrast, beta=0)
    if invert:
        contrast = cv2.bitwise_not(contrast)
    blurred = cv2.GaussianBlur(contrast, (9, 9), 0)
    _, image = cv2.threshold(blurred, 80, 255, cv2.THRESH_BINARY)
    resized = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
    # fill the edge pixels with 0
    resized[:, [0, -1]] = 0
    resized[:, [1, -2]] = 0
    resized[[0, -1], :] = 0
    resized[[1, -2], :] = 0
    return resized


def predict_digit(image):
    image = TF.to_tensor(image)
    image = TF.normalize(image, (0.5,), (0.5,))
    image = image.unsqueeze(0)  # Add batch dimension
    output = model(image)
    _, predicted = torch.max(output.data, 1)
    return predicted.item()


# Example usage
image1 = import_image('data/test.png', invert=True)
predicted_digit = predict_digit(image1)

image2 = import_image('data/test2.jpg', invert=True)
predicted_digit2 = predict_digit(image2)

fig = plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image1, cmap='gray')
plt.title(f'Predicted: {predicted_digit}')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image2, cmap='gray')
plt.title(f'Predicted: {predicted_digit2}')
plt.axis('off')
plt.show()


## Experiment

1. Try handwriting 3 different numbers on a white paper
2. Take a picture of each digit with the camera
3. Run the model for each picture 
4. Write the results
5. Calculate the accuracy

### Capture Image (Re-run this section to capture new images)

- Let's start with capturing an image with the camera and save it.

In [None]:
camera.start()
camera.capture_file("data/capture.jpg")
camera.close()
clear_output()


image = Image.open("data/capture.jpg")
plt.imshow(image)

### Test model

In [None]:
# Example usage
image = import_image('data/capture.jpg', invert=True)
predicted_digit = predict_digit(image)

plt.imshow(image, cmap='gray', interpolation='none')
plt.title("Prediction: {}".format(predicted_digit))
plt.xticks([])
plt.yticks([])
plt.show()

### RESULTS:
   
|       | **Actual** | **Predicted** |
|:-----:|:----------:|:-------------:|
| **1** |      _     |       _       |
| **2** |      _     |       _       |
| **3** |      _     |       _       |