In [1]:
# Install dependencies
!pip install transformers datasets peft accelerate bitsandbytes evaluate rouge_score -q
!pip install nltk -q

# Import libraries
import os
import re
import nltk
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
from datasets import load_dataset
from evaluate import load
from torch.utils.data import DataLoader, Dataset
from google.colab import files
import time
import gc

# Download NLTK data
try:
    nltk.download('punkt', quiet=True)
    nltk.download('punkt_tab', quiet=True)
except Exception as e:
    print(f"Error downloading NLTK data: {e}")
    exit(1)

# Verify NLTK
try:
    nltk.tokenize.word_tokenize("Test sentence.")
except LookupError:
    print("NLTK punkt/punkt_tab not found.")
    exit(1)

print("Environment setup complete.")

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m31.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m51.8 MB/s[0m eta [36

In [2]:
# Initialize tokenizer for token count filtering
tokenizer = AutoTokenizer.from_pretrained("t5-base", add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token

# Preprocessing functions (reused from your project)
def clean_text(text):
    """Clean text by removing boilerplate, citations, HTML, and normalizing whitespace."""
    if not isinstance(text, str):
        return ""
    text = re.sub(r'<[^>]+>', '', text)  # Remove HTML
    text = re.sub(r'\[H\.R\.\s*\d+\]|\[S\.\s*\d+\]|Section\s*\d+\([a-zA-Z]\)', '', text)  # Remove citations
    text = re.sub(r'(?i)(Be it enacted by the Senate and House|In the Senate of the United States|Congress finds that)', '', text)  # Remove boilerplate
    text = re.sub(r'[^\w\s.,!?]', '', text)  # Remove special characters
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def clean_context(context):
    """Clean and join context paragraphs from dict or string."""
    if isinstance(context, dict) and 'contexts' in context:
        return " ".join(clean_text(c) for c in context['contexts'])
    return clean_text(context)

# Load and preprocess BillSum
bill_sum_ds = load_dataset("billsum", split="test")  # US bills
bill_sum_data = []
for item in bill_sum_ds:
    text = clean_text(item['text'])
    summary = clean_text(item['summary'])
    if not text or not summary:
        continue
    text_words = len(nltk.tokenize.word_tokenize(text))
    summary_words = len(nltk.tokenize.word_tokenize(summary))
    input_tokens = len(tokenizer(text, add_special_tokens=False)['input_ids'])
    if (100 <= text_words <= 5000 and 20 <= summary_words <= 500 and input_tokens <= 2000):
        input_text = f"[CLS] Summarize: {text} [SEP]"
        bill_sum_data.append({"input_text": input_text, "output_text": summary})

bill_sum_test_df = pd.DataFrame(bill_sum_data)
bill_sum_test_df = bill_sum_test_df.drop_duplicates(subset=['input_text']).dropna()
print("BillSum Test DataFrame size after preprocessing:", len(bill_sum_test_df))

# Load and preprocess PubMedQA
pubmedqa_ds = load_dataset("pubmed_qa", "pqa_labeled", split="train")  # ~1,000 samples
pubmedqa_data = []
for item in pubmedqa_ds:
    question = clean_text(item['question'])
    context = clean_context(item['context'])
    answer = item['final_decision']  # Yes/no/maybe
    if not question or not context:
        continue
    question_words = len(nltk.tokenize.word_tokenize(question))
    input_tokens = len(tokenizer(f"Question: {question} Context: {context}", add_special_tokens=False)['input_ids'])
    if question_words >= 5 and input_tokens <= 512:
        input_text = f"[CLS] Question: {question} Context: {context} [SEP]"
        pubmedqa_data.append({"input_text": input_text, "output_text": f"Answer: {answer}"})

# Split PubMedQA (80/10/10)
pubmedqa_df = pd.DataFrame(pubmedqa_data).sample(frac=1, random_state=42).reset_index(drop=True)
train_size = int(0.8 * len(pubmedqa_df))
val_size = int(0.1 * len(pubmedqa_df))
pubmedqa_test_df = pubmedqa_df[train_size + val_size:].drop_duplicates(subset=['input_text']).dropna()
print("PubMedQA Test DataFrame size after preprocessing:", len(pubmedqa_test_df))

print("Datasets preprocessed successfully.")

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/7.27k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/91.8M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/15.8M [00:00<?, ?B/s]

ca_test-00000-of-00001.parquet:   0%|          | 0.00/6.12M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/18949 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3269 [00:00<?, ? examples/s]

Generating ca_test split:   0%|          | 0/1237 [00:00<?, ? examples/s]

BillSum Test DataFrame size after preprocessing: 2222


README.md:   0%|          | 0.00/5.19k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

PubMedQA Test DataFrame size after preprocessing: 97
Datasets preprocessed successfully.


In [3]:
def format_for_summarization(text):
    """Format input for summarization task."""
    cleaned_text = clean_text(text)
    if not cleaned_text:
        raise ValueError("Input text is empty after cleaning.")
    return f"[CLS] Summarize: {cleaned_text} [SEP]"

def format_for_qa(question, context=""):
    """Format input for QA task."""
    cleaned_question = clean_text(question)
    cleaned_context = clean_context(context)
    if not cleaned_question:
        raise ValueError("Question is empty after cleaning.")
    if not cleaned_context:
        cleaned_context = "No context provided"
    return f"[CLS] Question: {cleaned_question} Context: {cleaned_context} [SEP]"

def validate_input(text, task):
    """Validate input length and content."""
    word_count = len(nltk.tokenize.word_tokenize(text)) if text else 0
    if task == "summarization" and word_count < 50:
        raise ValueError("Summarization input must be at least 50 words.")
    if task == "qa" and word_count < 5:
        raise ValueError("QA question must be at least 5 words.")
    return True

print("Preprocessing functions for user inputs defined.")

Preprocessing functions for user inputs defined.


In [4]:
class AdapterManager:
    def __init__(self, model_name="t5-base", device="cuda" if torch.cuda.is_available() else "cpu"):
        """Initialize base model and adapter mappings."""
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, add_eos_token=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        ).to(device)
        self.model = self.base_model  # Current model (base or with adapter)
        self.current_task = None
        self.adapter_paths = {
            "summarization": "/content/legal_adapters",
            "qa": "/content/qa_biomed_lora"
            # Add translation: "./lora_adapters/translation_trained" if implemented
        }

    def load_adapter(self, task):
        """Load LoRA adapter for the specified task."""
        if task not in self.adapter_paths:
            raise ValueError(f"Unknown task: {task}. Supported: {list(self.adapter_paths.keys())}")
        if self.current_task == task:
            print(f"Adapter for {task} already loaded.")
            return
        adapter_path = self.adapter_paths[task]
        if not os.path.exists(adapter_path):
            raise FileNotFoundError(f"Adapter not found at {adapter_path}")
        self.model = PeftModel.from_pretrained(self.base_model, adapter_path).to(self.device)
        self.current_task = task
        print(f"Loaded adapter for {task}. Active adapters: {self.model.active_adapters}")
        print(f"Trainable parameters: {self.model.print_trainable_parameters()}")
        gc.collect()
        torch.cuda.empty_cache()

    def get_model(self):
        """Return the current model."""
        return self.model

    def get_tokenizer(self):
        """Return the tokenizer."""
        return self.tokenizer

# Initialize AdapterManager
adapter_manager = AdapterManager()
print("AdapterManager initialized.")

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

AdapterManager initialized.


In [5]:
def detect_task(input_text, adapter_manager):
    """Detect task using zero-shot prompting and rule-based fallback."""
    # Rule-based detection
    if (
        re.search(r'Summarize:', input_text, re.IGNORECASE) or
        any(keyword in input_text.lower() for keyword in ["bill", "act", "amends"]) or
        len(nltk.tokenize.word_tokenize(input_text)) > 50
    ):
        return "summarization"
    if (
        re.search(r'Question:.*Context:', input_text, re.IGNORECASE) or
        any(keyword in input_text.lower() for keyword in ["patient", "disease", "treatment", "does", "is", "what"]) or
        input_text.strip().endswith("?")
    ):
        return "qa"

    # Zero-shot prompting
    prompt = (
        f"Classify the task for this input: {input_text[:200]}... "
        "Output only: Summarization or Question Answering. "
        "Summarization generates a concise summary of a long text, often legislative. "
        "Question Answering answers a question based on a context, often biomedical."
    )
    inputs = adapter_manager.get_tokenizer()(
        prompt,
        truncation=True,
        padding='max_length',
        max_length=512,
        return_tensors='pt'
    ).to(adapter_manager.device)
    with torch.no_grad():
        outputs = adapter_manager.base_model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=20,
            num_beams=2
        )
    raw_task = adapter_manager.get_tokenizer().decode(outputs[0], skip_special_tokens=True)
    print(f"Debug: Raw task output: '{raw_task}'")  # Debug log
    task = raw_task.lower().strip()
    if "summar" in task:
        return "summarization"
    if "question" in task or "qa" in task:
        return "qa"
    print(f"Warning: Task detection unclear ('{raw_task}'). Defaulting to summarization.")
    return "summarization"

# Test task detection
test_inputs = [
    "This bill amends the Tax Code to provide tax credits for small businesses.",
    "Does aspirin reduce heart attack risk?",
    "[CLS] Summarize: This bill amends the Tax Code... [SEP]",
    "[CLS] Question: Does aspirin reduce heart attack risk? Context: Studies show... [SEP]"
]
for input_text in test_inputs:
    task = detect_task(input_text, adapter_manager)
    print(f"Input: {input_text[:50]}... -> Detected task: {task}")

Input: This bill amends the Tax Code to provide tax credi... -> Detected task: summarization
Input: Does aspirin reduce heart attack risk?... -> Detected task: qa
Input: [CLS] Summarize: This bill amends the Tax Code... ... -> Detected task: summarization
Input: [CLS] Question: Does aspirin reduce heart attack r... -> Detected task: qa


In [6]:
import torch
import torch.cuda as cuda
import gc
import os

# Print GPU memory usage for debugging
def print_memory_usage():
    if cuda.is_available():
        allocated = cuda.memory_allocated() / 1024**3
        reserved = cuda.memory_reserved() / 1024**3
        print(f"GPU Memory - Allocated: {allocated:.2f} GiB, Reserved: {reserved:.2f} GiB")
    else:
        print("No GPU available.")

def run_inference(raw_input, question=None, context="", adapter_manager=adapter_manager):
    """Run inference on user input, detecting task and generating output."""
    start_time = time.time()

    # Preprocess input
    try:
        task = detect_task(raw_input, adapter_manager)
        if task == "summarization":
            formatted_input = format_for_summarization(raw_input)
            max_length = 512
            validate_input(raw_input, task)
        elif task == "qa":
            formatted_input = format_for_qa(question if question else raw_input, context)
            max_length = 512
            validate_input(question if question else raw_input, task)
        else:
            raise ValueError(f"Unsupported task: {task}")
    except ValueError as e:
        print(f"Preprocessing error: {str(e)}")
        return f"Error: {str(e)}"

    # Log preprocessing
    print(f"Raw input: {raw_input[:50]}...")
    if context:
        print(f"Context: {context[:50]}...")
    print(f"Task detected: {task}")
    print(f"Formatted input: {formatted_input[:50]}...")

    # Load adapter
    try:
        adapter_manager.load_adapter(task)
        # Verify trainable parameters and adapter config
        model = adapter_manager.get_model()
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Trainable parameters: {trainable_params}")
        if trainable_params == 0:
            print("Warning: No trainable parameters in adapter!")
            # Check adapter files
            adapter_path = "/content/legal_adapters" if task == "summarization" else "/content/qa_biomed_lora"
            if os.path.exists(adapter_path):
                print(f"Adapter directory contents: {os.listdir(adapter_path)}")
            else:
                print(f"Adapter directory {adapter_path} does not exist!")
    except Exception as e:
        print(f"Adapter loading error: {str(e)}")
        return f"Error: {str(e)}"

    # Tokenize
    try:
        inputs = adapter_manager.get_tokenizer()(
            formatted_input,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        ).to(adapter_manager.device)
    except Exception as e:
        print(f"Tokenization error: {str(e)}")
        return f"Error: {str(e)}"

    # Generate output
    try:
        print("Starting model.generate...")
        print_memory_usage()
        with torch.no_grad():
            outputs = adapter_manager.get_model().generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=100 if task == "summarization" else 32,
                num_beams=2,
                length_penalty=1.0
            )
        print("Completed model.generate")
    except Exception as e:
        print(f"Error in model.generate: {str(e)}")
        return f"Error: {str(e)}"

    # Decode output
    result = adapter_manager.get_tokenizer().decode(outputs[0], skip_special_tokens=True)

    # Clear memory
    gc.collect()
    cuda.empty_cache()
    print_memory_usage()

    # Log runtime
    runtime = time.time() - start_time
    print(f"Output: {result}")
    print(f"Runtime: {runtime:.2f} seconds")

    return result

