In [None]:
!git clone https://github.com/declare-lab/RECCON.git

In [None]:
"""
Loads a dialogue dataset from a JSON file, 
retains only the 'turn', 'speaker', and 'utterance' fields in each utterance, 
and writes the cleaned dialogue data to a new JSON file.

Steps:
1. Load the original dialogue JSON file.
2. Traverse through the nested structure of conversation data.
3. Remove any unwanted fields from each utterance.
4. Save the cleaned data to 'clean_conversation.json'.
"""

import json

# Load JSON from a file
with open("/kaggle/working/RECCON/data/original_annotation/dailydialog_valid.json", "r", encoding="utf-8") as f:
    data = json.load(f)

# Function to clean the data
def clean_json(data):
    for key in data:  # Iterate over each conversation ID
        for conversation in data[key]:  # Iterate over lists of utterances
            for utterance in conversation:  # Iterate over individual utterances
                keys_to_remove = [k for k in utterance if k not in {"turn", "speaker", "utterance"}]
                for k in keys_to_remove:
                    del utterance[k]

# Clean the JSON structure
clean_json(data)

# Save cleaned JSON to a new file
with open("clean_conversation.json", "w", encoding="utf-8") as f:
    json.dump(data, f, indent=4)

print("Processing complete. Cleaned JSON saved t0 clean_conversation.json")


In [None]:
"""
This script processes dialogue data in JSON format to extract meaningful clauses 
from utterances using spaCy's dependency parser. It then applies post-processing 
to merge certain clause patterns for better quality. The workflow includes:
1. Clause extraction from each utterance.
2. Removal of redundant or nested clauses.
3. Merging of 'so'-starting and single-word clauses.
4. Saving the processed output into a new JSON file.
"""


#Task 1
import re
import json
import spacy

# Load the spaCy English model
nlp = spacy.load("en_core_web_sm")

def clause_span(token, doc):
    """Return the surface‑ordered text span of token.subtree."""
    indices = sorted(t.i for t in token.subtree)
    span = doc[indices[0] : indices[-1] + 1].text
    return span

def extract_clauses(text):
    """
    Extracts unique, non‑redundant clauses in surface order,
    normalizing whitespace before punctuation.
    """
    doc = nlp(text)
    raw_clauses = []

    # Select all predicate/subordinate clause heads
    for token in doc:
        if token.dep_ in {"ROOT", "advcl", "ccomp", "acl", "xcomp"}:
            span = clause_span(token, doc)
            # remove extra space before punctuation
            span = re.sub(r'\s+([,.;?!])', r'\1', span).strip()
            raw_clauses.append(span)

    # Filter out any clause fully contained in a larger clause
    clauses = []
    for c in raw_clauses:
        if not any(c != other and c in other for other in raw_clauses):
            clauses.append(c)

    return clauses

def process_conversations(input_file, output_file):
    """
    Reads a JSON of conversations, extracts clauses for each utterance,
    and writes the augmented data back to JSON.
    """
    with open(input_file, "r", encoding="utf-8") as f:
        conversations = json.load(f)

    for conv_id, dialogues in conversations.items():
        for dialogue in dialogues:
            for turn in dialogue:
                utt = turn.get("utterance", "")
                turn["clauses"] = extract_clauses(utt)

    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(conversations, f, indent=4, ensure_ascii=False)

    print(f"Processed data saved to {output_file}")

if __name__ == "__main__":
    # Adjust these paths as needed
    input_file = "/kaggle/working/RECCON/data/original_annotation/dailydialog_train.json"
    output_file = "extracted_clauses_corrected_train.json"
    process_conversations(input_file, output_file)


import json

def remerge_so_and_single_clauses(clauses):
    """
    Given a list of clause strings, performs two kinds of merging:
    1. Any clause starting with 'so ' (case-insensitive) is merged into the previous clause.
    2. Any single-word clause is merged into the following clause.
    Returns a new list of merged clauses.
    """
    # Step 1: merge 'so ' clauses into the previous one
    merged = []
    for clause in clauses:
        stripped = clause.lstrip()
        if merged and stripped.lower().startswith("so "):
            prev = merged[-1].rstrip(" .?!")
            # capitalize first word of the 'so' clause and append
            merged[-1] = prev + " " + stripped.capitalize()
        else:
            merged.append(clause)

    # Step 2: merge single-word clauses into the next one
    final = []
    i = 0
    while i < len(merged):
        clause = merged[i].strip()
        # count words ignoring punctuation
        word_only = clause.strip(".?!")
        if i + 1 < len(merged) and len(word_only.split()) == 1:
            # merge this single word with the next clause
            next_clause = merged[i + 1].lstrip()
            # lowercase first character of next clause for smooth merging
            if next_clause:
                merged_next = word_only + " " + next_clause[0].lower() + next_clause[1:]
            else:
                merged_next = word_only
            final.append(merged_next)
            i += 2  # skip the next clause since we've merged it
        else:
            final.append(merged[i])
            i += 1

    return final


def postprocess_file(input_file, output_file):
    """
    Loads a JSON of conversations (with a 'clauses' list per turn),
    applies remerge_so_and_single_clauses to each turn, and writes out the result.
    """
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)

    for conv_id, dialogues in data.items():
        for dialogue in dialogues:
            for turn in dialogue:
                if 'clauses' in turn and isinstance(turn['clauses'], list):
                    turn['clauses'] = remerge_so_and_single_clauses(turn['clauses'])

    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

    print(f"Post-processed data saved to {output_file}")

