In [None]:
# !pip install transformers
# import os
# import json
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import Dataset, DataLoader
# from torch.optim import XLMRobertaModel, XLMRobertaTokenizerFast, AdamW, get_linear_schedule_with_warmup
# import numpy as np
# from tqdm import tqdm
# import logging
# from sklearn.metrics import precision_recall_fscore_support

!pip install transformers

import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW  # AdamW comes from torch.optim
from torch.utils.data import Dataset, DataLoader
from transformers import XLMRobertaModel, XLMRobertaTokenizerFast, get_linear_schedule_with_warmup

import numpy as np
from tqdm import tqdm
import logging
from sklearn.metrics import precision_recall_fscore_support

DEBUG = False

# Set up logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

# Define constants
MAX_SEQ_LENGTH = 512
NUM_EPOCHS = 5
BATCH_SIZE = 8
LEARNING_RATE = 2e-5
WARMUP_STEPS = 0
ADAPTER_SIZE = 128
NUM_LABELS_SPAN = 3  # 0:O, 1:B, 2:I
NUM_LABELS_POLARITY = 4  # Positive, Negative, Neutral, None
NUM_LABELS_INTENSITY = 3  # Strong, Average, Weak
NUM_LABELS_RELATION = 2  # Related, Not Related
SPAN_EMBEDDING_DIM = 768
RELATION_EMBEDDING_DIM = 256

# Custom collate function for DataLoader
def custom_collate(batch):
    batch = [item for item in batch if item is not None]
    collated_batch = {}
    fixed_keys = ['input_ids', 'attention_mask', 'holder_labels', 'target_labels', 'expression_labels']
    for key in fixed_keys:
        if key in batch[0]:
            collated_batch[key] = torch.stack([item[key] for item in batch])
    for key in batch[0].keys():
        if key not in fixed_keys:
            collated_batch[key] = [item[key] for item in batch]
    return collated_batch

