# 準備
1.  必要パッケージのインストール

In [None]:
! pip install torchinfo timm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchinfo import summary
import timm
import random
import numpy as np

def set_seed(seed=42):
    """再現性のためにシード値を固定"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Seed set to {seed}")

# MNISTデータセットのダウンロード

In [None]:
# 訓練用データ拡張（オンザフライ）
train_transform = transforms.Compose([
    transforms.RandomApply([transforms.RandomRotation(degrees=(-10, 10))], p=0.5),  # ±5~10度回転を50%で適用
    transforms.RandomApply([transforms.RandomAffine(
        degrees=0,  # 回転は上で実施
        translate=(3/28, 3/28),  # 上下左右1~3px平行移動（正規化のため28で割る）
        scale=(0.9, 1.1)  # 0.9~1.1倍の拡大縮小
    )], p=0.5),  # 50%の確率で適用
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# テスト用は拡張なし
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

training_set = datasets.MNIST("./data", train=True, download=True, transform=train_transform)
test_set = datasets.MNIST("./data", train=False, transform=test_transform)

# シード値を固定したgeneratorを作成
g = torch.Generator()
g.manual_seed(42)

# DataLoaderの並列化で高速化
train_loader = torch.utils.data.DataLoader(
    training_set, 
    batch_size=64, 
    shuffle=True, 
    generator=g,
    num_workers=2,  # CPUでの並列データ読み込み
    pin_memory=True,  # GPU転送の高速化
    persistent_workers=True  # ワーカー再利用
)
test_loader = torch.utils.data.DataLoader(
    test_set, 
    batch_size=1000,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True
)

# 教師モデルの定義

In [None]:
# ここに教師モデルを自由に定義してください
class TeacherModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 例: より深いネットワークや大きなチャネル数など
        pass
    
    def forward(self, x):
        # forward処理を定義
        pass

print( summary(TeacherModel(), input_size=(64, 1, 28, 28)) )

# 教師モデルの学習

In [None]:
def train_teacher(model, device, train_loader, optimizer, epoch, scaler):
    """教師モデルの訓練"""
    model.train()
    for batch_idx, (x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)
        
        with torch.cuda.amp.autocast():
            logits = model(x)
            loss = F.cross_entropy(logits, y, label_smoothing=0.1)
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        if batch_idx % 100 == 0:
            print(f"[Teacher] Epoch={epoch+1}, Batch={batch_idx+1:03}, Loss={loss.item():.4f}")

def test_teacher(model, device, test_loader):
    """教師モデルのテスト"""
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            with torch.cuda.amp.autocast():
                logits = model(x)
            
            y_hat = logits.argmax(dim=1, keepdim=True)
            correct += y_hat.eq(y.view_as(y_hat)).sum().item()
    
    accuracy = correct / len(test_loader.dataset)
    print(f"[Teacher] Test-set accuracy={accuracy:.04f}\n")
    return accuracy

# 教師モデルの学習実行
def train_teacher_model(epochs=20, seed=42):
    """教師モデルを1から学習（ベスト精度の重みを保存）"""
    set_seed(seed)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher = TeacherModel().to(device)
    optimizer = torch.optim.AdamW(teacher.parameters(), lr=0.001)
    scaler = torch.cuda.amp.GradScaler()
    
    best_accuracy = 0.0
    best_state = None
    
    print("=" * 50)
    print("教師モデルの学習開始")
    print("=" * 50)
    
    for epoch in range(epochs):
        train_teacher(teacher, device, train_loader, optimizer, epoch, scaler)
        accuracy = test_teacher(teacher, device, test_loader)
        
        # ベスト精度を更新したら重みを保存
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_state = {k: v.cpu().clone() for k, v in teacher.state_dict().items()}
            print(f"*** ベスト精度更新: {best_accuracy:.04f} (Epoch {epoch+1}) ***\n")
    
    # ベストな重みをロード
    teacher.load_state_dict(best_state)
    print("=" * 50)
    print(f"教師モデルの学習完了 (ベスト精度: {best_accuracy:.04f})")
    print("=" * 50)
    print()
    
    return teacher

# 教師モデルを学習（20エポック程度で十分）
teacher_model = train_teacher_model(epochs=20)

# 学生モデルの定義

In [None]:
#学生モデル
class HSwish(nn.Module):
    def forward(self, x):
        return x * F.relu6(x + 3) / 6


class MobileNetV4Block(nn.Module):
    """
    MobileNetV4 Universal Inverted Bottleneck (UIB)
    - optional dw1: depthwise BEFORE expansion
    - expand: 1x1
    - optional dw2: depthwise AFTER expansion
    - project: 1x1
    """
    def __init__(
        self,
        inp,
        oup,
        expand_ratio=4,
        kernel_size=3,
        stride=1,
        use_dw1=True,   # optional depthwise before expand
        use_dw2=True,    # optional depthwise after expand
        activation='hswish'
    ):
        super().__init__()

        hidden_dim = inp * expand_ratio
        self.use_res = (stride == 1 and inp == oup)

        if activation == 'relu':
            act = nn.ReLU(inplace=True)
        else:
            act = HSwish()

        layers = []

        # Optional Depthwise BEFORE Expand (ExtraDW)
        if use_dw1:
            layers += [
                nn.Conv2d(inp, inp, kernel_size, stride,
                          kernel_size // 2, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                act,
            ]

        # Expand 1x1
        layers += [
            nn.Conv2d(inp, hidden_dim, 1, 1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            act,
        ]

        # Optional dw AFTER expand
        if use_dw2:
            layers += [
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride,
                          kernel_size // 2, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                act,
            ]

        # Project 1x1
        layers += [
            nn.Conv2d(hidden_dim, oup, 1, 1, bias=False),
            nn.BatchNorm2d(oup)
        ]

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        out = self.block(x)
        if self.use_res:
            return out + x
        return out


class MyMnistNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Stem: 1 → 12
        self.stem = nn.Sequential(
            nn.Conv2d(1, 12, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(12),
            HSwish(),
        )

        # MobileNetV4 Blocks
        self.blocks = nn.Sequential(
            MobileNetV4Block(12, 16, stride=1, expand_ratio=2),
            MobileNetV4Block(16, 16, stride=1, expand_ratio=2),  
        )

        # Classifier
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        x = self.gap(x).flatten(1)
        return self.fc(x)
print( summary(MyMnistNet(), input_size=(64, 1, 28, 28)) )

# 学生モデルの学習（知識蒸留）

In [None]:
def distillation_loss(student_logits, teacher_logits, labels, temperature=3.0, alpha=0.7):
    """
    知識蒸留用の損失関数
    - ソフトターゲット（教師の出力）とハードターゲット（正解ラベル）を両方考慮
    - temperature: 温度パラメータ（高いほど確率分布が滑らかになる）
    - alpha: ソフトターゲットの重み（1-alphaがハードターゲットの重み）
    """
    # ソフトターゲット損失: KLダイバージェンス
    soft_targets = F.softmax(teacher_logits / temperature, dim=1)
    soft_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    soft_loss = F.kl_div(soft_log_probs, soft_targets, reduction='batchmean') * (temperature ** 2)
    
    # ハードターゲット損失: クロスエントロピー
    hard_loss = F.cross_entropy(student_logits, labels, label_smoothing=0.1)
    
    # 重み付き和
    return alpha * soft_loss + (1 - alpha) * hard_loss


def train_student(model, teacher, device, train_loader, optimizer, epoch, scaler):
    """学生モデルの訓練（知識蒸留）"""
    model.train()
    teacher.eval()  # 教師モデルは推論モード
    out = 0
    
    for batch_idx, (x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)
        
        with torch.cuda.amp.autocast():
            # 学生モデルの出力
            student_logits = model(x)
            
            # 教師モデルの出力（勾配計算不要）
            with torch.no_grad():
                teacher_logits = teacher(x)
            
            # 蒸留損失
            loss = distillation_loss(student_logits, teacher_logits, y, temperature=3.0, alpha=0.7)
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        if batch_idx % 100 == 0:
            print(f"[Student] Epoch={epoch+1}, Batch={batch_idx+1:03}, Loss={loss.item():.4f}")
            out = loss.item()
    return out


def test(model, device, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            with torch.cuda.amp.autocast():
                p_y_hat = model(x)
            
            y_hat = p_y_hat.argmax(dim=1, keepdim=True)
            correct += y_hat.eq(y.view_as(y_hat)).sum().item()

    accuracy = correct / len(test_loader.dataset)
    print(f"[Student] Test-set accuracy={accuracy:.04f}\n")
    return accuracy


def main_distillation(teacher, seed=42):
    """知識蒸留による学生モデルの学習"""
    set_seed(seed)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher = teacher.to(device)
    student = MyMnistNet().to(device)
    optimizer = torch.optim.AdamW(student.parameters(), lr=0.001)
    scaler = torch.cuda.amp.GradScaler()
    
    print("=" * 50)
    print("学生モデルの知識蒸留学習開始")
    print("=" * 50)
    
    for epoch in range(100):
        train_student(student, teacher, device, train_loader, optimizer, epoch, scaler)
        test(student, device, test_loader)


# 知識蒸留による学生モデルの学習を実行
main_distillation(teacher_model)