# Test inference with longer inputs (>50 words)
print("\nTesting summarization with explicit keywords:")
test_input = (
    "Summarize: This bill amends the Tax Code to provide tax credits for small businesses "
    "employing fewer than 50 employees. The credits aim to encourage job creation and economic "
    "growth in local communities. Eligible businesses must demonstrate compliance with federal "
    "regulations and submit annual reports to qualify for the credits. The bill also includes "
    "provisions for auditing to ensure proper use of funds."
)
result = run_inference(test_input, adapter_manager=adapter_manager)
print(f"Result: {result}")

print("\nTesting summarization with raw input:")
test_raw_input = (
    "This bill amends the Tax Code to provide tax credits for small businesses employing fewer "
    "than 50 employees. The credits aim to encourage job creation and economic growth in local "
    "communities. Eligible businesses must demonstrate compliance with federal regulations and "
    "submit annual reports to qualify for the credits. The bill also includes provisions for "
    "auditing to ensure proper use of funds."
)
result_raw = run_inference(test_raw_input, adapter_manager=adapter_manager)
print(f"Result: {result_raw}")



print("\nTesting QA with explicit keywords:")
test_qa = "Question: Does aspirin reduce heart attack risk?"
result_qa = run_inference(test_qa, question=test_qa, context="Studies show aspirin reduces clotting.", adapter_manager=adapter_manager)
print(f"Result: {result_qa}")

