In [15]:
import collections
import os
import warnings
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchmeta.datasets.helpers import omniglot
from torchmeta.modules import MetaBatchNorm2d, MetaConv2d, MetaLinear, MetaModule, MetaSequential
from torchmeta.utils.data import BatchMetaDataLoader
from torchmeta.utils.gradient_based import gradient_update_parameters
from tqdm import tqdm

warnings.filterwarnings("ignore")

### 모델 정의

In [16]:
# classification model 정의
class ConvNet(MetaModule):
    def __init__(self, in_channels: int, out_features: int) -> None:
        super(ConvNet, self).__init__()
        self.in_channels = in_channels
        self.out_features = out_features
        self.hidden_size = 64

        self.convs = MetaSequential(
            self.convBlock(self.in_channels, self.hidden_size, 3), 
            self.convBlock(self.hidden_size, self.hidden_size, 3),
            self.convBlock(self.hidden_size, self.hidden_size, 3),
            self.convBlock(self.hidden_size, self.hidden_size, 3),
        )

        self.linear = MetaLinear(self.hidden_size, self.out_features)
    
    @classmethod
    def convBlock(cls, in_channels: int, out_channels: int, kernel_size: int) -> MetaSequential:
        return MetaSequential(
            MetaConv2d(in_channels, out_channels, kernel_size, padding=1, stride=3),
            MetaBatchNorm2d(out_channels, momentum=1.0, track_running_stats=False),
            nn.ReLU(),
        )
    
    def forward(
        self, x:torch.Tensor, params: Optional[collections.OrderedDict] = None
    ) -> torch.Tensor:
        x_convs = self.convs(x, params=self.get_subdict(params, "convs")) # [batch_size, hidden_size, 1, 1]
        output = self.linear(x_convs.flatten(start_dim=1), params=self.get_subdict(params, "linear")) # [batch_size, self.out_features]
        return output

### 훈련 및 테스트

In [26]:
# 학습함수 정의 -> 이게 핵심임. 
def train_maml(
    device: str,
    task_batch_size: int,
    task_batch: Dict[str, List[torch.Tensor]],
    model: ConvNet, 
    criterion: nn.CrossEntropyLoss, # classification이므로
    optimizer: torch.optim.Adam,
) -> Tuple[float, float]:
    model.train()
    optimizer.zero_grad()
    
    support_xs = task_batch['train'][0].to(device=device) # [batch_size, num_shots*num_ways, 1, 28, 28]
    support_ys = task_batch['train'][1].to(device=device) # [batch_size, num_ways]
    query_xs = task_batch['test'][0].to(device=device) # [batch_size, num_shots*num_ways, 1, 28, 28]
    query_ys = task_batch['test'][1].to(device=device) # [batch_size, num_ways]

    outer_loss = torch.tensor(0.0, device=device)
    accuracy = torch.tensor(0.0, device=device)
    # inner loop 진행 + task마다 outer loop를 계산을 위해 계산
    for support_x, support_y, query_x, query_y in zip(support_xs, support_ys, query_xs, query_ys): # support_x : [num_shots, 1]
        support_pred = model(support_x)
        inner_loss = criterion(support_pred, support_y)

        params = gradient_update_parameters(model, inner_loss, step_size=0.4, first_order=True)

        query_pred = model(query_x, params=params)
        outer_loss += criterion(query_pred, query_y)

        with torch.no_grad():
            _, query_pred = torch.max(query_pred, dim=-1)
            accuracy += torch.mean(query_pred.eq(query_y).float())
    
    outer_loss.div_(task_batch_size)

    outer_loss.backward()
    optimizer.step()

    accuracy.div_(task_batch_size)
    return accuracy.item(), outer_loss.item()

In [27]:
# 테스트 함수 정의
def test_maml(
    device: str,
    task_batch_size: int,
    task_batch: Dict[str, List[torch.Tensor]],
    model: ConvNet, 
    criterion: nn.CrossEntropyLoss, # classification이므로
) -> Tuple[float, float]:
    model.eval()
    
    support_xs = task_batch['train'][0].to(device=device) # [batch_size, num_shots*num_ways, 1, 28, 28]
    support_ys = task_batch['train'][1].to(device=device) # [batch_size, num_ways]
    query_xs = task_batch['test'][0].to(device=device) # [batch_size, num_shots*num_ways, 1, 28, 28]
    query_ys = task_batch['test'][1].to(device=device) # [batch_size, num_ways]

    outer_loss = torch.tensor(0.0, device=device)
    accuracy = torch.tensor(0.0, device=device)
    # inner loop 진행 + task마다 outer loop를 계산을 위해 계산
    for support_x, support_y, query_x, query_y in zip(support_xs, support_ys, query_xs, query_ys): # support_x : [num_shots, 1]
        support_pred = model(support_x)
        inner_loss = criterion(support_pred, support_y)

        params = gradient_update_parameters(model, inner_loss, step_size=0.4, first_order=True)

        query_pred = model(query_x, params=params)
        outer_loss += criterion(query_pred, query_y)

        with torch.no_grad():
            _, query_pred = torch.max(query_pred, dim=-1)
            accuracy += torch.mean(query_pred.eq(query_y).float())
    
    outer_loss.div_(task_batch_size)
    accuracy.div_(task_batch_size)
    return accuracy.item(), outer_loss.item()

