In [None]:
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm


def train(model, dataloader, criterion, optimizer, device):
    model.train()  # モデルをトレーニングモードに設定
    for images, maps in tqdm(dataloader):
        images = images.to(device)  # 入力データをデバイスに移動
        maps = maps.to(device)      # ターゲットデータをデバイスに移動

        optimizer.zero_grad()  # 勾配をゼロにリセット

        outputs = model(images)  # モデルを呼び出し
        loss = criterion(outputs, maps)  # 損失を計算
        loss.backward()  # バックプロパゲーション
        optimizer.step()  # 重みを更新

def test(model, dataloader, criterion, device):
    model.eval()

    report = dict()
    report["input_images"] = []
    report["target_images"] = []
    report["prediction_images"] = []
    report["loss"] = []

    total_loss = 0
    correct = 0  # 正しく分類されたサンプル数

    for batch_idx,(images, maps) in enumerate(dataloader):
        images = images.to(device)
        maps = maps.to(device)

        with torch.no_grad():
            outputs = model(images)
            loss = criterion(outputs, maps)
            total_loss += loss.item()

            # 予測ラベルを取得
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == maps).sum().item()

            # 画像を収集
            report["input_images"].append(images.cpu())
            report["target_images"].append(maps.cpu())
            report["prediction_images"].append(outputs.cpu())
            report["loss"].append(loss.item())

    average_loss = total_loss / len(dataloader)
    report["average_loss"] = average_loss
    accuracy = 100. * correct / len(dataloader.dataset)

    print(
        f'\nTest set: Average loss: {average_loss:.4f}, Accuracy: {correct}/{len(dataloader.dataset)} ({accuracy:.2f}%)\n'
    )

    return report


# テスト結果の可視化（オプション）
def visualize_results(report, num_images=5):
    for i in range(min(num_images, len(report["input_images"]))):
        plt.figure(figsize=(12, 4))

        plt.subplot(1, 3, 1)
        plt.title("Input Image")
        plt.imshow(report["input_images"][i].permute(1, 2, 0).numpy())
        plt.axis("off")

        plt.subplot(1, 3, 2)
        plt.title("Target Image")
        plt.imshow(report["target_images"][i].permute(1, 2, 0).numpy())
        plt.axis("off")

        plt.subplot(1, 3, 3)
        plt.title("Predicted Image")
        plt.imshow(report["prediction_images"][i].permute(1, 2, 0).numpy())
        plt.axis("off")

        plt.show()