print("\nTesting QA with raw input:")
test_raw_qa = "Does aspirin reduce heart attack risk?"
result_raw_qa = run_inference(test_raw_qa, question=test_raw_qa, context="Studies show aspirin reduces clotting.", adapter_manager=adapter_manager)
print(f"Result: {result_raw_qa}")



Testing summarization with explicit keywords:
Raw input: Summarize: This bill amends the Tax Code to provid...
Task detected: summarization
Formatted input: [CLS] Summarize: Summarize This bill amends the Ta...
Loaded adapter for summarization. Active adapters: ['default']
trainable params: 0 || all params: 224,673,024 || trainable%: 0.0000
Trainable parameters: None
Trainable parameters: 0
Adapter directory contents: ['adapter_config.json', 'adapter_model.safetensors', 'README.md']
Starting model.generate...
GPU Memory - Allocated: 0.53 GiB, Reserved: 0.67 GiB
Completed model.generate
GPU Memory - Allocated: 0.54 GiB, Reserved: 0.67 GiB
Output: Amends the Tax Code to provide tax credits for small businesses employing fewer than 50 employees. [EOS]
Runtime: 7.08 seconds
Result: Amends the Tax Code to provide tax credits for small businesses employing fewer than 50 employees. [EOS]

Testing summarization with raw input:
Raw input: This bill amends the Tax Code to provide tax credi...
T

