In [None]:
import numpy as np
import pandas as pd
import random
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from torch_geometric.nn import GATConv
from torch_geometric.utils import dense_to_sparse
from torch.optim import AdamW
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
import os
from sklearn.utils.class_weight import compute_class_weight


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)


def load_and_preprocess_data(filepath):
    df = pd.read_excel(filepath)
    df.columns = df.columns.str.strip()

    aspect_polarity_col = next((col for col in df.columns if 'Aspect' in col and 'Polarity' in col), None)
    if not aspect_polarity_col:
        raise KeyError("Could not find the '{Aspect category, Sentiment Polarity}' column. Please verify the column names.")

    def parse_aspect_polarities(aspect_polarities_str):
        """
        Parses a string containing multiple aspect-polarity pairs in the format:
        "{aspect1, polarity1}, {aspect2, polarity2}, ..."
        Returns a list of (aspect, polarity) tuples.
        """
        pattern = r"\{([^,{}]+),\s*([^,{}]+)\}"
        matches = re.findall(pattern, aspect_polarities_str)
        return [(aspect.strip(), polarity.strip()) for aspect, polarity in matches]

    df['Aspect_Polarities_List'] = df[aspect_polarity_col].apply(parse_aspect_polarities)
    df_exploded = df.explode('Aspect_Polarities_List').reset_index(drop=True)
    df_exploded[['Aspect', 'Polarity']] = pd.DataFrame(df_exploded['Aspect_Polarities_List'].tolist(), index=df_exploded.index)
    df_exploded = df_exploded.drop(columns=[aspect_polarity_col, 'Aspect_Polarities_List'])

    expected_labels = {'positive', 'negative', 'neutral'}
    initial_label_counts = df_exploded['Polarity'].value_counts()
    print(f"Label distribution before cleaning:\n{initial_label_counts}\n")

    malformed_labels = df_exploded[~df_exploded['Polarity'].isin(expected_labels)]['Polarity'].unique()
    if len(malformed_labels) > 0:
        print(f"Found malformed labels: {malformed_labels}")
        print("Attempting to clean them...\n")

        def extract_polarity(label):
            tokens = label.split(',')
            last_token = tokens[-1].strip()
            if last_token in expected_labels:
                return last_token
            else:
                return None 

        df_exploded['Clean_Polarity'] = df_exploded['Polarity'].apply(extract_polarity)
        valid_df = df_exploded.dropna(subset=['Clean_Polarity']).copy()
        invalid_df = df_exploded[df_exploded['Clean_Polarity'].isna()].copy()

        print(f"Dropped {len(invalid_df)} samples due to unresolvable labels.\n")

        valid_df['Polarity'] = valid_df['Clean_Polarity']
        valid_df = valid_df.drop(columns=['Clean_Polarity'])

        cleaned_label_counts = valid_df['Polarity'].value_counts()
        print(f"Label distribution after cleaning:\n{cleaned_label_counts}\n")
    else:
        valid_df = df_exploded.copy()
        print("No malformed labels found.\n")

    balanced_df = valid_df[valid_df['Polarity'] != 'neutral']
    label_counts_after_drop = balanced_df['Polarity'].value_counts()
    print(f"Label distribution after dropping 'neutral':\n{label_counts_after_drop}\n")

    return balanced_df

