# ArgLegalSumm: End-to-end Legal Text Summarization (Colab-ready)

This single notebook trains a Hugging Face encoderâ€“decoder model on your legal dataset (train/test CSVs), resumes safely from checkpoints after restarts, evaluates ROUGE, and runs inference. Steps are idempotent and re-load settings from disk so each cell can run independently.

In [None]:
# 1) Check GPU and Install Dependencies (safe to rerun)
import sys, subprocess, json, os

REQS = [
    'transformers>=4.41.0',
    'datasets>=2.18.0',
    'accelerate>=0.32.0',
    'evaluate>=0.4.2',
    'rouge-score>=0.1.2',
    'sentencepiece>=0.1.99',
    'PyYAML>=6.0.1',
    'pandas>=2.0.0',
    'tqdm>=4.66.0',
    'nltk>=3.8.1',
    'scikit-learn>=1.2.0'
]

def pip_install(pkgs):
    cmd = [sys.executable, '-m', 'pip', 'install', '-q'] + pkgs
    try:
        subprocess.check_call(cmd)
    except subprocess.CalledProcessError as e:
        print('pip install error:', e)

pip_install(REQS)

# Show versions
import transformers, datasets, evaluate, pandas as pd, torch, nltk
print('Python:', sys.version.split()[0])
print('Torch:', torch.__version__)
print('Transformers:', transformers.__version__)
print('Datasets:', datasets.__version__)
print('Evaluate:', evaluate.__version__)

# GPU info
if torch.cuda.is_available():
    print('CUDA available. Device count:', torch.cuda.device_count())
    print('Current device:', torch.cuda.get_device_name(0))
else:
    print('CUDA not available; training will run on CPU (slow).')

# NLTK punkt for optional scoring
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    import nltk
    nltk.download('punkt')

In [None]:
# 2) Connect to Google Drive and Define Paths (persistent artifacts)
import os, json

IN_COLAB = False
try:
    import google.colab  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    ROOT_DIR = '/content/drive/MyDrive/arglegalsumm'
else:
    # Fallback local path if running outside Colab
    ROOT_DIR = os.path.abspath('./arglegalsumm_artifacts')

DATA_DIR = os.path.join(ROOT_DIR, 'data')
CKPT_DIR = os.path.join(ROOT_DIR, 'checkpoints')
BEST_DIR = os.path.join(ROOT_DIR, 'best')
LOG_DIR = os.path.join(ROOT_DIR, 'logs')
TOKENIZER_DIR = os.path.join(ROOT_DIR, 'tokenizer')
CACHE_DIR = os.path.join(ROOT_DIR, 'cache')
EXPORT_DIR = os.path.join(ROOT_DIR, 'exports')
SRC_DIR = os.path.join(ROOT_DIR, 'src')  # optional copy of the project source

for d in [ROOT_DIR, DATA_DIR, CKPT_DIR, BEST_DIR, LOG_DIR, TOKENIZER_DIR, CACHE_DIR, EXPORT_DIR, SRC_DIR]:
    os.makedirs(d, exist_ok=True)

print('ROOT_DIR:', ROOT_DIR)
print('DATA_DIR:', DATA_DIR)
print('CKPT_DIR:', CKPT_DIR)
print('BEST_DIR:', BEST_DIR)
print('SRC_DIR:', SRC_DIR)

In [None]:
# 3) Sync Project Files and Python Path (optional)
import os, sys, shutil, glob

# Option A (recommended): Keep this notebook self-contained; no import from src required.
# Option B: If you have the repo 'src' folder (e.g., uploaded/unzipped to content), copy it into ROOT_DIR and add to sys.path.
if os.path.isdir('/content/src'):
    # user uploaded repo to /content
    if not os.path.isdir(os.path.join(SRC_DIR, 'summarization')):
        shutil.copytree('/content/src', SRC_DIR, dirs_exist_ok=True)

if os.path.isdir(SRC_DIR):
    if SRC_DIR not in sys.path:
        sys.path.insert(0, SRC_DIR)

print('sys.path contains SRC_DIR:', SRC_DIR in sys.path)
print('Found modules under SRC_DIR:', os.listdir(SRC_DIR))

In [None]:
# 4) Load Training Config (YAML from repo if available) and persist JSON config
import os, json, yaml

# Defaults mirror src/summarization/config/train_config.yml
DEFAULTS = {
    'model': {
        'checkpoint': 'allenai/led-base-16384',
        'batchsize': 1
    },
    'data': {
        'max_input_length': 1024,
        'max_output_length': 512
    },
    'training': {
        'eval_steps': 250,
        'save_steps': 250,
        'epochs': 3,  # lower by default for Colab; override if YAML present
        'gradient_accumulation_steps': 4,
        'learning_rate': 5e-5,
        'logging_steps': 50
    },
    'io': {
        'output_dir': CKPT_DIR,
        'best_dir': BEST_DIR,
        'tokenizer_dir': TOKENIZER_DIR,
        'export_dir': EXPORT_DIR,
        'cache_dir': CACHE_DIR
    },
    'data_schema': {
        'input_col_candidates': ['article', 'text', 'document', 'content'],
        'summary_col_candidates': ['summary', 'target', 'abstract']
    }
}

# Try to load YAML from the repo structure if present under SRC_DIR
YAML_PATHS = [
    os.path.join(SRC_DIR, 'summarization', 'config', 'train_config.yml'),
    os.path.join('/content', 'src', 'summarization', 'config', 'train_config.yml'),
]
loaded_yml = None
for yp in YAML_PATHS:
    if os.path.isfile(yp):
        with open(yp, 'r') as f:
            try:
                loaded_yml = yaml.safe_load(f)
                break
            except Exception as e:
                print('Failed to parse YAML at', yp, e)

