In [2]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, Dense, GlobalAveragePooling2D, MaxPooling2D,
    BatchNormalization, Dropout, Concatenate, Add, LeakyReLU,
    SeparableConv2D, Activation, AveragePooling2D,
    GlobalMaxPooling2D, Reshape, Multiply
)
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, BackupAndRestore
from tensorflow.keras.regularizers import l2
from tensorflow_addons.optimizers import AdamW
from sklearn.metrics import classification_report

# Configuration
IMG_SIZE = 128
BATCH_SIZE = 4
NUM_CLASSES = 3
EPOCHS = 75
BASE_LR = 2e-4
WEIGHT_DECAY = 2e-5

# Enhanced GPU settings with mixed precision
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print(f"Using {len(gpus)} GPU(s) with mixed precision")
else:
    print("Using CPU with mixed precision")

# Enhanced attention mechanism with spatial attention
def cbam_block(input_tensor, ratio=8):
    channel_avg = GlobalAveragePooling2D()(input_tensor)
    channel_max = GlobalMaxPooling2D()(input_tensor)
    channel = Concatenate()([channel_avg, channel_max])
    channel = Reshape((1, 1, -1))(channel)
    channel = Dense(input_tensor.shape[-1] // ratio, activation='relu')(channel)
    channel = Dense(input_tensor.shape[-1], activation='sigmoid')(channel)
    
    channel_out = Multiply()([input_tensor, channel])
    
    spatial_avg = tf.reduce_mean(channel_out, axis=-1, keepdims=True)
    spatial_max = tf.reduce_max(channel_out, axis=-1, keepdims=True)
    spatial = Concatenate(axis=-1)([spatial_avg, spatial_max])
    spatial = Conv2D(1, (7, 7), padding='same', activation='sigmoid')(spatial)
    
    return Multiply()([channel_out, spatial])

# Enhanced residual block
def residual_block(x, filters, kernel_size=3, stride=1):
    shortcut = x
    
    if stride != 1 or x.shape[-1] != filters:
        shortcut = Conv2D(filters, (1, 1), strides=stride, padding='same')(x)
        shortcut = BatchNormalization()(shortcut)
    
    y = SeparableConv2D(filters, kernel_size, strides=stride, padding='same')(x)
    y = BatchNormalization()(y)
    y = LeakyReLU(0.1)(y)
    
    y = SeparableConv2D(filters, kernel_size, padding='same')(y)
    y = BatchNormalization()(y)
    y = cbam_block(y)
    
    output = Add()([shortcut, y])
    return LeakyReLU(0.1)(output)

# Simplified ensemble model
def create_ensemble_model():
    inputs = Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    efficientnet = EfficientNetB0(weights='imagenet', include_top=False, input_tensor=inputs)
    for layer in efficientnet.layers[-40:]:
        layer.trainable = True
    eff_features = efficientnet.output
    
    x = Conv2D(32, (7, 7), strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.1)(x)
    x = MaxPooling2D((3, 3), strides=2, padding='same')(x)
    x = residual_block(x, 64)
    x = residual_block(x, 128)
    
    eff_avg = GlobalAveragePooling2D()(eff_features)
    custom_avg = GlobalAveragePooling2D()(x)
    combined = Concatenate()([eff_avg, custom_avg])
    
    combined = Dense(512, kernel_regularizer=l2(WEIGHT_DECAY))(combined)
    combined = BatchNormalization()(combined)
    combined = LeakyReLU(0.1)(combined)
    combined = Dropout(0.3)(combined)
    outputs = Dense(NUM_CLASSES, activation='softmax', dtype='float32')(combined)
    
    return Model(inputs=inputs, outputs=outputs)

# Enhanced data augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.3,
    height_shift_range=0.3,
    shear_range=0.2,
    zoom_range=0.3,
    horizontal_flip=True,
    vertical_flip=True,
    validation_split=0.2,
    fill_mode='nearest',
    brightness_range=[0.7, 1.3],
    channel_shift_range=20.0
)

validation_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

# Load data
train_generator = train_datagen.flow_from_directory(
    'Train',
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

validation_generator = validation_datagen.flow_from_directory(
    'Train',
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

# Create and compile model
base_optimizer = AdamW(learning_rate=BASE_LR, weight_decay=WEIGHT_DECAY)
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(base_optimizer)
model = create_ensemble_model()
model.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC()]
)

# Enhanced callbacks (removed ModelCheckpoint due to issue)
callbacks = [
    EarlyStopping(monitor='val_accuracy', patience=15, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7, verbose=1),
    BackupAndRestore(backup_dir='./backup')
]

