In [1]:
!pip install sentencepiece protobuf --quiet

In [2]:
import json
import logging
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import numpy as np

# üéØ Multi-EAOS Model - Production Ready

## Improvements in This Notebook

### ‚úÖ Data Management
- Train/Validation split (80/20) for proper model evaluation
- Prevents overfitting and enables model selection

### ‚úÖ Training Process
- **Validation loop** to monitor generalization performance
- **Early stopping** (patience: 20 epochs) to prevent overfitting
- **Best model selection** based on validation F1-score
- **Periodic checkpoints** (every 10 epochs) for training resume

### ‚úÖ Evaluation Metrics
- **Accuracy**: Overall correctness of predictions
- **Precision**: Ratio of correct predictions among all predictions
- **Recall**: Ratio of correct predictions among all ground truth samples
- **F1-Score**: Harmonic mean of Precision and Recall
- Real-time metrics tracking during training and validation
- Comprehensive evaluation function for test sets

### ‚úÖ Model Persistence
- Organized folder structure:
  - `models/checkpoints/` - Training checkpoints
  - `models/best_model/` - Best model + config for deployment
- Saves model configuration (config.json) with:
  - Label mappings
  - Training metrics (accuracy, precision, recall, F1-score)
  - Model hyperparameters
- Easy model loading for inference

### ‚úÖ Backend Integration
- `EAOSInference` class for production use
- Confidence threshold filtering
- Batch prediction support
- Ready for FastAPI/Flask integration

### üöÄ Quick Start
1. Run all cells to prepare data and define model
2. Execute training: `run_training(train_dataset, val_dataset)`
3. Best model automatically saved to `models/best_model/`
4. Use `EAOSInference` class in your backend
5. Evaluate model performance with comprehensive metrics

---

In [3]:
# C·∫•u h√¨nh log ƒë·ªÉ d·ªÖ theo d√µi l·ªói
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def find_span_indices(text, phrase):
    """
    T√¨m v·ªã tr√≠ b·∫Øt ƒë·∫ßu v√† k·∫øt th√∫c c·ªßa m·ªôt c·ª•m t·ª´ trong c√¢u.
    Tr·∫£ v·ªÅ (start_char_idx, end_char_idx).
    """
    if not phrase or phrase.lower() == "null" or phrase == "":
        return (-1, -1)  # X·ª≠ l√Ω tr∆∞·ªùng h·ª£p Implicit (·∫©n)

    start_idx = text.lower().find(phrase.lower())
    if start_idx == -1:
        return None  # Kh√¥ng t√¨m th·∫•y kh·ªõp

    end_idx = start_idx + len(phrase)
    return (start_idx, end_idx)

# --- PH·∫¶N THAY ƒê·ªîI L·ªöN: ƒê·ªåC JSON M·ªöI ---
def process_json_array_data(file_path):
    processed_data = []

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            full_content = f.read()

        # Ph√¢n t√≠ch JSON
        raw_json = json.loads(full_content)

        # L·∫•y m·∫£ng t·ª´ key "results" thay v√¨ l√† JSON Array tr·ª±c ti·∫øp
        raw_entries = raw_json.get("results", [])
        if not isinstance(raw_entries, list):
            logger.error("Key 'results' kh√¥ng ph·∫£i l√† m·∫£ng. Ki·ªÉm tra l·∫°i file JSON.")
            return []

    except json.JSONDecodeError as e:
        logger.error(f"L·ªñI L·ªöN: T·ªáp tin kh√¥ng ph·∫£i l√† JSON h·ª£p l·ªá. L·ªói t·∫°i char {e.pos}: {e}")
        return []
    except Exception as e:
        logger.error(f"L·ªói khi ƒë·ªçc file: {e}")
        return []

    for i, entry in enumerate(raw_entries):
        try:
            text = entry["text"]
            labels = entry["labels"]
        except KeyError as e:
            logger.warning(f"M·∫´u {i+1}: Thi·∫øu key {e} trong ƒë·ªëi t∆∞·ª£ng JSON. B·ªè qua.")
            continue

        valid_labels = []
        for label in labels:
            entity_span = find_span_indices(text, label['entity'])
            opinion_span = find_span_indices(text, label['opinion'])

            if entity_span is None:
                logger.warning(f"M·∫´u {i+1}: Kh√¥ng t√¨m th·∫•y Entity '{label.get('entity')}' trong text.")
                continue
            if opinion_span is None:
                logger.warning(f"M·∫´u {i+1}: Kh√¥ng t√¨m th·∫•y Opinion '{label.get('opinion')}' trong text.")
                continue

            new_label = {
                "entity_text": label['entity'],
                "entity_span": entity_span,
                "opinion_text": label['opinion'],
                "opinion_span": opinion_span,
                "aspect": label['aspect'],
                "sentiment": label['sentiment']
            }
            valid_labels.append(new_label)

        if valid_labels:
            processed_data.append({
                "text": text,
                "labels": valid_labels
            })

    return processed_data

In [6]:
data = process_json_array_data('./filtered.json')
print(f"ƒê√£ x·ª≠ l√Ω th√†nh c√¥ng {len(data)} m·∫´u d·ªØ li·ªáu.")
print(data[0])

ERROR:__main__:L·ªói khi ƒë·ªçc file: [Errno 2] No such file or directory: './filtered.json'


ƒê√£ x·ª≠ l√Ω th√†nh c√¥ng 0 m·∫´u d·ªØ li·ªáu.


IndexError: list index out of range

In [None]:
# 1. C·∫•u h√¨nh c√°c nh√£n (Mapping)
# D·ª±a theo b√†i b√°o: 5 nh√≥m Aspect v√† 3 nh√≥m Sentiment
ASPECT_MAP = {
    "ƒê·ªãa ƒëi·ªÉm": 1,
    "K·ªãch b·∫£n": 2,
    "D√†n d·ª±ng": 3,
    "D√†n cast": 4,
    "Kh√°ch m·ªùi": 5,
    "Kh·∫£ nƒÉng ch∆°i tr√≤ ch∆°i": 6,
    "Qu·∫£ng c√°o": 7,
    "Th·ª≠ th√°ch": 8,
    "T∆∞∆°ng t√°c gi·ªØa c√°c th√†nh vi√™n": 9,
    "Tinh th·∫ßn ƒë·ªìng ƒë·ªôi": 10,
    "Kh√°c": 0
}
SENTIMENT_MAP = {
    "t√≠ch c·ª±c": 1,
    "ti√™u c·ª±c": 2,
    "trung t√≠nh": 0,
}

