In [6]:
# 必要ライブラリのインストール
!pip install torch-geometric -q

In [7]:
# インポート
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch_geometric.nn import GCNConv
from sklearn.metrics import top_k_accuracy_score
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# データセット準備
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
val_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

class VGG(nn.Module):
    def __init__(self):
        super().__init__()
        # VGG16の代替構造（32x32入力対応）
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),  # (B, 64, 32, 32)
            nn.ReLU(True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),  # (B, 64, 32, 32)
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),        # (B, 64, 16, 16)

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),        # (B, 128, 8, 8)

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),        # (B, 256, 4, 4)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)


class VGG_GNN(nn.Module):
    def __init__(self):
        super().__init__()
        # CIFAR-10向けVGGスタイルの特徴抽出部
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),        # 16x16

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),        # 8x8

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),        # 4x4
        )

        self.gnn = GCNConv(256, 256)
        self.classifier = nn.Linear(256, 10)

    def forward(self, x):
        x = self.features(x)  # (B, 256, 4, 4)
        B, C, H, W = x.shape
        x = x.view(B, C, -1).permute(0, 2, 1)  # (B, 16, 256)

        edge_index = self._create_edges(H, W).to(x.device)
        out = []
        for i in range(B):
            gnn_out = self.gnn(x[i], edge_index)  # (16, 256)
            pooled = gnn_out.mean(dim=0)          # (256,)
            out.append(pooled)
        out = torch.stack(out)  # (B, 256)
        return self.classifier(out)

    def _create_edges(self, H, W):
        # 格子構造の隣接ノードを定義
        edges = []
        for i in range(H):
            for j in range(W):
                idx = i * W + j
                if i < H - 1:
                    edges.append([idx, (i + 1) * W + j])
                if j < W - 1:
                    edges.append([idx, i * W + (j + 1)])
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        return edge_index


# EarlyStoppingのクラス定義
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, score):
        if self.best_score is None or score > self.best_score:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# 評価関数
def evaluate(model, loader):
    model.eval()
    top1_correct = 0
    preds_all, labels_all = [], []
    loop = tqdm(loader, desc="Evaluating")
    with torch.no_grad():
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            top1_correct += (preds == labels).sum().item()
            preds_all.append(outputs.cpu())
            labels_all.append(labels.cpu())
    preds_all = torch.cat(preds_all).numpy()
    labels_all = torch.cat(labels_all).numpy()
    top1 = 100 * top1_correct / len(loader.dataset)
    top5 = 100 * top_k_accuracy_score(labels_all, preds_all, k=5)
    return top1, top5

