In [15]:
import sys
import os

# 获取当前 Notebook 的目录（根目录）
notebook_dir = os.getcwd()

# 将 ALMT/models 目录添加到 sys.path
sys.path.append(os.path.join(notebook_dir, "ALMT", "models"))
sys.path.append(os.path.join(notebook_dir, "ALMT", "core"))
sys.path.append(os.path.join(notebook_dir, "ALMT"))

In [51]:
del sys.modules['almt_layer']
del sys.modules['dataset']
del sys.modules['Adapter']

In [2]:
import torch
from torch import nn
from almt_layer import Transformer,CrossTransformerEncoder, HhyperLearningEncoder, CrossTransformer, Transformer
from bert import BertTextEncoder
from einops import repeat
from dataset import MMDataset
from core.scheduler import GradualWarmupScheduler
from Adapter import TextAdapter, AudioAdapter
import numpy as np

import torch.optim as optim
from torchvision import datasets, models, transforms

from torch.utils.data import DataLoader, Dataset
import torch.optim.lr_scheduler as lr_scheduler
from PIL import Image

import time
from tqdm import tqdm



2025-03-25 17:20:32.459746: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
import gc

def clear_memory():
    torch.cuda.empty_cache()  # 释放 GPU 缓存
    gc.collect()  # 释放 CPU 内存

clear_memory()

In [4]:
train_dataset = MMDataset(dataset='mosi',mode='train', image_num=5, generate_num = 4)
valid_dataset = MMDataset(dataset='mosi',mode='valid', image_num=5, generate_num = 1)

100%|██████████| 1284/1284 [15:20<00:00,  1.39it/s]
100%|██████████| 229/229 [00:37<00:00,  6.10it/s]


In [68]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=8)
val_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=8)

In [69]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [70]:
class CrossAttentionTA(nn.Module):
    def __init__(self, dim=128, heads=8, mlp_dim=128):
        super().__init__()
        self.cross_attn = CrossTransformer(source_num_frames=8, tgt_num_frames=8, dim=dim, depth=1, heads=heads, mlp_dim=mlp_dim)

    def forward(self, h_t, h_a):
        out = self.cross_attn(h_t, h_a)  # [B, 9, D]
        return out  # keep [CLS] and body for downstream split

class GateController(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, 1),
            nn.Sigmoid()  # outputs probability ∈ [0,1]
        )

    def forward(self, state):
        p = self.fc(state)  # [B, 1]
        dist = Bernoulli(probs=p)
        action = dist.sample()  # [B, 1]
        log_prob = dist.log_prob(action)  # [B, 1]
        return action, log_prob, p


In [71]:
class EMALMTBlock(nn.Module):
    def __init__(self, dim=128, heads=8, mlp_dim=128, dropout=0.):
        super(EMALMTBlock, self).__init__()

        # 融合模块（ALMT核心）
        self.fusion = HhyperLearningEncoder(dim=dim, depth=1, heads=heads, dim_head=16, dropout=dropout)

        # 视觉生成模块（E步）
        self.visual_predictor = Transformer(
            num_frames=16,
            save_hidden=False,
            token_len=None,  # 直接输出 token_len 长度的序列
            dim=dim,
            depth=1,
            heads=heads,
            mlp_dim=mlp_dim,
            dropout=dropout
        )

        # 视觉反馈映射模块（M步）
        self.feedback_proj = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim)
        )
        
        self.cross_ta = CrossAttentionTA(dim=dim, heads=heads, mlp_dim=mlp_dim)
        self.gate_controller = GateController(dim=2*dim)  # one per layer


    def forward(self, h_v_list, h_l, h_a, h_hyper_v, layer_idx):
        """
        输入：
        - h_v_list: List of vision encoder layer outputs
        - h_l, h_a: 融合输入
        - h_hyper_v: 上一层融合结果
        - layer_idx: 当前层 index，用于选择要更新的视觉层
        输出：
        - h_hyper_v: 当前层融合输出
        - h_v_list: 修正后的视觉层列表
        """
        # 获取 TA 融合的 [CLS] token 表示


        h_ta_full = self.cross_ta(h_l, h_a)  # [B, 9, D] with [CLS] + 8 tokens
        gate_input = torch.cat([h_hyper_v[:,0],h_ta_full[:, 0]] ,dim=-1) 
        h_ta = h_ta_full[:, 1:]  # [B, 8, D] - remaining tokens for gated fusion
        gate, log_prob, prob = self.gate_controller(gate_input)  # [B, 1]
        gate = gate.unsqueeze(2)  # [B, 1, 1]

        # Apply gate to h_ta
        h_ta_gated = gate * h_ta

        h_hyper_v = self.fusion([h_v_list[layer_idx]], h_l, h_a, h_hyper_v)
        
        h_hyper_v = h_hyper_v + h_ta_gated

        # === E步：基于融合结果预测视觉表示
        h_v_feature = self.visual_predictor(h_hyper_v)  # (B, token_len, dim)
        
        if layer_idx+1<len(h_v_list):
            # === M步：将 h_v_feature 映射并反馈更新 h_v_list[layer_idx]
            h_v_feedback = self.feedback_proj(h_v_feature)  # (B, token_len, dim)

            h_v_list[layer_idx+1] = h_v_list[layer_idx+1] + h_v_feedback
            

        return h_hyper_v, h_v_list, h_v_feature, log_prob, gate