# T·∫£i Tokenizer c·ªßa PhoBERT (d√πng b·∫£n base ho·∫∑c v2 ƒë·ªÅu ƒë∆∞·ª£c)
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")

In [None]:
try:
    tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base", use_fast=True)
    print("ƒê√£ load Fast Tokenizer th√†nh c√¥ng!")
except Exception as e:
    print(f"Kh√¥ng th·ªÉ load Fast Tokenizer: {e}")
    # N·∫øu l·ªói n√†y x·∫£y ra, b·∫°n bu·ªôc ph·∫£i d√πng C√°ch 2 b√™n d∆∞·ªõi

In [None]:
class EAOSDatasetManual(Dataset):
    def __init__(self, data, tokenizer, max_len=256, max_quads=4):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.max_quads = max_quads

    def __len__(self):
        return len(self.data)

    def manual_find_token_span(self, encoding, char_span):
        """
        H√†m th·ªß c√¥ng ƒë·ªÉ √°nh x·∫° (char_start, char_end) -> (token_start, token_end)
        D√πng cho Slow Tokenizer kh√¥ng c√≥ offset_mapping.
        """
        char_start, char_end = char_span
        if char_start == -1: return 0, 0 # Implicit

        # Tokenizer Slow tr·∫£ v·ªÅ input_ids. Ta c·∫ßn convert ng∆∞·ª£c l·∫°i tokens ƒë·ªÉ ki·ªÉm tra ƒë·ªô d√†i
        tokens = self.tokenizer.convert_ids_to_tokens(encoding['input_ids'][0])

        # T·∫°o l·∫°i map th·ªß c√¥ng
        current_char_idx = 0
        token_spans = [] # List l∆∞u [(start, end), (start, end)...] cho t·ª´ng token

        # L∆∞u √Ω: PhoBERT d√πng BPE v√† k√Ω t·ª± '_' thay cho kho·∫£ng tr·∫Øng, ho·∫∑c '@@'
        # Logic n√†y ch·ªâ mang t√≠nh t∆∞∆°ng ƒë·ªëi cho PhoBERT, c·∫ßn c·∫©n th·∫≠n v·ªõi special tokens
        for token in tokens:
            # B·ªè qua special tokens ban ƒë·∫ßu (nh∆∞ <s>)
            if token in [self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]:
                token_len = 0
            else:
                # Clean token ƒë·ªÉ t√≠nh ƒë·ªô d√†i th·ª±c t·∫ø trong c√¢u
                clean_token = token.replace('@@', '').replace('_', ' ')
                token_len = len(clean_token)

            # N·∫øu l√† token ƒë·∫ßu ti√™n (sau <s>), c√≥ th·ªÉ c·∫ßn strip kho·∫£ng tr·∫Øng th·ª´a n·∫øu c√≥
            start = current_char_idx
            end = current_char_idx + token_len
            token_spans.append((start, end))

            # C·∫≠p nh·∫≠t v·ªã tr√≠ con tr·ªè (gi·∫£ ƒë·ªãnh token n·ªëi ti·∫øp nhau kh√≠t)
            # V·ªõi PhoBERT Slow, logic n√†y c√≥ th·ªÉ l·ªách 1-2 k√Ω t·ª± do c√°ch x·ª≠ l√Ω d·∫•u c√°ch
            # ƒê√¢y l√† ƒëi·ªÉm y·∫øu c·ªßa Slow Tokenizer.
            current_char_idx += token_len

        # T√¨m token index d·ª±a tr√™n char index
        token_start, token_end = -1, -1

        # Logic so kh·ªõp g·∫ßn ƒë√∫ng (Approximate Matching)
        for idx, (t_start, t_end) in enumerate(token_spans):
            # N·∫øu span c·ªßa token giao nhau v·ªõi span c·ªßa entity
            if t_start <= char_start < t_end:
                token_start = idx
            if t_start < char_end <= t_end:
                token_end = idx

        if token_start != -1 and token_end == -1: token_end = token_start # S·ª≠a l·ªói n·∫øu ch·ªâ b·∫Øt ƒë∆∞·ª£c start

        return (token_start, token_end) if (token_start != -1) else (0, 0)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['text']
        labels = item['labels']

        # KH√îNG D√ôNG return_offsets_mapping n·ªØa
        encoding = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )

        # N·∫øu d√πng C√°ch 1 th√†nh c√¥ng (Fast), b·∫°n c√≥ th·ªÉ uncomment d√≤ng n√†y ƒë·ªÉ l·∫•y offset
        # offset_mapping = encoding.offset_mapping[0] if encoding.is_fast else None

        target_matrix = np.full((self.max_quads, 6), -1, dtype=int)

        for i, label in enumerate(labels):
            if i >= self.max_quads: break

            # D√πng h√†m manual
            e_s, e_e = self.manual_find_token_span(encoding, label['entity_span'])
            o_s, o_e = self.manual_find_token_span(encoding, label['opinion_span'])

            asp_id = ASPECT_MAP.get(label['aspect'], 4)
            sent_id = SENTIMENT_MAP.get(label['sentiment'], 2)

            target_matrix[i] = [e_s, e_e, o_s, o_e, asp_id, sent_id]

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'targets': torch.tensor(target_matrix, dtype=torch.long)
        }

In [None]:
ds = EAOSDatasetManual(data, tokenizer)

In [None]:
# Split data into train/validation sets (80/20)
from sklearn.model_selection import train_test_split

train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

