In [None]:
import os, io, re, math, json, random, warnings, tokenize
from dataclasses import dataclass
from typing import List, Dict, Optional

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
# ============ Optional: HuggingFace tokenizers (for BPE from scratch) ============
try:
    from tokenizers import Tokenizer
    from tokenizers.models import BPE
    from tokenizers.trainers import BpeTrainer
    from tokenizers.pre_tokenizers import Whitespace
    from tokenizers.normalizers import NFKC
except Exception as e:
    raise RuntimeError(
        "The 'tokenizers' package is required. Install with: pip install tokenizers"
    )

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# =============================================================================
# 0) Small utilities
# =============================================================================

SPECIAL_TOKENS = ["[PAD]", "[CLS]", "[SEP]", "[MASK]", "[UNK]"]
PAD, CLS, SEP, MASK, UNK = SPECIAL_TOKENS

def simple_format_code_py(code: str) -> str:
    """아주 가벼운 포매팅: 주석 제거 + 공백 정리 (실전은 black/ruff 등 별도 파이프라인 권장)"""
    try:
        out_tokens = []
        tokgen = tokenize.generate_tokens(io.StringIO(code).readline)
        for tok_type, tok_str, *_ in tokgen:
            if tok_type != tokenize.COMMENT:
                out_tokens.append((tok_type, tok_str))
        code = tokenize.untokenize(out_tokens)
    except Exception:
        pass
    code = code.replace("\t", "    ")
    code = "\n".join(ln.rstrip() for ln in code.splitlines())
    code = re.sub(r"\n{3,}", "\n\n", code)
    return code

# ---- Role tagging schema (확장 가능) ----
ROLE2ID = {
    "PAD": 0,
    "UNK": 1,
    "KEYWORD": 2,
    "IDENTIFIER": 3,
    "OP": 4,
    "LITERAL": 5,
    "SEP": 6,   # special (CLS/SEP/MASK/PAD) 등에 부여할 용도
}
ROLE_VOCAB_SIZE = len(ROLE2ID)

def heuristic_role_tagging(tokens: List[str]) -> List[int]:
    py_keywords = {
        "def","class","return","if","elif","else","for","while","try","except","finally",
        "with","as","import","from","pass","break","continue","lambda","yield","global",
        "nonlocal","assert","raise","del","in","is","not","and","or","true","false","none",
    }
    ops = set(list("=+-*/%<>!&|^~:.,;()[]{}@")) | {
        "**","//","==","!=",">=","<=","->","+=","-=","*=","/=","%=","**=","//=",
        "<<",">>","<<=",">>=","&=","|=","^=","@="
    }

    role_ids = []
    for tk in tokens:
        low = tk.lower()
        if tk in (PAD, CLS, SEP, MASK, UNK):
            role_ids.append(ROLE2ID["SEP"])
        elif low in py_keywords:
            role_ids.append(ROLE2ID["KEYWORD"])
        elif re.fullmatch(r"\d+(\.\d+)?", tk):
            role_ids.append(ROLE2ID["LITERAL"])
        elif (len(tk) >= 2 and tk[0] in ("'", '"')) or re.fullmatch(r"['\"].*['\"]", tk):
            role_ids.append(ROLE2ID["LITERAL"])
        elif (tk in ops) or re.fullmatch(r"[\(\)\[\]\{\}\,\:\.\;\+\-\*\/\%\=\!\<\>\&\|\^\~@]+", tk):
            role_ids.append(ROLE2ID["OP"])
        else:
            role_ids.append(ROLE2ID["IDENTIFIER"])
    return role_ids

In [None]:
# =============================================================================
# 1) Train BPE Tokenizer from scratch
# =============================================================================

def iter_corpus_texts(parquet_path: str, text_col_hint: str = "text"):
    df = pd.read_parquet(parquet_path)
    if text_col_hint not in df.columns:
        cands = [c for c in df.columns if "code" in c.lower() or "text" in c.lower()]
        assert cands, f"No text/code column in: {list(df.columns)}"
        text_col_hint = cands[0]
    for x in df[text_col_hint].astype(str).tolist():
        yield simple_format_code_py(x)