# Gradient accumulation training function with manual checkpointing
def train_with_gradient_accumulation(model, train_generator, validation_generator, steps_per_epoch, validation_steps, epochs, callbacks):
    optimizer = model.optimizer
    loss_fn = tf.keras.losses.CategoricalCrossentropy()
    train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
    val_acc_metric = tf.keras.metrics.CategoricalAccuracy()
    accum_steps = 4
    effective_batch_size = BATCH_SIZE * accum_steps
    best_val_acc = -float('inf')  # For manual checkpointing
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        train_iterator = iter(train_generator)
        val_iterator = iter(validation_generator)
        
        train_acc_metric.reset_states()
        val_acc_metric.reset_states()
        total_train_loss = 0
        total_val_loss = 0
        
        for step in range(steps_per_epoch):
            gradients = [tf.zeros_like(var) for var in model.trainable_variables]
            step_train_loss = 0
            
            for _ in range(accum_steps):
                try:
                    x_batch, y_batch = next(train_iterator)
                except StopIteration:
                    train_iterator = iter(train_generator)
                    x_batch, y_batch = next(train_iterator)
                
                with tf.GradientTape() as tape:
                    logits = model(x_batch, training=True)
                    loss = loss_fn(y_batch, logits)
                    scaled_loss = optimizer.get_scaled_loss(loss)
                
                scaled_grads = tape.gradient(scaled_loss, model.trainable_variables)
                grads = optimizer.get_unscaled_gradients(scaled_grads)
                gradients = [g + acc_g for g, acc_g in zip(grads, gradients)]
                step_train_loss += loss
                train_acc_metric.update_state(y_batch, logits)
            
            step_train_loss /= accum_steps
            gradients = [g / accum_steps for g in gradients]
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            total_train_loss += step_train_loss
            
            if (step + 1) % 10 == 0 or step == steps_per_epoch - 1:
                print(f"Step {step + 1}/{steps_per_epoch}, "
                      f"Loss: {step_train_loss:.4f}, "
                      f"Accuracy: {train_acc_metric.result():.4f}")
        
        for val_step in range(validation_steps):
            try:
                x_val, y_val = next(val_iterator)
            except StopIteration:
                val_iterator = iter(validation_generator)
                x_val, y_val = next(val_iterator)
            
            val_logits = model(x_val, training=False)
            val_loss = loss_fn(y_val, val_logits)
            total_val_loss += val_loss
            val_acc_metric.update_state(y_val, val_logits)
        
        avg_train_loss = total_train_loss / steps_per_epoch
        avg_val_loss = total_val_loss / validation_steps
        train_acc = train_acc_metric.result()
        val_acc = val_acc_metric.result()
        
        print(f"Epoch {epoch + 1} Summary: "
              f"Training Loss: {avg_train_loss:.4f}, "
              f"Training Accuracy: {train_acc:.4f}, "
              f"Validation Loss: {avg_val_loss:.4f}, "
              f"Validation Accuracy: {val_acc:.4f}")
        
        # Manual checkpointing
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            model.save_weights('best_model.h5')
            print(f"Saved best model weights with validation accuracy: {val_acc:.4f}")
        
        logs = {
            'loss': float(avg_train_loss),
            'accuracy': float(train_acc),
            'val_loss': float(avg_val_loss),
            'val_accuracy': float(val_acc)
        }
        for callback in callbacks:
            callback.on_epoch_end(epoch, logs=logs)

# Two-stage training with gradient accumulation
print("Stage 1: Feature extraction")
train_with_gradient_accumulation(
    model,
    train_generator,
    validation_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    validation_steps=validation_generator.samples // BATCH_SIZE,
    epochs=30,
    callbacks=callbacks
)

# Fine-tuning
print("Stage 2: Fine-tuning")
for layer in model.layers:
    layer.trainable = True

base_optimizer = AdamW(learning_rate=BASE_LR/10, weight_decay=WEIGHT_DECAY)
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(base_optimizer)
model.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC()]
)

train_with_gradient_accumulation(
    model,
    train_generator,
    validation_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    validation_steps=validation_generator.samples // BATCH_SIZE,
    epochs=EPOCHS-30,
    callbacks=callbacks
)

# [Rest of your code for TTA and evaluation remains unchanged]
# ...
print("Training and evaluation complete!")

Using 1 GPU(s) with mixed precision
Found 27123 images belonging to 3 classes.
Found 6779 images belonging to 3 classes.
Stage 1: Feature extraction

Epoch 1/30
Step 10/6780, Loss: 1.7359, Accuracy: 0.2375
Step 20/6780, Loss: 1.3998, Accuracy: 0.3063
Step 30/6780, Loss: 0.9298, Accuracy: 0.3417
Step 40/6780, Loss: 1.1533, Accuracy: 0.3938
Step 50/6780, Loss: 0.8338, Accuracy: 0.4300
Step 60/6780, Loss: 1.4761, Accuracy: 0.4510
Step 70/6780, Loss: 0.9967, Accuracy: 0.4670
Step 80/6780, Loss: 0.9745, Accuracy: 0.4820
Step 90/6780, Loss: 0.9505, Accuracy: 0.4903
Step 100/6780, Loss: 0.5936, Accuracy: 0.5025
Step 110/6780, Loss: 0.8306, Accuracy: 0.5136
Step 120/6780, Loss: 0.5965, Accuracy: 0.5271
Step 130/6780, Loss: 1.2447, Accuracy: 0.5361
Step 140/6780, Loss: 0.6057, Accuracy: 0.5437
Step 150/6780, Loss: 0.8607, Accuracy: 0.5554
Step 160/6780, Loss: 0.8380, Accuracy: 0.5652
Step 170/6780, Loss: 0.7780, Accuracy: 0.5684
Step 180/6780, Loss: 0.6232, Accuracy: 0.5708
Step 190/6780, Loss:

KeyboardInterrupt: 