# 1. Setup & Dependencies

In [None]:
!pip install torch torchvision torchmeta pillow pandas numpy transformers
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchmeta.datasets.helpers import miniimagenet
from torchmeta.utils.data import BatchMetaDataLoader
from transformers import ViTModel, ViTConfig
import numpy as np

# 2. Simplified Data Loading (Mock ISIC Data - Replace with Actual API Calls)

In [None]:
# For prototyping: Use MiniImagenet as stand-in for ISIC
train_dataset = miniimagenet("data", ways=5, shots=5, test_shots=15, meta_train=True, download=True)
train_loader = BatchMetaDataLoader(train_dataset, batch_size=2, num_workers=2)

# 3. Model Implementation

In [None]:
class MetaRegViT(nn.Module):
    def __init__(self, num_base_classes=8, mask_ratio=0.2):
        super().__init__()
        self.mask_ratio = mask_ratio

        # ViT-Tiny (6M params)
        config = ViTConfig(
            image_size=224, patch_size=16, num_classes=num_base_classes,
            hidden_size=128, num_hidden_layers=4, num_attention_heads=4
        )
        self.vit = ViTModel(config)

        # Masking: Freeze 80% of attention heads
        self._freeze_heads()

    def _freeze_heads(self):
        """Freeze 80% of attention heads (static prototype)"""
        for layer in self.vit.encoder.layer:
            for head_idx in range(3):  # First 3/4 heads frozen
                param_names = [
                    f'attention.attention.query.weight_{head_idx}',
                    f'attention.attention.key.weight_{head_idx}',
                    f'attention.attention.value.weight_{head_idx}'
                ]
                for name, param in layer.named_parameters():
                    if any(p in name for p in param_names):
                        param.requires_grad = False

    def forward(self, x):
        return self.vit(x).last_hidden_state[:, 0]  # CLS token

# 4. MAML Training Loop (Phase 2)

In [None]:
def maml_train(model, train_loader, epochs=100, inner_lr=0.01, outer_lr=3e-4):
    optimizer = optim.AdamW(model.parameters(), lr=outer_lr)

    for epoch in range(epochs):
        for batch in train_loader:
            # Inner loop adaptation
            train_inputs, train_targets = batch["train"]
            test_inputs, test_targets = batch["test"]

            # Clone model for task-specific adaptation
            learner = type(model)(num_base_classes=5)
            learner.load_state_dict(model.state_dict())
            inner_optim = optim.SGD(learner.parameters(), lr=inner_lr)

            # Adapt on support set
            support_loss = nn.CrossEntropyLoss()(
                learner(train_inputs), train_targets
            )
            inner_optim.zero_grad()
            support_loss.backward()
            inner_optim.step()

            # Outer loop update
            query_loss = nn.CrossEntropyLoss()(
                learner(test_inputs), test_targets
            )
            optimizer.zero_grad()
            query_loss.backward()
            optimizer.step()

# 5. Continual Learning with EWC (Phase 3)

In [None]:
class EWC:
    def __init__(self, model, fisher_lambda=1e3):
        self.model = model
        self.fisher_lambda = fisher_lambda
        self.fisher = {}
        self.params = {n: p.clone() for n, p in model.named_parameters()}

    def compute_fisher(self, dataset):
        # Simplified: Diagonal Fisher approximation
        for batch in dataset:
            loss = nn.CrossEntropyLoss()(self.model(batch[0]), batch[1])
            loss.backward()
            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    self.fisher[n] = p.grad.data.clone().pow(2)
            break  # Use one batch for prototyping

    def penalty(self):
        return sum(
            (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
            for n, p in self.model.named_parameters()
        ) * self.fisher_lambda

# 6. Training Pipeline

In [None]:
def main():
    # Phase 1: Base Training
    model = MetaRegViT(num_base_classes=8)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=3e-4)

    # (Replace with actual base training loop)
    print("Skipping base training for prototyping...")

    # Phase 2: Meta-Learning
    maml_train(model, train_loader, epochs=10)

    # Phase 3: Continual Learning
    ewc = EWC(model)
    ewc.compute_fisher(train_loader)  # Mock Fisher

    # Mock incremental task
    incremental_optim = optim.SGD(model.parameters(), lr=1e-4)
    for _ in range(5):  # 5 epochs
        loss = criterion(model(torch.randn(2, 3, 224, 224)), torch.LongTensor([0, 1]))
        loss += ewc.penalty()
        incremental_optim.zero_grad()
        loss.backward()
        incremental_optim.step()

    print("Training complete!")

if __name__ == "__main__":
    main()