# 导入库

In [1]:
# 导入所需的库和模块
# - 数据处理和操作库：pandas, numpy
# - PyTorch相关库：torch及其各种子模块
# - 图像处理库：PIL, albumentations
# - 模型和评估工具：transformers, torchmetrics
# - 实验跟踪和日志：wandb, tqdm
# - 其他实用工具：warnings, colorama, os, random等

import copy
import gc
import os
import random
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import albumentations as A
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import wandb
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from albumentations.pytorch import ToTensorV2
from colorama import Back, Fore, Style
from PIL import Image
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold
from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

import warnings
warnings.simplefilter('ignore')

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

c_ = Fore.GREEN
sr_ = Style.RESET_ALL

# 配置类

In [2]:
# 配置类，定义了模型训练和评估的所有超参数和设置
class Config:
    seed = 101
    debug = False  # set debug=False for Full Training
    exp_name = "vit/sbert-multilabel"
    model_name = "vit-sbert-multimodal-multilabel"
    backbone = "google/vit-base-patch16-224+sentence-transformers/all-mpnet-base-v2-ep10"
    tokenizer = "sentence-transformers/all-mpnet-base-v2"
    image_encoder = "google/vit-base-patch16-224"
    train_bs = 16
    valid_bs = 32
    img_size = [224, 224]
    max_len = 128
    epochs = 10
    competition = "memotions-7k"

    # 多标签配置
    label_names = ['humour', 'sarcasm', 'offensive', 'motivational', 'overall_sentiment']
    humour_classes = 4       # not_funny, funny, very_funny, hilarious
    sarcasm_classes = 4      # not_sarcastic, general, twisted_meaning, very_twisted
    offensive_classes = 3    # not_offensive, slight, very_offensive
    motivational_classes = 2 # not_motivational, motivational
    sentiment_classes = 5    # very_negative, negative, neutral, positive, very_positive

    # 标签映射
    humour_map = {'not_funny': 0, 'funny': 1, 'very_funny': 2, 'hilarious': 3}
    sarcasm_map = {'not_sarcastic': 0, 'general': 1, 'twisted_meaning': 2, 'very_twisted': 3}
    offensive_map = {'not_offensive': 0, 'slight': 1, 'very_offensive': 2}
    motivational_map = {'not_motivational': 0, 'motivational': 1}
    sentiment_map = {'very_negative': 0, 'negative': 1, 'neutral': 2, 'positive': 3, 'very_positive': 4}

    # 交叉注意力参数
    ca_hidden_size = 256
    ca_num_heads = 8
    ca_dropout = 0.1

    # 优化器参数
    optimizer = 'Adam'
    learning_rate = 3e-4
    rho = 0.9
    eps = 1e-6
    lr_decay = 0
    betas = (0.9, 0.999)
    momentum = 0
    alpha = 0.99

    # 调度器参数
    scheduler = 'CosineAnnealingLR'
    min_lr = 1e-6
    T_max = int(30000/train_bs*epochs)+50
    T_0 = 25
    warmup_epochs = 0
    weight_decay = 1e-6

    # 训练配置
    n_accumulate = max(1, 32//train_bs)
    num_folds = 5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 工具函数

In [3]:
# 设置随机种子函数
def set_seed(seed: int = 42):
    """设置随机种子，确保实验结果可复现"""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(">>> SEEDED <<<")

set_seed(Config.seed)

>>> SEEDED <<<


In [4]:

# 配置Weights & Biases (WandB) 用于实验跟踪和可视化
try:
    # 尝试使用已有的登录凭证
    wandb.init(project="multimodal-multilabel-sentiment-analysis", resume=True)
    wandb.finish()  # 立即结束初始化的运行
    print("已成功连接到wandb")
except Exception as e:
    print(f"wandb连接异常: {e}")
    print("尝试重新登录...")
    wandb.login()  # 如果需要，会提示您在终端输入API密钥

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: hanser33 (hanser33-nanjing-tech-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


已成功连接到wandb


## 选择优化器

In [5]:

# 根据配置选择并返回合适的优化器
def get_optimizer(model: nn.Module):
    """返回基于配置的优化器"""
    if Config.optimizer == "Adadelta":
        optimizer = optim.Adadelta(
            model.parameters(), lr=Config.learning_rate, rho=Config.rho, eps=Config.eps
        )
    elif Config.optimizer == "Adagrad":
        optimizer = optim.Adagrad(
            model.parameters(),
            lr=Config.learning_rate,
            lr_decay=Config.lr_decay,
            weight_decay=Config.weight_decay,
        )
    elif Config.optimizer == "Adam":
        optimizer = optim.Adam(
            model.parameters(),
            lr=Config.learning_rate,
            betas=Config.betas,
            eps=Config.eps,
        )
    elif Config.optimizer == "RMSProp":
        optimizer = optim.RMSprop(
            model.parameters(),
            lr=Config.learning_rate,
            alpha=Config.alpha,
            eps=Config.eps,
            weight_decay=Config.weight_decay,
            momentum=Config.momentum,
        )
    else:
        raise NotImplementedError(
            f"优化器 {Config.optimizer} 尚未实现。"
        )
    return optimizer

## 选择学习率调度器

In [6]:

# 根据配置选择并返回合适的学习率调度器
def get_scheduler(optimizer: optim):
    """返回基于配置的学习率调度器"""
    if Config.scheduler == "CosineAnnealingLR":
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=Config.T_max, eta_min=Config.min_lr
        )
    elif Config.scheduler == "CosineAnnealingWarmRestarts":
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer=optimizer, T_0=Config.T_0, eta_min=Config.min_lr
        )
    elif Config.scheduler == "ReduceLROnPlateau":
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer,
            mode="min",
            factor=0.1,
            patience=10,
            threshold=0.0001,
            min_lr=Config.min_lr,
        )
    elif Config.scheduler == "ExponentialLR":
        scheduler = lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.85)
    elif Config.scheduler is None:
        scheduler = None
    else:
        raise NotImplementedError(
            "请求的调度器尚未实现"
        )
    return scheduler

## 使用相关性矩阵

In [7]:

# 标签相关性分析工具
def analyze_label_correlation(df: pd.DataFrame) -> Tuple[pd.DataFrame, np.ndarray]:
    """分析并返回标签间的相关性矩阵"""
    label_columns = Config.label_names

    # 计算各标签间的相关系数
    correlation_matrix = df[label_columns].corr()

    # 可视化相关性矩阵
    plt.figure(figsize=(10, 8))
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
    plt.title('标签间相关性矩阵')
    plt.tight_layout()
    plt.savefig('label_correlation.png')
    plt.close()

    # 计算标签共现矩阵
    label_values = {}
    for label in label_columns:
        label_values[label] = df[label].unique()

    # 返回相关性矩阵和可视化图像
    return correlation_matrix, 'label_correlation.png'

