<a href="https://colab.research.google.com/github/uakarsh/lightning-flash/blob/master/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# adapted from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py#L154

In [None]:
!pip install learn2learn
!pip install kornia
!pip install lightning-flash
!pip install 'lightning-flash[image]'

In [None]:
## Train file
!wget https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1

--2022-07-06 14:03:20--  https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.11.18, 2620:100:601c:18::a27d:612
Connecting to www.dropbox.com (www.dropbox.com)|162.125.11.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/dl/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl [following]
--2022-07-06 14:03:20--  https://www.dropbox.com/s/dl/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc9143f4c9c9d1358c5dc2f8f70f.dl.dropboxusercontent.com/cd/0/get/BonlWaNWpPqfbmVfhuL9L_PdHqR-yBb5BiC0o5SoaVx-JVtoaYW_BVD74812lIdsNVjYR6CXoEEk85lQnLjClFCRWOBmtzcltTV9uxsJm-NguKfP-qGcUjvtYmVbIh_E1BVwTxorF9_TJAsCLYjCpNBwa30197VtA4X5QZ-DQpXSqw/file?dl=1# [following]
--2022-07-06 14:03:21--  https://uc9143f4c9c9d1358c5dc2f8f70f.dl.dropboxusercontent.com/cd/0/get/BonlWaNWpPqfb

In [None]:
## Validation File
!wget https://www.dropbox.com/s/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl?dl=1

--2022-07-06 14:03:35--  https://www.dropbox.com/s/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.6.18, 2620:100:601c:18::a27d:612
Connecting to www.dropbox.com (www.dropbox.com)|162.125.6.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/dl/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl [following]
--2022-07-06 14:03:35--  https://www.dropbox.com/s/dl/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc914baf6ffcf61a2dbe6da68796.dl.dropboxusercontent.com/cd/0/get/BomjGsUa6V-Fl0AxfbjFITON5qgxjVFWtqlEZP_XkKazHYdR86WgUNAz0N5_Y1HM52U8Q3MMNC4Nwii_ScnkYqAVUj88ezCAo79OrJ1yQtWDEFuj0dnLo9RW1VGFd04pZm9yx5YNdp1YJ5PM8cRK6TEbrsaCmJ77IFSu9BDouLzS6A/file?dl=1# [following]
--2022-07-06 14:03:36--  https://uc914baf6ffcf61a2dbe6da68796.dl.dropboxusercontent.com/cd/0/get/

In [None]:
!cp './mini-imagenet-cache-train.pkl?dl=1' './mini-imagenet-cache-train.pkl'
!cp './mini-imagenet-cache-validation.pkl?dl=1' './mini-imagenet-cache-validation.pkl'

In [None]:
!rm './mini-imagenet-cache-train.pkl?dl=1'
!rm './mini-imagenet-cache-validation.pkl?dl=1'

In [None]:
warnings.simplefilter("ignore")

from typing import Callable, Tuple, Union
from dataclasses import dataclass
import warnings

import kornia.augmentation as Ka
import kornia.geometry as Kg
import learn2learn as l2l

import torch
import torchvision
from torch import nn
import torchvision.transforms as T

from flash.core.data.io.input_transform import InputTransform
import flash
from flash.core.data.io.input import DataKeys
from flash.core.data.transforms import ApplyToKeys, kornia_collate
from flash.image import ImageClassificationData, ImageClassifier

from PIL import Image
import numpy as np

In [None]:
# download MiniImagenet
train_dataset = l2l.vision.datasets.MiniImagenet(root="./", mode="train", download=False)
val_dataset = l2l.vision.datasets.MiniImagenet(root="./", mode="validation", download=False)

In [None]:
@dataclass
class ImageClassificationInputTransform(InputTransform):

    image_size: Tuple[int, int] = (196, 196)
    mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
    std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)

    def per_sample_transform(self):
        return T.Compose([
        ApplyToKeys(
            DataKeys.INPUT,
            T.Compose([
                T.ToTensor(),
                Kg.Resize((196, 196)),
                # SPATIAL
                Ka.RandomHorizontalFlip(p=0.25),
                Ka.RandomRotation(degrees=90.0, p=0.25),
                Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25),
                Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
                
                # PIXEL-LEVEL
                Ka.ColorJitter(brightness=1 / 30, p=0.25),  # brightness
                Ka.ColorJitter(saturation=1 / 30, p=0.25),  # saturation
                Ka.ColorJitter(contrast=1 / 30, p=0.25),  # contrast
                Ka.ColorJitter(hue=1 / 30, p=0.25),  # hue
                Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25),
                Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25),
            ]),
        ),
        ApplyToKeys(DataKeys.TARGET, torch.as_tensor)]
    )

    def train_per_sample_transform(self):
        return T.Compose(
            [
                ApplyToKeys(
                    "input",
                    T.Compose(
                        [
                            T.ToTensor(),
                            T.Resize(self.image_size),
                            T.Normalize(self.mean, self.std),
                            T.RandomHorizontalFlip(),
                            T.ColorJitter(),
                            T.RandomAutocontrast(),
                            T.RandomPerspective(),
                        ]
                    ),
                ),
                ApplyToKeys("target", torch.as_tensor),
            ]
        )
    def per_batch_transform_on_device(self):
      return ApplyToKeys(
        DataKeys.INPUT,
        Ka.RandomHorizontalFlip(p=0.25),
    )
      

    def collate(self):
      return kornia_collate


In [None]:
# construct datamodule

datamodule = ImageClassificationData.from_tensors(
    train_data=train_dataset.x,
    train_targets=torch.from_numpy(train_dataset.y.astype(int)),
    val_data=val_dataset.x,
    val_targets=torch.from_numpy(val_dataset.y.astype(int)),
    train_transform=ImageClassificationInputTransform,
    val_transform=ImageClassificationInputTransform,
    batch_size= 1
)

In [None]:
model = ImageClassifier(
    backbone="resnet18",
    training_strategy="prototypicalnetworks",
    training_strategy_kwargs={
        "epoch_length": 10 * 16,
        "meta_batch_size": 1,
        "num_tasks": 200,
        "test_num_tasks": 2000,
        "ways": datamodule.num_classes,
        "shots": 1,
        "test_ways": 5,
        "test_shots": 1,
        "test_queries": 15,
    },
    optimizer=torch.optim.Adam,
    learning_rate=0.001,
)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/studio-lab-user/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

Using 'prototypicalnetworks' provided by learnables/learn2learn (https://github.com/learnables/learn2learn).


In [None]:
trainer = flash.Trainer(
    max_epochs=1,
    gpus=1,
    precision=16,
)
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/studio-lab-user/notebooks/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | train_metrics | ModuleDict         | 0     
1 | val_metrics   | ModuleDict         | 0     
2 | test_metrics  | ModuleDict         | 0     
3 | adapter       | Learn2LearnAdapter | 11.2 M
-----------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
22.419    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]