# Dataset Loading

### Do initial data cleaning and summary statistics (e.g. mean, max length etc.)

In [1]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
import matplotlib.pyplot as plt

files = {
    "train": "train.json",
    "test": "test.json",
    "validation": "validation.json"

}

def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def analyze_sequences(premises, hypotheses, labels, dataset_name):
    premise_lengths = [len(str(p).split()) for p in premises]
    hypothesis_lengths = [len(str(h).split()) for h in hypotheses]
    total_lengths = [p_len + h_len for p_len, h_len in zip(premise_lengths, hypothesis_lengths)]
    
    complex_words = ['although', 'however', 'therefore', 'because', 'if', 'then', 
                    'unless', 'while', 'despite', 'according to', 'based on',
                    'estimate', 'calculate', 'determine', 'analyze', 'conclude']
    
    complexity_scores = []
    for p, h in zip(premises, hypotheses):
        text = f"{p} {h}".lower()
        score = sum(1 for word in complex_words if word in text)
        complexity_scores.append(score)
    
    label_counts = Counter(labels)
    
    print(f"\n {dataset_name.upper()} DATASET ANALYSIS")
    print("=" * 50)
    
    print(f"Total entries: {len(premises)}")
    print(f"Label distribution: {dict(label_counts)}")
    
    print(f"\nPremise length - Avg: {np.mean(premise_lengths):.1f}, Max: {max(premise_lengths)}, 95th %ile: {np.percentile(premise_lengths, 95):.1f}")
    print(f"Hypothesis length - Avg: {np.mean(hypothesis_lengths):.1f}, Max: {max(hypothesis_lengths)}, 95th %ile: {np.percentile(hypothesis_lengths, 95):.1f}")

    #Add total length string of Premise and Hypothesis to get TOTAL string length
    print(f"Total length (P+H) - Avg: {np.mean(total_lengths):.1f}, Max: {max(total_lengths)}, 95th %ile: {np.percentile(total_lengths, 95):.1f}")
    
    print(f"\nComplexity score - Avg: {np.mean(complexity_scores):.1f}, Max: {max(complexity_scores)}")
    print(f"Entries with complex logic (>2): {sum(1 for score in complexity_scores if score > 2)}/{len(complexity_scores)}")
    
    print(f"\nLongest premises (>30 words):")
    long_indices = [i for i, length in enumerate(premise_lengths) if length > 30]
    for i in long_indices[:3]: 
        print(f"  {i}: {premises[i][:100]}...")
    
    print(f"\nMost complex examples (score >= 3):")
    complex_indices = [i for i, score in enumerate(complexity_scores) if score >= 3]
    for i in complex_indices[:3]:  
        print(f"  {i}: Score {complexity_scores[i]}")
        print(f"     Premise: {premises[i][:80]}...")
        print(f"     Hypothesis: {hypotheses[i]}")
    
    return {
        'premise_lengths': premise_lengths,
        'hypothesis_lengths': hypothesis_lengths,
        'total_lengths': total_lengths,
        'complexity_scores': complexity_scores,
        'label_counts': label_counts
    }

all_analyses = {}

for name, file_path in files.items():
    try:
        data = load_json(file_path)
        
        premises_dict = data['premise']
        hypotheses_dict = data['hypothesis'] 
        labels_dict = data['label']
        
        premises = [premises_dict[key] for key in sorted(premises_dict.keys(), key=int)]
        hypotheses = [hypotheses_dict[key] for key in sorted(hypotheses_dict.keys(), key=int)]
        labels = [labels_dict[key] for key in sorted(labels_dict.keys(), key=int)]
        
        analysis = analyze_sequences(premises, hypotheses, labels, name)
        all_analyses[name] = analysis
        
    except Exception as e:
        print(f"✗ Error analyzing {file_path}: {e}")
        continue

print(f"\n RECOMMENDATION ANALYSIS")
print("=" * 50)

all_total_lengths = []
all_complexity_scores = []

for name, analysis in all_analyses.items():
    all_total_lengths.extend(analysis['total_lengths'])
    all_complexity_scores.extend(analysis['complexity_scores'])

avg_total_length = np.mean(all_total_lengths)
avg_complexity = np.mean(all_complexity_scores)
pct_long_sequences = sum(1 for length in all_total_lengths if length > 50) / len(all_total_lengths) * 100
pct_complex_sequences = sum(1 for score in all_complexity_scores if score > 2) / len(all_complexity_scores) * 100

print(f"Overall average sequence length (P+H): {avg_total_length:.1f} words")
print(f"Overall average complexity score: {avg_complexity:.1f}")
print(f"Sequences > 50 words: {pct_long_sequences:.1f}%")
print(f"Sequences with complex logic (>2): {pct_complex_sequences:.1f}%")

print(f"\n ARCHITECTURE RECOMMENDATION:")

if avg_total_length > 40 or pct_long_sequences > 20 or avg_complexity > 2.5:
    print("RECOMMEND: LSTM")
    print("   - Longer sequences benefit from LSTM's cell state")
    print("   - Complex reasoning patterns need fine-grained gating")
    print("   - Better for preserving long-range dependencies")
elif avg_total_length < 20 and avg_complexity < 1.5:
    print("RECOMMEND: GRU") 
    print("   - Shorter sequences work well with GRU")
    print("   - Simpler patterns don't need LSTM's complexity")
    print("   - Faster training and simpler implementation")
else:
    print("RECOMMEND: LSTM (default for NLI tasks)")
    print("   - NLI typically involves moderate complexity")
    print("   - LSTM provides better performance margin")
    print("   - More impressive for academic purposes")

print(f"\n Additional factors:")
print("- Academic context: LSTM shows deeper understanding")
print("- Attention requirement: LSTM+attention is a strong combination")
print("- You can implement GRU later as comparative model")


 TRAIN DATASET ANALYSIS
Total entries: 23088
Label distribution: {'neutral': 14618, 'entails': 8470}

Premise length - Avg: 18.1, Max: 10587, 95th %ile: 33.0
Hypothesis length - Avg: 11.7, Max: 36, 95th %ile: 20.0
Total length (P+H) - Avg: 29.8, Max: 10605, 95th %ile: 48.0

Complexity score - Avg: 0.3, Max: 9
Entries with complex logic (>2): 49/23088

Longest premises (>30 words):
  5: Pluto is about 39 times more distant from the Sun than is the Earth, and it takes about 250 Earth ye...
  19: Even if you don't live near a stream or river (though most people do), there are many little things ...
  31: Slugs, believe it or not have a very important purpose. They are decomposers, which means they eats ...

Most complex examples (score >= 3):
  186: Score 3
     Premise: If a child is having difficulty, she explains, "then everyone, parents, child an...
     Hypothesis: Children resemble their parents because they have similar dna.
  270: Score 3
     Premise: When the National Academy o

Found outlier of max length of 10587. This is alarming as this can skew the data, given that max length of premise of test and validation is 45 -55 we can restrict the training data to have this range.

### First how many does it catch if premise  x <= 80 words. Arbitrarily chose 80 as starting point

In [2]:
import json

MAX_TOTAL_LENGTH = 100    
MAX_PREMISE_LENGTH = 80    
MAX_HYPOTHESIS_LENGTH = 40  

def find_outliers(data, thresholds):
    outliers = []
    
    premises_dict = data['premise']
    hypotheses_dict = data['hypothesis']
    labels_dict = data['label']
    
    premises = [premises_dict[key] for key in sorted(premises_dict.keys(), key=int)]
    hypotheses = [hypotheses_dict[key] for key in sorted(hypotheses_dict.keys(), key=int)]
    labels = [labels_dict[key] for key in sorted(labels_dict.keys(), key=int)]
    
    for i, (premise, hypothesis, label) in enumerate(zip(premises, hypotheses, labels)):
        premise_words = len(str(premise).split())
        hypothesis_words = len(str(hypothesis).split())
        total_words = premise_words + hypothesis_words
        
        if (total_words > thresholds['max_total'] or 
            premise_words > thresholds['max_premise'] or 
            hypothesis_words > thresholds['max_hypothesis']):
            
            outliers.append({
                'index': i,
                'premise': premise,
                'hypothesis': hypothesis, 
                'label': label,
                'premise_length': premise_words,
                'hypothesis_length': hypothesis_words,
                'total_length': total_words
            })
    
    return outliers, premises, hypotheses, labels

