In [1]:
!pip install gradio
import os
import torch
import numpy as np
import pandas as pd
from typing import List, Dict, Optional, Union
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, Trainer, TrainingArguments
from sentence_transformers import SentenceTransformer
import logging
from sklearn.metrics.pairwise import cosine_similarity
import datasets
from torch.utils.data import Dataset
import json
import gradio as gr

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ArgumentRelationDataset(Dataset):
    """Custom dataset for argument relation data"""
    
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        item = self.data[idx]
        source = item["source_text"]
        target = item["target_text"]
        context = item.get("context", "")
        
        # Prepare text input
        if context:
            text = f"Context: {context}\nSource: {source}\nTarget: {target}"
        else:
            text = f"Source: {source}\nTarget: {target}"
            
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        # Remove batch dimension
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        
        # Add label
        encoding["labels"] = torch.tensor(item["label_id"])
        
        return encoding

class EnhancedRBAMModel:
    """
    Enhanced Relation-Based Argument Mining Model
    
    This model combines transformer-based classification with semantic similarity
    analysis and contextual argument clustering to provide more accurate relation
    detection between arguments.
    """
    
    def __init__(self, 
                 model_path: str = "models/deberta-v3-large-relation-finetuned",
                 embedding_model: str = "all-MiniLM-L6-v2",
                 device: str = None):
        """
        Initialize the enhanced RBAM model.
        
        Args:
            model_path: Path to the fine-tuned classification model
            embedding_model: SentenceTransformer model for semantic analysis
            device: Device to run the model on (None for auto-detection)
        """
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {self.device}")
        
        # Load classification model
        try:
            self.model_path = model_path
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)
            self.model.eval()
            logger.info(f"Successfully loaded classification model from {model_path}")
            
            # Determine relation labels from model config
            self.id2label = self.model.config.id2label if hasattr(self.model.config, 'id2label') else {
                0: "support", 1: "attack", 2: "neutral", 3: "detail"
            }
            self.label2id = {v: k for k, v in self.id2label.items()}
            logger.info(f"Loaded relation labels: {self.id2label}")
            
        except Exception as e:
            logger.error(f"Failed to load classification model: {e}")
            raise Exception(f"Model initialization failed: {str(e)}")
            
        # Load embedding model for semantic analysis
        try:
            self.embedding_model = SentenceTransformer(embedding_model).to(self.device)
            logger.info(f"Successfully loaded embedding model: {embedding_model}")
        except Exception as e:
            logger.error(f"Failed to load embedding model: {e}")
            self.embedding_model = None
            
        # Initialize cache
        self.cache = {}
        
    def train_model(self, train_data: List[Dict], validation_data: Optional[List[Dict]] = None, 
                   output_dir: str = "./model_output", epochs: int = 3, batch_size: int = 16):
        """
        Train or fine-tune the model on new data.
        
        Args:
            train_data: List of dictionaries with source_text, target_text, optional context, and label_id
            validation_data: Optional validation data with the same format as train_data
            output_dir: Directory to save the trained model
            epochs: Number of training epochs
            batch_size: Batch size for training
        """
        logger.info(f"Starting model training with {len(train_data)} examples")
        
        # Create datasets
        train_dataset = ArgumentRelationDataset(train_data, self.tokenizer)
        eval_dataset = ArgumentRelationDataset(validation_data, self.tokenizer) if validation_data else None
        
        # Define training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir=f"{output_dir}/logs",
            logging_steps=100,
            evaluation_strategy="epoch" if eval_dataset else "no",
            save_strategy="epoch",
            load_best_model_at_end=True if eval_dataset else False,
        )
        
        # Create trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset
        )
        
        # Train model
        trainer.train()
        
        # Save model
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        
        # Update current model
        self.model = trainer.model
        logger.info(f"Training completed, model saved to {output_dir}")
        
        # Clear cache after training
        self.cache = {}
        
        return {
            "status": "success",
            "message": f"Model trained on {len(train_data)} examples and saved to {output_dir}"
        }
        
    def predict_relation(self, 
                        source_text: str, 
                        target_text: str, 
                        context: Optional[str] = None) -> Dict:
        """
        Predict the relation between source and target arguments.
        
        Args:
            source_text: The source argument text
            target_text: The target argument text
            context: Optional context for the arguments
            
        Returns:
            Dictionary with relation type, confidence score, and features
        """
        # Create cache key
        cache_key = f"{source_text[:100]}|{target_text[:100]}|{context[:100] if context else 'no_context'}"
        
        # Check if result is cached
        if cache_key in self.cache:
            return self.cache[cache_key]
            
        # Prepare input
        input_text = self._prepare_input(source_text, target_text, context)
        inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        
        # Get model prediction
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = torch.nn.functional.softmax(logits, dim=1)[0]
            prediction = torch.argmax(probabilities).item()
            confidence = probabilities[prediction].item()
            
        # Get additional semantic features if embedding model is available
        semantic_features = {}
        if self.embedding_model:
            semantic_features = self._get_semantic_features(source_text, target_text)
        
        # Ensemble the predictions
        relation_type = self.id2label[prediction]
        
        # Apply confidence adjustment based on semantic features
        if semantic_features and 'similarity' in semantic_features:
            # Adjust confidence based on semantic similarity
            if relation_type in ['support', 'detail'] and semantic_features['similarity'] < 0.3:
                confidence = confidence * 0.8  # Reduce confidence for support/detail with low similarity
            elif relation_type == 'attack' and semantic_features['similarity'] > 0.8:
                confidence = confidence * 0.8  # Reduce confidence for attack with high similarity
        
        # Build result
        result = {
            'relation_type': relation_type,
            'confidence': confidence,
            'features': {
                'probabilities': {self.id2label[i]: prob.item() for i, prob in enumerate(probabilities)},
                **semantic_features
            }
        }
        
        # Cache the result
        self.cache[cache_key] = result
        
        return result
    
    def _prepare_input(self, source_text: str, target_text: str, context: Optional[str] = None) -> str:
        """
        Prepare input text for the model.
        """
        if context:
            return f"Context: {context}\nSource: {source_text}\nTarget: {target_text}"
        else:
            return f"Source: {source_text}\nTarget: {target_text}"
    
    def _get_semantic_features(self, source_text: str, target_text: str) -> Dict:
        """
        Extract semantic features for the argument pair.
        """
        # Get embeddings
        source_embedding = self.embedding_model.encode(source_text, convert_to_tensor=True)
        target_embedding = self.embedding_model.encode(target_text, convert_to_tensor=True)
        
        # Calculate similarity
        similarity = torch.cosine_similarity(source_embedding.unsqueeze(0), 
                                           target_embedding.unsqueeze(0)).item()
        
        return {
            'similarity': similarity,
        }
    
    def batch_predict(self, argument_pairs: List[Dict]) -> List[Dict]:
        """
        Predict relations for multiple argument pairs.
        
        Args:
            argument_pairs: List of dictionaries with source_text, target_text, and optional context
            
        Returns:
            List of prediction dictionaries
        """
        results = []
        
        for pair in argument_pairs:
            source = pair.get('source_text')
            target = pair.get('target_text')
            context = pair.get('context')
            
            if not source or not target:
                results.append({'error': 'Missing source or target text'})
                continue
                
            result = self.predict_relation(source, target, context)
            results.append(result)
            
        return results
    
    def analyze_graph(self, argument_pairs: List[Dict]) -> Dict:
        """
        Perform graph-based analysis of the argument network.
        
        Args:
            argument_pairs: List of dictionaries with source_text, target_text, and optional context
            
        Returns:
            Dictionary with graph analysis results
        """
        # Get predictions for all pairs
        predictions = self.batch_predict(argument_pairs)
        
        # Extract unique arguments
        unique_args = set()
        for pair in argument_pairs:
            unique_args.add(pair['source_text'])
            unique_args.add(pair['target_text'])
        
        # Build adjacency matrix
        n_args = len(unique_args)
        arg_list = list(unique_args)
        arg_to_idx = {arg: i for i, arg in enumerate(arg_list)}
        
        # Initialize matrices
        support_matrix = np.zeros((n_args, n_args))
        attack_matrix = np.zeros((n_args, n_args))
        
        # Fill matrices
        for i, pair in enumerate(argument_pairs):
            source_idx = arg_to_idx[pair['source_text']]
            target_idx = arg_to_idx[pair['target_text']]
            rel_type = predictions[i]['relation_type']
            confidence = predictions[i]['confidence']
            
            if rel_type == 'support':
                support_matrix[source_idx, target_idx] = confidence
            elif rel_type == 'attack':
                attack_matrix[source_idx, target_idx] = confidence
        
        # Calculate centrality metrics
        support_centrality = np.sum(support_matrix, axis=0)
        attack_centrality = np.sum(attack_matrix, axis=0)
        
        # Identify key arguments
        key_args = []
        for i, arg in enumerate(arg_list):
            total_centrality = support_centrality[i] + attack_centrality[i]
            if total_centrality > 0:
                key_args.append({
                    'text': arg[:100] + '...' if len(arg) > 100 else arg,
                    'support_centrality': float(support_centrality[i]),
                    'attack_centrality': float(attack_centrality[i]),
                    'total_centrality': float(total_centrality)
                })
        
        # Sort by total centrality
        key_args.sort(key=lambda x: x['total_centrality'], reverse=True)
        
        # Identify inconsistencies
        inconsistencies = self._identify_inconsistencies(argument_pairs, predictions)
        
        return {
            'key_arguments': key_args[:5],  # Top 5 key arguments
            'argument_count': len(unique_args),
            'relation_count': len(argument_pairs),
            'inconsistencies': inconsistencies
        }
    
    def _identify_inconsistencies(self, argument_pairs: List[Dict], predictions: List[Dict]) -> List[Dict]:
        """
        Identify potential inconsistencies in the argument network.
        """
        inconsistencies = []
        
        # Build graph
        graph = {}
        for i, pair in enumerate(argument_pairs):
            source = pair['source_text']
            target = pair['target_text']
            rel_type = predictions[i]['relation_type']
            
            if source not in graph:
                graph[source] = {}
            if target not in graph:
                graph[target] = {}
                
            graph[source][target] = rel_type
        
        # Check for circular support/attack patterns
        for source in graph:
            for target in graph.get(source, {}):
                rel_type = graph[source][target]
                
                # Check if there's a reverse relation
                if target in graph and source in graph.get(target, {}):
                    reverse_rel = graph[target][source]
                    
                    # Check if there's an inconsistency
                    if rel_type == reverse_rel and rel_type in ['support', 'attack']:
                        inconsistencies.append({
                            'type': 'circular_relation',
                            'source': source[:50] + '...' if len(source) > 50 else source,
                            'target': target[:50] + '...' if len(target) > 50 else target,
                            'relation': rel_type,
                            'description': f"Circular {rel_type} relation detected"
                        })
        
        # Check for support-attack contradictions
        for arg1 in graph:
            for arg2 in graph.get(arg1, {}):
                rel1 = graph[arg1][arg2]
                
                # Check if arg1 and arg2 have relations to the same arg3
                for arg3 in graph:
                    if arg3 != arg1 and arg3 != arg2 and arg2 in graph.get(arg3, {}):
                        rel2 = graph[arg3][arg2]
                        
                        # Check if one supports while the other attacks
                        if rel1 == 'support' and rel2 == 'attack' or rel1 == 'attack' and rel2 == 'support':
                            inconsistencies.append({
                                'type': 'competing_relations',
                                'arguments': [
                                    arg1[:50] + '...' if len(arg1) > 50 else arg1,
                                    arg2[:50] + '...' if len(arg2) > 50 else arg2,
                                    arg3[:50] + '...' if len(arg3) > 50 else arg3
                                ],
                                'description': f"Competing {rel1}/{rel2} relations detected"
                            })
        
        return inconsistencies

