<a href="https://colab.research.google.com/github/Sridipta-Roy/Protein-Function-Prediction/blob/main/03_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from google.colab import drive

# --- Mount drive and basic config ---
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/protein-multimodal')

Mounted at /content/drive


In [None]:
!pip install -q transformers accelerate peft bitsandbytes datasets sentencepiece

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m45.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    get_linear_schedule_with_warmup,
)

from peft import LoraConfig, get_peft_model
from tqdm import tqdm

In [None]:
PROJECT_ROOT = "/content/drive/MyDrive/protein-multimodal"
DATA_DIR      = f"{PROJECT_ROOT}/data"
PROCESSED_DIR = f"{DATA_DIR}/processed"
EMB_DIR       = f"{DATA_DIR}/embeddings"

MODEL_NAME = "microsoft/phi-2"  #LM
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
ALIGNMENT_WEIGHT = 1.0
CONTRASTIVE_WEIGHT = 0.5   # modest extra contrastive term
CONTRASTIVE_TEMPERATURE = 0.07

In [None]:
with open(f"{EMB_DIR}/train_metadata.json") as f:
    train_meta = json.load(f)

with open(f"{EMB_DIR}/val_metadata.json") as f:
    val_meta = json.load(f)

train_emb = torch.load(f"{EMB_DIR}/train_embeddings.pt", map_location="cpu")
val_emb   = torch.load(f"{EMB_DIR}/val_embeddings.pt", map_location="cpu")

print(f"Train embeddings: {train_emb.shape}, dtype={train_emb.dtype}")
print(f"Val embeddings:   {val_emb.shape}, dtype={val_emb.dtype}")
print(f"Train metadata entries: {len(train_meta)}")
print(f"Val metadata entries:   {len(val_meta)}")

assert train_emb.shape[0] == len(train_meta)
assert val_emb.shape[0] == len(val_meta)

Train embeddings: torch.Size([7949, 1280]), dtype=torch.float16
Val embeddings:   torch.Size([991, 1280]), dtype=torch.float16
Train metadata entries: 7949
Val metadata entries:   991


In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Pad token ID:", tokenizer.pad_token_id)


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Pad token ID: 50256


In [None]:
# def build_prompt_for_sequence(seq: str) -> str:
#     seq = seq.strip()
#     return (
#         "[PROTEIN]\n"
#         f"LENGTH: {len(seq)}\n"
#         f"SEQUENCE: {seq}\n"
#         "[/PROTEIN]\n"
#         "FUNCTION:\n"
#     )

In [None]:
def build_prompt_for_sequence(seq, go_terms=None, ec_numbers=None):
    prompt = "[PROTEIN]\n"
    prompt += f"LENGTH: {len(seq)}\n"
    prompt += f"SEQUENCE: {seq[:50]}...(truncated)\n"

    # Only add if available
    if go_terms:
        go_str = ", ".join([g["go_term"] for g in go_terms[:5] if g["go_term"]])
        prompt += f"GO_TERMS: {go_str}\n"

    if ec_numbers:
        prompt += f"EC_NUMBERS: {', '.join(ec_numbers)}\n"

    prompt += "[/PROTEIN]\nFUNCTION:\n"
    return prompt