class VisionFeatureAggregator(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=8),
            num_layers=1
        )

    def forward(self, x):  
        return self.transformer(x)
    
    
class ResNetWithDropout(nn.Module):
    def __init__(self, dropout_rate=0.3):
        super().__init__()
        resnet18 = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet18.children())[:-1])  # 去掉最后一层 FC
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.features(x)
        x = x.squeeze(-1).squeeze(-1)
        x = self.dropout(x)
        return x


class Model(nn.Module):
    def __init__(self, AHL_depth=3, fusion_layer_depth=2, bert_pretrained='bert-base-uncased'):
        super(Model, self).__init__()

        self.h_hyper = nn.Parameter(torch.ones(1, 8, 128))

        # 基础编码器
        self.bertmodel = BertTextEncoder(use_finetune=True, transformers='bert', pretrained=bert_pretrained)
        self.img_extractor = ResNetWithDropout(dropout_rate=0.3)
        self.vf_aggregator = VisionFeatureAggregator(dim=512)

        # 投影层
        self.proj_l0 = nn.Linear(768, 128)
        self.proj_a0 = nn.Linear(5, 128)
        self.proj_v0 = nn.Sequential(nn.Linear(512, 128), nn.Dropout(0.3))

        # 序列处理
        self.proj_l = Transformer(num_frames=50, save_hidden=False, token_len=8, dim=128, depth=1, heads=8, mlp_dim=128)
        self.proj_a = Transformer(num_frames=50, save_hidden=False, token_len=8, dim=128, depth=1, heads=8, mlp_dim=128)
        self.proj_v = Transformer(num_frames=5, save_hidden=False, token_len=8, dim=128, depth=1, heads=8, mlp_dim=128)

        

        # 视觉主干编码器（输出多个中间层）
        self.vision_encoder = Transformer(num_frames=8, save_hidden=True, token_len=None, dim=128, depth=AHL_depth-1, heads=8, mlp_dim=128)

        # 新增：多层 EMALMTBlock 替代 h_hyper_layer_v
        self.em_almt_blocks = nn.ModuleList([
            EMALMTBlock(dim=128, heads=8, mlp_dim=128, dropout=0.)
            for _ in range(AHL_depth)
        ])

        # 跨模态融合 & 情感预测
        self.fusion_layer = CrossTransformer(source_num_frames=8, tgt_num_frames=8, dim=128, depth=fusion_layer_depth, heads=8, mlp_dim=128)
        self.cls_head = nn.Linear(128, 1)
        
        self.vision_feature_extractor = nn.Sequential(
            Transformer(num_frames=8, save_hidden=False, token_len=1, dim=128, depth=1, heads=8, mlp_dim=64, dropout=0.),   
            nn.Dropout(0.3)
        )

        
    def forward(self, x_visual, x_audio, x_text):
        b = x_visual.size(0)
        h_hyper_v = repeat(self.h_hyper, '1 n d -> b n d', b=b)

        # 视觉特征提取
        x_visual = self.img_extractor(x_visual.view(-1, 3, 224, 224))
        x_visual = x_visual.view(b, 5, 512)
        x_visual = self.proj_v0(self.vf_aggregator(x_visual))



        x_audio = self.proj_a0(x_audio)
        x_text = self.bertmodel(x_text)
        x_text = self.proj_l0(x_text)

        h_v = self.proj_v(x_visual)[:, :8]
        h_a = self.proj_a(x_audio)[:, :8]
        h_l = self.proj_l(x_text)[:, :8]
        
        
        
        h_v_list = list(self.vision_encoder(h_v))  # 多层视觉表示
        
        log_probs = []
        gates = []

        for i, block in enumerate(self.em_almt_blocks):
            h_hyper_v, h_v_list, h_v_feature, log_prob, gate = block(h_v_list, h_l, h_a, h_hyper_v, layer_idx=i)
            log_probs.append(log_prob)
            gates.append(gate)


        # 情感预测
        feat = self.fusion_layer(h_hyper_v, h_v_list[-1])[:, 0]
        output = self.cls_head(feat)

        return output, self.vision_feature_extractor(h_v_feature)[:,0], log_probs, gates


In [72]:
model = Model().to(device)

