In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cp /content/drive/MyDrive/coreference/coreference_model.pth .

In [3]:
!pip install streamlit pyngrok

Collecting streamlit
  Downloading streamlit-1.49.1-py3-none-any.whl.metadata (9.5 kB)
Collecting pyngrok
  Downloading pyngrok-7.3.0-py3-none-any.whl.metadata (8.1 kB)
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Downloading streamlit-1.49.1-py3-none-any.whl (10.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m80.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyngrok-7.3.0-py3-none-any.whl (25 kB)
Downloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m121.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyngrok, pydeck, streamlit
Successfully installed pydeck-0.9.1 pyngrok-7.3.0 streamlit-1.49.1


In [9]:
%%writefile app.py
import streamlit as st
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import networkx as nx
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
from matplotlib.patches import FancyBboxPatch, ConnectionPatch
import matplotlib.patches as mpatches

# Define the model class (same as training)
class CoreferenceModel(nn.Module):
    def __init__(self, model_name):
        super(CoreferenceModel, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        output = self.dropout(pooled_output)
        return self.classifier(output)

@st.cache_resource
def load_model():
    """Load the trained coreference model"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    try:
        checkpoint = torch.load('/content/coreference_model.pth', weights_only=False)
        model = CoreferenceModel(checkpoint['model_name']).to(device)
        model.load_state_dict(checkpoint['model_state_dict'])
        tokenizer = checkpoint['tokenizer']
        model.eval()
        return model, tokenizer, device
    except Exception as e:
        st.error(e)
        return None, None, None

def predict_coreference(model, tokenizer, device, token1, token2):
    """Predict if two tokens are coreferent"""
    if model is None:
        return 0.5

    text = f"{token1} [SEP] {token2}"

    encoding = tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=128,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        probabilities = torch.softmax(outputs, dim=1)
        coreference_prob = probabilities[0][1].item()  # Probability of being coreferent

    return coreference_prob

def create_clusters_from_text(text, model, tokenizer, device, threshold=0.7):
    """Create coreference clusters from input text"""
    # Simple tokenization (split by spaces and punctuation)
    import re
    tokens = re.findall(r'\b\w+\b', text)

    if not tokens:
        return [], {}

    # Create token mapping
    token_to_id = {token: f"T{i+1}" for i, token in enumerate(set(tokens))}

    # Get all unique tokens with their positions
    unique_tokens = list(set(tokens))
    token_positions = {}
    for i, token in enumerate(tokens):
        if token not in token_positions:
            token_positions[token] = []
        token_positions[token].append(i)

    # Predict coreference for all pairs
    clusters = []
    used_tokens = set()

    for i, token1 in enumerate(unique_tokens):
        if token1 in used_tokens:
            continue

        current_cluster = [token1]
        used_tokens.add(token1)

        for j, token2 in enumerate(unique_tokens[i+1:], i+1):
            if token2 in used_tokens:
                continue

            prob = predict_coreference(model, tokenizer, device, token1, token2)

            if prob > threshold:
                current_cluster.append(token2)
                used_tokens.add(token2)

        if len(current_cluster) > 1:
            clusters.append(current_cluster)

    return clusters, token_to_id

def create_cluster_table(clusters, token_to_id):
    """Create a DataFrame for cluster visualization"""
    data = []
    for i, cluster in enumerate(clusters):
        cluster_id = f"C{i+1}"
        for token in cluster:
            data.append({
                'Cluster ID': cluster_id,
                'Token': token,
                'Token ID': token_to_id[token]
            })

    return pd.DataFrame(data)

def create_network_graph(clusters, token_to_id):
    """Create network graph visualization"""
    fig = go.Figure()

    # Create graph
    G = nx.Graph()

    # Add nodes and edges
    colors = px.colors.qualitative.Set3
    cluster_colors = {}

    for i, cluster in enumerate(clusters):
        color = colors[i % len(colors)]
        cluster_colors[f"C{i+1}"] = color

        # Add nodes
        for token in cluster:
            G.add_node(token_to_id[token], cluster=f"C{i+1}")

        # Add edges within cluster
        for j, token1 in enumerate(cluster):
            for token2 in cluster[j+1:]:
                G.add_edge(token_to_id[token1], token_to_id[token2])

    # Get layout
    pos = nx.spring_layout(G, k=3, iterations=50)

    # Add edges
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    fig.add_trace(go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=2, color='lightgray'),
        hoverinfo='none',
        mode='lines'
    ))

    # Add nodes
    for i, cluster in enumerate(clusters):
        cluster_id = f"C{i+1}"
        color = cluster_colors[cluster_id]

        node_x = []
        node_y = []
        node_text = []
        hover_text = []

        for token in cluster:
            token_id = token_to_id[token]
            x, y = pos[token_id]
            node_x.append(x)
            node_y.append(y)
            node_text.append(token_id)
            hover_text.append(f"Token: {token}<br>ID: {token_id}<br>Cluster: {cluster_id}")

        fig.add_trace(go.Scatter(
            x=node_x, y=node_y,
            mode='markers+text',
            text=node_text,
            textposition='middle center',
            hovertext=hover_text,
            hoverinfo='text',
            marker=dict(
                size=30,
                color=color,
                line=dict(width=2, color='black')
            ),
            name=cluster_id
        ))

    fig.update_layout(
        title="Coreference Clusters Network",
        showlegend=True,
        hovermode='closest',
        margin=dict(b=20,l=5,r=5,t=40),
        annotations=[ dict(
            text="Hover over nodes to see original tokens",
            showarrow=False,
            xref="paper", yref="paper",
            x=0.005, y=-0.002,
            xanchor='left', yanchor='bottom',
            font=dict(color='gray', size=12)
        )],
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        plot_bgcolor='white'
    )

    return fig


def create_arc_diagram(text, clusters, token_to_id):
    """Create arc-based visualization showing coreference relationships"""
    # Simple tokenization
    import re
    tokens = re.findall(r'\b\w+\b', text)

    if not tokens or not clusters:
        return None

    fig, ax = plt.subplots(figsize=(16, 6))

    # Create a mapping from token to its ID for display
    token_display_ids = []
    for token in tokens:
        token_display_ids.append(token_to_id.get(token, token))

    # Position tokens along x-axis (base line)
    x_positions = np.arange(len(tokens))
    y_base = 0

    # Colors for different clusters
    colors = ['#2E8B57', '#4682B4', '#DC143C', '#FF8C00', '#9932CC', '#008B8B', '#B22222', '#228B22']

    # Draw the base line
    ax.axhline(y=y_base, color='black', linewidth=2, alpha=0.8)

    # Plot token IDs on the base line
    for i, token_id in enumerate(token_display_ids):
        # Draw vertical tick mark
        ax.plot([i, i], [y_base-0.05, y_base+0.05], 'k-', linewidth=2)

        # Add token ID below the line
        ax.text(i, y_base-0.15, str(token_id), ha='center', va='top',
               fontsize=12, weight='bold')

    # Draw arcs for coreference relationships
    max_height = 0
    for cluster_idx, cluster in enumerate(clusters):
        color = colors[cluster_idx % len(colors)]

        # Find all positions of tokens in this cluster
        positions = []
        for token in cluster:
            for i, t in enumerate(tokens):
                if t == token:
                    positions.append(i)

        positions = sorted(set(positions))  # Remove duplicates and sort

        if len(positions) > 1:
            # For each cluster, connect all pairs with arcs
            # But to avoid clutter, we'll connect in a chain-like manner
            for i in range(len(positions) - 1):
                start_pos = positions[i]
                end_pos = positions[i + 1]

                # Calculate arc parameters
                center_x = (start_pos + end_pos) / 2
                width = end_pos - start_pos

                # Arc height increases with distance and cluster index for layering
                base_height = 0.3 + (width * 0.1)
                arc_height = base_height + (cluster_idx * 0.2)
                max_height = max(max_height, arc_height)

                # Create semi-circle arc
                theta = np.linspace(0, np.pi, 100)
                arc_x = center_x + (width/2) * np.cos(theta)
                arc_y = y_base + arc_height * np.sin(theta)

                # Draw the arc
                ax.plot(arc_x, arc_y, color=color, linewidth=3, alpha=0.8)

                # Add small vertical lines at connection points
                ax.plot([start_pos, start_pos], [y_base, y_base + 0.1],
                       color=color, linewidth=3, alpha=0.8)
                ax.plot([end_pos, end_pos], [y_base, y_base + 0.1],
                       color=color, linewidth=3, alpha=0.8)

    # Customize the plot
    ax.set_xlim(-0.5, len(tokens) - 0.5)
    ax.set_ylim(-0.3, max_height + 0.3)

    # Remove all spines and ticks except for custom elements
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])

    # Add title
    ax.set_title('Coreference Resolution - Arc Diagram', fontsize=16, weight='bold', pad=20)

    # Add legend if there are clusters
    if clusters:
        legend_elements = []
        for i in range(len(clusters)):
            color = colors[i % len(colors)]
            legend_elements.append(plt.Line2D([0], [0], color=color, lw=3,
                                            label=f'Cluster C{i+1}'))
        ax.legend(handles=legend_elements, loc='upper right', frameon=False)

    plt.tight_layout()
    return fig

import pandas as pd
import torch
import re
from collections import defaultdict

def create_word_cluster_table(clusters, word_to_id):
    """Create a DataFrame for word cluster visualization"""
    data = []
    for i, cluster in enumerate(clusters):
        cluster_id = f"C{i+1}"
        for word in cluster:
            data.append({
                'Cluster ID': cluster_id,
                'Word': word,
                'Word ID': word_to_id[word]
            })

    return pd.DataFrame(data)

def extract_words_from_tamil_text(text):
    """Extract meaningful words from Tamil text, handling Tamil script properly"""
    # Remove punctuation and split by spaces
    # This regex preserves Tamil characters and removes common punctuation
    words = re.findall(r'[\u0B80-\u0BFF]+|[a-zA-Z]+', text)

    # Filter out very short words (single characters) that might be punctuation
    words = [word.strip() for word in words if len(word.strip()) > 1]

    return words

def predict_coreference(model, tokenizer, device, word1, word2):
    """Predict if two words are coreferent"""
    if model is None:
        # Fallback: simple string similarity for demo
        if word1.lower() == word2.lower():
            return 0.9
        elif word1.lower() in word2.lower() or word2.lower() in word1.lower():
            return 0.6
        else:
            return 0.3

    # Format input as the model expects
    text = f"{word1} [SEP] {word2}"

    try:
        encoding = tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=128,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = torch.softmax(outputs, dim=1)
            coreference_prob = probabilities[0][1].item()  # Probability of being coreferent

        return coreference_prob

    except Exception as e:
        print(f"Error predicting coreference: {e}")
        return 0.0

def create_word_clusters_from_tamil_sentence(text, model, tokenizer, device, threshold=0.7):
    """Create coreference clusters from Tamil sentence input - returns word clusters"""

    # Extract words from the text
    words = extract_words_from_tamil_text(text)

    if not words:
        return [], {}

    print(f"Extracted words: {words}")

    # Create word mapping
    unique_words = list(set(words))
    word_to_id = {word: f"W{i+1}" for i, word in enumerate(unique_words)}

    # Track word positions in original text
    word_positions = defaultdict(list)
    for i, word in enumerate(words):
        word_positions[word].append(i)

    # Create clusters using coreference prediction
    clusters = []
    used_words = set()

    for i, word1 in enumerate(unique_words):
        if word1 in used_words:
            continue

        current_cluster = [word1]
        used_words.add(word1)

        for j, word2 in enumerate(unique_words[i+1:], i+1):
            if word2 in used_words:
                continue

            # Predict coreference between words
            prob = predict_coreference(model, tokenizer, device, word1, word2)
            # print(f"Coreference probability between '{word1}' and '{word2}': {prob:.3f}")

            if prob > threshold:
                current_cluster.append(word2)
                used_words.add(word2)

        # Only keep clusters with more than one word
        if len(current_cluster) > 1:
            clusters.append(current_cluster)

    return clusters, word_to_id

def get_coreference_clusters_tamil_words(text, model=None, tokenizer=None, device=None, threshold=0.7):
    """
    Main function: Given a Tamil sentence, return word coreference clusters.
    Returns both clusters and a DataFrame for visualization.

    Args:
        text (str): Tamil sentence
        model: Trained coreference model (optional)
        tokenizer: Model tokenizer (optional)
        device: PyTorch device (optional)
        threshold (float): Coreference probability threshold

    Returns:
        clusters (list): List of word clusters
        cluster_df (DataFrame): DataFrame for visualization
    """

    # Ensure input is a valid string
    if not isinstance(text, str) or not text.strip():
        return [], pd.DataFrame()

    print(f"Input text: {text}")

    # Set default device if not provided
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Run cluster creation
    clusters, word_to_id = create_word_clusters_from_tamil_sentence(
        text=text,
        model=model,
        tokenizer=tokenizer,
        device=device,
        threshold=threshold
    )

    # Build DataFrame for visualization
    if clusters:
        cluster_df = create_word_cluster_table(clusters, word_to_id)
    else:
        cluster_df = pd.DataFrame(columns=['Cluster ID', 'Word', 'Word ID'])

    return clusters, cluster_df

def main():
    st.set_page_config(
        page_title="Coreference Resolution",
        page_icon="🔗",
        layout="wide"
    )

    st.title("🔗 Coreference Resolution System")
    st.markdown("Enter text to identify coreference clusters and visualize relationships between entities.")

    # Load model
    model, tokenizer, device = load_model()

    if model is None:
        st.warning("Model not loaded. Running in demo mode with sample predictions.")

    # Input section
    st.header("Input Text")
    input_text = st.text_area(
        "Enter your text here:",
        height=150,
        placeholder="Enter text containing entities that might refer to the same thing..."
    )

    # Threshold slider
    threshold = st.slider("Coreference Threshold", 0.1, 0.9, 0.7, 0.1)

    if st.button("Analyze Coreference", type="primary"):
        if input_text.strip():
            with st.spinner("Analyzing coreference relationships..."):
                # Create clusters
                clusters, cluster_df = get_coreference_clusters_tamil_words(
                    text=input_text,
                    model=model,
                    tokenizer=tokenizer,
                    device=device,
                    threshold=0.6  # Lower threshold for demo
                )


                if clusters:
                    st.success(f"Found {len(clusters)} coreference clusters!")

                    st.header("📊 Cluster Table")
                    st.dataframe(cluster_df, use_container_width=True)


                    # Cluster details
                    st.header("📝 Cluster Details")
                    for i, cluster in enumerate(clusters):
                        with st.expander(f"Cluster C{i+1} - {len(cluster)} tokens"):
                            st.write(f"  Cluster {i}: {cluster}")

                else:
                    st.info("No coreference clusters found. Try lowering the threshold or using different text.")

        else:
            st.error("Please enter some text to analyze.")

    # Sample texts
    # st.header("💡 Sample Texts")
    # col1, col2 = st.columns(2)

    # with col1:
    #     if st.button("Load Sample 1"):
    #         sample1 = "John went to the store. He bought some milk. The man was happy with his purchase."
    #         st.session_state['sample_text'] = sample1

    # with col2:
    #     if st.button("Load Sample 2"):
    #         sample2 = "Mary is a teacher. She works at the school. The woman loves her job very much."
    #         st.session_state['sample_text'] = sample2

    # # Load sample text if button was clicked
    # if 'sample_text' in st.session_state:
    #     st.text_area("Sample loaded:", value=st.session_state['sample_text'], key="sample_display")
    #     if st.button("Use This Sample"):
    #         st.session_state['input_text'] = st.session_state['sample_text']
    #         st.rerun()

if __name__ == "__main__":
    main()

Overwriting app.py


In [10]:
from pyngrok import ngrok

ngrok_key = "2vOa3qpFVL6rc9bgTeWIrI3EG0i_4YHxrqQsdfmka7ryesrY9"
port = 8501

ngrok.set_auth_token(ngrok_key)
ngrok.connect(port).public_url

'https://8069cf7895b5.ngrok-free.app'

In [None]:
!rm -rf logs.txt && streamlit run app.py &>/content/logs.txt

In [None]:
# tamil_samples = {
#     "Tamil Sample 1": """
#     உன்னைப் பார்த்தால் தெரியவில்லையா ? நீ தப்பி ஓடி ஒளிந்து கொள்ள வந்திருக்கிறவன் என்று நேற்றைக்கே ஊகித்தேன் .
#     """,

#     "Tamil Sample 2": """
#     அவன் பள்ளிக்குச் சென்றான் . அவன் நல்ல மாணவன் . பையன் படிப்பில் சிறந்து விளங்குகிறான் .
#     """,

#     "Tamil Sample 3": """
#     பெண் கடைக்குச் சென்றாள் . அவள் பழம் வாங்கினாள் . பெண்மணி மகிழ்ச்சியாக இருந்தாள் .
#     """
# }