POS_TAGS = {
    'PRON': set([
        'আমি', 'তুমি', 'আপনি', 'সে', 'তারা', 'আমরা', 'উনি', 'তোমরা', 'তোমাদের',
        'এরা', 'এইগুলি', 'সেইগুলি', 'তাহারা', 'এটা', 'এগুলো', 'আপনারা', 'আমাদের'
    ]),
    'DET': set([
        'একটি', 'সেই', 'এই', 'তাই', 'কোনো', 'সব', 'এমন', 'যে', 'যত', 'প্রতিটি',
        'কিছু', 'কিছুই', 'প্রত্যেকটি', 'অনেকগুলো', 'তিনটি', 'দুইটি', 'একত্রে',
        'এখানে', 'সেখানে', 'খুবই', 'মোটেও'
    ]),
    'VERB': set([
        'খাওয়া', 'পড়া', 'যাওয়া', 'আসা', 'করা', 'হওয়া', 'লিখা', 'বলা', 'শোনা',
        'খাও', 'পড়', 'যাও', 'আস', 'কর', 'হয়', 'লিখ', 'বল', 'শুন',
        'ছিলাম', 'ছিল', 'থাক', 'থাকছি', 'থাকবো', 'করবো', 'হবে',
        'দেখা', 'শোনা', 'ভালোবাসা', 'হাসা', 'কান্না', 'নাচা', 'কাজ করা',
        'শেখা', 'গণনা করা', 'বিস্তারিত লেখা', 'বিশ্লেষণ করা',
        'আলোচনা করা', 'পরিচালনা করা', 'নির্মাণ করা', 'সমাধান করা', 'উন্নয়ন করা',
        'করানো', 'নাচানো', 'দেখানো', 'শেখানো', 'লেখানো', 'নতুন', 'বজানো',
        'শোনানো', 'পড়ানো', 'আসানো', 'লুকানো', 'বাঁচানো', 'ধোঁয়া', 'গড়ানো',
        'বন্ধ করা', 'খেলানো', 'জীবন', 'ভালোবাসা', 'হাসা', 'কান্না',
        'উঠানো', 'সাজানো', 'সাফ করা', 'সমাধান করা', 'উন্নয়ন করা',
        'পরিচালনা করা', 'আলোচনা করা', 'নির্মাণ করা', 'সংরক্ষণ করা',
        'আনন্দ করা', 'দূর করা', 'শুধু করা', 'অধ্যয়ন করা', 'আঁকা',
        'খোলা', 'চলাচল করা',
        # Domain-Specific Verb Suffixes
        'ড্রাইভ করা',    # To drive
        'কল করা',       # To call
        'স্ক্রিন দেখা',    # To view the screen
        'রিজিস্টার করা',  # To register
        'অর্ডার করা',      # To order
        'প্রচার করা',      # To promote
        'বিজ্ঞাপন করা',    # To advertise
        'অ্যাপডেট করা',    # To update
        'লঞ্চ করা',        # To launch
        'রক্ষণাবেক্ষণ করা',  # To maintain
        'আনন্দ করা',        # To enjoy
        'দূর করা',          # To remove
        'শুধু করা',        # To do only
        'অধ্যয়ন করা',      # To study
        'আঁকা',              # To draw
        'খোলা',              # To open
        'চলাচল করা',        # To move
        'সাফ করা',          # To clean
        'কঠিনভাবে',          # To do with difficulty
        'সহজে',              # To do easily
        'সম্পূর্ণভাবে',      # To do completely
        'পরিচ্ছন্নভাবে',    # To do cleanly
        'গভীরভাবে',          # To do deeply
        'তীক্ষ্ণভাবে',        # To do sharply
        'অবিলম্বে'           # To do immediately
    ]),
    'NOUN': set([
        'বই', 'স্কুল', 'কলেজ', 'বিশ্ববিদ্যালয়', 'ছাত্র', 'ছাত্রী', 'শিক্ষক',
        'মানুষ', 'গাছ', 'পাখি', 'বাড়ি', 'পথ', 'মাটি', 'জল', 'আকাশ',
        'চাকরি', 'বাজার', 'দোকান', 'রান্না', 'খেলা', 'সাহিত্য',
        'চিঠি', 'গান', 'ছবি', 'কবিতা', 'চলচ্চিত্র', 'প্রেম', 'বন্ধুত্ব',
        'কম্পিউটার', 'মোবাইল', 'ইন্টারনেট', 'তথ্য', 'বিজ্ঞান', 'গবেষণা',
        'দেশ', 'শহর', 'গ্রাম', 'রাজধানী', 'বাংলাদেশ', 'ভারত', 'পাকিস্তান',
        'পরীক্ষা', 'শিক্ষা', 'কাজ', 'পরিবার', 'সম্পর্ক', 'সময়', 'ব্যবসা',
        'স্বাস্থ্য', 'দ্রব্য', 'পরিবেশ', 'নিয়ম', 'আইন', 'নীতি',
        'বিজনেস', 'প্রযুক্তি', 'অর্থনীতি', 'রাজনীতি', 'মনোবিজ্ঞান', 'ভাষা',
        'সংস্কৃতি', 'ধর্ম', 'ক্রীড়া', 'সামাজিক', 'পরিবহন', 'উদ্যোগ', 'সাহস',
        'সমাজ', 'বিজ্ঞান', 'গবেষণা', 'নিয়ম', 'আইন', 'নীতি', 'অর্থ', 'স্বাস্থ্য',
        'দ্রব্য', 'পরিবেশ', 'সময়', 'ব্যবসা',
        # Domain-Specific Nouns
        'গাড়ি', 'ফোন', 'মুভি', 'রেস্টুরেন্ট', 'মডেল', 'ইঞ্জিন', 'টেকনোলজি',
        'ক্যামেরা', 'স্ক্রিন', 'হেডসেট', 'স্টাইল', 'সিনেমা', 'ক্যাটারিং',
        'বেড়া', 'টেবিল', 'চেয়ার', 'রিং', 'ডিসপ্লে', 'হ্যান্ডসেট',
        'পার্টি', 'বুট', 'ট্রান্সমিশন'
    ]),
    'ADJ': set([
        'ভাল', 'মন্দ', 'বড়', 'ছোট', 'লাল', 'নীল', 'সবুজ', 'সুন্দর', 'অসুন্দর',
        'তাড়াতাড়ি', 'ধীরে', 'শান্ত', 'গুরুত্বপূর্ণ', 'জটিল', 'সহজ', 'বিরাট', 'বিশাল',
        'উজ্জ্বল', 'গাঢ়', 'হালকা', 'নরম', 'কঠিন', 'দৃঢ়', 'নীরব', 'উচ্চ', 'নিম্ন',
        'তীক্ষ্ণ', 'নরমাল', 'উজ্জ্বল', 'মধুর', 'দ্রুত', 'সুস্বাদু', 'কঠোর', 'স্বচ্ছ',
        'অস্বচ্ছ', 'শীতল', 'উষ্ণ', 'অত্যন্ত', 'সামান্য', 'মাঝারি', 'গভীর',
        'হাই-টেক', 'অটো', 'স্মার্ট', 'প্রিমিয়াম', 'অ্যাডভান্সড', 'কাজকরি',
        'লাইটওয়েট', 'স্টাইলিশ', 'আনন্দদায়ক', 'কমফোর্টেবল'
    ]),
    'ADV': set([
        'তাড়াতাড়ি', 'ধীরে', 'ভালভাবে', 'মন্দভাবে', 'একটু', 'অনেক', 'খুব',
        'যথেষ্ট', 'তেমন', 'মোটামুটি', 'ধীরে-ধীরে', 'আস্তে', 'যথাযথভাবে', 'বিশেষভাবে',
        'পরিচ্ছন্নভাবে', 'সম্পূর্ণভাবে', 'অবিলম্বে', 'সহজে', 'কঠিনভাবে', 'গভীরভাবে',
        'হাই-টেকলি', 'স্মার্টলি', 'প্রিমিয়ামলি', 'অ্যাডভান্সডলি'
    ]),
    'CONJ': set([
        'এবং', 'কিন্তু', 'তবুও', 'অথবা', 'কারণ', 'যদিও', 'তাই', 'যেহেতু',
        'তবে', 'তবেই', 'যেমন', 'যেহেতু', 'অতএব', 'নেইলে', 'অথবা', 'যদি'
    ]),
    'INTJ': set([
        'আরে', 'ওহ', 'বাহ', 'আহা', 'হায়', 'ধন্যবাদ', 'কি', 'হ্যাঁ', 'না', 'ঠিক আছে',
        'অচ্ছা', 'মারিয়া', 'ওরে', 'দয়া করে', 'দুঃখিত', 'ওই', 'হুম', 'উফ', 'আসসালামু আলাইকুম'
    ]),
    'PREP': set([
        'এর', 'উপর', 'নিচে', 'পাশে', 'পরে', 'মধ্যে', 'জন্য', 'দ্বারা', 'দিকে', 'সঙ্গে',
        'সহ', 'পর', 'অধীন', 'বাবে', 'নিম্নে', 'উর্ধ্বে', 'পর্যন্ত', 'অন্তर्गत', 'আগে',
        'পিছনে', 'সঙ্গে', 'পাশের'
    ]),
    'PUNCT': set([
        '।', ',', '.', '!', '?', ';', ':', '-', '–', '—', '(', ')', '[', ']', '{', '}', '"', "'"
    ])
}

