In [1]:
# 必要ライブラリのインストール
!pip install torch-geometric -q

In [2]:
# インポート
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch_geometric.nn import GCNConv
from torchvision.models import resnet18, ResNet18_Weights
from sklearn.metrics import top_k_accuracy_score
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# データセット準備
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
val_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

class RNNClassifier(nn.Module):
    def __init__(self, input_size=96, hidden_size=128, num_layers=2, num_classes=10):
        super().__init__()
        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        B, C, H, W = x.shape  # CIFAR-10: (B, 3, 32, 32)
        x = x.view(B, H, C * W)  # (B, 32, 96) -> 32ステップ, 各ステップ96次元
        out, _ = self.rnn(x)     # (B, 32, hidden_size)
        out = out[:, -1, :]      # 最終ステップの出力だけ使う
        return self.fc(out)

class ResNet_GNN(nn.Module):
    def __init__(self):
        super().__init__()
        base = resnet18(weights=ResNet18_Weights.DEFAULT)

        # CIFAR-10対応: 最初の畳み込みとMaxPoolを変更
        base.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        base.maxpool = nn.Identity()  # MaxPoolを除去

        # 最終的に (B, 512, 4, 4) の特徴マップが得られるようにする
        self.features = nn.Sequential(*list(base.children())[:-2])

        self.gnn = GCNConv(512, 512)
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        x = self.features(x)  # (B, 512, 4, 4)
        B, C, H, W = x.shape
        x = x.view(B, C, -1).permute(0, 2, 1)  # (B, 16, 512)

        edge_index = self._create_edges(H, W).to(x.device)
        out = []
        for i in range(B):
            gnn_out = self.gnn(x[i], edge_index)  # (16, 512)
            pooled = gnn_out.mean(dim=0)         # (512,)
            out.append(pooled)
        return self.classifier(torch.stack(out))  # (B, 10)

    def _create_edges(self, H, W):
        edges = []
        for i in range(H):
            for j in range(W):
                idx = i * W + j
                if i < H - 1:
                    edges.append([idx, (i + 1) * W + j])
                if j < W - 1:
                    edges.append([idx, i * W + (j + 1)])
        return torch.tensor(edges, dtype=torch.long).t().contiguous()

  # EarlyStoppingのクラス定義
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, score):
        if self.best_score is None or score > self.best_score:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# 評価関数
def evaluate(model, loader):
    model.eval()
    top1_correct = 0
    preds_all, labels_all = [], []
    loop = tqdm(loader, desc="Evaluating")
    with torch.no_grad():
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            top1_correct += (preds == labels).sum().item()
            preds_all.append(outputs.cpu())
            labels_all.append(labels.cpu())
    preds_all = torch.cat(preds_all).numpy()
    labels_all = torch.cat(labels_all).numpy()
    top1 = 100 * top1_correct / len(loader.dataset)
    top5 = 100 * top_k_accuracy_score(labels_all, preds_all, k=5)
    return top1, top5