def calculate_label_weights(df: pd.DataFrame) -> Dict[str, torch.Tensor]:
    """计算每个标签类别的权重，用于处理类别不平衡"""
    weights = {}

    # 为每个标签计算类别权重
    for label in Config.label_names:
        if label == 'humour':
            class_counts = df[label].value_counts().sort_index()
            weights[label] = torch.tensor([1.0 / (count / len(df)) for count in class_counts],
                                        dtype=torch.float32)

        elif label == 'sarcasm':
            class_counts = df[label].value_counts().sort_index()
            weights[label] = torch.tensor([1.0 / (count / len(df)) for count in class_counts],
                                        dtype=torch.float32)

        elif label == 'offensive':
            class_counts = df[label].value_counts().sort_index()
            weights[label] = torch.tensor([1.0 / (count / len(df)) for count in class_counts],
                                        dtype=torch.float32)

        elif label == 'motivational':
            class_counts = df[label].value_counts().sort_index()
            weights[label] = torch.tensor([1.0 / (count / len(df)) for count in class_counts],
                                        dtype=torch.float32)

        elif label == 'overall_sentiment':
            class_counts = df[label].value_counts().sort_index()
            weights[label] = torch.tensor([1.0 / (count / len(df)) for count in class_counts],
                                        dtype=torch.float32)

    return weights

# 数据处理

In [None]:
# 创建训练数据的多标签分层K折交叉验证划分
def create_folds():
    """创建数据的多标签分层K折交叉验证划分"""
    df = pd.read_csv('E:\\DL\\NLPtest\\example\\memotion_dataset_7k\\labels.csv')

    # 删除不需要的列并重置索引
    if 'Unnamed: 0' in df.columns:
        df = df.drop('Unnamed: 0', axis=1)
    df = df.sample(frac=1, random_state=Config.seed).reset_index(drop=True)

    # 处理标签：将文本标签映射为数值
    df['humour'] = df['humour'].map(Config.humour_map)
    df['sarcasm'] = df['sarcasm'].map(Config.sarcasm_map)
    df['offensive'] = df['offensive'].map(Config.offensive_map)
    df['motivational'] = df['motivational'].map(Config.motivational_map)
    df['overall_sentiment'] = df['overall_sentiment'].map(Config.sentiment_map)

    # 创建一个组合标签用于分层抽样
    df['stratify_label'] = df['humour'].astype(str) + "_" + \
                           df['sarcasm'].astype(str) + "_" + \
                           df['offensive'].astype(str) + "_" + \
                           df['motivational'].astype(str) + "_" + \
                           df['overall_sentiment'].astype(str)

    # 使用分层K折交叉验证
    skf = StratifiedKFold(n_splits=Config.num_folds, shuffle=True, random_state=Config.seed)

    df['kfold'] = -1
    for fold, (train_idx, val_idx) in enumerate(skf.split(X=df, y=df['stratify_label'])):
        df.loc[val_idx, 'kfold'] = fold

    # 保存折叠信息到CSV
    df.to_csv('multilabel_folds.csv', index=False)

    # 进行标签相关性分析
    correlation_matrix, corr_image = analyze_label_correlation(df)

    print(f"数据集划分完成，共 {len(df)} 条数据，分为 {Config.num_folds} 折")
    print(f"标签相关性分析结果已保存到 {corr_image}")

    return df

# 数据集类

In [19]:

