In [None]:
from google.colab import drive

# 구글 드라이브 마운트
drive.mount('/content/drive')

In [None]:
# 0) 혹시 남아있는 것들 제거
!pip -q uninstall -y torch torchvision torchaudio bitsandbytes triton transformers accelerate peft datasets einops

In [None]:

# 1) PyTorch (CUDA 12.1 빌드로 고정)
!pip -q install --index-url https://download.pytorch.org/whl/cu121 torch==2.4.1 torchvision torchaudio

In [None]:
# 2) Triton (torch 2.4.x와 궁합)
!pip -q install triton==2.3.0

In [None]:
# 3) bitsandbytes (cu121 바이너리 있는 버전)
!pip -q install bitsandbytes==0.43.1


In [None]:
# 4) 나머지 라이브러리
!pip -q install "transformers>=4.46.1" "accelerate>=0.33.0" "peft>=0.11.1" "datasets>=2.20.0" "einops"

In [None]:
import torch, bitsandbytes as bnb, triton, transformers
print("torch:", torch.__version__, "CUDA:", torch.version.cuda)  # ← 반드시 12.1
print("triton:", triton.__version__)
print("bitsandbytes:", bnb.__version__)
import triton.ops  # ← 에러 없으면 OK


In [None]:
import torch, bitsandbytes as bnb
print("CUDA:", torch.cuda.is_available(), "| device:", torch.cuda.get_device_name(0))

# 간단한 4bit Linear 테스트
from bitsandbytes.nn import Linear4bit
lin = Linear4bit(1024, 1024, compute_dtype=torch.bfloat16).cuda()
x = torch.randn(2, 1024, device="cuda")
y = lin(x)
print("OK: 4bit matmul ->", y.shape)


In [None]:
# bnb만 올리면 됨 (재시작 권장)
!pip -q install --upgrade bitsandbytes==0.44.1


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import torch

QWEN_DIR = "/content/drive/MyDrive/DILAB/qwen3-8b"  # 네 경로

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # A100이면 bfloat16 권장
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tok = AutoTokenizer.from_pretrained(QWEN_DIR, use_fast=False, trust_remote_code=True)

base = AutoModelForCausalLM.from_pretrained(
    QWEN_DIR,
    trust_remote_code=True,
    quantization_config=bnb_cfg,
    device_map="auto",
    dtype=torch.bfloat16,            # ← torch_dtype 대신 dtype 사용
    attn_implementation="eager",     # 문제 예방용
    low_cpu_mem_usage=True,
)

lora_cfg = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj"],  # 필요시 실제 모듈명으로 조정
    bias="none", task_type="FEATURE_EXTRACTION",
)

model = get_peft_model(base, lora_cfg)
model.print_trainable_parameters()


### 1.RF2스냅샷 읽기

In [None]:
import duckdb, pandas as pd

DESC_FILE   = "/content/drive/MyDrive/DILAB/OK/DI_LAB/MARS_Datathon/Datasets/SNOMED_CT_datasets/SNOMED_International_2025-10/SnomedCT_InternationalRF2_PRODUCTION_20251001T120000Z/Snapshot/Terminology/sct2_Description_Snapshot-en_INT_20251001.txt"
LANGREF_FILE= "/content/drive/MyDrive/DILAB/OK/DI_LAB/MARS_Datathon/Datasets/SNOMED_CT_datasets/SNOMED_International_2025-10/SnomedCT_InternationalRF2_PRODUCTION_20251001T120000Z/Snapshot/Refset/Language/der2_cRefset_LanguageSnapshot-en_INT_20251001.txt"

con = duckdb.connect()

FSN  = 900000000000003001  # Fully Specified Name
SYN  = 900000000000013009  # Synonym
PREF = 900000000000548007  # Preferred

# 활성 영문 Description
con.execute(f"""
  CREATE OR REPLACE TEMP VIEW d AS
  SELECT
    CAST(conceptId AS BIGINT)  AS concept_id,
    id                          AS description_id,
    languageCode                AS lang,
    CAST(typeId AS BIGINT)     AS type_id,
    term
  FROM read_csv_auto('{DESC_FILE}', delim='\\t', header=1, sample_size=-1)
  WHERE active='1' AND languageCode='en';
""")

# 활성 Language Refset(수용도)
con.execute(f"""
  CREATE OR REPLACE TEMP VIEW l AS
  SELECT
    referencedComponentId          AS description_id,
    CAST(acceptabilityId AS BIGINT) AS acceptability_id
  FROM read_csv_auto('{LANGREF_FILE}', delim='\\t', header=1, sample_size=-1)
  WHERE active='1';
""")

# 개념별 동의어 세트 만들기
df_en = con.execute(f"""
  WITH joined AS (
    SELECT d.concept_id, d.description_id, d.type_id, d.term, l.acceptability_id
    FROM d LEFT JOIN l USING(description_id)
  )
  SELECT
    concept_id,
    MAX(CASE WHEN type_id={FSN} THEN term END)                                 AS fsn,
    MAX(CASE WHEN type_id={SYN} AND acceptability_id={PREF} THEN term END)     AS pref_syn,
    LIST(DISTINCT CASE WHEN type_id={SYN} THEN term END)                        AS all_synonyms
  FROM joined
  GROUP BY concept_id
""").df()