In [7]:
!ls /content/legal_adapters

adapter_config.json  adapter_model.safetensors	README.md


In [13]:
import torch.cuda as cuda
import nltk
import pandas as pd
from evaluate import load
from google.colab import files
from sklearn.metrics import accuracy_score, f1_score


# Print GPU memory usage for debugging
def print_memory_usage():
    if cuda.is_available():
        allocated = cuda.memory_allocated() / 1024**3
        reserved = cuda.memory_reserved() / 1024**3
        print(f"GPU Memory - Allocated: {allocated:.2f} GiB, Reserved: {reserved:.2f} GiB")
    else:
        print("No GPU available.")

# Filter BillSum samples to shorter inputs (word count <500)
def get_word_count(text):
    clean_text = text.replace("[CLS] Summarize: ", "").replace(" [SEP]", "")
    return len(nltk.tokenize.word_tokenize(clean_text))

bill_sum_short_df = bill_sum_test_df[bill_sum_test_df['input_text'].apply(get_word_count) < 500]
print(f"Filtered BillSum samples (word count <500): {len(bill_sum_short_df)}")

# Create test suite (summarization focus, QA commented out)
#Build your test suite
test_suite = (
    # 1) The first block: BillSum short-input summarization examples
    [
        {
            "input": bill_sum_short_df['input_text'].iloc[i],
            "expected_task": "summarization",
            "expected_output": bill_sum_short_df['output_text'].iloc[i]
        }
        for i in range(min(5, len(bill_sum_short_df)))
    ]
    +
    # 2) Edge-case examples
    [
        {
            "input": "Raw bill text with typos!!! Amends tax code...",
            "expected_task": "summarization",
            "expected_output": "Summary of tax code amendment."
        },
        {
            "input": "",
            "expected_task": None,
            "expected_output": "Error: Input text is empty after cleaning."
        },
        # QA edge case example (still in the list, just change task if needed)
        {
            "input": "Is this a question without context?",
            "expected_task": "qa",
            "expected_output": "Answer: Maybe"
        },  # ← note this trailing comma
    ]
    # Uncomment this block once pubmedqa_test_df is available:
    +
    [
        {
            "input": pubmedqa_test_df['input_text'].iloc[i],
            "expected_task": "qa",
            "expected_output": pubmedqa_test_df['output_text'].iloc[i]
        }
        for i in range(min(5, len(pubmedqa_test_df)))
    ]
)
# Run tests
rouge = load("rouge")
test_results = []
for idx, test_case in enumerate(test_suite):
    input_text = test_case['input']
    expected_task = test_case['expected_task']
    expected_output = test_case['expected_output']

    print(f"\nProcessing test case {idx + 1}/{len(test_suite)}")
    print_memory_usage()

    # Run inference
        # Run inference
    try:
        result = run_inference(input_text, adapter_manager=adapter_manager)
        print(f"Inference result: {result}")
    except Exception as e:
        result = f"Error: {str(e)}"
        print(f"Error in inference: {str(e)}")

    # Evaluate
    detected_task = detect_task(input_text, adapter_manager)
    task_correct = (detected_task == expected_task) if expected_task else True
    metrics = {}

    # Summarization → ROUGE
    if expected_task == "summarization" and "Error" not in result:
        try:
            rouge_scores = rouge.compute(
                predictions=[result],
                references=[expected_output],  # unchanged
                use_stemmer=True
            )
            # Unpack mid‐fmeasure if present
            rf = rouge_scores["rouge1"]
            metrics = {
                "rouge1_f1": rf.mid.fmeasure if hasattr(rf, "mid") else rf
            }
        except Exception as e:
            print(f"ROUGE evaluation error: {str(e)}")
            metrics = {"rouge1_f1": 0.0}

    # QA → Accuracy & F1
    elif expected_task == "qa" and "Error" not in result:
        # strip prefix / lowercase
        pred_clean = result.replace("Answer:", "").strip().lower()
        gold_clean = expected_output.replace("Answer:", "").strip().lower()

        acc = accuracy_score([gold_clean], [pred_clean])
        f1  = f1_score([gold_clean], [pred_clean], average="weighted")
        metrics = {"accuracy": acc, "f1": f1}

    # Append results
    test_results.append({
        "input":            input_text[:50],
        "detected_task":    detected_task,
        "task_correct":     task_correct,
        "output":           result,
        "expected_output":  expected_output,
        "metrics":          metrics
    })


    # Clear memory
    gc.collect()
    cuda.empty_cache()
    print_memory_usage()