def train_bpe_tokenizer_from_corpus(
    parquet_path: str,
    vocab_size: int = 32000,
    min_frequency: int = 2,
    save_dir: str = "./tokenizer",
) -> str:
    os.makedirs(save_dir, exist_ok=True)
    tok = Tokenizer(BPE(unk_token=UNK))
    tok.normalizer = NFKC()
    tok.pre_tokenizer = Whitespace()

    trainer = BpeTrainer(
        vocab_size=vocab_size,
        min_frequency=min_frequency,
        special_tokens=SPECIAL_TOKENS,
        show_progress=True,
    )

    # tokenizers는 파일/리스트 경로 기반 학습이 일반적이지만,
    # 여기서는 in-memory 코퍼스를 임시 파일로 저장 후 학습.
    tmp_txt = os.path.join(save_dir, "corpus.tmp.txt")
    with open(tmp_txt, "w", encoding="utf-8") as f:
        for i, text in enumerate(iter_corpus_texts(parquet_path)):
            f.write(text.replace("\n", " ") + "\n")
            if (i + 1) % 10000 == 0:
                print(f"[Tokenizer] fed {i+1} lines")

    tok.train(files=[tmp_txt], trainer=trainer)

    tok_path = os.path.join(save_dir, "tokenizer.json")
    tok.save(tok_path)
    print(f"[Tokenizer] saved to {tok_path}")
    return tok_path

In [None]:
# =============================================================================
# 2) Dataset & Collator
# =============================================================================

class CodeDataset(Dataset):
    def __init__(
        self,
        parquet_path: str,
        tokenizer_json: str,
        text_col: str = "text",
        max_len: int = 256,
        do_format: bool = True,
    ):
        super().__init__()
        self.df = pd.read_parquet(parquet_path)
        if text_col not in self.df.columns:
            cands = [c for c in self.df.columns if "code" in c.lower() or "text" in c.lower()]
            assert cands, f"No text/code column in: {list(self.df.columns)}"
            text_col = cands[0]
        self.text_col = text_col
        self.max_len = max_len
        self.do_format = do_format

        self.tokenizer = Tokenizer.from_file(tokenizer_json)
        # pad id 세팅
        self.pad_id = self.tokenizer.token_to_id(PAD)
        assert self.pad_id is not None, "Tokenizer must contain [PAD]"
        self.cls_id = self.tokenizer.token_to_id(CLS)
        self.sep_id = self.tokenizer.token_to_id(SEP)
        self.mask_id = self.tokenizer.token_to_id(MASK)
        self.unk_id = self.tokenizer.token_to_id(UNK)

    def encode(self, text: str) -> Dict[str, List[int]]:
        enc = self.tokenizer.encode(text)
        ids = enc.ids
        # [CLS] ... [SEP]
        ids = [self.cls_id] + ids[: self.max_len - 2] + [self.sep_id]
        # segment ids: 전부 0 (단일 시퀀스)
        seg = [0] * len(ids)
        return {"input_ids": ids, "token_type_ids": seg}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        code = str(self.df.iloc[idx][self.text_col])
        if self.do_format:
            code = simple_format_code_py(code)
        out = self.encode(code)
        # 토큰 문자열 복원용(역할 태깅)
        tokens = [self.tokenizer.id_to_token(i) for i in out["input_ids"]]
        role_ids = heuristic_role_tagging(tokens)
        assert len(role_ids) == len(out["input_ids"])
        return {
            "input_ids": torch.tensor(out["input_ids"], dtype=torch.long),
            "token_type_ids": torch.tensor(out["token_type_ids"], dtype=torch.long),
            "role_ids": torch.tensor(role_ids, dtype=torch.long),
        }

@dataclass
class CollatorMLMRole:
    pad_id: int
    mask_id: int
    vocab_size: int
    mlm_prob: float = 0.15

    def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        # pad to max length in batch
        max_len = max(len(f["input_ids"]) for f in features)

        def pad_tensor(t: torch.Tensor, pad_val: int):
            if t.dim() == 1:
                out = torch.full((max_len,), pad_val, dtype=t.dtype)
                out[: len(t)] = t
                return out
            raise ValueError("expected 1D tensors")

        input_ids = torch.stack([pad_tensor(f["input_ids"], self.pad_id) for f in features])
        token_type_ids = torch.stack([pad_tensor(f["token_type_ids"], 0) for f in features])
        role_ids = torch.stack([pad_tensor(f["role_ids"], ROLE2ID["PAD"]) for f in features])
        attention_mask = (input_ids != self.pad_id).long()

        # ---- MLM labels ----
        labels = input_ids.clone()
        # special tokens mask
        special_mask = (input_ids == self.pad_id)
        # 확률적으로 마스킹
        prob = torch.full_like(input_ids, fill_value=self.mlm_prob, dtype=torch.float)
        prob[special_mask] = 0.0
        masked = torch.bernoulli(prob).bool()

        labels[~masked] = -100

        # 80% [MASK]
        replace_mask = masked & (torch.rand_like(prob) < 0.8)
        input_ids[replace_mask] = self.mask_id

        # 10% random
        random_mask = masked & (~replace_mask) & (torch.rand_like(prob) < 0.5)
        random_words = torch.randint(low=0, high=self.vocab_size, size=input_ids.size(), dtype=torch.long)
        input_ids[random_mask] = random_words[random_mask]

        # Role labels: 스페셜/패딩 토큰 제외
        role_labels = role_ids.clone()
        role_labels[(input_ids == self.pad_id)] = -100  # 이미 패딩은 안 쓰이지만 안전하게

        return {
            "input_ids": input_ids,
            "token_type_ids": token_type_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "role_ids": role_ids,
            "role_labels": role_labels,
        }