emo_loss_fn = torch.nn.MSELoss()

def SupervisedContrastiveLoss(h_v_cls, labels, sigma=0.1):
    """
    :param h_v_cls: (B, D) - visual embeddings after projection (final representation)
    :param labels: (B,) - continuous sentiment labels
    :param sigma: float - Gaussian kernel width (smaller = stricter similarity)
    :return: scalar contrastive loss
    """
    batch_size = h_v_cls.shape[0]

    # Normalize embeddings
    h_v_cls = torch.nn.functional.normalize(h_v_cls, dim=1)  # (B, D)

    # Cosine similarity matrix: (B, B)
    similarity_matrix = torch.cosine_similarity(h_v_cls.unsqueeze(1), h_v_cls.unsqueeze(0), dim=-1)

    # Compute label distance matrix
    labels = labels.contiguous().view(-1, 1)  # (B, 1)
    label_diff = labels - labels.T  # (B, B)

    # Gaussian weight based on label similarity
    weight_matrix = torch.exp(- (label_diff ** 2) / (2 * sigma ** 2)).to(h_v_cls.device)  # (B, B)

    # Exponential of cosine similarity
    exp_sim = torch.exp(similarity_matrix)  # (B, B)

    # Avoid self-comparison by masking diagonal to 0
    identity_mask = torch.eye(batch_size, device=h_v_cls.device)
    weight_matrix = weight_matrix * (1 - identity_mask)
    exp_sim = exp_sim * (1 - identity_mask)

    # Compute weighted contrastive loss
    numerator = torch.sum(exp_sim * weight_matrix, dim=1)  # (B,)
    denominator = torch.sum(exp_sim, dim=1) + 1e-8          # (B,)

    loss = -torch.log(numerator / denominator + 1e-8)       # (B,)
    return loss.mean()


In [73]:
# =========================
# 设置训练参数
# =========================
epochs = 200
train_mae, val_mae = [], []

# 优化器设置：降低 weight_decay 并调整 betas 提高稳定性
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4, weight_decay=1e-4)
scheduler_steplr = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=0.9 * epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=0.1 * epochs, after_scheduler=scheduler_steplr)



# =========================
# AverageMeter 类用于记录损失
# =========================
class AverageMeter(object):
    def __init__(self):
        self.value = 0
        self.value_avg = 0
        self.value_sum = 0
        self.count = 0

    def reset(self):
        self.value = 0
        self.value_avg = 0
        self.value_sum = 0
        self.count = 0

    def update(self, value, count):
        self.value = value
        self.value_sum += value * count
        self.count += count
        self.value_avg = self.value_sum / self.count

# =========================
# Train 函数
# =========================
def train(model, train_loader, optimizer, epoch):
    train_pbar = tqdm(enumerate(train_loader), total=len(train_loader), dynamic_ncols=True)
    losses = AverageMeter()
    l1s = AverageMeter()
    l2s = AverageMeter()
    rls = AverageMeter()  # NEW

    y_pred, y_true = [], []

    model.train()
    moving_baseline = 0.6  # 初始 baseline
    alpha = 0.9            # 滑动平均系数
    
    lambda_recon = 0.3
        
    lambda_rl = 0.5        # NEW: 可调 RL loss 权重

    for cur_iter, data in train_pbar:
        img = data['images'].to(device)
        audio = data['audio'].to(device)
        text = data['text'].to(device)
        label = data['labels'].to(device).view(-1, 1)
        batchsize = img.shape[0]

        optimizer.zero_grad()

        # === 模型输出（多返回） ===
        output, h_v_generated, log_probs, gates = model(img, audio, text)

        l1 = emo_loss_fn(output, label)
        l2 = SupervisedContrastiveLoss(h_v_generated, label)

        # === RL loss ===
        mae_clamped = l1.detach().clamp(0, 2)
        reward = 1.0 - mae_clamped / 2.0
        advantage = reward - moving_baseline
        loss_rl = sum([-(lp.squeeze() * advantage).mean() for lp in log_probs])  # sum over layers

        moving_baseline = alpha * moving_baseline + (1 - alpha) * reward.item()  # EMA 更新

        # === 总损失 ===
        loss = l1 + lambda_recon * l2 + lambda_rl * loss_rl

        loss.backward()
        optimizer.step()

        losses.update(loss.item(), batchsize)
        l1s.update(l1.item(), batchsize)
        l2s.update(l2.item(), batchsize)
        rls.update(loss_rl.item(), batchsize)

        y_pred.append(output.cpu())
        y_true.append(label.cpu())

        train_pbar.set_description(f'train [Epoch {epoch}]')
        train_pbar.set_postfix({
            'loss': '{:.5f}'.format(losses.value_avg),
            'emo_loss': '{:.4f}'.format(l1s.value_avg),
            'visual_loss': '{:.4f}'.format(l2s.value_avg),
            'RL_loss': '{:.4f}'.format(rls.value_avg),  # NEW
            'λ_recon': '{:.2f}'.format(lambda_recon),
            'λ_RL': '{:.2f}'.format(lambda_rl),
            'lr': '{:.2e}'.format(optimizer.param_groups[0]['lr'])
        }, refresh=False)

    pred, true = torch.cat(y_pred), torch.cat(y_true)
    mae = torch.mean(torch.abs(pred - true)).item()
    tqdm.write(f"train MAE: {mae:.4f}")
    train_mae.append(mae)