# Save and display results
test_results_df = pd.DataFrame(test_results)
test_results_df.to_csv("mole_test_results.csv", index=False)
print("\nTest Results:")
print(test_results_df)

# Download results
files.download("mole_test_results.csv")

Filtered BillSum samples (word count <500): 10

Processing test case 1/13
GPU Memory - Allocated: 0.53 GiB, Reserved: 0.67 GiB
Raw input: [CLS] Summarize: SECTION 1. TEMPORARY DUTY SUSPENS...
Task detected: summarization
Formatted input: [CLS] Summarize: CLS Summarize SECTION 1. TEMPORAR...
Loaded adapter for summarization. Active adapters: ['default']
trainable params: 0 || all params: 224,673,024 || trainable%: 0.0000
Trainable parameters: None
Trainable parameters: 0
Adapter directory contents: ['adapter_config.json', 'adapter_model.safetensors', 'README.md']
Starting model.generate...
GPU Memory - Allocated: 0.54 GiB, Reserved: 0.67 GiB
Completed model.generate
GPU Memory - Allocated: 0.54 GiB, Reserved: 0.67 GiB
Output: Amends the Harmonized Tariff Schedule of the United States (HTS) to amend the Harmonized Tariff Schedule of the United States (HTS) by inserting in numerical sequence the following new headings: (1) SNtertbutyl 1,2,3,4 tetrahydro3 isoquinoline carboxamide hydrochlo

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [16]:
from IPython.display import display
import ipywidgets as widgets
import torch.cuda as cuda

# Print GPU memory usage for debugging
def print_memory_usage():
    if cuda.is_available():
        allocated = cuda.memory_allocated() / 1024**3
        reserved = cuda.memory_reserved() / 1024**3
        print(f"GPU Memory - Allocated: {allocated:.2f} GiB, Reserved: {reserved:.2f} GiB")
    else:
        print("No GPU available.")

