In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras import layers, Model
import numpy as np

# Load dataset
IMG_SIZE = 224
BATCH_SIZE = 32

dataset, info = tfds.load("tf_flowers", as_supervised=True, with_info=True)
train_ds = dataset['train']

# Preprocessing function (Resizing, Normalizing, One-hot encoding)
def preprocess(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) / 255.0
    label = tf.one_hot(label, depth=info.features['label'].num_classes)
    return image, label

train_ds = train_ds.map(preprocess).batch(BATCH_SIZE).shuffle(1000)

# Data Augmentation
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
])

# Define Base Model (Feature Extractor)
base_model = EfficientNetB0(include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3))
base_model.trainable = False  # Freeze feature extractor

# MAML Model
class MAMLModel(Model):
    def __init__(self, base_model):
        super(MAMLModel, self).__init__()
        self.base_model = base_model
        self.global_pool = layers.GlobalAveragePooling2D()
        self.dense1 = layers.Dense(128, activation='relu')
        self.output_layer = layers.Dense(info.features['label'].num_classes, activation='softmax')

    def call(self, inputs, training=False):
        x = self.base_model(inputs, training=training)
        x = self.global_pool(x)
        x = self.dense1(x)
        return self.output_layer(x)

# Instantiate MAML Model
meta_model = MAMLModel(base_model)

# Loss & Optimizer
loss_fn = tf.keras.losses.CategoricalCrossentropy()
meta_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005, clipnorm=1.0)  # Lower LR + Gradient Clipping

# MAML Training Step (Inner & Outer Loop)
@tf.function
def train_step(images, labels):
    # Inner loop
    with tf.GradientTape() as tape:
        predictions = meta_model(images, training=True)
        loss = loss_fn(labels, predictions)

    gradients = tape.gradient(loss, meta_model.trainable_variables)
    meta_optimizer.apply_gradients(zip(gradients, meta_model.trainable_variables))

    return loss

# Training Loop
EPOCHS = 10
for epoch in range(EPOCHS):
    total_loss = 0
    for images, labels in train_ds:
        images = data_augmentation(images)  # Apply augmentation
        loss = train_step(images, labels)
        total_loss += loss.numpy()

    print(f"Epoch {epoch+1}, Meta Loss: {total_loss/len(train_ds):.4f}")

# Fine-Tuning (Unfreeze Base Model)
base_model.trainable = True
fine_tune_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
meta_model.compile(optimizer=fine_tune_optimizer, loss=loss_fn, metrics=['accuracy'])

# Fine-tune for better accuracy
meta_model.fit(train_ds, epochs=5)

# Evaluate Model
test_loss, test_acc = meta_model.evaluate(train_ds)
print(f"Test Accuracy: {test_acc:.4f}")



Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/tf_flowers/3.0.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/tf_flowers/incomplete.MUGC31_3.0.1/tf_flowers-train.tfrecord*...:   0%|   …