In [None]:
# =============================================================================
# 3) Model (BERT-like Encoder with Role Embeddings)
# =============================================================================

class RoleBertConfig:
    def __init__(
        self,
        vocab_size: int,
        role_vocab_size: int = ROLE_VOCAB_SIZE,
        hidden_size: int = 768,
        num_hidden_layers: int = 6,
        num_attention_heads: int = 12,
        intermediate_size: int = 3072,
        max_position_embeddings: int = 512,
        type_vocab_size: int = 2,
        layer_norm_eps: float = 1e-12,
        hidden_dropout_prob: float = 0.1,
    ):
        self.vocab_size = vocab_size
        self.role_vocab_size = role_vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.layer_norm_eps = layer_norm_eps
        self.hidden_dropout_prob = hidden_dropout_prob

class RoleBertForPretraining(nn.Module):
    def __init__(self, cfg: RoleBertConfig, pad_id: int):
        super().__init__()
        self.cfg = cfg
        self.pad_id = pad_id

        # Embeddings
        self.word_embeddings = nn.Embedding(cfg.vocab_size, cfg.hidden_size, padding_idx=pad_id)
        self.pos_embeddings = nn.Embedding(cfg.max_position_embeddings, cfg.hidden_size)
        self.seg_embeddings = nn.Embedding(cfg.type_vocab_size, cfg.hidden_size)
        self.role_embeddings = nn.Embedding(cfg.role_vocab_size, cfg.hidden_size)

        self.emb_ln = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
        self.emb_drop = nn.Dropout(cfg.hidden_dropout_prob)

        # Encoder (PyTorch Transformer)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=cfg.hidden_size,
            nhead=cfg.num_attention_heads,
            dim_feedforward=cfg.intermediate_size,
            dropout=cfg.hidden_dropout_prob,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg.num_hidden_layers)

        # Heads
        self.mlm_transform = nn.Sequential(
            nn.Linear(cfg.hidden_size, cfg.hidden_size),
            nn.GELU(),
            nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps),
        )
        self.mlm_decoder = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
        self.mlm_bias = nn.Parameter(torch.zeros(cfg.vocab_size))
        self.mlm_decoder.bias = self.mlm_bias

        self.role_classifier = nn.Linear(cfg.hidden_size, cfg.role_vocab_size)

        # Weight tying: decoder.weight = embeddings.weight
        self.mlm_decoder.weight = self.word_embeddings.weight

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.trunc_normal_(m.weight, std=0.02)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def _embed(self, input_ids, token_type_ids, role_ids):
        bsz, seqlen = input_ids.size()
        pos_ids = torch.arange(seqlen, device=input_ids.device).unsqueeze(0).expand(bsz, seqlen)

        x = (
            self.word_embeddings(input_ids)
            + self.pos_embeddings(pos_ids)
            + self.seg_embeddings(token_type_ids)
            + self.role_embeddings(role_ids.clamp(0, self.cfg.role_vocab_size - 1))
        )
        x = self.emb_ln(x)
        x = self.emb_drop(x)
        return x

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        role_ids: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        role_labels: Optional[torch.Tensor] = None,
        lambda_mlm: float = 1.0,
        lambda_role: float = 1.0,
    ):
        x = self._embed(input_ids, token_type_ids, role_ids)
        # PyTorch Transformer expects True where to mask (padding positions)
        key_padding_mask = (attention_mask == 0)
        h = self.encoder(x, src_key_padding_mask=key_padding_mask)

        # MLM
        mlm_hidden = self.mlm_transform(h)
        logits = self.mlm_decoder(mlm_hidden)

        # Role
        role_logits = self.role_classifier(h)

        loss = None
        mlm_loss = None
        role_loss = None
        if labels is not None:
            mlm_loss = nn.CrossEntropyLoss(ignore_index=-100)(
                logits.view(-1, logits.size(-1)), labels.view(-1)
            )
        if role_labels is not None:
            role_loss = nn.CrossEntropyLoss(ignore_index=-100)(
                role_logits.view(-1, role_logits.size(-1)), role_labels.view(-1)
            )
        if (mlm_loss is not None) or (role_loss is not None):
            loss = (lambda_mlm * (mlm_loss if mlm_loss is not None else 0.0)) + \
                   (lambda_role * (role_loss if role_loss is not None else 0.0))

        return {
            "loss": loss,
            "mlm_loss": mlm_loss,
            "role_loss": role_loss,
            "logits": logits,
            "role_logits": role_logits,
            "last_hidden_state": h,
        }