files_to_check = ["train.json", "test.json", "validation.json"]
thresholds = {
    'max_total': MAX_TOTAL_LENGTH,
    'max_premise': MAX_PREMISE_LENGTH,
    'max_hypothesis': MAX_HYPOTHESIS_LENGTH
}

print("OUTLIER DETECTION ANALYSIS")
print("=" * 50)

for file_path in files_to_check:
    try:
        print(f"\nAnalyzing: {file_path}")
        data = json.load(open(file_path, "r", encoding="utf-8"))
        
        outliers, premises, hypotheses, labels = find_outliers(data, thresholds)
        
        print(f"Total entries: {len(premises)}")
        print(f"Outliers found: {len(outliers)}")
        print(f"Outlier percentage: {len(outliers)/len(premises)*100:.2f}%")
        
        if outliers:
            print(f"\nTop 5 outliers (showing worst offenders):")
            # Sort by total length descending
            sorted_outliers = sorted(outliers, key=lambda x: x['total_length'], reverse=True)
            for outlier in sorted_outliers[:5]:
                print(f"Index {outlier['index']}: {outlier['total_length']} words total")
                print(f"  Premise: {outlier['premise_length']} words")
                print(f"  Hypothesis: {outlier['hypothesis_length']} words")
                print(f"  Label: {outlier['label']}")
                print(f"  Premise preview: {outlier['premise'][:100]}...")
                print()
                
        else:
            print("✓ No outliers found!")
            
    except Exception as e:
        print(f"Error analyzing {file_path}: {e}")

OUTLIER DETECTION ANALYSIS

Analyzing: train.json
Total entries: 23088
Outliers found: 9
Outlier percentage: 0.04%

Top 5 outliers (showing worst offenders):
Index 1104: 10605 words total
  Premise: 10587 words
  Hypothesis: 18 words
  Label: neutral
  Premise preview: The History of Polar Front and Air Mass Concept in the United States -	The best explanation for how ...

Index 1129: 6258 words total
  Premise: 6247 words
  Hypothesis: 11 words
  Label: entails
  Premise preview: Structure of Plant Cell Walls XLIII.	A cell wall is not present in animal cells.	neutral	A cell wall...

Index 1484: 2413 words total
  Premise: 2388 words
  Hypothesis: 25 words
  Label: neutral
  Premise preview: The fat in whole milk helps one-year olds get all the fat they need to be healthy and grow well, pro...

Index 995: 1284 words total
  Premise: 1270 words
  Hypothesis: 14 words
  Label: neutral
  Premise preview: We have all these flowering apple trees, and they're all in bloom.	A sign that an appl

Analysis

9 outliers were found

### How about if restrict the threshold further -- premise x <= 55

In [3]:
import json

MAX_TOTAL_LENGTH = 100    
MAX_PREMISE_LENGTH = 55 
MAX_HYPOTHESIS_LENGTH = 40  

def find_outliers(data, thresholds):
    outliers = []
    
    premises_dict = data['premise']
    hypotheses_dict = data['hypothesis']
    labels_dict = data['label']
    
    premises = [premises_dict[key] for key in sorted(premises_dict.keys(), key=int)]
    hypotheses = [hypotheses_dict[key] for key in sorted(hypotheses_dict.keys(), key=int)]
    labels = [labels_dict[key] for key in sorted(labels_dict.keys(), key=int)]
    
    for i, (premise, hypothesis, label) in enumerate(zip(premises, hypotheses, labels)):
        premise_words = len(str(premise).split())
        hypothesis_words = len(str(hypothesis).split())
        total_words = premise_words + hypothesis_words
        
        if (total_words > thresholds['max_total'] or 
            premise_words > thresholds['max_premise'] or 
            hypothesis_words > thresholds['max_hypothesis']):
            
            outliers.append({
                'index': i,
                'premise': premise,
                'hypothesis': hypothesis, 
                'label': label,
                'premise_length': premise_words,
                'hypothesis_length': hypothesis_words,
                'total_length': total_words
            })
    
    return outliers, premises, hypotheses, labels

files_to_check = ["train.json", "test.json", "validation.json"]
thresholds = {
    'max_total': MAX_TOTAL_LENGTH,
    'max_premise': MAX_PREMISE_LENGTH,
    'max_hypothesis': MAX_HYPOTHESIS_LENGTH
}

print("OUTLIER DETECTION ANALYSIS")
print("=" * 50)

for file_path in files_to_check:
    try:
        print(f"\nAnalyzing: {file_path}")
        data = json.load(open(file_path, "r", encoding="utf-8"))
        
        outliers, premises, hypotheses, labels = find_outliers(data, thresholds)
        
        print(f"Total entries: {len(premises)}")
        print(f"Outliers found: {len(outliers)}")
        print(f"Outlier percentage: {len(outliers)/len(premises)*100:.2f}%")
        
        if outliers:
            print(f"\nTop 5 outliers (showing worst offenders):")
            # Sort by total length descending
            sorted_outliers = sorted(outliers, key=lambda x: x['total_length'], reverse=True)
            for outlier in sorted_outliers[:5]:
                print(f"Index {outlier['index']}: {outlier['total_length']} words total")
                print(f"  Premise: {outlier['premise_length']} words")
                print(f"  Hypothesis: {outlier['hypothesis_length']} words")
                print(f"  Label: {outlier['label']}")
                print(f"  Premise preview: {outlier['premise'][:100]}...")
                print()
                
        else:
            print("✓ No outliers found!")
            
    except Exception as e:
        print(f"Error analyzing {file_path}: {e}")

OUTLIER DETECTION ANALYSIS

Analyzing: train.json
Total entries: 23088
Outliers found: 10
Outlier percentage: 0.04%

Top 5 outliers (showing worst offenders):
Index 1104: 10605 words total
  Premise: 10587 words
  Hypothesis: 18 words
  Label: neutral
  Premise preview: The History of Polar Front and Air Mass Concept in the United States -	The best explanation for how ...

Index 1129: 6258 words total
  Premise: 6247 words
  Hypothesis: 11 words
  Label: entails
  Premise preview: Structure of Plant Cell Walls XLIII.	A cell wall is not present in animal cells.	neutral	A cell wall...

Index 1484: 2413 words total
  Premise: 2388 words
  Hypothesis: 25 words
  Label: neutral
  Premise preview: The fat in whole milk helps one-year olds get all the fat they need to be healthy and grow well, pro...

Index 995: 1284 words total
  Premise: 1270 words
  Hypothesis: 14 words
  Label: neutral
  Premise preview: We have all these flowering apple trees, and they're all in bloom.	A sign that an app

Now there are 10 outliers found

# Data Cleaning

### For 9 outliers

In [4]:
# RUN CODE TO EXPORT OUTLIERS TO A TXT FILE -> train_outliers.txt

# Configuration
MAX_TOTAL_LENGTH = 100
MAX_PREMISE_LENGTH = 80  
MAX_HYPOTHESIS_LENGTH = 40