# Dataset class
class SentimentDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=MAX_SEQ_LENGTH):
        self.max_length = max_length
        self.tokenizer = tokenizer
        self.examples = []
        with open(file_path, 'r', encoding='utf-8') as f:
            first_line = f.readline().strip()
            f.seek(0)
            if first_line.startswith('['):
                data = json.load(f)
            else:
                data = [json.loads(line) for line in f if line.strip()]
        logger.info(f"Loaded {len(data)} examples from {file_path}")
        for entry in data:
            if isinstance(entry, dict):
                processed = self.process_example(entry)
                if processed:
                    self.examples.append(processed)

    def process_example(self, entry):
        try:
            text = entry.get('text', '')
            sent_id = entry.get('sent_id', '')
            opinions = entry.get('opinions', [])
            if not text:
                logger.warning(f"Empty text for entry with sent_id {sent_id}")
                return None

            tokenized = self.tokenizer(
                text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_offsets_mapping=True,
                return_tensors='pt'
            )
            offset_mapping = tokenized['offset_mapping'][0]
            input_ids = tokenized['input_ids'][0]
            attention_mask = tokenized['attention_mask'][0]

            holder_labels = torch.zeros(self.max_length, dtype=torch.long)
            target_labels = torch.zeros(self.max_length, dtype=torch.long)
            expression_labels = torch.zeros(self.max_length, dtype=torch.long)

            opinion_data = []
            for opinion in opinions:
                if isinstance(opinion, dict):
                    holder_span = self._extract_span_safely(opinion, 'Source', offset_mapping)
                    target_span = self._extract_span_safely(opinion, 'Target', offset_mapping)
                    expression_span = self._extract_span_safely(opinion, 'Polar_expression', offset_mapping)
                    polarity = self._get_polarity_label(opinion.get('Polarity', 'None'))
                    intensity = self._get_intensity_label(opinion.get('Intensity', 'Average'))

                    if all(span is not None for span in [holder_span, target_span, expression_span]):
                        opinion_data.append({
                            'holder_span': holder_span,
                            'target_span': target_span,
                            'expression_span': expression_span,
                            'polarity': polarity,
                            'intensity': intensity,
                        })
                        self._mark_span(holder_labels, holder_span[0], holder_span[1])
                        self._mark_span(target_labels, target_span[0], target_span[1])
                        self._mark_span(expression_labels, expression_span[0], expression_span[1])

            return {
                'sent_id': sent_id,
                'text': text,
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'holder_labels': holder_labels,
                'target_labels': target_labels,
                'expression_labels': expression_labels,
                'opinion_data': opinion_data,
                'num_opinions': len(opinion_data)
            }
        except Exception as e:
            logger.error(f"Error processing example: {e}")
            return None

    def _extract_span_safely(self, opinion, key, offset_mapping):
        span_data = opinion.get(key)
        if not span_data:
            return (0, 0)
        try:
            span_str = None
            if isinstance(span_data, list):
                if len(span_data) > 1:
                    candidate = span_data[1]
                    span_str = candidate[0] if isinstance(candidate, list) and candidate else candidate
                else:
                    candidate = span_data[0]
                    span_str = candidate[1] if isinstance(candidate, list) and len(candidate) > 1 else candidate
            elif isinstance(span_data, str):
                span_str = span_data
            else:
                return (0, 0)
            if not span_str or ":" not in span_str:
                return (0, 0)
            return self._get_token_span(span_str, offset_mapping)
        except Exception:
            return (0, 0)

    def _get_token_span(self, span_str, offset_mapping):
        try:
            start, end = map(int, span_str.split(':'))
            start_token = end_token = 1
            for idx, (ts, te) in enumerate(offset_mapping):
                ts, te = int(ts), int(te)
                if ts == 0 and te == 0:
                    continue
                if ts <= start < te:
                    start_token = idx
                if ts < end:
                    end_token = idx
            return (start_token, end_token)
        except Exception:
            return (0, 0)

    def _mark_span(self, labels, start_idx, end_idx):
        try:
            labels[start_idx] = 1  # B
            if end_idx > start_idx:
                labels[start_idx+1:end_idx+1] = 2  # I
        except Exception as e:
            logger.warning(f"Error marking span {start_idx}:{end_idx}: {e}")

    def _get_polarity_label(self, polarity):
        polarity_map = {'Positive': 0, 'Negative': 1, 'Neutral': 2, 'None': 3}
        return polarity_map.get(polarity, 3)

    def _get_intensity_label(self, intensity):
        intensity_map = {'Strong': 0, 'Average': 1, 'Weak': 2}
        return intensity_map.get(intensity, 1)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

# Neural network modules
class SelfAttentionLayer(nn.Module):
    def __init__(self, input_dim, num_heads=8, head_dim=96):
        super(SelfAttentionLayer, self).__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.query = nn.Linear(input_dim, num_heads * head_dim)
        self.key = nn.Linear(input_dim, num_heads * head_dim)
        self.value = nn.Linear(input_dim, num_heads * head_dim)
        self.output_projection = nn.Linear(num_heads * head_dim, input_dim)
        self.layer_norm = nn.LayerNorm(input_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.head_dim)
        output = self.output_projection(context)
        return self.layer_norm(output + x)

class SpanDetector(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_labels=NUM_LABELS_SPAN):
        super(SpanDetector, self).__init__()
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.ReLU()
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, x):
        x = self.hidden(x)
        x = self.activation(x)
        return self.classifier(x)

class CrossSpanAttention(nn.Module):
    def __init__(self, input_dim, output_dim=RELATION_EMBEDDING_DIM):
        super(CrossSpanAttention, self).__init__()
        self.attention = nn.MultiheadAttention(input_dim, num_heads=4, batch_first=True)
        self.projection = nn.Linear(input_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, spans, span_masks=None):
        key_padding_mask = ~span_masks if span_masks is not None else None
        context, _ = self.attention(spans, spans, spans, key_padding_mask=key_padding_mask)
        output = self.projection(context)
        return self.layer_norm(output)

class RelationClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_labels=NUM_LABELS_RELATION):
        super(RelationClassifier, self).__init__()
        self.hidden = nn.Linear(input_dim * 2, hidden_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, span_pairs):
        x = self.hidden(span_pairs)
        x = self.activation(x)
        x = self.dropout(x)
        return self.classifier(x)

class PolarityClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_labels=NUM_LABELS_POLARITY):
        super(PolarityClassifier, self).__init__()
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, x):
        x = self.hidden(x)
        x = self.activation(x)
        x = self.dropout(x)
        return self.classifier(x)

class IntensityClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_labels=NUM_LABELS_INTENSITY):
        super(IntensityClassifier, self).__init__()
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, x):
        x = self.hidden(x)
        x = self.activation(x)
        x = self.dropout(x)
        return self.classifier(x)

class LanguageAdapter(nn.Module):
    def __init__(self, input_dim, bottleneck_dim=ADAPTER_SIZE):
        super(LanguageAdapter, self).__init__()
        self.down_project = nn.Linear(input_dim, bottleneck_dim)
        self.activation = nn.ReLU()
        self.up_project = nn.Linear(bottleneck_dim, input_dim)
        self.layer_norm = nn.LayerNorm(input_dim)

    def forward(self, x):
        residual = x
        x = self.down_project(x)
        x = self.activation(x)
        x = self.up_project(x)
        return self.layer_norm(x + residual)

# Main model class
class StructuredSentimentModel(nn.Module):
    def __init__(self, pretrained_model_name="xlm-roberta-base", use_adapters=False, num_languages=8):
        super(StructuredSentimentModel, self).__init__()
        self.encoder = XLMRobertaModel.from_pretrained(pretrained_model_name)
        self.hidden_size = self.encoder.config.hidden_size
        self.span_attention = SelfAttentionLayer(self.hidden_size)
        self.holder_detector = SpanDetector(self.hidden_size)
        self.target_detector = SpanDetector(self.hidden_size)
        self.expression_detector = SpanDetector(self.hidden_size)
        self.cross_span_attention = CrossSpanAttention(self.hidden_size)
        self.relation_classifier = RelationClassifier(RELATION_EMBEDDING_DIM)
        self.polarity_classifier = PolarityClassifier(RELATION_EMBEDDING_DIM)
        self.intensity_classifier = IntensityClassifier(RELATION_EMBEDDING_DIM)
        self.use_adapters = use_adapters
        if use_adapters:
            self.language_adapters = nn.ModuleList([LanguageAdapter(self.hidden_size) for _ in range(num_languages)])
        self._init_weights()

    def _init_weights(self):
        modules = [self.span_attention, self.holder_detector, self.target_detector,
                   self.expression_detector, self.cross_span_attention,
                   self.relation_classifier, self.polarity_classifier, self.intensity_classifier]
        for module in modules:
            for name, param in module.named_parameters():
                if 'weight' in name and len(param.shape) >= 2:
                    nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    nn.init.zeros_(param)

    def extract_spans(self, span_logits, attention_mask):
        batch_size = span_logits.size(0)
        span_preds = torch.argmax(torch.softmax(span_logits, dim=-1), dim=-1)
        all_spans = []
        for i in range(batch_size):
            mask = attention_mask[i].bool()
            preds = span_preds[i][mask]
            spans = []
            start_idx = None
            for j, label in enumerate(preds):
                if label == 1:  # B
                    if start_idx is not None:
                        spans.append((start_idx, j - 1))
                    start_idx = j
                elif label == 0:  # O
                    if start_idx is not None:
                        spans.append((start_idx, j - 1))
                        start_idx = None
            if start_idx is not None:
                spans.append((start_idx, len(preds) - 1))
            all_spans.append(spans)
        return all_spans

    def get_span_embeddings(self, hidden_states, spans, attention_mask):
        batch_size = hidden_states.size(0)
        max_spans = max([len(s) for s in spans], default=0)
        if max_spans == 0:
            return torch.zeros((batch_size, 0, self.hidden_size), device=hidden_states.device), torch.zeros((batch_size, 0), dtype=torch.bool, device=hidden_states.device)
        span_embeddings = torch.zeros((batch_size, max_spans, self.hidden_size), device=hidden_states.device)
        span_masks = torch.zeros((batch_size, max_spans), dtype=torch.bool, device=hidden_states.device)
        for i in range(batch_size):
            for j, (start, end) in enumerate(spans[i]):
                if j < max_spans:
                    span_embeddings[i, j] = hidden_states[i, start:end+1].mean(dim=0)
                    span_masks[i, j] = True
        return span_embeddings, span_masks

    def _combine_spans(self, holder_emb, holder_mask, target_emb, target_mask, expr_emb, expr_mask):
        batch_size = holder_emb.size(0)
        max_spans = holder_emb.size(1) + target_emb.size(1) + expr_emb.size(1)
        if max_spans == 0:
            return torch.zeros((batch_size, 0, self.hidden_size), device=holder_emb.device), torch.zeros((batch_size, 0), dtype=torch.bool, device=holder_emb.device)
        combined_emb = torch.zeros((batch_size, max_spans, self.hidden_size), device=holder_emb.device)
        combined_mask = torch.zeros((batch_size, max_spans), dtype=torch.bool, device=holder_emb.device)
        holder_size = holder_emb.size(1)
        target_size = target_emb.size(1)
        expr_size = expr_emb.size(1)
        combined_emb[:, :holder_size] = holder_emb
        combined_emb[:, holder_size:holder_size+target_size] = target_emb
        combined_emb[:, holder_size+target_size:] = expr_emb
        combined_mask[:, :holder_size] = holder_mask
        combined_mask[:, holder_size:holder_size+target_size] = target_mask
        combined_mask[:, holder_size+target_size:] = expr_mask
        return combined_emb, combined_mask

    def _create_span_pairs(self, span_embeddings, holder_mask, target_mask, expr_mask):
        batch_size = span_embeddings.size(0)
        holder_size = holder_mask.size(1)
        target_size = target_mask.size(1)
        expr_size = expr_mask.size(1)
        total_holders = holder_mask.sum(dim=1)
        total_targets = target_mask.sum(dim=1)
        total_expressions = expr_mask.sum(dim=1)
        max_pairs = torch.max(total_holders * total_expressions + total_targets * total_expressions)
        if max_pairs == 0:
            return None, None
        pair_embeddings = torch.zeros((batch_size, max_pairs, RELATION_EMBEDDING_DIM * 2), device=span_embeddings.device)
        pair_indices = torch.zeros((batch_size, max_pairs, 2), dtype=torch.long, device=span_embeddings.device)
        offset = holder_size + target_size
        for i in range(batch_size):
            pair_idx = 0
            for h_idx in range(holder_size):
                if not holder_mask[i, h_idx]:
                    continue
                for e_idx in range(expr_size):
                    if not expr_mask[i, e_idx] or pair_idx >= max_pairs:
                        continue
                    pair_embeddings[i, pair_idx] = torch.cat([span_embeddings[i, h_idx], span_embeddings[i, offset + e_idx]])
                    pair_indices[i, pair_idx] = torch.tensor([h_idx, offset + e_idx], device=span_embeddings.device)
                    pair_idx += 1
            for t_idx in range(target_size):
                if not target_mask[i, t_idx]:
                    continue
                for e_idx in range(expr_size):
                    if not expr_mask[i, e_idx] or pair_idx >= max_pairs:
                        continue
                    pair_embeddings[i, pair_idx] = torch.cat([span_embeddings[i, holder_size + t_idx], span_embeddings[i, offset + e_idx]])
                    pair_indices[i, pair_idx] = torch.tensor([holder_size + t_idx, offset + e_idx], device=span_embeddings.device)
                    pair_idx += 1
        return pair_embeddings, pair_indices

    def forward(self, input_ids, attention_mask, language_id=None, labels=None):
        batch_size = input_ids.size(0)
        encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask)
        hidden_states = encoder_outputs.last_hidden_state

        if self.use_adapters and language_id is not None:
            adapted_states = torch.zeros_like(hidden_states)
            for i in range(batch_size):
                adapted_states[i] = self.language_adapters[language_id[i].item()](hidden_states[i])
            hidden_states = adapted_states

        span_aware_states = self.span_attention(hidden_states, attention_mask)
        holder_logits = self.holder_detector(span_aware_states)
        target_logits = self.target_detector(span_aware_states)
        expression_logits = self.expression_detector(span_aware_states)

        if labels is not None:
            holder_labels, target_labels, expression_labels, gold_opinions = labels
            loss_fct = nn.CrossEntropyLoss()
            holder_loss = loss_fct(holder_logits.view(-1, NUM_LABELS_SPAN)[attention_mask.view(-1) == 1], holder_labels.view(-1)[attention_mask.view(-1) == 1])
            target_loss = loss_fct(target_logits.view(-1, NUM_LABELS_SPAN)[attention_mask.view(-1) == 1], target_labels.view(-1)[attention_mask.view(-1) == 1])
            expression_loss = loss_fct(expression_logits.view(-1, NUM_LABELS_SPAN)[attention_mask.view(-1) == 1], expression_labels.view(-1)[attention_mask.view(-1) == 1])
            span_loss = holder_loss + target_loss + expression_loss

            # Compute polarity and intensity losses using gold opinions
            all_expr_emb = []
            all_polarity_labels = []
            all_intensity_labels = []
            for i in range(batch_size):
                opinions = gold_opinions[i]
                if not opinions:
                    continue
                gold_holder_spans = [op['holder_span'] for op in opinions]
                gold_target_spans = [op['target_span'] for op in opinions]
                gold_expr_spans = [op['expression_span'] for op in opinions]
                holder_emb, _ = self.get_span_embeddings(span_aware_states[i:i+1], [gold_holder_spans], attention_mask[i:i+1])
                target_emb, _ = self.get_span_embeddings(span_aware_states[i:i+1], [gold_target_spans], attention_mask[i:i+1])
                expr_emb, _ = self.get_span_embeddings(span_aware_states[i:i+1], [gold_expr_spans], attention_mask[i:i+1])
                all_emb, all_mask = self._combine_spans(
                    holder_emb, torch.ones_like(holder_emb[..., 0], dtype=torch.bool),
                    target_emb, torch.ones_like(target_emb[..., 0], dtype=torch.bool),
                    expr_emb, torch.ones_like(expr_emb[..., 0], dtype=torch.bool)
                )
                relation_aware_emb = self.cross_span_attention(all_emb, all_mask)
                expr_relation_aware = relation_aware_emb[:, -len(gold_expr_spans):, :]
                all_expr_emb.append(expr_relation_aware[0])
                all_polarity_labels.extend([op['polarity'] for op in opinions])
                all_intensity_labels.extend([op['intensity'] for op in opinions])

            if all_expr_emb:
                all_expr_emb = torch.cat(all_expr_emb, dim=0)
                polarity_logits = self.polarity_classifier(all_expr_emb)
                intensity_logits = self.intensity_classifier(all_expr_emb)
                gold_polarity = torch.tensor(all_polarity_labels, dtype=torch.long, device=hidden_states.device)
                gold_intensity = torch.tensor(all_intensity_labels, dtype=torch.long, device=hidden_states.device)
                polarity_loss = loss_fct(polarity_logits, gold_polarity)
                intensity_loss = loss_fct(intensity_logits, gold_intensity)
                total_loss = span_loss + polarity_loss + intensity_loss
            else:
                total_loss = span_loss

            return {
                'loss': total_loss,
                'holder_logits': holder_logits,
                'target_logits': target_logits,
                'expression_logits': expression_logits
            }

        # Inference mode
        holder_spans = self.extract_spans(holder_logits, attention_mask)
        target_spans = self.extract_spans(target_logits, attention_mask)
        expression_spans = self.extract_spans(expression_logits, attention_mask)
        holder_embeddings, holder_masks = self.get_span_embeddings(span_aware_states, holder_spans, attention_mask)
        target_embeddings, target_masks = self.get_span_embeddings(span_aware_states, target_spans, attention_mask)
        expression_embeddings, expression_masks = self.get_span_embeddings(span_aware_states, expression_spans, attention_mask)
        all_span_embeddings, all_span_masks = self._combine_spans(holder_embeddings, holder_masks, target_embeddings, target_masks, expression_embeddings, expression_masks)
        relation_aware_embeddings = self.cross_span_attention(all_span_embeddings, all_span_masks)
        relation_pairs, pair_indices = self._create_span_pairs(relation_aware_embeddings, holder_masks, target_masks, expression_masks)

        relation_logits = None
        polarity_logits = None
        intensity_logits = None
        if relation_pairs is not None:
            relation_logits = self.relation_classifier(relation_pairs)
            expression_relation_aware = relation_aware_embeddings[:, holder_embeddings.size(1) + target_embeddings.size(1):, :]
            polarity_logits = self.polarity_classifier(expression_relation_aware)
            intensity_logits = self.intensity_classifier(expression_relation_aware)

        return {
            'holder_logits': holder_logits,
            'target_logits': target_logits,
            'expression_logits': expression_logits,
            'relation_logits': relation_logits,
            'polarity_logits': polarity_logits,
            'intensity_logits': intensity_logits,
            'holder_spans': holder_spans,
            'target_spans': target_spans,
            'expression_spans': expression_spans,
            'pair_indices': pair_indices
        }