In [None]:
# =============================================================================
# 4) Optimizer & Scheduler
# =============================================================================

class CosineWithWarmup:
    """Cosine decay with warmup (simple, step-based)"""
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr_ratio=0.0):
        self.optimizer = optimizer
        self.warmup_steps = max(1, warmup_steps)
        self.total_steps = max(self.warmup_steps + 1, total_steps)
        self.min_lr_ratio = min_lr_ratio
        self.step_num = 0
        self.base_lrs = [g["lr"] for g in optimizer.param_groups]

    def step(self):
        self.step_num += 1
        for i, group in enumerate(self.optimizer.param_groups):
            base_lr = self.base_lrs[i]
            if self.step_num <= self.warmup_steps:
                lr = base_lr * self.step_num / self.warmup_steps
            else:
                progress = (self.step_num - self.warmup_steps) / (self.total_steps - self.warmup_steps)
                cosine = 0.5 * (1 + math.cos(math.pi * progress))
                lr = base_lr * (self.min_lr_ratio + (1 - self.min_lr_ratio) * cosine)
            group["lr"] = lr

def build_optimizer(model, lr=1e-4, weight_decay=0.01):
    # AdamW (수동 구현 or torch.optim.AdamW)
    return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
# =============================================================================
# 5) Train loop
# =============================================================================

def train(
    model: RoleBertForPretraining,
    loader: DataLoader,
    num_epochs: int = 1,
    lr: float = 1e-4,
    weight_decay: float = 0.01,
    warmup_ratio: float = 0.06,
    lambda_mlm: float = 1.0,
    lambda_role: float = 1.0,
    max_grad_norm: float = 1.0,
):
    model.to(DEVICE)
    model.train()

    optimizer = build_optimizer(model, lr=lr, weight_decay=weight_decay)
    total_steps = num_epochs * len(loader)
    warmup_steps = int(total_steps * warmup_ratio)
    scheduler = CosineWithWarmup(optimizer, warmup_steps, total_steps, min_lr_ratio=0.0)

    step = 0
    for epoch in range(num_epochs):
        for batch in loader:
            input_ids = batch["input_ids"].to(DEVICE)
            token_type_ids = batch["token_type_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            role_ids = batch["role_ids"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            role_labels = batch["role_labels"].to(DEVICE)

            out = model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                role_ids=role_ids,
                labels=labels,
                role_labels=role_labels,
                lambda_mlm=lambda_mlm,
                lambda_role=lambda_role,
            )
            loss = out["loss"]

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            scheduler.step()

            if step % 100 == 0:
                mlm_l = float(out["mlm_loss"]) if out["mlm_loss"] is not None else 0.0
                role_l = float(out["role_loss"]) if out["role_loss"] is not None else 0.0
                print(f"ep {epoch} step {step}/{total_steps} | loss {float(loss):.4f} | mlm {mlm_l:.4f} | role {role_l:.4f}")
            step += 1

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# =============================================================================
# 6) Main: end-to-end
# =============================================================================

