<a href="https://colab.research.google.com/github/The-cheater/Deep_Learning_Models/blob/main/ideal_sir.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ===========================
# 1️⃣ Google Drive & Extraction
# ===========================
from google.colab import drive
drive.mount('/content/drive')

import os
import zipfile

# Unzip datasets
def unzip_dataset(zip_path, extract_to):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

unzip_dataset('/content/drive/MyDrive/dataset/GAF_Images.zip', '/content/GAF_Images')
unzip_dataset('/content/drive/MyDrive/dataset/MTF_Images.zip', '/content/MTF_Images')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# ===========================
# 2️⃣ Imports and Setup
# ===========================
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
import numpy as np
import matplotlib.pyplot as plt


In [3]:
# ===========================
# 3️⃣ Optimized PairedDataset
# ===========================
class PairedDataGenerator(keras.utils.Sequence):
    def __init__(self, gaf_dir, mtf_dir, batch_size=16, img_size=(224,224), shuffle=True):
        self.gaf_paths, self.mtf_paths, self.labels = self._load_pairs(gaf_dir, mtf_dir)
        self.batch_size = batch_size
        self.img_size = img_size
        self.shuffle = shuffle
        self.on_epoch_end()

    def _load_pairs(self, gaf_dir, mtf_dir):
        gaf_paths, mtf_paths, labels = [], [], []
        for root, _, files in os.walk(gaf_dir):
            for fname in files:
                if fname.endswith('_gaf.png'):
                    gaf_path = os.path.join(root, fname)
                    mtf_path = gaf_path.replace('GAF_Images', 'MTF_Images').replace('_gaf.png', '_mtf.png')
                    if os.path.exists(mtf_path):
                        gaf_paths.append(gaf_path)
                        mtf_paths.append(mtf_path)
                        labels.append(0 if '/EL/' in gaf_path else 1 if '/PD/' in gaf_path else 2)
        return gaf_paths, mtf_paths, np.array(labels)

    def __len__(self):
        return int(np.ceil(len(self.labels) / self.batch_size))

    def __getitem__(self, index):
        batch_indices = self.indices[index*self.batch_size : (index+1)*self.batch_size]
        gaf_batch = [self._load_image(self.gaf_paths[i]) for i in batch_indices]
        mtf_batch = [self._load_image(self.mtf_paths[i]) for i in batch_indices]
        # Change the return type for the inputs from a list to a tuple
        return (np.array(gaf_batch), np.array(mtf_batch)), self.labels[batch_indices]

    def _load_image(self, path):
        img = tf.io.read_file(path)
        img = tf.image.decode_png(img, channels=3)
        img = tf.image.resize(img, self.img_size)
        img = tf.keras.applications.efficientnet.preprocess_input(img)
        return img

    def on_epoch_end(self):
        self.indices = np.arange(len(self.labels))
        if self.shuffle:
            np.random.shuffle(self.indices)

In [4]:
# ===========================
# 4️⃣ TensorFlow Model Definition
# ===========================
def conv_block(x, filters, n_convs, name):
    for i in range(n_convs):
        x = layers.Conv2D(filters, 3, padding='same', name=f'{name}_conv{i+1}')(x)
        x = layers.BatchNormalization(name=f'{name}_bn{i+1}')(x)
        x = layers.ReLU(name=f'{name}_relu{i+1}')(x)
    return x