In [None]:
class ProteinTextDataset(Dataset):
    """
    Uses:
      - global ESM embedding for each protein
      - function text as target
      - sequence in the prompt (mirrors evaluation-time behavior)
    """
    def __init__(self, embeddings, metadata, tokenizer, max_length=384):
        self.embeddings = embeddings
        self.meta = metadata
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        emb = self.embeddings[idx]        # (D,)
        entry = self.meta[idx]

        target_text = entry["function"]

        seq = entry.get("sequence", "")
        prompt = build_prompt_for_sequence(seq)

        full_text = prompt + target_text

        encoded = self.tokenizer(
            full_text,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        return {
            "embedding": emb,  # (D,)
            "input_ids": encoded["input_ids"].squeeze(0),        # (T,)
            "attention_mask": encoded["attention_mask"].squeeze(0),  # (T,)
        }

In [None]:
def collate_fn(batch):
    """
    Pads input_ids and attention_mask to the max length in the batch.
    Embeddings are already same size, just stacked.
    """
    embeddings = torch.stack([item["embedding"] for item in batch])  # (B, D)

    input_ids_list = [item["input_ids"] for item in batch]
    attn_list      = [item["attention_mask"] for item in batch]

    max_len = max(x.size(0) for x in input_ids_list)

    padded_ids = []
    padded_masks = []

    for ids, mask in zip(input_ids_list, attn_list):
        pad_len = max_len - ids.size(0)
        if pad_len > 0:
            pad_ids = torch.full((pad_len,), tokenizer.pad_token_id, dtype=ids.dtype)
            pad_mask = torch.zeros(pad_len, dtype=mask.dtype)
            ids = torch.cat([ids, pad_ids], dim=0)
            mask = torch.cat([mask, pad_mask], dim=0)

        padded_ids.append(ids)
        padded_masks.append(mask)

    return {
        "embedding": embeddings,                        # (B, D)
        "input_ids": torch.stack(padded_ids, dim=0),    # (B, T)
        "attention_mask": torch.stack(padded_masks, 0), # (B, T)
    }


In [None]:
train_ds = ProteinTextDataset(train_emb, train_meta, tokenizer)
val_ds = ProteinTextDataset(val_emb, val_meta, tokenizer)

train_loader = DataLoader(
    train_ds,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    val_ds,
    batch_size=16,
    shuffle=False,
    collate_fn=collate_fn,
)

In [None]:
class ProteinProjector(nn.Module):
    """
    Projects global ESM embedding into LLM hidden dimension as a prefix token.
    """
    def __init__(self, input_dim, output_dim=2560):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 2048),
            nn.GELU(),
            nn.Linear(2048, output_dim),
        )

    def forward(self, x):
        return self.mlp(x)

proj = ProteinProjector(input_dim=train_emb.shape[1]).to(DEVICE)
print("Projector param dtype / device:",
      next(proj.parameters()).dtype, next(proj.parameters()).device)

Projector param dtype / device: torch.float32 cuda:0


In [None]:
#config = AutoConfig.from_pretrained(MODEL_NAME)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,   # keep weights in fp32 for stability
)
model.to(DEVICE)

lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_cfg)
model.to(DEVICE)

model.print_trainable_parameters()
print("Model first param dtype / device:",
      next(model.parameters()).dtype, next(model.parameters()).device)

config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/564M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

trainable params: 5,242,880 || all params: 2,784,926,720 || trainable%: 0.1883
Model first param dtype / device: torch.float16 cuda:0