# Suffix-based rules for POS tagging
NOUN_SUFFIXES = set([
    'তা', 'ি', 'া', 'মান', 'পনা', 'কর্ম', 'গৃহ', 'পথ', 'কথা',
    'নাম', 'গতি', 'বোধ', 'বাণী', 'প্রণয়', 'দর্শন', 'সংগঠন',
    'উদ্দেশ্য', 'চিন্তা', 'প্রভাব', 'প্রত্যাশা', 'পরিচয়', 'ব্যবহার',
    'বিজনেস', 'প্রযুক্তি', 'অর্থনীতি', 'রাজনীতি', 'মনোবিজ্ঞান', 'ভাষা',
    'সংস্কৃতি', 'ধর্ম', 'ক্রীড়া', 'সামাজিক', 'পরিবহন', 'উদ্যোগ', 'সাহস',
    'সমাজ', 'বিজ্ঞান', 'গবেষণা', 'নিয়ম', 'আইন', 'নীতি', 'অর্থ', 'স্বাস্থ্য',
    'দ্রব্য', 'পরিবেশ', 'সময়', 'ব্যবসা',
    # Domain-Specific Noun Suffixes
    'গাড়ি', 'ফোন', 'মুভি', 'রেস্টুরেন্ট', 'মডেল', 'ইঞ্জিন', 'টেকনোলজি',
    'ক্যামেরা', 'স্ক্রিন', 'হেডসেট', 'স্টাইল', 'সিনেমা', 'ক্যাটারিং',
    'বেড়া', 'টেবিল', 'চেয়ার', 'রিং', 'ডিসপ্লে', 'হ্যান্ডসেট',
    'পার্টি', 'বুট', 'ট্রান্সমিশন'
])

