In [26]:

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Custom Dataset Class
class MNISTDigitsDataset(Dataset):
    def __init__(self, root_dirs, transform=None):
        self.image_paths = []
        for root_dir in root_dirs:
            for file in os.listdir(root_dir):
                if file.endswith('.png'):
                    self.image_paths.append(os.path.join(root_dir, file))
        self.transform = transform


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]

        # Load the image
        image_name = os.path.basename(image_path)
        # Extract label from the filename (e.g., '0428.png' -> [0, 4, 2, 8])
        label = [int(digit) for digit in image_name.split('.')[0]]

        label_num = torch.sum(torch.tensor(label)).item()  # Sum the elements and convert to a Python scalar


        # print("label ",label_num, "sum ",sum(label))
        # Convert label to one-hot encoding (4 digits, 10 classes per digit)
        one_hot_label = torch.zeros(40, dtype=torch.long)  # 10 classes * 4 digits = 40
        for i, digit in enumerate(label):
            one_hot_label[i * 10 + digit] = 1  # Set the corresponding class to 1


        return one_hot_label.view(-1,4,10) , label_num ,image_name


# Create Dataset and DataLoader
#image_dir = r"./exterim/images"  # Update with your path
root_dirs = [
    r"./exterim/images",  # Update with your paths
    r"./exterim/images2"  # Add additional directories here
]

dataset = MNISTDigitsDataset(root_dirs=root_dirs)

from torch.utils.data import random_split
print("length of dataset",len(dataset))
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size



# DataLoader for batching
batch_size = 16


train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# DataLoader for training and test sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Example usage
for one_hot_label, label_num,label in train_loader:
    print(one_hot_label.shape)  # Should print torch.Size([32, 1, 40, 168])
    # print(label_num.shape)  # Should print torch.Size([32, 40]) for one-hot encoded labels

    print(one_hot_label[0])
    print(label_num[0])
    print(label[0])

    break


length of dataset 2775
torch.Size([16, 1, 4, 10])
tensor([[[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0]]])
tensor(25)
5857.png


In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MNISTSumModel(nn.Module):
    def __init__(self):
        super(MNISTSumModel, self).__init__()
        
        # MLP layers (flatten the 1x4x10 output)
        self.fc1 = nn.Linear(40, 64)  # 40 input features (4 digits * 10 classes)
        self.fc22 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)  # Output a single scalar (the sum of digits)

    def forward(self, x):
        # Apply softmax to each 10-length vector (per digit)
        x = x.float()
        x = F.softmax(x, dim=-1)
        
        # Flatten the input (4 digits * 10 classes)
        x = x.view(-1, 40)  # Shape becomes (batch_size, 40)

        # MLP layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc22(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # Single scalar output
        return x


In [28]:
import torch.optim as optim

# Instantiate the model
model = MNISTSumModel()

# Loss and optimizer
criterion = nn.MSELoss()  # Mean Squared Error Loss for regression
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 25  # Set number of epochs
for epoch in range(epochs):
    model.train()  # Set model to training mode
    
    running_loss = 0.0
    for one_hot_label, label_num, _ in train_loader:
        optimizer.zero_grad()  # Zero the gradients
        
        # Forward pass
        output = model(one_hot_label)
        
        # Compute loss (compare output to label_num)
        loss = criterion(output.squeeze(), label_num.float())  # Squeeze to remove extra dimensions
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")


Epoch [1/25], Loss: 104.8147
Epoch [2/25], Loss: 26.4548
Epoch [3/25], Loss: 16.7540
Epoch [4/25], Loss: 2.5075
Epoch [5/25], Loss: 0.0616
Epoch [6/25], Loss: 0.0073
Epoch [7/25], Loss: 0.0036
Epoch [8/25], Loss: 0.0025
Epoch [9/25], Loss: 0.0018
Epoch [10/25], Loss: 0.0014
Epoch [11/25], Loss: 0.0011
Epoch [12/25], Loss: 0.0009
Epoch [13/25], Loss: 0.0007
Epoch [14/25], Loss: 0.0006
Epoch [15/25], Loss: 0.0005
Epoch [16/25], Loss: 0.0004
Epoch [17/25], Loss: 0.0003
Epoch [18/25], Loss: 0.0003
Epoch [19/25], Loss: 0.0003
Epoch [20/25], Loss: 0.0002
Epoch [21/25], Loss: 0.0002
Epoch [22/25], Loss: 0.0002
Epoch [23/25], Loss: 0.0002
Epoch [24/25], Loss: 0.0001
Epoch [25/25], Loss: 0.0001