# Create datasets
train_dataset = EAOSDatasetManual(train_data, tokenizer)
val_dataset = EAOSDatasetManual(val_data, tokenizer)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

In [None]:
# L·∫•y th·ª≠ m·∫´u ƒë·∫ßu ti√™n
sample = ds[0]

print("Input IDs shape:", sample['input_ids'].shape) # N√™n l√† [256]
print("Targets shape:", sample['targets'].shape)     # N√™n l√† [5, 6]
print("M·∫´u Targets ƒë·∫ßu ti√™n:\n", sample['targets'])
print("Decode th·ª≠ input:", tokenizer.decode(sample['input_ids']))

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel

class MultiEAOSModel(nn.Module):
    def __init__(self, model_name="vinai/phobert-base",
                 num_aspects=11,
                 num_sentiments=3,
                 max_len=256,
                 max_quads=4,
                 hidden_dim=256):
        super(MultiEAOSModel, self).__init__()

        # 1. BERT Encoder (PhoBERT) [cite: 12, 45]
        # D√πng ƒë·ªÉ m√£ h√≥a ng·ªØ nghƒ©a ng·ªØ c·∫£nh ti·∫øng Vi·ªát
        self.bert = AutoModel.from_pretrained(model_name)
        self.bert_hidden_size = self.bert.config.hidden_size # Th∆∞·ªùng l√† 768

        # 2. BiLSTM Layer [cite: 161, 163]
        # H·ªçc ph·ª• thu·ªôc xa v√† ng·ªØ c·∫£nh hai chi·ªÅu
        self.lstm = nn.LSTM(
            input_size=self.bert_hidden_size,
            hidden_size=hidden_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        # BiLSTM output size = hidden_dim * 2 (do 2 chi·ªÅu)
        self.lstm_out_dim = hidden_dim * 2

        # 3. Learnable Queries & Attention
        # ƒê·ªÉ d·ª± ƒëo√°n Max_Quads b·ªô t·ª© c√πng l√∫c, ta t·∫°o ra c√°c "Query Vectors"
        # M·ªói Query ƒë·∫°i di·ªán cho m·ªôt "khe ch·ª©a" b·ªô t·ª© ti·ªÅm nƒÉng
        self.quad_queries = nn.Parameter(torch.randn(max_quads, self.lstm_out_dim))

        # Multi-Head Attention: Queries (Quad Slots) t√¨m ki·∫øm th√¥ng tin t·ª´ Key/Value (BiLSTM Output)
        self.attention = nn.MultiheadAttention(
            embed_dim=self.lstm_out_dim,
            num_heads=4,
            batch_first=True
        )

        # 4. Prediction Heads (6 ƒë·∫ßu ra cho m·ªói Quad) [cite: 169]
        # Output: (Entity_Start, Entity_End, Opinion_Start, Opinion_End, Aspect, Sentiment)

        # D·ª± ƒëo√°n v·ªã tr√≠ trong c√¢u (Pointer Network) -> Output size = max_len
        self.fc_e_start = nn.Linear(self.lstm_out_dim, max_len)
        self.fc_e_end   = nn.Linear(self.lstm_out_dim, max_len)
        self.fc_o_start = nn.Linear(self.lstm_out_dim, max_len)
        self.fc_o_end   = nn.Linear(self.lstm_out_dim, max_len)

        # D·ª± ƒëo√°n ph√¢n lo·∫°i
        self.fc_aspect    = nn.Linear(self.lstm_out_dim, num_aspects)
        self.fc_sentiment = nn.Linear(self.lstm_out_dim, num_sentiments)

        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        # --- A. Encoding Phase ---
        # Output PhoBERT: (Batch, Seq_Len, 768)
        bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]

        # Output BiLSTM: (Batch, Seq_Len, Hidden_Dim * 2)
        lstm_out, _ = self.lstm(bert_out)

        # --- B. Multi-EAOS Decoding Phase ---
        batch_size = input_ids.size(0)

        # Nh√¢n b·∫£n Queries cho c·∫£ Batch: (Batch, Max_Quads, Hidden_Dim * 2)
        queries = self.quad_queries.unsqueeze(0).expand(batch_size, -1, -1)

        # Attention: Queries (Q) soi v√†o LSTM Output (K, V) ƒë·ªÉ l·∫•y th√¥ng tin
        # Output: (Batch, Max_Quads, Hidden_Dim * 2)
        attn_out, _ = self.attention(query=queries, key=lstm_out, value=lstm_out)

        attn_out = self.dropout(attn_out)

        # --- C. Prediction Phase ---
        # M·ªói vector trong attn_out ƒë·∫°i di·ªán cho 1 b·ªô t·ª© (Quad)

        # 1. D·ª± ƒëo√°n v·ªã tr√≠ (Logits: Batch, Max_Quads, Max_Len)
        e_start_logits = self.fc_e_start(attn_out)
        e_end_logits   = self.fc_e_end(attn_out)
        o_start_logits = self.fc_o_start(attn_out)
        o_end_logits   = self.fc_o_end(attn_out)

        # 2. D·ª± ƒëo√°n ph√¢n lo·∫°i (Logits: Batch, Max_Quads, Num_Classes)
        aspect_logits    = self.fc_aspect(attn_out)
        sentiment_logits = self.fc_sentiment(attn_out)

        return {
            "e_start": e_start_logits,
            "e_end": e_end_logits,
            "o_start": o_start_logits,
            "o_end": o_end_logits,
            "aspect": aspect_logits,
            "sentiment": sentiment_logits
        }

In [None]:
# Gi·∫£ l·∫≠p input t·ª´ Step 2
dummy_input_ids = torch.randint(0, 1000, (2, 256)) # Batch size = 2, Max len = 256
dummy_mask = torch.ones((2, 256))

# Kh·ªüi t·∫°o model
model = MultiEAOSModel(max_quads=4)

# Forward pass
outputs = model(dummy_input_ids, dummy_mask)

