In [1]:
#----------------------------------
# T5 Legal Summarization Training Script
# Organized with proper folder structure and safety checks
#----------------------------------

#----------------------------------
# 1️⃣ Install Required Libraries
#----------------------------------
!pip install transformers sentencepiece datasets evaluate rouge_score
!pip install datasets torch matplotlib pandas


Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=0b5a03e002def6a5b7c063070e8f5ed310f6898121ddb04b73be0fe1dfde318e
  Stored in directory: /root/.cache/pip/wheels/85/9d/af/01feefbe7d55ef5468796f0c68225b6788e85d9d0a281e7a70
Successfully built rouge_score
Installing collected packages: rouge_score, evaluate
Successfully installed evaluate-0.4.6 rouge_score-0.1.2


In [2]:
#----------------------------------
# 2️⃣ Import Libraries and Setup
#----------------------------------
import os
import torch
import json
import shutil
from datetime import datetime
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt

from datasets import load_dataset
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback
)
import evaluate

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
#----------------------------------
# 3️⃣ Setup Organized Folder Structure
#----------------------------------
class ProjectStructure:
    def __init__(self, base_path="/content/drive/MyDrive/Legal_Summarizer_Project"):
        self.base_path = Path(base_path)
        self.setup_directories()

    def setup_directories(self):
        """Create organized directory structure"""
        self.dirs = {
            'base': self.base_path,
            'models': self.base_path / 'models',
            'data': self.base_path / 'data',
            'outputs': self.base_path / 'outputs',
            'logs': self.base_path / 'logs',
            'checkpoints': self.base_path / 'checkpoints',
            'evaluation': self.base_path / 'evaluation',
            'configs': self.base_path / 'configs'
        }

        # Create all directories
        for dir_path in self.dirs.values():
            dir_path.mkdir(parents=True, exist_ok=True)

        print("📁 Directory structure created:")
        for name, path in self.dirs.items():
            print(f"  {name}: {path}")

    def get_path(self, dir_name):
        return self.dirs.get(dir_name, self.base_path)

# Initialize project structure
project = ProjectStructure()


📁 Directory structure created:
  base: /content/drive/MyDrive/Legal_Summarizer_Project
  models: /content/drive/MyDrive/Legal_Summarizer_Project/models
  data: /content/drive/MyDrive/Legal_Summarizer_Project/data
  outputs: /content/drive/MyDrive/Legal_Summarizer_Project/outputs
  logs: /content/drive/MyDrive/Legal_Summarizer_Project/logs
  checkpoints: /content/drive/MyDrive/Legal_Summarizer_Project/checkpoints
  evaluation: /content/drive/MyDrive/Legal_Summarizer_Project/evaluation
  configs: /content/drive/MyDrive/Legal_Summarizer_Project/configs


In [4]:
#----------------------------------
# 4️⃣ Configuration Management
#----------------------------------
config = {
    "model_name": "t5-base",
    "dataset_name": "d0r1h/ILC",
    "max_input_length": 512,
    "max_target_length": 150,
    "learning_rate": 3e-5,
    "batch_size": 4,
    "num_epochs": 10,
    "weight_decay": 0.01,
    "num_beams": 6,
    "repetition_penalty": 2.5,
    "length_penalty": 1.5,
    "min_length": 50,
    "max_length": 150,
    "eval_samples": 200
}

# Save configuration
config_path = project.get_path('configs') / 'training_config.json'
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)
print(f"💾 Configuration saved to: {config_path}")


💾 Configuration saved to: /content/drive/MyDrive/Legal_Summarizer_Project/configs/training_config.json