def create_l3_fusion_model(input_shape=(224,224,3), num_classes=3):
    # Branch 1 (GAF)
    input_gaf = layers.Input(shape=input_shape, name='gaf_input')
    x1 = conv_block(input_gaf, 64, 2, 'branch1_conv1')
    x1 = layers.MaxPool2D(2, 2, name='branch1_pool1')(x1)
    x1 = conv_block(x1, 128, 2, 'branch1_conv2')
    x1 = layers.MaxPool2D(2, 2, name='branch1_pool2')(x1)
    x1 = conv_block(x1, 256, 3, 'branch1_conv3')
    branch1_out = layers.MaxPool2D(2, 2, name='branch1_pool3')(x1)

    # Branch 2 (MTF)
    input_mtf = layers.Input(shape=input_shape, name='mtf_input')
    x2 = conv_block(input_mtf, 64, 2, 'branch2_conv1')
    x2 = layers.MaxPool2D(2, 2, name='branch2_pool1')(x2)
    x2 = conv_block(x2, 128, 2, 'branch2_conv2')
    x2 = layers.MaxPool2D(2, 2, name='branch2_pool2')(x2)
    x2 = conv_block(x2, 256, 3, 'branch2_conv3')
    branch2_out = layers.MaxPool2D(2, 2, name='branch2_pool3')(x2)

    # Fusion
    fused = layers.Concatenate(axis=-1)([
        layers.Conv2D(256, 3, padding='same')(branch1_out),
        layers.Conv2D(256, 3, padding='same')(branch2_out)
    ])

    # Common trunk
    x = conv_block(fused, 512, 3, 'fusion_conv4')
    x = layers.MaxPool2D(2, 2, name='pool4')(x)
    x = conv_block(x, 512, 3, 'conv5')
    x = layers.MaxPool2D(2, 2, name='pool5')(x)

    # Classification head
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(4096, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(4096, activation='relu')(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    return Model(inputs=[input_gaf, input_mtf], outputs=outputs)

In [5]:
# ===========================
# 5️⃣ Data Preparation
# ===========================
train_gen = PairedDataGenerator(
    '/content/GAF_Images/GAF_Images_train',
    '/content/MTF_Images/MTF_Images_train',
    batch_size=32,
    shuffle=True
)

val_gen = PairedDataGenerator(
    '/content/GAF_Images/GAF_Images_train',
    '/content/MTF_Images/MTF_Images_train',
    batch_size=32,
    shuffle=False
)


In [None]:
# ===========================
# 6️⃣ Model and Training Setup
# ===========================
model = create_l3_fusion_model()
optimizer = Adam(learning_rate=1e-4)
model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

callbacks = [
    ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=3),
    EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True),
    ModelCheckpoint('best_model.h5', save_best_only=True)
]
# ===========================
# 7️⃣ Training Execution
# ===========================
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=50,
    callbacks=callbacks,
    verbose=1
)

  self._warn_if_super_not_called()


Epoch 1/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.5208 - loss: 0.9937