# =========================
# Evaluate 函数
# =========================
def evaluate(model, eval_loader, optimizer, epoch):
    test_pbar = tqdm(enumerate(eval_loader), total=len(eval_loader))
    losses = AverageMeter()
    l1s = AverageMeter()
    l2s = AverageMeter()

    y_pred, y_true = [], []

    model.eval()
    with torch.no_grad():
        for cur_iter, data in test_pbar:
            img = data['images'].to(device)
            audio = data['audio'].to(device)
            text = data['text'].to(device)
            label = data['labels'].to(device).view(-1, 1)
            batchsize = img.shape[0]

            output, h_v_generated, log_probs, gates = model(img, audio, text)
            
           

            l1 = emo_loss_fn(output, label)
            l2 = SupervisedContrastiveLoss(h_v_generated, label)

            
            
            
            lambda_recon = 0.3


            loss = l1 + lambda_recon * l2
            
            y_pred.append(output.cpu())
            y_true.append(label.cpu())

            losses.update(loss.item(), batchsize)
            l1s.update(l1.item(), batchsize)
            l2s.update(l2.item(), batchsize)


            test_pbar.set_description('eval')
            test_pbar.set_postfix({
                'loss': '{:.5f}'.format(losses.value_avg),
                'emo_loss': '{:.4f}'.format(l1s.value_avg),
                'visual_loss': '{:.4f}'.format(l2s.value_avg),
                'lambda_recon': '{:.2f}'.format(lambda_recon),
                'lr': '{:.2e}'.format(optimizer.param_groups[0]['lr'])
            }, refresh=False)

        pred, true = torch.cat(y_pred), torch.cat(y_true)
        mae = torch.mean(torch.abs(pred - true)).item()
        corr = np.corrcoef(pred.squeeze(), true.squeeze())[0, 1]
        print('evaluate MAE:', mae, ' Corr:', corr)
        gate_stats = torch.cat(gates, dim=1)  # [B, num_layers]
        print('Gate Activation Ratio:', gate_stats.float().mean(dim=0).cpu().numpy())

In [74]:
for epoch in range(1, epochs + 1):
    train(model, train_loader, optimizer, epoch)
    evaluate(model, val_loader, optimizer, epoch)  
    scheduler_warmup.step()

train [Epoch 1]: 100%|██████████| 161/161 [01:05<00:00,  2.46it/s, loss=3.94266, emo_loss=2.8632, visual_loss=3.7234, RL_loss=-0.0751, λ_recon=0.30, λ_RL=0.50, lr=0.00e+00]


train MAE: 1.4093


eval: 100%|██████████| 8/8 [00:02<00:00,  3.10it/s, loss=4.34516, emo_loss=3.2178, visual_loss=3.7578, lambda_recon=0.30, lr=0.00e+00]

evaluate MAE: 1.4769006967544556  Corr: -0.09858677744650485
Gate Activation Ratio: [[0.8]
 [0. ]
 [0.8]]



train [Epoch 2]: 100%|██████████| 161/161 [01:05<00:00,  2.46it/s, loss=3.08975, emo_loss=1.9874, visual_loss=3.7336, RL_loss=-0.0355, λ_recon=0.30, λ_RL=0.50, lr=5.00e-06]


train MAE: 1.1970


eval: 100%|██████████| 8/8 [00:02<00:00,  2.76it/s, loss=2.63152, emo_loss=1.5052, visual_loss=3.7546, lambda_recon=0.30, lr=5.00e-06]

evaluate MAE: 0.9874880313873291  Corr: 0.6802205984537061
Gate Activation Ratio: [[0.4]
 [0.4]
 [0.6]]



train [Epoch 3]: 100%|██████████| 161/161 [01:05<00:00,  2.47it/s, loss=1.68538, emo_loss=0.5850, visual_loss=3.6170, RL_loss=0.0306, λ_recon=0.30, λ_RL=0.50, lr=1.00e-05]


train MAE: 0.5884


eval: 100%|██████████| 8/8 [00:02<00:00,  2.90it/s, loss=2.25929, emo_loss=1.1629, visual_loss=3.6546, lambda_recon=0.30, lr=1.00e-05]

