In [None]:
# Install required packages
!pip install transformers torch

# Import libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import XLMRobertaModel, XLMRobertaTokenizerFast
import json
import logging
from tqdm import tqdm
from google.colab import files
import os

# Set up logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
MAX_SEQ_LENGTH = 512
BATCH_SIZE = 8
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_LABELS_SPAN = 3
NUM_LABELS_POLARITY = 4
NUM_LABELS_INTENSITY = 3
SPAN_EMBEDDING_DIM = 768
RELATION_EMBEDDING_DIM = 256
ADAPTER_SIZE = 128

# Label mappings
polarity_map = {'Positive': 0, 'Negative': 1, 'Neutral': 2, 'None': 3}
intensity_map = {'Strong': 0, 'Average': 1, 'Weak': 2}
polarity_reverse_map = {v: k for k, v in polarity_map.items()}
intensity_reverse_map = {v: k for k, v in intensity_map.items()}

# Neural network modules
class SelfAttentionLayer(nn.Module):
    def __init__(self, input_dim, num_heads=8, head_dim=96):
        super(SelfAttentionLayer, self).__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.query = nn.Linear(input_dim, num_heads * head_dim)
        self.key = nn.Linear(input_dim, num_heads * head_dim)
        self.value = nn.Linear(input_dim, num_heads * head_dim)
        self.output_projection = nn.Linear(num_heads * head_dim, input_dim)
        self.layer_norm = nn.LayerNorm(input_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.head_dim)
        output = self.output_projection(context)
        return self.layer_norm(output + x)

class SpanDetector(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_labels=NUM_LABELS_SPAN):
        super(SpanDetector, self).__init__()
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.ReLU()
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, x):
        x = self.hidden(x)
        x = self.activation(x)
        return self.classifier(x)

class CrossSpanAttention(nn.Module):
    def __init__(self, input_dim, output_dim=RELATION_EMBEDDING_DIM):
        super(CrossSpanAttention, self).__init__()
        self.attention = nn.MultiheadAttention(input_dim, num_heads=4, batch_first=True)
        self.projection = nn.Linear(input_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, spans, span_masks=None):
        key_padding_mask = ~span_masks if span_masks is not None else None
        context, _ = self.attention(spans, spans, spans, key_padding_mask=key_padding_mask)
        output = self.projection(context)
        return self.layer_norm(output)

class RelationClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_labels=2):
        super(RelationClassifier, self).__init__()
        self.hidden = nn.Linear(input_dim * 2, hidden_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, span_pairs):
        x = self.hidden(span_pairs)
        x = self.activation(x)
        x = self.dropout(x)
        return self.classifier(x)

class PolarityClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_labels=NUM_LABELS_POLARITY):
        super(PolarityClassifier, self).__init__()
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, x):
        x = self.hidden(x)
        x = self.activation(x)
        x = self.dropout(x)
        return self.classifier(x)

class IntensityClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, num_labels=NUM_LABELS_INTENSITY):
        super(IntensityClassifier, self).__init__()
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, x):
        x = self.hidden(x)
        x = self.activation(x)
        x = self.dropout(x)
        return self.classifier(x)

class LanguageAdapter(nn.Module):
    def __init__(self, input_dim, bottleneck_dim=ADAPTER_SIZE):
        super(LanguageAdapter, self).__init__()
        self.down_project = nn.Linear(input_dim, bottleneck_dim)
        self.activation = nn.ReLU()
        self.up_project = nn.Linear(bottleneck_dim, input_dim)
        self.layer_norm = nn.LayerNorm(input_dim)

    def forward(self, x):
        residual = x
        x = self.down_project(x)
        x = self.activation(x)
        x = self.up_project(x)
        return self.layer_norm(x + residual)