Dataset tf_flowers downloaded and prepared to /root/tensorflow_datasets/tf_flowers/3.0.1. Subsequent calls will reuse this data.
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5
[1m16705208/16705208[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Epoch 1, Meta Loss: 1.6163
Epoch 2, Meta Loss: 1.6274
Epoch 3, Meta Loss: 1.6122
Epoch 4, Meta Loss: 1.6085
Epoch 5, Meta Loss: 1.6071
Epoch 6, Meta Loss: 1.6030
Epoch 7, Meta Loss: 1.6023
Epoch 8, Meta Loss: 1.6016
Epoch 9, Meta Loss: 1.6011
Epoch 10, Meta Loss: 1.6012
Epoch 1/5
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m132s[0m 445ms/step - accuracy: 0.2885 - loss: 1.5728
Epoch 2/5
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 119ms/step - accuracy: 0.6606 - loss: 1.1489
Epoch 3/5
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 122ms/step - accuracy: 0.7709 - loss: 0.8632
Epoch 4/5
[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━

In [None]:
# Install necessary libraries
!pip install tensorflow-datasets higher

# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import tensorflow_datasets as tfds
import higher
import random
import numpy as np

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load TF-Flowers dataset from TensorFlow Datasets
print("Loading TF-Flowers...")
ds_train = tfds.load('tf_flowers', split='train', as_supervised=True)

# Convert to PyTorch Dataset
class FlowersDataset(Dataset):
    def __init__(self, tf_dataset, transform=None):
        self.data = list(tf_dataset)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]
        # Ensure img is in the correct format (HWC) and data type (uint8)
        img = img.numpy()
        img = img.astype(np.uint8) # Cast to uint8

        # Check if image has 3 dimensions and the first dimension is not 1, 3, or 4
        if img.ndim == 3 and img.shape[0] not in (1, 3, 4):
            # Assuming channels-first, transpose to channels-last
            img = img.transpose(1, 2, 0)

        # Check if image has more than 4 channels and reduce if necessary
        if img.shape[-1] > 4:
            img = img[:,:,:3]  # Keep only the first 3 channels (RGB)

        img = transforms.ToPILImage()(img)
        if self.transform:
            img = self.transform(img)
        return img, label.numpy()

# Transform
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# Prepare dataset and shuffle
flowers_dataset = FlowersDataset(ds_train, transform=transform)

# Helper to create few-shot tasks
def create_task(dataset, num_classes=5, num_samples=5):
    class_indices = {}
    for idx, (_, label) in enumerate(dataset):
        label = int(label)
        if label not in class_indices:
            class_indices[label] = []
        class_indices[label].append(idx)

    selected_classes = random.sample(list(class_indices.keys()), num_classes)
    task_samples = []

    for cls in selected_classes:
        cls_samples = random.sample(class_indices[cls], num_samples * 2)  # support + query
        support_set = cls_samples[:num_samples]
        query_set = cls_samples[num_samples:]
        task_samples.append((support_set, query_set))

    return task_samples

# Load Pretrained Model
model = models.resnet18(pretrained=True)
# Change the final layer later per task; base model stores weights
model.to(device)

# Define MAML Training loop components
meta_optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

# In-depth MAML Training loop
epochs = 10  # Number of meta-training epochs
inner_steps = 5  # Gradient steps in the inner loop
inner_lr = 0.01  # Inner loop learning rate
num_classes = 5  # Number of classes per task
num_samples = 5  # Shots per class (support set size)

print("Starting MAML meta-training...")

for epoch in range(epochs):
    model.train()
    meta_loss = 0.0

    for task_i in range(4):  # number of tasks per meta-update
        task = create_task(flowers_dataset, num_classes, num_samples)

        # Create task-specific model head
        task_model = models.resnet18(pretrained=True)
        # Store the in_features of the original fc layer
        in_features = task_model.fc.in_features
        # Replace the final fully connected layer with an Identity layer
        task_model.fc = nn.Identity()
        task_model.to(device)
        # Load state dict of the base model
        task_model.load_state_dict(model.state_dict(), strict=False)

        # Add a new fully connected layer for the task, using the stored in_features
        task_model.fc = nn.Linear(in_features, num_classes)
        task_model.to(device)

        # Optimizer for inner loop
        optimizer = optim.SGD(task_model.parameters(), lr=inner_lr)

        # Use higher to make the inner loop differentiable
        with higher.innerloop_ctx(task_model, optimizer, copy_initial_weights=False) as (fmodel, diffopt):

            # Inner-loop training on support set
            for support_set, _ in task:
                support_images, support_labels = [], []
                for idx in support_set:
                    img, lbl = flowers_dataset[idx]
                    support_images.append(img.unsqueeze(0))
                    support_labels.append(lbl)
                support_images = torch.cat(support_images).to(device)
                support_labels = torch.tensor(support_labels).to(device)

                support_preds = fmodel(support_images)
                support_loss = loss_fn(support_preds, support_labels)
                diffopt.step(support_loss)

            # Outer-loop evaluation on query set
            query_loss = 0.0
            for _, query_set in task:
                query_images, query_labels = [], []
                for idx in query_set:
                    img, lbl = flowers_dataset[idx]
                    query_images.append(img.unsqueeze(0))
                    query_labels.append(lbl)

                query_images = torch.cat(query_images).to(device)
                query_labels = torch.tensor(query_labels).to(device)

                query_preds = fmodel(query_images)
                query_loss += loss_fn(query_preds, query_labels)

            # Accumulate meta-loss across tasks
            meta_loss += query_loss

    meta_loss /= 4  # average over tasks
    meta_optimizer.zero_grad()
    meta_loss.backward()
    meta_optimizer.step()

    print(f"Epoch {epoch + 1}/{epochs}, Meta Loss: {meta_loss.item():.4f}")

# Save meta-trained model
torch.save(model.state_dict(), 'maml_tf_flowers.pth')
print("Model saved as maml_tf_flowers.pth")

Using device: cuda
Loading TF-Flowers...




Starting MAML meta-training...




Epoch 1/10, Meta Loss: 16.3598




Epoch 2/10, Meta Loss: 16.3026




Epoch 3/10, Meta Loss: 16.9612




Epoch 4/10, Meta Loss: 16.6634




Epoch 5/10, Meta Loss: 16.7027




Epoch 6/10, Meta Loss: 17.4599




Epoch 7/10, Meta Loss: 16.5216




Epoch 8/10, Meta Loss: 16.0793




Epoch 9/10, Meta Loss: 17.2957




Epoch 10/10, Meta Loss: 16.9557
Model saved as maml_tf_flowers.pth


In [None]:
!pip install higher
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import tensorflow_datasets as tfds
import higher
import random
import numpy as np

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load TF-Flowers dataset
ds_train = tfds.load('tf_flowers', split='train', as_supervised=True)

# Convert to PyTorch Dataset
class FlowersDataset(Dataset):
    def __init__(self, tf_dataset, transform=None):
        self.data = list(tf_dataset)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]
        img = img.numpy()
        if img.ndim == 3 and img.shape[-1] not in (1, 3, 4):
            img = img[:,:,:3]  # Ensure RGB
        elif img.shape[-1] == 1:
            img = np.repeat(img, 3, axis=-1)
        img = transforms.ToPILImage()(img)
        if self.transform:
            img = self.transform(img)
        return img, int(label.numpy())

# Define transformations
transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
])