VERB_SUFFIXES = set([
    'তে', 'া', 'চ্ছি', 'চ্ছিল', 'বে', 'ল', 'য়', 'ছে', 'ছি',
    'নো', 'ানো', 'ইয়া', 'আনা', 'যাওয়া', 'করানো', 'নাচানো',
    'দেখানো', 'শেখানো', 'লেখানো', 'নতুন', 'বজানো', 'শোনানো',
    'পড়ানো', 'আসানো', 'লুকানো', 'বাঁচানো', 'ধোঁয়া', 'গড়ানো',
    'বন্ধ করা', 'খেলানো', 'জীবন', 'ভালোবাসা', 'হাসা', 'কান্না',
    'উঠানো', 'সাজানো', 'সাফ করা', 'সমাধান করা', 'উন্নয়ন করা',
    'পরিচালনা করা', 'আলোচনা করা', 'নির্মাণ করা', 'সংরক্ষণ করা',
    'আনন্দ করা', 'দূর করা', 'শুধু করা', 'অধ্যয়ন করা', 'আঁকা',
    'খোলা', 'চলাচল করা',
    # Domain-Specific Verb Suffixes
    'ড্রাইভ করা',    # To drive
    'কল করা',       # To call
    'স্ক্রিন দেখা',    # To view the screen
    'রিজিস্টার করা',  # To register
    'অর্ডার করা',      # To order
    'প্রচার করা',      # To promote
    'বিজ্ঞাপন করা',    # To advertise
    'অ্যাপডেট করা',    # To update
    'লঞ্চ করা',        # To launch
    'রক্ষণাবেক্ষণ করা',  # To maintain
    'আনন্দ করা',        # To enjoy
    'দূর করা',          # To remove
    'শুধু করা',        # To do only
    'অধ্যয়ন করা',      # To study
    'আঁকা',              # To draw
    'খোলা',              # To open
    'চলাচল করা',        # To move
    'সাফ করা',          # To clean
    'কঠিনভাবে',          # To do with difficulty
    'সহজে',              # To do easily
    'সম্পূর্ণভাবে',      # To do completely
    'পরিচ্ছন্নভাবে',    # To do cleanly
    'গভীরভাবে',          # To do deeply
    'তীক্ষ্ণভাবে',        # To do sharply
    'অবিলম্বে'           # To do immediately
])


def tokenize(sentence):
    """
    Tokenizes a Bengali sentence by separating punctuation from words.

    Args:
        sentence (str): Input sentence.

    Returns:
        list: List of tokens.
    """
    tokens = re.findall(r'\w+|[^\w\s]', sentence, re.UNICODE)
    return tokens