In [5]:
#----------------------------------
# 5️⃣ Safe Model Loading/Saving Functions
#----------------------------------
class ModelManager:
    def __init__(self, project_structure, config):
        self.project = project_structure
        self.config = config
        self.model_path = self.project.get_path('models') / 't5-legal-summarizer'
        self.checkpoint_path = self.project.get_path('checkpoints')

    def load_or_create_model(self):
        """Safely load existing model or create new one"""
        try:
            if self.model_path.exists() and any(self.model_path.iterdir()):
                print("🔄 Loading existing model from:", self.model_path)
                tokenizer = T5Tokenizer.from_pretrained(str(self.model_path))
                model = T5ForConditionalGeneration.from_pretrained(str(self.model_path))
                print("✅ Successfully loaded existing model")
            else:
                print("🆕 Creating new model from:", self.config["model_name"])
                tokenizer = T5Tokenizer.from_pretrained(self.config["model_name"])
                model = T5ForConditionalGeneration.from_pretrained(self.config["model_name"])
                print("✅ Successfully created new model")

        except Exception as e:
            print(f"❌ Error loading model: {e}")
            print("🔄 Falling back to creating new model...")
            tokenizer = T5Tokenizer.from_pretrained(self.config["model_name"])
            model = T5ForConditionalGeneration.from_pretrained(self.config["model_name"])

        return tokenizer, model

    def save_model(self, tokenizer, model, suffix=""):
        """Safely save model with backup"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        save_path = self.model_path if not suffix else self.model_path.parent / f"t5-legal-summarizer_{suffix}"

        try:
            # Create backup if model already exists
            if save_path.exists():
                backup_path = self.checkpoint_path / f"backup_{timestamp}"
                shutil.copytree(save_path, backup_path)
                print(f"📦 Backup created at: {backup_path}")

            # Save model and tokenizer
            tokenizer.save_pretrained(str(save_path))
            model.save_pretrained(str(save_path))

            # Save additional checkpoint info
            checkpoint_info = {
                "timestamp": timestamp,
                "model_path": str(save_path),
                "config": self.config
            }

            checkpoint_file = self.checkpoint_path / f"checkpoint_{timestamp}.json"
            with open(checkpoint_file, 'w') as f:
                json.dump(checkpoint_info, f, indent=2)

            print(f"💾 Model saved successfully to: {save_path}")
            print(f"📋 Checkpoint info saved to: {checkpoint_file}")

        except Exception as e:
            print(f"❌ Error saving model: {e}")

# Initialize model manager
model_manager = ModelManager(project, config)


In [6]:
#----------------------------------
# 6️⃣ Load Dataset with Error Handling
#----------------------------------
def load_dataset_safely():
    try:
        print(f"📥 Loading dataset: {config['dataset_name']}")
        dataset = load_dataset(config["dataset_name"])

        # Save dataset info
        dataset_info = {
            "name": config["dataset_name"],
            "train_size": len(dataset["train"]) if "train" in dataset else 0,
            "test_size": len(dataset["test"]) if "test" in dataset else 0,
            "features": list(dataset["train"].features.keys()) if "train" in dataset else []
        }

        info_path = project.get_path('data') / 'dataset_info.json'
        with open(info_path, 'w') as f:
            json.dump(dataset_info, f, indent=2)

        print("✅ Dataset loaded successfully")
        print(f"📊 Train samples: {dataset_info['train_size']}")
        print(f"📊 Test samples: {dataset_info['test_size']}")

        return dataset

    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return None

dataset = load_dataset_safely()
if dataset is None:
    raise Exception("Failed to load dataset. Please check your internet connection and dataset name.")


📥 Loading dataset: d0r1h/ILC


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/368 [00:00<?, ?B/s]

train.csv:   0%|          | 0.00/36.5M [00:00<?, ?B/s]

test.csv:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2058 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1015 [00:00<?, ? examples/s]

✅ Dataset loaded successfully
📊 Train samples: 2058
📊 Test samples: 1015


In [7]:
#----------------------------------
# 7️⃣ Load Model
#----------------------------------
tokenizer, model = model_manager.load_or_create_model()


🆕 Creating new model from: t5-base


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

✅ Successfully created new model


In [8]:
#----------------------------------
# 8️⃣ Data Preprocessing
#----------------------------------
def preprocess_function(examples):
    """Preprocess data for T5 model"""
    # Add "summarize:" prefix (T5 convention)
    inputs = ["summarize: " + text for text in examples["Case"]]

    # Tokenize input cases
    model_inputs = tokenizer(
        inputs,
        max_length=config["max_input_length"],
        truncation=True,
        padding="max_length"
    )

    # Tokenize target summaries
    labels = tokenizer(
        examples["Summary"],
        max_length=config["max_target_length"],
        truncation=True,
        padding="max_length"
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

print("🔄 Preprocessing dataset...")
tokenized_dataset = dataset.map(preprocess_function, batched=True)
print("✅ Dataset preprocessing completed")


🔄 Preprocessing dataset...


Map:   0%|          | 0/2058 [00:00<?, ? examples/s]

Map:   0%|          | 0/1015 [00:00<?, ? examples/s]

✅ Dataset preprocessing completed


In [9]:
#----------------------------------
# 9️⃣ Setup Training Arguments
#----------------------------------
training_args = Seq2SeqTrainingArguments(
    output_dir=str(project.get_path('logs')),
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=100,
    learning_rate=config["learning_rate"],
    per_device_train_batch_size=config["batch_size"],
    per_device_eval_batch_size=config["batch_size"],
    num_train_epochs=config["num_epochs"],
    weight_decay=config["weight_decay"],
    save_total_limit=3,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    predict_with_generate=True,
    report_to=None  # Disable wandb
)


In [10]:
#----------------------------------
# 🔟 Initialize Trainer with Early Stopping
#----------------------------------
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)


  trainer = Seq2SeqTrainer(


In [11]:
#----------------------------------
# 1️⃣1️⃣ Train the Model
#----------------------------------
print("🚀 Starting model training...")
try:
    trainer.train()
    print("✅ Training completed successfully!")

    # Save training history
    training_history = trainer.state.log_history
    history_path = project.get_path('logs') / 'training_history.json'
    with open(history_path, 'w') as f:
        json.dump(training_history, f, indent=2)

except Exception as e:
    print(f"❌ Training failed: {e}")
    # Save model state anyway
    model_manager.save_model(tokenizer, model, "interrupted")


🚀 Starting model training...


  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhw2781022[0m ([33mhw2781022-shree-l-r-tiwari-college-of-engineering[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
----------------------------------
# 1️⃣2️⃣ Save Final Model
#----------------------------------
model_manager.save_model(tokenizer, model, "final")




In [None]:
#----------------------------------
# 1️⃣3️⃣ Summary Generation Function
#----------------------------------
def generate_summary(text, model, tokenizer):
    """Generate summary for given text"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    inputs = tokenizer(
        "summarize: " + text,
        return_tensors="pt",
        truncation=True,
        max_length=config["max_input_length"]
    ).to(device)

    summary_ids = model.generate(
        inputs["input_ids"],
        max_length=config["max_length"],
        min_length=config["min_length"],
        length_penalty=config["length_penalty"],
        num_beams=config["num_beams"],
        repetition_penalty=config["repetition_penalty"],
        early_stopping=True
    )

    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