len(df_en), df_en.head(3)


### 2. "같은 개념 동의어"로 PositivePair생성

In [None]:
import numpy as np
import random

# 1) ndarray → list, None/공백 제거, 중복 제거(순서 보존)
def to_syn_list(x):
    # ndarray → list
    if isinstance(x, np.ndarray):
        vals = x.tolist()
    elif isinstance(x, list):
        vals = x
    else:
        return []

    # 문자열만 남기고 공백 제거
    vals = [s.strip() for s in vals if isinstance(s, str) and s and s.strip()]

    # 중복 제거 (순서 보존)
    seen = set()
    uniq = []
    for s in vals:
        if s not in seen:
            seen.add(s)
            uniq.append(s)
    return uniq

df_en["syn_list"] = df_en["all_synonyms"].apply(to_syn_list)

# 2) 동의어가 2개 이상인 개념 수 확인
ok_mask = df_en["syn_list"].apply(lambda lst: len(lst) >= 2)
print("concepts with >=2 synonyms:", ok_mask.sum())

# 3) pairs 생성 (개념당 최대 5쌍, 인접 페어 방식 or 셔플 방식 중 택1)

# (A) 인접 페어 방식: ["a","b","c"] → ("a","b"), ("b","c"), ...
pairs = []
for syns in df_en.loc[ok_mask, "syn_list"]:
    for i in range(min(5, len(syns)-1)):
        a, b = syns[i], syns[i+1]
        if a != b:
            pairs.append((a, b))

print("pairs:", len(pairs), "| example:", pairs[:5])

# (B) 랜덤 페어 방식(원하면 이걸로): 개념당 동의어 랜덤 셔플 후 인접 페어
# pairs = []
# for syns in df_en.loc[ok_mask, "syn_list"]:
#     random.shuffle(syns)
#     for i in range(min(5, len(syns)-1)):
#         a, b = syns[i], syns[i+1]
#         if a != b:
#             pairs.append((a, b))
# print("pairs:", len(pairs), "| example:", pairs[:5])


### 4. InfoNCE 학습 루프(in-batch negatives)

In [None]:
!pip -q install tqdm

import os, math, torch, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from tqdm.auto import tqdm

# 메모리 파편화 완화(옵션)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ---- 모델 메모리 절약 설정 (네가 이미 적용한 설정과 병행)
from peft import prepare_model_for_kbit_training
model.config.output_hidden_states = True
model.config.use_cache = False
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

# ===== 하이퍼파라미터(메모리 세이브 예시) =====
MAX_PAIRS   = 30000
MAX_LENGTH  = 64
BATCH_SIZE  = 8
EPOCHS      = 1
LR          = 1e-4
GRAD_ACCUM  = 4
TEMPERATURE = 0.05

device = model.device

class PairDataset(Dataset):
    def __init__(self, pairs): self.pairs = pairs
    def __len__(self): return len(self.pairs)
    def __getitem__(self, i):
        a, b = self.pairs[i]
        return {"a": a, "b": b}

@dataclass
class Batch:
    input_ids_a: torch.Tensor
    attention_mask_a: torch.Tensor
    input_ids_b: torch.Tensor
    attention_mask_b: torch.Tensor