def tag_word(word):
    """
    Tags a single word based on predefined dictionaries and morphological suffix rules.

    Args:
        word (str): The word to tag.

    Returns:
        str: The POS tag.
    """
    for pos, words in POS_TAGS.items():
        if word in words:
            return pos

    for suffix in sorted(VERB_SUFFIXES, key=lambda x: len(x), reverse=True):
        if word.endswith(suffix):
            return 'VERB'
    for suffix in sorted(NOUN_SUFFIXES, key=lambda x: len(x), reverse=True):
        if word.endswith(suffix):
            return 'NOUN'

    return 'NOUN'

def define_relations_from_pos(pos_tags):
    """
    Define syntactic relations based on POS tags.

    Args:
        pos_tags (list of tuples): List containing (word, POS_tag) pairs.

    Returns:
        list of tuples: Each tuple contains (head_index, dependent_index, relation).
    """
    relations = []
    n = len(pos_tags)
    
    for i, (word, tag) in enumerate(pos_tags):
        if tag == 'VERB':
            for j in range(i-1, -1, -1):
                if j >= n:
                    print(f"Warning: j={j} exceeds the number of tokens.")
                    continue
                if pos_tags[j][1] in {'PRON', 'NOUN'}:
                    relations.append((j, i, 'subj'))
                    break
        if tag == 'VERB':
            for j in range(i+1, n):
                if j >= n:
                    print(f"Warning: j={j} exceeds the number of tokens.")
                    break
                if pos_tags[j][1] == 'NOUN':
                    relations.append((i, j, 'obj'))
                    break
        if tag == 'ADJ':
            if i+1 < n and pos_tags[i+1][1] == 'NOUN':
                relations.append((i+1, i, 'amod'))
    
    return relations

def pos_tagging_from_scratch(sentence):
    """
    Performs advanced POS tagging based on predefined dictionaries and morphological rules,
    and derives syntactic relations.

    Args:
        sentence (str): Input sentence in Bengali.

    Returns:
        list of tuples: List containing (word, POS_tag) pairs.
        list of tuples: List containing (head_index, dependent_index, relation) triples.
    """
    pos_tags = []
    relations = []
    words = tokenize(sentence)
    n = len(words)

    if n == 0:
        print("Warning: Encountered an empty sentence.")
        return pos_tags, relations

    for i, word in enumerate(words):
        tag = tag_word(word)
        pos_tags.append((word, tag))

    relations = define_relations_from_pos(pos_tags)

    if len(pos_tags) != n:
        print(f"Error: pos_tags length {len(pos_tags)} does not match number of words {n}")
        print(f"Sentence: '{sentence}'")
        print(f"Words: {words}")

    return pos_tags, relations

def construct_position_graph(length):
    adjacency_matrix = np.eye(length)
    for i in range(length):
        for j in range(length):
            adjacency_matrix[i][j] = 1 / (abs(i - j) + 1)
    return adjacency_matrix

def construct_semantic_similarity_graph(embeddings):
    similarity_matrix = np.matmul(embeddings, embeddings.T)
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    similarity_matrix /= np.matmul(norms, norms.T) + 1e-8
    return similarity_matrix

def construct_syntax_graph(pos_tags, relation_weights=None):
    """
    Constructs a syntax adjacency matrix based on POS tags.

    Args:
        pos_tags (list of tuples): List containing (word, POS_tag) pairs.
        relation_weights (dict, optional): Weights for different relation types.

    Returns:
        np.ndarray: Syntax adjacency matrix.
    """
    if relation_weights is None:
        relation_weights = {'subj': 1.0, 'obj': 1.0, 'amod': 1.0}

    relations = define_relations_from_pos(pos_tags)
    n = len(pos_tags)
    adjacency_matrix = np.zeros((n, n))
    
    for (head, dependent, relation) in relations:
        adjacency_matrix[head][dependent] = relation_weights.get(relation, 1.0)
    
    return adjacency_matrix

class BidirectionalCrossAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(BidirectionalCrossAttention, self).__init__()
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, bert_embeddings, gcn_embeddings):
        query = self.query_proj(bert_embeddings) 
        key = self.key_proj(gcn_embeddings)   
        value = self.value_proj(gcn_embeddings)  

        attention_scores = torch.matmul(query, key.transpose(-1, -2)) 
        attention_weights = self.softmax(attention_scores)      

        attended_gcn = torch.matmul(attention_weights, value)  

        return attended_gcn

