In [10]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# 超参数

In [11]:
epochs = 50
batch_size = 16
save_steps = 10
num_workers = 4
lr = 0.001
lr_step = 40

alpha = 0.7
temperature = 7

# 数据

In [12]:
import os
import json
import random
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms


def read_split_data(root, val_rate=0.2):
    
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    flower_class.sort()
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    print(class_indices) 

    train_images_path = []
    train_images_label = []
    val_images_path = []
    val_images_label = []
    every_class_num = []
   
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)]
        images.sort()
        image_class = class_indices[cla]
        
        every_class_num.append(len(images))
        val_path = random.sample(images, k=int(len(images) * val_rate))
        
        # 划分训练集 和 验证集
        for img_path in images:
            if img_path in val_path:  
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))
    assert len(train_images_path) > 0, "number of training images must greater than 0."
    assert len(val_images_path) > 0, "number of validation images must greater than 0."

    return train_images_path, train_images_label, val_images_path, val_images_label


class MyDataSet(Dataset):

    def __init__(self, images_path, images_class, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return torch.as_tensor(img), torch.as_tensor(label)


data_root = "/kaggle/input/flowers/flower_photos"
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(data_root)

# =============================== Transform ===============================
img_size = 224
train_transform = transforms.Compose([transforms.RandomResizedCrop(img_size),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

val_transform = transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                      transforms.CenterCrop(img_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# =============================== DataSet ===============================

train_dataset = MyDataSet(images_path=train_images_path,
                          images_class=train_images_label,
                          transform=train_transform)

val_dataset = MyDataSet(images_path=val_images_path,
                        images_class=val_images_label,
                        transform=val_transform)

# =============================== DataLoader ===============================

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=num_workers)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         pin_memory=True,
                                         num_workers=num_workers)

print("len(train_loader) = {}".format(len(train_loader)))
print("len(val_loader) = {}".format(len(val_loader)))

{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
3670 images were found in the dataset.
2939 images for training.
731 images for validation.
len(train_loader) = 184
len(val_loader) = 46


# Train and Validation

In [13]:
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

class RunningAverage():

    def __init__(self):
        self.steps = 0
        self.loss_sum = 0
        self.acc_sum = 0
    
    def update(self, loss, acc):
        self.loss_sum += loss
        self.acc_sum += acc
        self.steps += 1
    
    def __call__(self):
        return self.loss_sum/float(self.steps), self.acc_sum/float(self.steps)



def kd_train_and_evaluate(teacher_model, student_model, train_dataloader, val_dataloader, criteria, optimizer, scheduler, alpha, temperature):
    
    best_val_acc = 0.0
    for epoch in range(epochs):
        
        print("Epoch {}/{}".format(epoch + 1, epochs))

        # ---------- train ------------
        
        student_model.train()
        metric_avg = RunningAverage()
        
        for i, (train_batch, labels_batch) in enumerate(train_dataloader):
            train_batch, labels_batch = train_batch.to(device), labels_batch.to(device)
            
            student_outputs = student_model(train_batch)
            
            with torch.no_grad():
                teacher_outputs = teacher_model(train_batch)
                
            loss = criteria(student_outputs, teacher_outputs, labels_batch, alpha, temperature)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if i % save_steps == 0:
                student_outputs = student_outputs.data.cpu().numpy()
                labels_batch = labels_batch.data.cpu().numpy()

                predict_labels = np.argmax(student_outputs, axis=1)
                acc = np.sum(predict_labels == labels_batch) / float(labels_batch.size)
                metric_avg.update(loss.item(), acc)
            
        scheduler.step()
        train_loss, train_acc = metric_avg()
        print("- Train metrics: loss={}, acc={}".format(train_loss, train_acc))      

        # ---------- validate ------------

        student_model.eval()
        metric_avg = RunningAverage()

        for val_batch, labels_batch in val_dataloader:
            val_batch, labels_batch = val_batch.to(device), labels_batch.to(device)

            student_outputs = student_model(val_batch)
            loss = 0
            
            student_outputs = student_outputs.data.cpu().numpy()
            labels_batch = labels_batch.data.cpu().numpy()

            predict_labels = np.argmax(student_outputs, axis=1)
            acc = np.sum(predict_labels == labels_batch) / float(labels_batch.size)
            metric_avg.update(loss, acc)
        _, val_acc = metric_avg()
        print("- Validate metrics: acc={}".format(val_acc))

# 教师网络

In [14]:
import torch.nn as nn
import torchvision.models as models

teacher_model = nn.Sequential(models.resnet50(),
                              nn.Dropout(0.5),
                              nn.ReLU(),
                              nn.Linear(in_features=1000, out_features=5, bias=True))

checkpoint = torch.load('/kaggle/input/resnet-50-weights/best.pth')
teacher_model.load_state_dict(checkpoint["state_dict"])
teacher_model.to(device)

print("prepare the teacher model --- done")

prepare the teacher model --- done


# 学生网络

In [15]:
student_model = nn.Sequential(models.resnet34(),
                         nn.ReLU(),
                         nn.Dropout(0.5),
                         nn.Linear(in_features=1000, out_features=5, bias=True))

student_model.to(device)
print("prepare the student model --- done")

prepare the student model --- done


# 损失函数

In [16]:
def loss_fn_kd(student_outputs, teacher_outputs, labels, alpha, temperature):

    alpha = alpha
    T = temperature

    KD_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(student_outputs/T, dim=1),
                                                  F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(student_outputs, labels) * (1.- alpha)

    return KD_loss

# 优化器 与 学习率调度器

In [17]:
import torchvision.models as models
import torch.nn.functional as F

optimizer = optim.SGD(student_model.parameters(), lr=lr, momentum=0.9)
scheduler = StepLR(optimizer, step_size=lr_step, gamma=0.1)

# 开始训练

In [18]:
kd_train_and_evaluate(teacher_model, student_model, train_loader, val_loader, loss_fn_kd, optimizer, scheduler, alpha, temperature)

Epoch 1/50
- Train metrics: loss=9.806779359516344, acc=0.4473684210526316
- Validate metrics: acc=0.5328557312252965
Epoch 2/50
- Train metrics: loss=8.057040189441881, acc=0.5657894736842105
- Validate metrics: acc=0.5926383399209486
Epoch 3/50
- Train metrics: loss=7.9696633941248844, acc=0.5657894736842105
- Validate metrics: acc=0.6131422924901185
Epoch 4/50
- Train metrics: loss=7.228650682850888, acc=0.5921052631578947
- Validate metrics: acc=0.6904644268774703
Epoch 5/50
- Train metrics: loss=7.567201990830271, acc=0.5921052631578947
- Validate metrics: acc=0.6957756916996047
Epoch 6/50
- Train metrics: loss=6.20396812338578, acc=0.6578947368421053
- Validate metrics: acc=0.6341403162055336
Epoch 7/50
- Train metrics: loss=5.5960291310360555, acc=0.6447368421052632
- Validate metrics: acc=0.7120800395256918
Epoch 8/50
- Train metrics: loss=5.658480305420725, acc=0.6710526315789473
- Validate metrics: acc=0.7251729249011858
Epoch 9/50
- Train metrics: loss=6.2022033114182324, ac