print("K√≠ch th∆∞·ªõc ƒë·∫ßu ra:")
print("Entity Start Logits:", outputs['e_start'].shape) # K√¨ v·ªçng: [2, 4, 256]
print("Aspect Logits:      ", outputs['aspect'].shape)  # K√¨ v·ªçng: [2, 4, 11] (11 Aspect categories)
print("Sentiment Logits:   ", outputs['sentiment'].shape) # K√¨ v·ªçng: [2, 4, 3] (3 Sentiment classes)

In [None]:
import torch.optim as optim
from tqdm import tqdm # Th∆∞ vi·ªán t·∫°o thanh ti·∫øn tr√¨nh (loading bar)

class MultiEAOSLoss(nn.Module):
    def __init__(self):
        super(MultiEAOSLoss, self).__init__()
        # ignore_index=-1 gi√∫p b·ªè qua c√°c v·ªã tr√≠ padding trong qu√° tr√¨nh t√≠nh loss
        self.criterion = nn.CrossEntropyLoss(ignore_index=-1)

    def forward(self, outputs, targets):
        """
        outputs: Dictionary ch·ª©a logits t·ª´ model (e_start, e_end, ..., sentiment)
        targets: Tensor [Batch, Quads, 6] ch·ª©a nh√£n ƒë√∫ng
                 Th·ª© t·ª± c·ªôt: 0:e_s, 1:e_e, 2:o_s, 3:o_e, 4:aspect, 5:sentiment
        """
        # Ch√∫ng ta g·ªôp (Flatten) dimension Batch v√† Quads l·∫°i ƒë·ªÉ t√≠nh Loss m·ªôt th·ªÉ
        # Shape Logits: (Batch * Quads, Num_Classes)
        # Shape Targets: (Batch * Quads)

        # 1. Loss cho v·ªã tr√≠ Entity (Start & End)
        loss_e_start = self.criterion(outputs['e_start'].view(-1, outputs['e_start'].shape[-1]),
                                      targets[:, :, 0].view(-1))
        loss_e_end   = self.criterion(outputs['e_end'].view(-1, outputs['e_end'].shape[-1]),
                                      targets[:, :, 1].view(-1))

        # 2. Loss cho v·ªã tr√≠ Opinion (Start & End)
        loss_o_start = self.criterion(outputs['o_start'].view(-1, outputs['o_start'].shape[-1]),
                                      targets[:, :, 2].view(-1))
        loss_o_end   = self.criterion(outputs['o_end'].view(-1, outputs['o_end'].shape[-1]),
                                      targets[:, :, 3].view(-1))

        # 3. Loss cho Aspect & Sentiment
        loss_aspect    = self.criterion(outputs['aspect'].view(-1, outputs['aspect'].shape[-1]),
                                        targets[:, :, 4].view(-1))
        loss_sentiment = self.criterion(outputs['sentiment'].view(-1, outputs['sentiment'].shape[-1]),
                                        targets[:, :, 5].view(-1))

        # T·ªïng h·ª£p Loss (c√≥ th·ªÉ th√™m tr·ªçng s·ªë weight n·∫øu mu·ªën ∆∞u ti√™n task n√†o h∆°n)
        total_loss = loss_e_start + loss_e_end + loss_o_start + loss_o_end + loss_aspect + loss_sentiment
        return total_loss

In [None]:
import os
import torch
import json
from datetime import datetime

# Create model directories
MODEL_DIR = "./models"
CHECKPOINT_DIR = os.path.join(MODEL_DIR, "checkpoints")
BEST_MODEL_DIR = os.path.join(MODEL_DIR, "best_model")

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(BEST_MODEL_DIR, exist_ok=True)

def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, filename):
    """
    Save training checkpoint with all necessary information
    """
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    }
    torch.save(checkpoint, filename)
    print(f"‚úÖ Saved checkpoint: {filename}")

def save_best_model(model, tokenizer, epoch, val_loss, metrics=None):
    """
    Save the best model with all necessary files for deployment
    """
    # Save model weights
    model_path = os.path.join(BEST_MODEL_DIR, "model.pth")
    torch.save(model.state_dict(), model_path)
    
    # Save model configuration
    config = {
        "model_name": "vinai/phobert-base",
        "num_aspects": 11,
        "num_sentiments": 3,
        "max_len": 256,
        "max_quads": 4,
        "hidden_dim": 256,
        "best_epoch": epoch,
        "best_val_loss": float(val_loss),
        "aspect_map": ASPECT_MAP,
        "sentiment_map": SENTIMENT_MAP,
        "saved_at": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    }
    
    if metrics:
        config.update(metrics)
    
    config_path = os.path.join(BEST_MODEL_DIR, "config.json")
    with open(config_path, 'w', encoding='utf-8') as f:
        json.dump(config, f, ensure_ascii=False, indent=2)
    
    print(f"üèÜ Saved best model at epoch {epoch} with val_loss: {val_loss:.4f}")
    
def load_checkpoint(model, optimizer, filename, device):
    """
    Load training checkpoint to resume training
    """
    if not os.path.isfile(filename):
        print(f"‚ö†Ô∏è  Checkpoint not found: {filename}. Training from scratch.")
        return model, optimizer, 0, float('inf')
    
    print(f"üîÑ Loading checkpoint from: {filename}")
    checkpoint = torch.load(filename, map_location=device, weights_only=False)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint.get('val_loss', float('inf'))
    
    print(f"‚úÖ Resumed from epoch {checkpoint['epoch']}, best val_loss: {best_val_loss:.4f}")
    return model, optimizer, start_epoch, best_val_loss

def load_model_for_inference(model_class, device):
    """
    Load the best saved model for inference (backend use)
    """
    config_path = os.path.join(BEST_MODEL_DIR, "config.json")
    model_path = os.path.join(BEST_MODEL_DIR, "model.pth")
    
    if not os.path.exists(config_path) or not os.path.exists(model_path):
        raise FileNotFoundError("Best model not found. Please train the model first.")
    
    # Load configuration
    with open(config_path, 'r', encoding='utf-8') as f:
        config = json.load(f)
    
    # Initialize model with saved config
    model = model_class(
        model_name=config['model_name'],
        num_aspects=config['num_aspects'],
        num_sentiments=config['num_sentiments'],
        max_len=config['max_len'],
        max_quads=config['max_quads'],
        hidden_dim=config['hidden_dim']
    ).to(device)
    
    # Load weights
    model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
    model.eval()
    
    print(f"‚úÖ Loaded model from epoch {config['best_epoch']} with val_loss: {config['best_val_loss']:.4f}")
    return model, config