# Create input widget
input_box = widgets.Textarea(
    value='',
    placeholder='Enter text (e.g., legislative bill, biomedical question)',
    description='Input:',
    layout={'width': '600px', 'height': '200px'}
)
context_box = widgets.Textarea(
    value='',
    placeholder='Optional context for QA',
    description='Context:',
    layout={'width': '600px', 'height': '100px'}
)
button = widgets.Button(description="Run MoLE")
output = widgets.Output()

def on_button_click(b):
    with output:
        output.clear_output()
        raw_input = input_box.value
        context = context_box.value
        print(f"Raw Input: {raw_input[:50]}...")
        if context:
            print(f"Context: {context[:50]}...")
        print_memory_usage()
        try:
            result = run_inference(
                raw_input,
                question=raw_input if not context else None,
                context=context,
                adapter_manager=adapter_manager
            )
            print(f"Final Output: {result}")
        except Exception as e:
            print(f"Error during inference: {str(e)}")
        print_memory_usage()
        gc.collect()
        cuda.empty_cache()

button.on_click(on_button_click)
display(input_box, context_box, button, output)

# Demo with sample inputs
print("\nSample Demonstrations:")
try:
    # Summarization (short input to reduce memory)
    bill_sum_sample = "Summarize: SECTION 1. EMPLOYEE WAGE PROTECTION.—Notwithstanding any other provision of law, every employer engaged in interstate commerce shall pay to each employee no less than the prevailing wage as determined by the Department of Labor. No employer shall withhold any portion of any wage or salary unless expressly authorized by law or by a written agreement signed by the employee. All wage disputes shall be subject to adjudication in the Federal District Court for the district in which the employers principal place of business is located."
    print("\nSummarization Example:")
    print_memory_usage()
    result = run_inference(bill_sum_sample, adapter_manager=adapter_manager)
    print(f"Output: {result}")
    print_memory_usage()
    gc.collect()
    cuda.empty_cache()

    # QA
    pubmedqa_sample = "Question: Does aspirin reduce heart attack risk? Context: Studies show aspirin reduces clotting."
    print("\nQA Example:")
    print_memory_usage()
    result = run_inference(pubmedqa_sample, adapter_manager=adapter_manager)
    print(f"Output: {result}")
    print_memory_usage()
    gc.collect()
    cuda.empty_cache()

    # Edge case: Noisy input
    noisy_input = """
<html><body>Be it enacted by the Senate and House of Representatives of the United States:
Section 2(a) [H.R. 5678] “Funding Authorization” — The Secretary shall allocate
$1,000,000—$2,000,000 per fiscal year to § communities.*1 <i>See appendix</i> for definitions.
All terms not defined herein shall have the meanings given in 25 U.S.C. § 3001(f).
</body></html>
"""
    print("\nEdge Case (Noisy Input):")
    print_memory_usage()
    result = run_inference(noisy_input, adapter_manager=adapter_manager)
    print(f"Output: {result}")
    print_memory_usage()
    gc.collect()
    cuda.empty_cache()

except Exception as e:
    print(f"Error in sample demonstrations: {str(e)}")

Textarea(value='', description='Input:', layout=Layout(height='200px', width='600px'), placeholder='Enter text…

Textarea(value='', description='Context:', layout=Layout(height='100px', width='600px'), placeholder='Optional…

Button(description='Run MoLE', style=ButtonStyle())

Output()


Sample Demonstrations:

Summarization Example:
GPU Memory - Allocated: 0.53 GiB, Reserved: 0.67 GiB
Raw input: Summarize: SECTION 1. EMPLOYEE WAGE PROTECTION.—No...
Task detected: summarization
Formatted input: [CLS] Summarize: Summarize SECTION 1. EMPLOYEE WAG...
Loaded adapter for summarization. Active adapters: ['default']
trainable params: 0 || all params: 224,673,024 || trainable%: 0.0000
Trainable parameters: None
Trainable parameters: 0
Adapter directory contents: ['adapter_config.json', 'adapter_model.safetensors', 'README.md']
Starting model.generate...
GPU Memory - Allocated: 0.54 GiB, Reserved: 0.67 GiB
Completed model.generate
GPU Memory - Allocated: 0.54 GiB, Reserved: 0.67 GiB
Output: Employers engaged in interstate commerce shall pay to each employee no less than the prevailing wage as determined by the Department of Labor. Employers engaged in interstate commerce shall pay to each employee no less than the prevailing wage as determined by the Department of Labor. [EOS]