In [28]:
def get_dataloader(
    config: Dict[str, Any]
) -> Tuple[BatchMetaDataLoader, BatchMetaDataLoader, BatchMetaDataLoader]:
    train_dataset = omniglot(
        folder=config["folder_name"],
        shots=config["num_shots"],
        # test_shots=1, # default = shots
        ways=config["num_ways"],
        shuffle=True,
        meta_train=True,
        download=config["download"],
    )
    train_dataloader = BatchMetaDataLoader(
        train_dataset, batch_size=config["task_batch_size"], shuffle=True, num_workers=1
    )

    val_dataset = omniglot(
        folder=config["folder_name"],
        shots=config["num_shots"],
        # test_shots=1, # default = shots
        ways=config["num_ways"],
        shuffle=True,
        meta_val=True,
        download=config["download"],
    )
    val_dataloader = BatchMetaDataLoader(
        val_dataset, batch_size=config["task_batch_size"], shuffle=True, num_workers=1
    )

    test_dataset = omniglot(
        folder=config["folder_name"],
        shots=config["num_shots"],
        # test_shots=1, # default = shots
        ways=config["num_ways"],
        shuffle=True,
        meta_test=True,
        download=config["download"],
    )
    test_dataloader = BatchMetaDataLoader(
        test_dataset, batch_size=config["task_batch_size"], shuffle=True, num_workers=1
    )
    return train_dataloader, val_dataloader, test_dataloader

In [29]:
config = {
    "folder_name": "../load_dataset/dataset",
    "download": False,
    "num_shots": 5,
    "num_ways": 5,
    "output_folder": "saved_model",
    "task_batch_size": 32,  # 필수
    "num_task_batch_train": 600,  # 필수
    "num_task_batch_test": 200,  # 필수
    "device": "cuda",  # 필수
}

train_dataloader, val_dataloader, test_dataloader = get_dataloader(config)

model = ConvNet(in_channels=1, out_features=config["num_ways"]).to(
    device=config["device"]
)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [30]:
def save_model(output_folder: str, model: ConvNet, title: str) -> None:
    if not os.path.isdir(output_folder):
        os.mkdir(output_folder)
    filename = os.path.join(output_folder, title)

    with open(filename, "wb") as f:
        state_dict = model.state_dict()
        torch.save(state_dict, f)
    print("Model is saved in", filename)


def load_model(output_folder: str, model: ConvNet, title: str) -> None:
    filename = os.path.join(output_folder, title)
    model.load_state_dict(torch.load(filename))
    print("Model is loaded")

In [31]:
def print_graph(
    train_accuracies: List[float],
    val_accuracies: List[float],
    train_losses: List[float],
    val_losses: List[float],
) -> None:
    fig, axs = plt.subplots(1, 2, figsize=(14, 6))

    axs[0].plot(train_accuracies, label="train_acc")
    axs[0].plot(val_accuracies, label="test_acc")
    axs[0].set_title("Accuracy")
    axs[0].legend()

    axs[1].plot(train_losses, label="train_loss")
    axs[1].plot(val_losses, label="test_loss")
    axs[1].set_title("Loss")
    axs[1].legend()

    fig.show()

In [32]:
# 메타-트레이닝
with tqdm(
    zip(train_dataloader, val_dataloader), total=config["num_task_batch_train"]
) as pbar:
    train_accuracies = []
    val_accuracies = []
    train_losses = []
    val_losses = []

    for task_batch_idx, (train_batch, val_batch) in enumerate(pbar):
        if task_batch_idx >= config["num_task_batch_train"]:
            break

        train_accuracy, train_loss = train_maml(
            device=config["device"],
            task_batch_size=config["task_batch_size"],
            task_batch=train_batch,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
        )
        val_accuracy, val_loss = test_maml(
            device=config["device"],
            task_batch_size=config["task_batch_size"],
            task_batch=val_batch,
            model=model,
            criterion=criterion,
        )

        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        pbar.set_postfix(
            train_accuracy="{0:.4f}".format(train_accuracy),
            val_accuracy="{0:.4f}".format(val_accuracy),
            train_loss="{0:.4f}".format(train_loss),
            val_loss="{0:.4f}".format(val_loss),
        )

'''# 모델 저장하기
save_model(
    output_folder=config["output_folder"], model=model, title="maml_classification.th"
)'''

print_graph(
    train_accuracies=train_accuracies,
    val_accuracies=val_accuracies,
    train_losses=train_losses,
    val_losses=val_losses,
)

100%|██████████| 600/600 [09:17<00:00,  1.08it/s, train_accuracy=0.9437, train_loss=0.1825, val_accuracy=0.9162, val_loss=0.2088]


: 

: 

In [None]:
# 모델 불러오기
load_model(
    output_folder=config["output_folder"], model=model, title="maml_classification.th"
)

# 메타-테스팅
with tqdm(test_dataloader, total=config["num_task_batch_test"]) as pbar:
    sum_test_accuracies = 0.0
    sum_test_losses = 0.0

    for task_batch_idx, test_batch in enumerate(pbar):
        if task_batch_idx >= config["num_task_batch_test"]:
            break

        test_accuracy, test_loss = test_maml(
            device=config["device"],
            task_batch_size=config["task_batch_size"],
            task_batch=test_batch,
            model=model,
            criterion=criterion,
        )

        sum_test_accuracies += test_accuracy
        sum_test_losses += test_loss
        pbar.set_postfix(
            test_accuracy="{0:.4f}".format(sum_test_accuracies / (task_batch_idx + 1)),
            test_loss="{0:.4f}".format(sum_test_losses / (task_batch_idx + 1)),
        )