In [None]:
import os
import json
import copy
import time
import torch
import pickle
import random
import string
import logging
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from G2MoE.utils import *
from G2MoE.model.Loss import Triplet
from G2MoE.model.Testing_Process import test_e2e
from G2MoE.model.Training_Process import train_e2e_moe, val_e2e_moe
from G2MoE.model.Graph_Constructor import Graph
from G2MoE.model.Graph_Encoder import Contrast_Encoder, End2End_Encoder
from G2MoE.model.MoEGAT import StepWiseGraphConvLayerMoE



In [6]:
args = {'gpu': 2, 'seed': 42, 'batch_size': 1, 'input': 768, 'hidden': 2048, 'heads': 128,
       'epochs': 100, 'log_every': 20, 'lr': 0.0003, 'dropout': 0.3, 'num_layers': 3, 'max_word_num': 185, 'kappa': 0.2, 'target_main_contribution': 0.6}

triple_args = {'margin': 0.5, 'topk': 5, 'thred': [0.4, 0.5]}

# if use_dynamic_topk is True --> top_k is the maximum of available topk
moe_args = {'num_experts': 3,
            'top_k': 3,
            'use_dynamic_topk': True}

# Config
EARLY_STOPPING_PATIENCE = 20 
EARLY_STOPPING_METRIC = "val_rouge2" 

dataset_name = "Abmusu_weight_blend_concateV2_contribute"  

model_save_root_path = '/kaggle/working/'
c_patient, s_patient = 30,30
best_r2, best_c_loss, best_s_loss = 0, 10000, 10000
history = {'loss': [], 'val_loss': []}
model_save_path = ""