if loaded_yml:
    # merge into defaults
    DEFAULTS['model']['checkpoint'] = loaded_yml.get('model', {}).get('checkpoint', DEFAULTS['model']['checkpoint'])
    DEFAULTS['model']['batchsize'] = loaded_yml.get('model', {}).get('batchsize', DEFAULTS['model']['batchsize'])
    DEFAULTS['data']['max_input_length'] = loaded_yml.get('data', {}).get('max_input_length', DEFAULTS['data']['max_input_length'])
    DEFAULTS['data']['max_output_length'] = loaded_yml.get('data', {}).get('max_output_length', DEFAULTS['data']['max_output_length'])
    DEFAULTS['training']['eval_steps'] = loaded_yml.get('training', {}).get('eval_steps', DEFAULTS['training']['eval_steps'])
    DEFAULTS['training']['save_steps'] = loaded_yml.get('training', {}).get('save_steps', DEFAULTS['training']['save_steps'])
    DEFAULTS['training']['epochs'] = loaded_yml.get('training', {}).get('epochs', DEFAULTS['training']['epochs'])

# Persist JSON config for restart-safe execution
CONFIG_JSON = os.path.join(ROOT_DIR, 'pipeline_config.json')
with open(CONFIG_JSON, 'w') as f:
    json.dump(DEFAULTS, f, indent=2)

print('Resolved config:')
print(json.dumps(DEFAULTS, indent=2))
print('Saved to', CONFIG_JSON)

In [None]:
# 5) Load Train/Test CSVs and Validate Schema (idempotent + case-insensitive + custom paths)
import os, json, shutil
import pandas as pd
from pathlib import Path

# reload config
with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'r') as f:
    CFG = json.load(f)

# Allow explicit paths via environment variables (recommended in Colab)
train_path_env = os.environ.get('TRAIN_CSV_PATH')
test_path_env = os.environ.get('TEST_CSV_PATH')

# Expected filenames under DATA_DIR if env vars not provided
TRAIN_FILE_NAME = 'train (1).csv' if os.path.isfile(os.path.join(DATA_DIR, 'train (1).csv')) else 'train.csv'
TEST_FILE_NAME = 'test.csv'

if train_path_env and os.path.isfile(train_path_env):
    train_path = train_path_env
else:
    train_path = os.path.join(DATA_DIR, TRAIN_FILE_NAME)

if test_path_env and os.path.isfile(test_path_env):
    test_path = test_path_env
else:
    test_path = os.path.join(DATA_DIR, TEST_FILE_NAME)

# Provide upload fallback in Colab if files aren't found and no env vars
if not (os.path.isfile(train_path) and os.path.isfile(test_path)):
    print('CSV files not found. DATA_DIR =', DATA_DIR)
    print('You can set TRAIN_CSV_PATH and TEST_CSV_PATH env vars, or upload files now...')
    try:
        from google.colab import files
        uploaded = files.upload()
        os.makedirs(DATA_DIR, exist_ok=True)
        for name in uploaded.keys():
            shutil.move(name, os.path.join(DATA_DIR, name))
        train_path = os.path.join(DATA_DIR, TRAIN_FILE_NAME)
        test_path = os.path.join(DATA_DIR, TEST_FILE_NAME)
        print('Uploaded. DATA_DIR contains:', os.listdir(DATA_DIR))
    except Exception as e:
        print('Upload failed; please ensure your CSVs exist in DATA_DIR or set env vars.', e)

assert os.path.isfile(train_path), f'Missing train CSV: {train_path}'
assert os.path.isfile(test_path), f'Missing test CSV: {test_path}'

print('Loading CSVs...')
train_df = pd.read_csv(train_path)
test_df = pd.read_csv(test_path)
print('Train shape:', train_df.shape, '| Test shape:', test_df.shape)
print('Train columns:', list(train_df.columns))
print('Test columns:', list(test_df.columns))

# Case-insensitive normalization of column names
train_df.columns = [str(c).strip().lower() for c in train_df.columns]
test_df.columns = [str(c).strip().lower() for c in test_df.columns]

# Normalize common schema names
def resolve_cols(df):
    input_col = None
    summary_col = None
    for c in CFG['data_schema']['input_col_candidates']:
        if c.lower() in df.columns:
            input_col = c.lower()
            break
    for c in CFG['data_schema']['summary_col_candidates']:
        if c.lower() in df.columns:
            summary_col = c.lower()
            break
    return input_col, summary_col

train_in_col, train_sum_col = resolve_cols(train_df)
# Prefer summary col on test too (often 'summary')
test_in_col, test_sum_col = resolve_cols(test_df)

# If test has no summaries, allow test_sum_col to be None
assert train_in_col is not None, 'Could not find an input text column (tried: {}) in train CSV'.format(CFG['data_schema']['input_col_candidates'])
assert train_sum_col is not None, 'Could not find a summary column (tried: {}) in train CSV'.format(CFG['data_schema']['summary_col_candidates'])

# Clean NAs and strip
for df, in_c, sum_c in [
    (train_df, train_in_col, train_sum_col),
    (test_df, test_in_col, test_sum_col),
]:
    if in_c is not None:
        df[in_c] = df[in_c].astype(str).fillna('').str.strip()
    if sum_c is not None:
        df[sum_c] = df[sum_c].astype(str).fillna('').str.strip()

# Persist cleaned copies (idempotent)
clean_train_csv = os.path.join(DATA_DIR, 'train_clean.csv')
clean_test_csv = os.path.join(DATA_DIR, 'test_clean.csv')
train_df.to_csv(clean_train_csv, index=False)
test_df.to_csv(clean_test_csv, index=False)