if __name__ == "__main__":
    postprocess_file(
        "extracted_clauses_corrected_train.json",
        "extracted_clauses_train.json"
    )


In [None]:
"""
This program extracts clause-level emotion and cause labels from annotated dialogues.

Steps:
1. Load the input JSON file containing dialogue turns with emotion and cause annotations.
2. Extract all emotion cause spans for each dialogue.
3. For each clause in a turn, check if it expresses emotion, cause, both, or neither.
4. Assign labels accordingly and write the result to a CSV file with columns:
   clause_text, emotion_label, clause_label.
"""


import json
import csv

def build_clause_dataset(input_json, output_csv):
    with open(input_json, 'r', encoding='utf-8') as f:
        data = json.load(f)

    with open(output_csv, 'w', newline='', encoding='utf-8') as fout:
        writer = csv.writer(fout)
        writer.writerow(["clause_text", "emotion_label", "clause_label"])

        for conv in data.values():
            for dialogue in conv:
                # Step 1: gather all cause spans from the dialogue
                all_cause_spans = []
                for turn in dialogue:
                    spans = turn.get("expanded emotion cause span", [])
                    all_cause_spans.extend([span.strip() for span in spans])

                # Step 2: now process each clause in context of all cause spans
                for turn in dialogue:
                    emotion_label = turn.get("emotion", "neutral")
                    clauses = turn.get("clauses", [])

                    for clause in clauses:
                        is_emotion = emotion_label != "neutral"
                        is_cause = any(span in clause for span in all_cause_spans)

                        if is_emotion and is_cause:
                            clause_label = "both"
                        elif is_emotion:
                            clause_label = "emotion"
                        elif is_cause:
                            clause_label = "cause"
                        else:
                            clause_label = "neither"

                        writer.writerow([clause, emotion_label,clause_label])

    print(f"✅ Dataset written to {output_csv}")

if __name__ == "__main__":
    build_clause_dataset(
        "extracted_clauses_test.json",
        "clause_dataset_test.csv"
    )


In [None]:
"""
This program generates clause-level embeddings for a train-test dataset using the RoBERTa model.
Each clause is associated with an emotion label, and the embeddings are stored in a PyTorch .pt file.

Steps:
1. Load a JSON file containing dialogues and clause-level annotations.
2. Use the RoBERTa transformer to compute embeddings for each clause.
3. Associate each clause with its corresponding emotion label.
4. Organize data per conversation ID and save embeddings, clauses, and emotion labels as a dictionary to a .pt file.
"""

import json
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

# Setup
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
model = AutoModel.from_pretrained("roberta-base")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load clauses
with open("/kaggle/working/extracted_clauses_train.json", "r", encoding="utf-8") as f:
    data = json.load(f)

# Clause embeddings: {conv_id: {"clauses": [...], "embeddings": Tensor [n,768], "emotions": [...]}}
clause_data = {}

def embed_clause(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :].squeeze(0).cpu()  # [768]

# Process each conversation
for conv_id, dialogue in tqdm(data.items(), desc="Embedding clauses"):
    all_clauses = []
    all_embeddings = []
    all_emotions = []

    for turn in dialogue:
        for utt in turn:
            emotion = utt.get("emotion", "neutral")
            clauses = utt.get("clauses", [])
            for clause in clauses:
                emb = embed_clause(clause)
                all_clauses.append(clause)
                all_embeddings.append(emb)
                all_emotions.append(emotion)  # associate this clause with its turn's emotion

    if all_embeddings:
        clause_data[conv_id] = {
            "clauses": all_clauses,
            "embeddings": torch.stack(all_embeddings),  # shape: [num_clauses, 768]
            "emotions": all_emotions                    # shape: [num_clauses]
        }

# Save to .pt file
torch.save(clause_data, "clause_embeddings_with_emotions_train.pt")
print("✅ Saved clause embeddings and emotion labels to clause_embeddings_with_emotions file")


In [None]:
"""
This script extracts pairs of emotion and cause clauses from a given JSON file of dialogues.
It then labels these pairs as positive or negative based on the relationship between the emotion and cause clauses.

Steps:
1. Load the input JSON file containing dialogues.
2. Flatten the clauses for each dialogue and collect emotion and cause spans.
3. For each emotion clause, find its corresponding cause clause and create positive pairs.
4. For each positive pair, sample one negative pair that does not have an emotion-cause relationship.
5. Label the pairs (positive: 1, negative: 0) and add their relationship type (if any).
6. Save the generated pairs to an output JSON file and print the pair counts.

The output contains a dictionary with conversation IDs as keys and a list of clause pairs with their attributes.
Each pair has:
- "emotion_idx": Index of the emotion clause
- "cause_idx": Index of the cause clause
- "label": 1 for positive pairs, 0 for negative pairs
- "type": Relationship type (e.g., "no-context")
"""


import json
import random