print("üìÅ Model directories created:")
print(f"  - Checkpoints: {CHECKPOINT_DIR}")
print(f"  - Best model: {BEST_MODEL_DIR}")

In [None]:
from torch.utils.data import DataLoader

# --- CONFIGURATION ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 8
LR = 2e-5
EPOCHS = 200
SAVE_EVERY = 10  # Save checkpoint every N epochs
EARLY_STOP_PATIENCE = 20  # Stop if no improvement for N epochs

def calculate_metrics(outputs, targets):
    """
    Calculate Accuracy, Precision, Recall, F1-score for EAOS predictions
    
    A prediction is considered correct (TP) if ALL 6 components match:
    - entity_start, entity_end, opinion_start, opinion_end, aspect, sentiment
    
    Returns:
        dict: Contains accuracy, precision, recall, f1_score
    """
    batch_size, max_quads = targets.shape[0], targets.shape[1]
    
    # Get predictions (argmax)
    pred_e_start = torch.argmax(outputs['e_start'], dim=-1)  # [batch, quads]
    pred_e_end = torch.argmax(outputs['e_end'], dim=-1)
    pred_o_start = torch.argmax(outputs['o_start'], dim=-1)
    pred_o_end = torch.argmax(outputs['o_end'], dim=-1)
    pred_aspect = torch.argmax(outputs['aspect'], dim=-1)
    pred_sentiment = torch.argmax(outputs['sentiment'], dim=-1)
    
    # Flatten predictions and targets
    # Shape: [batch * quads]
    pred_e_start = pred_e_start.view(-1)
    pred_e_end = pred_e_end.view(-1)
    pred_o_start = pred_o_start.view(-1)
    pred_o_end = pred_o_end.view(-1)
    pred_aspect = pred_aspect.view(-1)
    pred_sentiment = pred_sentiment.view(-1)
    
    true_e_start = targets[:, :, 0].view(-1)
    true_e_end = targets[:, :, 1].view(-1)
    true_o_start = targets[:, :, 2].view(-1)
    true_o_end = targets[:, :, 3].view(-1)
    true_aspect = targets[:, :, 4].view(-1)
    true_sentiment = targets[:, :, 5].view(-1)
    
    # Create mask for valid targets (not padding, i.e., != -1)
    valid_mask = (true_e_start != -1)
    
    if valid_mask.sum() == 0:
        return {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
    
    # Check if all 6 components match (strict matching)
    all_match = (
        (pred_e_start == true_e_start) &
        (pred_e_end == true_e_end) &
        (pred_o_start == true_o_start) &
        (pred_o_end == true_o_end) &
        (pred_aspect == true_aspect) &
        (pred_sentiment == true_sentiment)
    )
    
    # Apply mask to only count valid predictions
    all_match = all_match & valid_mask
    
    # True Positives: predictions that match ground truth
    tp = all_match.sum().item()
    
    # Total valid ground truth samples
    total_valid = valid_mask.sum().item()
    
    # For EAOS, we consider:
    # - TP: Correct predictions
    # - FP: Incorrect predictions (where target exists)
    # - FN: Missed predictions (same as FP in this case since we have fixed slots)
    fp = total_valid - tp
    fn = total_valid - tp
    
    # Calculate metrics
    accuracy = tp / total_valid if total_valid > 0 else 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1_score
    }

def train_epoch(model, data_loader, optimizer, loss_fn, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    all_metrics = {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1_score': 0}
    
    progress_bar = tqdm(data_loader, desc="Training")
    
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        targets = batch['targets'].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        metrics = calculate_metrics(outputs, targets)
        for key in all_metrics:
            all_metrics[key] += metrics[key]
        
        total_loss += loss.item()
        progress_bar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'f1': f"{metrics['f1_score']:.4f}"
        })
    
    avg_loss = total_loss / len(data_loader)
    avg_metrics = {key: val / len(data_loader) for key, val in all_metrics.items()}
    
    return avg_loss, avg_metrics