[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m506s[0m 2s/step - accuracy: 0.5209 - loss: 0.9935 - val_accuracy: 0.5690 - val_loss: 0.9856 - learning_rate: 1.0000e-04
Epoch 2/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.5707 - loss: 0.8837



[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m433s[0m 2s/step - accuracy: 0.5708 - loss: 0.8836 - val_accuracy: 0.5355 - val_loss: 0.9608 - learning_rate: 1.0000e-04
Epoch 3/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m427s[0m 2s/step - accuracy: 0.6251 - loss: 0.8255 - val_accuracy: 0.5763 - val_loss: 1.1086 - learning_rate: 1.0000e-04
Epoch 4/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.6621 - loss: 0.7645



[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m436s[0m 2s/step - accuracy: 0.6620 - loss: 0.7645 - val_accuracy: 0.5943 - val_loss: 0.9291 - learning_rate: 1.0000e-04
Epoch 5/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m423s[0m 2s/step - accuracy: 0.6947 - loss: 0.6795 - val_accuracy: 0.5127 - val_loss: 1.1210 - learning_rate: 1.0000e-04
Epoch 6/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.7385 - loss: 0.6076



[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m425s[0m 2s/step - accuracy: 0.7385 - loss: 0.6076 - val_accuracy: 0.6458 - val_loss: 0.8960 - learning_rate: 1.0000e-04
Epoch 7/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.7847 - loss: 0.5083



[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m429s[0m 2s/step - accuracy: 0.7847 - loss: 0.5083 - val_accuracy: 0.6593 - val_loss: 0.7040 - learning_rate: 1.0000e-04
Epoch 8/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8181 - loss: 0.4305



[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m425s[0m 2s/step - accuracy: 0.8181 - loss: 0.4305 - val_accuracy: 0.7981 - val_loss: 0.5470 - learning_rate: 1.0000e-04
Epoch 9/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8450 - loss: 0.3838



[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m436s[0m 2s/step - accuracy: 0.8451 - loss: 0.3837 - val_accuracy: 0.8279 - val_loss: 0.4235 - learning_rate: 1.0000e-04
Epoch 10/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8980 - loss: 0.2651



[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m428s[0m 2s/step - accuracy: 0.8980 - loss: 0.2651 - val_accuracy: 0.9327 - val_loss: 0.1741 - learning_rate: 1.0000e-04
Epoch 11/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m419s[0m 2s/step - accuracy: 0.9208 - loss: 0.1968 - val_accuracy: 0.6853 - val_loss: 0.7754 - learning_rate: 1.0000e-04
Epoch 12/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m415s[0m 2s/step - accuracy: 0.9450 - loss: 0.1503 - val_accuracy: 0.6507 - val_loss: 1.3524 - learning_rate: 1.0000e-04
Epoch 13/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m449s[0m 2s/step - accuracy: 0.9542 - loss: 0.1255 - val_accuracy: 0.9235 - val_loss: 0.2086 - learning_rate: 1.0000e-04
Epoch 14/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m417s[0m 2s/step - accuracy: 0.9760 - loss: 0.0640 - val_accuracy: 0.8303 - val_loss: 0.4045 - learning_rate: 5.0000e-05
Epoch 15/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━



[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m423s[0m 2s/step - accuracy: 0.9780 - loss: 0.0678 - val_accuracy: 0.9937 - val_loss: 0.0211 - learning_rate: 5.0000e-05
Epoch 16/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m417s[0m 2s/step - accuracy: 0.9922 - loss: 0.0238 - val_accuracy: 0.9386 - val_loss: 0.1688 - learning_rate: 5.0000e-05
Epoch 17/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m418s[0m 2s/step - accuracy: 0.9823 - loss: 0.0529 - val_accuracy: 0.9917 - val_loss: 0.0251 - learning_rate: 5.0000e-05
Epoch 18/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m416s[0m 2s/step - accuracy: 0.9910 - loss: 0.0255 - val_accuracy: 0.9917 - val_loss: 0.0261 - learning_rate: 5.0000e-05
Epoch 19/50
[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.9963 - loss: 0.0127



[1m252/252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m420s[0m 2s/step - accuracy: 0.9963 - loss: 0.0127 - val_accuracy: 1.0000 - val_loss: 9.5035e-04 - learning_rate: 2.5000e-05
Epoch 20/50
[1m170/252[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m1:31[0m 1s/step - accuracy: 0.9892 - loss: 0.0366

In [None]:
# ===========================
# 8️⃣ Visualization
# ===========================
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(history.history['loss'], label='Train')
plt.plot(history.history['val_loss'], label='Val')
plt.legend()
plt.title('Loss')

plt.subplot(1,2,2)
plt.plot(history.history['accuracy'], label='Train')
plt.plot(history.history['val_accuracy'], label='Val')
plt.legend()
plt.title('Accuracy')
plt.show()

# ===========================
# 9️⃣ Evaluation
# ===========================
test_gen = PairedDataGenerator(
    '/content/GAF_Images/GAF_Images_test',
    '/content/MTF_Images/MTF_Images_test',
    batch_size=32,
    shuffle=False
)

model.load_weights('best_model.h5')
test_loss, test_acc = model.evaluate(test_gen)
print(f"✅ Final Test Accuracy: {test_acc:.4f}")