# Create a function to run the interactive web UI
def run_interactive_rbam_system():
    # Create the model
    model_path = "distilbert-base-uncased"  # Default model, will be fine-tuned
    
    try:
        rbam_model = EnhancedRBAMModel(model_path=model_path)
        model_status = "Model loaded successfully"
    except Exception as e:
        # If loading fails, use a fallback approach
        from transformers import DistilBertForSequenceClassification
        model = DistilBertForSequenceClassification.from_pretrained(
            "distilbert-base-uncased", 
            num_labels=4
        )
        model.config.id2label = {0: "support", 1: "attack", 2: "neutral", 3: "detail"}
        model.config.label2id = {"support": 0, "attack": 1, "neutral": 2, "detail": 3}
        
        # Create output directory if it doesn't exist
        os.makedirs("models/distilbert-rbam", exist_ok=True)
        model.save_pretrained("models/distilbert-rbam")
        
        # Now initialize with the saved model
        rbam_model = EnhancedRBAMModel(model_path="models/distilbert-rbam")
        model_status = "Initialized with base model (requires training)"
    
    # Sample training data
    sample_train_data = [
        {
            "source_text": "Climate change is primarily caused by human activities.",
            "target_text": "The rise in global temperatures correlates with increased CO2 emissions.",
            "context": "Environmental science debate",
            "label_id": 0  # support
        },
        {
            "source_text": "Renewable energy is too expensive to replace fossil fuels.",
            "target_text": "Solar panel costs have decreased by 90% in the last decade.",
            "context": "Energy policy discussion",
            "label_id": 1  # attack
        },
        {
            "source_text": "Excessive social media use is harmful to mental health.",
            "target_text": "Studies show correlation between screen time and anxiety in teens.",
            "context": "Public health forum",
            "label_id": 0  # support
        }
    ]
    
    # Function to train the model
    def train_model_from_ui(train_data_json, epochs, batch_size):
        try:
            # Parse training data
            train_data = json.loads(train_data_json)
            
            # Validate data format
            if not isinstance(train_data, list):
                return "Error: Training data must be a list of examples"
            
            for item in train_data:
                if not all(k in item for k in ["source_text", "target_text", "label_id"]):
                    return "Error: Each training example must have source_text, target_text, and label_id"
            
            # Train the model
            result = rbam_model.train_model(
                train_data=train_data,
                epochs=int(epochs),
                batch_size=int(batch_size),
                output_dir="models/rbam-custom"
            )
            
            return f"Training completed: {result['message']}"
        except Exception as e:
            return f"Training error: {str(e)}"
    
    # Function to predict relations
    def predict_relation_from_ui(source_text, target_text, context):
        if not source_text or not target_text:
            return "Error: Source and target arguments are required"
        
        result = rbam_model.predict_relation(
            source_text=source_text,
            target_text=target_text,
            context=context if context else None
        )
        
        # Format the result for display
        formatted_result = f"""
## Prediction Result

**Relation Type:** {result['relation_type']}
**Confidence:** {result['confidence']:.4f} ({result['confidence']*100:.1f}%)

### Probability Distribution:
"""
        
        # Add probability distribution
        probs = result['features']['probabilities']
        for rel, prob in probs.items():
            formatted_result += f"- {rel}: {prob:.4f} ({prob*100:.1f}%)\n"
        
        # Add semantic similarity if available
        if 'similarity' in result['features']:
            formatted_result += f"\n**Semantic Similarity:** {result['features']['similarity']:.4f}\n"
            
        return formatted_result
    
    # Function to analyze argument graph
    def analyze_graph_from_ui(argument_pairs_json):
        try:
            # Parse argument pairs
            argument_pairs = json.loads(argument_pairs_json)
            
            # Validate data format
            if not isinstance(argument_pairs, list):
                return "Error: Argument pairs must be a list"
            
            for item in argument_pairs:
                if not all(k in item for k in ["source_text", "target_text"]):
                    return "Error: Each pair must have source_text and target_text"
            
            # Analyze the graph
            result = rbam_model.analyze_graph(argument_pairs)
            
            # Format the result for display
            formatted_result = f"""
## Graph Analysis Result

**Arguments:** {result['argument_count']}
**Relations:** {result['relation_count']}

### Key Arguments:
"""
            
            # Add key arguments
            for i, arg in enumerate(result['key_arguments']):
                formatted_result += f"{i+1}. **{arg['text']}**\n"
                formatted_result += f"   - Support: {arg['support_centrality']:.2f}\n"
                formatted_result += f"   - Attack: {arg['attack_centrality']:.2f}\n"
                formatted_result += f"   - Total: {arg['total_centrality']:.2f}\n\n"
            
            # Add inconsistencies
            formatted_result += f"\n### Inconsistencies ({len(result['inconsistencies'])}):\n"
            
            for i, inconsistency in enumerate(result['inconsistencies']):
                formatted_result += f"{i+1}. **{inconsistency['type']}**\n"
                formatted_result += f"   {inconsistency['description']}\n\n"
            
            return formatted_result
        except Exception as e:
            return f"Analysis error: {str(e)}"
    
    # Create the Gradio interface
    with gr.Blocks(title="Enhanced RBAM System") as demo:
        gr.Markdown("# Enhanced Relation-Based Argument Mining (RBAM) System")
        gr.Markdown(f"**Model Status:** {model_status}")
        
        with gr.Tab("Model Training"):
            gr.Markdown("### Train the RBAM Model")
            gr.Markdown("Use this tab to fine-tune the model on your own data.")
            
            train_data_input = gr.Textbox(
                label="Training Data (JSON format)",
                placeholder=json.dumps(sample_train_data, indent=2),
                lines=10
            )
            
            with gr.Row():
                epochs_input = gr.Number(label="Epochs", value=3, minimum=1)
                batch_size_input = gr.Number(label="Batch Size", value=8, minimum=1)
            
            train_button = gr.Button("Train Model")
            train_output = gr.Textbox(label="Training Result")
            
            train_button.click(
                fn=train_model_from_ui,
                inputs=[train_data_input, epochs_input, batch_size_input],
                outputs=train_output
            )
        
        with gr.Tab("Relation Prediction"):
            gr.Markdown("### Predict Relations Between Arguments")
            
            source_input = gr.Textbox(label="Source Argument", lines=4)
            target_input = gr.Textbox(label="Target Argument", lines=4)
            context_input = gr.Textbox(label="Context (optional)", lines=2)
            
            predict_button = gr.Button("Predict Relation")
            prediction_output = gr.Markdown(label="Prediction Result")
            
            predict_button.click(
                fn=predict_relation_from_ui,
                inputs=[source_input, target_input, context_input],
                outputs=prediction_output
            )
        
        with gr.Tab("Graph Analysis"):
            gr.Markdown("### Analyze Argument Network")
            
            sample_graph_data = [
                {"source_text": "Climate change is primarily caused by human activities.", 
                 "target_text": "The rise in global temperatures correlates with increased CO2 emissions."},
                {"source_text": "Renewable energy is too expensive to replace fossil fuels.", 
                 "target_text": "Solar panel costs have decreased by 90% in the last decade."},
                {"source_text": "Solar panel costs have decreased by 90% in the last decade.", 
                 "target_text": "Renewable energy is becoming more economically viable."}
            ]
            
            graph_data_input = gr.Textbox(
                label="Argument Pairs (JSON format)",
                placeholder=json.dumps(sample_graph_data, indent=2),
                lines=10
            )
            
            analyze_button = gr.Button("Analyze Graph")
            analysis_output = gr.Markdown(label="Analysis Result")
            
            analyze_button.click(
                fn=analyze_graph_from_ui,
                inputs=graph_data_input,
                outputs=analysis_output
            )
    
    # Launch the interface
    demo.launch(share=True)
    return "Interactive RBAM system started"