In [None]:
def forward_with_prefix(batch):
    """
    batch:
      - embedding: (B, D_esm)
      - input_ids: (B, T)
      - attention_mask: (B, T)

    Returns:
      - total_loss (scalar, requires_grad)
      - lm_loss_val, align_loss_val, contrastive_loss_val (detached scalars for logging)
    """
    embedding = batch["embedding"].to(DEVICE).to(model.dtype)      # (B, D_esm)
    input_ids = batch["input_ids"].to(DEVICE)                      # (B, T)
    attn_mask = batch["attention_mask"].to(DEVICE)                 # (B, T)

    B, T = input_ids.shape

    # Project protein embedding to prefix token (B, 1, D_llm)
    prefix_token = proj(embedding).unsqueeze(1)                    # (B,1,D_llm)

    # Token embeddings from LM
    token_embeds = model.get_input_embeddings()(input_ids)         # (B,T,D_llm)

    # Concatenate prefix
    inputs_embeds = torch.cat([prefix_token, token_embeds], dim=1) # (B,T+1,D)

    # Attention mask: prefix visible
    prefix_mask = torch.ones((B, 1), dtype=attn_mask.dtype, device=DEVICE)
    full_attn_mask = torch.cat([prefix_mask, attn_mask], dim=1)    # (B,T+1)

    # Labels: ignore prefix position with -100
    labels = torch.full(
        (B, T + 1),
        fill_value=-100,
        dtype=input_ids.dtype,
        device=DEVICE,
    )
    labels[:, 1:] = input_ids

    # Forward with hidden states for alignment
    outputs = model(
        inputs_embeds=inputs_embeds,
        attention_mask=full_attn_mask,
        labels=labels,
        output_hidden_states=True,
    )

    lm_loss = outputs.loss

    # Text representation from last hidden layer (need to exclude prefix)
    last_hidden = outputs.hidden_states[-1]             # (B, T+1, D_llm)
    text_hidden = last_hidden[:, 1:, :]                 # (B, T, D_llm)
    text_mask = full_attn_mask[:, 1:]                   # (B, T)

    text_mask_f = text_mask.unsqueeze(-1).float()
    text_repr = (text_hidden * text_mask_f).sum(dim=1) / (text_mask_f.sum(dim=1) + 1e-8)  # (B, D_llm)

    # Protein representation: the prefix token
    prot_repr = prefix_token.squeeze(1)                 # (B, D_llm)

    # Normalize for cosine + contrastive
    prot_norm = F.normalize(prot_repr, dim=-1)
    text_norm = F.normalize(text_repr, dim=-1)

    # Alignment loss: pairwise cosine)
    cos_sim = (prot_norm * text_norm).sum(dim=-1)       # (B,)
    align_loss = 1.0 - cos_sim.mean()

    # Contrastive loss
    logits = prot_norm @ text_norm.t()                  # (B, B)
    logits = logits / CONTRASTIVE_TEMPERATURE
    targets = torch.arange(B, device=logits.device)

    loss_i2t = F.cross_entropy(logits, targets)
    loss_t2i = F.cross_entropy(logits.t(), targets)
    contrastive_loss = 0.5 * (loss_i2t + loss_t2i)

    total_loss = lm_loss + ALIGNMENT_WEIGHT * align_loss + CONTRASTIVE_WEIGHT * contrastive_loss

    # detached versions for logging only
    lm_loss_val = lm_loss.detach().item()
    align_loss_val = align_loss.detach().item()
    contrastive_loss_val = contrastive_loss.detach().item()

    return total_loss, lm_loss_val, align_loss_val, contrastive_loss_val

In [None]:
@torch.no_grad()
def eval_loop(dataloader):
    model.eval()
    proj.eval()
    total_loss = 0.0
    lm_accum = 0.0
    align_accum = 0.0
    contr_accum = 0.0
    n_batches = 0

    for batch in dataloader:
        with torch.amp.autocast("cuda", dtype=torch.float16, enabled=(DEVICE=="cuda")):
            loss, lm_l, al_l, cl_l = forward_with_prefix(batch)
        total_loss += loss.item()
        lm_accum += lm_l
        align_accum += al_l
        contr_accum += cl_l
        n_batches += 1

    n = max(1, n_batches)
    return (
        total_loss / n,
        lm_accum / n,
        align_accum / n,
        contr_accum / n,
    )

print("Model first param dtype/device:", next(model.parameters()).dtype, next(model.parameters()).device)
print("Projector first param dtype/device:", next(proj.parameters()).dtype, next(proj.parameters()).device)
print("Train embeddings dtype:", train_emb.dtype)

Model first param dtype/device: torch.float16 cuda:0
Projector first param dtype/device: torch.float32 cuda:0
Train embeddings dtype: torch.float16


In [None]:
EPOCHS = 10
LR = 5e-4

optimizer = torch.optim.AdamW(
    list(model.parameters()) + list(proj.parameters()),
    lr=LR,
)

num_training_steps = EPOCHS * len(train_loader)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * num_training_steps),
    num_training_steps=num_training_steps,
)

scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))

