# 1. Setup & Dependencies

In [None]:
!pip install tensorflow tensorflow-addons isic-api
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
import tensorflow_addons as tfa

In [None]:
# 2. Data Pipeline (Mock ISIC Data - Replace with Actual API Calls)

In [None]:
class ISICLoader:
    def __init__(self, ways=5, shots=5):
        # Replace with ISIC API calls
        self.num_classes = ways
        self.shots = shots

    def get_meta_batch(self, batch_size=4):
        # Generate mock meta-learning tasks
        return (
            tf.random.normal((batch_size, self.shots, 224, 224, 3)),  # Support
            tf.random.uniform((batch_size, self.shots), 0, self.num_classes, dtype=tf.int32),
            tf.random.normal((batch_size, 15, 224, 224, 3)),  # Query
            tf.random.uniform((batch_size, 15), 0, self.num_classes, dtype=tf.int32)
        )

# 3. Model Implementation

In [None]:
class MaskedMultiHeadAttention(layers.MultiHeadAttention):
    def __init__(self, mask_ratio=0.2, **kwargs):
        super().__init__(**kwargs)
        self.mask_ratio = mask_ratio

    def _build_attention(self, rank):
        super()._build_attention(rank)
        # Initialize masking: 20% heads trainable
        self._trainable_heads = [int(i >= (1-self.mask_ratio)*self._num_heads)
                                for i in range(self._num_heads)]

    def _compute_attention(self, *args, **kwargs):
        outputs = super()._compute_attention(*args, **kwargs)
        # Zero out non-trainable heads (simplified static masking)
        outputs *= tf.constant(self._trainable_heads, dtype=tf.float32)[None, None, :, None]
        return outputs

class MetaRegViT(Model):
    def __init__(self, num_base_classes=8):
        super().__init__()
        self.patch_size = 16

        # Input preprocessing
        self.patch_extract = layers.Conv2D(128, (16, 16), strides=16, activation='linear')

        # Transformer Encoder with Masked Attention
        self.encoder = tf.keras.Sequential([
            tf.keras.Sequential([
                MaskedMultiHeadAttention(num_heads=4, key_dim=32),
                layers.Dense(128),  # FFN
            ]) for _ in range(4)
        ])

        # Classification head
        self.head = layers.Dense(num_base_classes)

    def call(self, inputs):
        # Patch embedding
        patches = self.patch_extract(inputs)
        batch = tf.shape(patches)[0]
        seq_len = (224//16)**2
        patches = tf.reshape(patches, (batch, seq_len, 128))

        # Transformer encoder
        encoded = self.encoder(patches)

        # CLS token prediction
        return self.head(encoded[:, 0])

# 4. MAML Training (Phase 2)

In [None]:
class MAMLTrainer:
    def __init__(self, model, inner_lr=0.01):
        self.model = model
        self.inner_lr = inner_lr
        self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    def train_step(self, support_x, support_y, query_x, query_y):
        # Inner loop
        task_weights = []
        for i in range(support_x.shape[0]):  # Batch dimension
            with tf.GradientTape() as tape:
                logits = self.model(support_x[i])
                loss = self.loss_fn(support_y[i], logits)
            grads = tape.gradient(loss, self.model.trainable_weights)
            task_weights.append([
                w - self.inner_lr * g for w, g in zip(self.model.weights, grads)
            ])

        # Outer loop
        with tf.GradientTape() as outer_tape:
            total_loss = 0
            for i, weights in enumerate(task_weights):
                # Apply temporary weights
                original_weights = self.model.get_weights()
                self.model.set_weights(weights)
                query_logits = self.model(query_x[i])
                total_loss += self.loss_fn(query_y[i], query_logits)
                self.model.set_weights(original_weights)

        # Update base model
        gradients = outer_tape.gradient(total_loss, self.model.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_weights))
        return total_loss

# 5. Continual Learning with EWC (Phase 3)

In [None]:
class EWCRegularizer:
    def __init__(self, model, fisher_lambda=1e3):
        self.model = model
        self.fisher = {}
        self.opt_states = {}
        self.lambda_ = fisher_lambda

    def compute_fisher(self, dataset):
        for name, var in self.model.named_trainable_weights:
            self.fisher[name] = tf.zeros_like(var)

        for x, y in dataset.take(100):  # Use 100 samples
            with tf.GradientTape() as tape:
                logits = self.model(x)
                loss = self.loss_fn(y, logits)
            grads = tape.gradient(loss, self.model.trainable_weights)
            for name, g in zip(self.model.trainable_weights_names, grads):
                self.fisher[name] += g**2

    def __call__(self):
        penalty = 0
        for name, var in self.model.named_trainable_weights:
            penalty += tf.reduce_sum(self.fisher[name] * (var - self.opt_states[name])**2)
        return self.lambda_ * penalty

# 6. Training Pipeline

In [None]:
def main():
    # Initialize
    model = MetaRegViT()
    dataloader = ISICLoader()
    maml = MAMLTrainer(model)
    maml.optimizer = tf.keras.optimizers.Adam(3e-4)

    # Phase 1: Base Training (Simplified)
    print("Skipping base pretraining...")

    # Phase 2: Meta-Learning
    for epoch in range(10):
        support_x, support_y, query_x, query_y = dataloader.get_meta_batch()
        loss = maml.train_step(support_x, support_y, query_x, query_y)
        print(f"Meta-Epoch {epoch}: Loss={loss.numpy():.2f}")

    # Phase 3: Continual Learning with EWC
    ewc = EWCRegularizer(model)
    ewc.compute_fisher(tf.data.Dataset.from_tensor_slices((
        tf.random.normal((100, 224, 224, 3)),  # Mock data
        tf.random.uniform((100,), 0, 8, dtype=tf.int32)
    )).batch(4))

    # Mock incremental task training
    opt = tf.keras.optimizers.SGD(1e-4)
    for _ in range(5):
        with tf.GradientTape() as tape:
            logits = model(tf.random.normal((4, 224, 224, 3)))
            loss = tf.keras.losses.sparse_categorical_crossentropy(
                tf.random.uniform((4,), 0, 8, dtype=tf.int32), logits, from_logits=True)
            loss += ewc()
        grads = tape.gradient(loss, model.trainable_weights)
        opt.apply_gradients(zip(grads, model.trainable_weights))

    print("Training complete!")

if __name__ == "__main__":
    main()