# Run the system when this cell is executed
if __name__ == "__main__":
    # Install required packages if not already installed
    try:
        import gradio
    except ImportError:
        print("Installing required packages...")
        import pip
        pip.main(['install', 'gradio', 'transformers', 'sentence-transformers', 'torch', 'datasets'])
    
    print("Starting Enhanced RBAM Interactive System...")
    run_interactive_rbam_system()



2025-08-01 19:00:43.465691: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754074843.726232      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754074843.795063      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Starting Enhanced RBAM Interactive System...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://5097fa65f43e6b9303.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Codo with file loading access

In [4]:
!pip install gradio
import os
import torch
import numpy as np
import pandas as pd
from typing import List, Dict, Optional, Union
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, Trainer, TrainingArguments
from sentence_transformers import SentenceTransformer
import logging
from sklearn.metrics.pairwise import cosine_similarity
import datasets
from torch.utils.data import Dataset
import json
import gradio as gr

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ArgumentRelationDataset(Dataset):
    """Custom dataset for argument relation data"""
    
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        item = self.data[idx]
        source = item["source_text"]
        target = item["target_text"]
        context = item.get("context", "")
        
        # Prepare text input
        if context:
            text = f"Context: {context}\nSource: {source}\nTarget: {target}"
        else:
            text = f"Source: {source}\nTarget: {target}"
            
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        # Remove batch dimension
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        
        # Add label
        encoding["labels"] = torch.tensor(item["label_id"])
        
        return encoding

