In [None]:
!pip install -U bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-c

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
import os, torch, random
from transformers import (
    AutoModel, AutoTokenizer,
    AutoModelForCausalLM, BitsAndBytesConfig
)
from torch import nn
from torch.utils.data import DataLoader
from datasets import load_dataset
import pathlib, torch, datetime as dt

In [10]:
CKPT_DIR = pathlib.Path("/content/drive/MyDrive/fMRIwork/checkpoints")
CKPT_DIR.mkdir(exist_ok=True, parents=True)

In [None]:
# # ----------------------- 3. Toy training data --------------------------
# corpus = [
#     "The cat ran up the street.",
#     "A quick brown fox jumps over the lazy dog.",
#     "Deep learning models require lots of data.",
#     "Transformers have changed natural language processing.",
#     "Neural networks can approximate complex functions.",
#     "The sky was clear and the stars were bright.",
#     "Reinforcement learning agents learn through rewards.",
#     "Graph neural networks operate on relational data.",
#     "Diffusion models generate high-quality images.",
#     "Optimization is at the heart of machine learning."
# ] * 2000  # ~2000 sentences – tiny but enough for demo

In [2]:
path = "/content/drive/MyDrive/fMRIwork/c4training_large.txt"
with open(path, encoding="utf-8") as f:
    corpus = [line.strip() for line in f]
print(f"Loaded {len(corpus):,} lines")


Loaded 1,000,000 lines


In [3]:
corpus[:5]

['Beginners BBQ Class Taking Place in Missoula!',
 "Discussion in 'Mac OS X Lion (10.7)' started by axboi87, Jan 20, 2012.",
 'Foil plaid lycra and spandex shortall with metallic slinky insets. Attached metallic elastic belt with O-ring. Headband included. Great hip hop or jazz dance costume. Made in the USA.',
 'How many backlinks per day for new site?',
 'The Denver Board of Education opened the 2017-18 school year with an update on projects that include new construction, upgrades, heat mitigation and quality learning environments.']

In [4]:
# ------------------------------ CONFIG --------------------------------
NEOBERT_CKPT   = "bert-base-uncased"
LLAMA_CKPT     = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
BATCH_SIZE     = 64
LR             = 2e-4
NUM_EPOCHS     = 10
MAX_LEN        = 64                             # truncate long sentences
DEVICE         = "cuda" if torch.cuda.is_available() else "cpu"
SEED           = 42

torch.manual_seed(SEED); random.seed(SEED)

In [5]:
# -------------------------- Load models -----------------------------
print("Loading NeoBERT …")
neo_tok  = AutoTokenizer.from_pretrained(NEOBERT_CKPT)
neo_bert = AutoModel.from_pretrained(NEOBERT_CKPT).requires_grad_(False).to(DEVICE)

print("Loading TinyLlama …")
bnb_cfg  = BitsAndBytesConfig(load_in_8bit=True)
llm_tok  = AutoTokenizer.from_pretrained(LLAMA_CKPT)
llm = AutoModelForCausalLM.from_pretrained(
    LLAMA_CKPT,
    device_map="auto",
    torch_dtype=torch.float32,      # load weights in half-precision
).eval().requires_grad_(False)


bert_dim  = neo_bert.config.hidden_size                   # 768
llama_dim = llm.config.hidden_size                        # 2048


Loading NeoBERT …


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Loading TinyLlama …


tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [6]:
# ------------------------ 2. Projector net -----------------------------
class Bert2Llama(nn.Module):
    """4-layer GELU MLP: 768 ➜ 2048"""
    def __init__(self, in_d=bert_dim, out_d=llama_dim, h_d=1536, n=4):
        super().__init__()
        layers = []
        for i in range(n):
            layers.append(nn.Linear(in_d if i == 0 else h_d,
                                    out_d if i == n-1 else h_d))
            if i < n-1:
                layers.append(nn.GELU())
        self.net = nn.Sequential(*layers)

    def forward(self, x):          # x: (B, T, 768)
        return self.net(x)         #    (B, T, 2048)

projector = Bert2Llama().to(DEVICE)

In [7]:
def collate(batch):
    # Encode with both tokenizers
    neo_enc  = neo_tok(batch, padding=True, truncation=True,
                       max_length=MAX_LEN, return_tensors="pt")
    llama_enc= llm_tok(batch, padding=True, truncation=True,
                       max_length=MAX_LEN, return_tensors="pt")
    return neo_enc, llama_enc

