In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Normalize
import numpy as np

In [2]:
# Import necessary libraries
import torch
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

# Define a function to split the FashionMNIST dataset into tasks
def create_fashion_mnist_tasks(root_dir, num_tasks=5, transform=None):
    # Load the entire FashionMNIST dataset
    full_dataset = FashionMNIST(root=root_dir, train=True, download=True, transform=transform)

    # Determine the number of classes per task
    classes_per_task = len(full_dataset.classes) // num_tasks

    # Create tasks by splitting the dataset
    tasks = []
    for task_idx in range(num_tasks):
        # Calculate class indices for the current task
        class_start = task_idx * classes_per_task
        class_end = class_start + classes_per_task
        task_classes = list(range(class_start, class_end))

        # Find indices of images belonging to the current task's classes
        task_indices = [i for i, (_, label) in enumerate(full_dataset) if label in task_classes]

        # Create a Subset for the current task
        task_dataset = Subset(full_dataset, task_indices)
        tasks.append(task_dataset)

    return tasks

# Define the transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create tasks
root_dir = './data'
tasks = create_fashion_mnist_tasks(root_dir, num_tasks=5, transform=transform)

# Preview the number of samples in each task
task_sizes = [len(task) for task in tasks]
task_sizes


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 11923559.03it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 202936.77it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3728233.31it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5369877.39it/s]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



[12000, 12000, 12000, 12000, 12000]

In [3]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device.")

Using cpu device.


In [4]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x), 2))
        x = nn.functional.relu(nn.functional.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [5]:
def maml_update(model, optimizer, loss_fn, data_loader, steps=1, alpha=0.001):
    # Create a copy of the model's initial state
    initial_state = {name: param.clone() for name, param in model.named_parameters()}

    # Task-specific update
    for step in range(steps):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            # Apply the update to the model (simulating gradient descent)
            with torch.no_grad():
                for name, param in model.named_parameters():
                    param -= alpha * param.grad

    # Update the meta-model parameters
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    meta_loss = 0
    for inputs, labels in data_loader:
        outputs = model(inputs)
        meta_loss += loss_fn(outputs, labels)

    meta_optimizer.zero_grad()
    meta_loss.backward()
    meta_optimizer.step()

    # Restore the model to its initial state
    with torch.no_grad():
        for name, param in model.named_parameters():
            param.copy_(initial_state[name])

In [7]:
model = SimpleCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [8]:
# Assuming tasks is a list of DataLoader instances, each representing a task
for task_id, task_loader in enumerate(tasks):
    print(f"Training on task {task_id}")
    maml_update(model, optimizer, nn.CrossEntropyLoss(), task_loader)

Training on task 0


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x49 and 3136x128)

In [None]:
def evaluate_model(model, tasks):
    model.eval()
    total_accuracy = 0
    for task_loader in tasks:
        correct = 0
        total = 0
        for inputs, labels in task_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total
        total_accuracy += accuracy
        print(f"Task Accuracy: {accuracy}")
    print(f"Average Accuracy: {total_accuracy / len(tasks)}")

In [None]:
evaluate_model(model, tasks)