for epoch in range(1, EPOCHS + 1):
    print(f"\n===== Epoch {epoch} / {EPOCHS} =====")
    model.train()
    proj.train()

    total_loss = 0.0
    lm_accum = 0.0
    align_accum = 0.0
    contr_accum = 0.0

    for i, batch in enumerate(tqdm(train_loader, desc="Training")):
        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda", dtype=torch.float16, enabled=(DEVICE=="cuda")):
            loss, lm_l, al_l, cl_l = forward_with_prefix(batch)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)

        # Log gradient norm every 100 steps
        if i % 100 == 0:
            total_norm = 0
            for p in list(model.parameters()) + list(proj.parameters()):
                if p.grad is not None:
                    total_norm += p.grad.data.norm(2).item() ** 2
            total_norm = total_norm ** 0.5
            print(f"\nStep {i} | Gradient norm: {total_norm:.4f}")

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        torch.nn.utils.clip_grad_norm_(proj.parameters(),  max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        lm_accum += lm_l
        align_accum += al_l
        contr_accum += cl_l

    n_train = len(train_loader)
    avg_train = total_loss / n_train
    avg_lm    = lm_accum / n_train
    avg_align = align_accum / n_train
    avg_contr = contr_accum / n_train

    val_total, val_lm, val_align, val_contr = eval_loop(val_loader)

    print(f"[Train] total={avg_train:.4f} | lm={avg_lm:.4f} | align={avg_align:.4f} | contr={avg_contr:.4f}")
    print(f"[Val]   total={val_total:.4f} | lm={val_lm:.4f} | align={val_align:.4f} | contr={val_contr:.4f}")


  scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))



===== Epoch 1 / 10 =====


Training:   0%|          | 1/497 [00:02<17:43,  2.14s/it]


Step 0 | Gradient norm: nan


Training:  20%|██        | 101/497 [01:24<05:30,  1.20it/s]


Step 100 | Gradient norm: 0.6422


Training:  40%|████      | 201/497 [02:44<04:05,  1.21it/s]


Step 200 | Gradient norm: 1.4108


Training:  61%|██████    | 301/497 [04:06<02:43,  1.20it/s]


Step 300 | Gradient norm: 0.7016


Training:  81%|████████  | 401/497 [05:27<01:15,  1.28it/s]


Step 400 | Gradient norm: 0.7186


Training: 100%|██████████| 497/497 [06:45<00:00,  1.23it/s]


[Train] total=2.1189 | lm=1.5176 | align=0.1743 | contr=0.8539
[Val]   total=1.2420 | lm=1.1573 | align=0.0423 | contr=0.0847

===== Epoch 2 / 10 =====


Training:   0%|          | 1/497 [00:00<06:57,  1.19it/s]


Step 0 | Gradient norm: 0.4119


Training:  20%|██        | 101/497 [01:21<05:27,  1.21it/s]


Step 100 | Gradient norm: 0.3890


Training:  40%|████      | 201/497 [02:42<04:07,  1.20it/s]


Step 200 | Gradient norm: 0.3411


Training:  61%|██████    | 301/497 [04:05<02:43,  1.20it/s]


Step 300 | Gradient norm: 0.3040


Training:  81%|████████  | 401/497 [05:26<01:20,  1.20it/s]


Step 400 | Gradient norm: 0.2980


Training: 100%|██████████| 497/497 [06:45<00:00,  1.23it/s]


[Train] total=1.2443 | lm=1.1714 | align=0.0393 | contr=0.0671
[Val]   total=1.1770 | lm=1.1176 | align=0.0290 | contr=0.0606

===== Epoch 3 / 10 =====


Training:   0%|          | 1/497 [00:00<06:58,  1.18it/s]


Step 0 | Gradient norm: 0.3135


Training:  20%|██        | 101/497 [01:23<05:29,  1.20it/s]


Step 100 | Gradient norm: 0.2476


Training:  40%|████      | 201/497 [02:44<03:57,  1.24it/s]


Step 200 | Gradient norm: 0.2945


Training:  61%|██████    | 301/497 [04:05<02:43,  1.20it/s]


