In [1]:
# %pip install transformers datasets torch

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
from datasets import Dataset, DatasetDict
import torch
import json

In [3]:
# ✅ STEP 1: Setup Model and Tokenizer (CodeT5+)
model_name = "Salesforce/codet5p-220m"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("✅ Using device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
model.to(device)

✅ Using device: NVIDIA GeForce RTX 2080 Ti


T5ForConditionalGeneration(
  (shared): Embedding(32100, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32100, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [4]:
# from google.colab import files
# uploaded = files.upload()

In [5]:
# ✅ STEP 2: Load and preprocess dataset
def load_json_dataset(path):
    with open(path) as f:
        data = json.load(f)
    return [{"input": item["input"], "output": item["output"]} for item in data]

# train_data = load_json_dataset("/home/sysadm/Music/unitime_nlp/data/processed/train.json")
# val_data = load_json_dataset("/home/sysadm/Music/unitime_nlp/data/processed/val.json")
# test_data = load_json_dataset("/home/sysadm/Music/unitime_nlp/data/processed/test.json")

train_data = load_json_dataset("/home/sysadm/Music/unitime_nlp/data/processed/train_data.json")
val_data = load_json_dataset("/home/sysadm/Music/unitime_nlp/data/processed/val_data.json")
test_data = load_json_dataset("/home/sysadm/Music/unitime_nlp/data/processed/test_data.json")

# Organize into HuggingFace dataset
full_dataset = DatasetDict({
    "train": Dataset.from_list(train_data),
    "validation": Dataset.from_list(val_data),
    "test": Dataset.from_list(test_data)
})

In [6]:
# ✅ STEP 3: Tokenize the data
def tokenize(batch):
    model_inputs = tokenizer(batch["input"], max_length=512, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(batch["output"], max_length=512, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

train_tokenized = full_dataset["train"].map(tokenize, batched=True)
val_tokenized = full_dataset["validation"].map(tokenize, batched=True)
test_tokenized = full_dataset["test"].map(tokenize, batched=True)


Map:   0%|          | 0/2457 [00:00<?, ? examples/s]



Map:   0%|          | 0/455 [00:00<?, ? examples/s]

Map:   0%|          | 0/456 [00:00<?, ? examples/s]

In [7]:
# ✅ STEP 4: Set training arguments
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    eval_strategy="steps",
    save_strategy="epoch",
    logging_dir="./logs",
    num_train_epochs=5,
    fp16=True if torch.cuda.is_available() else False,
    report_to="none"
)


# ✅ STEP 5: Train the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer
)

print("🚀 Starting training...")
trainer.train()
print("✅ Training complete")


  trainer = Trainer(


🚀 Starting training...


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss
500,1.5685,0.060667
1000,0.0461,0.036025
1500,0.0326,0.031495
2000,0.0291,0.030057
2500,0.0275,0.029299
3000,0.0258,0.029086


✅ Training complete


In [None]:
from huggingface_hub import login

# 🔐 Optional: Log in with your HF token (only needed once per environment)
login(token="")  # Or set env var: HUGGINGFACE_TOKEN

# ✅ Save model locally
trainer.save_model("codet5-unitime")

# ✅ Push model to Hugging Face Hub
model.push_to_hub("unitime-codet5")
tokenizer.push_to_hub("unitime-codet5")


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Rai02x/unitime-codet5/commit/2ca57087b4fefbca103ff0d7175ffe3a978620f9', commit_message='Upload tokenizer', commit_description='', oid='2ca57087b4fefbca103ff0d7175ffe3a978620f9', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Rai02x/unitime-codet5', endpoint='https://huggingface.co', repo_type='model', repo_id='Rai02x/unitime-codet5'), pr_revision=None, pr_num=None)

In [9]:
import re

# def extract_id_map(xml: str) -> dict:
#     """
#     Extracts key entity IDs from ground truth XML.
#     Returns a dictionary like {'course_id': '904', ...}
#     """
#     id_map = {}
#     id_map["course_id"] = re.search(r'<course\s+id="(\d+)"', xml).group(1)
#     id_map["offering_id"] = re.search(r'<offering\s+id="(\d+)"', xml).group(1)
#     id_map["class_id"] = re.search(r'<class\s+id="(\d+)"', xml).group(1)
#     id_map["instructor_id"] = re.search(r'<instructor\s+id="(\d+)"', xml).group(1)
#     return id_map

# def extract_id_map(xml_text):
#     """
#     Extracts tag and id attribute value pairs from the XML.
#     Example: <course id="904" ...> → {'course': '904'}
#     """
#     return {
#         f"{tag} id": id_val
#         for tag, id_val in re.findall(r'<(\w+)[^>]*?\bid="(\d+)"', xml_text)
#     }



In [10]:
# def replace_ids(xml: str, id_map: dict) -> str:
#     """
#     Replace ID fields in predicted XML with the true values from id_map.
#     """
#     xml = re.sub(r'(course\s+id=")\d+(")', rf'\1{id_map["course_id"]}\2', xml)
#     xml = re.sub(r'(offering\s+id=")\d+(")', rf'\1{id_map["offering_id"]}\2', xml)
#     xml = re.sub(r'(class\s+id=")\d+(")', rf'\1{id_map["class_id"]}\2', xml)
#     xml = re.sub(r'(instructor\s+id=")\d+(")', rf'\1{id_map["instructor_id"]}\2', xml)
#     return xml

# def replace_ids(xml_text, id_map):
#     """
#     Replaces ids in the prediction based on the tag context using id_map.
#     Example: Replace course id="XYZ" with id_map["course id"]
#     """
#     def replacer(match):
#         tag = match.group(1)
#         attr = match.group(2)
#         value = match.group(3)

#         key = f"{tag} {attr}"
#         if key in id_map:
#             return f'{attr}="{id_map[key]}"'
#         else:
#             return match.group(0)  # leave unchanged

#     # Match: <tag ... id="value" ...>
#     return re.sub(r'<(\w+)[^>]*?\b(id)="(\d+)"', replacer, xml_text)



In [11]:
import re

def extract_id_map(xml_text):
    """
    Extract all ID-like attributes and return a mapping:
    {'course id': '904', 'offering id': '6728', ...}
    """
    return dict(re.findall(r'(\w+)\s*=\s*"(\d+)"', xml_text))

def replace_ids(xml_text, id_map):
    """
    Replace all id="..." fields in xml_text with values from id_map
    based on their attribute name.
    """
    def replacer(match):
        attr = match.group(1)
        if attr in id_map:
            return f'{attr}="{id_map[attr]}"'
        else:
            return match.group(0)  # leave unchanged if not in map

    return re.sub(r'(\w+)\s*=\s*"\d+"', replacer, xml_text)


In [12]:

# # ✅ STEP 6: Predict and fix XML
# def fix_xml(text):
#     text = text.strip()
#     if not text.startswith("<"):
#         text = "<" + text
#     if text.count("<") > text.count(">"):
#         text += ">"
#     return text

# print("🔍 Running prediction on test set...")
# raw_test = full_dataset["test"]  # Needed for original input/output

# for i in range(300,305):
#     example = test_tokenized[i]
#     input_text = raw_test[i]["input"]
#     ground_truth = raw_test[i]["output"]

#     # Prepare inputs for model
#     inputs = {k: torch.tensor(v).unsqueeze(0).to(device) for k, v in example.items() if k in tokenizer.model_input_names}

#     # Generate prediction
#     outputs = model.generate(**inputs, max_length=512)
#     prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

#     # Fix malformed XML
#     fixed_prediction = fix_xml(prediction)

#     # Correct ID values based on ground truth
#     id_map = extract_id_map(ground_truth)
#     final_prediction = replace_ids(fixed_prediction, id_map)

#     # Output result
#     print("📥 Input:", input_text)
#     print("✅ Raw Prediction:", prediction)
#     print("🛠 Fixed XML:", fixed_prediction)
#     print("🔁 Final with Correct IDs:", final_prediction)
#     print("🎯 Ground Truth:", ground_truth)
#     print("-" * 50)



In [13]:
# ✅ STEP 6: Predict and fix XML with Training Metrics
import xml.etree.ElementTree as ET
from difflib import SequenceMatcher

def fix_xml(text):
    text = text.strip()
    if not text.startswith("<"):
        text = "<" + text
    if text.count("<") > text.count(">"):
        text += ">"
    return text

def calculate_exact_match(pred, truth):
    """Check if prediction exactly matches ground truth"""
    return pred.strip() == truth.strip()

def calculate_bleu_score(pred, truth):
    """Simple BLEU-like score using character overlap"""
    matcher = SequenceMatcher(None, pred, truth)
    return matcher.ratio()

def extract_xml_elements(xml_string):
    """Extract key XML elements for semantic comparison"""
    try:
        root = ET.fromstring(xml_string)
        elements = {}
        
        # Extract preferences attributes
        elements['term'] = root.get('term', '')
        elements['year'] = root.get('year', '')
        elements['campus'] = root.get('campus', '')
        
        # Extract subpart details
        subpart = root.find('subpart')
        if subpart is not None:
            elements['subject'] = subpart.get('subject', '')
            elements['course'] = subpart.get('course', '')
            elements['type'] = subpart.get('type', '')
        
        # Extract time preferences
        time_prefs = []
        for pref in root.findall('.//pref'):
            time_prefs.append({
                'days': pref.get('days', ''),
                'start': pref.get('start', ''),
                'stop': pref.get('stop', ''),
                'level': pref.get('level', '')
            })
        elements['time_prefs'] = time_prefs
        
        return elements
    except:
        return {}

def calculate_semantic_accuracy(pred, truth):
    """Check if key XML elements match semantically"""
    pred_elements = extract_xml_elements(pred)
    truth_elements = extract_xml_elements(truth)
    
    if not pred_elements or not truth_elements:
        return 0.0
    
    matches = 0
    total = 0
    
    # Check basic attributes
    for key in ['term', 'year', 'campus', 'subject', 'course', 'type']:
        if key in pred_elements and key in truth_elements:
            total += 1
            if pred_elements[key] == truth_elements[key]:
                matches += 1
    
    # Check time preferences
    pred_prefs = pred_elements.get('time_prefs', [])
    truth_prefs = truth_elements.get('time_prefs', [])
    
    if len(pred_prefs) == len(truth_prefs):
        for p_pref, t_pref in zip(pred_prefs, truth_prefs):
            for attr in ['days', 'start', 'stop', 'level']:
                total += 1
                if p_pref.get(attr) == t_pref.get(attr):
                    matches += 1
    
    return matches / total if total > 0 else 0.0

def is_valid_xml(xml_string):
    """Check if XML is well-formed"""
    try:
        ET.fromstring(xml_string)
        return True
    except:
        return False

# Tracking metrics
metrics = {
    'exact_matches': 0,
    'semantic_matches': 0,
    'valid_xml_count': 0,
    'total_predictions': 0,
    'bleu_scores': [],
    'semantic_accuracies': []
}

print("🔍 Running prediction on test set with metrics...")
raw_test = full_dataset["test"]

for i in range(300, 305):
    example = test_tokenized[i]
    input_text = raw_test[i]["input"]
    ground_truth = raw_test[i]["output"]

    # Prepare inputs for model
    inputs = {k: torch.tensor(v).unsqueeze(0).to(device) for k, v in example.items() if k in tokenizer.model_input_names}

    # Generate prediction
    outputs = model.generate(**inputs, max_length=512)
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Fix malformed XML
    fixed_prediction = fix_xml(prediction)

    # Correct ID values based on ground truth
    id_map = extract_id_map(ground_truth)
    final_prediction = replace_ids(fixed_prediction, id_map)

    # Calculate metrics
    exact_match = calculate_exact_match(final_prediction, ground_truth)
    bleu_score = calculate_bleu_score(final_prediction, ground_truth)
    semantic_accuracy = calculate_semantic_accuracy(final_prediction, ground_truth)
    is_valid = is_valid_xml(final_prediction)

    # Update tracking
    metrics['total_predictions'] += 1
    if exact_match:
        metrics['exact_matches'] += 1
    if semantic_accuracy > 0.9:  # Consider >90% semantic match as success
        metrics['semantic_matches'] += 1
    if is_valid:
        metrics['valid_xml_count'] += 1
    
    metrics['bleu_scores'].append(bleu_score)
    metrics['semantic_accuracies'].append(semantic_accuracy)

    # Output result with metrics
    print(f"📥 Input: {input_text}")
    print(f"✅ Raw Prediction: {prediction}")
    print(f"🛠 Fixed XML: {fixed_prediction}")
    print(f"🔁 Final with Correct IDs: {final_prediction}")
    print(f"🎯 Ground Truth: {ground_truth}")
    print(f"📊 METRICS:")
    print(f"   ✓ Exact Match: {'✅' if exact_match else '❌'}")
    print(f"   ✓ BLEU Score: {bleu_score:.3f}")
    print(f"   ✓ Semantic Accuracy: {semantic_accuracy:.3f}")
    print(f"   ✓ Valid XML: {'✅' if is_valid else '❌'}")
    print("-" * 50)

# Final metrics summary
print("📈 TRAINING QUALITY METRICS:")
print(f"🎯 Exact Match Rate: {metrics['exact_matches']}/{metrics['total_predictions']} ({metrics['exact_matches']/metrics['total_predictions']*100:.1f}%)")
print(f"🧠 Semantic Match Rate: {metrics['semantic_matches']}/{metrics['total_predictions']} ({metrics['semantic_matches']/metrics['total_predictions']*100:.1f}%)")
print(f"📝 Valid XML Rate: {metrics['valid_xml_count']}/{metrics['total_predictions']} ({metrics['valid_xml_count']/metrics['total_predictions']*100:.1f}%)")
print(f"📊 Average BLEU Score: {sum(metrics['bleu_scores'])/len(metrics['bleu_scores']):.3f}")
print(f"🔍 Average Semantic Accuracy: {sum(metrics['semantic_accuracies'])/len(metrics['semantic_accuracies']):.3f}")

# Training quality assessment
overall_quality = (metrics['exact_matches'] + metrics['semantic_matches']) / (2 * metrics['total_predictions'])
if overall_quality > 0.9:
    print("🟢 EXCELLENT TRAINING: Model is performing very well!")
elif overall_quality > 0.7:
    print("🟡 GOOD TRAINING: Model is performing well with room for improvement")
else:
    print("🔴 NEEDS IMPROVEMENT: Consider more training or data quality checks")

🔍 Running prediction on test set with metrics...
📥 Input: Set preferences for ENG 301 section C: essential for MWF schedule and likes Conference Room rooms.
✅ Raw Prediction: <preferences term="Fall" year="2010" campus="woebegon">
  <class subject="ENG" course="301" suffix="C" type="Lab">
    <timePref level="R">
      <pref days="MWF" start="1300" stop="1700" level="R"/>
    </timePref>
    <groupPref group="Conference Room" level="1"/>
  </class>
</preferences>
🛠 Fixed XML: <preferences term="Fall" year="2010" campus="woebegon">
  <class subject="ENG" course="301" suffix="C" type="Lab">
    <timePref level="R">
      <pref days="MWF" start="1300" stop="1700" level="R"/>
    </timePref>
    <groupPref group="Conference Room" level="1"/>
  </class>
</preferences>
🔁 Final with Correct IDs: <preferences term="Fall" year="2010" campus="woebegon">
  <class subject="ENG" course="301" suffix="C" type="Lab">
    <timePref level="R">
      <pref days="MWF" start="1300" stop="1500" level="R"/>