def extract_clause_pairs_with_types(input_json, output_json):
    with open(input_json, "r", encoding="utf-8") as f:
        data = json.load(f)

    output = {}
    pair_type_counts = {"positive": 0, "negative": 0}  # To count types

    for conv_id, dialogue in data.items():
        flat_clauses = []
        emotion_clause_idxs = set()
        cause_clause_idxs = set()
        gold_pairs = set()
        pair_types = {}

        # Step 1: Flatten all clauses
        clause_idx = 0
        turn_clause_ranges = []
        for turn in dialogue:
            for utt in turn:
                num_clauses = len(utt.get("clauses", []))
                turn_clause_ranges.append((utt, clause_idx, clause_idx + num_clauses))
                clause_idx += num_clauses

        flat_clauses = [clause for turn in dialogue for utt in turn for clause in utt.get("clauses", [])]

        if len(flat_clauses) == 0:
            continue  # skip conversations with no clauses

        # Step 2: Find emotion and cause clauses, and collect gold pairs
        for utt, start_idx, end_idx in turn_clause_ranges:
            emotion = utt.get("emotion", "neutral")
            type_ = utt.get("type", ["no-context"])[0]
            cause_spans = [s.strip() for s in utt.get("expanded emotion cause span", [])]

            if emotion != "neutral":
                for i in range(start_idx, end_idx):
                    emotion_clause_idxs.add(i)

            for span in cause_spans:
                for j in range(len(flat_clauses)):
                    if span in flat_clauses[j]:
                        cause_clause_idxs.add(j)
                        for i in range(start_idx, end_idx):
                            if emotion != "neutral":
                                gold_pairs.add((i, j))
                                pair_types[(i, j)] = type_

        # Step 3: Construct final pairs with 1 negative sample per positive
        final_pairs = []
        for i in emotion_clause_idxs:
            positives = [j for j in range(len(flat_clauses)) if (i, j) in gold_pairs]
            negatives = [j for j in range(len(flat_clauses)) if (i, j) not in gold_pairs and j != i]

            for j in positives:
                rel_type = pair_types.get((i, j), "no-context")
                final_pairs.append({
                    "emotion_idx": i,
                    "cause_idx": j,
                    "label": 1,
                    "type": rel_type
                })

                # Increment count for positive pair
                pair_type_counts["positive"] += 1

                # Sampling 1 negative pair for each positive pair
                if negatives:
                    neg_samples = random.sample(negatives, k=min(random.choice([0, 1]), len(negatives)))
                    for neg_j in neg_samples:
                        final_pairs.append({
                            "emotion_idx": i,
                            "cause_idx": neg_j,
                            "label": 0,
                            "type": "no-context"
                        })
                         # Increment count for negative pair
                        pair_type_counts["negative"] += 1

        if final_pairs:
            output[conv_id] = final_pairs

    with open(output_json, "w", encoding="utf-8") as f:
        json.dump(output, f, indent=2)
    print(f"✅ Saved labeled clause pairs with types to {output_json}")

    # Print the pair counts
    print("Pair counts:")
    print(f"Positive pairs: {pair_type_counts['positive']}")
    print(f"Negative pairs: {pair_type_counts['negative']}")

# Example usage
if __name__ == "__main__":
    extract_clause_pairs_with_types(
        "/kaggle/working/extracted_clauses_train.json",
        "clause_pair_labels_with_cause_type_train.json"
    )


In [None]:
"""
This script processes a dataset containing emotional labels for text clauses and consolidates synonyms into a single label to reduce label sparsity.

Steps:
1. Load the dataset from a CSV file.
2. Compute the support (frequency) of each label in the 'emotion_label' column.
3. Define semantic groups of synonyms (e.g., 'angry' and 'anger' are treated as the same).
4. Build a mapping from lower-support synonyms to the higher-support representative label within each semantic group.
5. Replace all occurrences of synonyms with the corresponding higher-support label.
6. Save the cleaned dataset into a new CSV file.
7. Optionally, print out the label replacements that occurred.

The goal is to make the dataset less sparse and more consistent by consolidating similar emotion labels into one representative label.
"""


import pandas as pd
from collections import defaultdict

# Load your CSV
df = pd.read_csv("/kaggle/working/clause_dataset_test.csv")

# Assume the label column is named 'label'
label_col = 'emotion_label'

# Step 1: Compute label support
support = df[label_col].value_counts().to_dict()

# Step 2: Define semantic equivalents
semantic_groups = [
    ['anger', 'angry'],
    ['happiness', 'happy', 'excited'],
    ['surprise', 'surprised'],
    ['disgust'],
    ['fear'],
    ['sadness'],
    ['neutral']
]

# Step 3: Build mapping from lower-support synonyms to higher-support one
replacement_map = {}
for group in semantic_groups:
    # Get support values for all in group
    group_supports = {label: support.get(label, 0) for label in group}
    # Find the label with max support
    target_label = max(group_supports, key=group_supports.get)
    # Map other labels to this one
    for label in group:
        if label != target_label:
            replacement_map[label] = target_label

# Step 4: Replace labels
df[label_col] = df[label_col].replace(replacement_map)

# Step 5: Save cleaned CSV
df.to_csv("clause_dataset_test.csv", index=False)

# Optional: print what was replaced
print("Replaced labels:")
for old, new in replacement_map.items():
    print(f"{old} → {new}")


