In [1]:
import keras
import numpy as np
import os
import cv2

2023-11-19 19:10:44.193312: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [2]:
train_image_folder = '/raid/mpsych/OMAMA/DATA/data/train'
train_npz_folder = '/raid/mpsych/OMAMA/DATA/data/2d_resized_512/images'

In [3]:
img_height = 512
img_width = 512
batch_size = 32

In [6]:
def custom_data_generator(image_folder, npz_folder, batch_size, img_height, img_width):
    print(f"Generator called")
    image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
    npz_files = [f for f in os.listdir(npz_folder) if f.endswith('.npz')]
    total_files = len(image_files)

    print(f"Total Files: {total_files}")

    indices = np.arange(total_files)
    np.random.shuffle(indices)

    num_images_loaded = 0

    while True:
        print(f"Entered loop")
        for i in range(0, total_files, batch_size):
            batch_indices = indices[i:i + batch_size]

            print(f"Batch Indices: {batch_indices}")

            batch_images = []
            batch_npz = []
            batch_labels = []

            for idx in batch_indices:
                img_file = os.path.join(image_folder, image_files[idx])
                npz_file = os.path.join(npz_folder, npz_files[idx])

                image = cv2.imread(img_file, cv2.IMREAD_GRAYSCALE)
                image = cv2.resize(image, (img_width, img_height))
                image = np.expand_dims(image, axis=-1) 

                with np.load(npz_file, allow_pickle=True) as data:
                    npz = data['data']
                npz = np.expand_dims(npz, axis=-1)

                batch_images.append(image)
                batch_npz.append(npz)
                
                # Assign class labels: 1 for PNG, 0 for NPZ
                if img_file.endswith('.png'):
                    batch_labels.append([1, 0])
                if npz_file.endswith('.npz'):
                    batch_labels.append([0, 1])
               

                num_images_loaded += 1

            if len(batch_images) > 0:
                yield (np.array(batch_images + batch_npz), np.array(batch_labels))
            else:
                break

In [7]:
NUMBER_OF_CLASSES = 2

In [8]:
model = keras.models.Sequential()
model.add(keras.layers.Conv2D(32, kernel_size=(3, 3),
                             activation='relu',
                             input_shape=(img_height, img_width, 1)))
model.add(keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(keras.layers.Dropout(0.25))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(128, activation='relu'))
model.add(keras.layers.Dropout(0.5))
model.add(keras.layers.Dense(NUMBER_OF_CLASSES, activation='softmax'))

2023-11-19 19:11:29.445426: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2023-11-19 19:11:29.474592: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2023-11-19 19:11:29.474657: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: chimera12
2023-11-19 19:11:29.474663: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: chimera12
2023-11-19 19:11:29.475089: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 450.172.1
2023-11-19 19:11:29.475107: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 450.172.1
2023-11-19 19:11:29.475111: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 450.172.1
2023-11-19 19:11:29.475643: I tensorflow/core/platform/cpu_fe

In [9]:
model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

In [10]:
# Create the data generators
train_generator = custom_data_generator(train_image_folder, train_npz_folder, batch_size, img_height, img_width)

try:
    model.fit(
        train_generator,
        epochs=3,
        steps_per_epoch=50,
        verbose=1
    )
except Exception as e:
    print("An error occurred:", str(e))

Generator called
Total Files: 99999
Entered loop
Batch Indices: [48450 36446 81243 83443 88547 70652 97282 32430 74840 87637 64122 76843
 45454 76497 65969 26650 51464 47581 30336 43789 60910 61957 99570 73029
 29717 25368  7012 33687 34493 83543 67339 87207]


2023-11-19 19:11:36.258792: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2023-11-19 19:11:36.279160: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2245720000 Hz