# Save resolved schema (lowercase) to config JSON for later cells
CFG['data_schema']['input_col'] = train_in_col
CFG['data_schema']['summary_col'] = train_sum_col
with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'w') as f:
    json.dump(CFG, f, indent=2)

print('Resolved columns -> input:', train_in_col, '| summary:', train_sum_col)
print('Saved cleaned datasets to:')
print(' -', clean_train_csv)
print(' -', clean_test_csv)
print('Tip: To use explicit paths next time, set env vars:')
print('  TRAIN_CSV_PATH="/content/drive/MyDrive/.../train.csv"  TEST_CSV_PATH="/content/drive/MyDrive/.../test.csv"')

In [None]:
# 6) Text Cleaning and Normalization (legal-specific; idempotent)
import os, re, json
import pandas as pd

with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'r') as f:
    CFG = json.load(f)

in_col = CFG['data_schema']['input_col']
sum_col = CFG['data_schema']['summary_col']

clean_train_csv = os.path.join(DATA_DIR, 'train_clean.csv')
clean_test_csv = os.path.join(DATA_DIR, 'test_clean.csv')
train_df = pd.read_csv(clean_train_csv)
test_df = pd.read_csv(clean_test_csv)

ENABLE_CLEANING = True

_ws_re = re.compile(r'\s+')
_docket_re = re.compile(r'\b(\d{1,4}[-/]\d{1,6})(?:\s*[A-Za-z]*)?\b')


def normalize_text(s: str) -> str:
    if not isinstance(s, str):
        s = str(s)
    s = s.replace('\u201c', '"').replace('\u201d', '"').replace('\u2019', "'")
    s = _docket_re.sub(' ', s)
    s = _ws_re.sub(' ', s)
    s = s.strip()
    if s and s[-1].isalnum():
        s += '.'
    return s

if ENABLE_CLEANING:
    train_df[in_col] = train_df[in_col].map(normalize_text)
    if sum_col in train_df.columns:
        train_df[sum_col] = train_df[sum_col].map(normalize_text)
    if in_col in test_df.columns:
        test_df[in_col] = test_df[in_col].map(normalize_text)
    if sum_col in test_df.columns:
        test_df[sum_col] = test_df[sum_col].map(normalize_text)

# Overwrite cleaned files
train_df.to_csv(clean_train_csv, index=False)
test_df.to_csv(clean_test_csv, index=False)
print('Re-saved normalized CSVs:', clean_train_csv, clean_test_csv)

In [None]:
# 7) Train/Validation Split and Persisted Splits
import os, json
import pandas as pd
from sklearn.model_selection import train_test_split

with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'r') as f:
    CFG = json.load(f)

in_col = CFG['data_schema']['input_col']
sum_col = CFG['data_schema']['summary_col']

clean_train_csv = os.path.join(DATA_DIR, 'train_clean.csv')
train_df = pd.read_csv(clean_train_csv)

train_split_csv = os.path.join(DATA_DIR, 'train_split.csv')
val_split_csv = os.path.join(DATA_DIR, 'val_split.csv')

if os.path.isfile(train_split_csv) and os.path.isfile(val_split_csv):
    print('Using existing persisted splits:')
    print(train_split_csv, val_split_csv)
else:
    # Create stratification bins by summary length if available
    if sum_col in train_df.columns:
        lengths = train_df[sum_col].fillna('').astype(str).str.split().map(len)
        bins = pd.qcut(lengths, q=min(10, max(2, lengths.nunique())), duplicates='drop')
        stratify = bins
    else:
        stratify = None
    tr_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42, stratify=stratify)
    tr_df.to_csv(train_split_csv, index=False)
    val_df.to_csv(val_split_csv, index=False)

print(pd.read_csv(train_split_csv).shape, pd.read_csv(val_split_csv).shape)

In [None]:
# 8) Tokenizer Setup (pretrained + optional custom tokens)
import os, json
from transformers import AutoTokenizer

with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'r') as f:
    CFG = json.load(f)

MODEL_CHECKPOINT = CFG['model']['checkpoint']
TOKENIZER_DIR = CFG['io']['tokenizer_dir']

# Load tokenizer
print('Loading tokenizer:', MODEL_CHECKPOINT)
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, use_fast=True)

# Optional: load argument tokens if available under SRC_DIR
extra_tokens = []
for fname in ['binary_tokens.txt', 'fine_grained_tokens.txt']:
    p = os.path.join(SRC_DIR, 'summarization', fname)
    if os.path.isfile(p):
        with open(p, 'r') as f:
            toks = [t.strip() for t in f if t.strip()]
            extra_tokens.extend(toks)

if extra_tokens:
    print(f'Adding {len(extra_tokens)} special tokens from repo to tokenizer')
    tokenizer.add_special_tokens({'additional_special_tokens': extra_tokens})

# Persist tokenizer for reuse after restarts
os.makedirs(TOKENIZER_DIR, exist_ok=True)
tokenizer.save_pretrained(TOKENIZER_DIR)
print('Saved tokenizer to', TOKENIZER_DIR)

In [None]:
# 9) Dataset and Tokenization Functions (cache to disk)
import os, json
import pandas as pd
from datasets import Dataset, load_dataset

with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'r') as f:
    CFG = json.load(f)

in_col = CFG['data_schema']['input_col']
sum_col = CFG['data_schema']['summary_col']
max_src_len = CFG['data']['max_input_length']
max_tgt_len = CFG['data']['max_output_length']

train_split_csv = os.path.join(DATA_DIR, 'train_split.csv')
val_split_csv = os.path.join(DATA_DIR, 'val_split.csv')
clean_test_csv = os.path.join(DATA_DIR, 'test_clean.csv')

