In [1]:
import pandas as pd
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split
# =========================
# 0) 配置
# =========================
CSV_PATH = "/root/autodl-tmp/CommitFit/dataset/Ghadhab/dataset.csv"
MODEL_NAME = "/root/autodl-tmp/models/codet5-base"

# =========================
# 2) 读 CSV -> Dataset -> split
# =========================
df = pd.read_csv(CSV_PATH)
label2id={'Adaptive':0, 'Corrective':1, 'Perfective':2}
df = df.replace({"labels": label2id})
df

Unnamed: 0,user,repo,commit,labels,msgs,diffs,feature
0,ponsonio,RxJava,0531b8bff5c14d9504beefb4ad47f473e3a22932,2,Change hasException to hasThrowable--,diff --git a/rxjava-core/src/main/java/rx/Noti...,"[1, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,ponsonio,RxJava,0950c46beda335819928585f1262dfe1dca78a0b,0,Trying to extend the Scheduler interface accor...,diff --git a/rxjava-core/src/main/java/rx/Sche...,"[2, 44, 0, 0, 30, 0, 0, 1, 18, 0, 0, 0, 0, 0, ..."
2,ponsonio,RxJava,0f92fdd8e6422d5b79c610a7fd8409d222315a49,0,RunAsync method for outputting multiple values--,diff --git a/rxjava-contrib/rxjava-async-util/...,"[2, 53, 0, 0, 42, 0, 0, 1, 45, 1, 0, 0, 0, 0, ..."
3,ponsonio,RxJava,100f571c9a2835d5a30a55374b9be74c147e031f,1,forEach with Action1 but not Observer--I re-re...,diff --git a/language-adaptors/rxjava-groovy/s...,"[1, 5, 122, 9, 10, 9, 4, 1, 5, 18, 2, 0, 0, 0,..."
4,ponsonio,RxJava,191f023cf5253ea90647bc091dcaf55ccdce81cc,1,1.x: Fix Completable swallows- OnErrorNotImple...,diff --git a/src/main/java/rx/Completable.java...,"[1, 1, 0, 0, 0, 0, 0, 1, 21, 0, 0, 0, 0, 0, 0,..."
...,...,...,...,...,...,...,...
1776,jenkinsci,clearcase-plugin,51e9da224f80254476a7dc446bca817b505381d8,2,Use a temporary file to decrease memory consum...,diff --git a/src/main/java/hudson/plugins/clea...,"[2, 12, 0, 4, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,..."
1777,jexp,batch-import,609d6c4b1eea2c33d9fb950fcbb9ba9dc1f80fc3,2,added a more memory efficient structure for st...,diff --git a/src/main/java/org/neo4j/batchimpo...,"[10, 159, 29, 35, 9, 2, 1, 5, 106, 0, 4, 8, 0,..."
1778,hdiv,hdiv,19b650c78a1c76f4fd90274d7f163f863c0d39e4,2,Memory and performance optimizations,diff --git a/hdiv-config/src/main/java/org/hdi...,"[31, 302, 131, 140, 170, 89, 53, 7, 88, 14, 17..."
1779,casidiablo,persistence,d7bf95159df37a3d338ca267dddd3d26b38ec37c,2,Now it is possible to specify the sqlite open ...,diff --git a/pom.xml b/pom.xml\nindex 394263b....,"[5, 57, 20, 9, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [2]:
import re

def compress_diff_minimal(
    diff_text: str,
    max_changed_lines: int = 120,   # 最多保留多少条 +/-
    max_chars: int = 3500,          # 最终硬截断（字符）
) -> str:
    if not diff_text:
        return ""

    lines = diff_text.splitlines()
    kept = []
    changed_cnt = 0

    for ln in lines:
        # 文件头
        if ln.startswith("diff --git "):
            kept.append(ln)
            continue

        # 跳过噪声行
        if ln.startswith(("index ", "new file mode", "deleted file mode", "similarity index", "rename from", "rename to")):
            continue
        if ln.startswith(("--- ", "+++ ")):  # 这两行通常很长且重复文件名，可选保留；这里跳过以更“最小”
            continue
        if "GIT binary patch" in ln:
            continue

        # hunk 头
        if ln.startswith("@@"):
            kept.append(ln)
            continue

        # 只保留变更行（排除+++ / ---）
        if (ln.startswith("+") and not ln.startswith("+++")) or (ln.startswith("-") and not ln.startswith("---")):
            kept.append(ln)
            changed_cnt += 1
            if changed_cnt >= max_changed_lines:
                kept.append("... (diff truncated: too many changed lines)")
                break
            continue

        # 其他上下文行：不保留（最小改动版）
        # 如果你想保留少量上下文，把这里改成“遇到变更行附近保留1-2行”即可

    out = "\n".join(kept).strip()
    if len(out) > max_chars:
        out = out[:max_chars] + "\n... (diff truncated: max_chars)"
    return out


In [3]:
def build_prompt(diff_compact: str) -> str:
    return (
        "Write a concise Git commit message (imperative mood). "
        "Focus on what changed.\n\n"
        "Diff:\n"
        f"{diff_compact}\n\n"
        "Commit message:"
    )


In [4]:
df["diff_compact"] = df["diffs"].map(compress_diff_minimal)
df["gen_prompt"] = df["diff_compact"].map(build_prompt)   # 不要 label 版本

In [14]:
df = df.rename(columns={'msgs':'target_text','gen_prompt':'source_text'})
df

Unnamed: 0,user,repo,commit,labels,target_text,diffs,feature,diff_compact,source_text
0,ponsonio,RxJava,0531b8bff5c14d9504beefb4ad47f473e3a22932,2,Change hasException to hasThrowable--,diff --git a/rxjava-core/src/main/java/rx/Noti...,"[1, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",diff --git a/rxjava-core/src/main/java/rx/Noti...,Write a concise Git commit message (imperative...
1,ponsonio,RxJava,0950c46beda335819928585f1262dfe1dca78a0b,0,Trying to extend the Scheduler interface accor...,diff --git a/rxjava-core/src/main/java/rx/Sche...,"[2, 44, 0, 0, 30, 0, 0, 1, 18, 0, 0, 0, 0, 0, ...",diff --git a/rxjava-core/src/main/java/rx/Sche...,Write a concise Git commit message (imperative...
2,ponsonio,RxJava,0f92fdd8e6422d5b79c610a7fd8409d222315a49,0,RunAsync method for outputting multiple values--,diff --git a/rxjava-contrib/rxjava-async-util/...,"[2, 53, 0, 0, 42, 0, 0, 1, 45, 1, 0, 0, 0, 0, ...",diff --git a/rxjava-contrib/rxjava-async-util/...,Write a concise Git commit message (imperative...
3,ponsonio,RxJava,100f571c9a2835d5a30a55374b9be74c147e031f,1,forEach with Action1 but not Observer--I re-re...,diff --git a/language-adaptors/rxjava-groovy/s...,"[1, 5, 122, 9, 10, 9, 4, 1, 5, 18, 2, 0, 0, 0,...",diff --git a/language-adaptors/rxjava-groovy/s...,Write a concise Git commit message (imperative...
4,ponsonio,RxJava,191f023cf5253ea90647bc091dcaf55ccdce81cc,1,1.x: Fix Completable swallows- OnErrorNotImple...,diff --git a/src/main/java/rx/Completable.java...,"[1, 1, 0, 0, 0, 0, 0, 1, 21, 0, 0, 0, 0, 0, 0,...",diff --git a/src/main/java/rx/Completable.java...,Write a concise Git commit message (imperative...
...,...,...,...,...,...,...,...,...,...
1776,jenkinsci,clearcase-plugin,51e9da224f80254476a7dc446bca817b505381d8,2,Use a temporary file to decrease memory consum...,diff --git a/src/main/java/hudson/plugins/clea...,"[2, 12, 0, 4, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,...",diff --git a/src/main/java/hudson/plugins/clea...,Write a concise Git commit message (imperative...
1777,jexp,batch-import,609d6c4b1eea2c33d9fb950fcbb9ba9dc1f80fc3,2,added a more memory efficient structure for st...,diff --git a/src/main/java/org/neo4j/batchimpo...,"[10, 159, 29, 35, 9, 2, 1, 5, 106, 0, 4, 8, 0,...",diff --git a/src/main/java/org/neo4j/batchimpo...,Write a concise Git commit message (imperative...
1778,hdiv,hdiv,19b650c78a1c76f4fd90274d7f163f863c0d39e4,2,Memory and performance optimizations,diff --git a/hdiv-config/src/main/java/org/hdi...,"[31, 302, 131, 140, 170, 89, 53, 7, 88, 14, 17...",diff --git a/hdiv-config/src/main/java/org/hdi...,Write a concise Git commit message (imperative...
1779,casidiablo,persistence,d7bf95159df37a3d338ca267dddd3d26b38ec37c,2,Now it is possible to specify the sqlite open ...,diff --git a/pom.xml b/pom.xml\nindex 394263b....,"[5, 57, 20, 9, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","diff --git a/pom.xml b/pom.xml\n@@ -23,5 +23,5...",Write a concise Git commit message (imperative...


In [6]:
# df['source_text'][0]

In [7]:
# df['diff_compact'][0]

In [8]:
train, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val, test = train_test_split(temp_df, test_size=0.5, random_state=42)

In [9]:
from simplet5_trl import SimpleT5_TRL
import pandas as pd

# # Prepare data
# train_df = pd.DataFrame({
#     "source_text": ["summarize: This is a long article..."],
#     "target_text": ["Short summary"]
# })

# Train
model = SimpleT5_TRL()
model.from_pretrained('/root/autodl-tmp/models/google-t5/t5-base')
model.train(train_df=train, eval_df=val, max_epochs=4)

Loading weights:   0%|          | 0/257 [00:00<?, ?it/s]

Map:   0%|          | 0/1246 [00:00<?, ? examples/s]

Map:   0%|          | 0/267 [00:00<?, ? examples/s]

2026-01-05 10:31:29.660685: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-05 10:31:29.738536: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-05 10:31:31.432386: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
  if not hasattr(np, "object"):


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,2.371777,2.773134,16.237,8.7199,15.7203,15.7203
2,3.531259,2.64396,21.0024,11.6768,20.1587,20.1587
3,1.734208,2.595615,23.6681,13.2092,22.8988,22.8988
4,3.850724,2.585294,23.2085,12.9804,22.5262,22.5262


In [15]:
df["source_text"][0]

'Write a concise Git commit message (imperative mood). Focus on what changed.\n\nDiff:\ndiff --git a/rxjava-core/src/main/java/rx/Notification.java b/rxjava-core/src/main/java/rx/Notification.java\n@@ -92,5 +92,5 @@ public class Notification<T> {\n-    public boolean hasException() {\n+    public boolean hasThrowable() {\n@@ -126,5 +126,5 @@ public class Notification<T> {\n-        if (hasException())\n+        if (hasThrowable())\n@@ -137,5 +137,5 @@ public class Notification<T> {\n-        if (hasException())\n+        if (hasThrowable())\n@@ -155,5 +155,5 @@ public class Notification<T> {\n-        if (hasException() && !getThrowable().equals(notification.getThrowable()))\n+        if (hasThrowable() && !getThrowable().equals(notification.getThrowable()))\n\nCommit message:'

In [17]:
# Predict
model.load_model("outputs/checkpoint-624", use_gpu=True)
print(model.predict(df["source_text"][0]))

Loading weights:   0%|          | 0/257 [00:00<?, ?it/s]

['Added hasException() and hasThrowable() functionality']


In [19]:
# df["target_text"][0]

'Change hasException to hasThrowable--'