Epoch 1/3
 1/50 [..............................] - ETA: 44:17 - loss: 90.9620 - accuracy: 0.5625Batch Indices: [ 6163 42140 17112 62278 20454 93927 51930 19575 74274 36386 34137 71019
 62639 99498 50346 71836 55522 69954 33866 98543 18882 73131 49989 52466
 82931 90455 85525 76142 94466 88049  6496 65518]
 2/50 [>.............................] - ETA: 42:35 - loss: 135.3896 - accuracy: 0.5781Batch Indices: [53246 82319  5331 37637 44175 75655 52171 32907 99526  7422 65100 56856
 74360 33393 46264 73776 28879 71319 74818 19382 11402  1235 15013 82298
 16136 68335 47882 85447  1229 80561 16971 62961]
 3/50 [>.............................] - ETA: 42:00 - loss: 403.6102 - accuracy: 0.5521Batch Indices: [97322 74457 18170 68929 89956 87171 23446 60787 68846 53928 99679 59914
 34927 82805 98419 47175 51934 73375 35572 79509 24445 25445 56326 50010
 38688 10876 58529 88202 36345 73766 86269  1999]
 4/50 [=>............................] - ETA: 41:03 - loss: 350.1594 - accuracy: 0.5430Batch Indi

 17443 24410 55859 68946 80440  8907 62245  1138 63188 53043 50951 63489
 97623 58490  2575  8257 97025 85135 75284 28758]
 92550 67637 37209 95765 16292 59497 75956 15390  5225 43055 39132 58130
 81575 99433 83230 43647 25726 80203 70714  2058]
 54662 82245 35050 51691 56121  9730 13929 14429 56537 76174 45924 63756
 87828 51796 34695 93391 25573 41065 72729 21828]
 74545 44925 89670 58963 31091 12254 62099 36584 26643 71816  7264 32779
 58230 86653 50599 46169 33319  7882 65810  1285]
 17230 36300 44998 26013 15850 68844 50966 45131 40711 15016 61207 99572
 97664  9652 45955 62409 70755 38252 26797 77932]
  7995 53401 67095 81647 42474 44247 32773 86049 48707 27526 94530 80744
  9562  7720 64114 66243 12295 11168 73801 43912]
 82336 37242 33251 76049 13800 65909 51046 17191  2536 16594 77378 72231
 96979 15134 65200 36676 96244 60529 12634 62137]
 46743 72155 84076 40404 34294 15023 28324 45726  5187 24600 49229 11185
 13631 94908 76676 70575 74957 57851 18597  7015]
  9821 46943 921

 7/50 [===>..........................] - ETA: 38:18 - loss: 1.6576 - accuracy: 0.5223Batch Indices: [25081 19774 41474 54818 67565 79181 45834 66279 71082 43955 37348 45970
 66584  7677 44785 59352 58314 35895 47996  7847 11815  4311 59187 78369
 81927 13979 56318 80069 56541 61410 14390 55114]
 8/50 [===>..........................] - ETA: 37:24 - loss: 1.5538 - accuracy: 0.5371Batch Indices: [10643 76483 12322 85401 77320 11488 71051 87076 59791 32408 53862 39035
 84335 81522 65549 22022 67567 23618 67812 12826 22938 76036 10001 18534
  1227 16752 18192 59959 24193 79452 56113 33260]
 9/50 [====>.........................] - ETA: 36:31 - loss: 1.6303 - accuracy: 0.5243Batch Indices: [27956 96891 44047  3003 93065 24681 80146 26673 96304 71611 89894 68250
 35493 65938 63765 53175  3725 93857 28590  1787 94866 73453 58966 58004
 73693 59792 40886 22825 59620 29501 39976   787]