# Load splits into HF Datasets
train_ds = load_dataset('csv', data_files={'train': train_split_csv})['train']
val_ds = load_dataset('csv', data_files={'train': val_split_csv})['train']
test_ds = load_dataset('csv', data_files={'train': clean_test_csv})['train']

print('Train/Val/Test sizes:', len(train_ds), len(val_ds), len(test_ds))

# Tokenization function including global_attention_mask for LED
from functools import partial

def tokenize_fn(batch, tokenizer=None, in_col=None, sum_col=None, max_src_len=1024, max_tgt_len=256):
    inputs = tokenizer(
        batch[in_col],
        padding='max_length',
        truncation=True,
        max_length=max_src_len,
    )
    model_inputs = {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
    }
    if sum_col in batch:
        labels = tokenizer(
            batch[sum_col],
            padding='max_length',
            truncation=True,
            max_length=max_tgt_len,
        )['input_ids']
        # Replace padding token id's of the labels by -100 so it's ignored by the loss
        labels = [[(-100 if t == tokenizer.pad_token_id else t) for t in seq] for seq in labels]
        model_inputs['labels'] = labels
    # LED specific: global attention on first token
    global_attention_mask = []
    for att in inputs['attention_mask']:
        gam = [0] * len(att)
        if len(gam) > 0:
            gam[0] = 1
        global_attention_mask.append(gam)
    model_inputs['global_attention_mask'] = global_attention_mask
    return model_inputs

map_fn = partial(
    tokenize_fn,
    tokenizer=tokenizer,
    in_col=in_col,
    sum_col=sum_col,
    max_src_len=max_src_len,
    max_tgt_len=max_tgt_len,
)

train_tokenized = train_ds.map(map_fn, batched=True, remove_columns=[c for c in train_ds.column_names if c not in [in_col, sum_col]])
val_tokenized = val_ds.map(map_fn, batched=True, remove_columns=[c for c in val_ds.column_names if c not in [in_col, sum_col]])
# For test, sum_col may be missing; keep only input processing
map_fn_test = partial(tokenize_fn, tokenizer=tokenizer, in_col=in_col, sum_col=sum_col, max_src_len=max_src_len, max_tgt_len=max_tgt_len)
test_tokenized = test_ds.map(map_fn_test, batched=True, remove_columns=[c for c in test_ds.column_names if c not in [in_col, sum_col] and c != in_col])

print(train_tokenized[0].keys())

In [None]:
# 10) Data collator and dataset formatting
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=None, label_pad_token_id=-100, padding=True, return_tensors='pt')

# set format for PyTorch
train_tokenized.set_format(type='torch', columns=['input_ids','attention_mask','global_attention_mask','labels'])
val_tokenized.set_format(type='torch', columns=['input_ids','attention_mask','global_attention_mask','labels'])
# On test we might not have 'labels'
cols_test = ['input_ids','attention_mask','global_attention_mask']
cols_test = [c for c in cols_test if c in test_tokenized.column_names]
test_tokenized.set_format(type='torch', columns=cols_test)

print('Batch example shapes:')
batch = data_collator([train_tokenized[i] for i in range(min(2, len(train_tokenized)))])
for k, v in batch.items():
    try:
        print(k, v.shape)
    except Exception:
        print(k, type(v))

In [None]:
# 11) Initialize Encoder-Decoder Model (with LED defaults)
import os, json, torch
from transformers import AutoModelForSeq2SeqLM

with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'r') as f:
    CFG = json.load(f)

MODEL_CHECKPOINT = CFG['model']['checkpoint']

print('Loading model:', MODEL_CHECKPOINT)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT)

# If we added extra tokens to tokenizer, resize embeddings
if hasattr(tokenizer, 'vocab'):
    model.resize_token_embeddings(len(tokenizer))

# LED best practices for memory
try:
    model.gradient_checkpointing_enable()
except Exception:
    pass
model.config.use_cache = False  # important with gradient checkpointing

# Ensure EOS/BOS tokens are set (helps beam search stop cleanly)
try:
    if getattr(model.config, 'eos_token_id', None) is None and getattr(tokenizer, 'eos_token_id', None) is not None:
        model.config.eos_token_id = tokenizer.eos_token_id
    if getattr(model.config, 'bos_token_id', None) is None and getattr(tokenizer, 'bos_token_id', None) is not None:
        model.config.bos_token_id = tokenizer.bos_token_id
except Exception:
    pass

