In [None]:
pip install flask flask-cors torch transformers




In [None]:
from flask import Flask, request, jsonify
from transformers import BartForConditionalGeneration, BartTokenizer
from flask_cors import CORS
import torch
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize Flask app
app = Flask(__name__)
CORS(app)

# Load BART model and tokenizer
try:
    logger.info("Loading BART model and tokenizer...")
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
    model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
    if torch.cuda.is_available():
        model = model.cuda()
        logger.info("Model loaded on GPU")
    else:
        logger.info("Model loaded on CPU")
except Exception as e:
    logger.error(f"Error loading BART model: {e}")
    model = None
    tokenizer = None

def generate_summary(text):
    """Generate summary using BART model with improved error handling and validation."""
    if not text:
        raise ValueError("No text provided for summarization.")

    if not model or not tokenizer:
        raise RuntimeError("Summarization model not properly initialized.")

    try:
        # Validate input length
        word_count = len(text.split())
        if word_count < 40:
            return text

        # Calculate dynamic length parameters based on input
        max_length = min(150, max(50, word_count // 2))
        min_length = min(30, max(10, word_count // 4))

        # Tokenize input
        inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)

        # Move inputs to GPU if available
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}

        # Generate summary
        summary_ids = model.generate(
            inputs.input_ids,
            max_length=max_length,
            min_length=min_length,
            length_penalty=2.0,
            num_beams=4,
            early_stopping=True
        )

        # Decode summary
        summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

        # Validate output
        if len(summary.split()) < 10 or summary.strip() == text.strip():
            logger.warning("Generated summary too short or identical to input")
            return text

        return summary

    except Exception as e:
        logger.error(f"Error in generate_summary: {str(e)}")
        raise

@app.route('/summarize', methods=['POST'])
def summarize():
    """API endpoint for text summarization with comprehensive error handling."""
    try:
        data = request.get_json()
        if not data:
            return jsonify({
                'success': False,
                'error': 'No JSON data provided'
            }), 400

        text = data.get('text', '').strip()
        if not text:
            return jsonify({
                'success': False,
                'error': 'No text provided for summarization'
            }), 400

        summary = generate_summary(text)
        return jsonify({
            'success': True,
            'summary': summary,
            'original_length': len(text.split()),
            'summary_length': len(summary.split())
        })

    except ValueError as e:
        return jsonify({
            'success': False,
            'error': str(e)
        }), 400

    except RuntimeError as e:
        return jsonify({
            'success': False,
            'error': str(e)
        }), 503

    except Exception as e:
        logger.error(f"Unexpected error in summarize endpoint: {str(e)}")
        return jsonify({
            'success': False,
            'error': 'Internal server error occurred'
        }), 500

if __name__ == '__main__':
    # Add basic health check endpoint
    @app.route('/health', methods=['GET'])
    def health_check():
        return jsonify({
            'status': 'healthy',
            'model_loaded': model is not None and tokenizer is not None
        })

    # Configuration
    host = '0.0.0.0'
    port = 5000

    logger.info(f"Starting server at http://{host}:{port}")
    app.run(debug=True, host=host, port=port)

RuntimeError: Failed to import transformers.models.bart.modeling_bart because of the following error (look up to see its traceback):
module 'sympy' has no attribute 'printing'