class HighwayNetworkGate(nn.Module):
    def __init__(self, input_dim, gate_dim):
        super(HighwayNetworkGate, self).__init__()
        self.H = nn.Sequential(
            nn.Linear(input_dim, gate_dim),
            nn.ReLU(),
            nn.Linear(gate_dim, input_dim)
        )
        self.T = nn.Sequential(
            nn.Linear(input_dim, gate_dim),
            nn.ReLU(),
            nn.Linear(gate_dim, input_dim),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        Hx = self.H(x)       
        Tx = self.T(x)        
        Cx = 1 - Tx            
        output = Tx * Hx + Cx * x
        return output

def compute_dynamic_adjacency(pos_tags, max_length):
    n = min(len(pos_tags), max_length)
    pos_tags = pos_tags[:n]

    syntax_matrix = construct_syntax_graph(pos_tags)
    position_matrix = construct_position_graph(n)
    combined_matrix = syntax_matrix + position_matrix[:n, :n]
    max_val = combined_matrix.max()
    if max_val != 0:
        combined_matrix = combined_matrix / max_val

    padded_matrix = np.zeros((max_length, max_length))
    padded_matrix[:n, :n] = combined_matrix

    return torch.tensor(padded_matrix, dtype=torch.float)

def format_input(comment, aspect):
    return f"Comment: {comment} [SEP] Aspect: {aspect}"


class ABSA_Dataset(Dataset):
    def __init__(self, df, tokenizer, max_length=128):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.label2id = {'negative': 0, 'neutral': 1, 'positive': 2}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        comment = self.df.iloc[idx]['Comment']
        aspect = self.df.iloc[idx]['Aspect']
        polarity = self.df.iloc[idx]['Polarity']

        formatted_input = format_input(comment, aspect)

        encoding = self.tokenizer.encode_plus(
            text=formatted_input,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()

        aspect_tokens = self.tokenizer.tokenize(aspect)
        aspect_token_ids = self.tokenizer.convert_tokens_to_ids(aspect_tokens)
        input_id_list = input_ids.tolist()

        aspect_position = 0 
        for i in range(len(input_id_list)):
            if input_id_list[i:i+len(aspect_token_ids)] == aspect_token_ids:
                aspect_position = i
                break

        pos_tags, relations = pos_tagging_from_scratch(comment)
        dynamic_adjacency = compute_dynamic_adjacency(pos_tags, self.max_length)  

        label = self.label2id.get(polarity.lower(), None)

        if label is None:
            label = self.label2id['neutral']

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'dynamic_adjacency': dynamic_adjacency,
            'label': torch.tensor(label, dtype=torch.long),
            'aspect_position': torch.tensor(aspect_position, dtype=torch.long)
        }

class Hybrid_Model(nn.Module):
    def __init__(self, bert_model_name='bert-base-multilingual-cased', hidden_dim=768, num_classes=3, max_length=128):
        super(Hybrid_Model, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        
        self.gcn_syntax = GATConv(in_channels=hidden_dim, out_channels=hidden_dim, heads=1, concat=False)
        self.gcn_semantic = GATConv(in_channels=hidden_dim, out_channels=hidden_dim, heads=1, concat=False)
        
        self.cross_attention_syntax = BidirectionalCrossAttention(hidden_dim)
        self.cross_attention_semantic = BidirectionalCrossAttention(hidden_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim*2, nhead=8, dim_feedforward=2048, dropout=0.1)
        self.mambaformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
        
        self.gcn_aspect = GATConv(in_channels=hidden_dim, out_channels=hidden_dim, heads=1, concat=False)
        
        self.kangate = HighwayNetworkGate(input_dim=hidden_dim*3, gate_dim=hidden_dim)  
        self.classification_head = nn.Linear(hidden_dim * 4, num_classes)  

        self.max_length = max_length  

    def adjacency_to_edge_index(self, adjacency_matrix):
        """
        Convert an adjacency matrix to edge_index and edge_attr for PyTorch Geometric.
        """
        edge_index, edge_attr = dense_to_sparse(adjacency_matrix)
        return edge_index, edge_attr

    def forward(self, input_ids, attention_mask, adjacency_matrices, aspect_positions):
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        H = bert_outputs.last_hidden_state 

        batch_size, seq_len, hidden_dim = H.size()
        aspect_embeddings = []
        for i in range(batch_size):
            pos = aspect_positions[i]
            if pos >= seq_len:
                pos = 0  
            aspect_emb = H[i, pos, :]  
            aspect_embeddings.append(aspect_emb)
        aspect_embeddings = torch.stack(aspect_embeddings, dim=0)

        semantic_adjacency_matrices = []
        for i in range(batch_size):
            H_i = H[i].detach().cpu().numpy()  
            semantic_matrix = construct_semantic_similarity_graph(H_i)
            if semantic_matrix.max() > 0:
                semantic_matrix = semantic_matrix / semantic_matrix.max()
            padded_semantic = np.zeros((self.max_length, self.max_length))
            current_len = min(H_i.shape[0], self.max_length)
            padded_semantic[:current_len, :current_len] = semantic_matrix[:self.max_length, :self.max_length]
            semantic_adjacency_matrices.append(torch.tensor(padded_semantic, dtype=torch.float))
        semantic_adjacency_matrices = torch.stack(semantic_adjacency_matrices, dim=0).to(input_ids.device) 

        syntax_adjacency = adjacency_matrices  
        semantic_adjacency = semantic_adjacency_matrices  

        H_gcn_syntax_list = []
        for i in range(batch_size):
            A_syntax = syntax_adjacency[i]
            edge_index_syntax, edge_attr_syntax = self.adjacency_to_edge_index(A_syntax)
            if edge_index_syntax.numel() == 0:
                H_gcn_syntax = H[i]
            else:
                H_gcn_syntax = self.gcn_syntax(H[i], edge_index_syntax, edge_attr_syntax)
            H_gcn_syntax_list.append(H_gcn_syntax)

        H_gcn_syntax = torch.stack(H_gcn_syntax_list, dim=0)  

        H_gcn_semantic_list = []
        for i in range(batch_size):
            A_semantic = semantic_adjacency[i]
            edge_index_semantic, edge_attr_semantic = self.adjacency_to_edge_index(A_semantic)
            if edge_index_semantic.numel() == 0:
                H_gcn_semantic = H[i]
            else:
                H_gcn_semantic = self.gcn_semantic(H[i], edge_index_semantic, edge_attr_semantic)
            H_gcn_semantic_list.append(H_gcn_semantic)

        H_gcn_semantic = torch.stack(H_gcn_semantic_list, dim=0) 
        attention_output_syntax = self.cross_attention_syntax(H, H_gcn_syntax)  

        attention_output_semantic = self.cross_attention_semantic(H, H_gcn_semantic)  
        concatenated_attentions = torch.cat((attention_output_syntax, attention_output_semantic), dim=-1) 

        concatenated_attentions = concatenated_attentions.permute(1, 0, 2)
        mambaformer_output = self.mambaformer(concatenated_attentions)  
        mambaformer_output = mambaformer_output.permute(1, 0, 2)

        H_gcn_aspect_list = []
        for i in range(batch_size):
            A_aspect = adjacency_matrices[i]
            edge_index_aspect, edge_attr_aspect = self.adjacency_to_edge_index(A_aspect)
            if edge_index_aspect.numel() == 0:
                H_gcn_aspect = H[i]
            else:
                H_gcn_aspect = self.gcn_aspect(H[i], edge_index_aspect, edge_attr_aspect)
            H_gcn_aspect_list.append(H_gcn_aspect)

        H_gcn_aspect = torch.stack(H_gcn_aspect_list, dim=0) 

        mambaformer_pooled = mambaformer_output.mean(dim=1) 
        gcn_aspect_pooled = H_gcn_aspect.mean(dim=1)       

        combined_features = torch.cat((mambaformer_pooled, gcn_aspect_pooled), dim=1) 
        gated_features = self.kangate(combined_features)
        final_representation = torch.cat((gated_features, aspect_embeddings), dim=1)  

        logits = self.classification_head(final_representation) 

        return logits

    def set_max_length(self, max_length):
        self.max_length = max_length

def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    total_correct = 0
    total_examples = 0

    pbar = tqdm(dataloader, desc="Training", leave=False)
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        dynamic_adjacency = batch['dynamic_adjacency'].to(device)
        labels = batch['label'].to(device)
        aspect_positions = batch['aspect_position'].to(device)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask, dynamic_adjacency, aspect_positions)

        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

        preds = torch.argmax(logits, dim=1)
        batch_correct = (preds == labels).sum().item()
        total_correct += batch_correct
        total_examples += labels.size(0)

        batch_acc = batch_correct / labels.size(0)
        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'acc': f"{batch_acc:.4f}"})

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_examples
    return avg_loss, accuracy