In [None]:

#----------------------------------
# 1️⃣4️⃣ Evaluation and Export
#----------------------------------
print("📊 Starting evaluation...")
eval_set = dataset["test"].select(range(min(config["eval_samples"], len(dataset["test"]))))

# Generate summaries
batch_summaries = []
reference_summaries = []

for i, example in enumerate(eval_set):
    if i % 50 == 0:
        print(f"Processing {i+1}/{len(eval_set)} samples...")

    try:
        generated_summary = generate_summary(example["Case"], model, tokenizer)
        batch_summaries.append(generated_summary)
        reference_summaries.append(example["Summary"])
    except Exception as e:
        print(f"Error generating summary for sample {i}: {e}")
        batch_summaries.append("")
        reference_summaries.append(example["Summary"])

# Create results dataframe
results_df = pd.DataFrame({
    "Case": [x["Case"] for x in eval_set],
    "Reference_Summary": reference_summaries,
    "Generated_Summary": batch_summaries
})

# Save results
output_file = project.get_path('outputs') / f"legal_summaries_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
results_df.to_csv(output_file, index=False)
print(f"📁 Results exported to: {output_file}")


In [None]:
#----------------------------------
# 1️⃣5️⃣ ROUGE Evaluation
#----------------------------------
print("📈 Calculating ROUGE scores...")
try:
    rouge = evaluate.load("rouge")
    valid_summaries = [(pred, ref) for pred, ref in zip(batch_summaries, reference_summaries) if pred.strip()]

    if valid_summaries:
        predictions, references = zip(*valid_summaries)
        scores = rouge.compute(predictions=list(predictions), references=list(references))

        print("📊 ROUGE Scores:")
        for k, v in scores.items():
            print(f"  {k}: {v:.4f}")

        # Save scores
        scores_path = project.get_path('evaluation') / f"rouge_scores_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(scores_path, 'w') as f:
            json.dump(scores, f, indent=2)

        # Plot scores
        def plot_rouge_scores(rouge_scores):
            plt.figure(figsize=(10, 6))
            keys = list(rouge_scores.keys())
            values = [rouge_scores[k] for k in keys]

            bars = plt.bar(keys, values, color=['skyblue', 'lightcoral', 'lightgreen', 'gold'][:len(keys)])
            plt.title("ROUGE Scores for Legal Summarizer", fontsize=16, fontweight='bold')
            plt.ylabel("Score", fontsize=12)
            plt.ylim(0, 1)

            # Add value labels on bars
            for bar, value in zip(bars, values):
                plt.text(bar.get_x() + bar.get_width()/2, value + 0.02,
                        f"{value:.3f}", ha='center', fontweight='bold')

            plt.grid(axis='y', alpha=0.3)
            plt.tight_layout()

            # Save plot
            plot_path = project.get_path('evaluation') / f"rouge_plot_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            plt.show()

        plot_rouge_scores(scores)

    else:
        print("⚠️ No valid summaries generated for ROUGE evaluation")