In [None]:
'''
This script is designed for training a multi-class classification model 
to predict emotions and clauses from text using a transformer-based 
approach. The code uses pre-trained models (RoBERTa for emotion classification 
and BERT for clause classification) and fine-tunes them on custom datasets.

The script performs the following steps:
1. Data Preprocessing: Loads and preprocesses emotion and clause datasets, 
   including fixing typos, label encoding, and saving the encoders.
2. Class Weight Calculation**: Computes class weights to handle class imbalance 
   for the clause classification task.
3. Tokenizer Initialization: Loads tokenizers for both emotion and clause models.
4. Dataset Creation: Creates custom dataset classes for both emotion and clause tasks.
5. Weighted Sampling: Implements weighted random samplers to address class imbalance 
   during training.
6. Model Definition: Defines transformer-based models for both emotion and clause classification, 
   using BERT-like architectures with a linear classifier head.
7. Loss Functions: Implements Focal Loss for clause classification to improve performance 
   on imbalanced classes, and Cross-Entropy Loss for emotion classification.
8. Training Loop: Defines the training loop, which optimizes the models using AdamW 
   and uses a learning rate scheduler to adjust the learning rate based on performance.
9. Model Evaluation: Evaluates the models on test datasets and generates classification reports.

Key components:
- Emotion classification with RoBERTa (Twitter model)
- Clause classification with BERT (base model)
- Handling class imbalance using weighted random sampling and loss functions
- Saving the tokenizers and label encoders for future use
'''



# Task 2
import os
# Make CUDA errors synchronous for accurate tracebacks
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import pickle
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import numpy as np
import joblib

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Config
MODEL_NAME_EMOTION = 'cardiffnlp/twitter-roberta-base-emotion'
MODEL_NAME_CLAUSE  = 'bert-base-uncased'
MAX_LEN    = 128
BATCH_SIZE = 16
EPOCHS     = 12
LR         = 1e-5
WD         = 0.01
PATIENCE   = 3

# Paths
TRAIN_CSV = '/kaggle/working/clause_dataset_train.csv'
TEST_CSV  = '/kaggle/working/clause_dataset_test.csv'

# Load & preprocess
df_train = pd.read_csv(TRAIN_CSV)
df_test  = pd.read_csv(TEST_CSV)

# Fix typos
fix_map = {"happines": "happiness", "sad": "sadness"}
df_train['emotion_label'] = df_train['emotion_label'].replace(fix_map)
df_test ['emotion_label'] = df_test ['emotion_label'].replace(fix_map)

# Encode labels
emotion_encoder = LabelEncoder()
clause_encoder  = LabelEncoder()
emotion_encoder.fit(df_train['emotion_label'].tolist() + df_test['emotion_label'].tolist())
clause_encoder .fit(df_train['clause_label'].tolist()  + df_test['clause_label'].tolist())

for df in (df_train, df_test):
    df['emotion_label_enc'] = emotion_encoder.transform(df['emotion_label'])
    df['clause_label_enc']  = clause_encoder.transform(df['clause_label'])

# Save encoders
with open('encoder_emotion.pkl', 'wb') as f:
    pickle.dump(emotion_encoder, f)
with open('encoder_clause.pkl', 'wb') as f:
    pickle.dump(clause_encoder, f)
    
# Compute class weights for clause loss
num_clause_classes = len(clause_encoder.classes_)
present = np.unique(df_train['clause_label_enc'])
w_present = compute_class_weight(class_weight='balanced',
                                 classes=present,
                                 y=df_train['clause_label_enc'])
clause_weights = np.zeros(num_clause_classes, dtype=np.float32)
for cls, w in zip(present, w_present):
    clause_weights[int(cls)] = w
clause_weights_tensor = torch.tensor(clause_weights, dtype=torch.float)

# Tokenizers

tokenizer_emotion = AutoTokenizer.from_pretrained(MODEL_NAME_EMOTION)
with open('tokenizer_emotion.pkl', 'wb') as f:
    pickle.dump(tokenizer_emotion, f)
tokenizer_clause  = AutoTokenizer.from_pretrained(MODEL_NAME_CLAUSE)
with open('tokenizer_clause.pkl', 'wb') as f:
    pickle.dump(tokenizer_clause, f)

np.save('emotion_label_classes.npy', emotion_encoder.classes_)
np.save('clause_label_classes.npy',  clause_encoder.classes_)

# Dataset definition
TEXT_COL = 'clause_text'
class ClauseDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.texts = df[TEXT_COL].tolist()
        self.em_ls = df['emotion_label_enc'].tolist()
        self.cl_ls = df['clause_label_enc'].tolist()
        self.tok   = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tok(self.texts[idx],
                       truncation=True,
                       padding='max_length',
                       max_length=MAX_LEN,
                       return_tensors='pt')
        return {
            'input_ids':      enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'emotion_label':  torch.tensor(self.em_ls[idx], dtype=torch.long),
            'clause_label':   torch.tensor(self.cl_ls[idx], dtype=torch.long)
        }

# Samplers for oversampling
e_counts = df_train['emotion_label_enc'].value_counts().sort_index().values
w_em     = 1.0 / e_counts
sampler_em = WeightedRandomSampler(w_em[df_train['emotion_label_enc']],
                                   num_samples=len(df_train), replacement=True)

c_counts = df_train['clause_label_enc'].value_counts().reindex(
    np.arange(num_clause_classes), fill_value=0).values
w_cl     = 1.0 / (c_counts + 1e-6)
sampler_cl = WeightedRandomSampler(w_cl[df_train['clause_label_enc']],
                                   num_samples=len(df_train), replacement=True)

# Dataloaders
ds_tr_em = ClauseDataset(df_train, tokenizer_emotion)
ds_te_em = ClauseDataset(df_test,  tokenizer_emotion)
train_loader_em = DataLoader(ds_tr_em, batch_size=BATCH_SIZE, sampler=sampler_em)
test_loader_em  = DataLoader(ds_te_em, batch_size=BATCH_SIZE)