def eval_model(model, dataloader, device):
    model.eval()
    preds = []
    true_labels = []
    total_loss = 0
    total_correct = 0
    total_examples = 0

    pbar = tqdm(dataloader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            dynamic_adjacency = batch['dynamic_adjacency'].to(device)
            labels = batch['label'].to(device)
            aspect_positions = batch['aspect_position'].to(device)

            logits = model(input_ids, attention_mask, dynamic_adjacency, aspect_positions)
            loss = F.cross_entropy(logits, labels)
            total_loss += loss.item()

            predictions = torch.argmax(logits, dim=1)
            batch_correct = (predictions == labels).sum().item()
            total_correct += batch_correct
            total_examples += labels.size(0)

            preds.extend(predictions.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

            batch_acc = batch_correct / labels.size(0)
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'acc': f"{batch_acc:.4f}"})

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_examples
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, preds, average='weighted')
    return {'loss': avg_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}


def main():
    DATA_PATH = r"E:\Bengali Aspect\BANGLA_ABSA dataset\Restauant\Restaurant_ABSA.xlsx"
    BATCH_SIZE = 16
    EPOCHS = 20
    LEARNING_RATE = 2e-5
    MAX_LEN = 128
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("Loading and preprocessing data...")
    df = load_and_preprocess_data(DATA_PATH)
    train_df, val_df, test_df = np.split(df.sample(frac=1, random_state=42),
                                         [int(.8*len(df)), int(.9*len(df))])
    
    print(f"Number of training samples: {len(train_df)}")
    print("Sample training data:")
    print(train_df.head(), "\n")
    
    print(f"Number of validation samples: {len(val_df)}")
    print("Sample validation data:")
    print(val_df.head(), "\n")
    
    print(f"Number of test samples: {len(test_df)}")
    print("Sample test data:")
    print(test_df.head(), "\n")

    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

    print("Creating datasets...")
    train_dataset = ABSA_Dataset(train_df, tokenizer, max_length=MAX_LEN)
    val_dataset = ABSA_Dataset(val_df, tokenizer, max_length=MAX_LEN)
    test_dataset = ABSA_Dataset(test_df, tokenizer, max_length=MAX_LEN)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print("Initializing model...")
    model = Hybrid_Model(bert_model_name='bert-base-multilingual-cased', hidden_dim=768, num_classes=3, max_length=MAX_LEN)
    model.to(DEVICE)

    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=1e-8)
    total_steps = len(train_loader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    best_val_f1 = 0.0
    patience = 3
    patience_counter = 0

    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch + 1}/{EPOCHS}")
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, DEVICE)
        val_metrics = eval_model(model, val_loader, DEVICE)
        current_lr = optimizer.param_groups[0]['lr']

        print(f"Learning Rate: {current_lr:.6f}")
        print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.4f}")
        print(f"Validation Loss: {val_metrics['loss']:.4f}, Validation Accuracy: {val_metrics['accuracy']:.4f}")
        print(f"Validation Metrics: Precision: {val_metrics['precision']:.4f}, Recall: {val_metrics['recall']:.4f}, F1: {val_metrics['f1']:.4f}")

        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            torch.save(model.state_dict(), "best_model_double_graph_enhanced.pth")
            print("Saved the best double graph enhanced model!")
            patience_counter = 0 
        else:
            patience_counter += 1
            print(f"No improvement in F1. Patience counter: {patience_counter}/{patience}")
            if patience_counter >= patience:
                print("Early stopping triggered!")
                break

    print("\nEvaluating on Test Data...")
    model.load_state_dict(torch.load("best_model_double_graph_enhanced.pth"))
    test_metrics = eval_model(model, test_loader, DEVICE)

    print(f"Test Loss: {test_metrics['loss']:.4f}, Test Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Test Metrics: Precision: {test_metrics['precision']:.4f}, Recall: {test_metrics['recall']:.4f}, F1: {test_metrics['f1']:.4f}")

if __name__ == "__main__":
    main()