<a href="https://colab.research.google.com/github/JithinBinoy-sudo/Improve-Downward-Monotonicity/blob/main/Downward_Monotonicity_Improvement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install --upgrade transformers


Collecting transformers
  Downloading transformers-4.55.4-py3-none-any.whl.metadata (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Downloading transformers-4.55.4-py3-none-any.whl (11.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.3/11.3 MB[0m [31m88.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.55.2
    Uninstalling transformers-4.55.2:
      Successfully uninstalled transformers-4.55.2
Successfully installed transformers-4.55.4


In [None]:
import os
import random
import pandas as pd
import torch
from collections import Counter
from sklearn.model_selection import train_test_split
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments

# ----------------------------
# 1. DISABLE W&B LOGGING
# ----------------------------
os.environ["WANDB_DISABLED"] = "true"

# ----------------------------
# 2. DOWNLOAD HELP DATASET
# ----------------------------
if not os.path.exists("help_dataset.tsv"):
    os.system("wget https://github.com/verypluming/HELP/raw/master/output_en/pmb_train_v1.0.tsv -O help_dataset.tsv")

# ----------------------------
# 3. LOAD HELP DATASET
# ----------------------------
df_help = pd.read_csv("help_dataset.tsv", sep="\t")
df_help = df_help[['ori_sentence', 'new_sentence', 'gold_label']]
df_help.columns = ['premise', 'hypothesis', 'label']

label_mapping = {'entailment': 0, 'neutral': 1, 'contradiction': 2}
df_help['label'] = df_help['label'].map(label_mapping)

help_data = list(zip(df_help['premise'], df_help['hypothesis'], df_help['label']))
print(f"Loaded HELP dataset with {len(help_data)} examples.")

# ----------------------------
# 4. POLARITY TAGGING
# ----------------------------
def add_polarity_tags(text):
    text_lower = text.lower()
    if any(word in text_lower for word in ["no","none","never","few","at most","not all","without","exactly"]):
        return "[DOWN] " + text
    else:
        return "[UP] " + text

# ----------------------------
# 5. CUSTOM LOGICAL DATASET (128 EXAMPLES)
# ----------------------------
sample_data = [
    # --- ENTAILMENT (43) ---
    ("All cats are mammals", "Some cats are mammals", 0),
    ("No dogs can fly", "Some dogs cannot fly", 0),
    ("Every student passed the exam", "Some students passed", 0),
    ("At least three birds are singing", "Some birds are singing", 0),
    ("Each car has wheels", "All cars have wheels", 0),
    ("All apples are red", "Some apples are red", 0),
    ("Every teacher attended the meeting", "Some teachers attended", 0),
    ("Not all birds can fly", "Some birds cannot fly", 0),
    ("Exactly two players scored", "Two players scored", 0),
    ("Every child got a gift", "Each child received a present", 0),
    ("Some men are teachers", "Some people are teachers", 0),
    ("All cars have wheels", "Some cars have wheels", 0),
    ("Every dog barked loudly", "Some dogs barked", 0),
    ("At least one window is open", "Some windows are open", 0),
    ("Not all lights are on", "Some lights are off", 0),
    ("Exactly three students passed", "Three students passed", 0),
    ("Each room has a window", "Every room has a window", 0),
    ("All birds can fly", "Some birds can fly", 0),
    ("Some kids are laughing", "Some kids are happy", 0),
    ("Every student joined", "All students joined", 0),
    ("At least two chairs are broken", "Some chairs are broken", 0),
    ("Some phones are charging", "Some devices are charging", 0),
    ("No cats are swimming", "No felines are swimming", 0),
    ("Each house has a roof", "Every house has a roof", 0),
    ("Some dogs are barking", "Some animals are making noise", 0),
    ("All students received books", "Some students received books", 0),
    ("Every flower is blooming", "Some flowers are blooming", 0),
    ("At most five chairs are broken", "At most six chairs are broken", 0),
    ("Exactly one window is open", "One window is open", 0),
    ("All cars have brakes", "Some cars have brakes", 0),
    ("Every child is happy", "Some children are happy", 0),
    ("Some trees are tall", "Some plants are tall", 0),
    ("Not all birds sing", "Some birds are silent", 0),
    ("All laptops are charged", "Some laptops are charged", 0),
    ("Every student answered correctly", "Some students answered correctly", 0),
    ("Some cats are playful", "Some animals are playful", 0),
    ("All houses have doors", "Some houses have doors", 0),
    ("Each person attended the meeting", "Every person attended", 0),
    ("At most three students failed", "At most four students failed", 0),
    ("Every bird can fly", "Some birds can fly", 0),
    ("Some flowers are red", "Some plants are red", 0),
    ("All cars are clean", "Some cars are clean", 0),
    ("Every dog is barking", "Some dogs are barking", 0),

    # --- NEUTRAL (42) ---
    ("Few dogs barked", "Some dogs barked", 1),
    ("Many cats slept", "All cats slept", 1),
    ("Most students passed", "Every student passed", 1),
    ("Some children are playing", "Some children are studying", 1),
    ("Every dog barked", "Some dogs barked", 1),
    ("At least 5 students joined", "At least 10 students joined", 1),
    ("John visited Paris", "Mary visited London", 1),
    ("Some books are new", "Some chairs are new", 1),
    ("Many birds are singing", "All birds are flying", 1),
    ("At least one car stopped", "At least three cars stopped", 1),
    ("Some people are dancing", "Some people are talking", 1),
    ("Every child laughed", "Some children cried", 1),
    ("A few dogs barked", "A few cats meowed", 1),
    ("Some laptops are on", "Some tablets are on", 1),
    ("All chairs are brown", "All tables are brown", 1),
    ("Most kids enjoyed the show", "Most kids disliked the show", 1),
    ("The store opens at 8 AM", "The store closes at 8 PM", 1),
    ("Some players scored", "Some coaches cheered", 1),
    ("He went to school", "She stayed home", 1),
    ("Some streets are closed", "Some houses are closed", 1),
    ("Every tree is tall", "Some trees are tall", 1),
    ("A few birds are flying", "Some animals are flying", 1),
    ("Most chairs are occupied", "Some chairs are empty", 1),
    ("At least three doors are open", "Some doors are open", 1),
    ("Some people are walking", "Some people are running", 1),
    ("Every child is smiling", "Some children are frowning", 1),
    ("Some phones are off", "Some devices are on", 1),
    ("Most students studied", "Some students did not study", 1),
    ("A few cats are sleeping", "Some animals are sleeping", 1),
    ("Some windows are closed", "Some doors are closed", 1),
    ("Every dog is awake", "Some dogs are asleep", 1),
    ("Some tables are round", "Some chairs are round", 1),
    ("Most flowers are blooming", "Some flowers are not blooming", 1),
    ("At least two cars stopped", "Some cars stopped", 1),
    ("Some people are talking", "Some people are listening", 1),
    ("Every student is present", "Some students are absent", 1),
    ("Some birds are chirping", "Some animals are making noise", 1),
    ("Most laptops are working", "Some laptops are broken", 1),
    ("A few windows are open", "Some doors are open", 1),
    ("Some children are running", "Some children are walking", 1),
    ("Most lights are on", "Some lights are off", 1),
    ("Some chairs are broken", "Some tables are broken", 1),
    ("Every phone is charged", "Some phones are not charged", 1),

    # --- CONTRADICTION (43) ---
    ("No cars are electric", "Some cars are electric", 2),
    ("Every student passed", "Some students did not pass", 2),
    ("Some children are playing", "No children are playing", 2),
    ("Without any help, she succeeded", "She did not succeed", 2),
    ("All cats are black", "No cats are black", 2),
    ("Most birds can fly", "No birds can fly", 2),
    ("Exactly one student attended", "No student attended", 2),
    ("She is alive", "She is dead", 2),
    ("No one entered the room", "Someone entered the room", 2),
    ("Every light is off", "Some lights are on", 2),
    ("No apples are red", "Some apples are red", 2),
    ("All students are happy", "No students are happy", 2),
    ("Most chairs are broken", "No chairs are broken", 2),
    ("Every dog is barking", "No dog is barking", 2),
    ("Some windows are open", "No windows are open", 2),
    ("All birds are singing", "No birds are singing", 2),
    ("The box is empty", "The box is full", 2),
    ("She passed the test", "She failed the test", 2),
    ("No phones are charging", "Some phones are charging", 2),
    ("He is present", "He is absent", 2),
    ("All tables are round", "No tables are round", 2),
    ("Some people are outside", "Nobody is outside", 2),
    ("No dogs are barking", "Some dogs are barking", 2),
    ("Every cat is sleeping", "Some cats are awake", 2),
    ("No students attended", "Some students attended", 2),
    ("All birds are flying", "No birds are flying", 2),
    ("Every light is on", "Some lights are off", 2),
    ("No chairs are broken", "Some chairs are broken", 2),
    ("All phones are off", "Some phones are on", 2),
    ("Every tree is tall", "Some trees are short", 2),
    ("No windows are open", "Some windows are open", 2),
    ("All doors are closed", "Some doors are open", 2),
    ("Every student is late", "Some students are on time", 2),
    ("No birds are singing", "Some birds are singing", 2),
    ("All laptops are off", "Some laptops are on", 2),
    ("Every child is crying", "Some children are laughing", 2),
    ("No cats are playing", "Some cats are playing", 2),
    ("All flowers are dead", "Some flowers are alive", 2),
    ("Every dog is asleep", "Some dogs are awake", 2),
    ("No students passed", "Some students passed", 2),
    ("All tables are broken", "Some tables are fine", 2),
    ("Every phone is broken", "Some phones are working", 2),
    ("No lights are on", "Some lights are on", 2),
    ("All birds are dead", "Some birds are alive", 2),
]


# ----------------------------
# 6. BALANCE HELP DATASET
# ----------------------------
counts = Counter([label for _, _, label in sample_data])
min_count = min(counts.values())

help_by_label = {0: [], 1: [], 2: []}
for p, h, l in help_data:
    help_by_label[l].append((p, h, l))

help_balanced = []
for l in [0,1,2]:
    help_balanced.extend(random.sample(help_by_label[l], min(min_count, len(help_by_label[l]))))

print(f"Balanced HELP dataset size: {len(help_balanced)}")
counts_help = Counter([label for _, _, label in help_balanced])
print(f"Class distribution in balanced HELP dataset: {counts_help}")

# ----------------------------
# 7. MERGE CUSTOM + HELP
# ----------------------------
all_data = sample_data + help_balanced
all_data = [(add_polarity_tags(p), add_polarity_tags(h), l) for p,h,l in all_data]

# ----------------------------
# 8. TRAIN/TEST SPLIT
# ----------------------------
train_texts, test_texts, train_labels, test_labels = train_test_split(
    [(p,h) for p,h,l in all_data],
    [l for _,_,l in all_data],
    test_size=0.1,
    stratify=[l for _,_,l in all_data],
    random_state=42
)
train_data = [(p,h,l) for (p,h),l in zip(train_texts, train_labels)]
test_data  = [(p,h,l) for (p,h),l in zip(test_texts, test_labels)]

# ----------------------------
# 9. TOKENIZER
# ----------------------------
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer.add_special_tokens({"additional_special_tokens": ["[UP]", "[DOWN]"]})

# ----------------------------
# 10. DATASET CLASS
# ----------------------------
from torch.utils.data import Dataset

class MonotonicityDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=64):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        p, h, label = self.data[idx]
        enc = self.tokenizer(
            p, h,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }

train_ds = MonotonicityDataset(train_data, tokenizer)
test_ds  = MonotonicityDataset(test_data, tokenizer)

# ----------------------------
# 11. MODEL
# ----------------------------
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=3)
model.resize_token_embeddings(len(tokenizer))

# ----------------------------
# 12. METRICS
# ----------------------------
from sklearn.metrics import accuracy_score, f1_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1_macro": f1_score(labels, preds, average="macro")
    }

# ----------------------------
# 13. TRAINING
# ----------------------------
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=10,
    per_device_train_batch_size=4,
    learning_rate=2e-5,
    eval_strategy="epoch",
    save_strategy="no",
    logging_steps=10,
    logging_dir="./logs"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

print("Training started...")
trainer.train()

# ----------------------------
# 14. EVALUATION
# ----------------------------
results = trainer.evaluate(eval_dataset=test_ds)
print("Overall Evaluation:", results)

# ----------------------------
# 15. UP/DOWN ACCURACY REPORTING
# ----------------------------
def polarity_metrics(trainer, dataset):
    up_preds, up_labels = [], []
    down_preds, down_labels = [], []

    for item in dataset:
        input_ids = item['input_ids'].unsqueeze(0)
        attention_mask = item['attention_mask'].unsqueeze(0)
        label = item['labels'].item()
        logits = trainer.model(input_ids=input_ids, attention_mask=attention_mask).logits
        pred = logits.argmax(dim=-1).item()

        first_token_id = input_ids[0,1].item()  # index 1 is [UP]/[DOWN]
        if first_token_id == tokenizer.convert_tokens_to_ids("[UP]"):
            up_preds.append(pred)
            up_labels.append(label)
        else:
            down_preds.append(pred)
            down_labels.append(label)

    print(f"UP Accuracy: {accuracy_score(up_labels, up_preds):.2f}")
    print(f"DOWN Accuracy: {accuracy_score(down_labels, down_preds):.2f}")

polarity_metrics(trainer, test_ds)


Loaded HELP dataset with 35891 examples.
Balanced HELP dataset size: 86
Class distribution in balanced HELP dataset: Counter({0: 43, 1: 43})


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
  trainer = Trainer(


Training started...




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,1.0197,0.961019,0.409091,0.193548
2,0.9931,0.884485,0.545455,0.4
3,0.7865,0.802888,0.590909,0.599567
4,0.6592,0.769332,0.636364,0.62735
5,0.5015,0.789682,0.681818,0.679739
6,0.2478,0.967043,0.545455,0.574675
7,0.1845,0.88895,0.727273,0.753813
8,0.2793,0.910888,0.772727,0.79085
9,0.1229,0.969113,0.772727,0.79085
10,0.0431,1.026856,0.727273,0.730908


Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

Overall Evaluation: {'eval_loss': 1.0268563032150269, 'eval_accuracy': 0.7272727272727273, 'eval_f1_macro': 0.730908152734778, 'eval_runtime': 2.3694, 'eval_samples_per_second': 9.285, 'eval_steps_per_second': 1.266, 'epoch': 10.0}
UP Accuracy: 0.67
DOWN Accuracy: 0.86
