In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Custom Dataset Class
class MNISTDigitsDataset(Dataset):
    def __init__(self, root_dirs, transform=None):
        #self.image_dir = image_dir
        #self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
        #self.transform = transform

        self.image_paths = []
        for root_dir in root_dirs:
            for file in os.listdir(root_dir):
                if file.endswith('.png'):
                    self.image_paths.append(os.path.join(root_dir, file))
        self.transform = transform


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        #image_path = os.path.join(self.image_dir, image_name)
        


        # Load the image
        image = Image.open(image_path).convert('L')  # Convert to grayscale

        image_name = os.path.basename(image_path)
        # Extract label from the filename (e.g., '0428.png' -> [0, 4, 2, 8])
        label = [int(digit) for digit in image_name.split('.')[0]]
        
        # Convert label to one-hot encoding (4 digits, 10 classes per digit)
        one_hot_label = torch.zeros(40, dtype=torch.long)  # 10 classes * 4 digits = 40
        for i, digit in enumerate(label):
            one_hot_label[i * 10 + digit] = 1  # Set the corresponding class to 1
        
        # Apply transformations (if any)
        if self.transform:
            image = self.transform(image)
        
        return image, one_hot_label

# Define image transformations
transform = transforms.Compose([
#    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),  # Random rotation ±15° and shifts up to 10%
    transforms.Resize((40, 168)),  # Resize image to the correct size
    transforms.ToTensor(),         # Convert image to Tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize (for grayscale images)
])

# Create Dataset and DataLoader
#image_dir = r"./exterim/images"  # Update with your path
root_dirs = [
    r"./exterim/images",  # Update with your paths
    r"./exterim/images2"  # Add additional directories here
]

dataset = MNISTDigitsDataset(root_dirs=root_dirs, transform=transform)

from torch.utils.data import random_split
print("length of dataset",len(dataset))
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size


# DataLoader for batching
batch_size = 16

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# DataLoader for training and test sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Example usage
for images, labels in train_loader:
    print(images.shape)  # Should print torch.Size([32, 1, 40, 168])
    print(labels.shape)  # Should print torch.Size([32, 40]) for one-hot encoded labels

    print(labels[0])
    # Display the first image in the batch
    # plt.imshow(images[0].squeeze(0), cmap='gray')  # Remove the channel dimension for display
    # plt.show()

    break



#####################################



import torch
import torch.nn as nn
import torch.nn.functional as F

class MNISTDigitModel(nn.Module):
    def __init__(self, num_blocks, kernel_size, activation, pool, dropout):
        super(MNISTDigitModel, self).__init__()
        self.num_blocks = num_blocks
        self.kernel_size = kernel_size
        self.activation = activation
        self.pool = pool
        self.dropout = dropout
        
        layers = []
        in_channels = 1  # Grayscale input images
        out_channels = 64  # Initial number of filters
        
        # Add convolutional blocks
        for _ in range(num_blocks):
            layers.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding='same'),
                self._get_activation(activation),
                nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding='same'),
                self._get_activation(activation),
                self._get_pool(pool),
                nn.Dropout(dropout)
            ))
            in_channels = out_channels
            out_channels *= 2  # Double the filters after each block
        

            
        
        
        self.conv_blocks = nn.Sequential(*layers)
        
        # Dummy input to calculate the flattened size
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, 40, 168)
            flattened_size = self.conv_blocks(dummy_input).numel()
        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_size, 512),
            self._get_activation(activation),
            nn.Dropout(dropout),
            nn.Linear(512, 40)  # 40 output classes (10 per digit for 4 digits)
        )
    
    def _get_activation(self, activation):
        if activation == 'relu':
            return nn.ReLU()
        elif activation == 'sigmoid':
            return nn.Sigmoid()
        else:
            raise ValueError("Activation not supported")
    
    def _get_pool(self, pool):
        if pool == 'max':
            return nn.MaxPool2d(2)
        elif pool == 'avg':
            return nn.AvgPool2d(2)
        else:
            raise ValueError("Pooling method not supported")
    
    def forward(self, x):
        x = self.conv_blocks(x)
        x = self.fc(x)

        x = x.view(-1, 4, 10)
        return x