seed_everything(args['seed'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import pickle

# Load train
with open("/kaggle/input/vlsp-graphs-sent/trainGraphs.pkl", "rb") as f:
    trainGraphs = pickle.load(f)

# Load test
with open("/kaggle/input/vlsp-test-graph/testGraphs.pkl", "rb") as f:
    testGraphs = pickle.load(f)

print(f"Train graphs: {len(trainGraphs)} samples")
print(f"Test graphs: {len(testGraphs)} samples")

Train graphs: 200 samples
Test graphs: 150 samples


In [8]:
!pip install wandb



In [None]:
import wandb
wandb.login(key="...")  # Add your wandb API key here

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mlexuanhung1234576[0m ([33mTripleAIML[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
# Initialize wandb
import wandb

use_dynamic_topk = moe_args['use_dynamic_topk']
target_main_contribution = args['target_main_contribution']

run_name = f"{dataset_name}_DynTopK-{use_dynamic_topk}_TMC-{target_main_contribution}"

wandb.init(
    project="MoEGAT",
    name=run_name,
    config={
        "learning_rate": args['lr'],
        "epochs": args['epochs'],
        "hidden_dim": args['hidden'],
        "dropout": args['dropout'],
        "heads": args['heads'],
        "max_word_num": args['max_word_num'],
        "kappa": args['kappa'],
        "num_experts": moe_args['num_experts'],
        "top_k": moe_args['top_k'],
        "use_dynamic_topk": use_dynamic_topk,
        "triplet_margin": triple_args['margin'],
        "triplet_topk": triple_args['topk'],
        "triple_thres0": triple_args['thred'][0],
        "triple_thres1": triple_args['thred'][1],
        "target_main_contribution": target_main_contribution,
    }
)

# Model initialization
c_graph_encoder = StepWiseGraphConvLayerMoE(
    in_dim=args['input'], out_dim=args['hidden'], hid_dim=args['hidden'],
    dropout_p=args['dropout'], act=nn.LeakyReLU(), nheads=args['heads'], iter=1, 
    num_experts=moe_args['num_experts'], top_k=moe_args['top_k'], 
    use_dynamic_topk=moe_args['use_dynamic_topk'], target_main_contribution=args['target_main_contribution']
).to(device)

contrast_filter = Contrast_Encoder(c_graph_encoder, args['hidden'], dropout_p=args['dropout']).to(device)
summarization_encoder = End2End_Encoder(args['input'], args['hidden'], args['dropout']).to(device)
loss_method = Triplet(margin=triple_args['margin'], topk=triple_args['topk'])

trainset = trainGraphs

optimizer = torch.optim.Adam([
    {'params': summarization_encoder.parameters()},
    {'params': contrast_filter.parameters()}
], lr=args['lr'], weight_decay=1e-5)

[34m[1mwandb[0m: Tracking run with wandb version 0.21.0
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20251122_172003-1agf5aeo[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mAbmusu_weight_blend_concateV2_contribute_DynTopK-True_TMC-0.6[0m
[34m[1mwandb[0m: ‚≠êÔ∏è View project at [34m[4mhttps://wandb.ai/TripleAIML/MoEGAT[0m
[34m[1mwandb[0m: üöÄ View run at [34m[4mhttps://wandb.ai/TripleAIML/MoEGAT/runs/1agf5aeo[0m


In [None]:
best_loss = float('inf')
early_stopping_counter = 0

wandb.watch(contrast_filter, log="all", log_freq=100)
wandb.watch(summarization_encoder, log="all", log_freq=100)

for i in range(args['epochs']):
    random.shuffle(trainset)

    if c_patient < 0:
        for p in contrast_filter.parameters():
            p.requires_grad = False

    model = [contrast_filter, summarization_encoder]

    train_loss, train_rouge2, train_c_loss, train_s_loss, train_moe_loss = train_e2e_moe(
        trainset, model, optimizer, loss_method,
        args['max_word_num'], args['kappa'], triple_args['thred'], epoch=i
    )

    history['loss'].append(train_loss)

    wandb.log({
        "epoch": i,
        "train/loss": train_loss,
        "train/c_loss": train_c_loss,
        "train/s_loss": train_s_loss,
        "train/moe_loss": train_moe_loss,
        "train/rouge2": train_rouge2
    })

    torch.cuda.empty_cache()

    if train_loss < best_loss:
        best_loss = train_loss
        early_stopping_counter = 0

        encoder_path = os.path.join(model_save_root_path, f"best_encoder_epoch{i}_loss{best_loss:.4f}.mdl")
        contrast_path = os.path.join(model_save_root_path, f"best_contrast_epoch{i}_loss{best_loss:.4f}.mdl")

        torch.save(summarization_encoder.state_dict(), encoder_path)
        torch.save(contrast_filter.state_dict(), contrast_path)

        wandb.log({
            "best_loss": best_loss,
            "best_epoch": i
        })

    else:
        early_stopping_counter += 1

    if early_stopping_counter >= EARLY_STOPPING_PATIENCE:
        break

wandb.finish()

Epoch 0
Batch 0, Loss: 1.1995265483856201
Batch 0, C-Loss: 1.575150489807129
Batch 0, S-Loss: 2.0224266052246094
Batch 0, MoE-Loss: 0.0010024189250543714
Main contribution: 49.98%
Dynamic Top-K: {'mean': 2.5142858028411865, 'std': 0.5034045577049255, 'min': 2.0, 'max': 3.0}
Batch 20, Loss: 1.073665976524353
Batch 20, C-Loss: 1.7633060216903687
Batch 20, S-Loss: 1.4570229053497314
Batch 20, MoE-Loss: 0.0006688821595162153
Main contribution: 58.05%
Dynamic Top-K: {'mean': 2.5303030014038086, 'std': 0.5029052495956421, 'min': 2.0, 'max': 3.0}
Batch 40, Loss: 1.070842981338501
Batch 40, C-Loss: 1.6801612377166748
Batch 40, S-Loss: 1.5319209098815918
Batch 40, MoE-Loss: 0.0004467074468266219
Main contribution: 58.26%
Dynamic Top-K: {'mean': 2.6500000953674316, 'std': 0.48304590582847595, 'min': 2.0, 'max': 3.0}
Batch 60, Loss: 1.0348056554794312
Batch 60, C-Loss: 1.5173969268798828
Batch 60, S-Loss: 1.5866889953613281
Batch 60, MoE-Loss: 0.00033027486642822623
Main contribution: 59.60%
Dyna



At Epoch 0, Val Loss: 0.6614705920219421, Val CLoss: 0.5021767020225525, Val SLoss: 1.4822044372558594, Val MoELoss: 3.1000970921013504e-05, Val R2: 0.30034733333333336
Val Main contribution: 59.69%
Val Dynamic Top-K: {'mean': 2.5305773337682087, 'std': 0.46984738945960997, 'min': 2.013333333333333, 'max': 3.0}
Epoch 0 Has best R2 Score of 0.30034733333333336, saved Model to /kaggle/working/c_0_0.30034733333333336.mdl
Epoch 1
Batch 0, Loss: 0.6203859448432922
Batch 0, C-Loss: 0.5209908485412598
Batch 0, S-Loss: 1.3401020765304565
Batch 0, MoE-Loss: 6.485879566753283e-05
Main contribution: 59.35%
Dynamic Top-K: {'mean': 2.481818199157715, 'std': 0.5019561648368835, 'min': 2.0, 'max': 3.0}
Batch 20, Loss: 0.6890371441841125
Batch 20, C-Loss: 0.5110191106796265
Batch 20, S-Loss: 1.5560725927352905
Batch 20, MoE-Loss: 1.9500910639180802e-05
Main contribution: 59.89%
Dynamic Top-K: {'mean': 2.4250001907348633, 'std': 0.49641573429107666, 'min': 2.0, 'max': 3.0}
Batch 40, Loss: 0.66862422227

[34m[1mwandb[0m: uploading output.log; uploading config.yaml; uploading working/c_1_0.3118606666666667.mdl; uploading working/c_3_0.312058.mdl; uploading working/c_7_0.3130266666666667.mdl (+ 4 more)
[34m[1mwandb[0m: uploading working/c_1_0.3118606666666667.mdl; uploading working/c_3_0.312058.mdl; uploading working/c_7_0.3130266666666667.mdl; uploading working/c_15_0.313472.mdl; uploading working/c_19_0.31512199999999996.mdl (+ 2 more)
[34m[1mwandb[0m: uploading working/c_1_0.3118606666666667.mdl; uploading working/c_3_0.312058.mdl; uploading working/c_7_0.3130266666666667.mdl; uploading working/c_15_0.313472.mdl; uploading working/c_23_0.3158586666666667.mdl (+ 1 more)
[34m[1mwandb[0m: uploading history steps 534-534, summary, console lines 2901-2907
[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:                    best_epoch ‚ñÅ‚ñÅ‚ñÇ‚ñÉ‚ñÜ‚ñá‚ñà


# Evaluation

In [12]:
c_file_path = model_save_path
e_file_path = c_file_path.replace("/c_", "/e_")

In [13]:
c_file_path

'/kaggle/working/c_23_0.3158586666666667.mdl'

In [14]:
c_graph_encoder = StepWiseGraphConvLayerMoE(in_dim=args['input'], out_dim=args['hidden'], hid_dim=args['hidden'],
                                         dropout_p=args['dropout'], act=nn.LeakyReLU(), nheads=args['heads'], iter=1, 
                                            num_experts=moe_args['num_experts'], top_k=moe_args['top_k'], use_dynamic_topk=moe_args['use_dynamic_topk']).to(device)


contrast_filter = Contrast_Encoder(c_graph_encoder, args['hidden'], dropout_p=args['dropout']).to(device)
summarization_encoder = End2End_Encoder(args['input'], args['hidden'], args['dropout']).to(device)

In [15]:
summarization_encoder.load_state_dict(torch.load(e_file_path, map_location=torch.device('cuda')), strict=False)
contrast_filter.load_state_dict(torch.load(c_file_path, map_location=torch.device('cuda')), strict=False)
model = [contrast_filter, summarization_encoder]

In [16]:
predicts, goldens = test_e2e(testGraphs, model, args["max_word_num"], args["kappa"])
rouge_scores = get_rouges(goldens, predicts)
bert_score = get_bert_score(goldens, predicts)

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

2025-11-22 17:51:11.650839: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763833871.827130      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763833871.876588      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/5 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/3 [00:00<?, ?it/s]

done in 12.39 seconds, 12.10 sentences/sec


In [17]:
def print_evaluation_scores(rouge_means, bert):
    for rouge_type, metrics in rouge_means.items():
        print(f"{rouge_type.upper()} Scores:")
        for metric, value in metrics.items():
            print(f"  {metric}: {value:.4f}")
        print()
        
    print("Bert Scores:")
    for metric, value in bert.items():
        print(f"  {metric}: {value:.4f}")

In [18]:
print_evaluation_scores(rouge_scores, bert_score)

ROUGE1 Scores:
  p: 0.4668
  r: 0.5999
  f: 0.5151

ROUGE2 Scores:
  p: 0.2801
  r: 0.3888
  f: 0.3159

ROUGEL Scores:
  p: 0.4376
  r: 0.5618
  f: 0.4826

Bert Scores:
  precision: 0.9052
  recall: 0.9086
  fmeasure: 0.9069


In [19]:
def print_for_copy(rouge_means, bert):
    for rouge_type, metrics in rouge_means.items():
        for metric, value in metrics.items():
            print(f"{value:.4f}", end="\t")
        
    for metric, value in bert.items():
        print(f" {value:.4f}", end="\t")

In [20]:
print_for_copy(rouge_scores, bert_score)

0.4668	0.5999	0.5151	0.2801	0.3888	0.3159	0.4376	0.5618	0.4826	 0.9052	 0.9086	 0.9069	

In [21]:
# L∆∞u predicts ra file
with open("predicts.txt", "w", encoding="utf-8") as f:
    for line in predicts:
        f.write(line.strip() + "\n")

# L∆∞u goldens ra file
with open("goldens.txt", "w", encoding="utf-8") as f:
    for line in goldens:
        f.write(line.strip() + "\n")