In [1]:
import os
import warnings
from typing import Any, Dict, List, Tuple

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
from tqdm import tqdm

warnings.filterwarnings("ignore")

### 학습 모델 정의

In [3]:
class PrototypicalNet(nn.Module):
    def __init__(self, in_channels: int, num_ways: int, num_shots: int) -> None:
        super(PrototypicalNet, self).__init__()
        self.in_channels = in_channels
        self.emb_size = 64
        self.num_ways = num_ways
        self.num_support = num_ways * num_shots
        self.num_query = self.num_support # num_query를 input으로 받아도 됨.

        self.embedding_net = nn.Sequential(
            self.convBlock(self.in_channels, self.emb_size, 3),
            self.convBlock(self.emb_size, self.emb_size, 3),
            self.convBlock(self.emb_size, self.emb_size, 3),
            self.convBlock(self.emb_size, self.emb_size, 3), # [batch_size * num_support, self.emb_size, 1, 1]
            nn.Flatten(start_dim=1), # [batch_size * num_support, self.emb_size]
            nn.Linear(self.emb_size, self.emb_size) # [batch_size * num_support, self.emb_size]   
        )
    

    @classmethod
    def convBlock(cls, in_channels: int, out_channels: int, kernel_size: int) -> nn.Sequential:
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1),
            nn.BatchNorm2d(out_channels, momentum=1.0, track_running_stats=False),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
    
    # num_support의 클레스 마다 prototype 계산(embedding의 평균)
    def get_prototypes(self, embeddings: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # embeddings: [batch_size, num_support, self.emb_size], targets: [batch_size, num_support]
        batch_size = embeddings.shape[0]
        indices = targets.unsqueeze(-1).expand_as(embeddings) # [batch_size, num_support, self.emb_size]
        prototypes = embeddings.new_zeros((batch_size, self.num_ways, self.emb_size))
        prototypes.scatter_add_(1, indices, embeddings).div_(float(self.num_support) / self.num_ways)
        return prototypes # [batch_size, self.num_ways, self.emb_size]
    
    def forward(
        self, support_x: torch.Tensor, support_y: torch.Tensor, query_x: torch.Tensor
    ) -> torch.Tensor:
        batch_size = support_x.shape[0]

        support_emb = self.embedding_net(support_x.flatten(start_dim=0, end_dim=1)).unflatten(
            dim=0, sizes=[batch_size, self.num_support]
        ) # [batch_size, self.num_support, self.emb_size]
        query_emb = self.embedding_net(query_x.flatten(start_dim=0, end_dim=1)).unflatten(
            dim=0, sizes=[batch_size, self.num_query]
        ) # [batch_size, self.num_query, self.emb_size]
        proto_emb = self.get_prototypes(support_emb, support_y) # [batch_size, self.num_ways, self.emb_size]

        distance_all = (query_emb.unsqueeze(2) - proto_emb.unsqueeze(1))**2 # [batch_size, self.num_query, self.num_ways, self.emb_size]
        distance = torch.sum(distance_all, dim=-1)
        return distance # [batch_size, self.num_query, self.num_ways]


In [9]:
'''batch_size = 32
num_ways = 5
num_shots = 3
num_query = 4
support_x = torch.randn(batch_size, num_ways*num_shots, 1, 28, 28)
support_y = torch.zeros(batch_size, num_ways*num_shots, dtype=torch.int64)
query_x = torch.randn(batch_size, num_ways*num_shots, 1, 28, 28)

model = PrototypicalNet(1, num_ways, num_shots)
out = model(support_x, support_y, query_x)
print(out.shape)'''

torch.Size([32, 15, 5])


### 학습 및 테스트

In [19]:
def train_proto(
    device: str,
    task_batch: Dict[str, List[torch.Tensor]],
    model: PrototypicalNet,
    criterion: nn.CrossEntropyLoss,
    optimizer: torch.optim.Adam,
) -> Tuple[float, float]:
    model.train()
    optimizer.zero_grad()

    support_xs = task_batch['train'][0].to(device=device) # [batch_size, num_shots*num_ways, 1, 28, 28]
    support_ys = task_batch['train'][1].to(device=device) # [batch_size, num_shots*num_ways]
    query_xs = task_batch['test'][0].to(device=device) # [batch_size, num_shots*num_ways, 1, 28, 28]
    query_ys = task_batch['test'][1].to(device=device) # [batch_size, num_shots*num_ways]

    accuracy = torch.tensor(0.0, device=device)

    distance = model(support_xs, support_ys, query_xs) # [batch_size, num_shots*num_ways, num_ways]
    distance_flat = distance.flatten(start_dim=0, end_dim=1)
    query_ys_flat = query_ys.flatten(start_dim=0, end_dim=1)
    loss = criterion(-distance_flat, query_ys_flat)

    loss.backward()
    optimizer.step()

    with torch.no_grad():
        _, query_pred = torch.max(-distance, dim=-1)
        accuracy += torch.mean(query_pred.eq(query_ys).float())
    return accuracy.item(), loss.item()

In [20]:
# 테스트 함수 정의
def test_proto(
    device: str,
    task_batch: Dict[str, List[torch.Tensor]],
    model: PrototypicalNet,
    criterion: nn.CrossEntropyLoss,
) -> Tuple[float, float]:
    model.eval()

    support_xs = task_batch['train'][0].to(device=device) # [batch_size, num_shots*num_ways, 1, 28, 28]
    support_ys = task_batch['train'][1].to(device=device) # [batch_size, num_shots*num_ways]
    query_xs = task_batch['test'][0].to(device=device) # [batch_size, num_shots*num_ways, 1, 28, 28]
    query_ys = task_batch['test'][1].to(device=device) # [batch_size, num_shots*num_ways]

    accuracy = torch.tensor(0.0, device=device)

    distance = model(support_xs, support_ys, query_xs) # [batch_size, num_shots*num_ways, num_ways]
    distance_flat = distance.flatten(start_dim=0, end_dim=1)
    query_ys_flat = query_ys.flatten(start_dim=0, end_dim=1)
    loss = criterion(-distance_flat, query_ys_flat)

    with torch.no_grad():
        _, query_pred = torch.max(-distance, dim=-1)
        accuracy += torch.mean(query_pred.eq(query_ys).float())
    return accuracy.item(), loss.item()

In [21]:
def get_dataloader(
    config: Dict[str, Any]
) -> Tuple[BatchMetaDataLoader, BatchMetaDataLoader, BatchMetaDataLoader]:
    train_dataset = omniglot(
        folder=config["folder_name"],
        shots=config["num_shots"],
        # test_shots=1, # default = shots
        ways=config["num_ways"],
        shuffle=True,
        meta_train=True,
        download=config["download"],
    )
    train_dataloader = BatchMetaDataLoader(
        train_dataset, batch_size=config["task_batch_size"], shuffle=True, num_workers=1
    )

    val_dataset = omniglot(
        folder=config["folder_name"],
        shots=config["num_shots"],
        # test_shots=1, # default = shots
        ways=config["num_ways"],
        shuffle=True,
        meta_val=True,
        download=config["download"],
    )
    val_dataloader = BatchMetaDataLoader(
        val_dataset, batch_size=config["task_batch_size"], shuffle=True, num_workers=1
    )

    test_dataset = omniglot(
        folder=config["folder_name"],
        shots=config["num_shots"],
        # test_shots=1, # default = shots
        ways=config["num_ways"],
        shuffle=True,
        meta_test=True,
        download=config["download"],
    )
    test_dataloader = BatchMetaDataLoader(
        test_dataset, batch_size=config["task_batch_size"], shuffle=True, num_workers=1
    )
    return train_dataloader, val_dataloader, test_dataloader

In [22]:
def save_model(output_folder: str, model: PrototypicalNet, title: str) -> None:
    if not os.path.isdir(output_folder):
        os.mkdir(output_folder)
    filename = os.path.join(output_folder, title)

    with open(filename, "wb") as f:
        state_dict = model.state_dict()
        torch.save(state_dict, f)
    print("Model is saved in", filename)


def load_model(output_folder: str, model: PrototypicalNet, title: str) -> None:
    filename = os.path.join(output_folder, title)
    model.load_state_dict(torch.load(filename))
    print("Model is loaded")

In [23]:
def print_graph(
    train_accuracies: List[float],
    val_accuracies: List[float],
    train_losses: List[float],
    val_losses: List[float],
) -> None:
    fig, axs = plt.subplots(1, 2, figsize=(14, 6))

    axs[0].plot(train_accuracies, label="train_acc")
    axs[0].plot(val_accuracies, label="test_acc")
    axs[0].set_title("Accuracy")
    axs[0].legend()

    axs[1].plot(train_losses, label="train_loss")
    axs[1].plot(val_losses, label="test_loss")
    axs[1].set_title("Loss")
    axs[1].legend()

    fig.show()

In [24]:
config = {
    "folder_name": "../load_dataset/dataset",
    "download": False,
    "num_shots": 5,
    "num_ways": 5,
    "output_folder": "saved_model",
    "task_batch_size": 32,  # 필수
    "num_task_batch_train": 600,  # 필수
    "num_task_batch_test": 200,  # 필수
    "device": "cuda",  # 필수
}

train_dataloader, val_dataloader, test_dataloader = get_dataloader(config)

model = PrototypicalNet(
    in_channels=1, num_ways=config["num_ways"], num_shots=config["num_shots"]
).to(device=config["device"])

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [25]:
# 메타-트레이닝
with tqdm(
    zip(train_dataloader, val_dataloader), total=config["num_task_batch_train"]
) as pbar:
    train_accuracies, val_accuracies = [], []
    train_losses, val_losses = [], []

    for task_batch_idx, (train_batch, val_batch) in enumerate(pbar):
        if task_batch_idx >= config["num_task_batch_train"]:
            break

        train_accuracy, train_loss = train_proto(
            device=config["device"],
            task_batch=train_batch,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
        )
        val_accuracy, val_loss = test_proto(
            device=config["device"],
            task_batch=val_batch,
            model=model,
            criterion=criterion,
        )

        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        pbar.set_postfix(
            train_accuracy="{0:.4f}".format(train_accuracy),
            val_accuracy="{0:.4f}".format(val_accuracy),
            train_loss="{0:.4f}".format(train_loss),
            val_loss="{0:.4f}".format(val_loss),
        )

    # 모델 저장하기
    save_model(
        output_folder=config["output_folder"],
        model=model,
        title="prototypical_network.th",
    )

    print_graph(
        train_accuracies=train_accuracies,
        val_accuracies=val_accuracies,
        train_losses=train_losses,
        val_losses=val_losses,
    )

100%|██████████| 600/600 [08:45<00:00,  1.14it/s, train_accuracy=0.9912, train_loss=0.0259, val_accuracy=0.9850, val_loss=0.0374]


Model is saved in saved_model\prototypical_network.th


: 

: 

In [None]:
# 모델 불러오기
load_model(
    output_folder=config["output_folder"], model=model, title="prototypical_network.th"
)

# 메타-테스팅
with tqdm(test_dataloader, total=config["num_task_batch_test"]) as pbar:
    sum_test_accuracies = 0.0
    sum_test_losses = 0.0

    for task_batch_idx, test_batch in enumerate(pbar):
        if task_batch_idx >= config["num_task_batch_test"]:
            break

        test_accuracy, test_loss = test_proto(
            device=config["device"],
            task_batch=test_batch,
            model=model,
            criterion=criterion,
        )
        sum_test_accuracies += test_accuracy
        sum_test_losses += test_loss

        pbar.set_postfix(
            test_accuracy="{0:.4f}".format(sum_test_accuracies / (task_batch_idx + 1)),
            test_loss="{0:.4f}".format(sum_test_losses / (task_batch_idx + 1)),
        )