# Training function
def train_model(model, train_dataloader, dev_dataloader, optimizer, scheduler, device, num_epochs, output_dir):
    best_f1 = 0.0
    for epoch in range(num_epochs):
        logger.info(f"Starting epoch {epoch+1}/{num_epochs}")
        model.train()
        train_loss = 0.0
        train_steps = 0
        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            holder_labels = batch['holder_labels'].to(device)
            target_labels = batch['target_labels'].to(device)
            expression_labels = batch['expression_labels'].to(device)
            gold_opinions = batch['opinion_data']  # List of lists
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=(holder_labels, target_labels, expression_labels, gold_opinions)
            )
            loss = outputs['loss']
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()
            train_steps += 1
        avg_train_loss = train_loss / train_steps if train_steps > 0 else 0
        logger.info(f"Epoch {epoch+1} - Average training loss: {avg_train_loss:.4f}")

        eval_results = evaluate_model(model, dev_dataloader, device)
        avg_f1 = sum(m['f1'] for m in eval_results.values()) / len(eval_results)
        logger.info(f"Epoch {epoch+1} - Evaluation F1: {avg_f1:.4f}")
        if avg_f1 > best_f1:
            best_f1 = avg_f1
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_f1': best_f1
            }, os.path.join(output_dir, f"best_model_f1_{best_f1:.4f}.pt"))
    return model