def validate_epoch(model, data_loader, loss_fn, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    all_metrics = {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1_score': 0}
    
    progress_bar = tqdm(data_loader, desc="Validation")
    
    with torch.no_grad():
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = batch['targets'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, targets)
            
            # Calculate metrics
            metrics = calculate_metrics(outputs, targets)
            for key in all_metrics:
                all_metrics[key] += metrics[key]
            
            total_loss += loss.item()
            progress_bar.set_postfix({
                'val_loss': f"{loss.item():.4f}",
                'f1': f"{metrics['f1_score']:.4f}"
            })
    
    avg_loss = total_loss / len(data_loader)
    avg_metrics = {key: val / len(data_loader) for key, val in all_metrics.items()}
    
    return avg_loss, avg_metrics

def run_training(train_dataset, val_dataset, resume_from=None):
    """
    Main training function with validation and model saving
    
    Args:
        train_dataset: Training dataset
        val_dataset: Validation dataset
        resume_from: Path to checkpoint to resume from (optional)
    """
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Initialize model and loss
    model = MultiEAOSModel(max_quads=4).to(DEVICE)
    loss_fn = MultiEAOSLoss().to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=LR)
    
    # Resume from checkpoint if specified
    start_epoch = 0
    best_val_loss = float('inf')
    best_val_f1 = 0.0
    patience_counter = 0
    
    if resume_from:
        model, optimizer, start_epoch, best_val_loss = load_checkpoint(
            model, optimizer, resume_from, DEVICE
        )
    
    print(f"üöÄ Starting training on: {DEVICE}")
    print(f"üìä Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
    print(f"üéØ Batch size: {BATCH_SIZE}, Learning rate: {LR}")
    print("-" * 70)
    
    # Training loop
    for epoch in range(start_epoch, EPOCHS):
        print(f"\n{'='*70}")
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        print(f"{'='*70}")
        
        # Train
        train_loss, train_metrics = train_epoch(model, train_loader, optimizer, loss_fn, DEVICE)
        
        # Validate
        val_loss, val_metrics = validate_epoch(model, val_loader, loss_fn, DEVICE)
        
        print(f"\nüìà Epoch {epoch + 1} Results:")
        print(f"   Train Loss: {train_loss:.4f}")
        print(f"   Train Metrics:")
        print(f"      - Accuracy:  {train_metrics['accuracy']:.4f}")
        print(f"      - Precision: {train_metrics['precision']:.4f}")
        print(f"      - Recall:    {train_metrics['recall']:.4f}")
        print(f"      - F1-Score:  {train_metrics['f1_score']:.4f}")
        print(f"\n   Val Loss:   {val_loss:.4f}")
        print(f"   Val Metrics:")
        print(f"      - Accuracy:  {val_metrics['accuracy']:.4f}")
        print(f"      - Precision: {val_metrics['precision']:.4f}")
        print(f"      - Recall:    {val_metrics['recall']:.4f}")
        print(f"      - F1-Score:  {val_metrics['f1_score']:.4f}")
        
        # Save checkpoint periodically
        if (epoch + 1) % SAVE_EVERY == 0:
            checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch+1}.pth")
            save_checkpoint(model, optimizer, epoch, train_loss, val_loss, checkpoint_path)
        
        # Save best model based on validation F1-score
        if val_metrics['f1_score'] > best_val_f1:
            best_val_f1 = val_metrics['f1_score']
            best_val_loss = val_loss
            patience_counter = 0
            
            # Save with metrics
            metrics_dict = {
                'train_accuracy': train_metrics['accuracy'],
                'train_precision': train_metrics['precision'],
                'train_recall': train_metrics['recall'],
                'train_f1_score': train_metrics['f1_score'],
                'val_accuracy': val_metrics['accuracy'],
                'val_precision': val_metrics['precision'],
                'val_recall': val_metrics['recall'],
                'val_f1_score': val_metrics['f1_score']
            }
            save_best_model(model, tokenizer, epoch + 1, val_loss, metrics=metrics_dict)
            print(f"   üèÜ New best model saved! (F1: {best_val_f1:.4f})")
        else:
            patience_counter += 1
            print(f"   No improvement ({patience_counter}/{EARLY_STOP_PATIENCE})")
        
        # Early stopping
        if patience_counter >= EARLY_STOP_PATIENCE:
            print(f"\n‚ö†Ô∏è  Early stopping triggered after {epoch + 1} epochs")
            print(f"   Best val F1-score: {best_val_f1:.4f}")
            print(f"   Best val loss: {best_val_loss:.4f}")
            break
        
        # Save latest checkpoint for resuming
        latest_checkpoint = os.path.join(CHECKPOINT_DIR, "latest_checkpoint.pth")
        save_checkpoint(model, optimizer, epoch, train_loss, val_loss, latest_checkpoint)
    
    print("\n" + "="*70)
    print("‚úÖ Training completed!")
    print(f"üèÜ Best validation F1-score: {best_val_f1:.4f}")
    print(f"üìÅ Best model saved in: {BEST_MODEL_DIR}")
    print("="*70)
    
    return model

# Example usage (commented out - uncomment to run):
# trained_model = run_training(train_dataset, val_dataset)
# 
# To resume training from checkpoint:
# trained_model = run_training(train_dataset, val_dataset, 
#                              resume_from="./models/checkpoints/latest_checkpoint.pth")

In [None]:
# ============================================================================
# EVALUATION FUNCTION FOR TEST SET
# ============================================================================

def evaluate_model(model, test_dataset, device=DEVICE, batch_size=8):
    """
    Comprehensive evaluation of the model on test set
    
    Args:
        model: Trained MultiEAOSModel
        test_dataset: Test dataset
        device: torch device
        batch_size: Batch size for evaluation
    
    Returns:
        dict: Evaluation metrics
    """
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    loss_fn = MultiEAOSLoss().to(device)
    
    model.eval()
    total_loss = 0
    all_metrics = {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1_score': 0}
    
    print("üîç Evaluating model on test set...")
    progress_bar = tqdm(test_loader, desc="Evaluation")
    
    with torch.no_grad():
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = batch['targets'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, targets)
            
            # Calculate metrics
            metrics = calculate_metrics(outputs, targets)
            for key in all_metrics:
                all_metrics[key] += metrics[key]
            
            total_loss += loss.item()
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'f1': f"{metrics['f1_score']:.4f}"
            })
    
    avg_loss = total_loss / len(test_loader)
    avg_metrics = {key: val / len(test_loader) for key, val in all_metrics.items()}
    
    # Print detailed results
    print("\n" + "="*70)
    print("üìä EVALUATION RESULTS")
    print("="*70)
    print(f"Test Loss:      {avg_loss:.4f}")
    print(f"Accuracy:       {avg_metrics['accuracy']:.4f} ({avg_metrics['accuracy']*100:.2f}%)")
    print(f"Precision:      {avg_metrics['precision']:.4f} ({avg_metrics['precision']*100:.2f}%)")
    print(f"Recall:         {avg_metrics['recall']:.4f} ({avg_metrics['recall']*100:.2f}%)")
    print(f"F1-Score:       {avg_metrics['f1_score']:.4f} ({avg_metrics['f1_score']*100:.2f}%)")
    print("="*70)
    
    return {
        'test_loss': avg_loss,
        **avg_metrics
    }

# ============================================================================
# TRAINING EXECUTION
# ============================================================================

# Run training with train/validation split
trained_model = run_training(train_dataset, val_dataset)

# Evaluate on validation set (can be used as test set if you don't have separate test data)
print("\n\n" + "üéØ FINAL EVALUATION ON VALIDATION SET " + "\n")
final_metrics = evaluate_model(trained_model, val_dataset)