evaluate MAE: 0.7906757593154907  Corr: 0.7734984026085001
Gate Activation Ratio: [[0.4]
 [0.2]
 [0.8]]



train [Epoch 4]: 100%|██████████| 161/161 [01:06<00:00,  2.41it/s, loss=1.22492, emo_loss=0.1796, visual_loss=3.4152, RL_loss=0.0416, λ_recon=0.30, λ_RL=0.50, lr=1.50e-05]


train MAE: 0.3224


eval: 100%|██████████| 8/8 [00:02<00:00,  2.87it/s, loss=2.11142, emo_loss=1.0282, visual_loss=3.6108, lambda_recon=0.30, lr=1.50e-05]

evaluate MAE: 0.7295584082603455  Corr: 0.7916502947221236
Gate Activation Ratio: [[0.4]
 [0.4]
 [0.6]]



train [Epoch 5]: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s, loss=1.12116, emo_loss=0.0967, visual_loss=3.3401, RL_loss=0.0448, λ_recon=0.30, λ_RL=0.50, lr=2.00e-05]


train MAE: 0.2418


eval: 100%|██████████| 8/8 [00:02<00:00,  2.95it/s, loss=2.13520, emo_loss=1.0504, visual_loss=3.6161, lambda_recon=0.30, lr=2.00e-05]

evaluate MAE: 0.7299193143844604  Corr: 0.7906522032539811
Gate Activation Ratio: [[0.8]
 [0.2]
 [0.6]]



train [Epoch 6]: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s, loss=1.09370, emo_loss=0.0674, visual_loss=3.3469, RL_loss=0.0444, λ_recon=0.30, λ_RL=0.50, lr=2.50e-05]


train MAE: 0.2023


eval: 100%|██████████| 8/8 [00:02<00:00,  3.16it/s, loss=2.07552, emo_loss=0.9981, visual_loss=3.5914, lambda_recon=0.30, lr=2.50e-05]

evaluate MAE: 0.7144607305526733  Corr: 0.7981809908396389
Gate Activation Ratio: [[0.4]
 [0.2]
 [0.8]]



train [Epoch 7]: 100%|██████████| 161/161 [01:05<00:00,  2.46it/s, loss=1.05329, emo_loss=0.0505, visual_loss=3.2694, RL_loss=0.0439, λ_recon=0.30, λ_RL=0.50, lr=3.00e-05]


train MAE: 0.1759


eval: 100%|██████████| 8/8 [00:02<00:00,  2.99it/s, loss=2.03307, emo_loss=0.9599, visual_loss=3.5774, lambda_recon=0.30, lr=3.00e-05]

evaluate MAE: 0.71439129114151  Corr: 0.8060413763563077
Gate Activation Ratio: [[0.4]
 [0.2]
 [0.2]]



train [Epoch 8]: 100%|██████████| 161/161 [01:06<00:00,  2.44it/s, loss=1.02061, emo_loss=0.0351, visual_loss=3.2141, RL_loss=0.0426, λ_recon=0.30, λ_RL=0.50, lr=3.50e-05]


train MAE: 0.1460


eval: 100%|██████████| 8/8 [00:02<00:00,  3.37it/s, loss=2.00172, emo_loss=0.9347, visual_loss=3.5567, lambda_recon=0.30, lr=3.50e-05]

evaluate MAE: 0.6827972531318665  Corr: 0.8121473592201877
Gate Activation Ratio: [[0.4]
 [0.6]
 [1. ]]



train [Epoch 9]: 100%|██████████| 161/161 [01:06<00:00,  2.42it/s, loss=1.01786, emo_loss=0.0350, visual_loss=3.2058, RL_loss=0.0421, λ_recon=0.30, λ_RL=0.50, lr=4.00e-05]


train MAE: 0.1458


eval: 100%|██████████| 8/8 [00:02<00:00,  3.09it/s, loss=2.02071, emo_loss=0.9522, visual_loss=3.5618, lambda_recon=0.30, lr=4.00e-05]

evaluate MAE: 0.6974905133247375  Corr: 0.8085733065793647
Gate Activation Ratio: [[0.4]
 [0. ]
 [1. ]]



train [Epoch 10]: 100%|██████████| 161/161 [01:05<00:00,  2.45it/s, loss=1.00952, emo_loss=0.0313, visual_loss=3.1837, RL_loss=0.0462, λ_recon=0.30, λ_RL=0.50, lr=4.50e-05]


train MAE: 0.1374


eval: 100%|██████████| 8/8 [00:02<00:00,  2.85it/s, loss=2.00696, emo_loss=0.9450, visual_loss=3.5398, lambda_recon=0.30, lr=4.50e-05]

evaluate MAE: 0.6790040135383606  Corr: 0.8139536679058663
Gate Activation Ratio: [[0.6]
 [0. ]
 [0.6]]



