In [1]:
import os
import yaml
import time
import datetime
import warnings
from tqdm import tqdm

import torch
import torch.nn.functional as F
import timm
import pandas as pd
import numpy as np
import torch.nn as nn

from box import Box
from torch.utils.data import DataLoader
from loguru import logger
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR

import sys
sys.path.append("../scr/")
sys.path.append("..")
from utils.create_dataset import BirdDataset
from utils.base_utils import set_seed
from utils.metrics import validation_epoch_end

In [2]:
# Наш конфиг для обучения модели 
class Config():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    debug= False
    seed = 1771
    path_to_files_base = "../data"
    batch_size = 4
    optimizer_lr = 0.006
    optimizer_wd = 0
    scheduler = "CosineAnnealingWarmRestarts"
    metric = "custom"
    loss_f = "nn.BCEWithLogitsLoss()"
    optimizer = "Adam"
    epochs = 5
    num_workers = 0
config = Config()
set_seed(seed=config.seed)

[32m2023-10-23 17:25:46.717[0m | [1mINFO    [0m | [36mutils.base_utils[0m:[36mset_seed[0m:[36m19[0m - [1mSet seed: 1771[0m


In [3]:
# Вернемся к нашему датасету. 
# Наша модель учитель была модель tf_efficientnet_b4. 
# tf_efficientnet_b0 - будет нашей моделью учеником. 
# Подгрузим нашу модель учителя, и наш датасет в таком же соотношения, как обучалась наша модель учитель 

In [4]:
# load models

model_teacher = torch.load("../experiment/23_October_2023_14_17/model_tf_efficientnet_b4_last_version.pt", map_location=config.device).to(config.device)
model_student = timm.create_model("tf_efficientnet_b0", pretrained=True).to(config.device)
model_student.classifier = nn.Sequential(
        nn.Linear(model_student.classifier.in_features, 264)
    )

model_teacher.to(config.device)
model_student.to(config.device)

metric = validation_epoch_end
optimizer = torch.optim.Adam(model_student.parameters(),
                                     lr=config.optimizer_lr,
                                     weight_decay=config.optimizer_wd
                                     )
loss_f = nn.BCEWithLogitsLoss() 
    
logger.info(f"Scheduler - {config.scheduler}")
scheduler = CosineAnnealingWarmRestarts(optimizer,
                                        T_0=10,
                                        T_mult=2,
                                        eta_min=0.000001,
                                        last_epoch=-1)
model_teacher.eval()

print('OK')

[32m2023-10-23 17:25:47.535[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m19[0m - [1mScheduler - CosineAnnealingWarmRestarts[0m


OK


In [5]:
class ImageClassificationKnowledgeDistillationTrainer():
    def __init__(self, teacher_model, alpha, temperature):
        self.teacher_model = teacher_model
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(self, model,  X_batch, y_batch, return_outputs=False):
        # Extract logits and loss from student model
        outputs_student = model.forward(X_batch)
        logits_student = outputs_student
        loss_f = nn.BCEWithLogitsLoss() 
        loss_ce = loss_f(outputs_student.float(), y_batch)

        # Extract logits from teacher model
        outputs_teacher = self.teacher_model(X_batch)
        logits_teacher = outputs_teacher

        # Compute distillation loss using Kullback-Leibler Divergence
        loss_fct = nn.KLDivLoss(reduction="batchmean")
        loss_kd = self.temperature ** 2 * loss_fct(
            F.log_softmax(logits_student / self.temperature, dim=-1),
            F.softmax(logits_teacher / self.temperature, dim=-1)
        )

        # Combine classification loss and distillation loss with a weight (alpha)
        loss = self.alpha * loss_ce + (1. - self.alpha) * loss_kd

        return (loss, outputs_student) if return_outputs else loss
        
trainer_for_find_loss = ImageClassificationKnowledgeDistillationTrainer(
    teacher_model=model_teacher,
    alpha=0.5,
    temperature=2.0
)

In [6]:
# load datasets 
df = pd.read_csv("../data/data.csv")
df_train, df_test = (df[df.fold != 3].reset_index(drop=True),
                     df[df.fold == 3].reset_index(drop=True)
                     )

logger.info(f"Size df_train- {df_train.shape[0]}")
logger.info(f"Size df_test- {df_test.shape[0]}")

dataset_train = BirdDataset(df=df_train,
                            path_to_folder_with_audio=config.path_to_files_base
                            )
dataset_test = BirdDataset(df=df_test,
                           path_to_folder_with_audio=config.path_to_files_base
                           )

train_loader = DataLoader(dataset_train,
                          batch_size=config.batch_size,
                          shuffle=True,
                          num_workers=config.num_workers)
valid_loader = DataLoader(dataset_test,
                          batch_size=config.batch_size,
                          num_workers=config.num_workers)

[32m2023-10-23 17:25:47.588[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mSize df_train- 12326[0m
[32m2023-10-23 17:25:47.589[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mSize df_test- 3082[0m


In [None]:
for epoch_i in range(1, config.epochs + 1):
    k = 0
    start = time.time()
    logger.info(f'---------------------epoch:{epoch_i}/{config.epochs}---------------------')

    # loss
    avg_train_loss = 0
    avg_val_loss = 0
    predicted_labels_list = None
    true_labels_list = None

    ############## Train #############
    model_student.train()
    train_pbar = tqdm(train_loader, desc="Training")
    for batch in train_pbar:
        X_batch = batch[0].to(config.device)
        y_batch = batch[1].to(config.device)

        optimizer.zero_grad()
        loss, res = trainer_for_find_loss.compute_loss(
            model=model_student,
            X_batch=X_batch,
            y_batch=y_batch,
            return_outputs=True
        )

        if torch.cuda.is_available():
            train_pbar.set_postfix(gpu_load=f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f}GB",
                                   loss=f"{loss.item():.4f}")
        else:
            train_pbar.set_postfix(loss=f"{loss.item():.4f}")

        loss.backward()
        optimizer.step()

        avg_train_loss += loss * len(y_batch)
        del batch, res

        if config.scheduler:
            scheduler.step()

        if config.debug:
            k += 1
            if k > 5:
                break

    model_student.eval()

    ########## VALIDATION ###############
    with torch.no_grad():
        for batch in (valid_loader):
            X_batch = batch[0].to(config.device)
            y_batch = batch[1].to(config.device)

            loss, res = trainer_for_find_loss.compute_loss(
            model=model_student,
            X_batch=X_batch,
            y_batch=y_batch,
            return_outputs=True
        )
            y_batch_onehot = y_batch

            avg_val_loss += loss * len(y_batch)

            # metrics
            res = res.detach().cpu().sigmoid().numpy()
            y_batch_onehot = y_batch_onehot.unsqueeze(1).detach().cpu().numpy()
            y_batch_onehot = y_batch_onehot.squeeze()

            if predicted_labels_list is None:
                predicted_labels_list = res
                true_labels_list = y_batch_onehot
            else:
                predicted_labels_list = np.concatenate([predicted_labels_list, res], axis=0)
                true_labels_list = np.concatenate([true_labels_list, y_batch_onehot], axis=0)

            del batch, res

            if config.debug:
                k += 1
                if k > 10:
                    break

    torch.cuda.empty_cache()

    avg_train_loss = avg_train_loss / len(dataset_train)
    avg_val_loss = avg_val_loss / len(dataset_test)

    all_predicted_labels = np.vstack(predicted_labels_list)
    all_true_labels = np.vstack(true_labels_list)
    all_true_labels = np.squeeze(all_true_labels)
    mask = (all_true_labels > 0) & (all_true_labels < 1)
    all_true_labels[mask] = 0
    avg_metric = metric(all_true_labels, all_predicted_labels)

    logger.info(f'epoch: {epoch_i}')

    logger.info("loss_train: %0.4f| loss_valid: %0.4f|" % (avg_train_loss, avg_val_loss))
    for m in avg_metric:
        logger.info(f"metric {m} : {avg_metric[m]:.<5g}")

    elapsed_time = time.time() - start
    hours = int(elapsed_time // 3600)
    minutes = int((elapsed_time % 3600) // 60)
    seconds = int(elapsed_time % 60)
    logger.info(f"Elapsed time: {hours:02d}:{minutes:02d}:{seconds:02d}")


[32m2023-10-23 17:25:47.608[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1m---------------------epoch:1/5---------------------[0m
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3082/3082 [19:51<00:00,  2.59it/s, gpu_load=2.13GB, loss=0.7887]
[32m2023-10-23 17:50:05.730[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m96[0m - [1mepoch: 1[0m
[32m2023-10-23 17:50:05.732[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m98[0m - [1mloss_train: 1.8046| loss_valid: 1.3870|[0m
[32m2023-10-23 17:50:05.739[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m100[0m - [1mmetric val_RMAP : 0.283333[0m
[32m2023-10-23 17:50:05.740[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m106[0m - [1mElapsed time: 00:24:18[0m
[32m2023-10-23 17:50:05.740[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1m----

In [None]:
# Results model tf_efficientnet_b4.

results_teacher = [{"epoch": 1,
                    "loss_train": 0.0236,
                    "loss_valid": 0.0198,
                    "val_RMAP": 0.277 
                   },
                   {"epoch": 2,
                    "loss_train": 0.0192,
                    "loss_valid": 0.0199,
                    "val_RMAP": 0.300846 
                   },
                   {"epoch": 3,
                    "loss_train": 0.0177,
                    "loss_valid": 0.0174,
                    "val_RMAP": 0.454 
                   },
                    {"epoch": 4,
                    "loss_train": 0.0175,
                    "loss_valid": 0.0240,
                    "val_RMAP": 0.364596 
                   },
                    {"epoch": 5,
                    "loss_train": 0.0167,
                    "loss_valid": 0.0376,
                    "val_RMAP": 0.473289
                   }]