# To resume training from a checkpoint, uncomment this line:
# trained_model = run_training(train_dataset, val_dataset, 
#                              resume_from="./models/checkpoints/latest_checkpoint.pth")

In [None]:
import torch

# C√°c map ng∆∞·ª£c ƒë·ªÉ chuy·ªÉn s·ªë th√†nh ch·ªØ
ID2ASPECT = {v: k for k, v in ASPECT_MAP.items() if v != -1}
ID2SENTIMENT = {v: k for k, v in SENTIMENT_MAP.items() if v != -1}

def decode_prediction(model, tokenizer, text, device, max_len=256):
    model.eval()

    # 1. Ti·ªÅn x·ª≠ l√Ω input (Tokenize)
    inputs = tokenizer(
        text,
        padding='max_length',
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # 2. Ch·∫°y m√¥ h√¨nh
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)

    # 3. Gi·∫£i m√£ k·∫øt qu·∫£ (Decoding)
    results = []

    # L·∫•y ra c√°c index c√≥ x√°c su·∫•t cao nh·∫•t (Argmax)
    # Shape: [1, 5] (Batch=1, Quads=5)
    pred_e_start = torch.argmax(outputs['e_start'], dim=-1)[0]
    pred_e_end   = torch.argmax(outputs['e_end'], dim=-1)[0]
    pred_o_start = torch.argmax(outputs['o_start'], dim=-1)[0]
    pred_o_end   = torch.argmax(outputs['o_end'], dim=-1)[0]
    pred_aspect  = torch.argmax(outputs['aspect'], dim=-1)[0]
    pred_sent    = torch.argmax(outputs['sentiment'], dim=-1)[0]

    # L·∫•y token g·ªëc ƒë·ªÉ decode text
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

    # Duy·ªát qua t·ª´ng slot trong 5 slots d·ª± ƒëo√°n
    for i in range(len(pred_e_start)):
        e_s, e_e = pred_e_start[i].item(), pred_e_end[i].item()
        o_s, o_e = pred_o_start[i].item(), pred_o_end[i].item()

        # --- B·ªò L·ªåC R√ÅC (HEURISTICS) ---
        # Lo·∫°i b·ªè n·∫øu Start > End ho·∫∑c ch·ªâ tr·ªè v√†o token ƒë·∫∑c bi·ªát (CLS/SEP/PAD)
        # Tokenizer PhoBERT: 0=<s>, 2=</s>, 1=<pad>
        if e_s > e_e or o_s > o_e: continue
        if e_s == 0 or e_e == 0: continue # B·ªè qua n·∫øu tr·ªè v√†o [CLS]
        if e_s >= len(tokens) or o_s >= len(tokens): continue

        # Decode text t·ª´ token index
        # convert_tokens_to_string s·∫Ω n·ªëi l·∫°i c√°c t·ª´ v√† x·ª≠ l√Ω d·∫•u '_'
        entity_tokens = tokens[e_s : e_e + 1]
        opinion_tokens = tokens[o_s : o_e + 1]

        entity_text = tokenizer.convert_tokens_to_string(entity_tokens).replace('_', ' ')
        opinion_text = tokenizer.convert_tokens_to_string(opinion_tokens).replace('_', ' ')

        # L·∫•y nh√£n ph√¢n lo·∫°i
        aspect_label = ID2ASPECT.get(pred_aspect[i].item(), "Kh√°c")
        sentiment_label = ID2SENTIMENT.get(pred_sent[i].item(), "Trung t√≠nh")

        # Ch·ªâ l·∫•y k·∫øt qu·∫£ n·∫øu text kh√¥ng r·ªóng
        if entity_text.strip() and opinion_text.strip():
            results.append({
                "entity": entity_text,
                "aspect": aspect_label,
                "opinion": opinion_text,
                "sentiment": sentiment_label
            })

    return results

# --- CH·∫†Y TH·ª¨ NGHI·ªÜM ---
# sample_text = "Ch∆∞∆°ng tr√¨nh m√πa n√†y ch√°n qu√°, MC d·∫´n nh·∫°t nh·∫Ωo"
# preds = decode_prediction(trained_model, tokenizer, sample_text, DEVICE)
# print("K·∫øt qu·∫£ d·ª± ƒëo√°n:", preds)

In [None]:
# ============================================================================
# BACKEND DEPLOYMENT UTILITY
# ============================================================================