In [29]:
torch.save(model.state_dict(), 'checkpoints/decoder.pth') 

In [30]:
model.eval()  # Set model to evaluation mode
test_loss = 0.0
predictions = []  # List to store predicted values
actuals = []  # List to store actual values

with torch.no_grad():
    for one_hot_label, label_num, _ in test_loader:
        output = model(one_hot_label)
        
        # Calculate loss
        loss = criterion(output.squeeze(), label_num.float())
        test_loss += loss.item()
        
        # Append the actual and predicted values
        predictions.extend(output.squeeze().cpu().numpy())  # Convert tensor to numpy for easier printing
        actuals.extend(label_num.cpu().numpy())  # Convert tensor to numpy for easier printing

# Calculate average test loss
avg_test_loss = test_loss / len(test_loader)
print(f"Test Loss: {avg_test_loss:.4f}")

# Print actual vs predicted values
print("Actual vs Predicted:")
for actual, predicted in zip(actuals, predictions):
    print(f"Actual: {actual}, Predicted: {predicted:.4f}")


Test Loss: 0.0002
Actual vs Predicted:
Actual: 20, Predicted: 20.0027
Actual: 23, Predicted: 23.0012
Actual: 7, Predicted: 7.0282
Actual: 19, Predicted: 19.0077
Actual: 19, Predicted: 19.0017
Actual: 27, Predicted: 27.0027
Actual: 24, Predicted: 24.0052
Actual: 5, Predicted: 5.0744
Actual: 13, Predicted: 13.0038
Actual: 14, Predicted: 14.0032
Actual: 29, Predicted: 29.0070
Actual: 21, Predicted: 21.0086
Actual: 16, Predicted: 16.0037
Actual: 18, Predicted: 18.0062
Actual: 13, Predicted: 13.0070
Actual: 16, Predicted: 16.0095
Actual: 21, Predicted: 20.9999
Actual: 9, Predicted: 9.0041
Actual: 17, Predicted: 17.0069
Actual: 19, Predicted: 19.0091
Actual: 9, Predicted: 9.0028
Actual: 14, Predicted: 14.0045
Actual: 6, Predicted: 5.9957
Actual: 18, Predicted: 18.0033
Actual: 18, Predicted: 18.0053
Actual: 17, Predicted: 17.0017
Actual: 24, Predicted: 24.0103
Actual: 11, Predicted: 11.0039
Actual: 9, Predicted: 9.0067
Actual: 15, Predicted: 15.0054
Actual: 20, Predicted: 20.0105
Actual: 17, 

In [31]:
# Instantiate the model again
model = MNISTSumModel()

# Load the saved model weights
model.load_state_dict(torch.load('checkpoints/decoder.pth'))
model.eval()  # Set the model to evaluation mode for inference
print("Model loaded from decoder.pth")

# Example inference loop
with torch.no_grad():
    for one_hot_label, label_num, image_name in test_loader:
        output = model(one_hot_label)
        print(f"Predicted: {output}, Actual: {label_num}")



Model loaded from decoder.pth
Predicted: tensor([[20.0027],
        [23.0012],
        [ 7.0282],
        [19.0077],
        [19.0017],
        [27.0027],
        [24.0052],
        [ 5.0744],
        [13.0038],
        [14.0032],
        [29.0070],
        [21.0086],
        [16.0037],
        [18.0062],
        [13.0070],
        [16.0095]]), Actual: tensor([20, 23,  7, 19, 19, 27, 24,  5, 13, 14, 29, 21, 16, 18, 13, 16])
Predicted: tensor([[20.9999],
        [ 9.0041],
        [17.0069],
        [19.0091],
        [ 9.0028],
        [14.0045],
        [ 5.9957],
        [18.0033],
        [18.0053],
        [17.0017],
        [24.0103],
        [11.0039],
        [ 9.0067],
        [15.0054],
        [20.0105],
        [17.0038]]), Actual: tensor([21,  9, 17, 19,  9, 14,  6, 18, 18, 17, 24, 11,  9, 15, 20, 17])
Predicted: tensor([[19.0004],
        [12.0053],
        [20.0016],
        [25.0098],
        [21.0061],
        [19.0007],
        [21.0105],
        [21.0027],
        [16

  model.load_state_dict(torch.load('checkpoints/decoder.pth'))
