In [45]:
import tarfile
import os
import shutil
from astropy.io import fits
import numpy as np
from scipy.ndimage import zoom
import re
import tensorflow as tf

# Specify the directory containing the .tar files
directory_path = '/Users/matt/Dropbox/learning_neuralnets/colombialensing/'

# image_size
image_size = 256
suffix = f"_{image_size}"

# Use a regular expression to match .tar files with the desired suffix
pattern = re.compile(rf"{suffix}.tar$")

# List all matching .tar files in the directory
all_tar_files = [f for f in os.listdir(directory_path) if pattern.search(f)]

def get_labels_for_file(tar_file_name):
    """
    Extracts labels from the tar file name.
    For the file "Om0.183_si0.958_256.tar", the labels will be [0.183, 0.958].
    
    Args:
    - tar_file_name (str): Name of the tar file.
    
    Returns:
    - list: List containing the two labels extracted from the filename.
    """
    # Split the filename on underscores
    parts = tar_file_name.split('_')

    # Extract the numeric values for 'Om' and 'si'
    om_label = float(parts[0][2:])
    si_label = float(parts[1][2:])
    
    return [om_label, si_label]



def extractBatchOld(tar_file):
    tar_file_path = os.path.join(directory_path, tar_file)
    
    # Extract the tar archive
    with tarfile.open(tar_file_path, 'r') as archive:
        archive.extractall(path=directory_path)

    dir_name = os.path.splitext(tar_file)[0]
    dir_path = os.path.join(directory_path, dir_name)

    all_files = os.listdir(dir_path)
    fits_files = [f for f in all_files if f.endswith('.fits')]

    # Create a numpy array of desired shape
    data_array = np.empty((len(fits_files), image_size, image_size), dtype=np.float16)

    for idx, file in enumerate(fits_files):
        with fits.open(os.path.join(dir_path, file)) as hdul:
            
            original_data = hdul[0].data
            
            #get rid of NANs, which affects a few files
            if np.isnan(original_data).any():
                continue
                
            data_array[idx] = original_data
            
            
    #since all files have the same labels
    labels = get_labels_for_file(tar_file)
    
    batch_size = len(fits_files)
    batch_labels = np.array([labels for i in range(batch_size)])
    print(np.shape( batch_labels ))
    
    WL_labels = tf.convert_to_tensor(batch_labels)

    # At this point, data_array contains all the data from the fits files and is ready to be used with TensorFlow
    # For example, you can convert it to a TensorFlow tensor:
    WL_tensor = tf.convert_to_tensor(data_array)

    shutil.rmtree(dir_path)

    print(f"Extracted data from {tar_file}")
    return [WL_tensor, WL_labels]



def extractBatch(number_files):
    
    for tar_file in all_tar_files:
    tar_file_path = os.path.join(directory_path, tar_file)
    
    # Extract the tar archive
    with tarfile.open(tar_file_path, 'r') as archive:
        archive.extractall(path=directory_path)

    dir_name = os.path.splitext(tar_file)[0]
    dir_path = os.path.join(directory_path, dir_name)

    all_files = os.listdir(dir_path)
    fits_files = [f for f in all_files if f.endswith('.fits')]

    # Create a numpy array of desired shape
    data_array = np.empty((len(fits_files), image_size, image_size), dtype=np.float16)

    for idx, file in enumerate(fits_files):
        with fits.open(os.path.join(dir_path, file)) as hdul:
            
            original_data = hdul[0].data
            
            #get rid of NANs, which affects a few files
            if np.isnan(original_data).any():
                continue
                
            data_array[idx] = original_data
            
            
    #since all files have the same labels
    labels = get_labels_for_file(tar_file)
    
    batch_size = len(fits_files)
    batch_labels = np.array([labels for i in range(batch_size)])
    print(np.shape( batch_labels ))
    
    WL_labels = tf.convert_to_tensor(batch_labels)

    # At this point, data_array contains all the data from the fits files and is ready to be used with TensorFlow
    # For example, you can convert it to a TensorFlow tensor:
    WL_tensor = tf.convert_to_tensor(data_array)

    shutil.rmtree(dir_path)

    print(f"Extracted data from {tar_file}")
    return [WL_tensor, WL_labels]