def find_outliers(data, thresholds):
    """Find entries that exceed length thresholds."""
    outliers = []
    
    premises_dict = data['premise']
    hypotheses_dict = data['hypothesis']
    labels_dict = data['label']
    
    # Convert to lists and sort by numeric keys
    premises = [premises_dict[key] for key in sorted(premises_dict.keys(), key=int)]
    hypotheses = [hypotheses_dict[key] for key in sorted(hypotheses_dict.keys(), key=int)]
    labels = [labels_dict[key] for key in sorted(labels_dict.keys(), key=int)]
    
    for i, (premise, hypothesis, label) in enumerate(zip(premises, hypotheses, labels)):
        premise_words = len(str(premise).split())
        hypothesis_words = len(str(hypothesis).split())
        total_words = premise_words + hypothesis_words
        
        if (total_words > thresholds['max_total'] or 
            premise_words > thresholds['max_premise'] or 
            hypothesis_words > thresholds['max_hypothesis']):
            
            outliers.append({
                'index': i,
                'premise': premise,
                'hypothesis': hypothesis, 
                'label': label,
                'premise_length': premise_words,
                'hypothesis_length': hypothesis_words,
                'total_length': total_words
            })
    
    return outliers, premises, hypotheses, labels

def export_outliers(outliers, output_path):
    """Export outliers to a TXT file."""
    with open(output_path, "w", encoding="utf-8") as f:
        f.write("OUTLIERS DETECTED\n")
        f.write("=" * 60 + "\n\n")
        f.write(f"Total outliers found: {len(outliers)}\n\n")
        
        for outlier in outliers:
            f.write(f"INDEX: {outlier['index']}\n")
            f.write(f"TOTAL WORDS: {outlier['total_length']}\n")
            f.write(f"PREMISE WORDS: {outlier['premise_length']}\n")
            f.write(f"HYPOTHESIS WORDS: {outlier['hypothesis_length']}\n")
            f.write(f"LABEL: {outlier['label']}\n")
            f.write(f"PREMISE: {outlier['premise']}\n")
            f.write(f"HYPOTHESIS: {outlier['hypothesis']}\n")
            f.write("-" * 60 + "\n\n")

def create_clean_dataset(premises, hypotheses, labels, outlier_indices, output_path):
    """Create a clean dataset without outliers."""
    clean_premises = [p for i, p in enumerate(premises) if i not in outlier_indices]
    clean_hypotheses = [h for i, h in enumerate(hypotheses) if i not in outlier_indices]
    clean_labels = [l for i, l in enumerate(labels) if i not in outlier_indices]
    
    clean_data = {
        'premise': dict(enumerate(clean_premises)),
        'hypothesis': dict(enumerate(clean_hypotheses)),
        'label': dict(enumerate(clean_labels))
    }
    
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(clean_data, f, indent=2)
    
    return clean_premises, clean_hypotheses, clean_labels

# Main processing pipeline
print(" COMPLETE OUTLIER HANDLING PIPELINE")
print("=" * 60)

files_to_process = ["train.json", "validation.json", "test.json"]
thresholds = {
    'max_total': MAX_TOTAL_LENGTH,
    'max_premise': MAX_PREMISE_LENGTH,
    'max_hypothesis': MAX_HYPOTHESIS_LENGTH
}

all_clean_data = {}

for file_path in files_to_process:
    print(f"\nProcessing: {file_path}")
    
    try:
        # Load and analyze
        data = json.load(open(file_path, "r", encoding="utf-8"))
        outliers, premises, hypotheses, labels = find_outliers(data, thresholds)
        
        print(f"Total entries: {len(premises)}")
        print(f"Outliers found: {len(outliers)}")
        
        if outliers:
            # Export outliers
            outlier_file = file_path.replace(".json", "_outliers.txt")
            export_outliers(outliers, outlier_file)
            print(f"✓ Outliers exported to: {outlier_file}")
            
            # Create clean version
            outlier_indices = [outlier['index'] for outlier in outliers]
            clean_file = file_path.replace(".json", "_clean.json")
            clean_premises, clean_hypotheses, clean_labels = create_clean_dataset(
                premises, hypotheses, labels, outlier_indices, clean_file
            )
            
            print(f"✓ Clean dataset created: {clean_file}")
            print(f"  Removed {len(outliers)} outliers, {len(clean_premises)} entries remaining")
            
            # Store clean data for later use
            all_clean_data[file_path] = {
                'premises': clean_premises,
                'hypotheses': clean_hypotheses,
                'labels': clean_labels
            }
            
        else:
            print("✓ No outliers found - using original data")
            # Store original data
            all_clean_data[file_path] = {
                'premises': premises,
                'hypotheses': hypotheses,
                'labels': labels
            }
            
    except Exception as e:
        print(f"✗ Error processing {file_path}: {e}")
        continue

# Continue with processing clean data
print(f"\n OUTLIER PROCESSING COMPLETE")
print("=" * 60)
print("Now continuing with clean data...")

# Show summary
for file_path, data in all_clean_data.items():
    print(f"{file_path}: {len(data['premises'])} clean entries")

# ==================== CONTINUE WITH YOUR NEXT STEPS HERE ====================
# Now you can proceed with text preprocessing, model training, etc.
# Example: access clean train data with all_clean_data['train.json']['premises']

print(f"\n READY FOR NEXT STEPS:")
print("- Text preprocessing (tokenization, cleaning)")
print("- Label encoding (entails/neutral to numerical)")
print("- LSTM model implementation")
print("- Training with clean data")

# Quick verification
if 'train.json' in all_clean_data:
    train_data = all_clean_data['train.json']
    print(f"\n Sample from clean training data:")
    print(f"First premise: {train_data['premises'][0][:100]}...")
    print(f"First hypothesis: {train_data['hypotheses'][0]}")
    print(f"First label: {train_data['labels'][0]}")

 COMPLETE OUTLIER HANDLING PIPELINE

Processing: train.json
Total entries: 23088
Outliers found: 9
✓ Outliers exported to: train_outliers.txt
✓ Clean dataset created: train_clean.json
  Removed 9 outliers, 23079 entries remaining

Processing: validation.json
Total entries: 1304
Outliers found: 0
✓ No outliers found - using original data

Processing: test.json
Total entries: 2126
Outliers found: 0
✓ No outliers found - using original data

 OUTLIER PROCESSING COMPLETE
Now continuing with clean data...
train.json: 23079 clean entries
validation.json: 1304 clean entries
test.json: 2126 clean entries

 READY FOR NEXT STEPS:
- Text preprocessing (tokenization, cleaning)
- Label encoding (entails/neutral to numerical)
- LSTM model implementation
- Training with clean data

 Sample from clean training data:
First premise: Pluto rotates once on its axis every 6.39 Earth days;...
First hypothesis: Earth rotates on its axis once times in one day.
First label: neutral


### For 10 outliers

In [5]:
MAX_TOTAL_LENGTH = 100
MAX_PREMISE_LENGTH = 55
MAX_HYPOTHESIS_LENGTH = 40

def find_outliers(data, thresholds):
    outliers = []
    
    premises_dict = data['premise']
    hypotheses_dict = data['hypothesis']
    labels_dict = data['label']
    
    premises = [premises_dict[key] for key in sorted(premises_dict.keys(), key=int)]
    hypotheses = [hypotheses_dict[key] for key in sorted(hypotheses_dict.keys(), key=int)]
    labels = [labels_dict[key] for key in sorted(labels_dict.keys(), key=int)]
    
    for i, (premise, hypothesis, label) in enumerate(zip(premises, hypotheses, labels)):
        premise_words = len(str(premise).split())
        hypothesis_words = len(str(hypothesis).split())
        total_words = premise_words + hypothesis_words
        
        if (total_words > thresholds['max_total'] or 
            premise_words > thresholds['max_premise'] or 
            hypothesis_words > thresholds['max_hypothesis']):
            
            outliers.append({
                'index': i,
                'premise': premise,
                'hypothesis': hypothesis, 
                'label': label,
                'premise_length': premise_words,
                'hypothesis_length': hypothesis_words,
                'total_length': total_words
            })
    
    return outliers, premises, hypotheses, labels