def collate(batch):
    a = [b["a"] for b in batch]
    b = [b["b"] for b in batch]
    ta = tok(a, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
    tb = tok(b, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
    return Batch(ta["input_ids"], ta["attention_mask"], tb["input_ids"], tb["attention_mask"])

def mean_pool(last_hidden_state, attention_mask):
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
    summed = (last_hidden_state * mask).sum(dim=1)
    denom  = mask.sum(dim=1).clamp(min=1e-6)
    return summed / denom

def gpu_mem_gb():
    if not torch.cuda.is_available(): return "CPU"
    alloc = torch.cuda.memory_allocated() / (1024**3)
    reserv = torch.cuda.memory_reserved() / (1024**3)
    tot = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    return f"alloc {alloc:.2f}G / reserved {reserv:.2f}G / total {tot:.0f}G"

# DataLoader/Optim
ds = PairDataset(pairs[:MAX_PAIRS])
if len(ds) == 0:
    raise RuntimeError("pairs가 0개입니다. pairs 생성 과정을 확인하세요.")

dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=collate)
optim = torch.optim.AdamW(model.parameters(), lr=LR)

# ===== 학습 루프 (진행바 + EMA 손실 + 메모리 로그) =====
model.train()
ema = None
beta = 0.98  # EMA 계수
autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

for epoch in range(EPOCHS):
    pbar = tqdm(enumerate(dl, 1), total=len(dl), desc=f"Epoch {epoch+1}/{EPOCHS}")
    optim.zero_grad(set_to_none=True)

    for step, batch in pbar:
        batch = Batch(
            batch.input_ids_a.to(device, non_blocking=True),
            batch.attention_mask_a.to(device, non_blocking=True),
            batch.input_ids_b.to(device, non_blocking=True),
            batch.attention_mask_b.to(device, non_blocking=True),
        )

        with torch.autocast(device_type="cuda", dtype=autocast_dtype):
            out_a = model(
                input_ids=batch.input_ids_a,
                attention_mask=batch.attention_mask_a,
                output_hidden_states=True, use_cache=False, return_dict=True,
            )
            out_b = model(
                input_ids=batch.input_ids_b,
                attention_mask=batch.attention_mask_b,
                output_hidden_states=True, use_cache=False, return_dict=True,
            )

            h_a = out_a.hidden_states[-1]
            h_b = out_b.hidden_states[-1]

            emb_a = mean_pool(h_a, batch.attention_mask_a)
            emb_b = mean_pool(h_b, batch.attention_mask_b)
            emb_a = F.normalize(emb_a, p=2, dim=1)
            emb_b = F.normalize(emb_b, p=2, dim=1)

            logits = (emb_a @ emb_b.T) / TEMPERATURE
            target = torch.arange(logits.size(0), device=device)
            loss = 0.5 * (F.cross_entropy(logits, target) + F.cross_entropy(logits.T, target))

        loss.backward()

        if step % GRAD_ACCUM == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optim.step()
            optim.zero_grad(set_to_none=True)

        # EMA 업데이트 + 진행바 표시
        ema = loss.item() if ema is None else beta*ema + (1-beta)*loss.item()
        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "ema": f"{ema:.4f}",
            "mem": gpu_mem_gb()
        })

        # 가끔 캐시 비워 파편화 완화
        if step % (GRAD_ACCUM*50) == 0:
            torch.cuda.empty_cache()

    print(f"[epoch {epoch+1}] EMA loss: {ema:.4f} | {gpu_mem_gb()}")

print("✅ training loop done")


### 5. LoRA 어댑터를 Drive에 저장&저장 되었는지 확인

In [None]:
SAVE_DIR = "/content/drive/MyDrive/DILAB/OK/DI_LAB/MARS_Datathon/Models/qwen3-8b-snomed-embed-lora"
model.save_pretrained(SAVE_DIR)
tok.save_pretrained(SAVE_DIR)
print("saved to:", SAVE_DIR)


### 6. 유사도 확인

In [None]:
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch, torch.nn.functional as F

QWEN_DIR = "/content/drive/MyDrive/DILAB/qwen3-8b"  # 네 베이스 모델 경로
LORA_DIR = "/content/drive/MyDrive/DILAB/OK/DI_LAB/MARS_Datathon/Models/qwen3-8b-snomed-embed-lora"

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tok2 = AutoTokenizer.from_pretrained(QWEN_DIR, use_fast=False, trust_remote_code=True)
base2 = AutoModelForCausalLM.from_pretrained(
    QWEN_DIR,
    trust_remote_code=True,
    quantization_config=bnb_cfg,
    dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",
)
mdl2 = PeftModel.from_pretrained(base2, LORA_DIR)

# 추론에서도 히든스테이트를 뽑도록 설정
mdl2.config.output_hidden_states = True
mdl2.config.use_cache = False
mdl2.eval()

def mean_pool(hidden, mask):
    mask = mask.unsqueeze(-1).type_as(hidden)
    summed = (hidden * mask).sum(dim=1)
    denom  = mask.sum(dim=1).clamp(min=1e-6)
    return summed / denom

@torch.no_grad()
def encode(texts, max_length=96):
    b = tok2(texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(mdl2.device)
    o = mdl2(**b, output_hidden_states=True, use_cache=False, return_dict=True)
    h = o.hidden_states[-1]                 # ✅ 마지막 레이어 히든스테이트 사용
    e = mean_pool(h, b["attention_mask"])
    return F.normalize(e, p=2, dim=1).cpu()

# 간단 유사도 테스트
cands = ["liver cirrhosis", "hepatic cirrhosis", "chronic obstructive pulmonary disease", "hepatitis C"]
q = "hepatic cirrhosis"
qe = encode([q]); ce = encode(cands)
sims = (qe @ ce.T)[0].tolist()
for c, s in sorted(zip(cands, sims), key=lambda x: -x[1]):
    print(f"{c:<45s} {s:.3f}")


In [None]:
cands = [
    "liver cirrhosis",                     # 진짜 동의어
    "hepatitis C",                         # 관련은 있지만 다른 질환
    "chronic obstructive pulmonary disease",  # 다른 장기
    "banana", "quantum entanglement", "table tennis", "K-pop idol"  # 무관
]
q = "hepatic cirrhosis"
qe = encode([q]); ce = encode(cands)
sims = (qe @ ce.T)[0].tolist()
for c, s in sorted(zip(cands, sims), key=lambda x: -x[1]):
    print(f"{c:<35s} {s:.3f}")