# Set default generation hyperparameters via generation_config (avoids warnings)
# Align max_length to config to prevent premature truncation
_gen_max = int(CFG['data'].get('max_output_length', 256))
model.generation_config.num_beams = 4
model.generation_config.max_length = _gen_max
# Choose a reasonable min_length relative to the max (and ensure < max)
model.generation_config.min_length = min(max(32, _gen_max // 6), _gen_max - 1)
# Lower length penalty encourages finishing thoughts instead of early cut-offs
model.generation_config.length_penalty = 1.0
model.generation_config.early_stopping = True
model.generation_config.no_repeat_ngram_size = 3

if torch.cuda.is_available():
    model = model.cuda()
    print('Moved model to CUDA')

# Parameter counts
num_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Params: total={num_params:,} trainable={trainable_params:,}')

In [None]:
# 12) TrainingArguments with robust checkpointing/auto-resume
import os, json
from transformers import Seq2SeqTrainingArguments

with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'r') as f:
    CFG = json.load(f)

args = Seq2SeqTrainingArguments(
    output_dir=CKPT_DIR,
    evaluation_strategy='steps',
    save_strategy='steps',
    logging_steps=CFG['training'].get('logging_steps', 50),
    eval_steps=CFG['training']['eval_steps'],
    save_steps=CFG['training']['save_steps'],
    save_total_limit=2,
    num_train_epochs=CFG['training']['epochs'],
    per_device_train_batch_size=CFG['model']['batchsize'],
    per_device_eval_batch_size=CFG['model']['batchsize'],
    gradient_accumulation_steps=CFG['training']['gradient_accumulation_steps'],
    learning_rate=CFG['training']['learning_rate'],
    predict_with_generate=True,
    load_best_model_at_end=True,
    # Use a metric that exists in compute_metrics/evaluate output
    metric_for_best_model='eval_rougeLsum',
    greater_is_better=True,
    fp16=torch.cuda.is_available(),
    report_to=['none'],  # disable wandb
    # Generation defaults to keep Trainer.generate consistent with our model config
    generation_max_length=CFG['data'].get('max_output_length', 256),
    generation_num_beams=4,
    # Stabilization
    warmup_steps=CFG['training'].get('warmup_steps', 500),
    label_smoothing_factor=CFG['training'].get('label_smoothing_factor', 0.1),
)

print(args)

In [None]:
# 13) Start or Resume Training from Latest Checkpoint (restart-safe)
import os, json
from transformers import Seq2SeqTrainer, EarlyStoppingCallback
from transformers.trainer_utils import get_last_checkpoint
import evaluate, torch

# Rehydrate state after crash
ensure_state_for_training()

cfg = load_cfg()
rouge_metric = evaluate.load('rouge')

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = [[(l if l != -100 else tokenizer.pad_token_id) for l in label] for label in labels]
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v, 4) for k, v in result.items()}
    gen_lens = [len(pred.split()) for pred in decoded_preds]
    result['gen_len'] = sum(gen_lens) / max(1, len(gen_lens))
    return result