def export_outliers(outliers, output_path):
    with open(output_path, "w", encoding="utf-8") as f:
        f.write("OUTLIERS DETECTED\n")
        f.write("=" * 60 + "\n\n")
        f.write(f"Total outliers found: {len(outliers)}\n\n")
        
        for outlier in outliers:
            f.write(f"INDEX: {outlier['index']}\n")
            f.write(f"TOTAL WORDS: {outlier['total_length']}\n")
            f.write(f"PREMISE WORDS: {outlier['premise_length']}\n")
            f.write(f"HYPOTHESIS WORDS: {outlier['hypothesis_length']}\n")
            f.write(f"LABEL: {outlier['label']}\n")
            f.write(f"PREMISE: {outlier['premise']}\n")
            f.write(f"HYPOTHESIS: {outlier['hypothesis']}\n")
            f.write("-" * 60 + "\n\n")

def create_clean_dataset(premises, hypotheses, labels, outlier_indices, output_path):
    clean_premises = [p for i, p in enumerate(premises) if i not in outlier_indices]
    clean_hypotheses = [h for i, h in enumerate(hypotheses) if i not in outlier_indices]
    clean_labels = [l for i, l in enumerate(labels) if i not in outlier_indices]
    
    clean_data = {
        'premise': dict(enumerate(clean_premises)),
        'hypothesis': dict(enumerate(clean_hypotheses)),
        'label': dict(enumerate(clean_labels))
    }
    
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(clean_data, f, indent=2)
    
    return clean_premises, clean_hypotheses, clean_labels

print(" COMPLETE OUTLIER HANDLING PIPELINE")
print("=" * 60)

files_to_process = ["train.json", "validation.json", "test.json"]
thresholds = {
    'max_total': MAX_TOTAL_LENGTH,
    'max_premise': MAX_PREMISE_LENGTH,
    'max_hypothesis': MAX_HYPOTHESIS_LENGTH
}

all_clean_data = {}

for file_path in files_to_process:
    print(f"\nProcessing: {file_path}")
    
    try:
        data = json.load(open(file_path, "r", encoding="utf-8"))
        outliers, premises, hypotheses, labels = find_outliers(data, thresholds)
        
        print(f"Total entries: {len(premises)}")
        print(f"Outliers found: {len(outliers)}")
        
        if outliers:
            outlier_file = file_path.replace(".json", "_outliers.txt")
            export_outliers(outliers, outlier_file)
            print(f"✓ Outliers exported to: {outlier_file}")
            
            outlier_indices = [outlier['index'] for outlier in outliers]
            clean_file = file_path.replace(".json", "_clean.json")
            clean_premises, clean_hypotheses, clean_labels = create_clean_dataset(
                premises, hypotheses, labels, outlier_indices, clean_file
            )
            
            print(f"✓ Clean dataset created: {clean_file}")
            print(f"  Removed {len(outliers)} outliers, {len(clean_premises)} entries remaining")
            
            all_clean_data[file_path] = {
                'premises': clean_premises,
                'hypotheses': clean_hypotheses,
                'labels': clean_labels
            }
            
        else:
            print("✓ No outliers found - using original data")
            all_clean_data[file_path] = {
                'premises': premises,
                'hypotheses': hypotheses,
                'labels': labels
            }
            
    except Exception as e:
        print(f"✗ Error processing {file_path}: {e}")
        continue

print(f"\n OUTLIER PROCESSING COMPLETE")
print("=" * 60)
print("Now continuing with clean data...")

for file_path, data in all_clean_data.items():
    print(f"{file_path}: {len(data['premises'])} clean entries")

print(f"\n READY FOR NEXT STEPS:")
print("- Text preprocessing (tokenization, cleaning)")
print("- Label encoding (entails/neutral to numerical)")
print("- LSTM model implementation")
print("- Training with clean data")

if 'train.json' in all_clean_data:
    train_data = all_clean_data['train.json']
    print(f"\n Sample from clean training data:")
    print(f"First premise: {train_data['premises'][0][:100]}...")
    print(f"First hypothesis: {train_data['hypotheses'][0]}")
    print(f"First label: {train_data['labels'][0]}")

 COMPLETE OUTLIER HANDLING PIPELINE

Processing: train.json
Total entries: 23088
Outliers found: 10
✓ Outliers exported to: train_outliers.txt
✓ Clean dataset created: train_clean.json
  Removed 10 outliers, 23078 entries remaining

Processing: validation.json
Total entries: 1304
Outliers found: 0
✓ No outliers found - using original data

Processing: test.json
Total entries: 2126
Outliers found: 0
✓ No outliers found - using original data

 OUTLIER PROCESSING COMPLETE
Now continuing with clean data...
train.json: 23078 clean entries
validation.json: 1304 clean entries
test.json: 2126 clean entries

 READY FOR NEXT STEPS:
- Text preprocessing (tokenization, cleaning)
- Label encoding (entails/neutral to numerical)
- LSTM model implementation
- Training with clean data

 Sample from clean training data:
First premise: Pluto rotates once on its axis every 6.39 Earth days;...
First hypothesis: Earth rotates on its axis once times in one day.
First label: neutral


#### AFTER inspecting outlier txt file, manually input index to remove outliers from training data

In [6]:
#INSPECTING the outlier txt file, FIND A WAY TO REMOVE IT
import json

# Known outlier indices from train.json
# TRAIN_OUTLIER_INDICES = [270, 537, 606, 608, 760 , 995, 1104, 1129,1484 ] # Add all 9 outlier indices here
TRAIN_OUTLIER_INDICES = [270, 537, 606, 608, 760 , 995, 1104, 1129,1484, 19888] # Add all 10 outlier indices here

print("LOADING AND CLEANING TRAIN DATA")
print("=" * 50)

try:
    # Load the original train.json
    with open("train.json", "r", encoding="utf-8") as f:
        train_data = json.load(f)
    
    # Extract and sort the data
    premises_dict = train_data['premise']
    hypotheses_dict = train_data['hypothesis']
    labels_dict = train_data['label']
    
    train_premises = [premises_dict[key] for key in sorted(premises_dict.keys(), key=int)]
    train_hypotheses = [hypotheses_dict[key] for key in sorted(hypotheses_dict.keys(), key=int)]
    train_labels = [labels_dict[key] for key in sorted(labels_dict.keys(), key=int)]
    
    print(f"Original train entries: {len(train_premises)}")
    
    # Remove outliers
    clean_train_premises = [p for i, p in enumerate(train_premises) if i not in TRAIN_OUTLIER_INDICES]
    clean_train_hypotheses = [h for i, h in enumerate(train_hypotheses) if i not in TRAIN_OUTLIER_INDICES]
    clean_train_labels = [l for i, l in enumerate(train_labels) if i not in TRAIN_OUTLIER_INDICES]
    
    print(f"Outliers removed: {len(TRAIN_OUTLIER_INDICES)}")
    print(f"Clean train entries: {len(clean_train_premises)}")
    
    # Save cleaned data to JSON
    cleaned_train_data = {
        'premise': dict(enumerate(clean_train_premises)),
        'hypothesis': dict(enumerate(clean_train_hypotheses)),
        'label': dict(enumerate(clean_train_labels))
    }
    
    with open("train_cleaned.json", "w", encoding="utf-8") as f:
        json.dump(cleaned_train_data, f, indent=2)
        print("✓ train_cleaned.json saved")
    
    # Save cleaned data to TXT
    with open("train_cleaned.txt", "w", encoding="utf-8") as f:
        f.write("TRAIN DATASET - CLEANED (OUTLIERS REMOVED)\n")
        f.write("=" * 80 + "\n\n")
        f.write(f"Outliers removed: {len(TRAIN_OUTLIER_INDICES)}\n")
        f.write(f"Total entries: {len(clean_train_premises)}\n\n")
        
        for i, (premise, hypothesis, label) in enumerate(zip(clean_train_premises, clean_train_hypotheses, clean_train_labels)):
            f.write(f"ENTRY {i}:\n")
            f.write(f"PREMISE: {premise}\n")
            f.write(f"HYPOTHESIS: {hypothesis}\n")
            f.write(f"LABEL: {label}\n")
            f.write("-" * 80 + "\n\n")
        
        print("✓ train_cleaned.txt saved")
    
    print("\n OUTLIER REMOVAL COMPLETE!")
    print("=" * 50)
    print("Files created:")
    print("  - train_cleaned.json (for model training)")
    print("  - train_cleaned.txt (for human inspection)")
    