class StructuredSentimentModel(nn.Module):
    def __init__(self, pretrained_model_name="xlm-roberta-base", use_adapters=False, num_languages=8):
        super(StructuredSentimentModel, self).__init__()
        self.encoder = XLMRobertaModel.from_pretrained(pretrained_model_name)
        self.hidden_size = self.encoder.config.hidden_size
        self.span_attention = SelfAttentionLayer(self.hidden_size)
        self.holder_detector = SpanDetector(self.hidden_size)
        self.target_detector = SpanDetector(self.hidden_size)
        self.expression_detector = SpanDetector(self.hidden_size)
        self.cross_span_attention = CrossSpanAttention(self.hidden_size)
        self.relation_classifier = RelationClassifier(RELATION_EMBEDDING_DIM)
        self.polarity_classifier = PolarityClassifier(RELATION_EMBEDDING_DIM)
        self.intensity_classifier = IntensityClassifier(RELATION_EMBEDDING_DIM)
        self.use_adapters = use_adapters
        if use_adapters:
            self.language_adapters = nn.ModuleList([LanguageAdapter(self.hidden_size) for _ in range(num_languages)])
        self._init_weights()

    def _init_weights(self):
        modules = [self.span_attention, self.holder_detector, self.target_detector,
                   self.expression_detector, self.cross_span_attention,
                   self.relation_classifier, self.polarity_classifier, self.intensity_classifier]
        for module in modules:
            for name, param in module.named_parameters():
                if 'weight' in name and len(param.shape) >= 2:
                    nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    nn.init.zeros_(param)

    def extract_spans(self, span_logits, attention_mask):
        batch_size = span_logits.size(0)
        span_preds = torch.argmax(torch.softmax(span_logits, dim=-1), dim=-1)
        all_spans = []
        for i in range(batch_size):
            mask = attention_mask[i].bool()
            preds = span_preds[i][mask]
            spans = []
            start_idx = None
            for j, label in enumerate(preds):
                if label == 1:  # B
                    if start_idx is not None:
                        spans.append((start_idx, j - 1))
                    start_idx = j
                elif label == 0:  # O
                    if start_idx is not None:
                        spans.append((start_idx, j - 1))
                        start_idx = None
                elif label == 2:  # I
                    if start_idx is None:
                        start_idx = j
            if start_idx is not None:
                spans.append((start_idx, len(preds) - 1))
            all_spans.append(spans)
        return all_spans

    def get_span_embeddings(self, hidden_states, spans, attention_mask):
        batch_size = hidden_states.size(0)
        max_spans = max([len(s) for s in spans], default=0)
        if max_spans == 0:
            return torch.zeros((batch_size, 0, self.hidden_size), device=hidden_states.device), torch.zeros((batch_size, 0), dtype=torch.bool, device=hidden_states.device)
        span_embeddings = torch.zeros((batch_size, max_spans, self.hidden_size), device=hidden_states.device)
        span_masks = torch.zeros((batch_size, max_spans), dtype=torch.bool, device=hidden_states.device)
        for i in range(batch_size):
            for j, (start, end) in enumerate(spans[i]):
                if j < max_spans:
                    span_embeddings[i, j] = hidden_states[i, start:end+1].mean(dim=0)
                    span_masks[i, j] = True
        return span_embeddings, span_masks

    def _combine_spans(self, holder_emb, holder_mask, target_emb, target_mask, expr_emb, expr_mask):
        batch_size = holder_emb.size(0)
        max_spans = holder_emb.size(1) + target_emb.size(1) + expr_emb.size(1)
        if max_spans == 0:
            return torch.zeros((batch_size, 0, self.hidden_size), device=holder_emb.device), torch.zeros((batch_size, 0), dtype=torch.bool, device=holder_emb.device)
        combined_emb = torch.zeros((batch_size, max_spans, self.hidden_size), device=holder_emb.device)
        combined_mask = torch.zeros((batch_size, max_spans), dtype=torch.bool, device=holder_emb.device)
        holder_size = holder_emb.size(1)
        target_size = target_emb.size(1)
        expr_size = expr_emb.size(1)
        combined_emb[:, :holder_size] = holder_emb
        combined_emb[:, holder_size:holder_size+target_size] = target_emb
        combined_emb[:, holder_size+target_size:] = expr_emb
        combined_mask[:, :holder_size] = holder_mask
        combined_mask[:, holder_size:holder_size+target_size] = target_mask
        combined_mask[:, holder_size+target_size:] = expr_mask
        return combined_emb, combined_mask

    def _create_span_pairs(self, span_embeddings, holder_mask, target_mask, expr_mask):
        batch_size = span_embeddings.size(0)
        holder_size = holder_mask.size(1)
        target_size = target_mask.size(1)
        expr_size = expr_mask.size(1)
        total_holders = holder_mask.sum(dim=1)
        total_targets = target_mask.sum(dim=1)
        total_expressions = expr_mask.sum(dim=1)
        max_pairs = torch.max(total_holders * total_expressions + total_targets * total_expressions)
        if max_pairs == 0:
            return None, None
        pair_embeddings = torch.zeros((batch_size, max_pairs, RELATION_EMBEDDING_DIM * 2), device=span_embeddings.device)
        pair_indices = torch.zeros((batch_size, max_pairs, 2), dtype=torch.long, device=span_embeddings.device)
        offset = holder_size + target_size
        for i in range(batch_size):
            pair_idx = 0
            for h_idx in range(holder_size):
                if not holder_mask[i, h_idx]:
                    continue
                for e_idx in range(expr_size):
                    if not expr_mask[i, e_idx] or pair_idx >= max_pairs:
                        continue
                    pair_embeddings[i, pair_idx] = torch.cat([span_embeddings[i, h_idx], span_embeddings[i, offset + e_idx]])
                    pair_indices[i, pair_idx] = torch.tensor([h_idx, offset + e_idx], device=span_embeddings.device)
                    pair_idx += 1
            for t_idx in range(target_size):
                if not target_mask[i, t_idx]:
                    continue
                for e_idx in range(expr_size):
                    if not expr_mask[i, e_idx] or pair_idx >= max_pairs:
                        continue
                    pair_embeddings[i, pair_idx] = torch.cat([span_embeddings[i, holder_size + t_idx], span_embeddings[i, offset + e_idx]])
                    pair_indices[i, pair_idx] = torch.tensor([holder_size + t_idx, offset + e_idx], device=span_embeddings.device)
                    pair_idx += 1
        return pair_embeddings, pair_indices

    def forward(self, input_ids, attention_mask, language_id=None, labels=None):
        batch_size = input_ids.size(0)
        encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask)
        hidden_states = encoder_outputs.last_hidden_state
        if self.use_adapters and language_id is not None:
            adapted_states = torch.zeros_like(hidden_states)
            for i in range(batch_size):
                adapted_states[i] = self.language_adapters[language_id[i].item()](hidden_states[i])
            hidden_states = adapted_states
        span_aware_states = self.span_attention(hidden_states, attention_mask)
        holder_logits = self.holder_detector(span_aware_states)
        target_logits = self.target_detector(span_aware_states)
        expression_logits = self.expression_detector(span_aware_states)
        if labels is not None:
            pass
        else:
            holder_spans = self.extract_spans(holder_logits, attention_mask)
            target_spans = self.extract_spans(target_logits, attention_mask)
            expression_spans = self.extract_spans(expression_logits, attention_mask)
            holder_embeddings, holder_masks = self.get_span_embeddings(span_aware_states, holder_spans, attention_mask)
            target_embeddings, target_masks = self.get_span_embeddings(span_aware_states, target_spans, attention_mask)
            expression_embeddings, expression_masks = self.get_span_embeddings(span_aware_states, expression_spans, attention_mask)
            all_span_embeddings, all_span_masks = self._combine_spans(holder_embeddings, holder_masks, target_embeddings, target_masks, expression_embeddings, expression_masks)
            relation_aware_embeddings = self.cross_span_attention(all_span_embeddings, all_span_masks)
            relation_pairs, pair_indices = self._create_span_pairs(relation_aware_embeddings, holder_masks, target_masks, expression_masks)
            relation_logits = None
            polarity_logits = None
            intensity_logits = None
            if relation_pairs is not None:
                relation_logits = self.relation_classifier(relation_pairs)
                expression_relation_aware = relation_aware_embeddings[:, holder_embeddings.size(1) + target_embeddings.size(1):, :]
                polarity_logits = self.polarity_classifier(expression_relation_aware)
                intensity_logits = self.intensity_classifier(expression_relation_aware)
            return {
                'holder_logits': holder_logits,
                'target_logits': target_logits,
                'expression_logits': expression_logits,
                'relation_logits': relation_logits,
                'polarity_logits': polarity_logits,
                'intensity_logits': intensity_logits,
                'holder_spans': holder_spans,
                'target_spans': target_spans,
                'expression_spans': expression_spans,
                'pair_indices': pair_indices
            }

