In [None]:
# Install necessary packages
!pip install transformers torch scikit-learn

import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split, accuracy_score
from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
from tqdm import tqdm

# Custom dataset class for the incident descriptions
class IncidentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        text = self.texts[index]
        label = self.labels[index]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.float)
        }

# Load the tokenizer and SciBERT model
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")

max_len = 256  # Adjust based on your text length

# Split dataset into train and test sets
X = df['Processed_Description'].values
y = df['MI_Incident'].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)

train_dataset = IncidentDataset(X_train, y_train, tokenizer, max_len)
test_dataset = IncidentDataset(X_test, y_test, tokenizer, max_len)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define a classifier based on SciBERT
class SciBERTClassifier(torch.nn.Module):
    def __init__(self, bert_model, num_labels=1):
        super(SciBERTClassifier, self).__init__()
        self.bert = bert_model
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # CLS token output
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

model = SciBERTClassifier(model).to(device)

# Use weighted BinaryCrossEntropy for imbalanced classes
class_weights = torch.tensor([0.55]).to(device)  # Adjust the weights based on class imbalance
criterion = BCEWithLogitsLoss(pos_weight=class_weights)

optimizer = AdamW(model.parameters(), lr=2e-5)
epochs = 3

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device).unsqueeze(1)

        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.4f}')

# Define the evaluation function for unique misclassification checking
def evaluate(model, loader, dataset_name="Test", df=None):
    model.eval()
    y_preds = []
    y_true = []
    misclassified_samples = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()

            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            probs = torch.sigmoid(logits).cpu().numpy().flatten()
            preds = (probs > 0.5).astype(int)

            y_preds.extend(preds)
            y_true.extend(labels)

            # Track misclassified samples with associated probabilities
            for i, (pred, true, prob) in enumerate(zip(preds, labels, probs)):
                if pred != true:
                    global_idx = batch_idx * loader.batch_size + i
                    misclassified_samples.append((global_idx, batch['text'][i], pred, true, prob))

    # Print classification report
    accuracy = accuracy_score(y_true, y_preds)
    print(f'{dataset_name} Accuracy: {accuracy * 100:.2f}%')
    print(f'{dataset_name} Classification Report:\n')
    print(classification_report(y_true, y_preds, target_names=['Non-MI', 'MI']))

    # Print misclassified descriptions with probabilities
    if df is not None:
        print(f'\n{dataset_name} Misclassified Samples:')
        for idx, text, pred, true, prob in misclassified_samples:
            if idx < len(df):
                print(f"Description: {df['Description'].iloc[idx]}")
                print(f"Text: {text}\nPredicted Label: {pred}, True Label: {true}, Probability: {prob:.4f}\n")


evaluate(model, train_loader, "Train", df)
evaluate(model, test_loader, "Test", df)

# Save the trained model and tokenizer
model.bert.save_pretrained('scibert_cls_model')
tokenizer.save_pretrained('scibert_cls_tokenizer')


Helper Functions