except Exception as e:
    print(f"❌ Error in ROUGE evaluation: {e}")


In [None]:


#----------------------------------
# 1️⃣6️⃣ Test Custom Case
#----------------------------------
print("\n🧾 Testing custom case summarization...")
custom_case_text = """
This judgment will dispose of the present application seeking interim injunction against the Defendants.
The plaintiff has approached this Court seeking relief in respect of trademark infringement and passing off.
The plaintiff claims exclusive rights over the trademark and seeks to restrain the defendants from using
any mark that is deceptively similar to their registered trademark. The matter involves questions of
likelihood of confusion among consumers and the strength of the plaintiff's trademark rights.
"""

try:
    custom_summary = generate_summary(custom_case_text, model, tokenizer)
    print("Generated Summary:")
    print(f"📝 {custom_summary}")

    # Save custom test
    custom_test_path = project.get_path('outputs') / 'custom_case_test.txt'
    with open(custom_test_path, 'w') as f:
        f.write(f"Input Case:\n{custom_case_text}\n\nGenerated Summary:\n{custom_summary}")

except Exception as e:
    print(f"❌ Error generating custom summary: {e}")


In [None]:

#----------------------------------
# 1️⃣7️⃣ Final Summary Report
#----------------------------------
print("\n" + "="*60)
print("🎉 TRAINING AND EVALUATION COMPLETED")
print("="*60)
print(f"📁 Project folder: {project.base_path}")
print(f"💾 Model saved at: {model_manager.model_path}")
print(f"📊 Results saved in: {project.get_path('outputs')}")
print(f"📈 Evaluation data in: {project.get_path('evaluation')}")
print(f"📋 Logs available in: {project.get_path('logs')}")
print("="*60)

# Create summary report
report = {
    "project_completed": datetime.now().isoformat(),
    "model_path": str(model_manager.model_path),
    "config": config,
    "dataset_info": {
        "name": config["dataset_name"],
        "samples_processed": len(eval_set)
    },
    "files_created": {
        "model": str(model_manager.model_path),
        "results": str(output_file),
        "config": str(config_path)
    }
}

report_path = project.get_path('base') / 'project_summary.json'
with open(report_path, 'w') as f:
    json.dump(report, f, indent=2)

print(f"📋 Complete project summary saved to: {report_path}")