except Exception as e:
    print(f"Error: {e}")
    print("Make sure train.json is in the current directory")

LOADING AND CLEANING TRAIN DATA
Original train entries: 23088
Outliers removed: 10
Clean train entries: 23078
✓ train_cleaned.json saved
✓ train_cleaned.txt saved

 OUTLIER REMOVAL COMPLETE!
Files created:
  - train_cleaned.json (for model training)
  - train_cleaned.txt (for human inspection)


#### **Verify if the outliers were successfully removed in the cleaned training JSON file**

In [7]:
import json

MAX_TOTAL_LENGTH = 100
MAX_PREMISE_LENGTH = 55 
MAX_HYPOTHESIS_LENGTH = 40

def check_outliers(premises, hypotheses, labels, dataset_name):
    outliers = []
    
    for i, (premise, hypothesis, label) in enumerate(zip(premises, hypotheses, labels)):
        premise_words = len(str(premise).split())
        hypothesis_words = len(str(hypothesis).split())
        total_words = premise_words + hypothesis_words
        
        if (total_words > MAX_TOTAL_LENGTH or 
            premise_words > MAX_PREMISE_LENGTH or 
            hypothesis_words > MAX_HYPOTHESIS_LENGTH):
            
            outliers.append({
                'index': i,
                'premise_length': premise_words,
                'hypothesis_length': hypothesis_words,
                'total_length': total_words,
                'label': label,
                'premise': premise,
                'hypothesis': hypothesis
            })
    
    return outliers

print("🔍 VERIFYING OUTLIER REMOVAL SUCCESS")
print("=" * 60)

print("\n1. CHECKING ORIGINAL train.json:")
try:
    with open("train.json", "r", encoding="utf-8") as f:
        original_data = json.load(f)
    
    original_premises = [original_data['premise'][key] for key in sorted(original_data['premise'].keys(), key=int)]
    original_hypotheses = [original_data['hypothesis'][key] for key in sorted(original_data['hypothesis'].keys(), key=int)]
    original_labels = [original_data['label'][key] for key in sorted(original_data['label'].keys(), key=int)]
    
    original_outliers = check_outliers(original_premises, original_hypotheses, original_labels, "original")
    print(f"   Original entries: {len(original_premises)}")
    print(f"   Original outliers: {len(original_outliers)}")
    
    if original_outliers:
        print("   Original outliers found:")
        for outlier in original_outliers:
            print(f"     Index {outlier['index']}: {outlier['total_length']} words (P:{outlier['premise_length']}, H:{outlier['hypothesis_length']})")
            
except Exception as e:
    print(f"   Error loading original: {e}")

print("\n2. CHECKING CLEANED train_cleaned.json:")
try:
    with open("train_cleaned.json", "r", encoding="utf-8") as f:
        cleaned_data = json.load(f)
    
    cleaned_premises = list(cleaned_data['premise'].values())
    cleaned_hypotheses = list(cleaned_data['hypothesis'].values())
    cleaned_labels = list(cleaned_data['label'].values())
    
    cleaned_outliers = check_outliers(cleaned_premises, cleaned_hypotheses, cleaned_labels, "cleaned")
    print(f"   Cleaned entries: {len(cleaned_premises)}")
    print(f"   Cleaned outliers: {len(cleaned_outliers)}")
    
    if cleaned_outliers:
        print("   ❌ STILL HAVE OUTLIERS:")
        for outlier in cleaned_outliers:
            print(f"     Index {outlier['index']}: {outlier['total_length']} words (P:{outlier['premise_length']}, H:{outlier['hypothesis_length']})")
            print(f"       Premise: {outlier['premise'][:100]}...")
            print(f"       Hypothesis: {outlier['hypothesis']}")
            print(f"       Label: {outlier['label']}")
            print("       " + "-" * 50)
    else:
        print("   ✅ SUCCESS: No outliers found in cleaned data!")
        
except Exception as e:
    print(f"   Error loading cleaned: {e}")

print("\n3. SUMMARY COMPARISON:")
print("=" * 40)
try:
    original_count = len(original_premises)
    cleaned_count = len(cleaned_premises)
    outliers_removed = original_count - cleaned_count
    
    print(f"   Original entries: {original_count}")
    print(f"   Cleaned entries: {cleaned_count}")
    print(f"   Outliers removed: {outliers_removed}")

    longest_cleaned = max(len(str(p).split()) + len(str(h).split()) 
                         for p, h in zip(cleaned_premises, cleaned_hypotheses))
    print(f"   Longest sequence in cleaned data: {longest_cleaned} words")
    
    if longest_cleaned <= MAX_TOTAL_LENGTH:
        print("   ✅ SUCCESS: All sequences are within reasonable length!")
    else:
        print(f"   ⚠️  WARNING: Still have long sequences (> {MAX_TOTAL_LENGTH} words)")
        
except NameError:
    print("   Could not compare - data not loaded properly")

print("\n" + "=" * 60)
print("VERIFICATION COMPLETE")
print("=" * 60)

🔍 VERIFYING OUTLIER REMOVAL SUCCESS

1. CHECKING ORIGINAL train.json:
   Original entries: 23088
   Original outliers: 10
   Original outliers found:
     Index 270: 1026 words (P:1016, H:10)
     Index 537: 791 words (P:775, H:16)
     Index 606: 663 words (P:646, H:17)
     Index 608: 358 words (P:344, H:14)
     Index 760: 615 words (P:595, H:20)
     Index 995: 1284 words (P:1270, H:14)
     Index 1104: 10605 words (P:10587, H:18)
     Index 1129: 6258 words (P:6247, H:11)
     Index 1484: 2413 words (P:2388, H:25)
     Index 19888: 68 words (P:56, H:12)

2. CHECKING CLEANED train_cleaned.json:
   Cleaned entries: 23078
   Cleaned outliers: 0
   ✅ SUCCESS: No outliers found in cleaned data!

3. SUMMARY COMPARISON:
   Original entries: 23088
   Cleaned entries: 23078
   Outliers removed: 10
   Longest sequence in cleaned data: 81 words
   ✅ SUCCESS: All sequences are within reasonable length!

VERIFICATION COMPLETE


# FURTHER DATA CLEANING

#### RegEx Text Wrangling

In [8]:
import json
import re

def clean_text(text):
    """
    Minimal text cleaning: remove URLs, HTML tags, and extra whitespace only
    Keep most symbols and punctuation, only remove excessive repetitive symbols
    """
    original_text = str(text)
    
    # Remove URLs
    text = re.sub(r'https?://\S+|www\.\S+', '', original_text, flags=re.MULTILINE)
    
    # Remove HTML tags
    text = re.sub(r'<.*?>', '', text)
    
    # Remove excessive repetitive symbols (3+ of the same symbol in a row)
    # But keep double dashes (--) as requested
    text = re.sub(r'\.{3,}', ' ', text)    # Remove multiple dots (....)
    text = re.sub(r'\*{3,}', ' ', text)    # Remove multiple asterisks (***)
    text = re.sub(r'={3,}', ' ', text)     # Remove multiple equals (===)
    text = re.sub(r'\!{3,}', ' ', text)    # Remove multiple exclamations (!!!)
    text = re.sub(r'\?{2,}', ' ', text)    # Remove multiple question marks (??)
    
    # Remove extra whitespace (multiple spaces/tabs/newlines to single space)
    text = re.sub(r'\s+', ' ', text).strip()
    
    return original_text, text