train [Epoch 11]: 100%|██████████| 161/161 [01:06<00:00,  2.42it/s, loss=1.00061, emo_loss=0.0237, visual_loss=3.1781, RL_loss=0.0470, λ_recon=0.30, λ_RL=0.50, lr=5.00e-05]


train MAE: 0.1202


eval: 100%|██████████| 8/8 [00:02<00:00,  3.06it/s, loss=2.06477, emo_loss=0.9960, visual_loss=3.5626, lambda_recon=0.30, lr=5.00e-05]

evaluate MAE: 0.6925716996192932  Corr: 0.8017197943446017
Gate Activation Ratio: [[0.6]
 [0.8]
 [0.6]]



train [Epoch 12]: 100%|██████████| 161/161 [01:05<00:00,  2.45it/s, loss=1.00364, emo_loss=0.0210, visual_loss=3.1975, RL_loss=0.0469, λ_recon=0.30, λ_RL=0.50, lr=5.50e-05]


train MAE: 0.1131


eval: 100%|██████████| 8/8 [00:02<00:00,  3.07it/s, loss=2.04962, emo_loss=0.9790, visual_loss=3.5688, lambda_recon=0.30, lr=5.50e-05]

evaluate MAE: 0.702087938785553  Corr: 0.8033829768293697
Gate Activation Ratio: [[0.4]
 [0.6]
 [0.8]]



train [Epoch 13]: 100%|██████████| 161/161 [01:06<00:00,  2.44it/s, loss=0.99189, emo_loss=0.0269, visual_loss=3.1412, RL_loss=0.0453, λ_recon=0.30, λ_RL=0.50, lr=6.00e-05]


train MAE: 0.1292


eval: 100%|██████████| 8/8 [00:02<00:00,  2.95it/s, loss=2.00043, emo_loss=0.9339, visual_loss=3.5551, lambda_recon=0.30, lr=6.00e-05]

evaluate MAE: 0.6722085475921631  Corr: 0.81267617341162
Gate Activation Ratio: [[0.8]
 [0.4]
 [0.6]]



train [Epoch 14]: 100%|██████████| 161/161 [01:06<00:00,  2.44it/s, loss=0.99873, emo_loss=0.0214, visual_loss=3.1854, RL_loss=0.0434, λ_recon=0.30, λ_RL=0.50, lr=6.50e-05]


train MAE: 0.1142


eval: 100%|██████████| 8/8 [00:02<00:00,  2.85it/s, loss=2.01790, emo_loss=0.9486, visual_loss=3.5643, lambda_recon=0.30, lr=6.50e-05]

evaluate MAE: 0.6622635126113892  Corr: 0.8183754778734837
Gate Activation Ratio: [[0.8]
 [0.6]
 [1. ]]



train [Epoch 15]: 100%|██████████| 161/161 [01:06<00:00,  2.44it/s, loss=0.98819, emo_loss=0.0199, visual_loss=3.1568, RL_loss=0.0425, λ_recon=0.30, λ_RL=0.50, lr=7.00e-05]

train MAE: 0.1104



eval: 100%|██████████| 8/8 [00:02<00:00,  3.16it/s, loss=2.01276, emo_loss=0.9446, visual_loss=3.5605, lambda_recon=0.30, lr=7.00e-05]

evaluate MAE: 0.6885285973548889  Corr: 0.8144479419306822
Gate Activation Ratio: [[0.6]
 [0.6]
 [1. ]]



train [Epoch 16]: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s, loss=1.11038, emo_loss=0.1420, visual_loss=3.2005, RL_loss=0.0165, λ_recon=0.30, λ_RL=0.50, lr=7.50e-05]


train MAE: 0.2138


eval: 100%|██████████| 8/8 [00:02<00:00,  3.04it/s, loss=3.15552, emo_loss=2.0581, visual_loss=3.6580, lambda_recon=0.30, lr=7.50e-05]

evaluate MAE: 1.1169084310531616  Corr: 0.6571122467844689
Gate Activation Ratio: [[0.8]
 [0.2]
 [0.8]]



train [Epoch 17]: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s, loss=1.31627, emo_loss=0.2874, visual_loss=3.3732, RL_loss=0.0338, λ_recon=0.30, λ_RL=0.50, lr=8.00e-05]


train MAE: 0.3677


eval: 100%|██████████| 8/8 [00:02<00:00,  3.21it/s, loss=2.23356, emo_loss=1.1496, visual_loss=3.6133, lambda_recon=0.30, lr=8.00e-05]

evaluate MAE: 0.7821822762489319  Corr: 0.7607688400366236
Gate Activation Ratio: [[1. ]
 [1. ]
 [0.8]]