flowers_dataset = FlowersDataset(ds_train, transform=transform)
data_loader = DataLoader(flowers_dataset, batch_size=32, shuffle=True)

# Few-shot task creator
def create_task(dataset, num_classes=5, num_samples=5):
    class_indices = {}
    for idx, (_, label) in enumerate(dataset):
        class_indices.setdefault(label, []).append(idx)
    selected_classes = random.sample(list(class_indices.keys()), num_classes)
    return [(random.sample(class_indices[cls], num_samples * 2), cls) for cls in selected_classes]

# Load ResNet50 model
model = models.resnet50(pretrained=True)
model.fc = nn.Identity()  # Remove final layer for feature extraction
model.to(device)

# Optimizers and loss function
meta_optimizer = optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

# Training hyperparameters
epochs = 20
inner_lr = 0.01
num_classes = 5
num_samples = 5

# MAML meta-training
print("Starting MAML meta-training...")
for epoch in range(epochs):
    model.train()
    meta_loss = 0.0

    for _ in range(5):  # More tasks per meta-update
        task = create_task(flowers_dataset, num_classes, num_samples)
        task_head = nn.Linear(2048, num_classes).to(device)
        task_optimizer = optim.SGD(task_head.parameters(), lr=inner_lr)

with higher.innerloop_ctx(task_head, task_optimizer, copy_initial_weights=True) as (fhead, diffopt):
    for support_set, _ in task:
        support_images, support_labels = zip(*[flowers_dataset[idx] for idx in support_set[:num_samples]])
        support_images = torch.stack([model(img.unsqueeze(0).to(device)).detach() for img in support_images])
        support_labels = torch.tensor(support_labels).to(device)

        # Remove one-hot encoding for support_labels
        # support_labels = nn.functional.one_hot(support_labels, num_classes=num_classes).float() # Convert to one-hot and float

        diffopt.step(loss_fn(fhead(support_images), support_labels))

    query_loss = 0.0
    for _, query_set in task:
        query_images, query_labels = zip(*[flowers_dataset[idx] for idx in query_set[num_samples:]])
        query_images = torch.stack([model(img.unsqueeze(0).to(device)).detach() for img in query_images])
        query_labels = torch.tensor(query_labels).to(device)

        # Remove one-hot encoding for query_labels
        # query_labels = nn.functional.one_hot(query_labels, num_classes=num_classes).float() # Convert to one-hot and float

        query_loss += loss_fn(fhead(query_images), query_labels)
    meta_loss += query_loss

    meta_loss /= 5
    meta_optimizer.zero_grad()
    meta_loss.backward()
    meta_optimizer.step()
    print(f"Epoch {epoch + 1}/{epochs}, Meta Loss: {meta_loss.item():.4f}")

# Save best model
torch.save(model.state_dict(), 'best_maml_tf_flowers_resnet50.pth')
print("Meta-trained model saved.")

Using device: cuda




Starting MAML meta-training...


RuntimeError: Expected target size [5, 5], got [5]