# 创建数据集类
class MultiLabelMemotionDataset(Dataset):
    """多标签情感分析数据集类"""
    def __init__(self, df: pd.DataFrame) -> None:
        super().__init__()
        self.df = df
        self.tokenizer = AutoTokenizer.from_pretrained(Config.tokenizer)
        self.transforms = A.Compose([
            A.Resize(height=Config.img_size[0], width=Config.img_size[1]),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

    def __len__(self) -> int:
        return self.df.shape[0]

    def __getitem__(self, ix: int) -> Dict[str, torch.Tensor]:
        row = self.df.iloc[ix]

        # 图像处理
        image_path = os.path.join('E:\\DL\\NLPtest\\example\\memotion_dataset_7k\\images', row['image_name'])
        img = np.array(Image.open(image_path).convert('RGB'))
        img = self.transforms(image=img)['image']

        # 文本处理
        text = str(row['text_corrected']).lower()
        out = self.tokenizer(
            text=text,
            max_length=Config.max_len,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        # 安全地转换标签为长整型
        try:
            humour_label = torch.tensor(int(row['humour']) if not pd.isna(row['humour']) else 0, dtype=torch.long)
        except (ValueError, TypeError):
            humour_label = torch.tensor(0, dtype=torch.long)  # 默认标签为0

        try:
            sarcasm_label = torch.tensor(int(row['sarcasm']) if not pd.isna(row['sarcasm']) else 0, dtype=torch.long)
        except (ValueError, TypeError):
            sarcasm_label = torch.tensor(0, dtype=torch.long)

        try:
            offensive_label = torch.tensor(int(row['offensive']) if not pd.isna(row['offensive']) else 0, dtype=torch.long)
        except (ValueError, TypeError):
            offensive_label = torch.tensor(0, dtype=torch.long)

        try:
            motivational_label = torch.tensor(int(row['motivational']) if not pd.isna(row['motivational']) else 0, dtype=torch.long)
        except (ValueError, TypeError):
            motivational_label = torch.tensor(0, dtype=torch.long)

        try:
            sentiment_label = torch.tensor(int(row['overall_sentiment']) if not pd.isna(row['overall_sentiment']) else 0, dtype=torch.long)
        except (ValueError, TypeError):
            sentiment_label = torch.tensor(0, dtype=torch.long)

        return {
            'image': img,
            'input_ids': out['input_ids'].squeeze(),
            'attention_mask': out['attention_mask'].squeeze(),
            'humour': humour_label,
            'sarcasm': sarcasm_label,
            'offensive': offensive_label,
            'motivational': motivational_label,
            'overall_sentiment': sentiment_label
        }

# 模型

In [20]:
# 定义多模态多标签情感分析模型
class MultiLabelMemotionModel(nn.Module):
    """多模态多标签情感分析模型"""
    def __init__(self):
        super().__init__()
        # 图像编码器
        self.image_encoder = AutoModel.from_pretrained(Config.image_encoder)
        self.image_fc = nn.Linear(self.image_encoder.config.hidden_size, Config.ca_hidden_size)

        # 文本编码器
        self.text_encoder = AutoModel.from_pretrained(Config.tokenizer)
        self.text_fc = nn.Linear(self.text_encoder.config.hidden_size, Config.ca_hidden_size)

        # 交叉注意力机制
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=Config.ca_hidden_size,
            num_heads=Config.ca_num_heads,
            dropout=Config.ca_dropout,
            batch_first=True
        )

        # 分类头
        self.humour_head = nn.Linear(Config.ca_hidden_size, Config.humour_classes)
        self.sarcasm_head = nn.Linear(Config.ca_hidden_size, Config.sarcasm_classes)
        self.offensive_head = nn.Linear(Config.ca_hidden_size, Config.offensive_classes)
        self.motivational_head = nn.Linear(Config.ca_hidden_size, Config.motivational_classes)
        self.sentiment_head = nn.Linear(Config.ca_hidden_size, Config.sentiment_classes)

    def forward(self, image, input_ids, attention_mask):
        # 图像特征提取
        img_features = self.image_encoder(pixel_values=image).last_hidden_state
        img_features = self.image_fc(img_features)

        # 文本特征提取
        text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        text_features = self.text_fc(text_features)

        # 交叉注意力融合
        fused_features, _ = self.cross_attention(img_features, text_features, text_features)

        # 分类头输出
        humour = self.humour_head(fused_features[:, 0, :])
        sarcasm = self.sarcasm_head(fused_features[:, 0, :])
        offensive = self.offensive_head(fused_features[:, 0, :])
        motivational = self.motivational_head(fused_features[:, 0, :])
        sentiment = self.sentiment_head(fused_features[:, 0, :])

        return humour, sarcasm, offensive, motivational, sentiment

# 训练
## 单个训练周期

In [21]:
# 单个训练周期优化
def train_one_epoch(model, optimizer, scheduler, dataloader):
    model.train()
    epoch_loss = 0.0
    step = 0

    for batch in tqdm(dataloader, desc="Training"):
        optimizer.zero_grad()

        # 获取输入数据
        image = batch['image'].to(Config.device)
        input_ids = batch['input_ids'].to(Config.device)
        attention_mask = batch['attention_mask'].to(Config.device)

        # 获取标签
        humour = batch['humour'].to(Config.device)
        sarcasm = batch['sarcasm'].to(Config.device)
        offensive = batch['offensive'].to(Config.device)
        motivational = batch['motivational'].to(Config.device)
        sentiment = batch['overall_sentiment'].to(Config.device)

        # 前向传播
        outputs = model(image, input_ids, attention_mask)

        # 计算各标签的损失
        humour_loss = F.cross_entropy(outputs[0], humour)
        sarcasm_loss = F.cross_entropy(outputs[1], sarcasm)
        offensive_loss = F.cross_entropy(outputs[2], offensive)
        motivational_loss = F.cross_entropy(outputs[3], motivational)
        sentiment_loss = F.cross_entropy(outputs[4], sentiment)

        # 总损失
        loss = humour_loss + sarcasm_loss + offensive_loss + motivational_loss + sentiment_loss

        # 反向传播
        loss.backward()

        # 梯度累积
        if (step + 1) % Config.n_accumulate == 0:
            optimizer.step()
            if scheduler:
                scheduler.step()

        # 记录训练详情到 wandb
        wandb.log({
            "train/loss": loss.item(),
            "train/humour_loss": humour_loss.item(),
            "train/sarcasm_loss": sarcasm_loss.item(),
            "train/offensive_loss": offensive_loss.item(),
            "train/motivational_loss": motivational_loss.item(),
            "train/sentiment_loss": sentiment_loss.item(),
            "learning_rate": optimizer.param_groups[0]['lr']
        }, step=step)

        epoch_loss += loss.item()
        step += 1

    return epoch_loss / len(dataloader)

## 单个验证周期

In [25]:
# 单个验证周期优化
def validate_one_epoch(model, dataloader, step=None):
    model.eval()
    epoch_loss = 0.0
    val_scores = defaultdict(list)

    # 为每个标签创建评估指标 - 修复：添加任务类型参数
    accuracy_metrics = {
        'humour': Accuracy(task="multiclass", num_classes=Config.humour_classes).to(Config.device),
        'sarcasm': Accuracy(task="multiclass", num_classes=Config.sarcasm_classes).to(Config.device),
        'offensive': Accuracy(task="multiclass", num_classes=Config.offensive_classes).to(Config.device),
        'motivational': Accuracy(task="multiclass", num_classes=Config.motivational_classes).to(Config.device),
        'overall_sentiment': Accuracy(task="multiclass", num_classes=Config.sentiment_classes).to(Config.device)
    }

    # 同样为其他指标添加任务类型
    precision_metrics = {
        'humour': Precision(task="multiclass", num_classes=Config.humour_classes).to(Config.device),
        'sarcasm': Precision(task="multiclass", num_classes=Config.sarcasm_classes).to(Config.device),
        'offensive': Precision(task="multiclass", num_classes=Config.offensive_classes).to(Config.device),
        'motivational': Precision(task="multiclass", num_classes=Config.motivational_classes).to(Config.device),
        'overall_sentiment': Precision(task="multiclass", num_classes=Config.sentiment_classes).to(Config.device)
    }

    recall_metrics = {
        'humour': Recall(task="multiclass", num_classes=Config.humour_classes).to(Config.device),
        'sarcasm': Recall(task="multiclass", num_classes=Config.sarcasm_classes).to(Config.device),
        'offensive': Recall(task="multiclass", num_classes=Config.offensive_classes).to(Config.device),
        'motivational': Recall(task="multiclass", num_classes=Config.motivational_classes).to(Config.device),
        'overall_sentiment': Recall(task="multiclass", num_classes=Config.sentiment_classes).to(Config.device)
    }

    f1_metrics = {
        'humour': F1Score(task="multiclass", num_classes=Config.humour_classes).to(Config.device),
        'sarcasm': F1Score(task="multiclass", num_classes=Config.sarcasm_classes).to(Config.device),
        'offensive': F1Score(task="multiclass", num_classes=Config.offensive_classes).to(Config.device),
        'motivational': F1Score(task="multiclass", num_classes=Config.motivational_classes).to(Config.device),
        'overall_sentiment': F1Score(task="multiclass", num_classes=Config.sentiment_classes).to(Config.device)
    }

    # 其余函数保持不变
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            # 获取输入数据
            image = batch['image'].to(Config.device)
            input_ids = batch['input_ids'].to(Config.device)
            attention_mask = batch['attention_mask'].to(Config.device)

            # 获取标签
            humour = batch['humour'].to(Config.device)
            sarcasm = batch['sarcasm'].to(Config.device)
            offensive = batch['offensive'].to(Config.device)
            motivational = batch['motivational'].to(Config.device)
            sentiment = batch['overall_sentiment'].to(Config.device)

            labels = [humour, sarcasm, offensive, motivational, sentiment]
            label_names = Config.label_names

            # 前向传播
            outputs = model(image, input_ids, attention_mask)

            # 计算总损失
            loss = (
                F.cross_entropy(outputs[0], humour) +
                F.cross_entropy(outputs[1], sarcasm) +
                F.cross_entropy(outputs[2], offensive) +
                F.cross_entropy(outputs[3], motivational) +
                F.cross_entropy(outputs[4], sentiment)
            )

            epoch_loss += loss.item()

            # 计算每个标签的评估指标
            for i, (output, label, label_name) in enumerate(zip(outputs, labels, label_names)):
                preds = torch.argmax(output, dim=1)
                acc = accuracy_metrics[label_name](preds, label)
                prec = precision_metrics[label_name](preds, label)
                rec = recall_metrics[label_name](preds, label)
                f1 = f1_metrics[label_name](preds, label)

                val_scores[f"{label_name}_accuracy"].append(float(acc.item()))
                val_scores[f"{label_name}_precision"].append(float(prec.item()))
                val_scores[f"{label_name}_recall"].append(float(rec.item()))
                val_scores[f"{label_name}_f1"].append(float(f1.item()))

            # 记录详细的验证指标到 wandb
            if step is not None:
                # 只记录批次级别的损失，其他指标在epoch结束时记录
                wandb.log({"valid/batch_loss": loss.item()}, step=step)

    # 计算平均指标
    avg_metrics = {}
    for key, values in val_scores.items():
        avg_metrics[key] = float(np.mean(values))

    # 在epoch结束时记录所有平均指标
    if step is not None:
        log_dict = {"valid/loss": epoch_loss / len(dataloader)}
        for key, value in avg_metrics.items():
            log_dict[f"valid/{key}"] = value
        wandb.log(log_dict, step=step)

    return epoch_loss / len(dataloader), val_scores

# 训练流程

In [26]:
# 完整训练流程优化
def run_training(
    model: nn.Module,
    optimizer: optim,
    trainloader: DataLoader,
    validloader: DataLoader,
    run: wandb.run,
    fold: int,
    scheduler: lr_scheduler = None,
) -> Tuple[nn.Module, defaultdict]:

    # 监控模型权重和梯度
    wandb.watch(models=[model], log_freq=100)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')
    best_epoch = -1
    history = defaultdict(list)

    for epoch in range(Config.epochs):
        gc.collect()
        print(f"\t\t\t\t########## EPOCH [{epoch+1}/{Config.epochs}] ##########")

        # 训练
        train_loss = train_one_epoch(
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            dataloader=trainloader,
        )

        # 验证
        step = (epoch + 1) * len(trainloader)
        valid_loss, valid_scores = validate_one_epoch(
            model=model,
            dataloader=validloader,
            step=step
        )

        # 计算平均评估指标
        avg_metrics = {}
        for label in Config.label_names:
            for metric in ["accuracy", "precision", "recall", "f1"]:
                metric_key = f"{label}_{metric}"
                avg_metrics[metric_key] = float(np.mean(valid_scores[metric_key]))

        # 记录每轮训练的主要指标
        wandb.log({
            "train/epoch/loss": float(train_loss),
            "valid/epoch/loss": float(valid_loss),
            "current_lr": float(optimizer.param_groups[0]["lr"]),
            "epoch": epoch
        })

        # 为每个标签记录指标并保存到历史记录
        for label in Config.label_names:
            for metric in ["accuracy", "precision", "recall", "f1"]:
                metric_key = f"{label}_{metric}"
                metric_value = avg_metrics[metric_key]
                wandb.log({f"valid/epoch/{metric_key}": metric_value})
                history[metric_key].append(metric_value)

        # 计算总体平均指标
        overall_metrics = {
            "accuracy": float(np.mean([avg_metrics[f"{label}_accuracy"] for label in Config.label_names])),
            "precision": float(np.mean([avg_metrics[f"{label}_precision"] for label in Config.label_names])),
            "recall": float(np.mean([avg_metrics[f"{label}_recall"] for label in Config.label_names])),
            "f1": float(np.mean([avg_metrics[f"{label}_f1"] for label in Config.label_names]))
        }

        # 记录总体平均指标
        for metric, value in overall_metrics.items():
            wandb.log({f"valid/epoch/overall_{metric}": value})
            history[f"overall_{metric}"].append(value)

        print(f'Train Loss: {train_loss:.5f} | Valid Loss: {valid_loss:.5f}')
        print(f'Overall Accuracy: {overall_metrics["accuracy"]:.5f}')

        # 保存最佳模型
        if valid_loss < best_loss:
            print(f"{c_}Validation Score Improved from {best_loss:.5f} to {valid_loss:.5f}{sr_}")
            best_epoch = epoch + 1
            best_loss = valid_loss

            # 更新 wandb 摘要 - 使用简单类型
            run.summary["Best_Loss"] = float(best_loss)
            run.summary["Best_Epoch"] = int(best_epoch)

            # 记录各标签的最佳指标
            for label in Config.label_names:
                for metric in ["accuracy", "precision", "recall", "f1"]:
                    metric_key = f"{label}_{metric}"
                    run.summary[f"Best_{label}_{metric}"] = float(avg_metrics[metric_key])

            # 记录总体最佳指标
            for metric, value in overall_metrics.items():
                run.summary[f"Best_Overall_{metric}"] = float(value)

            # 保存最佳模型
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"models/best/best_epoch-fold{fold:02d}.bin"
            torch.save(obj=best_model_wts, f=PATH)
            wandb.save(PATH)
            print(f"MODEL SAVED!{sr_}")

        # 保存最后一轮模型
        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = f"models/last/last_epoch-fold{fold:02d}.bin"
        torch.save(last_model_wts, PATH)

    # 加载最佳模型
    model.load_state_dict(best_model_wts, strict=True)

    # 保存训练历史
    torch.save(history, f=f"history/fold-{fold:02d}.pth")

    # 绘制训练历史
    fig, axs = plt.subplots(2, 2, figsize=(20, 15))
    fig.suptitle('Training History', fontsize=16)
    metrics_plot = ["accuracy", "precision", "recall", "f1"]

    for i, metric in enumerate(metrics_plot):
        row, col = i // 2, i % 2
        for label in Config.label_names:
            metric_key = f"{label}_{metric}"
            axs[row, col].plot(history[metric_key], label=f"{label}")
        axs[row, col].plot(history[f"overall_{metric}"], label="overall", linewidth=2, color="black")
        axs[row, col].set_title(f'{metric.capitalize()}')
        axs[row, col].set_xlabel('Epoch')
        axs[row, col].set_ylabel(f'{metric.capitalize()}')
        axs[row, col].legend()
        axs[row, col].grid(True)

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(f"history/metrics-fold{fold:02d}.png")
    wandb.log({"metrics_history": wandb.Image(plt)})
    plt.close(fig)

    return model, history

In [27]:
# 运行训练优化
def prepare_dataloaders(fold, df) -> Tuple[DataLoader, DataLoader]:
    """准备数据加载器"""
    train_df = df[df['kfold'] != fold].reset_index(drop=True)
    valid_df = df[df['kfold'] == fold].reset_index(drop=True)

    # 创建数据集
    train_dataset = MultiLabelMemotionDataset(train_df)
    valid_dataset = MultiLabelMemotionDataset(valid_df)

    # 创建数据加载器
    trainloader = DataLoader(dataset=train_dataset, batch_size=Config.train_bs, shuffle=True)
    validloader = DataLoader(dataset=valid_dataset, batch_size=Config.valid_bs, shuffle=False)

    return trainloader, validloader

# 创建保存目录
os.makedirs('models', exist_ok=True)
os.makedirs('models/best', exist_ok=True)
os.makedirs('models/last', exist_ok=True)
os.makedirs('history', exist_ok=True)

# 运行训练
df = create_folds()

# 计算标签权重用于处理不平衡类别
label_weights = calculate_label_weights(df)

for fold in range(Config.num_folds):
    print('#'*50)
    print(f'### Fold [{fold+1}/{Config.num_folds}]')
    print('#'*50)

    # 初始化wandb运行
    run = wandb.init(
        project='multimodal-multilabel-sentiment-analysis',
        config={k:v for k, v in dict(vars(Config)).items() if not k.startswith('__')},
        name=f'FOLD-{fold+1}|MODEL-{Config.model_name}',
        group=f'MODEL-{Config.model_name}',
        job_type=f'fold-{fold}',
        reinit=True
    )

    # 记录标签分布和权重 - 修复 NaN 处理
    label_dist = {}
    for label in Config.label_names:
        # 过滤掉 NaN 值
        label_data = df[label].dropna()
        values, counts = np.unique(label_data, return_counts=True)
        # 安全转换
        dist = {str(int(val) if not np.isnan(val) else "NaN"):
                int(count) for val, count in zip(values, counts)}
        label_dist[label] = dist

    wandb.log({"label_distributions": label_dist})

    # 准备数据加载器
    trainloader, validloader = prepare_dataloaders(fold=fold, df=df)

    # 初始化模型、优化器和调度器
    model = MultiLabelMemotionModel().to(Config.device)
    optimizer = get_optimizer(model=model)
    scheduler = get_scheduler(optimizer=optimizer)

    # 运行训练
    model, history = run_training(
        model=model,
        optimizer=optimizer,
        trainloader=trainloader,
        validloader=validloader,
        run=run,
        fold=fold,
        scheduler=scheduler
    )

    # 完成当前折叠的训练
    run.finish()

print("训练完成！")

数据集划分完成，共 6992 条数据，分为 5 折
标签相关性分析结果已保存到 label_correlation.png
##################################################
### Fold [1/5]
##################################################


0,1
learning_rate,████████▇▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▃▃▃▃▂▂▂▁▁▁
train/humour_loss,▄▃▂█▃▃▃▂▆▄▁▄▂▃▁▅▃▃▄▃▄▃▃▃▃▃▃▅▃▄▂▂▃▄▂▁▂▂▆▂
train/loss,█▇▅▆▅▆▅▁▄▃▃▅▃▁▇▇▅▃▅▂█▆▅▆▄▄▃▅▅▅▆▇▃▃▄▄▅▃▆▄
train/motivational_loss,▄▃▃▂▄▄▃▃▂█▄▄▅▃▃▄▄▄▃▄▄▂▂▄▄▄▃▃▃▄▂▂▁▅▄▃▄▃▃▄
train/offensive_loss,▄▇▄▅▆▁▃▂▁▄▂█▃▄▂▅▇▄█▇▂▆▂▆▃▄▄▂▂▂▂▄▃▄▆▃▄▇▆▆
train/sarcasm_loss,▆▄▅▅▆▃▅▆▅▇▁▂▄▅▃▃▄▂▁▅█▆▇▅▄▄▆▄▂▂▃▄▆▆▄▅▇▃▃▅
train/sentiment_loss,█▇▇▆▅▁█▅▇▅▃▅▄▄▅▄▃▅▃▇▅▃█▄▃▃▂▄▅▃▃▃▃▃▃▃▃▄▁▅

0,1
learning_rate,0.0003
train/humour_loss,1.25652
train/loss,5.38031
train/motivational_loss,0.6943
train/offensive_loss,1.09648
train/sarcasm_loss,1.14621
train/sentiment_loss,1.18679


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


				########## EPOCH [1/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.10it/s]
Validating: 100%|██████████| 44/44 [00:15<00:00,  2.83it/s]


Train Loss: 5.50464 | Valid Loss: 5.49403
Overall Accuracy: 0.47062
Validation Score Improved from inf to 5.49403
MODEL SAVED!
				########## EPOCH [2/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.04it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.69it/s]


Train Loss: 5.47564 | Valid Loss: 5.48259
Overall Accuracy: 0.46476
Validation Score Improved from 5.49403 to 5.48259
MODEL SAVED!
				########## EPOCH [3/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.12it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.75it/s]


Train Loss: 5.47652 | Valid Loss: 5.48551
Overall Accuracy: 0.47062
				########## EPOCH [4/10] ##########


Training: 100%|██████████| 350/350 [01:24<00:00,  4.12it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.67it/s]


Train Loss: 5.46741 | Valid Loss: 5.48426
Overall Accuracy: 0.47062
				########## EPOCH [5/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.07it/s]
Validating: 100%|██████████| 44/44 [00:12<00:00,  3.60it/s]


Train Loss: 5.46289 | Valid Loss: 5.47420
Overall Accuracy: 0.47062
Validation Score Improved from 5.48259 to 5.47420
MODEL SAVED!
				########## EPOCH [6/10] ##########


Training: 100%|██████████| 350/350 [01:35<00:00,  3.66it/s]
Validating: 100%|██████████| 44/44 [00:17<00:00,  2.53it/s]


Train Loss: 5.46746 | Valid Loss: 5.47284
Overall Accuracy: 0.47062
Validation Score Improved from 5.47420 to 5.47284
MODEL SAVED!
				########## EPOCH [7/10] ##########


Training: 100%|██████████| 350/350 [01:44<00:00,  3.36it/s]
Validating: 100%|██████████| 44/44 [00:17<00:00,  2.52it/s]


Train Loss: 5.46458 | Valid Loss: 5.47176
Overall Accuracy: 0.47062
Validation Score Improved from 5.47284 to 5.47176
MODEL SAVED!
				########## EPOCH [8/10] ##########


Training: 100%|██████████| 350/350 [01:32<00:00,  3.79it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.69it/s]


Train Loss: 5.46535 | Valid Loss: 5.47677
Overall Accuracy: 0.47062
				########## EPOCH [9/10] ##########


Training: 100%|██████████| 350/350 [01:24<00:00,  4.15it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.75it/s]


Train Loss: 5.45900 | Valid Loss: 5.48448
Overall Accuracy: 0.47062
				########## EPOCH [10/10] ##########


Training: 100%|██████████| 350/350 [01:24<00:00,  4.13it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.72it/s]


Train Loss: 5.46033 | Valid Loss: 5.49273
Overall Accuracy: 0.47062


0,1
current_lr,██▇▇▆▆▅▄▂▁
epoch,▁▂▃▃▄▅▆▆▇█
learning_rate,██████████▇▇▇▇▇▆▆▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▁
train/epoch/loss,█▄▄▂▂▂▂▂▁▁
train/humour_loss,█▄▃▇▁▅▅▄▄▄▄▆▇▄▂▄▇▆▄▂▄▆▃▄▄▄▅▄▄▆▂▇▃▄▅▆▁▂▆▇
train/loss,█▆▆▃▇█▆▇▄▇█▃▅█▅▆▅▅▅▆▆▅▅▅▁▄▆▄▄▄▂▅▄▅▄▅▇▅▄▅
train/motivational_loss,▅▆▅▅▇▆▄▅▇▁▄▆▇▆▆▄▄▃▄▂▅▇▅▃▃▇▇▅▁█▇▅▃▁▅▂▅▄▄▅
train/offensive_loss,█▆▅▆▄▄▆▇▆▂▆▅▄▆▆▅▅▄▁▄▃▆▄▃▃▇▃▆▆▅▄▄▄▄▃▄▅▅▆▅
train/sarcasm_loss,▇▇▇▅▅█▆▅▅▅▅▆▄▂█▃▂█▆▆▇▄▆▇▁▆▃▅▅▄█▄▄▃▄▄▆▄▃▄
train/sentiment_loss,▆▄▄▄▄▆▂▅▄▃▅▁▅▇▅▃▄▃▅▅▄▄█▄█▃▄▄▆▁▄▃▄▃▃▄▂▅▅▆

0,1
Best_Epoch,7.0
Best_Loss,5.47176
Best_Overall_accuracy,0.47062
Best_Overall_f1,0.47062
Best_Overall_precision,0.47062
Best_Overall_recall,0.47062
Best_humour_accuracy,0.34854
Best_humour_f1,0.34854
Best_humour_precision,0.34854
Best_humour_recall,0.34854


##################################################
### Fold [2/5]
##################################################


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


				########## EPOCH [1/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.10it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.83it/s]


Train Loss: 5.50962 | Valid Loss: 5.48391
Overall Accuracy: 0.47397
Validation Score Improved from inf to 5.48391
MODEL SAVED!
				########## EPOCH [2/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.11it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.84it/s]


Train Loss: 5.47511 | Valid Loss: 5.46753
Overall Accuracy: 0.47397
Validation Score Improved from 5.48391 to 5.46753
MODEL SAVED!
				########## EPOCH [3/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.10it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.79it/s]


Train Loss: 5.46822 | Valid Loss: 5.47010
Overall Accuracy: 0.47397
				########## EPOCH [4/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.10it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.81it/s]


Train Loss: 5.46628 | Valid Loss: 5.50798
Overall Accuracy: 0.47397
				########## EPOCH [5/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.11it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.81it/s]


Train Loss: 5.46724 | Valid Loss: 5.46980
Overall Accuracy: 0.47397
				########## EPOCH [6/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.10it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.78it/s]


Train Loss: 5.46662 | Valid Loss: 5.47343
Overall Accuracy: 0.46825
				########## EPOCH [7/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.12it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.84it/s]


Train Loss: 5.47333 | Valid Loss: 5.47727
Overall Accuracy: 0.47397
				########## EPOCH [8/10] ##########


Training: 100%|██████████| 350/350 [01:24<00:00,  4.12it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.81it/s]


Train Loss: 5.46056 | Valid Loss: 5.46886
Overall Accuracy: 0.46222
				########## EPOCH [9/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.08it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.77it/s]


Train Loss: 5.46224 | Valid Loss: 5.50533
Overall Accuracy: 0.46222
				########## EPOCH [10/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.10it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.82it/s]


Train Loss: 5.46043 | Valid Loss: 5.49195
Overall Accuracy: 0.47397


0,1
current_lr,██▇▇▆▆▅▄▂▁
epoch,▁▂▃▃▄▅▆▆▇█
learning_rate,███████████▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▄▄▄▄▄▃▂▂▂▁▁
train/epoch/loss,█▃▂▂▂▂▃▁▁▁
train/humour_loss,▄▄▂▅▁▃▂▁▁▃▃█▄▂▄▁▃▄▄▂▃▃▅▃▂▃▂▂▂▃▃▂▂▃▂▅▃▃▃▁
train/loss,██▆▄▅▁▃▃▄▄▅▄▆▅▄▅▆▆▆▃▅▆▃▄▃▅▆█▂▄▅▂▃▇▄▃▄▂▅▁
train/motivational_loss,▄▄▆▁▃▄▃▅▃▄▃▅▅▃▃▄▄▄▇▅▅▄▅▄▅▃▅▃▆▃▃▃▂▄▆▅▃█▃▄
train/offensive_loss,▂▃█▅▂▄▂▅▁▂▂▃▂▃▃▄▂▂▂▁▄▃▃▄▃▆▃▂▂▄▃▄▅▂▃▄▄▃▄▂
train/sarcasm_loss,▆▅▆▄▅▃▅▇▄▂▃▃▆▆▅▄▅▅▆█▃▅▂▂▄▄▆▄▆▅▂▃▇▄▂▆▂▁█▅
train/sentiment_loss,▄▄▄▃▆▄▂▄▂▇▃▁▅▃▃▅▃▂▅▄▅▄▅▅▂▄▁▃▁▃▆▄▃█▃▃▂▃▄▇

0,1
Best_Epoch,2.0
Best_Loss,5.46753
Best_Overall_accuracy,0.47397
Best_Overall_f1,0.47397
Best_Overall_precision,0.47397
Best_Overall_recall,0.47397
Best_humour_accuracy,0.35295
Best_humour_f1,0.35295
Best_humour_precision,0.35295
Best_humour_recall,0.35295


##################################################
### Fold [3/5]
##################################################


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


				########## EPOCH [1/10] ##########


Training: 100%|██████████| 350/350 [01:37<00:00,  3.61it/s]
Validating: 100%|██████████| 44/44 [00:12<00:00,  3.60it/s]


Train Loss: 5.52135 | Valid Loss: 5.45702
Overall Accuracy: 0.46512
Validation Score Improved from inf to 5.45702
MODEL SAVED!
				########## EPOCH [2/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.06it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.79it/s]


Train Loss: 5.48704 | Valid Loss: 5.45339
Overall Accuracy: 0.46965
Validation Score Improved from 5.45702 to 5.45339
MODEL SAVED!
				########## EPOCH [3/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.06it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.78it/s]


Train Loss: 5.47508 | Valid Loss: 5.47468
Overall Accuracy: 0.47474
				########## EPOCH [4/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.07it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.78it/s]


Train Loss: 5.46991 | Valid Loss: 5.49468
Overall Accuracy: 0.46512
				########## EPOCH [5/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.08it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.78it/s]


Train Loss: 5.47357 | Valid Loss: 5.45629
Overall Accuracy: 0.47474
				########## EPOCH [6/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.07it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.77it/s]


Train Loss: 5.47103 | Valid Loss: 5.46331
Overall Accuracy: 0.47474
				########## EPOCH [7/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.08it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.79it/s]


Train Loss: 5.47419 | Valid Loss: 5.46833
Overall Accuracy: 0.47474
				########## EPOCH [8/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.07it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.82it/s]


Train Loss: 5.46813 | Valid Loss: 5.45281
Overall Accuracy: 0.47474
Validation Score Improved from 5.45339 to 5.45281
MODEL SAVED!
				########## EPOCH [9/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.07it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.77it/s]


Train Loss: 5.47083 | Valid Loss: 5.46852
Overall Accuracy: 0.46965
				########## EPOCH [10/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.07it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.75it/s]


Train Loss: 5.47104 | Valid Loss: 5.44767
Overall Accuracy: 0.47474
Validation Score Improved from 5.45281 to 5.44767
MODEL SAVED!


0,1
current_lr,██▇▇▆▆▅▄▂▁
epoch,▁▂▃▃▄▅▆▆▇█
learning_rate,█████████▇▇▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▃▃▃▃▃▂▂▂▂▁
train/epoch/loss,█▃▂▁▂▁▂▁▁▁
train/humour_loss,▅▄▄▃▄▃▃▆▂▁▂▅▂▄▄▁█▆▂▆▅▃▅▅▄█▄▂▃▅▆▃▆▃▄▄▄▂▄▃
train/loss,▆▆▄▃▅▅▃▄▆▃▁▅▂▃▁█▂▅▃▃▂▄▆▄▂▆▇▄▃▃▄▃▄▃▄▂▆▇▆▃
train/motivational_loss,▅▄▄▄▄▃▄▅▅▄▃▅▅▃▅▅▆▄▅▆▄▅▆▅▄▂▂▆▄▂▁▃▇▄█▂▄▆▅▄
train/offensive_loss,▄▇▄▃▆▆█▆█▃▄▄▆▆▄▅▅▃▇▅▆▆▂▄▇▃▇▅▆▇▄▅▁▁▇▇▃▄▄▄
train/sarcasm_loss,▅▄▃▂▃▅▄▄▄▂▁█▁▁▅▂▂▃▅▅▄▄▄▅▄▅▂▂▆▂▄▄▄▃▃▂▂▆▅▄
train/sentiment_loss,▆▄▃▄▃▃▅▅▂▂▂▄▄▂▃▃▄▃▃▂▂▂▃▄▁▃▃▄▂▃▃▃▅▅▃█▃█▆▁

0,1
Best_Epoch,10.0
Best_Loss,5.44767
Best_Overall_accuracy,0.47474
Best_Overall_f1,0.47474
Best_Overall_precision,0.47474
Best_Overall_recall,0.47474
Best_humour_accuracy,0.34846
Best_humour_f1,0.34846
Best_humour_precision,0.34846
Best_humour_recall,0.34846


##################################################
### Fold [4/5]
##################################################


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


				########## EPOCH [1/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.06it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.82it/s]


Train Loss: 5.51509 | Valid Loss: 5.46349
Overall Accuracy: 0.46832
Validation Score Improved from inf to 5.46349
MODEL SAVED!
				########## EPOCH [2/10] ##########


Training: 100%|██████████| 350/350 [01:28<00:00,  3.97it/s]
Validating: 100%|██████████| 44/44 [00:12<00:00,  3.58it/s]


Train Loss: 5.48953 | Valid Loss: 5.44798
Overall Accuracy: 0.47499
Validation Score Improved from 5.46349 to 5.44798
MODEL SAVED!
				########## EPOCH [3/10] ##########


Training: 100%|██████████| 350/350 [01:29<00:00,  3.89it/s]
Validating: 100%|██████████| 44/44 [00:13<00:00,  3.30it/s]


Train Loss: 5.47236 | Valid Loss: 5.44673
Overall Accuracy: 0.46832
Validation Score Improved from 5.44798 to 5.44673
MODEL SAVED!
				########## EPOCH [4/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.05it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.71it/s]


Train Loss: 5.47837 | Valid Loss: 5.44817
Overall Accuracy: 0.47499
				########## EPOCH [5/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.08it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.84it/s]


Train Loss: 5.47147 | Valid Loss: 5.44648
Overall Accuracy: 0.47499
Validation Score Improved from 5.44673 to 5.44648
MODEL SAVED!
				########## EPOCH [6/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.07it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.83it/s]


Train Loss: 5.47118 | Valid Loss: 5.45080
Overall Accuracy: 0.47499
				########## EPOCH [7/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.07it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.81it/s]


Train Loss: 5.47230 | Valid Loss: 5.44553
Overall Accuracy: 0.47499
Validation Score Improved from 5.44648 to 5.44553
MODEL SAVED!
				########## EPOCH [8/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.06it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.80it/s]


Train Loss: 5.46652 | Valid Loss: 5.44462
Overall Accuracy: 0.46521
Validation Score Improved from 5.44553 to 5.44462
MODEL SAVED!
				########## EPOCH [9/10] ##########


Training: 100%|██████████| 350/350 [01:25<00:00,  4.08it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.81it/s]


Train Loss: 5.46710 | Valid Loss: 5.45582
Overall Accuracy: 0.46832
				########## EPOCH [10/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.07it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.84it/s]


Train Loss: 5.46220 | Valid Loss: 5.45348
Overall Accuracy: 0.47499


0,1
current_lr,██▇▇▆▆▅▄▂▁
epoch,▁▂▃▃▄▅▆▆▇█
learning_rate,██████████████▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▂▂▁
train/epoch/loss,█▅▂▃▂▂▂▂▂▁
train/humour_loss,▆▅▄▅▅▁▄▄▃▆▅▅▄▁▆▃▄▃█▄▅▅▆▄▅▃█▂▂▇▄▆▅▆▃▃▆█▄█
train/loss,▆▅▄▃▄▁▅▆▄▃▃▁▇▄▃▆▅▄▃▃▂▆▃▃▃▂▂▅▅▁▄█▄▃▂▄▃▃▁▄
train/motivational_loss,▅▅▆▆▆▅▁▃▇█▄▅▅▄▅▆▄▄▅▅▆▆▆▇▆▂▃▅▅▅▅▇▂▅▅▅▂▅▅▃
train/offensive_loss,▆▆▅▅▆▅▅▁█▅▇▅▅▄▆▆▃▅▆▇▃▇▇▆▃▇▅▅▇▅▅▇▅▆▆▅▅▁▃▆
train/sarcasm_loss,▆▅▄▅▇▄▄▅▆▆▄▆▇▄▃▃▆▅▅▆▃▄▃▂▁▅▇▅▅▆▅▄▅▅▃▄▇█▃▃
train/sentiment_loss,▆▅▄█▃▅▂▄▅▃▃▂▄▄▄▅▄▄▃▃▃▇▁▅▄▄▅▄▂▅▄▂▄▃▃▅▃▅▃▅

0,1
Best_Epoch,8.0
Best_Loss,5.44462
Best_Overall_accuracy,0.46521
Best_Overall_f1,0.46521
Best_Overall_precision,0.46521
Best_Overall_recall,0.46521
Best_humour_accuracy,0.35021
Best_humour_f1,0.35021
Best_humour_precision,0.35021
Best_humour_recall,0.35021


##################################################
### Fold [5/5]
##################################################


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


				########## EPOCH [1/10] ##########


Training: 100%|██████████| 350/350 [01:27<00:00,  4.01it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.78it/s]


Train Loss: 5.51839 | Valid Loss: 5.47710
Overall Accuracy: 0.46495
Validation Score Improved from inf to 5.47710
MODEL SAVED!
				########## EPOCH [2/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.03it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.86it/s]


Train Loss: 5.47724 | Valid Loss: 5.46156
Overall Accuracy: 0.47191
Validation Score Improved from 5.47710 to 5.46156
MODEL SAVED!
				########## EPOCH [3/10] ##########


Training: 100%|██████████| 350/350 [01:27<00:00,  4.01it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.85it/s]


Train Loss: 5.47475 | Valid Loss: 5.47588
Overall Accuracy: 0.47191
				########## EPOCH [4/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.03it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.82it/s]


Train Loss: 5.46969 | Valid Loss: 5.46963
Overall Accuracy: 0.47191
				########## EPOCH [5/10] ##########


Training: 100%|██████████| 350/350 [01:27<00:00,  4.01it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.84it/s]


Train Loss: 5.46732 | Valid Loss: 5.46164
Overall Accuracy: 0.47191
				########## EPOCH [6/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.03it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.84it/s]


Train Loss: 5.46481 | Valid Loss: 5.46702
Overall Accuracy: 0.47191
				########## EPOCH [7/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.06it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.85it/s]


Train Loss: 5.46750 | Valid Loss: 5.47169
Overall Accuracy: 0.47191
				########## EPOCH [8/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.03it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.86it/s]


Train Loss: 5.46743 | Valid Loss: 5.47586
Overall Accuracy: 0.46340
				########## EPOCH [9/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.04it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.82it/s]


Train Loss: 5.46683 | Valid Loss: 5.46656
Overall Accuracy: 0.47191
				########## EPOCH [10/10] ##########


Training: 100%|██████████| 350/350 [01:26<00:00,  4.03it/s]
Validating: 100%|██████████| 44/44 [00:11<00:00,  3.88it/s]


Train Loss: 5.46603 | Valid Loss: 5.46561
Overall Accuracy: 0.47191


0,1
current_lr,██▇▇▆▆▅▄▂▁
epoch,▁▂▃▃▄▅▆▆▇█
learning_rate,██████████▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▅▄▄▄▃▃▃▃▂▂▁▁
train/epoch/loss,█▃▂▂▁▁▁▁▁▁
train/humour_loss,▃▃▃▁▁▃▅▃▂▄▂▄▄▄▁▄▅▂▄▃▂▃▂▃▄▃▃▄▅▁█▃▃▄▄▂▃▄▅▅
train/loss,▆▆▆▆▅▅▅▇▅▅▄▅▇▅▅▆▁▂▄▅▄▅▂▅▅█▄▆▄▄▂▆█▄▅▄▆▂▅▅
train/motivational_loss,▅▇▆▅▆▇▁▆▅▃▆▃▆▅█▆▇▂▃▄▆▆▅▆▇▁▄▅▄▅▃▅▃▆▅▆▆▂▂▅
train/offensive_loss,▅▁▅▆▄▆▄▃▃▅▄▂▄▅▅▅▃▄▄█▄▇▄▆▃▅▄▅▆▆▃▂▆▄▅▅▅▂▃▃
train/sarcasm_loss,▇▇▆▃▅▃██▄▃▅▆▅▃▇▆▅▄▇▅▃▄▆▄▃▅▄▅▆▃▁▇▆▄█▃▆▄▃▇
train/sentiment_loss,█▅▇▄▆▁▇▃▃▂▂▃▃▇▂▃▂▄▃▂▅██▄▃▁▅▂▄▅▃▂▃▅▃▂▂▅▃▃

0,1
Best_Epoch,2.0
Best_Loss,5.46156
Best_Overall_accuracy,0.47191
Best_Overall_f1,0.47191
Best_Overall_precision,0.47191
Best_Overall_recall,0.47191
Best_humour_accuracy,0.3524
Best_humour_f1,0.3524
Best_humour_precision,0.3524
Best_humour_recall,0.3524


训练完成！