train [Epoch 18]: 100%|██████████| 161/161 [01:06<00:00,  2.42it/s, loss=1.02571, emo_loss=0.0485, visual_loss=3.2063, RL_loss=0.0307, λ_recon=0.30, λ_RL=0.50, lr=8.50e-05]


train MAE: 0.1593


eval: 100%|██████████| 8/8 [00:02<00:00,  3.60it/s, loss=2.23285, emo_loss=1.1481, visual_loss=3.6158, lambda_recon=0.30, lr=8.50e-05]

evaluate MAE: 0.7852872610092163  Corr: 0.7672892618766178
Gate Activation Ratio: [[0.8]
 [1. ]
 [0.6]]



train [Epoch 19]: 100%|██████████| 161/161 [01:05<00:00,  2.48it/s, loss=0.97943, emo_loss=0.0264, visual_loss=3.1261, RL_loss=0.0303, λ_recon=0.30, λ_RL=0.50, lr=9.00e-05]


train MAE: 0.1236


eval: 100%|██████████| 8/8 [00:02<00:00,  2.99it/s, loss=2.17500, emo_loss=1.0964, visual_loss=3.5952, lambda_recon=0.30, lr=9.00e-05]

evaluate MAE: 0.7728356122970581  Corr: 0.7748519320060059
Gate Activation Ratio: [[1. ]
 [0.8]
 [0.8]]



train [Epoch 20]: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s, loss=1.03043, emo_loss=0.0548, visual_loss=3.1979, RL_loss=0.0325, λ_recon=0.30, λ_RL=0.50, lr=9.50e-05]


train MAE: 0.1468


eval: 100%|██████████| 8/8 [00:02<00:00,  2.81it/s, loss=2.23293, emo_loss=1.1556, visual_loss=3.5912, lambda_recon=0.30, lr=9.50e-05]

evaluate MAE: 0.7834370136260986  Corr: 0.7643250780230577
Gate Activation Ratio: [[1. ]
 [1. ]
 [0.2]]



train [Epoch 21]: 100%|██████████| 161/161 [01:07<00:00,  2.40it/s, loss=0.99029, emo_loss=0.0179, visual_loss=3.1840, RL_loss=0.0345, λ_recon=0.30, λ_RL=0.50, lr=1.00e-04]


train MAE: 0.1020


eval: 100%|██████████| 8/8 [00:02<00:00,  2.82it/s, loss=2.21414, emo_loss=1.1342, visual_loss=3.5998, lambda_recon=0.30, lr=1.00e-04]

evaluate MAE: 0.7812797427177429  Corr: 0.7680000734270104
Gate Activation Ratio: [[1. ]
 [0.8]
 [0.8]]



train [Epoch 22]: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s, loss=0.96726, emo_loss=0.0112, visual_loss=3.1394, RL_loss=0.0284, λ_recon=0.30, λ_RL=0.50, lr=1.00e-04]


train MAE: 0.0798


eval: 100%|██████████| 8/8 [00:02<00:00,  2.99it/s, loss=2.44490, emo_loss=1.3479, visual_loss=3.6565, lambda_recon=0.30, lr=1.00e-04]

evaluate MAE: 0.8325932025909424  Corr: 0.7238955540217584
Gate Activation Ratio: [[0.6]
 [0.8]
 [0.8]]



train [Epoch 23]: 100%|██████████| 161/161 [01:05<00:00,  2.47it/s, loss=0.96100, emo_loss=0.0101, visual_loss=3.1294, RL_loss=0.0241, λ_recon=0.30, λ_RL=0.50, lr=1.00e-04]


train MAE: 0.0773


eval: 100%|██████████| 8/8 [00:02<00:00,  3.05it/s, loss=2.22572, emo_loss=1.1398, visual_loss=3.6196, lambda_recon=0.30, lr=1.00e-04]

evaluate MAE: 0.7779790759086609  Corr: 0.7663909387500535
Gate Activation Ratio: [[1. ]
 [0.8]
 [0.8]]



train [Epoch 24]: 100%|██████████| 161/161 [01:06<00:00,  2.42it/s, loss=0.93966, emo_loss=0.0076, visual_loss=3.0758, RL_loss=0.0186, λ_recon=0.30, λ_RL=0.50, lr=1.00e-04]


train MAE: 0.0682


eval: 100%|██████████| 8/8 [00:02<00:00,  2.96it/s, loss=2.20222, emo_loss=1.1158, visual_loss=3.6215, lambda_recon=0.30, lr=1.00e-04]

evaluate MAE: 0.764259934425354  Corr: 0.773798200195511
Gate Activation Ratio: [[0.8]
 [1. ]
 [1. ]]



train [Epoch 25]: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s, loss=0.95755, emo_loss=0.0080, visual_loss=3.1335, RL_loss=0.0190, λ_recon=0.30, λ_RL=0.50, lr=9.99e-05]