ds_tr_cl = ClauseDataset(df_train, tokenizer_clause)
ds_te_cl = ClauseDataset(df_test,  tokenizer_clause)
train_loader_cl = DataLoader(ds_tr_cl, batch_size=BATCH_SIZE, sampler=sampler_cl)
test_loader_cl  = DataLoader(ds_te_cl,  batch_size=BATCH_SIZE)

# Loss & Models
class FocalLoss(nn.Module):
    def __init__(self, weights=None, gamma=2.0):
        super().__init__()
        self.gamma = gamma
        self.register_buffer('weights', weights)
        self.ce = nn.CrossEntropyLoss(weight=weights, reduction='none')
    def forward(self, logits, targets):
        ce = self.ce(logits, targets)
        pt = torch.exp(-ce)
        return ((1 - pt)**self.gamma * ce).mean()

class EmotionModel(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.bert = AutoModel.from_pretrained(MODEL_NAME_EMOTION)
        hid = self.bert.config.hidden_size
        self.dropout = nn.Dropout(0.4)
        self.norm    = nn.LayerNorm(hid)
        self.classifier = nn.Linear(hid, num_labels)
    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.norm(self.dropout(out.last_hidden_state[:,0,:]))
        return self.classifier(pooled)

class ClauseModel(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.bert = AutoModel.from_pretrained(MODEL_NAME_CLAUSE)
        hid = self.bert.config.hidden_size
        self.dropout = nn.Dropout(0.4)
        self.norm    = nn.LayerNorm(hid)
        self.classifier = nn.Linear(hid, num_labels)
    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.norm(self.dropout(out.last_hidden_state[:,0,:]))
        return self.classifier(pooled)

emotion_model = EmotionModel(len(emotion_encoder.classes_)).to(device)
clause_model  = ClauseModel(len(clause_encoder.classes_)).to(device)
crit_em = nn.CrossEntropyLoss()
crit_cl = FocalLoss(weights=clause_weights_tensor).to(device)

optim_em = AdamW(emotion_model.parameters(), lr=LR, weight_decay=WD)
optim_cl = AdamW(clause_model.parameters(),  lr=LR, weight_decay=WD)
sched    = ReduceLROnPlateau(optim_em, mode='min', patience=2, factor=0.5, verbose=True)

# (Training and evaluation loops remain the same as earlier provided)
# ── Training & Evaluation ─────────────────────────────────────────────────────
def train_epoch_emotion():
    emotion_model.train()
    total_loss, correct, total = 0, 0, 0
    for b in tqdm(train_loader_em, desc='Train Emotion'):
        ids, attn, y = (b['input_ids'].to(device),
                        b['attention_mask'].to(device),
                        b['emotion_label'].to(device))
        optim_em.zero_grad()
        logits = emotion_model(ids, attn)
        loss   = crit_em(logits, y)
        loss.backward()
        nn.utils.clip_grad_norm_(emotion_model.parameters(), 1.0)
        optim_em.step()
        total_loss += loss.item()
        correct    += (logits.argmax(1)==y).sum().item()
        total      += y.size(0)
    return total_loss/len(train_loader_em), correct/total

@torch.no_grad()
def evaluate_emotion():
    emotion_model.eval()
    all_y, all_p = [], []
    total_loss = correct = total = 0
    for b in tqdm(test_loader_em, desc='Eval Emotion'):
        ids = b['input_ids'].to(device)
        attn = b['attention_mask'].to(device)
        y = b['emotion_label'].to(device)

        logits = emotion_model(ids, attn)
        loss = crit_em(logits, y)
        total_loss += loss.item()

        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)

        all_y.extend(y.cpu().tolist())
        all_p.extend(pred.cpu().tolist())

    rpt = classification_report(
        all_y,
        all_p,
        labels=list(range(len(emotion_encoder.classes_))),
        target_names=emotion_encoder.classes_,
        zero_division=0
    )
    return total_loss / len(test_loader_em), correct / total, rpt

def train_epoch_clause():
    clause_model.train()
    total_loss, correct, total = 0, 0, 0
    for b in tqdm(train_loader_cl, desc='Train Clause'):
        ids, attn, y = (b['input_ids'].to(device),
                        b['attention_mask'].to(device),
                        b['clause_label'].to(device))
        optim_cl.zero_grad()
        logits = clause_model(ids, attn)
        loss   = crit_cl(logits, y)
        loss.backward()
        nn.utils.clip_grad_norm_(clause_model.parameters(), 1.0)
        optim_cl.step()
        total_loss += loss.item()
        correct    += (logits.argmax(1)==y).sum().item()
        total      += y.size(0)
    return total_loss/len(train_loader_cl), correct/total

@torch.no_grad()
def evaluate_clause():
    clause_model.eval()
    all_y, all_p = [], []
    total_loss = correct = total = 0
    for b in tqdm(test_loader_cl, desc='Eval Clause'):
        ids = b['input_ids'].to(device)
        attn = b['attention_mask'].to(device)
        y = b['clause_label'].to(device)

        logits = clause_model(ids, attn)
        loss = crit_cl(logits, y)
        total_loss += loss.item()

        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)

        all_y.extend(y.cpu().tolist())
        all_p.extend(pred.cpu().tolist())

    rpt = classification_report(
        all_y,
        all_p,
        labels=list(range(len(clause_encoder.classes_))),
        target_names=clause_encoder.classes_,
        zero_division=0
    )
    return total_loss / len(test_loader_cl), correct / total, rpt