class SentimentAnalyzer:
    def __init__(self, model_path, pretrained_model="xlm-roberta-base", device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        logger.info(f"Using device: {self.device}")
        self.tokenizer = XLMRobertaTokenizerFast.from_pretrained(pretrained_model)
        self.model = StructuredSentimentModel(pretrained_model_name=pretrained_model)
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(self.device)
        self.model.eval()
        logger.info(f"Model loaded from {model_path}")
        self.polarity_map = {0: "Positive", 1: "Negative", 2: "Neutral", 3: "None"}
        self.intensity_map = {0: "Strong", 1: "Average", 2: "Weak"}

    def _get_actual_text_spans(self, text, tokens, spans):
        text_spans = []
        offset_mapping = tokens.offset_mapping[0].tolist()
        for start_token, end_token in spans:
            if start_token >= len(offset_mapping) or end_token >= len(offset_mapping):
                continue
            start_char = offset_mapping[start_token][0]
            end_char = offset_mapping[end_token][1]
            if start_char < end_char and end_char <= len(text):
                span_text = text[start_char:end_char]
                text_spans.append((start_char, end_char, span_text))
        return text_spans

    def analyze(self, text):
        tokens = self.tokenizer(
            text,
            max_length=MAX_SEQ_LENGTH,
            padding='max_length',
            truncation=True,
            return_offsets_mapping=True,
            return_tensors='pt'
        )
        input_ids = tokens['input_ids'].to(self.device)
        attention_mask = tokens['attention_mask'].to(self.device)
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        holder_spans = outputs['holder_spans'][0]
        target_spans = outputs['target_spans'][0]
        expression_spans = outputs['expression_spans'][0]
        holder_text_spans = self._get_actual_text_spans(text, tokens, holder_spans)
        target_text_spans = self._get_actual_text_spans(text, tokens, target_spans)
        expression_text_spans = self._get_actual_text_spans(text, tokens, expression_spans)
        sentiment_opinions = []
        if outputs['polarity_logits'] is not None and outputs['intensity_logits'] is not None:
            polarity_preds = torch.argmax(outputs['polarity_logits'], dim=-1)
            intensity_preds = torch.argmax(outputs['intensity_logits'], dim=-1)
            num_expressions = min(len(expression_text_spans), polarity_preds.size(1))
            for i in range(num_expressions):
                polarity_idx = polarity_preds[0, i].item()
                intensity_idx = intensity_preds[0, i].item()
                opinion = {
                    "expression": expression_text_spans[i][2],
                    "expression_span": f"{expression_text_spans[i][0]}:{expression_text_spans[i][1]}",
                    "polarity": self.polarity_map[polarity_idx],
                    "intensity": self.intensity_map[intensity_idx],
                }
                if outputs['pair_indices'] is not None:
                    for pair_idx in range(outputs['pair_indices'].size(1)):
                        idx1, idx2 = outputs['pair_indices'][0, pair_idx]
                        expr_offset = len(holder_spans) + len(target_spans)
                        if idx2 == expr_offset + i:
                            if idx1 < len(holder_spans):
                                holder_idx = idx1.item()
                                if holder_idx < len(holder_text_spans):
                                    opinion["holder"] = holder_text_spans[holder_idx][2]
                                    opinion["holder_span"] = f"{holder_text_spans[holder_idx][0]}:{holder_text_spans[holder_idx][1]}"
                            else:
                                target_idx = idx1.item() - len(holder_spans)
                                if target_idx < len(target_text_spans):
                                    opinion["target"] = target_text_spans[target_idx][2]
                                    opinion["target_span"] = f"{target_text_spans[target_idx][0]}:{target_text_spans[target_idx][1]}"
                if "holder" not in opinion:
                    opinion["holder"] = ""
                    opinion["holder_span"] = "0:0"
                if "target" not in opinion:
                    opinion["target"] = ""
                    opinion["target_span"] = "0:0"
                sentiment_opinions.append(opinion)
        return {
            "text": text,
            "holders": [span[2] for span in holder_text_spans],
            "targets": [span[2] for span in target_text_spans],
            "expressions": [span[2] for span in expression_text_spans],
            "opinions": sentiment_opinions
        }

class SentimentDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=MAX_SEQ_LENGTH):
        self.max_length = max_length
        self.tokenizer = tokenizer
        self.examples = []
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        for entry in data:
            processed = self.process_example(entry)
            if processed:
                self.examples.append(processed)

    def process_example(self, entry):
        text = entry.get('text', '')
        sent_id = entry.get('sent_id', '')
        opinions = entry.get('opinions', [])
        if not text:
            return None
        tokenized = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_offsets_mapping=True,
            return_tensors='pt'
        )
        offset_mapping = tokenized['offset_mapping'][0]
        input_ids = tokenized['input_ids'][0]
        attention_mask = tokenized['attention_mask'][0]
        holder_labels = torch.zeros(self.max_length, dtype=torch.long)
        target_labels = torch.zeros(self.max_length, dtype=torch.long)
        expression_labels = torch.zeros(self.max_length, dtype=torch.long)
        opinion_data = []
        for opinion in opinions:
            holder_span = self._extract_span(opinion, 'Source', offset_mapping)
            target_span = self._extract_span(opinion, 'Target', offset_mapping)
            expression_span = self._extract_span(opinion, 'Polar_expression', offset_mapping)
            polarity = polarity_map.get(opinion.get('Polarity', 'None'), 3)
            intensity = intensity_map.get(opinion.get('Intensity', 'Average'), 1)
            if all(span is not None for span in [holder_span, target_span, expression_span]):
                opinion_data.append({
                    'holder_span': holder_span,
                    'target_span': target_span,
                    'expression_span': expression_span,
                    'polarity': polarity,
                    'intensity': intensity,
                })
                self._mark_span(holder_labels, holder_span[0], holder_span[1])
                self._mark_span(target_labels, target_span[0], target_span[1])
                self._mark_span(expression_labels, expression_span[0], expression_span[1])
        return {
            'sent_id': sent_id,
            'text': text,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'holder_labels': holder_labels,
            'target_labels': target_labels,
            'expression_labels': expression_labels,
            'opinion_data': opinion_data,
            'offset_mapping': offset_mapping
        }

    def _extract_span(self, opinion, key, offset_mapping):
        span_data = opinion.get(key, [[], []])[1]
        if not span_data:
            return (0, 0)
        try:
            start, end = map(int, span_data[0].split(':'))
            start_token = end_token = 0
            for idx, (ts, te) in enumerate(offset_mapping):
                if ts == 0 and te == 0:
                    continue
                if ts <= start < te:
                    start_token = idx
                if ts < end <= te:
                    end_token = idx
            return (start_token, end_token)
        except Exception:
            return (0, 0)

    def _mark_span(self, labels, start_idx, end_idx):
        try:
            labels[start_idx] = 1
            if end_idx > start_idx:
                labels[start_idx+1:end_idx+1] = 2
        except Exception as e:
            logger.warning(f"Span marking error {start_idx}:{end_idx}: {e}")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