loader = DataLoader(corpus, batch_size=BATCH_SIZE,
                    shuffle=True, collate_fn=collate, num_workers=2)

In [12]:
# ---------------------- Training loop -------------------------------
optim = torch.optim.AdamW(projector.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss(ignore_index=llm_tok.pad_token_id)
global_step  = 0

print("\n=== Training projector ===")
for epoch in range(1, NUM_EPOCHS + 1):
    total, n_tok = 0.0, 0
    for neo_enc, llama_enc in loader:
        # Move to GPU
        neo_enc  = {k:v.to(DEVICE) for k,v in neo_enc.items()}
        llama_ids= llama_enc["input_ids"].to(DEVICE)

        with torch.no_grad():
            neo_out = neo_bert(**neo_enc).last_hidden_state  # (B,T,768)

        # Align sequence lengths
        T = min(neo_out.size(1), llama_ids.size(1))
        neo_out   = neo_out[:, :T, :]                        # (B,T,768)
        labels    = llama_ids[:, :T]                         # (B,T)

        # Forward through projector & LLaMA
        proj_emb = projector(neo_out)                  # (B, T, 2048), float32
        proj_emb = proj_emb.to(dtype=next(llm.parameters()).dtype)
        out = llm(inputs_embeds=proj_emb, labels=labels)

        loss = out.loss
        loss.backward(); optim.step(); optim.zero_grad()

        global_step += 1
        total  += loss.item() * labels.numel()
        n_tok  += labels.numel()

        if global_step % 1_000 == 0:
            torch.save(
                {
                    "epoch": epoch,
                    "step": global_step,
                    "proj_state": projector.state_dict(),
                    "opt_state":  optim.state_dict(),
                    "rng_state":  torch.get_rng_state(),
                },
                CKPT_DIR / f"step_{global_step:07d}.pt"
            )

    # Epoch checkpoint
    torch.save(
        {
            "epoch": epoch,
            "proj_state": projector.state_dict(),
            "opt_state":  optim.state_dict(),
            "rng_state":  torch.get_rng_state(),
        },
        CKPT_DIR / f"epoch_{epoch:02d}.pt"
    )

    ppl = torch.exp(torch.tensor(total / n_tok))
    print(f"Epoch {epoch}: loss {total/n_tok:.4f}  |  ppl {ppl:.2f}")


=== Training projector ===


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f0b280dfc40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1582, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/connection.py", line 948, in wait
    ready = selector.select(timeout)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
     

KeyboardInterrupt: 

In [None]:
def neo2llama(text, max_new=0):
    """Feed NeoBERT → projector → LLaMA.
       If max_new==0 → teacher forcing (should echo).
       Else           → free generation (should still echo first part)."""
    with torch.no_grad():
        neo_in = neo_tok(text, return_tensors="pt").to(DEVICE)
        bert_h = neo_bert(**neo_in).last_hidden_state       # (1, B_len, 768)
        bert_h = bert_h.to(dtype=projector.net[0].weight.dtype)
        proj_h = projector(bert_h)

        llama_in = llm_tok(text, return_tensors="pt").to(DEVICE)
        labels   = llama_in["input_ids"]  # (1, L_len)


        mask   = torch.ones(proj_h.size()[:2], dtype=torch.long, device=DEVICE)


        if max_new == 0:
            T = min(proj_h.size(1), labels.size(1))
            proj_h = proj_h[:, :T, :]
            labels = labels[:, :T]


            out    = llm(inputs_embeds=proj_h, labels=labels)
            ids    = out.logits.argmax(-1)


        else:
            ids = llm.generate(
                inputs_embeds=proj_h,
                input_ids=torch.full(
                  (1, proj_h.size(1)), fill_value=llm_tok.bos_token_id, device=DEVICE
                ),  # <-- dummy IDs to satisfy HF internals
                attention_mask=mask,
                max_new_tokens=max_new,
                eos_token_id=llm_tok.eos_token_id,
                pad_token_id=llm_tok.pad_token_id,
                do_sample=False,
            )

    return llm_tok.decode(ids[0], skip_special_tokens=True)



In [None]:
test_sentences = [
    "Quantum computers promise speed-ups that classical machines cannot match.",
    "The Aurora Borealis shimmered like green silk across the night sky.",
    "After twelve grueling innings, the underdogs finally clinched the pennant.",
    "Please email the revised PDF to marketing by 09:30 a.m. tomorrow.",
    "Serendipity often favors the curious, the prepared, and the persistent.",
    "“We’re out of espresso beans,” whispered the barista, glancing at the queue.",
    "A tiny gecko clung effortlessly to the laboratory’s glass wall.",
    "GDP growth slowed to 2.1 percent in Q2, missing analysts’ expectations.",
    "Einstein’s handwriting was notoriously hard to read, even for his assistants.",
    "The ancient oak, felled by lightning, revealed concentric rings of drought.",
    "Machine-learning models excel at interpolation but struggle with extrapolation.",
    "During the blackout, neighbors gathered in the hallway and shared candles.",
    "My passport expires in eleven months; I’d better renew it soon.",
    "Some mushrooms glow faintly in the dark due to bioluminescent enzymes.",
    "Arctic sea-ice extent hit a record low, alarming climate scientists worldwide.",
    "He typed rm -rf /*, realized the mistake, and pulled the network cable.",
    "To brew oolong correctly, let the kettle cool to eighty-five degrees Celsius.",
    "The violin’s G string snapped mid-concerto, yet the soloist played on.",
    "Cicadas emerge every seventeen years, blanketing trees in vibrating sound.",
    "Static friction exceeds kinetic friction—hence the sudden jerk when boxes slide.",
    "In 2029 NASA aims to redirect a small asteroid, proving planetary-defense tech.",
    "Her résumé boasted fluency in Sinhala, Japanese, and American Sign Language.",
    "The bookstore’s resident cat sleeps on whatever stack you’re trying to reach.",
    "Deep-sea vents host microbes that survive without sunlight, using chemosynthesis.",
    "“Ctrl + Z is my best friend,” the developer joked after a bad deploy.",
    "Sourdough starters are tiny ecosystems of yeast, lactobacilli, and hype.",
    "The committee postponed the vote until October 3rd, citing budget ambiguities.",
    "Long-form journalism thrives when readers trust slow, meticulous storytelling.",
    "A single photon triggered the avalanche photodiode, logging a digital ‘1’.",
    "The chef garnished the ramen with nori, scallions, and a seven-minute egg.",
    "For centuries, cartographers added mythical islands to fill white space.",
    "The smartwatch recorded 10,438 steps, but she still felt sedentary.",
    "PyTorch 2.3 introduced compile(), merging graph-mode speed with eager simplicity.",
    "Call me old-fashioned, but I still sync MP3s instead of streaming music.",
    "Two koalas dozed in the eucalyptus, oblivious to the drone overhead.",
    "His haiku read: autumn wind / carries lost umbrellas / into gray rivers.",
    "The conference Wi-Fi password was taped, in Comic Sans, to every chair.",
    "Riemann’s hypothesis remains unproven, yet primes obey its silent rhythm.",
    "“No spoilers!” she yelled, scrolling past the season-finale tweets.",
    "Solar panels blanketed the barn roof, outshining the rusty weather vane.",
    "A cashierless store feels eerie until you realize how fast you can leave."
]


In [None]:
#-------------------------- Generation Test ------------------------------
for sentence in test_sentences:
  print("Input sentence:   ", sentence)
  print(f"Teacher-forcing:  , {neo2llama(sentence, max_new=0)}\n")

Input sentence:    Quantum computers promise speed-ups that classical machines cannot match.
Teacher-forcing:  , Aotic computers promise--upsss that machines match match.

Input sentence:    The Aurora Borealis shimmered like green silk across the night sky.
Teacher-forcing:  , The blue Solb Nova sh shished like green w w across across night sky.

Input sentence:    After twelve grueling innings, the underdogs finally clinched the pennant.
Teacher-forcing:  , The a consecutive grulinging,ations the thesaw finally, the the the.

Input sentence:    Please email the revised PDF to marketing by 09:30 a.m. tomorrow.
Teacher-forcing:  , Please register your PDF version with customerseting by 0 033 a a..

Input sentence:    Serendipity often favors the curious, the prepared, and the persistent.
Teacher-forcing:  , Aorenceisity helpsves the curious curious the prepared,, and the persistent.

Input sentence:    “We’re out of espresso beans,” whispered the barista, glancing at the queue.
Teacher