def process_dataset(input_path, output_path):
    """
    Process a complete dataset file with minimal cleaning
    """
    with open(input_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    cleaning_report = []
    
    # Clean premises
    cleaned_premises = {}
    for key, premise in data['premise'].items():
        original, cleaned = clean_text(premise)
        cleaned_premises[key] = cleaned
        if original != cleaned:
            cleaning_report.append({
                'type': 'premise',
                'key': key,
                'original': original,
                'cleaned': cleaned
            })
    
    # Clean hypotheses
    cleaned_hypotheses = {}
    for key, hypothesis in data['hypothesis'].items():
        original, cleaned = clean_text(hypothesis)
        cleaned_hypotheses[key] = cleaned
        if original != cleaned:
            cleaning_report.append({
                'type': 'hypothesis', 
                'key': key,
                'original': original,
                'cleaned': cleaned
            })
    
    # Create cleaned dataset
    cleaned_data = {
        'premise': cleaned_premises,
        'hypothesis': cleaned_hypotheses,
        'label': data['label']
    }
    
    # Save cleaned dataset
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(cleaned_data, f, indent=2, ensure_ascii=False)
    
    return cleaned_data, cleaning_report

# Files to process
datasets = [
    ("train_cleaned.json", "train_final_cleaned.json", "TRAINING"),
    ("test.json", "test_final_cleaned.json", "TEST"), 
    ("validation.json", "validation_final_cleaned.json", "VALIDATION")
]

print("MINIMAL DATA CLEANING PIPELINE")
print("=" * 60)
print("Cleaning actions: Remove URLs, HTML tags, and excessive repetitive symbols")
print("Retaining: Case, punctuation, numbers, single/double symbols, -- dashes")
print("=" * 60)

cleaning_reports = {}

# Process all datasets
for input_file, output_file, dataset_name in datasets:
    try:
        print(f"\nProcessing {dataset_name}: {input_file} -> {output_file}")
        
        cleaned_data, report = process_dataset(input_file, output_file)
        cleaning_reports[dataset_name] = report
        
        print(f"Successfully cleaned and saved {output_file}")
        
        # Show statistics
        total_entries = len(cleaned_data['premise'])
        changes = len(report)
        
        print(f"{dataset_name} Statistics:")
        print(f"   Total entries: {total_entries}")
        print(f"   Cleaning changes: {changes}")
        
    except FileNotFoundError:
        print(f"File not found: {input_file}")
    except Exception as e:
        print(f"Error processing {input_file}: {e}")

print("\n" + "=" * 60)
print("CLEANING COMPLETE - READY FOR TOKENIZATION")
print("=" * 60)

MINIMAL DATA CLEANING PIPELINE
Cleaning actions: Remove URLs, HTML tags, and excessive repetitive symbols
Retaining: Case, punctuation, numbers, single/double symbols, -- dashes

Processing TRAINING: train_cleaned.json -> train_final_cleaned.json
Successfully cleaned and saved train_final_cleaned.json
TRAINING Statistics:
   Total entries: 23078
   Cleaning changes: 927

Processing TEST: test.json -> test_final_cleaned.json
Successfully cleaned and saved test_final_cleaned.json
TEST Statistics:
   Total entries: 2126
   Cleaning changes: 84

Processing VALIDATION: validation.json -> validation_final_cleaned.json
Successfully cleaned and saved validation_final_cleaned.json
VALIDATION Statistics:
   Total entries: 1304
   Cleaning changes: 80

CLEANING COMPLETE - READY FOR TOKENIZATION


### Tokenization

In [9]:
import json
import numpy as np
from collections import Counter
import nltk
from nltk.tokenize import word_tokenize

nltk.download('punkt')

def build_vocab(texts, max_vocab_size=20000):
    """Build vocabulary from all texts"""
    counter = Counter()
    for text in texts:
        tokens = word_tokenize(text)
        counter.update(tokens)
    
    # Create vocabulary with special tokens
    vocab = {'<PAD>': 0, '<UNK>': 1}
    
    # Add most common words
    for word, _ in counter.most_common(max_vocab_size - 2):  # Reserve 2 for special tokens
        vocab[word] = len(vocab)
    
    return vocab

def text_to_sequence(text, vocab, max_length=None):
    """Convert text to sequence of integers"""
    tokens = word_tokenize(text)
    sequence = [vocab.get(token, vocab['<UNK>']) for token in tokens]
    
    if max_length:
        if len(sequence) > max_length:
            sequence = sequence[:max_length]  # Truncate
        else:
            sequence = sequence + [vocab['<PAD>']] * (max_length - len(sequence))  # Pad
    
    return sequence

def load_and_tokenize_data(file_path, vocab, max_length=50):
    """Load data and tokenize premises and hypotheses"""
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    premises = list(data['premise'].values())
    hypotheses = list(data['hypothesis'].values())
    labels = list(data['label'].values())
    
    # Tokenize and convert to sequences
    premises_seq = [text_to_sequence(premise, vocab, max_length) for premise in premises]
    hypotheses_seq = [text_to_sequence(hypothesis, vocab, max_length) for hypothesis in hypotheses]
    
    return premises_seq, hypotheses_seq, labels

print("🔤 BUILDING VOCABULARY AND TOKENIZING DATA...")
print("=" * 60)

# Load all text data to build comprehensive vocabulary
all_texts = []
for file_path in ["train_final_cleaned.json", "test_final_cleaned.json", "validation_final_cleaned.json"]:
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        all_texts.extend(list(data['premise'].values()))
        all_texts.extend(list(data['hypothesis'].values()))
    except:
        continue

# Build vocabulary
vocab = build_vocab(all_texts, max_vocab_size=20000)
print(f"Vocabulary size: {len(vocab)}")

# Tokenize all datasets
max_seq_length = 50  # Adjust based on your data analysis

train_premises_seq, train_hypotheses_seq, train_labels = load_and_tokenize_data(
    "train_final_cleaned.json", vocab, max_seq_length
)
val_premises_seq, val_hypotheses_seq, val_labels = load_and_tokenize_data(
    "validation_final_cleaned.json", vocab, max_seq_length
)
test_premises_seq, test_hypotheses_seq, test_labels = load_and_tokenize_data(
    "test_final_cleaned.json", vocab, max_seq_length
)

print(f"Train sequences: {len(train_premises_seq)}")
print(f"Validation sequences: {len(val_premises_seq)}")
print(f"Test sequences: {len(test_premises_seq)}")
print(f"Sequence length: {max_seq_length}")

# Convert labels to numerical format if needed (assuming they're already numerical)
# If labels are strings like "entails"/"neutral", you'd need to encode them here

print("✅ TOKENIZATION COMPLETE - READY FOR MODEL TRAINING")
print("=" * 60)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\msi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


🔤 BUILDING VOCABULARY AND TOKENIZING DATA...
Vocabulary size: 20000
Train sequences: 23078
Validation sequences: 1304
Test sequences: 2126
Sequence length: 50
✅ TOKENIZATION COMPLETE - READY FOR MODEL TRAINING


In [10]:
def preview_tokenization(premises, hypotheses, vocab, num_samples=3):
    """Preview tokenization results for a few samples"""
    print("\n" + "🔍 TOKENIZATION PREVIEW")
    print("=" * 80)
    
    # Create reverse vocabulary for decoding
    reverse_vocab = {v: k for k, v in vocab.items()}
    
    for i in range(min(num_samples, len(premises))):
        print(f"\nSample {i+1}:")
        print(f"Original Premise: {premises[i]}")
        print(f"Tokenized Premise: {[reverse_vocab.get(idx, '<UNK>') for idx in train_premises_seq[i] if idx != vocab['<PAD>']]}")
        print(f"Token IDs: {[idx for idx in train_premises_seq[i] if idx != vocab['<PAD>']]}")
        
        print(f"Original Hypothesis: {hypotheses[i]}")
        print(f"Tokenized Hypothesis: {[reverse_vocab.get(idx, '<UNK>') for idx in train_hypotheses_seq[i] if idx != vocab['<PAD>']]}")
        print(f"Token IDs: {[idx for idx in train_hypotheses_seq[i] if idx != vocab['<PAD>']]}")
        print("-" * 60)

# Add this right after your tokenization code
preview_tokenization(train_premises[:5], train_hypotheses[:5], vocab, num_samples=3)


🔍 TOKENIZATION PREVIEW

Sample 1:
Original Premise: Pluto rotates once on its axis every 6.39 Earth days;
Tokenized Premise: ['Pluto', 'rotates', 'once', 'on', 'its', 'axis', 'every', '6.39', 'Earth', 'days', ';']
Token IDs: [1202, 1227, 453, 25, 56, 387, 406, 14694, 111, 893, 44]
Original Hypothesis: Earth rotates on its axis once times in one day.
Tokenized Hypothesis: ['Earth', 'rotates', 'on', 'its', 'axis', 'once', 'times', 'in', 'one', 'day', '.']
Token IDs: [111, 1227, 25, 56, 387, 453, 276, 9, 35, 302, 2]
------------------------------------------------------------

Sample 2:
Tokenized Premise: ['--', '-Glenn', 'Once', 'per', 'day', ',', 'the', 'earth', 'rotates', 'about', 'its', 'axis', '.']
Token IDs: [230, 14695, 2895, 313, 302, 5, 3, 62, 1227, 165, 56, 387, 2]
Original Hypothesis: Earth rotates on its axis once times in one day.
Tokenized Hypothesis: ['Earth', 'rotates', 'on', 'its', 'axis', 'once', 'times', 'in', 'one', 'day', '.']
Token IDs: [111, 1227, 25, 56, 387, 453,

### LSTM (WITH ATTENTION)

In [15]:
# ==================== FIXED DATASET CLASS WITH PADDING ====================
from torch.nn.utils.rnn import pad_sequence

class NLIDataset(Dataset):
    def __init__(self, premises_seq, hypotheses_seq, labels):
        self.premises_seq = premises_seq
        self.hypotheses_seq = hypotheses_seq
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return (
            torch.tensor(self.premises_seq[idx], dtype=torch.long),
            torch.tensor(self.hypotheses_seq[idx], dtype=torch.long),
            torch.tensor(self.labels[idx], dtype=torch.long)
        )

def collate_fn(batch):
    """Custom collate function to pad sequences in the same batch."""
    premises, hypotheses, labels = zip(*batch)
    
    # Pad sequences to the maximum length in the batch
    premises_padded = pad_sequence(premises, batch_first=True, padding_value=0)
    hypotheses_padded = pad_sequence(hypotheses, batch_first=True, padding_value=0)
    labels = torch.stack(labels)
    
    return premises_padded, hypotheses_padded, labels

# Create datasets with the custom collate function
train_dataset = NLIDataset(train_premises_seq, train_hypotheses_seq, train_labels_enc)
val_dataset = NLIDataset(val_premises_seq, val_hypotheses_seq, val_labels_enc)
test_dataset = NLIDataset(test_premises_seq, test_hypotheses_seq, test_labels_enc)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

#MODEL 1: LSTM with Attention
class LSTMAttentionModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=100, hidden_dim=128, num_classes=2, num_layers=2, dropout=0.3):
        super(LSTMAttentionModel, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm_premise = nn.LSTM(embedding_dim, hidden_dim, num_layers, 
                                   batch_first=True, bidirectional=True, dropout=dropout)
        self.lstm_hypothesis = nn.LSTM(embedding_dim, hidden_dim, num_layers,
                                      batch_first=True, bidirectional=True, dropout=dropout)
        
        # Attention mechanism
        self.attention = nn.MultiheadAttention(hidden_dim * 2, num_heads=4, dropout=dropout, batch_first=True)
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 6, hidden_dim),  # Increased because of bidirectional
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(self, premise, hypothesis):
        # Embedding
        prem_embed = self.embedding(premise)
        hyp_embed = self.embedding(hypothesis)
        
        # LSTM encoding
        prem_lstm, _ = self.lstm_premise(prem_embed)
        hyp_lstm, _ = self.lstm_hypothesis(hyp_embed)
        
        # Attention between premise and hypothesis
        attn_output, attn_weights = self.attention(prem_lstm, hyp_lstm, hyp_lstm)
        
        # Pooling
        prem_pooled = torch.mean(prem_lstm, dim=1)
        hyp_pooled = torch.mean(hyp_lstm, dim=1)
        attn_pooled = torch.mean(attn_output, dim=1)
        
        # Concatenate features
        combined = torch.cat([prem_pooled, hyp_pooled, attn_pooled], dim=1)
        
        return self.classifier(combined), attn_weights


### GRU

In [None]:

class GRUModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=100, hidden_dim=128, num_classes=2, num_layers=2, dropout=0.3):
        super(GRUModel, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.gru_premise = nn.GRU(embedding_dim, hidden_dim, num_layers, 
                                 batch_first=True, bidirectional=True, dropout=dropout)
        self.gru_hypothesis = nn.GRU(embedding_dim, hidden_dim, num_layers,
                                    batch_first=True, bidirectional=True, dropout=dropout)
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim),  # bidirectional * 2 for both premise and hypothesis
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(self, premise, hypothesis):
        # Embedding
        prem_embed = self.embedding(premise)
        hyp_embed = self.embedding(hypothesis)
        
        # GRU encoding
        prem_gru, _ = self.gru_premise(prem_embed)
        hyp_gru, _ = self.gru_hypothesis(hyp_embed)
        
        # Pooling
        prem_pooled = torch.mean(prem_gru, dim=1)
        hyp_pooled = torch.mean(hyp_gru, dim=1)
        
        # Concatenate and classify
        combined = torch.cat([prem_pooled, hyp_pooled], dim=1)
        return self.classifier(combined), None  # Return None for attention weights for compatibility