# Evaluation function
def evaluate_model(model, dataloader, device, output_json="evaluation_results.json"):
    model.eval()
    holder_preds, holder_true = [], []
    target_preds, target_true = [], []
    expression_preds, expression_true = [], []
    polarity_preds, polarity_true = [], []
    intensity_preds, intensity_true = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            holder_labels = batch['holder_labels'].to(device)
            target_labels = batch['target_labels'].to(device)
            expression_labels = batch['expression_labels'].to(device)
            gold_opinions = batch['opinion_data']  # List of lists of opinion dicts

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            # Token-level span predictions
            for i in range(input_ids.size(0)):
                mask = attention_mask[i].bool()
                seq_len = mask.sum().item()
                holder_preds.extend(torch.argmax(outputs['holder_logits'][i, :seq_len], dim=-1).cpu().tolist())
                holder_true.extend(holder_labels[i, :seq_len].cpu().tolist())
                target_preds.extend(torch.argmax(outputs['target_logits'][i, :seq_len], dim=-1).cpu().tolist())
                target_true.extend(target_labels[i, :seq_len].cpu().tolist())
                expression_preds.extend(torch.argmax(outputs['expression_logits'][i, :seq_len], dim=-1).cpu().tolist())
                expression_true.extend(expression_labels[i, :seq_len].cpu().tolist())

            # Polarity and intensity predictions
            if outputs['polarity_logits'] is not None and outputs['intensity_logits'] is not None:
                polarity_logits = outputs['polarity_logits']  # Shape: (batch_size, num_expressions, 4)
                intensity_logits = outputs['intensity_logits']  # Shape: (batch_size, num_expressions, 3)
                batch_polarity_preds = torch.argmax(polarity_logits, dim=-1).cpu().tolist()  # Shape: (batch_size, num_expressions)
                batch_intensity_preds = torch.argmax(intensity_logits, dim=-1).cpu().tolist()  # Shape: (batch_size, num_expressions)

                for i in range(len(gold_opinions)):
                    opinions = gold_opinions[i]
                    num_pred_expressions = len(batch_polarity_preds[i]) if i < len(batch_polarity_preds) else 0
                    num_gold_opinions = len(opinions)

                    # Align predictions with ground truth
                    for j in range(max(num_pred_expressions, num_gold_opinions)):
                        if j < num_pred_expressions:
                            pred_pol = batch_polarity_preds[i][j]
                            pred_int = batch_intensity_preds[i][j]
                        else:
                            pred_pol = 3  # None for polarity if no prediction
                            pred_int = 1  # Average for intensity if no prediction

                        if j < num_gold_opinions:
                            true_pol = opinions[j]['polarity']
                            true_int = opinions[j]['intensity']
                        else:
                            true_pol = 3  # None for polarity if no ground truth
                            true_int = 1  # Average for intensity if no ground truth

                        polarity_preds.append(pred_pol)
                        polarity_true.append(true_pol)
                        intensity_preds.append(pred_int)
                        intensity_true.append(true_int)

    # Compute token-level span metrics
    holder_prf = precision_recall_fscore_support(holder_true, holder_preds, average='macro', zero_division=0)
    target_prf = precision_recall_fscore_support(target_true, target_preds, average='macro', zero_division=0)
    expression_prf = precision_recall_fscore_support(expression_true, expression_preds, average='macro', zero_division=0)

    # Compute polarity and intensity metrics
    polarity_prf = precision_recall_fscore_support(polarity_true, polarity_preds, average='macro', zero_division=0)
    intensity_prf = precision_recall_fscore_support(intensity_true, intensity_preds, average='macro', zero_division=0)

    # Combine results
    results = {
        'holder': {'precision': holder_prf[0], 'recall': holder_prf[1], 'f1': holder_prf[2]},
        'target': {'precision': target_prf[0], 'recall': target_prf[1], 'f1': target_prf[2]},
        'expression': {'precision': expression_prf[0], 'recall': expression_prf[1], 'f1': expression_prf[2]},
        'polarity': {'precision': polarity_prf[0], 'recall': polarity_prf[1], 'f1': polarity_prf[2]},
        'intensity': {'precision': intensity_prf[0], 'recall': intensity_prf[1], 'f1': intensity_prf[2]}
    }

    # Save to JSON
    with open(output_json, "w", encoding="utf-8") as f:
        json.dump({'metrics': results}, f, indent=2)

    return results