# ── Main Loop ──────────────────────────────────────────────────────────────────
best_em_acc, best_cl_acc, no_imp = 0.0, 0.0, 0
for ep in range(1, EPOCHS+1):
    print(f"\n── Epoch {ep}/{EPOCHS} ─────────────────────────")
    tr_em_loss, tr_em_acc = train_epoch_emotion()
    tr_cl_loss, tr_cl_acc = train_epoch_clause()
    val_em_loss, val_em_acc, rpt_em = evaluate_emotion()
    val_cl_loss, val_cl_acc, rpt_cl = evaluate_clause()
    sched.step(val_em_loss)

    print(f"Emotion | train_loss={tr_em_loss:.4f} acc={tr_em_acc:.4f} | "
          f"val_loss={val_em_loss:.4f} acc={val_em_acc:.4f}")
    print("Emotion classification report:\n", rpt_em)
    print(f"Clause  | train_loss={tr_cl_loss:.4f} acc={tr_cl_acc:.4f} | "
          f"val_loss={val_cl_loss:.4f} acc={val_cl_acc:.4f}")
    print("Clause classification report:\n", rpt_cl)

    # early stopping & save
    if val_em_acc > best_em_acc or val_cl_acc > best_cl_acc:
        best_em_acc = max(best_em_acc, val_em_acc)
        best_cl_acc = max(best_cl_acc, val_cl_acc)
        no_imp = 0
        torch.save(emotion_model.state_dict(), 'best_emotion_model.pt')
        torch.save(clause_model.state_dict(),  'best_clause_model.pt')
        print("✅ Models improved and saved.")
    else:
        no_imp += 1
        print(f"⚠️ No improvement ({no_imp}/{PATIENCE})")
        if no_imp >= PATIENCE:
            print("⏹️ Early stopping triggered.")
            break

# ── Final Evaluation ─────────────────────────────────────────────────────────
print("\n📊 Final evaluation with best checkpoints:")
emotion_model.load_state_dict(torch.load('best_emotion_model.pt', map_location=device))
clause_model .load_state_dict(torch.load('best_clause_model.pt',  map_location=device))

_, fe_acc, fe_rpt = evaluate_emotion()
_, fc_acc, fc_rpt = evaluate_clause()
print(f"Final Emotion Acc: {fe_acc:.4f}\n", fe_rpt)
print(f"Final Clause  Acc: {fc_acc:.4f}\n", fc_rpt)

In [None]:
!pip install torch-geometric -q

In [None]:
'''
The code is for training a binary classification model to predict emotion-cause 
relationships between clause pairs using a Graph Attention Network (GAT). It builds 
a graph from conversation clause embeddings and annotated links, and trains the model 
using focal loss to handle class imbalance effectively.

The code follows these steps:
1. Data loading and conversion of clause-pair annotations into graph structures.
2. Definition of the GAT-based link prediction model.
3. Training loop with Focal Loss, evaluation using accuracy and classification report.
4. Early stopping based on validation performance and saving the best model.
'''


# Task 3
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv
import json
from sklearn.metrics import classification_report, accuracy_score
import os

# ==================== Hyperparameters ====================
BATCH_SIZE = 8
HIDDEN = 48
EPOCHS = 50
PATIENCE = 8
MODEL_OUT = "best_gat_model.pt"
MODEL_FULL = "best_gat_model_full.pt"

# ==================== Data Loading ====================
TRAIN_CLAUSE_F = "/kaggle/input/g40-007/clause_embeddings_with_emotions_train.pt"
TEST_CLAUSE_F = "/kaggle/input/g40-007/clause_embeddings_with_emotions_test.pt"
TRAIN_PAIR_F = "/kaggle/working/clause_pair_labels_with_types_train.json"
TEST_PAIR_F = "/kaggle/working/clause_pair_labels_with_types_test.json"

def build_data(conv_id, clause_data, pair_data):
    info = clause_data[conv_id]
    emb  = info["embeddings"]    # [n_clauses, emb_dim]

    # Get exactly the annotated pairs (they already include label=0 or 1)
    pairs = pair_data.get(conv_id, [])
    edge_list  = [(p["emotion_idx"], p["cause_idx"]) for p in pairs]
    edge_label = [p["label"]                for p in pairs]

    # Convert to tensors
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()  # [2, E]
    edge_label = torch.tensor(edge_label, dtype=torch.float)                 # [E]

    return Data(x=emb, edge_index=edge_index, y_link=edge_label)


def build_dataset(clause_data, pair_data):
    return [build_data(cid, clause_data, pair_data) for cid in clause_data if len(pair_data.get(cid, [])) > 0]
    
class FocalLoss(nn.Module):
    def __init__(self, gamma=3.0, alpha=1.0, reduction='mean'):
        super().__init__()
        self.gamma    = gamma
        self.alpha    = alpha
        self.reduction= reduction
        self.bce      = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, logits, targets):
        # per‐sample BCE
        bce_loss = self.bce(logits, targets)
        # probabilities for the true class
        p        = torch.sigmoid(logits)
        p_t      = targets * p + (1 - targets) * (1 - p)
        # focal modulator
        modulator= (1 - p_t) ** self.gamma
        loss     = self.alpha * modulator * bce_loss
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# ==================== Model ====================
class GATLinkPredictor(nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout=0.0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.gat1 = GATConv(in_channels, hidden_channels, heads=4,
                            concat=True, dropout=dropout)
        self.gat2 = GATConv(hidden_channels * 4, hidden_channels,
                            heads=1, concat=True, dropout=dropout)
        
        self.lin  = nn.Linear(hidden_channels * 2, 1)
        self.norm1 = nn.LayerNorm(hidden_channels * 4)
        self.norm2 = nn.LayerNorm(hidden_channels)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.norm1(self.gat1(x, edge_index)))
        x = self.dropout(x)
        x = self.norm2(self.gat2(x, edge_index))
         # in forward:
        
        row, col = edge_index
        edge_feat = torch.cat([x[row], x[col]], dim=1)
        edge_feat = self.dropout(edge_feat)
        return self.lin(edge_feat).squeeze(1)