if __name__ == "__main__":
    # ---- Paths/Params ----
    parquet_path = "/content/drive/MyDrive/code_corpus_preprocessed_light.parquet"
    save_root = "/from_scratch"
    os.makedirs(save_root, exist_ok=True)

    # ---- 1) Train tokenizer from your corpus ----
    tok_path = train_bpe_tokenizer_from_corpus(
        parquet_path=parquet_path,
        vocab_size=32000,
        min_frequency=2,
        save_dir=os.path.join(save_root, "tokenizer"),
    )

    # ---- 2) Build dataset/loader ----
    max_len = 256
    batch_size = 16

    ds = CodeDataset(
        parquet_path=parquet_path,
        tokenizer_json=tok_path,
        text_col="text",        # 필요 시 실제 열 이름으로 바꾸세요.
        max_len=max_len,
        do_format=True,
    )
    tk = Tokenizer.from_file(tok_path)
    pad_id = tk.token_to_id(PAD)
    mask_id = tk.token_to_id(MASK)
    vocab_size = tk.get_vocab_size()

    collate = CollatorMLMRole(
        pad_id=pad_id, mask_id=mask_id, vocab_size=vocab_size, mlm_prob=0.15
    )
    loader = DataLoader(
        ds, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collate, pin_memory=(DEVICE=="cuda")
    )

    # ---- 3) Build model ----
    cfg = RoleBertConfig(
        vocab_size=vocab_size,
        role_vocab_size=ROLE_VOCAB_SIZE,
        hidden_size=768,
        num_hidden_layers=6,         # 리소스 되면 12로 확장
        num_attention_heads=12,
        intermediate_size=3072,
        max_position_embeddings=512,
        type_vocab_size=2,
        hidden_dropout_prob=0.1,
    )
    model = RoleBertForPretraining(cfg, pad_id=pad_id)

    # ---- 4) Train ----
    train(
        model,
        loader,
        num_epochs=1,
        lr=1e-4,
        weight_decay=0.01,
        warmup_ratio=0.06,
        lambda_mlm=1.0,
        lambda_role=1.0,
        max_grad_norm=1.0,
    )

    # ---- 5) Save ----
    os.makedirs(os.path.join(save_root, "checkpoints"), exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_root, "checkpoints", "rolebert_scratch.pt"))
    print(f"[Done] Saved model to {os.path.join(save_root, 'checkpoints', 'rolebert_scratch.pt')}")

[Tokenizer] fed 10000 lines
[Tokenizer] fed 20000 lines
[Tokenizer] fed 30000 lines
[Tokenizer] fed 40000 lines
[Tokenizer] fed 50000 lines
[Tokenizer] fed 60000 lines
[Tokenizer] fed 70000 lines
[Tokenizer] fed 80000 lines
[Tokenizer] fed 90000 lines
[Tokenizer] fed 100000 lines
[Tokenizer] fed 110000 lines
[Tokenizer] fed 120000 lines
[Tokenizer] fed 130000 lines
[Tokenizer] fed 140000 lines
[Tokenizer] fed 150000 lines
[Tokenizer] fed 160000 lines
[Tokenizer] fed 170000 lines
[Tokenizer] fed 180000 lines
[Tokenizer] fed 190000 lines
[Tokenizer] fed 200000 lines
[Tokenizer] fed 210000 lines
[Tokenizer] fed 220000 lines
[Tokenizer] fed 230000 lines
[Tokenizer] fed 240000 lines
[Tokenizer] fed 250000 lines
[Tokenizer] fed 260000 lines
[Tokenizer] fed 270000 lines
[Tokenizer] fed 280000 lines
[Tokenizer] fed 290000 lines
[Tokenizer] fed 300000 lines
[Tokenizer] fed 310000 lines
[Tokenizer] fed 320000 lines
[Tokenizer] fed 330000 lines
[Tokenizer] fed 340000 lines
[Tokenizer] fed 350000 

In [None]:
!ls -lh /from_scratch/checkpoints

total 260M
-rw-r--r-- 1 root root 260M Oct 21 14:18 rolebert_scratch.pt


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!mkdir -p "/content/drive/MyDrive/models/rolebert"
!cp -v /from_scratch/checkpoints/rolebert_scratch.pt "/content/drive/MyDrive/models/rolebert/"
!cp -v /from_scratch/tokenizer/tokenizer.json "/content/drive/MyDrive/models/rolebert/"

Mounted at /content/drive
'/from_scratch/checkpoints/rolebert_scratch.pt' -> '/content/drive/MyDrive/models/rolebert/rolebert_scratch.pt'
'/from_scratch/tokenizer/tokenizer.json' -> '/content/drive/MyDrive/models/rolebert/tokenizer.json'
