In [1]:
import learn2learn as l2l
import torch
import torch.nn as nn
import torchvision as tv
from tqdm.notebook import tqdm

In [2]:
linear = nn.Linear(784, 10)
transform = l2l.optim.ModuleTransform(torch.nn.Linear)
metaopt = l2l.optim.LearnableOptimizer(linear, transform, lr=0.01)
opt = torch.optim.SGD(metaopt.parameters(), lr=0.001)


In [3]:
metaopt.zero_grad()
opt.zero_grad()


In [4]:
train_dataset = l2l.vision.datasets.CIFARFS(root='./data', mode='train', download=False)

In [5]:
train_dataset = l2l.data.MetaDataset(train_dataset)


In [6]:
train_dataset.dataset

Dataset CIFARFS
    Number of datapoints: 38401
    Root location: ./data\cifarfs\processed\train

In [7]:
N_ITERATIONS = 1
N_TASKS = 1


# for iteration in range(ITERATIONS):
#     for task in range(TASKS)
#         pass


In [7]:
from lstmopt import Net, HypergradTransform, accuracy

torch.manual_seed(2134)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Net()
model.to(device)

metaopt = l2l.optim.LearnableOptimizer(
    model=model,
    transform=HypergradTransform,
    lr=0.1
)
metaopt.to(device)

opt = torch.optim.Adam(metaopt.parameters(), lr=3e-4)
loss = torch.nn.NLLLoss()


In [8]:
 kwargs = {'num_workers': 1,
              'pin_memory': True} if torch.cuda.is_available() else {}


In [10]:
train_loader = torch.utils.data.DataLoader(
        tv.datasets.MNIST('./data/mnist', train=True, download=False,
            transform=tv.transforms.Compose([
                tv.transforms.ToTensor(),
#                 tv.transforms.Normalize((0.1307,), (0.3081,))
                ])),
            batch_size=128, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    tv.datasets.MNIST('./data/mnist', train=False, transform=tv.transforms.Compose([
        tv.transforms.ToTensor(),
#         tv.transforms.Normalize((0.1307,), (0.3081,))
        ])),
    batch_size=128, shuffle=False, **kwargs)

In [11]:
def train():
    model.train()
    for X, y in tqdm(train_loader, leave=False):
        X, y = X.to(device), y.to(device)
        metaopt.zero_grad()
        opt.zero_grad()
        err = loss(model(X), y)
        err.backward()
        opt.step()
        metaopt.step()


train()

  0%|          | 0/469 [00:00<?, ?it/s]

torch.Size([128, 1, 28, 28])


RuntimeError: mat1 dim 1 must match mat2 dim 0

In [None]:
def test(epoch=1):
    model.eval()
    test_error = 0.0
    test_accuracy = 0.0
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            preds = model(X)
            
            test_error += loss(preds, y)
            test_accuracy += accuracy(preds, y)
        test_error /= len(test_loader)
        test_accuracy /= len(test_loader)
    print(f'Epoch: {epoch}, loss: {test_error.item()}, accuracy: {test_accuracy.item()}')
    
test()

In [None]:
for p in metaopt.parameters():
    print(p)

In [None]:
for epoch in range(10):
    train()
    test(epoch)