In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

import data_loader
import evaluate
import models
import train

In [2]:
batch_size = 256
train_iter, test_iter = data_loader.load_data_fashion_mnist(batch_size, resize=224)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

teacher = models.teacher5().to(device)
student = models.FitNet1().to(device)

In [None]:
teacher_loss, teacher_acc = evaluate.evaluate(teacher, test_iter, device)
print(f"Test Loss: {teacher_loss}, Acc: {100. * teacher_acc}%")

In [9]:
train.train_student(
    teacher,
    train_iter,
    test_iter,
    criterion=nn.CrossEntropyLoss(),
    optimizer=optim.SGD(teacher.parameters(), lr=0.001, momentum=0.9),
    device=device,
    num_epochs=20,
)

epoch 1, loss 0.1940, train acc 0.928, test loss 0.2637, test acc 0.907
epoch 2, loss 0.1897, train acc 0.931, test loss 0.2672, test acc 0.903
epoch 3, loss 0.1862, train acc 0.932, test loss 0.2712, test acc 0.906
epoch 4, loss 0.1845, train acc 0.932, test loss 0.2632, test acc 0.909
epoch 5, loss 0.1839, train acc 0.933, test loss 0.2731, test acc 0.904
epoch 6, loss 0.1806, train acc 0.934, test loss 0.2668, test acc 0.908
epoch 7, loss 0.1824, train acc 0.933, test loss 0.2651, test acc 0.905
epoch 8, loss 0.1759, train acc 0.936, test loss 0.2706, test acc 0.901
epoch 9, loss 0.1719, train acc 0.937, test loss 0.2682, test acc 0.907
epoch 10, loss 0.1752, train acc 0.936, test loss 0.2563, test acc 0.910
epoch 11, loss 0.1679, train acc 0.939, test loss 0.2541, test acc 0.910
epoch 12, loss 0.1657, train acc 0.940, test loss 0.2519, test acc 0.913
epoch 13, loss 0.1699, train acc 0.938, test loss 0.2566, test acc 0.908
epoch 14, loss 0.1635, train acc 0.940, test loss 0.2498, te

In [8]:
train.train_student(
    student,
    train_iter,
    test_iter,
    criterion=nn.CrossEntropyLoss(),
    optimizer=optim.SGD(student.parameters(), lr=0.001, momentum=0.9),
    device=device,
    num_epochs=20,
)

epoch 1, loss 0.0759, train acc 0.972, test loss 0.3060, test acc 0.913
epoch 2, loss 0.0660, train acc 0.976, test loss 0.3398, test acc 0.911
epoch 3, loss 0.0693, train acc 0.975, test loss 0.3274, test acc 0.912
epoch 4, loss 0.0610, train acc 0.978, test loss 0.3654, test acc 0.910
epoch 5, loss 0.0631, train acc 0.977, test loss 0.3634, test acc 0.913
epoch 6, loss 0.0592, train acc 0.978, test loss 0.3612, test acc 0.913
epoch 7, loss 0.0673, train acc 0.975, test loss 0.3287, test acc 0.915
epoch 8, loss 0.0537, train acc 0.980, test loss 0.3497, test acc 0.911
epoch 9, loss 0.0521, train acc 0.981, test loss 0.3814, test acc 0.912
epoch 10, loss 0.0427, train acc 0.985, test loss 0.4104, test acc 0.911
epoch 11, loss 0.0516, train acc 0.981, test loss 0.4043, test acc 0.909
epoch 12, loss 0.0462, train acc 0.983, test loss 0.3865, test acc 0.913
epoch 13, loss 0.0394, train acc 0.986, test loss 0.4446, test acc 0.912
epoch 14, loss 0.0454, train acc 0.983, test loss 0.4030, te

In [None]:
student = nn.Sequential(*list(student.children())[:5])

In [None]:
student[0] = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
student

In [None]:
import torch.nn as nn


def conv_block(in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=False,
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )


class teacher5_big(nn.Module):
    def __init__(self):
        super(teacher5_big, self).__init__()
        self.features = nn.Sequential(
            conv_block(1, 128, kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            conv_block(128, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            conv_block(128, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=8, stride=1, padding=0),
        )
        self.classifier = nn.Sequential(nn.Flatten(), nn.Linear(56448, 10))

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


# 创建模型实例
student = teacher5_big()

In [None]:
X = torch.rand(1, 1, 224, 224)
student(X)