In [None]:
#loss_fn_kd 함수 (net)

def loss_fn_kd(outputs, labels, teacher_outputs, params):

    alpha = params.alpha #kd 에서 사용하는 가중치 계수 (soft-hard 비율)
    T = params.temperature #softmax 를 부드럽게 만들기 위한 temperature

    # Knowledge Distillation Loss 구성
    # 1) Soft target loss (KL Divergence)
    #    - student 출력: log_softmax
    #    - teacher 출력: softmax
    #    - temperature T 적용
    #    - gradient scale을 맞추기 위해 T^2 곱함
    # 2) Hard target loss (Cross Entropy)
    #    - 일반적인 정답 레이블 기반 loss

    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha) # 최종 loss = alpha * KD loss + (1 - alpha) * CE loss

    return KD_loss




In [None]:
#train_kd 함수 (train)

def train_kd(model, teacher_model, optimizer, loss_fn_kd, dataloader, metrics, params):

    # student model 학습 모드
    model.train()

    #teacher model은 추론만 수행하므로 eval 모드
    teacher_model.eval()

    summ = []
    # loss의 이동 평균을 계산하기 위한 객체
    loss_avg = utils.RunningAverage()

    # tqdm을 이용한 진행 바
    with tqdm(total=len(dataloader)) as t:
        for i, (train_batch, labels_batch) in enumerate(dataloader):

            if params.cuda:
                train_batch, labels_batch = train_batch.cuda(async=True), \
                                            labels_batch.cuda(async=True)
            # Tensor를 Variable로 감싸서 autograd 적용
            train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)

            # Forward pass (Student)
            output_batch = model(train_batch)
            with torch.no_grad():
                output_teacher_batch = teacher_model(train_batch)
            if params.cuda:
                output_teacher_batch = output_teacher_batch.cuda(async=True)

            # KD Loss 계산
            loss = loss_fn_kd(output_batch, labels_batch, output_teacher_batch, params)

            # 이전 step의 gradient 초기화
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            # Metric 계산 (일정 step마다)
            if i % params.save_summary_steps == 0:
                # GPU → CPU → numpy 변환
                output_batch = output_batch.data.cpu().numpy()
                labels_batch = labels_batch.data.cpu().numpy()

                # metric 계산
                summary_batch = {metric:metrics[metric](output_batch, labels_batch)
                                 for metric in metrics}
                summary_batch['loss'] = loss.data[0]
                summ.append(summary_batch)

            loss_avg.update(loss.data[0])

            t.set_postfix(loss='{:05.3f}'.format(loss_avg()))
            t.update()

    # Epoch 단위 metric 평균 계산
    metrics_mean = {metric:np.mean([x[metric] for x in summ]) for metric in summ[0]}
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
    logging.info("- Train metrics: " + metrics_string)