# ==================== FIXED TRAINING FUNCTION ====================
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.001, model_name="model"):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    train_losses, val_losses, val_accuracies = [], [], []
    
    print(f"\n TRAINING {model_name.upper()}")
    print("=" * 50)
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        total_loss = 0
        
        for premise, hypothesis, labels in train_loader:
            premise, hypothesis, labels = premise.to(device), hypothesis.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs, _ = model(premise, hypothesis)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0
        all_preds, all_labels = [], []
        
        with torch.no_grad():
            for premise, hypothesis, labels in val_loader:
                premise, hypothesis, labels = premise.to(device), hypothesis.to(device), labels.to(device)
                
                outputs, _ = model(premise, hypothesis)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        train_losses.append(total_loss / len(train_loader))
        val_losses.append(val_loss / len(val_loader))
        accuracy = accuracy_score(all_labels, all_preds)
        val_accuracies.append(accuracy)
        
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_losses[-1]:.4f} | "
              f"Val Loss: {val_losses[-1]:.4f} | "
              f"Val Acc: {accuracy:.4f}")
    
    return train_losses, val_losses, val_accuracies

# ==================== INSTANTIATE AND TRAIN MODELS ====================
print(" FIXED MODELS - TRAINING STARTING...")

# Model 1: LSTM with Attention
lstm_model = LSTMAttentionModel(len(vocab), embedding_dim=100, hidden_dim=128)
lstm_train_loss, lstm_val_loss, lstm_val_acc = train_model(
    lstm_model, train_loader, val_loader, num_epochs=5, model_name="LSTM with Attention"
)