def custom_collate(batch):
    collated = {
        'input_ids': torch.stack([item['input_ids'] for item in batch]),
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
        'holder_labels': torch.stack([item['holder_labels'] for item in batch]),
        'target_labels': torch.stack([item['target_labels'] for item in batch]),
        'expression_labels': torch.stack([item['expression_labels'] for item in batch]),
        'sent_id': [item['sent_id'] for item in batch],
        'text': [item['text'] for item in batch],
        'opinion_data': [item['opinion_data'] for item in batch],
        'offset_mapping': torch.stack([item['offset_mapping'] for item in batch])
    }
    return collated

def token_span_to_char_span(token_span, offset_mapping, text):
    if token_span == (0, 0):
        return [[], []]
    start, end = token_span
    char_start = offset_mapping[start, 0].item()
    char_end = offset_mapping[end, 1].item()
    span_text = text[char_start:char_end]
    offset_str = f"{char_start}:{char_end}"
    return [[span_text], [offset_str]]

def opinions_match(pred_op, gt_op):
    pred_holder, pred_target, pred_expr, pred_pol, pred_int = pred_op
    gt_holder, gt_target, gt_expr, gt_pol, gt_int = gt_op
    pred_int_adj = 'Standard' if pred_int == 'Average' else pred_int
    gt_int_adj = 'Standard' if gt_int == 'Average' else gt_int
    return (pred_holder == gt_holder and
            pred_target == gt_target and
            pred_expr == gt_expr and
            pred_pol == gt_pol and
            pred_int_adj == gt_int_adj)

