In [1]:
#load packages
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm
import numpy as np

class BlameDetectorDa(object):

    def __init__(self, model_path, max_length=512, batch_size=32):

        self.model_path = model_path
        self.max_length = max_length
        self.batch_size = batch_size

        self.model_initialization()

        return

    def model_initialization(self):
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_path,
            device_map='auto'
        )

        self.model.eval()

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        
        # Move to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model.to(self.device)

        print(f"Model loaded successfully on {self.device}")

        return

    def predict(self, text):
        """Make a prediction on a single text input."""
        # Tokenize input
        inputs = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
       
        # Move inputs to device
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        # Make prediction
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0][predicted_class].item()
        
        return predicted_class, confidence, probabilities[0].cpu().numpy()

    def predict_batch(self, texts):
        """Make predictions on a batch of texts."""
        # Tokenize all texts in the batch
        inputs = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Move inputs to device
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        # Make predictions
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=1)
            predicted_classes = torch.argmax(probabilities, dim=1).cpu().numpy()
            confidences = probabilities[range(len(predicted_classes)), predicted_classes].cpu().numpy()
        
        return predicted_classes, confidences, probabilities.cpu().numpy()

    def run_prediction(self, text):
        """Single text prediction (backward compatibility)."""
        predicted_class, confidence, probs = self.predict(text)
        return predicted_class, confidence

    def run_batch_prediction(self, texts, show_progress=True):
        """
        Run predictions on a list of texts with batching and progress bar.
        
        Args:
            texts: List of text strings to predict on
            show_progress: Whether to show progress bar (default: True)
            
        Returns:
            predicted_classes: numpy array of predicted class indices
            confidences: numpy array of confidence scores
        """
        all_predictions = []
        all_confidences = []
        
        # Create batches
        num_batches = (len(texts) + self.batch_size - 1) // self.batch_size
        
        # Setup progress bar
        if show_progress:
            pbar = tqdm(total=len(texts), desc="Processing texts", unit="text")
        
        # Process in batches
        for i in range(0, len(texts), self.batch_size):
            batch_texts = texts[i:i + self.batch_size]
            
            # Get predictions for this batch
            predicted_classes, confidences, _ = self.predict_batch(batch_texts)
            
            all_predictions.extend(predicted_classes)
            all_confidences.extend(confidences)
            
            # Update progress bar
            if show_progress:
                pbar.update(len(batch_texts))
        
        if show_progress:
            pbar.close()
        
        return np.array(all_predictions), np.array(all_confidences)


    def predict_from_json(self, json_path, text_key = 'text', output_key='prediction', 
                          confidence_key='confidence', show_progress=True):
        """
        Load JSON, predict, and return results with predictions added.
        
        Args:
            json_path: Path to JSON file
            text_key: Key in JSON objects containing the text to classify
            output_key: Key name for storing predictions (default: 'prediction')
            confidence_key: Key name for storing confidence scores (default: 'confidence')
            show_progress: Whether to show progress bar
            
        Returns:
            List of dictionaries with predictions added
        """
        import json
        
        # Load JSON data
        print(f"Loading data from {json_path}...")
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        print('json file loaded')
        
        # Handle different JSON structures
        if isinstance(data, list):
            items = data
        elif isinstance(data, dict):
            # If it's a dict, try to find the list of items
            # Adjust this based on your JSON structure
            items = list(data.values()) if all(isinstance(v, dict) for v in data.values()) else [data]
        else:
            raise ValueError("Unsupported JSON structure")
        
        print(f"Found {len(items)} items. extracting texts...")
        
        # Extract texts
        texts = [item[text_key] for item in items]

        print('start batch prediction')
        
        # Run predictions
        predictions, confidences = self.run_batch_prediction(texts, show_progress=show_progress)
        
        # Add predictions to original data
        for item, pred, conf in zip(items, predictions, confidences):
            item[output_key] = int(pred)
            item[confidence_key] = float(conf)
        
        return items

    def predict_from_json_to_file(self, input_path, output_path, text_key, 
                                   output_key='prediction', confidence_key='confidence',
                                   show_progress=True):
        """
        Load JSON, predict, and save results to a new file.
        
        Args:
            input_path: Path to input JSON file
            output_path: Path to save output JSON file
            text_key: Key containing text to classify
            output_key: Key name for predictions
            confidence_key: Key name for confidence scores
            show_progress: Whether to show progress bar
        """
        import json
        
        # Get predictions
        results = self.predict_from_json(
            input_path, text_key, output_key, confidence_key, show_progress
        )
        
        # Save to file
        print(f"Saving results to {output_path}...")
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        
        print(f"Done! Processed {len(results)} items")
        
        return results

In [5]:
detector = BlameDetectorDa(
    model_path="Lundsfryd/Pol_Blame_Detection_Da",
    max_length=512,
    batch_size=64
)

Model loaded successfully on cuda


In [None]:

# Method 2: Process and save to file directly
detector.predict_from_json_to_file(
    input_path="/work/MarkusLundsfrydJensen#1865/final_inference_data.json",
    output_path="/work/MarkusLundsfrydJensen#1865/final_final_inference_data.json",
    text_key="text"
)


Loading data from /work/MarkusLundsfrydJensen#1865/final_inference_data.json...


Processing texts:   1%|          | 19328/2191051 [01:03<1:06:37, 543.30text/s]

json file loaded
Found 2191051 items. extracting texts...
start batch prediction