trainer = Seq2SeqTrainer(
    model=model,
    processing_class=tokenizer,  # use processing_class (replaces deprecated `tokenizer` arg)
    args=args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

last_ckpt = None
if os.path.isdir(CKPT_DIR):
    last_ckpt = get_last_checkpoint(CKPT_DIR)
    print('Last checkpoint:', last_ckpt)

train_result = trainer.train(resume_from_checkpoint=last_ckpt)
trainer.save_state()

# Save training logs
os.makedirs(LOG_DIR, exist_ok=True)
train_logs_path = os.path.join(LOG_DIR, 'train_logs.txt')
with open(train_logs_path, 'a') as f:
    f.write(str(train_result.metrics) + '\n')
print('Training complete. Logs ->', train_logs_path)

In [None]:
# 14) Evaluate with ROUGE on Validation and Test (restart-safe)
import os, json
import numpy as np
import pandas as pd

# Rehydrate state
ensure_state_for_training()
CFG = load_cfg_ensured()
in_col = CFG['data_schema']['input_col']
sum_col = CFG['data_schema']['summary_col']

# Validation evaluation (use default 'eval_' prefix to avoid key mismatches)
val_metrics = trainer.evaluate(eval_dataset=val_tokenized)
print('Validation metrics:', val_metrics)

# Test evaluation
test_df = pd.read_csv(os.path.join(DATA_DIR, 'test_clean.csv'))
has_test_refs = (sum_col in test_df.columns) and test_df[sum_col].notna().any() and ('labels' in getattr(test_tokenized, 'column_names', []))

if has_test_refs:
    # We have references -> compute ROUGE with 'test_' prefix
    test_metrics = trainer.evaluate(eval_dataset=test_tokenized, metric_key_prefix='test')
else:
    # No references -> disable compute_metrics to avoid KeyErrors and just run predict for timing
    saved_cm = trainer.compute_metrics
    trainer.compute_metrics = None
    pred = trainer.predict(test_tokenized, metric_key_prefix='test')
    trainer.compute_metrics = saved_cm
    test_metrics = pred.metrics

print('Test metrics:', test_metrics)

# Helper to safely cast numbers for JSON
def _to_float_dict(d):
    out = {}
    for k, v in d.items():
        try:
            out[k] = float(v)
        except Exception:
            out[k] = v
    return out

# Save metrics
with open(os.path.join(ROOT_DIR, 'metrics_val.json'), 'w') as f:
    json.dump(_to_float_dict(val_metrics), f, indent=2)
with open(os.path.join(ROOT_DIR, 'metrics_test.json'), 'w') as f:
    json.dump(_to_float_dict(test_metrics), f, indent=2)
print('Saved metrics to ROOT_DIR')

In [None]:
# 15) Save Best Model/Tokenizer to Drive and verify reload
import os

trainer.save_model(BEST_DIR)  # saves model + tokenizer (if passed) + training args
try:
    tokenizer.save_pretrained(BEST_DIR)
except Exception:
    pass

print('Saved best model to', BEST_DIR)

# Quick reload test
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
re_tok = AutoTokenizer.from_pretrained(BEST_DIR)
re_model = AutoModelForSeq2SeqLM.from_pretrained(BEST_DIR)
if torch.cuda.is_available():
    re_model = re_model.cuda()

dummy = tokenizer('Test input.', return_tensors='pt')
if torch.cuda.is_available():
    dummy = {k: v.cuda() for k, v in dummy.items()}
out = re_model.generate(**dummy, max_length=32, num_beams=2)
print('Reloaded model generated:', re_tok.batch_decode(out, skip_special_tokens=True))

In [None]:
# 16) Batch Inference on Test Set and Export Summaries (restart-safe)
import os, json
import pandas as pd
from tqdm import tqdm
import torch

CFG = load_cfg_ensured()

# Reload best tokenizer/model for inference
infer_tokenizer = load_tokenizer_for_inference(CFG)
infer_model = load_model_for_inference(CFG, tokenizer=infer_tokenizer)

in_col_cfg = CFG['data_schema']['input_col']
sum_col = CFG['data_schema']['summary_col']  # may be absent in test

# Optional quick override for longer outputs (no retrain needed)
# You can also set GEN_MAX_LEN env var to control this dynamically.
OVERRIDE_MAX_LEN = int(os.environ.get('GEN_MAX_LEN', '640'))  # try 640 or 768
if OVERRIDE_MAX_LEN > 0:
    infer_model.generation_config.max_length = OVERRIDE_MAX_LEN
    # keep a reasonable min_length but ensure < max
    infer_model.generation_config.min_length = min(max(32, OVERRIDE_MAX_LEN // 6), OVERRIDE_MAX_LEN - 1)
    infer_model.generation_config.length_penalty = 1.0

print(
    'Generation config ->',
    'max_length=', infer_model.generation_config.max_length,
    '| min_length=', infer_model.generation_config.min_length,
    '| num_beams=', infer_model.generation_config.num_beams,
    '| length_penalty=', infer_model.generation_config.length_penalty,
)

# Reload test CSV for text and reference
test_df = pd.read_csv(os.path.join(DATA_DIR, 'test_clean.csv'))
# Ensure lowercase columns (safety)
test_df.columns = [str(c).strip().lower() for c in test_df.columns]

# Determine effective input column for test
candidates = [in_col_cfg] + CFG['data_schema'].get('input_col_candidates', []) + ['text', 'article', 'document', 'content']
effective_in_col = next((c for c in candidates if c in test_df.columns), None)
if effective_in_col is None:
    raise KeyError(f"None of the candidate input columns {list(dict.fromkeys(candidates))} were found in test_clean.csv. Available columns: {list(test_df.columns)}")

# Generation parameters (align with generation_config)
gen_kwargs = dict(
    max_length=infer_model.generation_config.max_length,
    min_length=infer_model.generation_config.min_length,
    num_beams=infer_model.generation_config.num_beams,
    length_penalty=infer_model.generation_config.length_penalty,
    no_repeat_ngram_size=infer_model.generation_config.no_repeat_ngram_size,
)

preds = []
bsz = 4
for i in tqdm(range(0, len(test_df), bsz)):
    batch_texts = test_df[effective_in_col].iloc[i:i+bsz].tolist()
    inputs = infer_tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True, max_length=CFG['data']['max_input_length'])
    # LED: add global attention on first token
    attn = inputs['attention_mask']
    global_attention_mask = torch.zeros_like(attn)
    global_attention_mask[:, 0] = 1
    if torch.cuda.is_available():
        inputs = {k: v.cuda() for k, v in inputs.items()}
        global_attention_mask = global_attention_mask.cuda()
    with torch.no_grad():
        out_ids = infer_model.generate(**inputs, global_attention_mask=global_attention_mask, **gen_kwargs)
    gen_texts = infer_tokenizer.batch_decode(out_ids, skip_special_tokens=True)
    preds.extend(gen_texts)

export_path = os.path.join(EXPORT_DIR, 'test_predictions.csv')
export_df = pd.DataFrame({
    'id': range(len(preds)),
    'text': test_df[effective_in_col].tolist(),
    'generated_summary': preds,
})
if sum_col in test_df.columns:
    export_df['reference_summary'] = test_df[sum_col].tolist()

os.makedirs(EXPORT_DIR, exist_ok=True)
export_df.to_csv(export_path, index=False)
print('Saved predictions to', export_path)

# Quick peek
export_df.head(3)

In [None]:
# 17) Ad-hoc Inference on Custom Legal Texts (restart-safe)
CFG = load_cfg_ensured()

# Reload best tokenizer/model for inference
infer_tokenizer = load_tokenizer_for_inference(CFG)
infer_model = load_model_for_inference(CFG, tokenizer=infer_tokenizer)

# Optional quick override for longer outputs (no retrain needed)
OVERRIDE_MAX_LEN = int(os.environ.get('GEN_MAX_LEN', '640'))  # try 640 or 768
if OVERRIDE_MAX_LEN > 0:
    infer_model.generation_config.max_length = OVERRIDE_MAX_LEN
    infer_model.generation_config.min_length = min(max(32, OVERRIDE_MAX_LEN // 6), OVERRIDE_MAX_LEN - 1)
    infer_model.generation_config.length_penalty = 1.0

print(
    'Generation config ->',
    'max_length=', infer_model.generation_config.max_length,
    '| min_length=', infer_model.generation_config.min_length,
    '| num_beams=', infer_model.generation_config.num_beams,
    '| length_penalty=', infer_model.generation_config.length_penalty,
)

samples = [
    "The appellant challenges the lower court's ruling on evidentiary grounds, alleging improper admission of hearsay statements.",
    "In this contractual dispute, the plaintiff claims breach due to non-delivery within the stipulated time frame, seeking damages and specific performance.",
]

inputs = infer_tokenizer(samples, return_tensors='pt', padding=True, truncation=True, max_length=CFG['data']['max_input_length'])
attn = inputs['attention_mask']
global_attention_mask = torch.zeros_like(attn)
global_attention_mask[:, 0] = 1
if torch.cuda.is_available():
    inputs = {k: v.cuda() for k, v in inputs.items()}
    global_attention_mask = global_attention_mask.cuda()

out_ids = infer_model.generate(
    **inputs,
    global_attention_mask=global_attention_mask,
    max_length=infer_model.generation_config.max_length,
    min_length=infer_model.generation_config.min_length,
    num_beams=infer_model.generation_config.num_beams,
    length_penalty=infer_model.generation_config.length_penalty,
)
for src, pred in zip(samples, infer_tokenizer.batch_decode(out_ids, skip_special_tokens=True)):
    print('\nSOURCE:\n', src)
    print('\nSUMMARY:\n', pred)


In [None]:
# 18) Optional: Compute ROUGE via src/summarization/score_summaries.py (cross-validate)
import os, subprocess, sys

summ_script = os.path.join(SRC_DIR, 'summarization', 'score_summaries.py')
export_path = os.path.join(EXPORT_DIR, 'test_predictions.csv')

if os.path.isfile(summ_script) and os.path.isfile(export_path):
    print('Running SummEval-based scoring...')
    try:
        # summ_eval requires additional install; install quietly if missing
        import summ_eval  # type: ignore
    except Exception:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'SummEval'])
    # Run as a module-style call using Python
    cmd = [sys.executable, summ_script, '-summary_out', export_path]
    print('Command:', ' '.join(cmd))
    rc = subprocess.call(cmd)
    if rc != 0:
        print('SummEval scoring script returned non-zero exit code:', rc)
else:
    print('Skipping SummEval scoring (script or export not found).')

In [None]:
# 19) Optional: Argument signals from src/argument_classification (augment inputs)
import os, pandas as pd

preds_path = os.path.join(SRC_DIR, 'argument_classification', 'artifacts', 'legal_bert_predicts.txt')
if os.path.isfile(preds_path):
    print('Found argument predictions at', preds_path)
    # Example augmentation: prepend a marker indicating argumentative sections
    # This is a placeholder illustrating how to integrate such features
    # You could merge these predictions onto your train/test by ID and modify text inputs accordingly.
else:
    print('No argument classification artifacts found; skipping augmentation.')

In [None]:
# (Optional) Set seeds for reproducibility
import random, os
import numpy as np
import torch
from transformers import set_seed

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
set_seed(SEED)
print('Seeds set to', SEED)

In [None]:
# Utility: Bootstrap functions to rehydrate state after a crash
import os, json, torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from datasets import load_dataset
from functools import partial
import pandas as pd


def load_cfg():
    with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'r') as f:
        return json.load(f)


def _resolve_cols_from_df(df, cfg):
    cols_lower = [str(c).strip().lower() for c in df.columns]
    # Prefer configured candidates
    in_c = None
    sum_c = None
    for c in cfg['data_schema'].get('input_col_candidates', []):
        if c.lower() in cols_lower:
            in_c = c.lower()
            break
    for c in cfg['data_schema'].get('summary_col_candidates', []):
        if c.lower() in cols_lower:
            sum_c = c.lower()
            break
    # Common fallbacks
    if in_c is None and 'text' in cols_lower:
        in_c = 'text'
    if sum_c is None and 'summary' in cols_lower:
        sum_c = 'summary'
    # Last resort fallback: pick the first textual-looking column
    if in_c is None and len(cols_lower) > 0:
        in_c = cols_lower[0]
    return in_c, sum_c


def ensure_cfg_schema(cfg):
    ds = cfg.get('data_schema', {})
    needs = ('input_col' not in ds) or ('summary_col' not in ds)
    if not needs:
        return cfg
    # Try to infer from cleaned train csv, then raw train, then any csv in DATA_DIR
    candidate_paths = [
        os.path.join(DATA_DIR, 'train_clean.csv'),
        os.path.join(DATA_DIR, 'train_split.csv'),
        os.path.join(DATA_DIR, 'train.csv'),
    ]
    # Add any other CSVs in DATA_DIR as last resort
    for fname in os.listdir(DATA_DIR):
        if fname.lower().endswith('.csv'):
            p = os.path.join(DATA_DIR, fname)
            if p not in candidate_paths:
                candidate_paths.append(p)
    for p in candidate_paths:
        try:
            if os.path.isfile(p):
                df = pd.read_csv(p, nrows=100)
                df.columns = [str(c).strip().lower() for c in df.columns]
                in_c, sum_c = _resolve_cols_from_df(df, cfg)
                if in_c is not None:
                    cfg['data_schema']['input_col'] = in_c
                if sum_c is not None:
                    cfg['data_schema']['summary_col'] = sum_c
                # Persist and return if input_col found
                if in_c is not None:
                    with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'w') as f:
                        json.dump(cfg, f, indent=2)
                    print(f"Schema ensured from {os.path.basename(p)} -> input_col='{in_c}', summary_col='{sum_c}'")
                    return cfg
        except Exception:
            pass
    # If still not found, leave as-is (downstream will raise a clearer error)
    return cfg