class EAOSInference:
    """
    Production-ready inference class for backend deployment
    This class can be imported and used in your FastAPI/Flask backend
    """
    
    def __init__(self, model_dir="./models/best_model", device=None):
        """
        Initialize the inference model
        
        Args:
            model_dir: Directory containing model.pth and config.json
            device: torch device (auto-detected if None)
        """
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model_dir = model_dir
        
        # Load configuration
        config_path = os.path.join(model_dir, "config.json")
        with open(config_path, 'r', encoding='utf-8') as f:
            self.config = json.load(f)
        
        # Create reverse mappings
        self.id2aspect = {v: k for k, v in self.config['aspect_map'].items()}
        self.id2sentiment = {v: k for k, v in self.config['sentiment_map'].items()}
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.config['model_name'])
        
        # Initialize and load model
        self.model = MultiEAOSModel(
            model_name=self.config['model_name'],
            num_aspects=self.config['num_aspects'],
            num_sentiments=self.config['num_sentiments'],
            max_len=self.config['max_len'],
            max_quads=self.config['max_quads'],
            hidden_dim=self.config['hidden_dim']
        ).to(self.device)
        
        # Load weights
        model_path = os.path.join(model_dir, "model.pth")
        self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
        self.model.eval()
        
        print(f"‚úÖ Model loaded successfully")
        print(f"   Device: {self.device}")
        print(f"   Model from epoch: {self.config['best_epoch']}")
        print(f"   Best val_loss: {self.config['best_val_loss']:.4f}")
    
    def predict(self, text, confidence_threshold=0.5):
        """
        Predict EAOS quadruples from input text
        
        Args:
            text: Input Vietnamese text
            confidence_threshold: Minimum confidence score (0-1)
        
        Returns:
            List of dictionaries containing predictions
        """
        # Tokenize
        inputs = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.config['max_len'],
            return_tensors="pt"
        )
        input_ids = inputs['input_ids'].to(self.device)
        attention_mask = inputs['attention_mask'].to(self.device)
        
        # Run inference
        with torch.no_grad():
            outputs = self.model(input_ids, attention_mask)
        
        # Decode predictions
        results = []
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
        
        # Get predictions with argmax
        pred_e_start = torch.argmax(outputs['e_start'], dim=-1)[0]
        pred_e_end = torch.argmax(outputs['e_end'], dim=-1)[0]
        pred_o_start = torch.argmax(outputs['o_start'], dim=-1)[0]
        pred_o_end = torch.argmax(outputs['o_end'], dim=-1)[0]
        pred_aspect = torch.argmax(outputs['aspect'], dim=-1)[0]
        pred_sent = torch.argmax(outputs['sentiment'], dim=-1)[0]
        
        # Get confidence scores (max softmax probability)
        aspect_probs = torch.softmax(outputs['aspect'], dim=-1)[0]
        sent_probs = torch.softmax(outputs['sentiment'], dim=-1)[0]
        
        for i in range(len(pred_e_start)):
            e_s, e_e = pred_e_start[i].item(), pred_e_end[i].item()
            o_s, o_e = pred_o_start[i].item(), pred_o_end[i].item()
            
            # Filter invalid predictions
            if e_s > e_e or o_s > o_e:
                continue
            if e_s == 0 or e_e == 0:
                continue
            if e_s >= len(tokens) or o_s >= len(tokens):
                continue
            
            # Get confidence scores
            aspect_conf = aspect_probs[i][pred_aspect[i]].item()
            sent_conf = sent_probs[i][pred_sent[i]].item()
            avg_confidence = (aspect_conf + sent_conf) / 2
            
            # Apply confidence threshold
            if avg_confidence < confidence_threshold:
                continue
            
            # Decode text
            entity_text = self.tokenizer.convert_tokens_to_string(
                tokens[e_s:e_e+1]
            ).replace('_', ' ').strip()
            
            opinion_text = self.tokenizer.convert_tokens_to_string(
                tokens[o_s:o_e+1]
            ).replace('_', ' ').strip()
            
            # Get labels
            aspect_label = self.id2aspect.get(pred_aspect[i].item(), "Kh√°c")
            sentiment_label = self.id2sentiment.get(pred_sent[i].item(), "Trung t√≠nh")
            
            if entity_text and opinion_text:
                results.append({
                    "entity": entity_text,
                    "aspect": aspect_label,
                    "opinion": opinion_text,
                    "sentiment": sentiment_label,
                    "confidence": round(avg_confidence, 3)
                })
        
        return results
    
    def predict_batch(self, texts, confidence_threshold=0.5):
        """
        Predict for multiple texts
        
        Args:
            texts: List of input texts
            confidence_threshold: Minimum confidence score
        
        Returns:
            List of prediction lists
        """
        return [self.predict(text, confidence_threshold) for text in texts]

# Example usage for backend deployment:
# inferencer = EAOSInference(model_dir="./models/best_model")
# result = inferencer.predict("Ch∆∞∆°ng tr√¨nh r·∫•t hay, MC d·∫´n t·ªët")
# print(result)

In [None]:
# Test the inference class (after training)
# Uncomment to test:

# inferencer = EAOSInference(model_dir="./models/best_model")
# 
# sample_text = "t√¥i th·∫•y m√πa 2 kh√¥ng hay b·∫±ng m√πa 1 v√¨ m√πa 1 c√≥ tr·∫•n th√†nh m√πa 2 l·∫°i kh√¥ng c√≥"
# predictions = inferencer.predict(sample_text, confidence_threshold=0.3)
# 
# print("Input:", sample_text)
# print("\nPredictions:")
# for i, pred in enumerate(predictions, 1):
#     print(f"{i}. Entity: {pred['entity']}")
#     print(f"   Aspect: {pred['aspect']}")
#     print(f"   Opinion: {pred['opinion']}")
#     print(f"   Sentiment: {pred['sentiment']}")
#     print(f"   Confidence: {pred['confidence']}")
#     print()

# üì¶ Backend Integration Guide

## How to Use This Model in Your Backend

After training, you'll have a `models/best_model/` folder containing:
- `model.pth` - The trained model weights
- `config.json` - Model configuration and label mappings

### Option 1: FastAPI Example

```python
# backend/main.py
from fastapi import FastAPI
from pydantic import BaseModel
import sys
sys.path.append('../Stage2')  # Add notebook path
from multi_eaos import EAOSInference  # Import the inference class

app = FastAPI()
inferencer = EAOSInference(model_dir="../Stage2/models/best_model")

class TextInput(BaseModel):
    text: str
    confidence_threshold: float = 0.5

@app.post("/predict")
def predict_eaos(input_data: TextInput):
    predictions = inferencer.predict(
        input_data.text, 
        confidence_threshold=input_data.confidence_threshold
    )
    return {
        "input": input_data.text,
        "predictions": predictions,
        "count": len(predictions)
    }
```

### Option 2: Export as Standalone Python Module

Create a file `backend/eaos_model.py`:
1. Copy the `MultiEAOSModel` class
2. Copy the `EAOSInference` class
3. Import and use in your backend

```python
# backend/eaos_model.py
# [Copy MultiEAOSModel and EAOSInference classes here]

# backend/api.py
from eaos_model import EAOSInference

model = EAOSInference("../models/best_model")
result = model.predict("Ch∆∞∆°ng tr√¨nh hay qu√°!")
```

### Model Loading in Production

```python
import torch
from eaos_model import EAOSInference

# Load once at startup (not per request!)
model = EAOSInference(
    model_dir="./models/best_model",
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# Use for predictions
def analyze_text(text: str):
    return model.predict(text, confidence_threshold=0.5)
```