Step 300 | Gradient norm: 0.3983


Training:  81%|████████  | 401/497 [05:27<01:17,  1.24it/s]


Step 400 | Gradient norm: 0.3256


Training: 100%|██████████| 497/497 [06:45<00:00,  1.22it/s]


[Train] total=1.1880 | lm=1.1328 | align=0.0292 | contr=0.0519
[Val]   total=1.1403 | lm=1.0915 | align=0.0237 | contr=0.0503

===== Epoch 4 / 10 =====


Training:   0%|          | 1/497 [00:00<06:56,  1.19it/s]


Step 0 | Gradient norm: 0.3637


Training:  20%|██        | 101/497 [01:22<05:18,  1.24it/s]


Step 100 | Gradient norm: 0.2596


Training:  40%|████      | 201/497 [02:43<03:56,  1.25it/s]


Step 200 | Gradient norm: 0.3139


Training:  61%|██████    | 301/497 [04:05<02:44,  1.19it/s]


Step 300 | Gradient norm: 0.3130


Training:  81%|████████  | 401/497 [05:27<01:20,  1.20it/s]


Step 400 | Gradient norm: 0.3259


Training: 100%|██████████| 497/497 [06:45<00:00,  1.23it/s]


[Train] total=1.1543 | lm=1.1040 | align=0.0252 | contr=0.0502
[Val]   total=1.1228 | lm=1.0790 | align=0.0214 | contr=0.0447

===== Epoch 5 / 10 =====


Training:   0%|          | 1/497 [00:00<07:00,  1.18it/s]


Step 0 | Gradient norm: 0.3245


Training:  20%|██        | 101/497 [01:21<05:25,  1.21it/s]


Step 100 | Gradient norm: 0.3663


Training:  40%|████      | 201/497 [02:43<04:06,  1.20it/s]


Step 200 | Gradient norm: 0.2895


Training:  61%|██████    | 301/497 [04:06<02:42,  1.21it/s]


Step 300 | Gradient norm: 0.3209


Training:  81%|████████  | 401/497 [05:28<01:19,  1.21it/s]


Step 400 | Gradient norm: 0.3547


Training: 100%|██████████| 497/497 [06:47<00:00,  1.22it/s]


[Train] total=1.1228 | lm=1.0772 | align=0.0225 | contr=0.0463
[Val]   total=1.1083 | lm=1.0684 | align=0.0200 | contr=0.0398

===== Epoch 6 / 10 =====


Training:   0%|          | 1/497 [00:00<06:56,  1.19it/s]


Step 0 | Gradient norm: 0.3218


Training:  20%|██        | 101/497 [01:22<05:29,  1.20it/s]


Step 100 | Gradient norm: 0.2921


Training:  40%|████      | 201/497 [02:43<04:06,  1.20it/s]


Step 200 | Gradient norm: 0.2964


Training:  61%|██████    | 301/497 [04:05<02:43,  1.20it/s]


Step 300 | Gradient norm: 0.3426


Training:  81%|████████  | 401/497 [05:26<01:19,  1.21it/s]


Step 400 | Gradient norm: 0.2785


Training: 100%|██████████| 497/497 [06:44<00:00,  1.23it/s]


[Train] total=1.1012 | lm=1.0590 | align=0.0205 | contr=0.0434
[Val]   total=1.0935 | lm=1.0555 | align=0.0181 | contr=0.0397

===== Epoch 7 / 10 =====


Training:   0%|          | 1/497 [00:00<06:59,  1.18it/s]


Step 0 | Gradient norm: 0.3364


Training:  20%|██        | 101/497 [01:22<05:30,  1.20it/s]


Step 100 | Gradient norm: 0.2829


Training:  40%|████      | 201/497 [02:44<04:06,  1.20it/s]


Step 200 | Gradient norm: 0.3312


Training:  61%|██████    | 301/497 [04:06<02:43,  1.20it/s]


Step 300 | Gradient norm: 0.3498