# Main execution
def main():
    input_path = '/content'
    output_dir = '/content'
    data_dir = os.path.join(input_path, 'sentiment-data')
    language = 'opener_en'
    pretrained_model = 'xlm-roberta-base'
    use_adapters = False
    num_epochs = NUM_EPOCHS
    batch_size = BATCH_SIZE
    learning_rate = LEARNING_RATE

    os.makedirs(output_dir, exist_ok=True)
    logger.addHandler(logging.FileHandler(os.path.join(output_dir, 'training.log')))

    train_file = os.path.join(data_dir, language, "/content/train.json")
    dev_file = os.path.join(data_dir, language, "/content/dev.json")
    test_file = os.path.join(data_dir, language, "/content/test.json")

    tokenizer = XLMRobertaTokenizerFast.from_pretrained(pretrained_model)
    train_dataset = SentimentDataset(train_file, tokenizer)
    dev_dataset = SentimentDataset(dev_file, tokenizer)
    test_dataset = SentimentDataset(test_file, tokenizer)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate, num_workers=2, pin_memory=True)
    dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate, num_workers=2, pin_memory=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate, num_workers=2, pin_memory=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = StructuredSentimentModel(pretrained_model_name=pretrained_model, use_adapters=use_adapters).to(device)
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=len(train_dataloader) * num_epochs)

    model = train_model(model, train_dataloader, dev_dataloader, optimizer, scheduler, device, num_epochs, output_dir)
    test_results = evaluate_model(model, test_dataloader, device)
    logger.info("Test Results:")
    for span_type, metrics in test_results.items():
        logger.info(f"{span_type}: Precision={metrics['precision']:.4f}, Recall={metrics['recall']:.4f}, F1={metrics['f1']:.4f}")
    torch.save({
        'model_state_dict': model.state_dict(),
        'test_results': test_results
    }, os.path.join(output_dir, "final_model.pt"))

if __name__ == '__main__':
    main()