In [2]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Custom Dataset Class
class MNISTDigitsDataset(Dataset):
    def __init__(self, root_dirs, transform=None):
        #self.image_dir = image_dir
        #self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
        #self.transform = transform

        self.image_paths = []
        for root_dir in root_dirs:
            for file in os.listdir(root_dir):
                if file.endswith('.png'):
                    self.image_paths.append(os.path.join(root_dir, file))
        self.transform = transform


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        #image_path = os.path.join(self.image_dir, image_name)
        


        # Load the image
        image = Image.open(image_path).convert('L')  # Convert to grayscale

        image_name = os.path.basename(image_path)
        # Extract label from the filename (e.g., '0428.png' -> [0, 4, 2, 8])
        label = [int(digit) for digit in image_name.split('.')[0]]
        

        #added
        label_num = torch.sum(torch.tensor(label)).item()  # Sum the elements and convert to a Python scalar


        # Convert label to one-hot encoding (4 digits, 10 classes per digit)
        one_hot_label = torch.zeros(40, dtype=torch.long)  # 10 classes * 4 digits = 40
        for i, digit in enumerate(label):
            one_hot_label[i * 10 + digit] = 1  # Set the corresponding class to 1
        
        # Apply transformations (if any)
        if self.transform:
            image = self.transform(image)
        
        return image, one_hot_label ,label_num ,image_name

# Define image transformations
transform = transforms.Compose([
#    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),  # Random rotation ±15° and shifts up to 10%
    transforms.Resize((40, 168)),  # Resize image to the correct size
    transforms.ToTensor(),         # Convert image to Tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize (for grayscale images)
])

# Create Dataset and DataLoader
#image_dir = r"./exterim/images"  # Update with your path
root_dirs = [
    r"./exterim/images",  # Update with your paths
    r"./exterim/images2"  # Add additional directories here
]

dataset = MNISTDigitsDataset(root_dirs=root_dirs, transform=transform)

from torch.utils.data import random_split
print("length of dataset",len(dataset))
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size


# DataLoader for batching
batch_size = 16

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# DataLoader for training and test sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Example usage
for images, labels ,label_num,image_name in train_loader:
    print(images.shape)  # Should print torch.Size([32, 1, 40, 168])
    print(labels.shape)  # Should print torch.Size([32, 40]) for one-hot encoded labels

    print(labels[0])
    print(label_num[0])
    print(image_name[0])
    # Display the first image in the batch
    # plt.imshow(images[0].squeeze(0), cmap='gray')  # Remove the channel dimension for display
    # plt.show()

    break


length of dataset 2775
torch.Size([16, 1, 40, 168])
torch.Size([16, 40])
tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
tensor(21)
5772.png


In [3]:
#class 1


import torch
import torch.nn as nn
import torch.nn.functional as F

class MNISTDigitModel(nn.Module):
    def __init__(self, num_blocks, kernel_size, activation, pool, dropout):
        super(MNISTDigitModel, self).__init__()
        self.num_blocks = num_blocks
        self.kernel_size = kernel_size
        self.activation = activation
        self.pool = pool
        self.dropout = dropout
        
        layers = []
        in_channels = 1  # Grayscale input images
        out_channels = 64  # Initial number of filters
        
        # Add convolutional blocks
        for _ in range(num_blocks):
            layers.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding='same'),
                self._get_activation(activation),
                nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding='same'),
                self._get_activation(activation),
                self._get_pool(pool),
                nn.Dropout(dropout)
            ))
            in_channels = out_channels
            out_channels *= 2  # Double the filters after each block
            
        
        self.conv_blocks = nn.Sequential(*layers)
        
        # Dummy input to calculate the flattened size
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, 40, 168)
            flattened_size = self.conv_blocks(dummy_input).numel()
        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_size, 512),
            self._get_activation(activation),
            nn.Dropout(dropout),
            nn.Linear(512, 40)  # 40 output classes (10 per digit for 4 digits)
        )
    
    def _get_activation(self, activation):
        if activation == 'relu':
            return nn.ReLU()
        elif activation == 'sigmoid':
            return nn.Sigmoid()
        else:
            raise ValueError("Activation not supported")
    
    def _get_pool(self, pool):
        if pool == 'max':
            return nn.MaxPool2d(2)
        elif pool == 'avg':
            return nn.AvgPool2d(2)
        else:
            raise ValueError("Pooling method not supported")
    
    def forward(self, x):
        x = self.conv_blocks(x)
        x = self.fc(x)

        x = x.view(-1, 4, 10)
        return x


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MNISTSumModel(nn.Module):
    def __init__(self):
        super(MNISTSumModel, self).__init__()
        
        # MLP layers (flatten the 1x4x10 output)
        self.fc1 = nn.Linear(40, 64)  # 40 input features (4 digits * 10 classes)
        self.fc22 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)  # Output a single scalar (the sum of digits)

    def forward(self, x):
        # Apply softmax to each 10-length vector (per digit)
        x = x.float()
        x = F.softmax(x, dim=-1)
        
        # Flatten the input (4 digits * 10 classes)
        x = x.view(-1, 40)  # Shape becomes (batch_size, 40)

        # MLP layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc22(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # Single scalar output
        return x


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


