# Introduction :

Fits files are a common format for storing astronomical data. They are used by many telescopes and observatories. The format is very flexible and can be used to store many different types of data. 
We have a dataset of fits images from a telescope with a timer, our aim from this notebook is to do Digits Recognition on the timer using CNN.

## Datasets :

For training our 1st model we will use the famous [MNIST](http://yann.lecun.com/exdb/mnist/) dataset as first try then we will use a labeled dataset of the timer images.

For the 2nd model we will use the same dataset of the telescope camera images, this data itself is from a private source so we can't share it, only a segment of it is shared in this notebook.

## Content :

1. [Importing Libraries](#1)
2. [Loading and Preprocessing Data](#2)
3. [MNIST Model](#3)
4. [Custom Model](#4)
5. [Conclusion](#5)
 

# 1-Importing Libraries 

In [None]:

# fits manipulation 
from astropy.io import fits
from astropy.visualization import simple_norm

# image processing
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

# Modeling
import torch

from torchvision.datasets import MNIST
from torchvision import transforms

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Evaluation

from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score
import seaborn as sns

## 2 - Loading and Preprocessing Data 

### 2.1 - Processing the Timer Images

In [None]:
# Define the path to your FITS file
fits_file_path = 'fits_file/selected_images.fits'

# Open the FITS file
hdulist = fits.open(fits_file_path)

The fits file is made of ```HDUs (Header Data Units)```, each HDU can contain a table or an image, our case only contains images, each HDU has a header and a data part, the header contains information about the image like the date, the telescope used, the exposure time, etc. The data part contains the image itself.

In [None]:
hdus = [hdulist[i].data for i in range(len(hdulist))]


plt.imshow(hdus[0], cmap='gray', norm=simple_norm(hdus[0], 'sqrt', percent=99.5))
plt.show()

**Our aim now is to extract the time from the images, 1st step is to crop the image just to include the timer**

In [None]:
for i in hdus[0:3]:
    roi = i[-80:-10, 30:270]
    plt.imshow(roi, cmap='gray', norm=simple_norm(i, 'sqrt', percent=99.5))
    plt.show()

**The letters look a bit overlapped, let's create a function to separate them**

In [None]:
def replace_columns_with_zeros(img):
    # Get the pixel data as a NumPy array
    img_data = np.array(img)

    # Replace the specwwified columns with zeros
    # Start from the highest index
    indices = [225,206,197,187,172,160,149,142,132,120,114,103,86,75,64,55,46,37,27,18]
    for index in indices :
        if index == 172 or index == 86:
            for _ in range(15):
                img_data[:,index] = 0
                img_data = np.insert(img_data, index, 0, axis=1)
        else:
            for _ in range(7):
                img_data[:,index] = 0
                img_data = np.insert(img_data, index, 0, axis=1)


    # Convert the NumPy array back to an image
    img = Image.fromarray(img_data)

    return img


# now let's test the function

roi = hdus[0][-80:-10, 30:270]
plt.imshow(roi, cmap='gray', norm=simple_norm(i, 'sqrt', percent=99.5))
plt.imsave('roi.png', roi, cmap='gray')
plt.title('Original')
plt.show()
roi = replace_columns_with_zeros(roi)
plt.imshow(roi, cmap='gray', norm=simple_norm(i, 'sqrt', percent=99.5))
plt.title('With Gaps')
plt.show()

Now we're gonna create a function that does general preprocessing on the images, it will do the following :

1. Inhance the contrast of the image : for our case the time is the white part of the image, so we will make it more white and the background more black.

2. Convert the image to grayscale : so we can use some methods that only work on grayscale images.

3. ```Thresholding``` : which means reducing the possible intensity values to 2 values, 0 and 255. that's done by setting a threshold value, all the pixels with intensity values less than the threshold are set to 0 and all the pixels with intensity values greater than the threshold are set to 255.

4. ```Deblurring``` : the images are a bit blurry, so we will use a deblurring filter to make them clearer.

In [None]:
def process_image(img):


        # Convert the ROI to a PIL Image (for easy handling)
        roi_image = Image.fromarray(img)

        # Set the DPI (dots per inch) for the PIL image
        roi_image.info['dpi'] = (500, 500)  # Set your desired DPI here
        
        #call the replace_columns function
        roi_image = replace_columns_with_zeros(roi_image)

        # Enhance the image
        enhancer = ImageEnhance.Contrast(roi_image)
        roi_image = enhancer.enhance(2)

        # Convert the image ti grayscale
        roi_image = roi_image.convert('L')
        
        # Apply direct thresholding
        threshold_value = 75
        roi_image = roi_image.point(lambda x: 0 if x < threshold_value else x)
        
      
        roi_image = roi_image.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=1))
        
        return roi_image
        
# Now let's test the function
roi = hdus[0][-80:-10, 30:270]
process_image(roi)

Even with the normalization the background with all those stars is still a problem, let's try to remove it. to do so let's write a function ```custom_process``` to deal with the major issues in the images.

Another custom method we will do is removing the noise that's resulted from thebetween-digits gaps we created, we can see some small groups of white pixels that doesn't belong to any digit, we will remove them by algorithm thath cheks the size of the white groups and remove the ones that are smaller than a certain threshold.


**Here's how would the process function look like :**

In [None]:
def custom_process(image):
    
    
    # Convert the image to a NumPy array
    image_data = np.array(image)
    
    # Create a copy of the image data to modify
    modified_data = np.copy(image_data)

    # Define the threshold value
    threshold_value = 55

    # Define the size of the surrounding matrix
    matrix_size = (4, 3)

    # Define the size of the surrounding matrix for the 3x3 matrix
    matrix_size_3x3 = 3

    # Define the minimum number of pixels in the surrounding matrix that must be above the threshold value
    min_pixels_above_threshold = 8

    # Iterate over each pixel in the image
    for i in range(matrix_size[0], image_data.shape[0] - matrix_size[0]):
        for j in range(matrix_size[1], image_data.shape[1] - matrix_size[1]):
            
            if image_data[i, j] > threshold_value and image_data[i, j] < 100:
                
                surrounding_matrix = image_data[i - matrix_size[0]:i + matrix_size[0] + 1, j - matrix_size[1]:j + matrix_size[1] + 1]
                surrounding_matrix_3x3 = image_data[i - matrix_size_3x3:i + matrix_size_3x3 + 1, j - matrix_size_3x3:j + matrix_size_3x3 + 1]

                # Check if at least min_pixels_above_threshold pixels in the surrounding matrix are above the threshold value
                if np.sum(surrounding_matrix > image_data[i, j] + 55) < min_pixels_above_threshold:

                    modified_data[i - matrix_size[0]:i + matrix_size[0] + 1, j - matrix_size[1]:j + matrix_size[1] + 1] = 0
                    modified_data[i - matrix_size_3x3:i + matrix_size_3x3 + 1, j - matrix_size_3x3:j + matrix_size_3x3 + 1] = 0

    
    modified_image = Image.fromarray(modified_data)

    return modified_image

**We can now create a loop that extracts the images from the ````fits``` file and apply the processing functions on them, let's view the results :**

In [None]:
for i, hdu in enumerate(hdulist[:3]):
    if hdu.data is not None:
        
        # Extract the image data as a NumPy array
        image_data = hdu.data

        # Crop the image to the defined ROI (bottom left)
        roi = image_data[-80:-10, 30:270]
        
        #the algorithm
        roi_image = process_image(roi)

        
        plt.imshow(roi_image,cmap='gray')
        plt.show()
        
        # Custom processing
        roi_image = custom_process(roi_image)
        
        # Save the image
        roi_image.save('roi_{}.png'.format(i))

### 2.2 - Labeling the Timer Images

The original fits file is too large, and we used this part of it cause it contains in the metadata the time the image was taken, so we can use it as a label for our images, let's extract the time from the metadata and add it to the images.

In [None]:
# digit extractor function

def extract_digits(hdu,image,j,labels):

    pattern = hdu.header['DATE']
    
    # create a list that only contains the figits in pattern
    digits = [int(char) for char in pattern[-8:] if char.isdigit()]
    # indices of the last 3 digits in the image
    indices = [[320,336],[337,353],[352,368]]


    for i,index in enumerate(indices):
        # Crop the image
        digit_image = image.crop((index[0], 18, index[1], 54))

        # convert the image to a NumPy array
        digit_image = np.array(digit_image)
            
        # add 6 dark columns to the left and right of the image 
        digit_image = np.pad(digit_image, ((0, 0), (6, 6)), 'constant', constant_values=(0, 0))
        
        
        # save the image to a folder called 'digits data'
        #plt.imsave('digits_data/{}/{}.png'.format(digits[i],len(labels[digits[i]])+i), digit_image, cmap='gray')
        
        # save it's label to a list called 
        labels[digits[i]].append(len(labels[digits[i]])+i)
        
    return labels    
        

In [None]:
all_labels = {0:[],1:[],2:[],3:[],4:[],5:[],6:[],7:[],8:[],9:[]}
for j, hdu in enumerate(hdulist[::2]):   #we only take pair HDUs for a reason described down in "Important Note"
    roi = hdu.data[-80:-10, 30:270]
    roi_image = process_image(roi)
    roi_image = custom_process(roi_image)
    all_labels = extract_digits(hdu, roi_image, j, all_labels)



In [None]:
#printing the number of digits in each class
for i in range(10):
    print('Number of {}s: {}'.format(i, len(all_labels[i])))

**Now evrything is set for the fits extraction and preprocessing, now we will move to the MNIST dataset and create our model.**

#### Important  Note :
We took the pair HDUs cause the physicians who gave us the data used a software that fills up the time in the metadata by adding the time of the previous image to the exposure time, but this method isn't accurate cause there's no stable time between the images, so some of the DATE in the headers we used are wrongs and we had to clean them after storing them in each class folder.

The good thing that the folder resulted ```new_folder``` contains the images with tehir correct labels, so we can use it directly without the need to extract the time from the metadata, but if you want to use the original fits file you can use the code below to extract the time from the metadata and add it to the images then clean it like we did.

# 3. MNIST Model

### A - Buidling the Model

The MNIST dataset is a dataset of handwritten digits, it contains 60000 images for training and 10000 images for testing, each image is 28x28 pixels, we will use it to train our model then we will use the timer images to test it.

In [None]:
# Defineing the CNN model
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(7*7*64, 128)
        self.fc2 = nn.Linear(128, 10)  

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x



### B - Training the Model

In [None]:
# Define the training process
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()


def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100 * correct / total
    return accuracy


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


model = CNNModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
num_epochs = 1  # changing it won't make a difference
#for epoch in range(num_epochs):                              # uncoment this for simpler output
 #   train(model, train_loader, optimizer, criterion, device)
  #  accuracy = evaluate(model, test_loader, device)
   # print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {accuracy:.2f}%")


In [None]:
# Evaluation 

train_accuracy_list = []
test_accuracy_list = []


num_epochs = 3  # you can change it just to see the differences in the plot 
for epoch in range(num_epochs):
    train(model, train_loader, optimizer, criterion, device)
    train_accuracy = evaluate(model, train_loader, device)
    test_accuracy = evaluate(model, test_loader, device)
    train_accuracy_list.append(train_accuracy)
    test_accuracy_list.append(test_accuracy)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%")


epochs = range(1, num_epochs+1)
plt.plot(epochs, train_accuracy_list, label='Train Accuracy')
plt.plot(epochs, test_accuracy_list, label='Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Model Accuracy')
plt.legend()
plt.show()

In [None]:
k=0
for i, hdu in enumerate(hdulist):
    if hdu.data is not None:
        if k>3:
            break
        k+=1
        
        

        image_data = hdu.data
        roi = image_data[-80:-10, 30:270]
        roi_image = process_image(roi)
    
        plt.imshow(roi_image, cmap='gray')
        plt.show()
        

        roi_image = custom_process(roi_image)
        
        # the list of start-end indices for each digit
        indices = [[6, 22], [22, 36], [55, 71], [71, 87], [104, 120], [121, 137], [154, 170], [171, 187], [203, 219], [223,239],[254,270],[270,286],[303,319],[320,336],[337,353],[352,368],[380,396]]

        for index in indices:
            digit_image = roi_image.crop((index[0], 18, index[1], 54))
            plt.imshow(digit_image, cmap='gray')
            plt.show()


            digit_image = np.array(digit_image)
            digit_image = np.pad(digit_image, ((0, 0), (6, 6)), 'constant', constant_values=(0, 0))

            digit_image = Image.fromarray(digit_image) 
            digit_image = digit_image.resize((28, 28))
            

            threshold_value = 100
            digit_image = digit_image.point(lambda x: 0 if x < threshold_value else x)
            digit_image = np.array(digit_image)
            

            digit_image = torch.from_numpy(digit_image)
            digit_image = digit_image.unsqueeze(0)
            digit_image = digit_image.float()
            

            digit_image /= 255.0
            digit = torch.argmax(model(digit_image.to(device))).item()
        
            print(digit, end=' ')
            
            break # remove this line to predict all digits
    



**We see that this model doesn't perform well on the timer images, this may be due to the diffrence between the nature of the 2 datasets.**

# 4. Model based on Timer Images

Now we will use the timer images to train our model, we will use the same architecture as the MNIST model, but we will change the input shape to fit the timer images.

for this model we will create a custom dataset class that will load the images and labels from the folders we created in the preprocessing part.

### A - Datastes and Dataloaders

In [None]:
from torchvision.datasets import ImageFolder
# Define the transform
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


root_dir = "new_folder"

# Load the train and test datasets using ImageFolder
train_dataset = ImageFolder(root=os.path.join(root_dir, "train"), transform=transform)
test_dataset = ImageFolder(root=os.path.join(root_dir, "test"), transform=transform)


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)


### B - Building the Model

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(64 * 4 * 4, 256)  # Adjusted the size
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### C - Training

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # Adjusted the learning rate


# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
                
        optimizer.zero_grad()
                
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
                
        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f"[Epoch {epoch+1}, Batch {i+1}] Loss: {running_loss / 2000:.3f}")
            running_loss = 0.0




### D - Evaluation

In [None]:
# Testing loop

y_pred = []
y_true = []
correct = 0  
total = 0  
with torch.no_grad():
    for data in test_loader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        y_pred.extend(predicted.tolist())
        y_true.extend(labels.tolist())
        total += labels.size(0)  
        correct += (predicted == labels).sum().item()  

accuracy = (100 * correct / total)
classification = classification_report(y_true, y_pred)
confusion = confusion_matrix(y_true, y_pred)



print(f"Accuracy on test set: {accuracy:.2f}%")
print("Classification Report:")
print(classification)
print("Confusion Matrix:")
print(confusion)
plt.figure(figsize=(8, 6))
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted")




**PEEEEEEEEEEERFECT, now we have a model that can recognize the digits in the timer images, let's complete adding the Dates to the impair hdus.**

In [None]:

k = 0
for i, hdu in enumerate(hdulist[1::2]):  #here' we're gonna fill the imapir HDUs cause the pair ones are aleady filled like we said in the note
    if hdu.data is not None:
        
        if k>10:   #remove this line to predict all digits
            break   
        k=k+1
        
        plt.imshow(hdu.data, cmap='gray')
        plt.show()

        image_data = hdu.data

        roi = image_data[-80:-10, 30:270]
        roi_image = process_image(roi)
        roi_image = custom_process(roi_image)
        
        # the list of start-end indices for each digit
        indices = [[6, 22], [22, 36], [55, 71], [71, 87], [104, 120], [121, 137], [154, 170], [171, 187], [203, 219], [223,239],[254,270],[270,286],[303,319],[320,336],[337,353],[352,368],[380,396]]
        
        date = ""
        
        # Loop through the indices and predict the digit for each sub-image
        for index in indices:

            digit_image = roi_image.crop((index[0], 18, index[1], 54))
            digit_image = np.array(digit_image)
            digit_image = np.repeat(digit_image[:, :, np.newaxis], 3, axis=2)
            digit_image = np.pad(digit_image, ((0, 0), (6, 6), (0, 0)), 'constant', constant_values=(0, 0))
            digit_image = Image.fromarray(digit_image) 
            digit_image = digit_image.resize((28, 28))
            
            digit_image = transforms.ToTensor()(digit_image)
            digit_image = digit_image.unsqueeze(0)
        
            digit_image = transforms.Normalize((0.5,), (0.5,))(digit_image)
            digit = torch.argmax(model(digit_image.to(device))).item()
            

            date += str(digit)
        
        # Setting the DATE header with the concatenated sub-images
        hdu.header['DATE'] = date
        print(date)
        print(hdu.header['DATE'])


The model looks doing well, some mistakes he does are due to the output function argmax that somtimes he gives +-1 the correct digit, with more epochs and better coroping it will be better. also the most important digits to recognize are the  meliseconds, and the model is doing well on them.

# 5. Conclusion

**We have created 2 models that can recognize the digits in the timer images, the 1st one is trained on the MNIST dataset and the 2nd one is trained on the timer images themselves.**

**The 2nd model is better than the 1st one, but it's still not perfect, we can add validation data to confirm it's performace, if it kept doing great it would be a great start of trying to optimize it for fun and learning.**

**Thank you for reading this notebook, I hope you liked it, if you have any questions or suggestions please leave them in the comments.**