train MAE: 0.0697


eval: 100%|██████████| 8/8 [00:02<00:00,  3.26it/s, loss=2.20956, emo_loss=1.1189, visual_loss=3.6357, lambda_recon=0.30, lr=9.99e-05]

evaluate MAE: 0.765745222568512  Corr: 0.7725198335940116
Gate Activation Ratio: [[1.]
 [1.]
 [1.]]



train [Epoch 26]: 100%|██████████| 161/161 [01:05<00:00,  2.48it/s, loss=0.96425, emo_loss=0.0067, visual_loss=3.1557, RL_loss=0.0218, λ_recon=0.30, λ_RL=0.50, lr=9.99e-05]


train MAE: 0.0635


eval: 100%|██████████| 8/8 [00:02<00:00,  3.09it/s, loss=2.15979, emo_loss=1.0767, visual_loss=3.6104, lambda_recon=0.30, lr=9.99e-05]

evaluate MAE: 0.7567517757415771  Corr: 0.7794375695378685
Gate Activation Ratio: [[0.6]
 [1. ]
 [1. ]]



train [Epoch 27]: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s, loss=0.94474, emo_loss=0.0068, visual_loss=3.0995, RL_loss=0.0162, λ_recon=0.30, λ_RL=0.50, lr=9.98e-05]


train MAE: 0.0645


eval: 100%|██████████| 8/8 [00:02<00:00,  3.23it/s, loss=2.17324, emo_loss=1.0890, visual_loss=3.6141, lambda_recon=0.30, lr=9.98e-05]

evaluate MAE: 0.7668777108192444  Corr: 0.7751052888545593
Gate Activation Ratio: [[1.]
 [1.]
 [1.]]



train [Epoch 28]: 100%|██████████| 161/161 [01:06<00:00,  2.41it/s, loss=0.96671, emo_loss=0.0064, visual_loss=3.1795, RL_loss=0.0130, λ_recon=0.30, λ_RL=0.50, lr=9.97e-05]


train MAE: 0.0626


eval: 100%|██████████| 8/8 [00:02<00:00,  2.86it/s, loss=2.18453, emo_loss=1.1005, visual_loss=3.6133, lambda_recon=0.30, lr=9.97e-05]

evaluate MAE: 0.7687922716140747  Corr: 0.7744642826526671
Gate Activation Ratio: [[0.8]
 [1. ]
 [1. ]]



train [Epoch 29]: 100%|██████████| 161/161 [01:06<00:00,  2.44it/s, loss=0.95538, emo_loss=0.0063, visual_loss=3.1412, RL_loss=0.0135, λ_recon=0.30, λ_RL=0.50, lr=9.96e-05]


train MAE: 0.0615


eval: 100%|██████████| 8/8 [00:02<00:00,  2.84it/s, loss=2.15984, emo_loss=1.0735, visual_loss=3.6210, lambda_recon=0.30, lr=9.96e-05]

evaluate MAE: 0.7601327896118164  Corr: 0.7795404084015222
Gate Activation Ratio: [[0.8]
 [1. ]
 [1. ]]



train [Epoch 30]: 100%|██████████| 161/161 [01:06<00:00,  2.43it/s, loss=0.95451, emo_loss=0.0077, visual_loss=3.1354, RL_loss=0.0125, λ_recon=0.30, λ_RL=0.50, lr=9.95e-05]


train MAE: 0.0687


eval: 100%|██████████| 8/8 [00:02<00:00,  2.98it/s, loss=2.17548, emo_loss=1.0894, visual_loss=3.6201, lambda_recon=0.30, lr=9.95e-05]

evaluate MAE: 0.7654860019683838  Corr: 0.7761794814521925
Gate Activation Ratio: [[1.]
 [1.]
 [1.]]



train [Epoch 31]: 100%|██████████| 161/161 [01:05<00:00,  2.47it/s, loss=0.95325, emo_loss=0.0061, visual_loss=3.1390, RL_loss=0.0109, λ_recon=0.30, λ_RL=0.50, lr=9.94e-05]


train MAE: 0.0604


eval: 100%|██████████| 8/8 [00:02<00:00,  3.00it/s, loss=2.20907, emo_loss=1.1159, visual_loss=3.6441, lambda_recon=0.30, lr=9.94e-05]

evaluate MAE: 0.7593504190444946  Corr: 0.7734879195865404
Gate Activation Ratio: [[0.8]
 [1. ]
 [1. ]]



train [Epoch 32]:  65%|██████▌   | 105/161 [00:43<00:23,  2.43it/s, loss=0.95108, emo_loss=0.0064, visual_loss=3.1183, RL_loss=0.0184, λ_recon=0.30, λ_RL=0.50, lr=9.92e-05]


KeyboardInterrupt: 