def evaluate_predictions(all_predicted, all_ground_truth):
    tp = fp = fn = 0
    for sent_pred, sent_gt in zip(all_predicted, all_ground_truth):
        matched_preds = set()
        for gt_op in sent_gt:
            for i, pred_op in enumerate(sent_pred):
                if i not in matched_preds and opinions_match(pred_op, gt_op):
                    tp += 1
                    matched_preds.add(i)
                    break
            else:
                fn += 1
        fp += len(sent_pred) - len(matched_preds)
    print(f"TP: {tp}, FP: {fp}, FN: {fn}")
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    return precision, recall, f1

# Upload test.json
print("Upload your test.json file:")
uploaded = files.upload()
test_file = list(uploaded.keys())[0]
with open(test_file, 'wb') as f:
    f.write(uploaded[test_file])

# Initialize SentimentAnalyzer
model_path = "/content/drive/MyDrive/NLP-Project/opener_en_best_model_f1_0.7344.pt"
pretrained_model = "xlm-roberta-base"
analyzer = SentimentAnalyzer(model_path, pretrained_model, DEVICE)

# Data setup
tokenizer = analyzer.tokenizer
test_dataset = SentimentDataset(test_file, tokenizer)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                            collate_fn=custom_collate)

# Prediction loop
all_predicted = []
all_ground_truth = []
all_sentences = []