10/50 [=====>........................] - ETA: 35:38 - loss: 1.6371 - accuracy: 0.5266Batch Indices: [  913 85682

 24764 55193 34020 47935 94324 38329 30144 57594   731 84634 17017 14843
 21136 21873 51775 27519   215   880 64184 49449]
  9148 77882 54011 67447 15478 98716 78702  5523 31481 36833 88978 72843
 67471 54939 40845  5383 82447  2596 42276 76700]
 87460 11359 94693 48831 95851 58318  7319 94555 46477 87249 39192 26780
  3781 43481 72146 58654 36157 64383 23677 60163]
 24093 40891 23203 65631 21790 51931 13179 71137  8080 52153 85669 13154
 23112  2802 33607 72504 14136 71418 54516 97044]
 94225 29089 50095 24966 88120 31445 55649 64770 41808 49684 63244  1005
 70417 71339 26609 47166 49557 43502 93836 81184]
 92577 73988 49182 26436 16521 31248 81260 64427 26245 65339  5812 41635
 31149 13828 16922 88843 17286 28921 42541 65789]
 30068 70609 56302 93983 60270 52645 24235 42653 11122 75657 48883 29046
 65083 70745 54190 71206 74401 71381 53427 24145]
  1451 75016 65481 94145 51159 19149 80514 31647  6987 93278 22102 84346
 82887 27730  9601 64013 62597 45356 74540 79803]
 83588 63096 960

 86684 29184 59157 42671 27202  5770  4036 18767  3823 34424 61387 96224
 42186 39372 85543 65048 18569 25583 28645 94594]
 80258 92117 95127 38828 65883 19218 50765 56220 59234 80381 59308 82554
 27655 53482 37439 89692 85994 28846 43049 36038]
 17057 64234 70365 36188 53666 33341 36261 14479 21067 74396 85976  5399
 24821 88567 27300 83802 42569 19228 61289 98368]
 72825 97052 80358 14616 61928 49002 21782  1305 36373 75310 79594   866
 27040 38249 31482 16124 78508 63623 19480 98348]
 66948 88138 58695  6755 41648 43781 95880 99694 43368 63287 20682  2186
 15607 20521 92656  6593  5289 50412 78095 69844]
  6095 29195 72127 69157 44586  6692 92113 83513 55719 83645 19426 31402
 46623  1238 69967 82198 71488 19700 71315 79326]
 80551 59896 65910 56952 42430 98479 88111 85004  7228 75751 84034 99444
 26869 83889 48823 33745 17010 73853 85348 50728]
 32857 65708 73340 16247 93123 57225 69854 56078 11964 99708 50161 55779
 66949 14282 40113 45568 79943 63483 18243 87931]
 48513 89880 254

 28053 86668 31033 81140 70080 89357 86508 43139 13070 38278 31441 35502
 20944 51641  4676 47642 57244 70287 92454 27659]
 11973 86831 26819 48712 93941 17702 66793   879    87 18064 35325 87885
 56293 81671 59059 20131 99387 35880 56555 35282]
 72160 49143 45713 33355 56128 83797 44128  3192 63950 87246 69197  4164
 61301 75945 61973 61182 96580 70229 30451 71223]
 53119 53850  1633 52357 63524 15988 50470 28575 34249 94286 10212 67157
 77321 15430 92553 99894  4740 66139  2260 38872]
 42596 27612  5998 93223 80839 26040 58397 30229 28790 27396  8117 79551
 43717 10007 56722 37418 74353 48958 72276 75325]
 74702 88909  2508  1776 36736 35831 28175   460 78824 79542  7764 85709
 96123 12663 51865 11678 85638 65496 26156 91292]
 13098 87519 69350 19809 53199  5914  2604 61264 32814 54730 32655 41360
 42094 24199 63543 84408 62839 84643 99819 60901]
Batch Indices: [60721 21113 34041 61447 76693 56990 14233 61178  3710 97196 34253  9549
 84578 79410 93378 75837  7584 91731 44088 60257 66

In [11]:
test_file = '/raid/mpsych/OMAMA/DATA/data/train/sample_40069.png'
test_image = cv2.imread(test_file, cv2.IMREAD_GRAYSCALE)
test_image = cv2.resize(test_image, (img_width, img_height))
test_image = np.expand_dims(test_image, axis=-1)
test_image = test_image / 255.0
test_image = np.expand_dims(test_image, axis=0)  # Add batch dimension

predictions = model.predict(test_image)
print("Predictions:", predictions)
predicted_class = np.argmax(predictions)
print("Predicted Class:", predicted_class)

Predictions: [[0.49999705 0.5000029 ]]
Predicted Class: 1