def train(model, loader, optimizer, criterion, val_loader, max_epochs=50, model_name="model"):
    model.to(device)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.5, verbose=True)
    early_stopper = EarlyStopping(patience=3)

    train_losses, train_accuracies, val_accuracies = [], [], []

    best_model_state = None
    best_val_acc = -1
    best_epoch = -1

    for epoch in range(max_epochs):
        model.train()
        total, correct, total_loss = 0, 0, 0
        loop = tqdm(loader, desc=f"Epoch {epoch+1}")
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            preds = outputs.argmax(dim=1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
            total_loss += loss.item()
            acc = 100 * correct / total
            loop.set_postfix(loss=loss.item(), acc=f"{acc:.2f}%")

        train_losses.append(total_loss / len(loader))
        train_accuracies.append(100 * correct / total)

        val_top1, _ = evaluate(model, val_loader)
        val_accuracies.append(val_top1)
        scheduler.step(val_top1)

        # ベストモデルの保存
        if val_top1 > best_val_acc:
            best_val_acc = val_top1
            best_model_state = model.state_dict()
            best_epoch = epoch + 1  # 1-based

        early_stopper(val_top1)
        if early_stopper.early_stop:
            print(f"⏹️ Early stopping at epoch {epoch+1}")
            break

    # 保存処理
    os.makedirs("checkpoints", exist_ok=True)
    model_path = f"checkpoints/{model_name}_best.pt"
    torch.save(best_model_state, model_path)
    print(f"💾 ベストモデル（epoch {best_epoch}）を保存しました: {model_path}")

    # 最後にベストモデルを読み込み直して返す
    model.load_state_dict(best_model_state)
    return model, train_losses, train_accuracies, val_accuracies, best_epoch

def plot_training_curves(train_losses, train_accuracies, val_accuracies, output_path="plots/training_summary.png"):
    import matplotlib.pyplot as plt
    import os

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(train_accuracies, label="Train Accuracy")
    plt.plot(val_accuracies, label="Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Value")
    plt.title("Training Summary")
    plt.legend()
    plt.grid(True)
    plt.savefig(output_path)
    print(f"📈 学習曲線を保存しました: {output_path}")
    plt.close()

In [10]:
# Plain VGG16 のトレーニングと保存
print("▶️ Training Plain VGG16...")
vgg_plain = PlainVGG()
optimizer_plain = torch.optim.Adam(vgg_plain.parameters(), lr=0.001)
vgg_plain, plain_losses, plain_accs, plain_val_accs, plain_best_epoch = train(
    vgg_plain, train_loader, optimizer_plain, nn.CrossEntropyLoss(), val_loader, model_name="vgg_plain")

plain_top1, plain_top5 = evaluate(vgg_plain, val_loader)
print(f"\n📊 VGG16 - Top-1 Accuracy: {plain_top1:.2f}%, Top-5 Accuracy: {plain_top5:.2f}% (Best Epoch: {plain_best_epoch})")
plot_training_curves(plain_losses, plain_accs, plain_val_accs, output_path="plots/vgg_plain_summary.png")


# VGG + GNN のトレーニングと保存
print("\n▶️ Training VGG16 + GNN...")
vgg_gnn = VGG_GNN()
optimizer_gnn = torch.optim.Adam(vgg_gnn.parameters(), lr=0.001)
vgg_gnn, gnn_losses, gnn_accs, gnn_val_accs, gnn_best_epoch = train(
    vgg_gnn, train_loader, optimizer_gnn, nn.CrossEntropyLoss(), val_loader, model_name="vgg_gnn")

gnn_top1, gnn_top5 = evaluate(vgg_gnn, val_loader)
print(f"\n📊 VGG16 + GNN - Top-1 Accuracy: {gnn_top1:.2f}%, Top-5 Accuracy: {gnn_top5:.2f}% (Best Epoch: {gnn_best_epoch})")
plot_training_curves(gnn_losses, gnn_accs, gnn_val_accs, output_path="plots/vgg_gnn_summary.png")



▶️ Training Plain VGG16...


Epoch 1: 100%|██████████| 782/782 [00:07<00:00, 107.42it/s, acc=31.79%, loss=1.29]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 138.92it/s]
Epoch 2: 100%|██████████| 782/782 [00:06<00:00, 117.30it/s, acc=50.61%, loss=1.06]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 133.13it/s]
Epoch 3: 100%|██████████| 782/782 [00:06<00:00, 118.72it/s, acc=59.59%, loss=1.27]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 137.06it/s]
Epoch 4: 100%|██████████| 782/782 [00:06<00:00, 118.82it/s, acc=65.32%, loss=1.14]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 132.98it/s]
Epoch 5: 100%|██████████| 782/782 [00:06<00:00, 120.53it/s, acc=69.29%, loss=1.28]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 135.11it/s]
Epoch 6: 100%|██████████| 782/782 [00:06<00:00, 118.53it/s, acc=71.99%, loss=1.25]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 133.89it/s]
Epoch 7: 100%|██████████| 782/782 [00:06<00:00, 117.80it/s, acc=73.81%, loss=0.726]
Evaluating: 100%|██████████| 157/157 [00

⏹️ Early stopping at epoch 19
💾 ベストモデル（epoch 16）を保存しました: checkpoints/vgg_plain_best.pt


Evaluating: 100%|██████████| 157/157 [00:01<00:00, 142.52it/s]



📊 VGG16 - Top-1 Accuracy: 77.55%, Top-5 Accuracy: 98.08% (Best Epoch: 16)
📈 学習曲線を保存しました: plots/vgg_plain_summary.png

▶️ Training VGG16 + GNN...


Epoch 1: 100%|██████████| 782/782 [00:59<00:00, 13.21it/s, acc=31.42%, loss=1.4]
Evaluating: 100%|██████████| 157/157 [00:07<00:00, 20.74it/s]
Epoch 2: 100%|██████████| 782/782 [00:59<00:00, 13.10it/s, acc=51.33%, loss=1.14]
Evaluating: 100%|██████████| 157/157 [00:07<00:00, 20.41it/s]
Epoch 3: 100%|██████████| 782/782 [00:59<00:00, 13.12it/s, acc=62.76%, loss=1.08]
Evaluating: 100%|██████████| 157/157 [00:07<00:00, 20.33it/s]
Epoch 4: 100%|██████████| 782/782 [00:59<00:00, 13.21it/s, acc=69.25%, loss=1.51]
Evaluating: 100%|██████████| 157/157 [00:07<00:00, 20.74it/s]
Epoch 5: 100%|██████████| 782/782 [00:59<00:00, 13.17it/s, acc=74.22%, loss=0.893]
Evaluating: 100%|██████████| 157/157 [00:07<00:00, 20.04it/s]
Epoch 6: 100%|██████████| 782/782 [00:59<00:00, 13.20it/s, acc=77.14%, loss=0.833]
Evaluating: 100%|██████████| 157/157 [00:07<00:00, 20.56it/s]
Epoch 7: 100%|██████████| 782/782 [00:59<00:00, 13.22it/s, acc=79.79%, loss=0.874]
Evaluating: 100%|██████████| 157/157 [00:08<00:00, 1

⏹️ Early stopping at epoch 12
💾 ベストモデル（epoch 9）を保存しました: checkpoints/vgg_gnn_best.pt


Evaluating: 100%|██████████| 157/157 [00:07<00:00, 20.76it/s]


📊 VGG16 + GNN - Top-1 Accuracy: 80.16%, Top-5 Accuracy: 99.08% (Best Epoch: 9)
📈 学習曲線を保存しました: plots/vgg_gnn_summary.png