# Model 2: GRU Model
gru_model = GRUModel(len(vocab), embedding_dim=100, hidden_dim=128)
gru_train_loss, gru_val_loss, gru_val_acc = train_model(
    gru_model, train_loader, val_loader, num_epochs=5, model_name="GRU"
)

print("\n MODELS TRAINED SUCCESSFULLY!")

 FIXED MODELS - TRAINING STARTING...

 TRAINING LSTM WITH ATTENTION


In [None]:
# ==================== COMPREHENSIVE MODEL EVALUATION ====================
def evaluate_model(model, test_loader, model_name="Model"):
    """Comprehensive evaluation of model performance."""
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    test_loss = 0
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for premise, hypothesis, labels in test_loader:
            premise, hypothesis, labels = premise.to(device), hypothesis.to(device), labels.to(device)
            
            outputs, _ = model(premise, hypothesis)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # Calculate metrics
    from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
    
    accuracy = accuracy_score(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds, 
                                       target_names=['entails', 'neutral'])
    conf_matrix = confusion_matrix(all_labels, all_preds)
    
    # ROC AUC (requires probability scores)
    try:
        roc_auc = roc_auc_score(all_labels, [p[1] for p in all_probs])  # Use probability for class 1
    except:
        roc_auc = "N/A"
    
    print(f"\n {model_name.upper()} EVALUATION RESULTS")
    print("=" * 60)
    print(f"Test Loss: {test_loss/len(test_loader):.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"ROC AUC: {roc_auc}")
    print("\nClassification Report:")
    print(class_report)
    print("\nConfusion Matrix:")
    print(conf_matrix)
    
    return {
        'accuracy': accuracy,
        'loss': test_loss/len(test_loader),
        'roc_auc': roc_auc,
        'predictions': all_preds,
        'probabilities': all_probs,
        'true_labels': all_labels
    }

# ==================== COMPARATIVE ANALYSIS ====================
def compare_models(results_dict):
    """Compare performance of multiple models."""
    print("\n MODEL COMPARISON ANALYSIS")
    print("=" * 60)
    
    models = list(results_dict.keys())
    metrics = ['accuracy', 'loss', 'roc_auc']
    
    for metric in metrics:
        print(f"\n{metric.upper():<12} {' | '.join(f'{m}: {results_dict[m][metric]:.4f}' for m in models)}")
    
    # Find best model
    best_acc_model = max(models, key=lambda m: results_dict[m]['accuracy'])
    best_loss_model = min(models, key=lambda m: results_dict[m]['loss'])
    
    print(f"\n Best Accuracy: {best_acc_model} ({results_dict[best_acc_model]['accuracy']:.4f})")
    print(f" Best Loss: {best_loss_model} ({results_dict[best_loss_model]['loss']:.4f})")

# ==================== VISUALIZATION ====================
def plot_training_history(lstm_train_loss, lstm_val_loss, lstm_val_acc,
                         gru_train_loss, gru_val_loss, gru_val_acc):
    """Plot training history for both models."""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Training Loss
    ax1.plot(lstm_train_loss, label='LSTM Train', marker='o')
    ax1.plot(gru_train_loss, label='GRU Train', marker='s')
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Validation Loss
    ax2.plot(lstm_val_loss, label='LSTM Val', marker='o')
    ax2.plot(gru_val_loss, label='GRU Val', marker='s')
    ax2.set_title('Validation Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)
    
    # Validation Accuracy
    ax3.plot(lstm_val_acc, label='LSTM', marker='o')
    ax3.plot(gru_val_acc, label='GRU', marker='s')
    ax3.set_title('Validation Accuracy')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Accuracy')
    ax3.legend()
    ax3.grid(True)
    
    # Final comparison bar chart
    models = ['LSTM', 'GRU']
    final_acc = [lstm_val_acc[-1], gru_val_acc[-1]]
    final_loss = [lstm_val_loss[-1], gru_val_loss[-1]]
    
    x = np.arange(len(models))
    width = 0.35
    
    ax4.bar(x - width/2, final_acc, width, label='Accuracy', alpha=0.8)
    ax4.bar(x + width/2, final_loss, width, label='Loss', alpha=0.8)
    ax4.set_title('Final Epoch Comparison')
    ax4.set_xticks(x)
    ax4.set_xticklabels(models)
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# ==================== ERROR ANALYSIS ====================
def error_analysis(model, test_loader, model_name, num_samples=5):
    """Analyze specific errors made by the model."""
    model.eval()
    errors = []
    
    with torch.no_grad():
        for premise, hypothesis, labels in test_loader:
            premise, hypothesis, labels = premise.to(device), hypothesis.to(device), labels.to(device)
            
            outputs, _ = model(premise, hypothesis)
            preds = torch.argmax(outputs, dim=1)
            
            # Find incorrect predictions
            incorrect_mask = (preds != labels)
            if incorrect_mask.any():
                incorrect_indices = torch.where(incorrect_mask)[0]
                for idx in incorrect_indices:
                    errors.append({
                        'premise': premise[idx].cpu().numpy(),
                        'hypothesis': hypothesis[idx].cpu().numpy(),
                        'true_label': labels[idx].item(),
                        'predicted_label': preds[idx].item(),
                        'probabilities': torch.softmax(outputs[idx], dim=0).cpu().numpy()
                    })
                    if len(errors) >= num_samples:
                        break
            if len(errors) >= num_samples:
                break
    
    print(f"\n {model_name.upper()} ERROR ANALYSIS (First {num_samples} errors)")
    print("=" * 60)
    
    for i, error in enumerate(errors):
        print(f"\nError {i+1}:")
        print(f"True: {error['true_label']} ({'entails' if error['true_label'] == 0 else 'neutral'})")
        print(f"Pred: {error['predicted_label']} ({'entails' if error['predicted_label'] == 0 else 'neutral'})")
        print(f"Confidence: {max(error['probabilities']):.3f}")
        # You can add actual text here if you have the reverse mapping

# ==================== RUN COMPREHENSIVE EVALUATION ====================
print(" RUNNING COMPREHENSIVE EVALUATION...")

# Evaluate both models on test set
lstm_results = evaluate_model(lstm_model, test_loader, "LSTM with Attention")
gru_results = evaluate_model(gru_model, test_loader, "GRU")

# Compare models
compare_models({
    'LSTM': lstm_results,
    'GRU': gru_results
})

# Plot training history
plot_training_history(lstm_train_loss, lstm_val_loss, lstm_val_acc,
                     gru_train_loss, gru_val_loss, gru_val_acc)

# Error analysis
error_analysis(lstm_model, test_loader, "LSTM")
error_analysis(gru_model, test_loader, "GRU")

print("\n COMPREHENSIVE EVALUATION COMPLETE!")

### LSTM (WITHOUT ABLATION)

In [None]:
### Attention Mechanism Ablation

# Create LSTM model WITHOUT attention
class LSTMModelNoAttention(nn.Module):
    # Same as LSTM but remove the attention parts
    def __init__(self, vocab_size, embedding_dim=100, hidden_dim=128, num_classes=2):
        super(LSTMModelNoAttention, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm_premise = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.lstm_hypothesis = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.classifier = nn.Linear(hidden_dim * 4, num_classes)  # No attention features
    
    def forward(self, premise, hypothesis):
        prem_embed = self.embedding(premise)
        hyp_embed = self.embedding(hypothesis)
        prem_lstm, _ = self.lstm_premise(prem_embed)
        hyp_lstm, _ = self.lstm_hypothesis(hyp_embed)
        prem_pooled = torch.mean(prem_lstm, dim=1)
        hyp_pooled = torch.mean(hyp_lstm, dim=1)
        combined = torch.cat([prem_pooled, hyp_pooled], dim=1)
        return self.classifier(combined)