In [None]:
import numpy as np
import os
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Add, Dropout, LeakyReLU
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, Callback
from sklearn.model_selection import train_test_split
from google.colab import drive

# === Mount Google Drive ===
drive.mount('/content/drive', force_remount=True)

# === GPU Setup ===
physical_devices = tf.config.list_physical_devices('GPU')
for gpu in physical_devices:
    tf.config.experimental.set_memory_growth(gpu, True)

# === Paths ===
tumor_dir = '/content/drive/MyDrive/NewData/blurtumor'
no_tumor_dir = '/content/drive/MyDrive/NewData/blurnormal'
img_size = 256

# === Image Loader ===
def load_images(folder):
    if not os.path.exists(folder) or len(os.listdir(folder)) == 0:
        raise FileNotFoundError(f"Folder issue: {folder}")

    images = []
    for filename in os.listdir(folder):
        path = os.path.join(folder, filename)
        img = cv2.imread(path)
        if img is not None:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (img_size, img_size))
            images.append(img_to_array(img) / 127.5 - 1)  # Normalize to [-1, 1]
    return images

# === Load Tumor and NoTumor Images ===
tumor_imgs = load_images(tumor_dir)
no_tumor_imgs = load_images(no_tumor_dir)

# === Combine ===
all_imgs = np.array(tumor_imgs + no_tumor_imgs, dtype='float32')

# === Split ===
x_train, x_test = train_test_split(all_imgs, test_size=0.2, random_state=42)

# === SSIM + L1 Loss ===
def combined_loss(y_true, y_pred):
    ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
    l1_loss = tf.reduce_mean(tf.abs(y_true - y_pred))
    return 0.5 * ssim_loss + 0.5 * l1_loss

# === Encoder + Decoder ===
def build_encoder(input_img):
    filters = [64, 128, 256]
    x = input_img
    skips = []
    for f in filters:
        x = Conv2D(f, (3, 3), padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.1)(x)
        x = Dropout(0.3)(x)
        skips.append(x)
        x = MaxPooling2D((2, 2), padding='same')(x)
    return x, skips

def build_decoder(encoded, skips):
    filters = [256, 128, 64]
    skips.reverse()
    x = encoded
    for i, f in enumerate(filters):
        x = UpSampling2D((2, 2))(x)
        x = Conv2D(f, (3, 3), padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.1)(x)
        if i < len(skips):
            x = Add()([x, skips[i]])
    return Conv2D(3, (3, 3), activation='tanh', padding='same')(x)

def build_autoencoder(input_shape=(256, 256, 3)):
    input_img = Input(shape=input_shape)
    encoded, skips = build_encoder(input_img)
    decoded = build_decoder(encoded, skips)
    model = Model(input_img, decoded)
    model.compile(optimizer=Adam(1e-4), loss=combined_loss)
    return model

# === Accuracy Display Callback ===
class AccuracyCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        val_loss = logs.get('val_loss')
        accuracy = (1 - val_loss) * 100
        print(f"Epoch {epoch+1}: Val Loss = {val_loss:.4f}, Accuracy ~ {accuracy:.2f}%")

# === Build + Train ===
autoencoder = build_autoencoder()
autoencoder.summary()

history = autoencoder.fit(
    x_train, x_train,
    epochs=300,
    batch_size=32,
    shuffle=True,
    validation_data=(x_test, x_test),
    callbacks=[ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5), AccuracyCallback()]
)

# === Save Model ===
model_path = '/content/drive/MyDrive/NewData/bestkidney_model.h5'
autoencoder.save(model_path)
print(f"Saved: {model_path}")

# === Plot ===
def plot_loss(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history.history['loss'], label='Train')
    plt.plot(history.history['val_loss'], label='Val')
    plt.legend()
    plt.title("Training vs Validation Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.show()

plot_loss(history)

# === Postprocessing ===
def enhance_and_sharpen(img):
    img = np.clip((img + 1) * 127.5, 0, 255).astype(np.uint8)
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    clahe = cv2.createCLAHE(2.0, (8, 8))
    bright = clahe.apply(gray)
    bright = np.clip(bright * 1.15, 0, 255).astype(np.uint8)
    sharpened = cv2.filter2D(bright, -1, np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]))
    return cv2.cvtColor(sharpened, cv2.COLOR_GRAY2RGB)

def display_images(originals, reconstructions, n=5):
    plt.figure(figsize=(12, 6))
    for i in range(n):
        plt.subplot(2, n, i+1)
        plt.imshow((originals[i] + 1) / 2)
        plt.title("Original")
        plt.axis('off')

        enhanced = enhance_and_sharpen(reconstructions[i])
        plt.subplot(2, n, n + i + 1)
        plt.imshow(enhanced)
        plt.title("Enhanced")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# === Predict + Show ===
recons = np.clip(autoencoder.predict(x_test), -1, 1)
display_images(x_test, recons, n=5)