Training:  81%|████████  | 401/497 [05:28<01:19,  1.20it/s]


Step 400 | Gradient norm: 0.3539


Training: 100%|██████████| 497/497 [06:46<00:00,  1.22it/s]


[Train] total=1.0768 | lm=1.0357 | align=0.0191 | contr=0.0441
[Val]   total=1.0829 | lm=1.0463 | align=0.0169 | contr=0.0394

===== Epoch 8 / 10 =====


Training:   0%|          | 1/497 [00:00<06:57,  1.19it/s]


Step 0 | Gradient norm: 0.3212


Training:  20%|██        | 101/497 [01:22<05:10,  1.28it/s]


Step 100 | Gradient norm: 0.3775


Training:  40%|████      | 201/497 [02:44<04:06,  1.20it/s]


Step 200 | Gradient norm: 0.2608


Training:  61%|██████    | 301/497 [04:06<02:43,  1.20it/s]


Step 300 | Gradient norm: 0.2922


Training:  81%|████████  | 401/497 [05:27<01:16,  1.25it/s]


Step 400 | Gradient norm: 0.3811


Training: 100%|██████████| 497/497 [06:45<00:00,  1.23it/s]


[Train] total=1.0612 | lm=1.0231 | align=0.0176 | contr=0.0410
[Val]   total=1.0789 | lm=1.0442 | align=0.0162 | contr=0.0369

===== Epoch 9 / 10 =====


Training:   0%|          | 1/497 [00:00<06:56,  1.19it/s]


Step 0 | Gradient norm: 0.3325


Training:  20%|██        | 101/497 [01:21<05:26,  1.21it/s]


Step 100 | Gradient norm: 0.3884


Training:  40%|████      | 201/497 [02:43<04:06,  1.20it/s]


Step 200 | Gradient norm: 0.4385


Training:  61%|██████    | 301/497 [04:05<02:43,  1.20it/s]


Step 300 | Gradient norm: 0.4008


Training:  81%|████████  | 401/497 [05:27<01:19,  1.21it/s]


Step 400 | Gradient norm: 0.3373


Training: 100%|██████████| 497/497 [06:46<00:00,  1.22it/s]


[Train] total=1.0420 | lm=1.0069 | align=0.0166 | contr=0.0370
[Val]   total=1.0743 | lm=1.0411 | align=0.0150 | contr=0.0363

===== Epoch 10 / 10 =====


Training:   0%|          | 1/497 [00:00<06:57,  1.19it/s]


Step 0 | Gradient norm: 0.2729


Training:  20%|██        | 101/497 [01:22<05:28,  1.20it/s]


Step 100 | Gradient norm: 0.3551


Training:  40%|████      | 201/497 [02:44<03:47,  1.30it/s]


Step 200 | Gradient norm: 0.5108


Training:  61%|██████    | 301/497 [04:06<02:42,  1.20it/s]


Step 300 | Gradient norm: 0.3298


Training:  81%|████████  | 401/497 [05:27<01:20,  1.20it/s]


Step 400 | Gradient norm: 0.3327


Training: 100%|██████████| 497/497 [06:45<00:00,  1.23it/s]


[Train] total=1.0348 | lm=1.0003 | align=0.0156 | contr=0.0377
[Val]   total=1.0724 | lm=1.0399 | align=0.0143 | contr=0.0364


In [None]:
SAVE_DIR = f"{PROJECT_ROOT}/trained_mllm_v3"
os.makedirs(SAVE_DIR, exist_ok=True)

model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

proj_path = os.path.join(SAVE_DIR, "protein_projector.pt")
torch.save(proj.state_dict(), proj_path)

print("Saved model + tokenizer to:", SAVE_DIR)
print("Saved projector to:", proj_path)

Saved model + tokenizer to: /content/drive/MyDrive/protein-multimodal/trained_mllm_v3
Saved projector to: /content/drive/MyDrive/protein-multimodal/trained_mllm_v3/protein_projector.pt