# Instantiate the models
new_dropout = 0.1
digit_model = MNISTDigitModel(num_blocks=5, kernel_size=3, activation='relu', pool='max', dropout=new_dropout)
sum_model = MNISTSumModel()


checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
latest_checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_621.pth')
start_epoch = 0
if os.path.exists(latest_checkpoint_path):
    print(f"Loading checkpoint from {latest_checkpoint_path}...")
    checkpoint = torch.load(latest_checkpoint_path)
    digit_model.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Resuming training from epoch {start_epoch}...")

    # for param in digit_model.parameters():
    #     param.requires_grad = False


sum_model_path = "checkpoints/decoder.pth"
if os.path.exists(sum_model_path):
    print(f"Loading pretrained MNISTSumModel from {sum_model_path}...")
    sum_model.load_state_dict(torch.load(sum_model_path))

    # for param in sum_model.parameters():
    #     param.requires_grad = False

Loading checkpoint from ./checkpoints\checkpoint_epoch_621.pth...
Resuming training from epoch 621...
Loading pretrained MNISTSumModel from checkpoints/decoder.pth...


  checkpoint = torch.load(latest_checkpoint_path)
  sum_model.load_state_dict(torch.load(sum_model_path))


In [7]:
# Set models to evaluation mode
digit_model.eval()
sum_model.eval()

# Lists to store results
all_predictions = []
all_actuals = []

# Disable gradient calculations for inference
with torch.no_grad():
    for batch_idx, (images, _, labels_sum, image_names) in enumerate(test_loader):
        # Step 1: Get output from digit model
        # Assuming digit_output has shape [batch_size, 4, 10]
        digit_output = digit_model(images)  # Shape: [batch_size, 4, 10]
        digit_output = digit_output.unsqueeze(1)  # Shape: [batch_size, 1, 4, 10]
        # print(digit_output.shape)

        # Prepare a list to hold the one-hot labels for each batch element
        one_hot_labels = []
        # Iterate through all batch elements
        for batch_idx in range(digit_output.shape[0]):
            predictions = torch.argmax(digit_output[batch_idx], dim=2)
            one_hot_label = torch.zeros(40, dtype=torch.long)
            for i, digit in enumerate(predictions[0]):
                one_hot_label[i * 10 + digit] = 1
            one_hot_label = one_hot_label.view(-1, 4, 10)
            one_hot_labels.append(one_hot_label)

        # Stack the one-hot labels into a batch of size [batch_size, 1, 4, 10]
        one_hot_labels = torch.stack(one_hot_labels, dim=0)
        print(one_hot_labels.shape)


        # Step 2: Feed the one-hot labels to the sum model
        sum_output = sum_model(one_hot_labels) 

        # Get predicted digit probabilities
        digit_probs = F.softmax(digit_output, dim=-1)  # Shape: [batch_size, 4, 10]
        
        # Get predicted digits (indices of maximum probability for each position)
        predicted_digits = torch.argmax(digit_probs, dim=-1)  # Shape: [batch_size, 4]
        
        # Get predicted sums
        predicted_sums = sum_output.squeeze()
        actual_sums = labels_sum.float()
        
        # Store predictions and actuals
        # For example, to append the sum of predictions for the batch
        all_predictions.append(predicted_sums.sum().item())  # Sum and convert to scalar
        all_actuals.append(actual_sums.sum().item())        # Sum and convert to scalar

        
        # Print results for this batch
        for i in range(len(images)):
            print(f"\nImage: {image_names[i]}")
            print(f"Predicted sum: {predicted_sums[i]:.2f}")
            print(f"Actual sum: {actual_sums[i]:.2f}")
            
        # Optional: break after first batch to see sample results
        break


torch.Size([16, 1, 4, 10])

Image: 6296.png
Predicted sum: 23.00
Actual sum: 23.00

Image: 2513.png
Predicted sum: 11.01
Actual sum: 11.00

Image: 1215.png
Predicted sum: 9.00
Actual sum: 9.00

Image: 6008.png
Predicted sum: 14.00
Actual sum: 14.00

Image: 5990.png
Predicted sum: 23.00
Actual sum: 23.00

Image: 8838.png
Predicted sum: 27.00
Actual sum: 27.00

Image: 5907.png
Predicted sum: 21.01
Actual sum: 21.00

Image: 3802.png
Predicted sum: 13.00
Actual sum: 13.00

Image: 3644.png
Predicted sum: 17.00
Actual sum: 17.00

Image: 9810.png
Predicted sum: 18.00
Actual sum: 18.00

Image: 9416.png
Predicted sum: 20.00
Actual sum: 20.00

Image: 6429.png
Predicted sum: 21.00
Actual sum: 21.00

Image: 2578.png
Predicted sum: 22.01
Actual sum: 22.00

Image: 6980.png
Predicted sum: 23.00
Actual sum: 23.00

Image: 3140.png
Predicted sum: 8.00
Actual sum: 8.00

Image: 7862.png
Predicted sum: 23.01
Actual sum: 23.00