def load_data_from_json_files(file_paths: List[str]) -> List[Dict]:
    """
    Load and combine data from multiple JSON files.
    
    Args:
        file_paths: List of paths to JSON files
        
    Returns:
        Combined list of data entries from all files
    """
    combined_data = []
    
    for file_path in file_paths:
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                
                # Handle both list and dictionary formats
                if isinstance(data, list):
                    combined_data.extend(data)
                elif isinstance(data, dict) and 'items' in data:
                    combined_data.extend(data['items'])
                elif isinstance(data, dict):
                    combined_data.append(data)
                    
                logger.info(f"Loaded {len(data) if isinstance(data, list) else 1} items from {file_path}")
                
        except Exception as e:
            logger.error(f"Error loading {file_path}: {str(e)}")
    
    return combined_data

class EnhancedRBAMModel:
    """
    Enhanced Relation-Based Argument Mining Model
    
    This model combines transformer-based classification with semantic similarity
    analysis and contextual argument clustering to provide more accurate relation
    detection between arguments.
    """
    
    def __init__(self, 
                 model_path: str = "models/deberta-v3-large-relation-finetuned",
                 embedding_model: str = "all-MiniLM-L6-v2",
                 device: str = None):
        """
        Initialize the enhanced RBAM model.
        
        Args:
            model_path: Path to the fine-tuned classification model
            embedding_model: SentenceTransformer model for semantic analysis
            device: Device to run the model on (None for auto-detection)
        """
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {self.device}")
        
        # Load classification model
        try:
            self.model_path = model_path
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)
            self.model.eval()
            logger.info(f"Successfully loaded classification model from {model_path}")
            
            # Determine relation labels from model config
            self.id2label = self.model.config.id2label if hasattr(self.model.config, 'id2label') else {
                0: "support", 1: "attack", 2: "neutral", 3: "detail"
            }
            self.label2id = {v: k for k, v in self.id2label.items()}
            logger.info(f"Loaded relation labels: {self.id2label}")
            
        except Exception as e:
            logger.error(f"Failed to load classification model: {e}")
            raise Exception(f"Model initialization failed: {str(e)}")
            
        # Load embedding model for semantic analysis
        try:
            self.embedding_model = SentenceTransformer(embedding_model).to(self.device)
            logger.info(f"Successfully loaded embedding model: {embedding_model}")
        except Exception as e:
            logger.error(f"Failed to load embedding model: {e}")
            self.embedding_model = None
            
        # Initialize cache
        self.cache = {}
        
    def train_model(self, train_data: List[Dict], validation_data: Optional[List[Dict]] = None, 
                   output_dir: str = "./model_output", epochs: int = 3, batch_size: int = 16):
        """
        Train or fine-tune the model on new data.
        
        Args:
            train_data: List of dictionaries with source_text, target_text, optional context, and label_id
            validation_data: Optional validation data with the same format as train_data
            output_dir: Directory to save the trained model
            epochs: Number of training epochs
            batch_size: Batch size for training
        """
        logger.info(f"Starting model training with {len(train_data)} examples")
        
        # Create datasets
        train_dataset = ArgumentRelationDataset(train_data, self.tokenizer)
        eval_dataset = ArgumentRelationDataset(validation_data, self.tokenizer) if validation_data else None
        
        # Define training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir=f"{output_dir}/logs",
            logging_steps=100,
            evaluation_strategy="epoch" if eval_dataset else "no",
            save_strategy="epoch",
            load_best_model_at_end=True if eval_dataset else False,
        )
        
        # Create trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset
        )
        
        # Train model
        trainer.train()
        
        # Save model
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        
        # Update current model
        self.model = trainer.model
        logger.info(f"Training completed, model saved to {output_dir}")
        
        # Clear cache after training
        self.cache = {}
        
        return {
            "status": "success",
            "message": f"Model trained on {len(train_data)} examples and saved to {output_dir}"
        }
        
    def predict_relation(self, 
                        source_text: str, 
                        target_text: str, 
                        context: Optional[str] = None) -> Dict:
        """
        Predict the relation between source and target arguments.
        
        Args:
            source_text: The source argument text
            target_text: The target argument text
            context: Optional context for the arguments
            
        Returns:
            Dictionary with relation type, confidence score, and features
        """
        # Create cache key
        cache_key = f"{source_text[:100]}|{target_text[:100]}|{context[:100] if context else 'no_context'}"
        
        # Check if result is cached
        if cache_key in self.cache:
            return self.cache[cache_key]
            
        # Prepare input
        input_text = self._prepare_input(source_text, target_text, context)
        inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        
        # Get model prediction
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = torch.nn.functional.softmax(logits, dim=1)[0]
            prediction = torch.argmax(probabilities).item()
            confidence = probabilities[prediction].item()
            
        # Get additional semantic features if embedding model is available
        semantic_features = {}
        if self.embedding_model:
            semantic_features = self._get_semantic_features(source_text, target_text)
        
        # Ensemble the predictions
        relation_type = self.id2label[prediction]
        
        # Apply confidence adjustment based on semantic features
        if semantic_features and 'similarity' in semantic_features:
            # Adjust confidence based on semantic similarity
            if relation_type in ['support', 'detail'] and semantic_features['similarity'] < 0.3:
                confidence = confidence * 0.8  # Reduce confidence for support/detail with low similarity
            elif relation_type == 'attack' and semantic_features['similarity'] > 0.8:
                confidence = confidence * 0.8  # Reduce confidence for attack with high similarity
        
        # Build result
        result = {
            'relation_type': relation_type,
            'confidence': confidence,
            'features': {
                'probabilities': {self.id2label[i]: prob.item() for i, prob in enumerate(probabilities)},
                **semantic_features
            }
        }
        
        # Cache the result
        self.cache[cache_key] = result
        
        return result
    
    def _prepare_input(self, source_text: str, target_text: str, context: Optional[str] = None) -> str:
        """
        Prepare input text for the model.
        """
        if context:
            return f"Context: {context}\nSource: {source_text}\nTarget: {target_text}"
        else:
            return f"Source: {source_text}\nTarget: {target_text}"
    
    def _get_semantic_features(self, source_text: str, target_text: str) -> Dict:
        """
        Extract semantic features for the argument pair.
        """
        # Get embeddings
        source_embedding = self.embedding_model.encode(source_text, convert_to_tensor=True)
        target_embedding = self.embedding_model.encode(target_text, convert_to_tensor=True)
        
        # Calculate similarity
        similarity = torch.cosine_similarity(source_embedding.unsqueeze(0), 
                                           target_embedding.unsqueeze(0)).item()
        
        return {
            'similarity': similarity,
        }
    
    def batch_predict(self, argument_pairs: List[Dict]) -> List[Dict]:
        """
        Predict relations for multiple argument pairs.
        
        Args:
            argument_pairs: List of dictionaries with source_text, target_text, and optional context
            
        Returns:
            List of prediction dictionaries
        """
        results = []
        
        for pair in argument_pairs:
            source = pair.get('source_text')
            target = pair.get('target_text')
            context = pair.get('context')
            
            if not source or not target:
                results.append({'error': 'Missing source or target text'})
                continue
                
            result = self.predict_relation(source, target, context)
            results.append(result)
            
        return results
    
    def analyze_graph(self, argument_pairs: List[Dict]) -> Dict:
        """
        Perform graph-based analysis of the argument network.
        
        Args:
            argument_pairs: List of dictionaries with source_text, target_text, and optional context
            
        Returns:
            Dictionary with graph analysis results
        """
        # Get predictions for all pairs
        predictions = self.batch_predict(argument_pairs)
        
        # Extract unique arguments
        unique_args = set()
        for pair in argument_pairs:
            unique_args.add(pair['source_text'])
            unique_args.add(pair['target_text'])
        
        # Build adjacency matrix
        n_args = len(unique_args)
        arg_list = list(unique_args)
        arg_to_idx = {arg: i for i, arg in enumerate(arg_list)}
        
        # Initialize matrices
        support_matrix = np.zeros((n_args, n_args))
        attack_matrix = np.zeros((n_args, n_args))
        
        # Fill matrices
        for i, pair in enumerate(argument_pairs):
            source_idx = arg_to_idx[pair['source_text']]
            target_idx = arg_to_idx[pair['target_text']]
            rel_type = predictions[i]['relation_type']
            confidence = predictions[i]['confidence']
            
            if rel_type == 'support':
                support_matrix[source_idx, target_idx] = confidence
            elif rel_type == 'attack':
                attack_matrix[source_idx, target_idx] = confidence
        
        # Calculate centrality metrics
        support_centrality = np.sum(support_matrix, axis=0)
        attack_centrality = np.sum(attack_matrix, axis=0)
        
        # Identify key arguments
        key_args = []
        for i, arg in enumerate(arg_list):
            total_centrality = support_centrality[i] + attack_centrality[i]
            if total_centrality > 0:
                key_args.append({
                    'text': arg[:100] + '...' if len(arg) > 100 else arg,
                    'support_centrality': float(support_centrality[i]),
                    'attack_centrality': float(attack_centrality[i]),
                    'total_centrality': float(total_centrality)
                })
        
        # Sort by total centrality
        key_args.sort(key=lambda x: x['total_centrality'], reverse=True)
        
        # Identify inconsistencies
        inconsistencies = self._identify_inconsistencies(argument_pairs, predictions)
        
        return {
            'key_arguments': key_args[:5],  # Top 5 key arguments
            'argument_count': len(unique_args),
            'relation_count': len(argument_pairs),
            'inconsistencies': inconsistencies
        }
    
    def _identify_inconsistencies(self, argument_pairs: List[Dict], predictions: List[Dict]) -> List[Dict]:
        """
        Identify potential inconsistencies in the argument network.
        """
        inconsistencies = []
        
        # Build graph
        graph = {}
        for i, pair in enumerate(argument_pairs):
            source = pair['source_text']
            target = pair['target_text']
            rel_type = predictions[i]['relation_type']
            
            if source not in graph:
                graph[source] = {}
            if target not in graph:
                graph[target] = {}
                
            graph[source][target] = rel_type
        
        # Check for circular support/attack patterns
        for source in graph:
            for target in graph.get(source, {}):
                rel_type = graph[source][target]
                
                # Check if there's a reverse relation
                if target in graph and source in graph.get(target, {}):
                    reverse_rel = graph[target][source]
                    
                    # Check if there's an inconsistency
                    if rel_type == reverse_rel and rel_type in ['support', 'attack']:
                        inconsistencies.append({
                            'type': 'circular_relation',
                            'source': source[:50] + '...' if len(source) > 50 else source,
                            'target': target[:50] + '...' if len(target) > 50 else target,
                            'relation': rel_type,
                            'description': f"Circular {rel_type} relation detected"
                        })
        
        # Check for support-attack contradictions
        for arg1 in graph:
            for arg2 in graph.get(arg1, {}):
                rel1 = graph[arg1][arg2]
                
                # Check if arg1 and arg2 have relations to the same arg3
                for arg3 in graph:
                    if arg3 != arg1 and arg3 != arg2 and arg2 in graph.get(arg3, {}):
                        rel2 = graph[arg3][arg2]
                        
                        # Check if one supports while the other attacks
                        if rel1 == 'support' and rel2 == 'attack' or rel1 == 'attack' and rel2 == 'support':
                            inconsistencies.append({
                                'type': 'competing_relations',
                                'arguments': [
                                    arg1[:50] + '...' if len(arg1) > 50 else arg1,
                                    arg2[:50] + '...' if len(arg2) > 50 else arg2,
                                    arg3[:50] + '...' if len(arg3) > 50 else arg3
                                ],
                                'description': f"Competing {rel1}/{rel2} relations detected"
                            })
        
        return inconsistencies