# ==================== Evaluation ====================
def evaluate(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            link_logits = model(batch)
            preds = torch.sigmoid(link_logits) > 0.55
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(batch.y_link.cpu().tolist())

    acc = accuracy_score(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, digits=4, zero_division=0)
    return acc, report

# ==================== Load and Prepare ====================

train_clause_data = torch.load(TRAIN_CLAUSE_F,weights_only=True)
test_clause_data = torch.load(TEST_CLAUSE_F,weights_only=True)

with open(TRAIN_PAIR_F, "r") as f:
    train_pair_data = json.load(f)
with open(TEST_PAIR_F, "r") as f:
    test_pair_data = json.load(f)

train_data = build_dataset(train_clause_data, train_pair_data)
test_data = build_dataset(test_clause_data, test_pair_data)

print(f"✅ Loaded {len(train_data)} train examples, {len(test_data)} test examples")

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1)

# ==================== Train Loop ====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GATLinkPredictor(in_channels=train_data[0].x.size(1), hidden_channels=HIDDEN).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
criterion = FocalLoss(gamma= 1.0, alpha=3.0)  # you can tweak gamma/alpha

best_acc = 0.0
no_improve = 0

print("🚀 Starting training...")
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0

    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        link_logits = model(batch)
        loss = criterion(link_logits, batch.y_link)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    acc, _ = evaluate(model, test_loader)
    print(f"Epoch {epoch:02d}  Loss: {total_loss/len(train_loader):.4f}  Test Acc: {acc:.4f}")

    if acc > best_acc:
        best_acc = acc
        no_improve = 0
        torch.save(model.state_dict(), MODEL_OUT)
        torch.save(model, MODEL_FULL)
        print(f"✅ Best model saved (acc={acc:.4f})")
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print(f"⏹️ Early stopping triggered at epoch {epoch}")
            break

# ==================== Final Evaluation ====================
print("\n📊 Final Evaluation:")
model = torch.load(MODEL_FULL, weights_only=False).to(device)
model.eval()
acc, report = evaluate(model, test_loader)
print(f"Accuracy: {acc:.4f}\n")
print(report)


In [None]:
'''
- Extracts clauses from conversations using spaCy's dependency parsing.
- Uses pre-trained transformer models to classify each clause's emotion and type.
- Applies a Graph Attention Network (GAT) to predict emotion-cause relationships between clauses.
- Loads all necessary artifacts (tokenizers, models, encoders) and processes input JSON data.
- Outputs labeled results including emotions, clause types, and cause-effect pairs to a JSON file.
'''


import argparse
import json
import re
import spacy
import torch
import pickle
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torch_geometric.data import Data
from torch_geometric.nn import GATConv