def load_cfg_ensured():
    cfg = load_cfg()
    cfg = ensure_cfg_schema(cfg)
    return cfg


def load_tokenizer_for_inference(cfg):
    # Prefer best dir -> tokenizer dir -> checkpoint
    if os.path.isdir(BEST_DIR):
        try:
            return AutoTokenizer.from_pretrained(BEST_DIR, use_fast=True)
        except Exception:
            pass
    if os.path.isdir(cfg['io']['tokenizer_dir']):
        try:
            return AutoTokenizer.from_pretrained(cfg['io']['tokenizer_dir'], use_fast=True)
        except Exception:
            pass
    return AutoTokenizer.from_pretrained(cfg['model']['checkpoint'], use_fast=True)


def _apply_generation_defaults(model):
    # Keep generation behavior consistent across reloads (use CFG if available)
    try:
        with open(os.path.join(ROOT_DIR, 'pipeline_config.json'), 'r') as _f:
            _cfg = json.load(_f)
        _gen_max = int(_cfg.get('data', {}).get('max_output_length', 256))
    except Exception:
        _gen_max = 256
    model.generation_config.num_beams = 4
    model.generation_config.max_length = _gen_max
    model.generation_config.min_length = min(max(32, _gen_max // 6), _gen_max - 1)
    model.generation_config.length_penalty = 1.0
    model.generation_config.early_stopping = True
    model.generation_config.no_repeat_ngram_size = 3


def load_model_for_inference(cfg, tokenizer=None):
    path = BEST_DIR if os.path.isdir(BEST_DIR) else cfg['model']['checkpoint']
    model = AutoModelForSeq2SeqLM.from_pretrained(path)
    if tokenizer is not None:
        try:
            model.resize_token_embeddings(len(tokenizer))
        except Exception:
            pass
    try:
        model.gradient_checkpointing_enable()
    except Exception:
        pass
    model.config.use_cache = False
    _apply_generation_defaults(model)
    if torch.cuda.is_available():
        model = model.cuda()
    return model


def prepare_tokenized_datasets(cfg, tokenizer):
    in_col = cfg['data_schema']['input_col']
    sum_col = cfg['data_schema']['summary_col']
    max_src_len = cfg['data']['max_input_length']
    max_tgt_len = cfg['data']['max_output_length']

    train_split_csv = os.path.join(DATA_DIR, 'train_split.csv')
    val_split_csv = os.path.join(DATA_DIR, 'val_split.csv')
    clean_test_csv = os.path.join(DATA_DIR, 'test_clean.csv')

    train_ds = load_dataset('csv', data_files={'train': train_split_csv})['train']
    val_ds = load_dataset('csv', data_files={'train': val_split_csv})['train']
    test_ds = load_dataset('csv', data_files={'train': clean_test_csv})['train']

    def tokenize_fn(batch, tokenizer=None, in_col=None, sum_col=None, max_src_len=1024, max_tgt_len=256):
        inputs = tokenizer(
            batch[in_col],
            padding='max_length',
            truncation=True,
            max_length=max_src_len,
        )
        model_inputs = {
            'input_ids': inputs['input_ids'],
            'attention_mask': inputs['attention_mask'],
        }
        if sum_col in batch:
            labels = tokenizer(
                batch[sum_col],
                padding='max_length',
                truncation=True,
                max_length=max_tgt_len,
            )['input_ids']
            labels = [[(-100 if t == tokenizer.pad_token_id else t) for t in seq] for seq in labels]
            model_inputs['labels'] = labels
        # LED: global attention on first token
        global_attention_mask = []
        for att in inputs['attention_mask']:
            gam = [0] * len(att)
            if len(gam) > 0:
                gam[0] = 1
            global_attention_mask.append(gam)
        model_inputs['global_attention_mask'] = global_attention_mask
        return model_inputs

    map_fn = partial(
        tokenize_fn,
        tokenizer=tokenizer,
        in_col=in_col,
        sum_col=sum_col,
        max_src_len=max_src_len,
        max_tgt_len=max_tgt_len,
    )
    train_tok = train_ds.map(map_fn, batched=True, remove_columns=[c for c in train_ds.column_names if c not in [in_col, sum_col]])
    val_tok = val_ds.map(map_fn, batched=True, remove_columns=[c for c in val_ds.column_names if c not in [in_col, sum_col]])
    test_tok = test_ds.map(map_fn, batched=True, remove_columns=[c for c in test_ds.column_names if c not in [in_col, sum_col] and c != in_col])

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=None, label_pad_token_id=-100, padding=True, return_tensors='pt')

    train_tok.set_format(type='torch', columns=['input_ids','attention_mask','global_attention_mask','labels'])
    val_tok.set_format(type='torch', columns=['input_ids','attention_mask','global_attention_mask','labels'])
    cols_test = ['input_ids','attention_mask','global_attention_mask']
    cols_test = [c for c in cols_test if c in test_tok.column_names]
    test_tok.set_format(type='torch', columns=cols_test)

    return train_tok, val_tok, test_tok, data_collator


def ensure_state_for_training():
    # Ensures tokenizer, model, train_tokenized, val_tokenized, data_collator exist in globals
    g = globals()
    cfg = load_cfg_ensured()
    if 'tokenizer' not in g or g['tokenizer'] is None:
        g['tokenizer'] = load_tokenizer_for_inference(cfg)
    if 'model' not in g or g['model'] is None:
        g['model'] = AutoModelForSeq2SeqLM.from_pretrained(cfg['model']['checkpoint'])
        try:
            g['model'].resize_token_embeddings(len(g['tokenizer']))
        except Exception:
            pass
        try:
            g['model'].gradient_checkpointing_enable()
        except Exception:
            pass
        g['model'].config.use_cache = False
        _apply_generation_defaults(g['model'])
        if torch.cuda.is_available():
            g['model'] = g['model'].cuda()
    needed = any(name not in g or g[name] is None for name in ['train_tokenized', 'val_tokenized', 'data_collator'])
    if needed:
        tr, va, te, dc = prepare_tokenized_datasets(cfg, g['tokenizer'])
        g['train_tokenized'] = tr
        g['val_tokenized'] = va
        g['test_tokenized'] = te
        g['data_collator'] = dc

print('Bootstrap utilities ready. Use ensure_state_for_training() in cells to rehydrate.')

# 20) Notes on warnings you may see
- The FutureWarning about `tokenizer` in Trainer is now resolved by using `processing_class`.
- A message about moving generation attributes to `generation_config` is expected; we now set them directly to avoid it.
- If you see a warning about missing keys like `embed_tokens.weight` or `lm_head.weight` when resuming or reloading, it's typically harmless when we've resized token embeddings after adding custom tokens. Those extra rows are newly initialized if the checkpoint was created before the tokenizer grew. To avoid it, keep the tokenizer vocabulary stable across runs (we persist it under `TOKENIZER_DIR`).