def train(model, loader, optimizer, criterion, val_loader, max_epochs=50, model_name="model"):
    model.to(device)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.5, verbose=True)
    early_stopper = EarlyStopping(patience=3)

    train_losses, train_accuracies, val_accuracies = [], [], []

    best_model_state = None
    best_val_acc = -1
    best_epoch = -1

    for epoch in range(max_epochs):
        model.train()
        total, correct, total_loss = 0, 0, 0
        loop = tqdm(loader, desc=f"Epoch {epoch+1}")
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            preds = outputs.argmax(dim=1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
            total_loss += loss.item()
            acc = 100 * correct / total
            loop.set_postfix(loss=loss.item(), acc=f"{acc:.2f}%")

        train_losses.append(total_loss / len(loader))
        train_accuracies.append(100 * correct / total)

        val_top1, _ = evaluate(model, val_loader)
        val_accuracies.append(val_top1)
        scheduler.step(val_top1)

        # ベストモデルの保存
        if val_top1 > best_val_acc:
            best_val_acc = val_top1
            best_model_state = model.state_dict()
            best_epoch = epoch + 1  # 1-based

        early_stopper(val_top1)
        if early_stopper.early_stop:
            print(f"⏹️ Early stopping at epoch {epoch+1}")
            break

    # 保存処理
    os.makedirs("checkpoints", exist_ok=True)
    model_path = f"checkpoints/{model_name}_best.pt"
    torch.save(best_model_state, model_path)
    print(f"💾 ベストモデル（epoch {best_epoch}）を保存しました: {model_path}")

    # 最後にベストモデルを読み込み直して返す
    model.load_state_dict(best_model_state)
    return model, train_losses, train_accuracies, val_accuracies, best_epoch

def plot_training_curves(train_losses, train_accuracies, val_accuracies, output_path="plots/training_summary.png"):
    import matplotlib.pyplot as plt
    import os

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(train_accuracies, label="Train Accuracy")
    plt.plot(val_accuracies, label="Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Value")
    plt.title("Training Summary")
    plt.legend()
    plt.grid(True)
    plt.savefig(output_path)
    print(f"📈 学習曲線を保存しました: {output_path}")
    plt.close()

In [5]:
print("\n▶️ Training RNN Classifier...")
rnn_model = RNNClassifier()
optimizer_rnn = torch.optim.Adam(rnn_model.parameters(), lr=0.001)

rnn_model, rnn_losses, rnn_accs, rnn_val_accs, rnn_best_epoch = train(
    rnn_model, train_loader, optimizer_rnn, nn.CrossEntropyLoss(), val_loader, model_name="rnn_classifier"
)

rnn_top1, rnn_top5 = evaluate(rnn_model, val_loader)
print(f"\n📊 RNN Classifier - Top-1 Accuracy: {rnn_top1:.2f}%, Top-5 Accuracy: {rnn_top5:.2f}% (Best Epoch: {rnn_best_epoch})")

plot_training_curves(rnn_losses, rnn_accs, rnn_val_accs, output_path="plots/rnn_training_summary.png")


print("\n▶️ Training ResNet18 + GNN...")
resnet_gnn = ResNet_GNN()
optimizer_resnet = torch.optim.Adam(resnet_gnn.parameters(), lr=0.001)
resnet_gnn, resnet_losses, resnet_accs, resnet_val_accs, resnet_best_epoch = train(
    resnet_gnn, train_loader, optimizer_resnet, nn.CrossEntropyLoss(), val_loader, model_name="resnet_gnn")

resnet_top1, resnet_top5 = evaluate(resnet_gnn, val_loader)
print(f"\n📊 ResNet18 + GNN - Top-1 Accuracy: {resnet_top1:.2f}%, Top-5 Accuracy: {resnet_top5:.2f}% (Best Epoch: {resnet_best_epoch})")
plot_training_curves(resnet_losses, resnet_accs, resnet_val_accs, output_path="plots/resnet_gnn_summary.png")




▶️ Training RNN Classifier...


Epoch 1: 100%|██████████| 782/782 [00:06<00:00, 120.82it/s, acc=30.37%, loss=1.66]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 145.51it/s]
Epoch 2: 100%|██████████| 782/782 [00:05<00:00, 143.07it/s, acc=36.43%, loss=1.73]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 148.11it/s]
Epoch 3: 100%|██████████| 782/782 [00:05<00:00, 142.20it/s, acc=39.43%, loss=2.08]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 151.45it/s]
Epoch 4: 100%|██████████| 782/782 [00:05<00:00, 140.44it/s, acc=41.48%, loss=1.86]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 141.63it/s]
Epoch 5: 100%|██████████| 782/782 [00:05<00:00, 143.38it/s, acc=43.12%, loss=1.42]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 151.72it/s]
Epoch 6: 100%|██████████| 782/782 [00:05<00:00, 140.10it/s, acc=44.33%, loss=1.05]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 150.74it/s]
Epoch 7: 100%|██████████| 782/782 [00:05<00:00, 146.65it/s, acc=45.15%, loss=1.53]
Evaluating: 100%|██████████| 157/157 [00:

⏹️ Early stopping at epoch 17
💾 ベストモデル（epoch 14）を保存しました: checkpoints/rnn_classifier_best.pt


Evaluating: 100%|██████████| 157/157 [00:01<00:00, 141.22it/s]



📊 RNN Classifier - Top-1 Accuracy: 45.74%, Top-5 Accuracy: 91.50% (Best Epoch: 14)
📈 学習曲線を保存しました: plots/rnn_training_summary.png

▶️ Training ResNet18 + GNN...


Epoch 1: 100%|██████████| 782/782 [01:09<00:00, 11.24it/s, acc=71.93%, loss=1.28]
Evaluating: 100%|██████████| 157/157 [00:08<00:00, 18.90it/s]
Epoch 2: 100%|██████████| 782/782 [01:09<00:00, 11.18it/s, acc=84.00%, loss=0.535]
Evaluating: 100%|██████████| 157/157 [00:08<00:00, 18.46it/s]
Epoch 3: 100%|██████████| 782/782 [01:09<00:00, 11.28it/s, acc=87.86%, loss=0.294]
Evaluating: 100%|██████████| 157/157 [00:08<00:00, 18.51it/s]
Epoch 4: 100%|██████████| 782/782 [01:08<00:00, 11.36it/s, acc=89.93%, loss=0.452]
Evaluating: 100%|██████████| 157/157 [00:08<00:00, 18.67it/s]
Epoch 5: 100%|██████████| 782/782 [01:09<00:00, 11.25it/s, acc=91.82%, loss=0.324]
Evaluating: 100%|██████████| 157/157 [00:08<00:00, 18.90it/s]
Epoch 6: 100%|██████████| 782/782 [01:09<00:00, 11.31it/s, acc=93.09%, loss=0.255]
Evaluating: 100%|██████████| 157/157 [00:08<00:00, 18.48it/s]
Epoch 7: 100%|██████████| 782/782 [01:09<00:00, 11.28it/s, acc=94.18%, loss=0.303]
Evaluating: 100%|██████████| 157/157 [00:08<00:0

⏹️ Early stopping at epoch 25
💾 ベストモデル（epoch 22）を保存しました: checkpoints/resnet_gnn_best.pt


Evaluating: 100%|██████████| 157/157 [00:08<00:00, 18.80it/s]



📊 ResNet18 + GNN - Top-1 Accuracy: 90.72%, Top-5 Accuracy: 99.59% (Best Epoch: 22)
📈 学習曲線を保存しました: plots/resnet_gnn_summary.png
