In [None]:
import torch
import random
import os
import sys
import logging
import numpy as np
import pandas as pd
from shutil import copy
from datetime import datetime
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, accuracy_score
import seaborn as sns
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(y_true, y_pred, classes, cmap_name='Blues'):
    """
    绘制混淆矩阵的函数。

    参数:
    y_true -- 真实标签列表。
    y_pred -- 预测标签列表。
    classes -- 类别标签列表。
    cmap_name -- 颜色映射名称，默认为'Blues'。
    """
    # 计算混淆矩阵
    cm = confusion_matrix(y_true, y_pred)

    # 绘制混淆矩阵
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap=cmap_name, cbar=False, xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.title('Confusion Matrix')
    plt.show()



def compute_class_frequencies(targets, num_classes):
    # 计算每个类别的样本数量
    class_counts = torch.bincount(targets, minlength=num_classes)
    
    # 防止除零错误，计算每个类别的频率
    class_freq = 1.0 / (class_counts + 1e-6)
    
    # 归一化类别频率
    class_freq = class_freq / torch.sum(class_freq)
    
    return class_freq

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def fix_randomness(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def _logger(logger_name, level=logging.DEBUG):
    """
    Method to return a custom logger with the given name and level
    """
    logger = logging.getLogger(logger_name)
    logger.setLevel(level)
    format_string = "%(message)s"
    log_format = logging.Formatter(format_string)
    # Creating and adding the console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(log_format)
    logger.addHandler(console_handler)
    # Creating and adding the file handler
    file_handler = logging.FileHandler(logger_name, mode='a')
    file_handler.setFormatter(log_format)
    logger.addHandler(file_handler)
    return logger


def starting_logs(data_type, exp_log_dir, seed_id):
    log_dir = os.path.join(exp_log_dir, "_seed_" + str(seed_id))
    os.makedirs(log_dir, exist_ok=True)
    log_file_name = os.path.join(log_dir, f"logs_{datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log")
    logger = _logger(log_file_name)
    logger.debug("=" * 45)
    logger.debug(f'Dataset: {data_type}')
    logger.debug("=" * 45)
    logger.debug(f'Seed: {seed_id}')
    logger.debug("=" * 45)
    return logger, log_dir


def save_checkpoint(exp_log_dir, model, dataset, dataset_configs, hparams, status):
    save_dict = {
        "dataset": dataset,
        "configs": dataset_configs.__dict__,
        "hparams": dict(hparams),
        "model": model.state_dict()
    }
    # save classification report
    save_path = os.path.join(exp_log_dir, f"checkpoint_{status}.pt")

    torch.save(save_dict, save_path)


def _calc_metrics(pred_labels, true_labels, classes_names):
    pred_labels = np.array(pred_labels).astype(int)
    true_labels = np.array(true_labels).astype(int)

    r = classification_report(true_labels, pred_labels, target_names=classes_names, digits=6, output_dict=True)
    accuracy = accuracy_score(true_labels, pred_labels)

    return accuracy * 100, r["macro avg"]["f1-score"] * 100


def _save_metrics(pred_labels, true_labels, log_dir, status):
    pred_labels = np.array(pred_labels).astype(int)
    true_labels = np.array(true_labels).astype(int)

    r = classification_report(true_labels, pred_labels, digits=6, output_dict=True)

    df = pd.DataFrame(r)
    accuracy = accuracy_score(true_labels, pred_labels)
    df["accuracy"] = accuracy
    df = df * 100

    # save classification report
    file_name = f"classification_report_{status}.xlsx"
    report_Save_path = os.path.join(log_dir, file_name)
    df.to_excel(report_Save_path)


import collections


def to_device(input, device):
    if torch.is_tensor(input):
        return input.to(device=device)
    elif isinstance(input, str):
        return input
    elif isinstance(input, collections.abc.Mapping):
        return {k: to_device(sample, device=device) for k, sample in input.items()}
    elif isinstance(input, collections.abc.Sequence):
        return [to_device(sample, device=device) for sample in input]
    else:
        raise TypeError("Input must contain tensor, dict or list, found {type(input)}")



# 指数平滑函数
def exponential_smoothing(data, alpha):
    smoothed = np.zeros_like(data)
    smoothed[0] = data[0]
    for i in range(1, len(data)):
        smoothed[i] = alpha * data[i] + (1 - alpha) * smoothed[i - 1]
    return smoothed

def plot_metrics(train_losses, val_losses, train_accuracies, val_accuracies):
    
    train_losses = [loss.cpu().numpy() if isinstance(loss, torch.Tensor) else loss for loss in train_losses]
    val_losses = [loss.cpu().numpy() if isinstance(loss, torch.Tensor) else loss for loss in val_losses]
    train_accuracies = [acc.cpu().numpy() if isinstance(acc, torch.Tensor) else acc for acc in train_accuracies]
    val_accuracies = [acc.cpu().numpy() if isinstance(acc, torch.Tensor) else acc for acc in val_accuracies]
    

    # 平滑系数
    alpha = 0.25
    
    # 对训练损失、验证损失、训练准确度和验证准确度进行平滑
    smoothed_train_losses = exponential_smoothing(train_losses, alpha)
    smoothed_val_losses = exponential_smoothing(val_losses, alpha)
    smoothed_train_accuracies = exponential_smoothing(train_accuracies, alpha)
    smoothed_val_accuracies = exponential_smoothing(val_accuracies, alpha)
    
    smoothed_train_losses = exponential_smoothing(smoothed_train_losses, 0.5)
    smoothed_val_losses = exponential_smoothing(smoothed_val_losses, 0.5)
    smoothed_train_accuracies = exponential_smoothing(smoothed_train_accuracies, 0.5)
    smoothed_val_accuracies = exponential_smoothing(smoothed_val_accuracies, 0.5)
    plt.figure(figsize=(12, 5))
    
    # 第一个子图：平滑后的损失曲线
    plt.subplot(1, 2, 1)  # 1行2列的第1个
    plt.plot(smoothed_train_losses, label='Train Loss')
    plt.plot(smoothed_val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title(' Loss Curves')
    plt.legend()
    
    # 第二个子图：平滑后的准确度曲线
    plt.subplot(1, 2, 2)  # 1行2列的第2个
    plt.plot(smoothed_train_accuracies, label='Train Accuracy')
    plt.plot(smoothed_val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title(' Accuracy Curves')
    plt.legend()
    
    # 显示整个图形
    plt.show()