# Example usage
# model = MNISTDigitModel(num_blocks=1, kernel_size=3, activation='relu', pool='max', dropout=0.1)
# print(model)


########################################################




from torch import nn, optim


# Ensure the checkpoint directory exists
checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)


new_dropout = 0.1
model = MNISTDigitModel(num_blocks=5, kernel_size=3, activation='relu', pool='max', dropout=new_dropout)


# Load checkpoint if it exists
latest_checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_621.pth')
start_epoch = 0  # Default start epoch
if os.path.exists(latest_checkpoint_path):
    print(f"Loading checkpoint from {latest_checkpoint_path}...")
    checkpoint = torch.load(latest_checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Resuming training from epoch {start_epoch}...")


new_learning_rate = 0.00001
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=new_learning_rate)






print("training started")
# Training Loop
num_epochs = 2  # Set the number of epochs
for epoch in range(start_epoch,start_epoch+num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    for images, labels in train_loader:       
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Reshape labels to [batch_size, 4, 10] (assuming they are one-hot encoded)
        labels = labels.view(-1, 4, 10)  # Shape: [batch_size, 4, 10]

        # Forward pass
        outputs = model(images)  # Shape: [batch_size, 4, 10]

        # Compute loss for each digit independently
        loss = 0
        for i in range(4):  # Loop through each digit
            digit_labels = labels[:, i, :].argmax(dim=1)  # Convert one-hot to class index
            digit_outputs = outputs[:, i, :]

            # print("-- ",digit_labels)

            # CrossEntropyLoss expects [batch_size, num_classes] for inputs and [batch_size] for labels
            loss += criterion(digit_outputs, digit_labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate loss
        running_loss += loss.item()

    # Print loss for every epoch
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')


    checkpoint = {
        'epoch': epoch + 1,  # Save the next epoch number
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }

    # Save the checkpoint
    if epoch%10==0:
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")


        
    with torch.no_grad():  # Disable gradient computation for evaluation
        for images, labels in test_loader:
            # Forward pass
            outputs = model(images)  # Shape: [batch_size, 4, 10]
            # print(images.shape)
            # print(outputs.shape)
            
            labels = labels.view(-1, 4, 10)
            print(labels.shape) 
            original_label = torch.argmax(labels, dim=2) 
            print("Original :", original_label[:3])

            print("outut ",outputs[0].shape)
            predictions = torch.argmax(outputs, dim=2)  # Shape: [4] with the predicted digit for each group
            print("Predictions: ",predictions[:3])

            break
    
print("hi")

# Update the latest checkpoint pointer
latest_checkpoint_path = os.path.join(checkpoint_dir, 'latest_checkpoint.pth')
torch.save(checkpoint, latest_checkpoint_path)
print(f"Latest checkpoint saved to {latest_checkpoint_path}")






length of dataset 2775
torch.Size([16, 1, 40, 168])
torch.Size([16, 40])
tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])
Loading checkpoint from ./checkpoints\checkpoint_epoch_621.pth...
Resuming training from epoch 621...
training started


  checkpoint = torch.load(latest_checkpoint_path)


Epoch [622/2], Loss: 0.0847
torch.Size([16, 4, 10])
Original : tensor([[6, 2, 9, 6],
        [1, 4, 7, 3],
        [9, 0, 9, 3]])
outut  torch.Size([4, 10])
Predictions:  tensor([[6, 2, 9, 6],
        [1, 9, 7, 3],
        [9, 0, 9, 3]])
Epoch [623/2], Loss: 0.0589
torch.Size([16, 4, 10])
Original : tensor([[6, 2, 9, 6],
        [1, 4, 7, 3],
        [9, 0, 9, 3]])
outut  torch.Size([4, 10])
Predictions:  tensor([[6, 6, 9, 6],
        [1, 4, 7, 3],
        [9, 0, 9, 3]])
hi
Latest checkpoint saved to ./checkpoints\latest_checkpoint.pth