def save_uploaded_files(files):
    """
    Save uploaded files to a temporary directory and return their paths.
    """
    if not os.path.exists("temp_uploads"):
        os.makedirs("temp_uploads")
        
    file_paths = []
    for i, file in enumerate(files):
        temp_path = os.path.join("temp_uploads", f"upload_{i}_{os.path.basename(file.name)}")
        
        # Fix: Handle both file-like objects and Gradio's UploadedFile objects
        if hasattr(file, 'read'):
            # File is a file-like object
            with open(temp_path, "wb") as f:
                f.write(file.read())
        else:
            # File might be a path string from Gradio
            if isinstance(file, str):
                # It's already a file path
                temp_path = file
            else:
                # Copy the file content
                with open(file.name, "rb") as src, open(temp_path, "wb") as dst:
                    dst.write(src.read())
                    
        file_paths.append(temp_path)
    
    return file_paths

# Create a function to run the interactive web UI
def run_interactive_rbam_system():
    # Create the model
    model_path = "distilbert-base-uncased"  # Default model, will be fine-tuned
    
    try:
        rbam_model = EnhancedRBAMModel(model_path=model_path)
        model_status = "Model loaded successfully"
    except Exception as e:
        # If loading fails, use a fallback approach
        from transformers import DistilBertForSequenceClassification
        model = DistilBertForSequenceClassification.from_pretrained(
            "distilbert-base-uncased", 
            num_labels=4
        )
        model.config.id2label = {0: "support", 1: "attack", 2: "neutral", 3: "detail"}
        model.config.label2id = {"support": 0, "attack": 1, "neutral": 2, "detail": 3}
        
        # Create output directory if it doesn't exist
        os.makedirs("models/distilbert-rbam", exist_ok=True)
        model.save_pretrained("models/distilbert-rbam")
        
        # Now initialize with the saved model
        rbam_model = EnhancedRBAMModel(model_path="models/distilbert-rbam")
        model_status = "Initialized with base model (requires training)"
    
    # Sample training data
    sample_train_data = [
        {
            "source_text": "Climate change is primarily caused by human activities.",
            "target_text": "The rise in global temperatures correlates with increased CO2 emissions.",
            "context": "Environmental science debate",
            "label_id": 0  # support
        },
        {
            "source_text": "Renewable energy is too expensive to replace fossil fuels.",
            "target_text": "Solar panel costs have decreased by 90% in the last decade.",
            "context": "Energy policy discussion",
            "label_id": 1  # attack
        },
        {
            "source_text": "Excessive social media use is harmful to mental health.",
            "target_text": "Studies show correlation between screen time and anxiety in teens.",
            "context": "Public health forum",
            "label_id": 0  # support
        }
    ]
    
    # Sample argument network data
    sample_graph_data = [
        {"source_text": "Climate change is primarily caused by human activities.", 
         "target_text": "The rise in global temperatures correlates with increased CO2 emissions."},
        {"source_text": "Renewable energy is too expensive to replace fossil fuels.", 
         "target_text": "Solar panel costs have decreased by 90% in the last decade."},
        {"source_text": "Solar panel costs have decreased by 90% in the last decade.", 
         "target_text": "Renewable energy is becoming more economically viable."}
    ]
    
    # Function to train the model
    def train_model_from_ui(file_inputs, train_data_json, epochs, batch_size):
        try:
            train_data = []
            
            # Process uploaded files if any
            if file_inputs:
                file_paths = save_uploaded_files(file_inputs)
                loaded_data = load_data_from_json_files(file_paths)
                train_data.extend(loaded_data)
                
            # Process pasted JSON if not empty
            if train_data_json.strip():
                pasted_data = json.loads(train_data_json)
                if isinstance(pasted_data, list):
                    train_data.extend(pasted_data)
                else:
                    train_data.append(pasted_data)
            
            # Validate data format
            if not train_data:
                return "Error: No training data provided"
                
            for item in train_data:
                if not all(k in item for k in ["source_text", "target_text", "label_id"]):
                    return "Error: Each training example must have source_text, target_text, and label_id"
            
            # Train the model
            result = rbam_model.train_model(
                train_data=train_data,
                epochs=int(epochs),
                batch_size=int(batch_size),
                output_dir="models/rbam-custom"
            )
            
            return f"Training completed with {len(train_data)} examples: {result['message']}"
        except Exception as e:
            return f"Training error: {str(e)}"
    
    # Function to predict relations
    def predict_relation_from_ui(source_text, target_text, context):
        if not source_text or not target_text:
            return "Error: Source and target arguments are required"
        
        result = rbam_model.predict_relation(
            source_text=source_text,
            target_text=target_text,
            context=context if context else None
        )
        
        # Format the result for display
        formatted_result = f"""
## Prediction Result

**Relation Type:** {result['relation_type']}
**Confidence:** {result['confidence']:.4f} ({result['confidence']*100:.1f}%)

### Probability Distribution:
"""
        
        # Add probability distribution
        probs = result['features']['probabilities']
        for rel, prob in probs.items():
            formatted_result += f"- {rel}: {prob:.4f} ({prob*100:.1f}%)\n"
        
        # Add semantic similarity if available
        if 'similarity' in result['features']:
            formatted_result += f"\n**Semantic Similarity:** {result['features']['similarity']:.4f}\n"
            
        return formatted_result
    
    # Function to analyze argument graph
    def analyze_graph_from_ui(file_inputs, graph_data_json):
        try:
            argument_pairs = []
            
            # Process uploaded files if any
            if file_inputs:
                file_paths = save_uploaded_files(file_inputs)
                loaded_data = load_data_from_json_files(file_paths)
                argument_pairs.extend(loaded_data)
                
            # Process pasted JSON if not empty
            if graph_data_json.strip():
                pasted_data = json.loads(graph_data_json)
                if isinstance(pasted_data, list):
                    argument_pairs.extend(pasted_data)
                else:
                    argument_pairs.append(pasted_data)
            
            # Validate data format
            if not argument_pairs:
                return "Error: No argument pairs provided"
                
            for item in argument_pairs:
                if not all(k in item for k in ["source_text", "target_text"]):
                    return "Error: Each pair must have source_text and target_text"
            
            # Analyze the graph
            result = rbam_model.analyze_graph(argument_pairs)
            
            # Format the result for display
            formatted_result = f"""
## Graph Analysis Result

**Arguments:** {result['argument_count']}
**Relations:** {result['relation_count']}

### Key Arguments:
"""
            
            # Add key arguments
            for i, arg in enumerate(result['key_arguments']):
                formatted_result += f"{i+1}. **{arg['text']}**\n"
                formatted_result += f"   - Support: {arg['support_centrality']:.2f}\n"
                formatted_result += f"   - Attack: {arg['attack_centrality']:.2f}\n"
                formatted_result += f"   - Total: {arg['total_centrality']:.2f}\n\n"
            
            # Add inconsistencies
            formatted_result += f"\n### Inconsistencies ({len(result['inconsistencies'])}):\n"
            
            for i, inconsistency in enumerate(result['inconsistencies']):
                formatted_result += f"{i+1}. **{inconsistency['type']}**\n"
                formatted_result += f"   {inconsistency['description']}\n\n"
            
            return formatted_result
        except Exception as e:
            return f"Analysis error: {str(e)}"
    
    # Create the Gradio interface
    with gr.Blocks(title="Enhanced RBAM System") as demo:
        gr.Markdown("# Enhanced Relation-Based Argument Mining (RBAM) System")
        gr.Markdown(f"**Model Status:** {model_status}")
        
        with gr.Tab("Model Training"):
            gr.Markdown("### Train the RBAM Model")
            gr.Markdown("Use this tab to fine-tune the model on your own data.")
            
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("**Option 1: Upload JSON Files**")
                    file_inputs = gr.File(
                        file_count="multiple",
                        label="Upload JSON Files",
                        file_types=[".json"]
                    )
                
                with gr.Column(scale=1):
                    gr.Markdown("**Option 2: Paste JSON Data**")
                    train_data_input = gr.Textbox(
                        label="Training Data (JSON format)",
                        placeholder=json.dumps(sample_train_data, indent=2),
                        lines=10
                    )
            
            with gr.Row():
                epochs_input = gr.Number(label="Epochs", value=3, minimum=1)
                batch_size_input = gr.Number(label="Batch Size", value=8, minimum=1)
            
            train_button = gr.Button("Train Model")
            train_output = gr.Textbox(label="Training Result")
            
            train_button.click(
                fn=train_model_from_ui,
                inputs=[file_inputs, train_data_input, epochs_input, batch_size_input],
                outputs=train_output
            )
        
        with gr.Tab("Relation Prediction"):
            gr.Markdown("### Predict Relations Between Arguments")
            
            source_input = gr.Textbox(label="Source Argument", lines=4)
            target_input = gr.Textbox(label="Target Argument", lines=4)
            context_input = gr.Textbox(label="Context (optional)", lines=2)
            
            predict_button = gr.Button("Predict Relation")
            prediction_output = gr.Markdown(label="Prediction Result")
            
            predict_button.click(
                fn=predict_relation_from_ui,
                inputs=[source_input, target_input, context_input],
                outputs=prediction_output
            )
        
        with gr.Tab("Graph Analysis"):
            gr.Markdown("### Analyze Argument Network")
            
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("**Option 1: Upload JSON Files**")
                    graph_file_inputs = gr.File(
                        file_count="multiple",
                        label="Upload JSON Files with Argument Pairs",
                        file_types=[".json"]
                    )
                
                with gr.Column(scale=1):
                    gr.Markdown("**Option 2: Paste JSON Data**")
                    graph_data_input = gr.Textbox(
                        label="Argument Pairs (JSON format)",
                        placeholder=json.dumps(sample_graph_data, indent=2),
                        lines=10
                    )
            
            analyze_button = gr.Button("Analyze Graph")
            analysis_output = gr.Markdown(label="Analysis Result")
            
            analyze_button.click(
                fn=analyze_graph_from_ui,
                inputs=[graph_file_inputs, graph_data_input],
                outputs=analysis_output
            )
        
        with gr.Tab("Help"):
            gr.Markdown("""
            ## How to Use the RBAM System
            
            ### Data Format
            
            #### Training Data Format
            ```json
            [
                {
                    "source_text": "First argument text",
                    "target_text": "Second argument text",
                    "context": "Optional context information",
                    "label_id": 0  // 0: support, 1: attack, 2: neutral, 3: detail
                },
                ...
            ]
            ```
            
            #### Argument Pairs Format for Graph Analysis
            ```json
            [
                {
                    "source_text": "First argument text",
                    "target_text": "Second argument text",
                    "context": "Optional context information"
                },
                ...
            ]
            ```
            
            ### Relation Types
            - **Support**: Source argument supports or reinforces the target argument
            - **Attack**: Source argument contradicts or weakens the target argument
            - **Neutral**: Source argument neither supports nor attacks the target argument
            - **Detail**: Source argument provides additional details about the target argument
            
            ### Tips
            - For best results, provide clear and concise arguments
            - Context is optional but can improve prediction accuracy
            - The model performs better when trained on domain-specific data
            - In graph analysis, larger networks may take longer to process
            """)
    
    # Launch the interface
    demo.launch(share=True)
    return "Interactive RBAM system started"

# Run the system when this script is executed
if __name__ == "__main__":
    # Install required packages if not already installed
    try:
        import gradio
    except ImportError:
        print("Installing required packages...")
        import pip
        pip.main(['install', 'gradio', 'transformers', 'sentence-transformers', 'torch', 'datasets'])
    
    print("Starting Enhanced RBAM Interactive System...")
    run_interactive_rbam_system()

Starting Enhanced RBAM Interactive System...


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


* Running on local URL:  http://127.0.0.1:7863
* Running on public URL: https://9e2ee11232c31f4df3.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