with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Predicting"):
        for i in range(len(batch['sent_id'])):
            sent_id = batch['sent_id'][i]
            text = batch['text'][i]
            offset_mapping = batch['offset_mapping'][i]
            analysis = analyzer.analyze(text)
            predicted_opinions = []
            for opinion in analysis['opinions']:
                holder_span = (0, 0)
                target_span = (0, 0)
                expression_span = (0, 0)
                if opinion['holder_span'] != "0:0":
                    start_char, end_char = map(int, opinion['holder_span'].split(':'))
                    for idx, (ts, te) in enumerate(offset_mapping):
                        if ts <= start_char < te:
                            holder_span = (idx, holder_span[1])
                        if ts < end_char <= te:
                            holder_span = (holder_span[0], idx)
                if opinion['target_span'] != "0:0":
                    start_char, end_char = map(int, opinion['target_span'].split(':'))
                    for idx, (ts, te) in enumerate(offset_mapping):
                        if ts <= start_char < te:
                            target_span = (idx, target_span[1])
                        if ts < end_char <= te:
                            target_span = (target_span[0], idx)
                if opinion['expression_span'] != "0:0":
                    start_char, end_char = map(int, opinion['expression_span'].split(':'))
                    for idx, (ts, te) in enumerate(offset_mapping):
                        if ts <= start_char < te:
                            expression_span = (idx, expression_span[1])
                        if ts < end_char <= te:
                            expression_span = (expression_span[0], idx)
                predicted_opinions.append((
                    holder_span,
                    target_span,
                    expression_span,
                    opinion['polarity'],
                    opinion['intensity']
                ))
            if len(all_predicted) < 3:
                print(f"\nSentence: {text}")
                print(f"Predicted: {predicted_opinions}")
                print(f"Ground Truth: {batch['opinion_data'][i]}")
            opinions_list = []
            for opinion in analysis['opinions']:
                source = [[opinion['holder']], [opinion['holder_span']]] if opinion['holder'] else [[], []]
                target = [[opinion['target']], [opinion['target_span']]] if opinion['target'] else [[], []]
                polar_expr = [[opinion['expression']], [opinion['expression_span']]] if opinion['expression'] else [[], []]
                opinion_dict = {
                    "Source": source,
                    "Target": target,
                    "Polar_expression": polar_expr,
                    "Polarity": opinion['polarity'],
                    "Intensity": opinion['intensity']
                }
                opinions_list.append(opinion_dict)
            sentence_dict = {
                "sent_id": sent_id,
                "text": text,
                "opinions": opinions_list
            }
            all_sentences.append(sentence_dict)
            gt_opinions = [
                (
                    op['holder_span'],
                    op['target_span'],
                    op['expression_span'],
                    polarity_reverse_map[op['polarity']],
                    intensity_reverse_map[op['intensity']]
                ) for op in batch['opinion_data'][i]
            ]
            all_predicted.append(predicted_opinions)
            all_ground_truth.append(gt_opinions)

# Save output
output_file = 'output.json'
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(all_sentences, f, indent=2)
logger.info(f"Predictions saved to {output_file}")

# Evaluate
precision, recall, f1 = evaluate_predictions(all_predicted, all_ground_truth)
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

# Check empty predictions
empty_count = sum(1 for pred in all_predicted if not pred)
print(f"Empty predictions: {empty_count}/{len(all_predicted)}")

# Download output
files.download(output_file)