class GATLinkPredictor(nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout=0.0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.gat1 = GATConv(in_channels, hidden_channels, heads=4,
                            concat=True, dropout=dropout)
        self.gat2 = GATConv(hidden_channels * 4, hidden_channels,
                            heads=1, concat=True, dropout=dropout)
        
        self.lin  = nn.Linear(hidden_channels * 2, 1)
        self.norm1 = nn.LayerNorm(hidden_channels * 4)
        self.norm2 = nn.LayerNorm(hidden_channels)

    def forward(self, feats, edge_index):
        x = F.relu(self.norm1(self.gat1(feats, edge_index)))
        x = self.dropout(x)
        x = self.norm2(self.gat2(x, edge_index))

        row, col = edge_index
        edge_feat = torch.cat([x[row], x[col]], dim=1)
        edge_feat = self.dropout(edge_feat)
        return self.lin(edge_feat).squeeze(1)


def clause_span(token, doc):
    indices = sorted(t.i for t in token.subtree)
    span = doc[indices[0] : indices[-1] + 1].text
    return span


def extract_clauses(text, nlp):
    doc = nlp(text)
    raw_clauses = []

    for token in doc:
        if token.dep_ in {"ROOT", "advcl", "ccomp", "acl", "xcomp"}:
            span = clause_span(token, doc)
            span = re.sub(r'\s+([,.;?!])', r'\1', span).strip()
            raw_clauses.append(span)

    clauses = []
    for c in raw_clauses:
        if not any(c != other and c in other for other in raw_clauses):
            clauses.append(c)

    return clauses


def remerge_so_and_single_clauses(clauses):
    merged = []
    for clause in clauses:
        stripped = clause.lstrip()
        if merged and stripped.lower().startswith("so "):
            prev = merged[-1].rstrip(" .?!")
            merged[-1] = prev + " " + stripped.capitalize()
        else:
            merged.append(clause)

    final = []
    i = 0
    while i < len(merged):
        clause = merged[i].strip()
        word_only = clause.strip(".?!")
        if i + 1 < len(merged) and len(word_only.split()) == 1:
            next_clause = merged[i + 1].lstrip()
            if next_clause:
                merged_next = word_only + " " + next_clause[0].lower() + next_clause[1:]
            else:
                merged_next = word_only
            final.append(merged_next)
            i += 2
        else:
            final.append(merged[i])
            i += 1

    return final


def build_graph(node_feats):
    num = node_feats.size(0)
    row = []
    col = []
    for i in range(num):
        for j in range(num):
            if i != j:
                row.append(i)
                col.append(j)
    edge_index = torch.tensor([row, col], dtype=torch.long)
    return edge_index


def inference_pipeline(conversation, 
                       tokenizer_emotion, tokenizer_clause,
                       model_emotion, model_clause, gat_model, 
                       emotion_encoder, clause_encoder, type_decoder,
                       device, nlp):
    clauses = []
    for turn in conversation:
        text = turn.get('utterance', '')
        cs = extract_clauses(text, nlp)
        cs = remerge_so_and_single_clauses(cs)
        clauses.extend(cs)

    enc_em = tokenizer_emotion(clauses, padding=True, truncation=True, return_tensors='pt', max_length=128).to(device)
    enc_cl = tokenizer_clause(clauses, padding=True, truncation=True, return_tensors='pt', max_length=128).to(device)

    with torch.no_grad():
        logits_em = model_emotion(enc_em['input_ids'], enc_em['attention_mask'])
        logits_cl = model_clause(enc_cl['input_ids'], enc_cl['attention_mask'])
        feats = model_clause.bert(**enc_cl).last_hidden_state[:, 0, :]

    em_preds = logits_em.argmax(dim=1).cpu().numpy()
    cl_preds = logits_cl.argmax(dim=1).cpu().numpy()
    emotions = emotion_encoder.inverse_transform(em_preds)
    clause_types = clause_encoder.inverse_transform(cl_preds)

    edge_index = build_graph(feats).to(device)
    with torch.no_grad():
        edge_logits = gat_model(feats, edge_index).cpu()
    edge_type_pred = (torch.sigmoid(edge_logits) > 0.55).long().numpy()
    cause_types = type_decoder.inverse_transform(edge_type_pred)

    results = {
        'clauses': [],
        'cause_pairs': []
    }
    for idx, cl in enumerate(clauses):
        results['clauses'].append({
            'text': cl,
            'emotion': emotions[idx],
            'clause_type': clause_types[idx]
        })
    rows, cols = edge_index.cpu().numpy()
    for k, (i, j) in enumerate(zip(rows, cols)):
        results['cause_pairs'].append({
            'source_clause': i,
            'target_clause': j,
            'cause_type': cause_types[k]
        })
    return results


def load_artifacts(args, device):
    tokenizer_emotion = AutoTokenizer.from_pretrained(args.tokenizer_emotion)
    tokenizer_clause  = AutoTokenizer.from_pretrained(args.tokenizer_clause)

    with open(args.encoder_path, 'rb') as f:
        encoders = pickle.load(f)
    emotion_encoder = encoders['emotion']
    clause_encoder = encoders['clause']
    type_decoder = encoders['cause_type']

    model_emotion = torch.load(args.emotion_model_path, map_location=device)
    model_emotion.eval()
    model_clause = torch.load(args.clause_model_path, map_location=device)
    model_clause.eval()
    gat_model = torch.load(args.gat_model_path, map_location=device)
    gat_model.eval()

    return tokenizer_emotion, tokenizer_clause, model_emotion, model_clause, gat_model, emotion_encoder, clause_encoder, type_decoder


def main():
    parser = argparse.ArgumentParser("Emotion-Cause Inference Pipeline")
    parser.add_argument('--input', type=str, required=True, help='/kaggle/working/clean_conversation.json')
    parser.add_argument('--output', type=str, required=True, help='/kaggle/working/clean_conversation_labelled.json')
    parser.add_argument('--tokenizer-emotion', type=str, required=True, help='/kaggle/input/tokenizer/tokenizer (1).json')
    parser.add_argument('--tokenizer-clause', type=str, required=True, help='/kaggle/input/tokenizer/tokenizer.json')
    parser.add_argument('--encoder-path', type=str, required=True, help='Pickle file with LabelEncoders')
    parser.add_argument('--emotion-model-path', type=str, required=True, help='Path to emotion classifier .pth')
    parser.add_argument('--clause-model-path', type=str, required=True, help='Path to clause-type classifier .pth')
    parser.add_argument('--gat-model-path', type=str, required=True, help='Path to GAT model .pth')
    parser.add_argument('--spacy-model', type=str, default='en_core_web_sm', help='spaCy model for clause splitting')
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    nlp = spacy.load(args.spacy_model)

    tokenizer_emotion, tokenizer_clause, model_emotion, model_clause, gat_model, emotion_encoder, clause_encoder, type_decoder = load_artifacts(args, device)

    with open(args.input, 'r', encoding='utf-8') as f:
        convs = json.load(f)

    all_results = {}
    for conv_id, turns in convs.items():
        res = inference_pipeline(
            turns, tokenizer_emotion, tokenizer_clause,
            model_emotion, model_clause, gat_model,
            emotion_encoder, clause_encoder, type_decoder,
            device, nlp
        )
        all_results[conv_id] = res

    with open(args.output, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, indent=2)
    print(f"Results written to {args.output}")

if __name__ == '__main__':
    main()