In [None]:
def prepare_data(texts, tokenizer, max_len):
    """
    Prepares the input texts for prediction by tokenizing and creating tensors.

    Args:
        texts (list of str): Input texts.
        tokenizer: Loaded tokenizer.
        max_len (int): Maximum length for tokenization.

    Returns:
        dict: Input tensors ready for the model.
    """
    encoding = tokenizer.batch_encode_plus(
        texts,
        add_special_tokens=True,
        max_length=max_len,
        return_token_type_ids=False,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    return {
        'input_ids': encoding['input_ids'].to(device),
        'attention_mask': encoding['attention_mask'].to(device)
    }

def predict(texts):
    """
    Predicts the class and confidence score for given texts using the fine-tuned model.

    Args:
        texts (list of str): Input texts to predict.

    Returns:
        list: Predicted labels and associated probabilities for each text.
    """
    model.eval()  # Ensure the model is in evaluation mode
    data = prepare_data(texts, tokenizer, max_len=256)  # Use the same max_len as during training

    with torch.no_grad():
        logits = model(data['input_ids'], data['attention_mask'])
        probs = torch.sigmoid(logits).cpu().numpy().flatten()  # Sigmoid to get probability scores
        preds = (probs > 0.5).astype(int)  # Threshold at 0.5 for binary classification

    results = []
    for pred, prob in zip(preds, probs):
        results.append({
            'predicted_label': 'MI' if pred == 1 else 'Non-MI',
            'confidence': prob if pred == 1 else 1 - prob  # Confidence for the predicted class
        })

    return results

# Prepare the text data from df2
texts = df2['Processed_Description'].tolist()

# Get predictions
predictions = predict(texts)

# Convert predictions to DataFrame for easier manipulation
pred_df = pd.DataFrame(predictions)

# Add predictions and confidence to df2
df2['Scibert_pred'] = pred_df['predicted_label']
df2['Scibert_prob'] = pred_df['confidence']

# Print the updated DataFrame with predictions
print(df2[['Processed_Description', 'Scibert_pred', 'Scibert_prob']])


## Integrated Gradient Helper Function

In [None]:
# Install Captum for Integrated Gradients if not already installed
!pip install captum

import torch
from captum.attr import IntegratedGradients

# Helper function to compute and return important tokens
def compute_important_tokens(model, tokenizer, text, label, max_len=256, baseline_text="[PAD]", n_steps=50):
    """
    Computes important tokens using Integrated Gradients.

    Args:
        model: The fine-tuned transformer model.
        tokenizer: The tokenizer corresponding to the model.
        text (str): The input text.
        label (int): The label index for which attributions are computed (usually 0 or 1).
        max_len (int): The maximum token length of input text.
        baseline_text (str): Baseline text, typically a pad token.
        n_steps (int): Number of steps for integration in IG.

    Returns:
        list of tuples: Important tokens and their scores.
    """
    model.eval()

    # Tokenize input and baseline text
    input_ids = tokenizer.encode(text, return_tensors='pt', max_length=max_len, truncation=True).to(device)
    baseline_ids = tokenizer.encode(baseline_text, return_tensors='pt', max_length=max_len, truncation=True).to(device)

    # Prepare attention mask
    input_mask = torch.ones_like(input_ids).to(device)

    # Initialize Integrated Gradients object
    ig = IntegratedGradients(model)

    # Compute attributions using integrated gradients
    attributions, _ = ig.attribute(
        inputs=input_ids,
        baselines=baseline_ids,
        target=label,
        additional_forward_args=(input_mask,),
        n_steps=n_steps,
        return_convergence_delta=False
    )

    # Decode tokens and match with attributions
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    attributions = attributions.squeeze(0).sum(dim=1).detach().cpu().numpy()

    # Zip tokens with their attributions and sort by importance
    important_tokens = sorted(zip(tokens, attributions), key=lambda x: abs(x[1]), reverse=True)
    return important_tokens[:5]  # Return top 5 important tokens


# Modified predict function incorporating important token computation
def predict(texts):
    """
    Predicts the class, confidence score, and important tokens for given texts.

    Args:
        texts (list of str): Input texts to predict.

    Returns:
        list: Predicted labels, associated probabilities, and important tokens for each text.
    """
    model.eval()  # Ensure the model is in evaluation mode
    data = prepare_data(texts, tokenizer, max_len=256)  # Use the same max_len as during training

    results = []
    with torch.no_grad():
        logits = model(data['input_ids'], data['attention_mask'])
        probs = torch.sigmoid(logits).cpu().numpy().flatten()  # Sigmoid to get probability scores
        preds = (probs > 0.5).astype(int)  # Threshold at 0.5 for binary classification

        # Compute important tokens for each prediction
        for text, pred, prob in zip(texts, preds, probs):
            important_tokens = compute_important_tokens(model, tokenizer, text, pred)
            results.append({
                'predicted_label': 'MI' if pred == 1 else 'Non-MI',
                'confidence': prob if pred == 1 else 1 - prob,  # Confidence for the predicted class
                'important_tokens': important_tokens
            })

    return results

# Prepare the text data from df2
texts = df2['Processed_Description'].tolist()

# Get predictions with important tokens
predictions = predict(texts)

# Convert predictions to DataFrame for easier manipulation
pred_df = pd.DataFrame(predictions)

# Add predictions, confidence, and important tokens to df2
df2['Scibert_pred'] = pred_df['predicted_label']
df2['Scibert_prob'] = pred_df['confidence']
df2['Important_Tokens'] = pred_df['important_tokens']

# Print the updated DataFrame with predictions and important tokens
print(df2[['Processed_Description', 'Scibert_pred', 'Scibert_prob', 'Important_Tokens']])


# Freezing SciBERT

1. Freezing Parameters: The parameters of the SciBERT model are frozen by setting ```param.requires_grad = False ```.

2. Optimizer Scope: The optimizer is set to only optimize ``` model.classifier.parameters()```, meaning it will update only the weights of the classifier layer during training.

In [None]:
# Freeze all the parameters of the SciBERT model
for param in model.parameters():
    param.requires_grad = False

# Define a classifier based on SciBERT
class SciBERTClassifier(torch.nn.Module):
    def __init__(self, bert_model, num_labels=1):
        super(SciBERTClassifier, self).__init__()
        self.bert = bert_model
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # CLS token output
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

# Create the SciBERT classifier and move it to the device
model = SciBERTClassifier(model).to(device)

# Use weighted BinaryCrossEntropy for imbalanced classes
class_weights = torch.tensor([1.0, 5.0]).to(device)  # Adjust the weights based on class imbalance
criterion = BCEWithLogitsLoss(pos_weight=class_weights)

# Only optimize the classifier's parameters
optimizer = AdamW(model.classifier.parameters(), lr=2e-5)

# Training loop
epochs = 3
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device).unsqueeze(1)

        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.4f}')

# Evaluate the model on the training and test datasets
evaluate(model, train_loader, "Train", df)
evaluate(model, test_loader, "Test", df)

# Save the trained model and tokenizer
model.bert.save_pretrained('scibert_cls_model')
tokenizer.save_pretrained('scibert_cls_tokenizer')


## Plotting Loss

In [None]:
import matplotlib.pyplot as plt
from torch.nn import BCEWithLogitsLoss
from transformers import AdamW
from tqdm import tqdm

# Initialize lists to store losses
train_losses = []
val_losses = []

# Training loop
for epoch in range(epochs):
    model.train()
    total_train_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device).unsqueeze(1)

        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        total_train_loss += loss.item()

        loss.backward()
        optimizer.step()

    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f'Epoch {epoch + 1}/{epochs} - Train Loss: {avg_train_loss:.4f}')

    # Validation phase
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in tqdm(test_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device).unsqueeze(1)

            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(logits, labels)
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(test_loader)
    val_losses.append(avg_val_loss)
    print(f'Epoch {epoch + 1}/{epochs} - Validation Loss: {avg_val_loss:.4f}')

# Plotting the losses
plt.figure(figsize=(12, 6))
plt.plot(range(1, epochs + 1), train_losses, label='Train Loss')
plt.plot(range(1, epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()