In [39]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

 #make less redundnat

def create_cnn_model(input_shape):
    model = Sequential()
    model.add(Conv2D(32, (3, 3), activation='relu', input_shape=input_shape))
    model.add(BatchNormalization())
    model.add(MaxPooling2D((2, 2)))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D((2, 2)))
    model.add(Flatten())
    model.add(Dense(64, activation='relu'))
    model.add(Dense(2, activation='softmax'))  
    
    return model

model = create_cnn_model((image_size, image_size, 1))  # Assuming grayscale images


In [49]:
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

In [41]:
model.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_8 (Conv2D)           (None, 254, 254, 32)      320       
                                                                 
 max_pooling2d_8 (MaxPooling  (None, 127, 127, 32)     0         
 2D)                                                             
                                                                 
 conv2d_9 (Conv2D)           (None, 125, 125, 64)      18496     
                                                                 
 max_pooling2d_9 (MaxPooling  (None, 62, 62, 64)       0         
 2D)                                                             
                                                                 
 flatten_4 (Flatten)         (None, 246016)            0         
                                                                 
 dense_8 (Dense)             (None, 64)               

In [50]:
count =0
for tar_file in all_tar_files:
    # ... [Code to extract and preprocess the data into data_array]

    [WL_tensor, labels]  = extractBatchold(tar_file)
    #tensor = tf.convert_to_tensor(data_array)
    
    # Assuming you have corresponding labels for each batch
    #labels = get_labels_for_batch(tar_file)  # You'd need to define this function

    print("labels shape = ", np.shape(labels))
    # Reshape tensor to be (-1, image_size, image_size, 1) to fit the CNN input shape
    WL_tensor = tf.reshape(WL_tensor, (-1, image_size, image_size, 1))

    # Train the model on this batch
    model.fit(WL_tensor, labels, epochs=5)  # You can adjust the number of epochs as needed

    print("Completed ", count/len(all_tar_files))
    count += 1


(512, 2)
Extracted data from Om0.518_si0.611_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.375_si0.332_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.513_si0.463_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.393_si0.686_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.650_si0.565_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.263_si0.795_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.294_si0.991_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.227_si0.793_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 

Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.313_si0.633_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.267_si0.642_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.460_si0.897_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.291_si0.775_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.361_si0.935_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.213_si0.536_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.274_si0.786_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.234_si0.864_256.tar
labels shap

Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.192_si0.727_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.283_si0.805_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.184_si0.829_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.252_si0.811_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(455, 2)
Extracted data from Om0.251_si0.807_256.tar
labels shape =  (455, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.349_si0.274_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.452_si0.454_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.220_si0.912_256.tar
labels shape =  (512,

Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.191_si0.924_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.513_si0.256_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.268_si0.801_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.515_si0.089_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.259_si0.875_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.261_si0.802_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.260_si0.800_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.528_si0.167_256.tar
labels shape =  (512, 2)
Epoch 

Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.410_si0.927_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.273_si1.204_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.275_si0.766_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.249_si0.764_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.315_si0.746_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.610_si0.397_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.317_si0.837_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.253_si0.589_256.tar
labels shape =  (512, 2)
Epoch 

Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.253_si0.852_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.638_si0.250_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.365_si0.524_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.224_si1.013_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.566_si0.520_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.496_si0.338_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.195_si1.095_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.403_si0.757_256.tar
labels shape =  (512, 2)
Epoch 

Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.268_si0.727_256.tar
labels shape =  (512, 2)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
(512, 2)
Extracted data from Om0.195_si0.994_256.tar
labels shape =  (512, 2)
Epoch 1/5

KeyboardInterrupt: 