In [13]:
import torch
from transformers import BertTokenizer, BertModel
import torch.nn as nn
from google.colab import drive
drive.mount('/content/drive')
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the model class
class BERTClass(nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')  # Changed from self.bert to self.bert_model
        self.dropout = nn.Dropout(0.3)
        self.linear = nn.Linear(768, 32)  # Changed to 32 labels based on the context

    def forward(self, input_ids, attention_mask, token_type_ids):
        output = self.bert_model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        dropout_output = self.dropout(output[1])
        final_output = self.linear(dropout_output)
        return final_output

# Define target labels (32 labels based on the context)
target_list = ['ad hominem',
 'anecdotal fallacy',
 'appeal to authority',
 'appeal to consequences',
 'appeal to emotion',
 'appeal to fear',
 'appeal to novelty',
 'appeal to popularity',
 'appeal to ridicule',
 'appeal to tradition',
 'argument from ignorance',
 'bandwagon fallacy',
 'circular reasoning',
 'correlation vs. causation',
 'equivocation',
 'false analogy',
 'false attribution',
 'false dilemma',
 'genetic fallacy',
 'guilt by association',
 'hasty generalization',
 'no true scotsman',
 'red herring',
 'slippery slope',
 'straw man',
 'tu quoque',
 'appeal to motive',
 'loaded question',
 'misleading vividness',
 'none',
 'composition/division',
 'other']

def load_model():
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Initialize model
    model = BERTClass()

    # Load the trained model weights
    checkpoint = torch.load("/content/drive/MyDrive/bert_model/best_model.pt",
                          map_location=device) #download and use the model weights from: https://drive.google.com/file/d/1-8nspRmZ0x6pMZrOdWgPV9RINjFZYm0B/view?usp=share_link
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()

    return model, tokenizer

def predict_text(raw_text, model, tokenizer):
    # Tokenize the input text
    encoded_text = tokenizer.encode_plus(
        raw_text,
        max_length=256,
        add_special_tokens=True,
        return_token_type_ids=True,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
        truncation=True
    )

    # Prepare input tensors
    input_ids = encoded_text['input_ids'].to(device, dtype=torch.long)
    attention_mask = encoded_text['attention_mask'].to(device, dtype=torch.long)
    token_type_ids = encoded_text['token_type_ids'].to(device, dtype=torch.long)

    # Get predictions
    with torch.no_grad():
        outputs = model(input_ids, attention_mask, token_type_ids)
        outputs = torch.sigmoid(outputs)
        predictions = (outputs > 0.5).float()  # Using 0.5 as threshold

    # Print results
    print(f"\nAnalyzing: {raw_text}")
    print("\nDetected fallacies:")
    found_fallacy = False
    for idx, p in enumerate(predictions[0]):
        if p == 1:
            found_fallacy = True
            print(f"- {target_list[idx]}")

    if not found_fallacy:
        print("No fallacies detected.")

# Main execution
if __name__ == "__main__":
    print("Loading model...")
    model, tokenizer = load_model()
    print("Model loaded successfully!")



Loading model...


  checkpoint = torch.load("/content/drive/MyDrive/bert_model/best_model.pt",


Model loaded successfully!


In [24]:
text = "I saw a red apple so all the apples are red"

predict_text(text, model, tokenizer)



Analyzing: I saw a red apple so all the apples are red

Detected fallacies:
- hasty generalization


In [25]:
text = "I like your shoes."

predict_text(text, model, tokenizer)


Analyzing: I like your shoes.

Detected fallacies:
No fallacies detected.
