In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import sys
import os

sys.path.append('../src')

In [3]:
import argparse
import torch
import pytorch_warmup as warmup
import wandb
from tqdm import tqdm
import yaml
import sys
import os
import matplotlib.pyplot as plt

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np

from mamba_ssm.models.config_mamba import MambaConfig

from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model

from utils import print_model_size, fix_seed
# from models.MambaWithEmbeddings import MambaLMHeadModelWithEmbeddings
# from training_functions import add_special_token, inference, unpack_batch
from models.PromptTuning import PromptTuning
from models.PeriodicTuning import PeriodicTuning
from data.Needle import NeedleCfg, make_dataloader


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
def get_logits(out):
    if hasattr(out, "logits"):
        return out.logits
    return out


def paste_pos_for_model(model, lengths: torch.Tensor) -> torch.Tensor:
    """Return positions (0-based) of the PASTE token in the *model's* logits tensor.

    `Needle` defines PASTE as the last real token, so without any prefix it is at `lengths-1`.
    Prompt/periodic wrappers change sequence length -> this index must be shifted.
    """
    # Base model: PASTE is the last real token.
    t = lengths - 1

    # PeriodicTuning: token positions are remapped as in `PeriodicTuning.forward`.
    if hasattr(model, "period") and hasattr(model, "n_prompt"):
        K = int(model.period)
        P = int(model.n_prompt)
        block = t // K
        offset = t % K
        return block * (K + P) + P + offset

    # PromptTuning: a fixed prefix of `n_prompt` tokens is prepended.
    if hasattr(model, "n_prompt") and hasattr(model, "prompt_embed"):
        return t + int(model.n_prompt)

    return t


def logits_at_paste(logits: torch.Tensor, paste_pos: torch.Tensor) -> torch.Tensor:
    b_idx = torch.arange(logits.size(0), device=logits.device)
    return logits[b_idx, paste_pos, :]  # [B, V]

@torch.no_grad()
def eval_loop(dl, model, criterion):
    model.eval()
    tot_loss, tot_correct, tot_n = 0.0, 0, 0

    for x, y, attn_mask, lengths in tqdm(dl, desc="test", leave=False):
        x = x.to(device)
        y = y.to(device)
        attn_mask = attn_mask.to(device)
        lengths = lengths.to(device)

        out = model(input_ids=x, attention_mask=attn_mask)
        logits = get_logits(out)

        paste_pos = paste_pos_for_model(model, lengths)
        last_logits = logits_at_paste(logits, paste_pos)  # [B, V]
        loss = criterion(last_logits, y)

        bs = x.size(0)
        tot_loss += loss.item() * bs
        tot_correct += (last_logits.argmax(dim=-1) == y).sum().item()
        tot_n += bs

    return tot_loss / max(tot_n, 1), tot_correct / max(tot_n, 1)

## Training model

In [7]:
VOCAB_SIZE = 16

In [8]:
cfg_train = NeedleCfg(vocab_size=VOCAB_SIZE, max_seq_len=16, seed=42, vary_length=True)
cfg_test = NeedleCfg(vocab_size=VOCAB_SIZE, max_seq_len=256, seed=42, vary_length=False)
train_dataloader = make_dataloader(n_examples=16_000, batch_size=512, cfg=cfg_train, shuffle=True, num_workers=8)
test_dataloader = make_dataloader(n_examples=100_000, batch_size=512, cfg=cfg_test, shuffle=False, num_workers=8)

In [9]:
from transformers import MambaConfig, MambaModel, MambaForCausalLM

configuration = MambaConfig(
    vocab_size = VOCAB_SIZE + 3,
    hidden_size = 8,
    state_size = 4,
    num_hidden_layers = 2,
)

fix_seed(42)
model = MambaForCausalLM(configuration)

In [10]:
model

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(19, 8)
    (layers): ModuleList(
      (0-1): 2 x MambaBlock(
        (norm): MambaRMSNorm(8, eps=1e-05)
        (mixer): MambaMixer(
          (conv1d): Conv1d(16, 16, kernel_size=(4,), stride=(1,), padding=(3,), groups=16)
          (act): SiLUActivation()
          (in_proj): Linear(in_features=8, out_features=32, bias=False)
          (x_proj): Linear(in_features=16, out_features=9, bias=False)
          (dt_proj): Linear(in_features=1, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=8, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm(8, eps=1e-05)
  )
  (lm_head): Linear(in_features=8, out_features=19, bias=False)
)

In [11]:
model.to(device)

opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=1e-2)
criterion = torch.nn.CrossEntropyLoss()

epochs = 15
for ep in range(1, epochs + 1):
    model.train()
    tot_loss, tot_correct, tot_n = 0.0, 0, 0

    for x, y, attn_mask, lengths in tqdm(train_dataloader, desc=f"train {ep}/{epochs}"):
        x = x.to(device)
        y = y.to(device)
        attn_mask = attn_mask.to(device)
        lengths = lengths.to(device)

        opt.zero_grad(set_to_none=True)

        out = model(input_ids=x, attention_mask=attn_mask)
        logits = get_logits(out)

        paste_pos = paste_pos_for_model(model, lengths)
        last_logits = logits_at_paste(logits, paste_pos)  # [B, V]
        loss = criterion(last_logits, y)
        loss.backward()
        opt.step()

        bs = x.size(0)
        tot_loss += loss.item() * bs
        tot_correct += (last_logits.argmax(dim=-1) == y).sum().item()
        tot_n += bs

    train_loss = tot_loss / max(tot_n, 1)
    train_acc = tot_correct / max(tot_n, 1)

    test_loss, test_acc = eval_loop(test_dataloader, model, criterion)
    print(f"epoch {ep}: train_loss={train_loss:.4f} train_acc={train_acc:.3f} | test_loss={test_loss:.4f} test_acc={test_acc:.3f}")


train 1/15: 100%|██████████| 32/32 [00:01<00:00, 21.23it/s]
                                                        

epoch 1: train_loss=2.7163 train_acc=0.206 | test_loss=2.8685 test_acc=0.070


train 2/15: 100%|██████████| 32/32 [00:00<00:00, 34.30it/s]
                                                        

epoch 2: train_loss=2.0748 train_acc=0.482 | test_loss=2.8236 test_acc=0.078


train 3/15: 100%|██████████| 32/32 [00:00<00:00, 37.15it/s]
                                                        

epoch 3: train_loss=0.7316 train_acc=0.839 | test_loss=4.6736 test_acc=0.148


train 4/15: 100%|██████████| 32/32 [00:00<00:00, 32.74it/s]
                                                        

epoch 4: train_loss=0.0679 train_acc=1.000 | test_loss=5.7545 test_acc=0.188


train 5/15: 100%|██████████| 32/32 [00:00<00:00, 34.02it/s]
                                                        

epoch 5: train_loss=0.0230 train_acc=1.000 | test_loss=6.0507 test_acc=0.213


train 6/15: 100%|██████████| 32/32 [00:01<00:00, 31.42it/s]
                                                        

epoch 6: train_loss=0.0138 train_acc=1.000 | test_loss=6.2124 test_acc=0.224


train 7/15: 100%|██████████| 32/32 [00:00<00:00, 37.08it/s]
                                                        

epoch 7: train_loss=0.0096 train_acc=1.000 | test_loss=6.2585 test_acc=0.237


train 8/15: 100%|██████████| 32/32 [00:00<00:00, 37.72it/s]
                                                        

epoch 8: train_loss=0.0071 train_acc=1.000 | test_loss=6.3631 test_acc=0.242


train 9/15: 100%|██████████| 32/32 [00:00<00:00, 38.04it/s]
                                                        

epoch 9: train_loss=0.0055 train_acc=1.000 | test_loss=6.4146 test_acc=0.248


train 10/15: 100%|██████████| 32/32 [00:01<00:00, 31.73it/s]
                                                        

epoch 10: train_loss=0.0044 train_acc=1.000 | test_loss=6.4269 test_acc=0.254


train 11/15: 100%|██████████| 32/32 [00:00<00:00, 35.93it/s]
                                                        

epoch 11: train_loss=0.0036 train_acc=1.000 | test_loss=6.4458 test_acc=0.259


train 12/15: 100%|██████████| 32/32 [00:00<00:00, 37.99it/s]
                                                        

epoch 12: train_loss=0.0030 train_acc=1.000 | test_loss=6.4142 test_acc=0.266


train 13/15: 100%|██████████| 32/32 [00:00<00:00, 34.19it/s]
                                                        

epoch 13: train_loss=0.0026 train_acc=1.000 | test_loss=6.4142 test_acc=0.270


train 14/15: 100%|██████████| 32/32 [00:00<00:00, 34.48it/s]
                                                        

epoch 14: train_loss=0.0022 train_acc=1.000 | test_loss=6.3935 test_acc=0.275


train 15/15: 100%|██████████| 32/32 [00:00<00:00, 33.07it/s]
                                                        

epoch 15: train_loss=0.0019 train_acc=1.000 | test_loss=6.4047 test_acc=0.277




## Training prompt

In [11]:
fix_seed(42)
model_with_prompt = PromptTuning(model, n_prompt=20)
model_with_prompt.to(device)

PromptTuning(
  (base_model): MambaForCausalLM(
    (backbone): MambaModel(
      (embeddings): Embedding(19, 8)
      (layers): ModuleList(
        (0-1): 2 x MambaBlock(
          (norm): MambaRMSNorm(8, eps=1e-05)
          (mixer): MambaMixer(
            (conv1d): Conv1d(16, 16, kernel_size=(4,), stride=(1,), padding=(3,), groups=16)
            (act): SiLUActivation()
            (in_proj): Linear(in_features=8, out_features=32, bias=False)
            (x_proj): Linear(in_features=16, out_features=9, bias=False)
            (dt_proj): Linear(in_features=1, out_features=16, bias=True)
            (out_proj): Linear(in_features=16, out_features=8, bias=False)
          )
        )
      )
      (norm_f): MambaRMSNorm(8, eps=1e-05)
    )
    (lm_head): Linear(in_features=8, out_features=19, bias=False)
  )
  (prompt_embed): Embedding(20, 8)
)

In [12]:
model_with_prompt.base_model.backbone.embeddings.weight

Parameter containing:
tensor([[ 2.4183e-01,  1.5556e-01,  7.2920e-01, -1.2094e+00,  1.2886e+00,
          4.8679e-02, -1.7976e-01, -8.4520e-03],
        [-1.5212e-01,  6.0555e-02, -1.1119e+00,  8.5851e-01,  4.6452e-01,
         -9.6510e-01,  5.1332e-01,  8.7872e-01],
        [ 9.4900e-02,  8.0048e-01,  1.2665e+00, -5.3615e-01, -1.1320e+00,
          2.6695e-01, -1.9365e-01,  4.9843e-01],
        [ 7.1381e-01,  4.9384e-01,  5.5582e-01,  1.4978e-01,  7.9409e-01,
         -1.1119e+00, -1.6840e-01,  9.7445e-01],
        [-8.1004e-02,  5.9683e-01,  3.6063e-01,  3.6787e-01, -1.1515e+00,
         -1.1423e+00,  7.4047e-01,  9.3772e-01],
        [-3.3474e-01, -8.0430e-01,  5.3718e-01,  1.0140e+00,  1.2552e+00,
         -2.2025e-01,  3.1017e-01,  9.3360e-01],
        [ 1.7565e-01,  1.0261e+00, -1.5965e-01, -1.5328e+00, -4.2284e-01,
         -7.9183e-01,  2.9912e-03,  5.6345e-01],
        [-4.8277e-01,  1.6674e-01, -1.2614e+00, -7.1049e-01, -6.6729e-01,
         -7.1693e-01,  1.2395e+00,  4.6473e

In [13]:
model_with_prompt.prompt_embed.weight

Parameter containing:
tensor([[-3.8011e-02,  4.5715e-03,  4.9719e-04, -6.9190e-03,  5.7367e-03,
         -1.4617e-02,  3.4964e-03, -2.1879e-02],
        [-3.2043e-02,  2.7058e-02,  2.5777e-02,  1.0459e-03, -3.0937e-02,
          1.5134e-02,  1.5510e-02,  4.0531e-02],
        [ 7.1635e-04,  2.4118e-03, -1.6113e-02, -4.1515e-03, -1.8639e-02,
         -3.1819e-02, -2.2720e-02, -1.0452e-02],
        [-1.0375e-02, -3.0026e-02, -3.8533e-02,  2.5570e-03,  2.0458e-02,
         -1.1116e-02,  1.4085e-02,  1.4198e-02],
        [ 3.5488e-02, -1.8431e-02,  1.9249e-02, -6.7403e-03, -2.3507e-02,
          7.1611e-03,  9.5754e-03,  2.7074e-02],
        [ 1.0521e-02,  4.2241e-02, -1.0415e-02, -1.8640e-02,  3.7032e-03,
          2.1374e-02,  2.6131e-02,  9.1967e-03],
        [-1.6293e-02, -2.0425e-02, -9.8985e-03, -1.1845e-02,  3.0863e-03,
          8.8153e-03, -2.9658e-03, -4.6369e-02],
        [-7.9599e-03,  2.1610e-02, -3.5617e-02,  3.0161e-02,  6.1886e-03,
         -1.0006e-02,  2.0700e-02,  3.3793e

In [14]:
cfg_train_prompt = NeedleCfg(vocab_size=VOCAB_SIZE, max_seq_len=256, seed=42, vary_length=True)
cfg_test_prompt = NeedleCfg(vocab_size=VOCAB_SIZE, max_seq_len=512, seed=42, vary_length=False)
train_dataloader_prompt = make_dataloader(n_examples=100_000, batch_size=512, shuffle=False, cfg=cfg_train_prompt, num_workers=8)
test_dataloader_prompt = make_dataloader(n_examples=10_000, batch_size=512, shuffle=False, cfg=cfg_test_prompt, num_workers=8)

In [15]:
model_with_prompt.to(device)

opt = torch.optim.Adam([p for p in model_with_prompt.parameters() if p.requires_grad], lr=3e-3)
criterion = torch.nn.CrossEntropyLoss()

epochs = 15
for ep in range(1, epochs + 1):
    model_with_prompt.train()
    tot_loss, tot_correct, tot_n = 0.0, 0, 0

    for x, y, attn_mask, lengths in tqdm(train_dataloader_prompt, desc=f"train {ep}/{epochs}"):
        x = x.to(device)
        y = y.to(device)
        attn_mask = attn_mask.to(device)
        lengths = lengths.to(device)

        opt.zero_grad(set_to_none=True)

        out = model_with_prompt(input_ids=x, attention_mask=attn_mask)
        logits = get_logits(out)

        paste_pos = paste_pos_for_model(model_with_prompt, lengths)
        last_logits = logits_at_paste(logits, paste_pos)

        loss = criterion(last_logits, y)
        loss.backward()
        opt.step()

        bs = x.size(0)
        tot_loss += loss.item() * bs
        tot_correct += (last_logits.argmax(dim=-1) == y).sum().item()
        tot_n += bs

    train_loss = tot_loss / max(tot_n, 1)
    train_acc = tot_correct / max(tot_n, 1)

    test_loss, test_acc = eval_loop(test_dataloader_prompt, model_with_prompt, criterion)
    print(f"epoch {ep}: train_loss={train_loss:.4f} train_acc={train_acc:.3f} | test_loss={test_loss:.4f} test_acc={test_acc:.3f}")


train 1/15:   0%|          | 0/196 [00:00<?, ?it/s]

train 1/15: 100%|██████████| 196/196 [00:02<00:00, 74.11it/s] 
                                                    

epoch 1: train_loss=1.8537 train_acc=0.642 | test_loss=8.6868 test_acc=0.169


train 2/15: 100%|██████████| 196/196 [00:03<00:00, 63.38it/s]
                                                    

epoch 2: train_loss=1.7317 train_acc=0.651 | test_loss=8.6734 test_acc=0.170


train 3/15: 100%|██████████| 196/196 [00:03<00:00, 58.93it/s]
                                                     

epoch 3: train_loss=1.7068 train_acc=0.653 | test_loss=8.6668 test_acc=0.170


train 4/15: 100%|██████████| 196/196 [00:02<00:00, 65.66it/s]
                                                     

epoch 4: train_loss=1.6967 train_acc=0.654 | test_loss=8.6719 test_acc=0.170


train 5/15: 100%|██████████| 196/196 [00:03<00:00, 63.07it/s]
                                                     

epoch 5: train_loss=1.6606 train_acc=0.657 | test_loss=8.6661 test_acc=0.170


train 6/15: 100%|██████████| 196/196 [00:02<00:00, 67.26it/s]
                                                     

epoch 6: train_loss=1.6455 train_acc=0.657 | test_loss=8.6586 test_acc=0.170


train 7/15: 100%|██████████| 196/196 [00:03<00:00, 62.21it/s]
                                                     

epoch 7: train_loss=1.6421 train_acc=0.658 | test_loss=8.6556 test_acc=0.171


train 8/15: 100%|██████████| 196/196 [00:03<00:00, 62.33it/s]
                                                     

epoch 8: train_loss=1.6378 train_acc=0.659 | test_loss=8.6497 test_acc=0.171


train 9/15: 100%|██████████| 196/196 [00:03<00:00, 62.15it/s]
                                                     

epoch 9: train_loss=1.6338 train_acc=0.659 | test_loss=8.6455 test_acc=0.170


train 10/15: 100%|██████████| 196/196 [00:03<00:00, 60.22it/s]
                                                     

epoch 10: train_loss=1.6324 train_acc=0.659 | test_loss=8.6405 test_acc=0.171


train 11/15: 100%|██████████| 196/196 [00:03<00:00, 63.53it/s]
                                                     

epoch 11: train_loss=1.6240 train_acc=0.659 | test_loss=8.6438 test_acc=0.171


train 12/15: 100%|██████████| 196/196 [00:03<00:00, 61.09it/s]
                                                     

epoch 12: train_loss=1.6253 train_acc=0.660 | test_loss=8.6401 test_acc=0.171


train 13/15: 100%|██████████| 196/196 [00:03<00:00, 61.27it/s]
                                                     

epoch 13: train_loss=1.6189 train_acc=0.659 | test_loss=8.6437 test_acc=0.171


train 14/15: 100%|██████████| 196/196 [00:03<00:00, 61.38it/s]
                                                     

epoch 14: train_loss=1.6176 train_acc=0.659 | test_loss=8.6450 test_acc=0.171


train 15/15: 100%|██████████| 196/196 [00:03<00:00, 60.83it/s]
                                                     

epoch 15: train_loss=1.6162 train_acc=0.660 | test_loss=8.6446 test_acc=0.171




In [16]:
model_with_prompt.base_model.backbone.embeddings.weight

Parameter containing:
tensor([[ 2.4183e-01,  1.5556e-01,  7.2920e-01, -1.2094e+00,  1.2886e+00,
          4.8679e-02, -1.7976e-01, -8.4520e-03],
        [-1.5212e-01,  6.0555e-02, -1.1119e+00,  8.5851e-01,  4.6452e-01,
         -9.6510e-01,  5.1332e-01,  8.7872e-01],
        [ 9.4900e-02,  8.0048e-01,  1.2665e+00, -5.3615e-01, -1.1320e+00,
          2.6695e-01, -1.9365e-01,  4.9843e-01],
        [ 7.1381e-01,  4.9384e-01,  5.5582e-01,  1.4978e-01,  7.9409e-01,
         -1.1119e+00, -1.6840e-01,  9.7445e-01],
        [-8.1004e-02,  5.9683e-01,  3.6063e-01,  3.6787e-01, -1.1515e+00,
         -1.1423e+00,  7.4047e-01,  9.3772e-01],
        [-3.3474e-01, -8.0430e-01,  5.3718e-01,  1.0140e+00,  1.2552e+00,
         -2.2025e-01,  3.1017e-01,  9.3360e-01],
        [ 1.7565e-01,  1.0261e+00, -1.5965e-01, -1.5328e+00, -4.2284e-01,
         -7.9183e-01,  2.9912e-03,  5.6345e-01],
        [-4.8277e-01,  1.6674e-01, -1.2614e+00, -7.1049e-01, -6.6729e-01,
         -7.1693e-01,  1.2395e+00,  4.6473e

In [17]:
model_with_prompt.prompt_embed.weight

Parameter containing:
tensor([[-0.0280,  0.1420,  0.0309, -0.0155,  0.2733,  0.0565,  0.1327,  0.0802],
        [-0.0886,  0.1137,  0.0175,  0.0678,  0.0887,  0.0892,  0.0547,  0.1461],
        [-0.0112, -0.0725, -0.1391,  0.0138, -0.1602, -0.0238, -0.1162, -0.0665],
        [ 0.0070,  0.0904,  0.0047, -0.0330,  0.1343, -0.0117,  0.0825,  0.0401],
        [-0.0557, -0.0796, -0.0765, -0.0560, -0.0400, -0.1036, -0.0578, -0.0970],
        [-0.0070,  0.1483,  0.0082, -0.0291,  0.1745,  0.0224,  0.1041,  0.0654],
        [-0.0194, -0.0850,  0.1024,  0.0701, -0.1315, -0.0092, -0.1197, -0.1099],
        [-0.0414, -0.0475,  0.0071, -0.0205,  0.1008, -0.0452,  0.0554, -0.0098],
        [ 0.0987,  0.0255, -0.0739, -0.0225, -0.1081,  0.0069,  0.0287,  0.0370],
        [ 0.0178,  0.0631, -0.0518,  0.0487,  0.1245,  0.0597,  0.1480, -0.0291],
        [-0.0703, -0.0044, -0.0640,  0.0743,  0.1457,  0.0636,  0.0935, -0.0589],
        [-0.0670, -0.1153, -0.0171, -0.0914, -0.0391, -0.0118, -0.0432, -0.1

## Training periodic

In [18]:
fix_seed(42)
model_with_periodic = PeriodicTuning(model, n_prompt=20, period=40)

In [19]:
model_with_periodic.base_model.backbone.embeddings.weight

Parameter containing:
tensor([[ 2.4183e-01,  1.5556e-01,  7.2920e-01, -1.2094e+00,  1.2886e+00,
          4.8679e-02, -1.7976e-01, -8.4520e-03],
        [-1.5212e-01,  6.0555e-02, -1.1119e+00,  8.5851e-01,  4.6452e-01,
         -9.6510e-01,  5.1332e-01,  8.7872e-01],
        [ 9.4900e-02,  8.0048e-01,  1.2665e+00, -5.3615e-01, -1.1320e+00,
          2.6695e-01, -1.9365e-01,  4.9843e-01],
        [ 7.1381e-01,  4.9384e-01,  5.5582e-01,  1.4978e-01,  7.9409e-01,
         -1.1119e+00, -1.6840e-01,  9.7445e-01],
        [-8.1004e-02,  5.9683e-01,  3.6063e-01,  3.6787e-01, -1.1515e+00,
         -1.1423e+00,  7.4047e-01,  9.3772e-01],
        [-3.3474e-01, -8.0430e-01,  5.3718e-01,  1.0140e+00,  1.2552e+00,
         -2.2025e-01,  3.1017e-01,  9.3360e-01],
        [ 1.7565e-01,  1.0261e+00, -1.5965e-01, -1.5328e+00, -4.2284e-01,
         -7.9183e-01,  2.9912e-03,  5.6345e-01],
        [-4.8277e-01,  1.6674e-01, -1.2614e+00, -7.1049e-01, -6.6729e-01,
         -7.1693e-01,  1.2395e+00,  4.6473e

In [20]:
model_with_periodic.prompt_embed.weight

Parameter containing:
tensor([[-3.8011e-02,  4.5715e-03,  4.9719e-04, -6.9190e-03,  5.7367e-03,
         -1.4617e-02,  3.4964e-03, -2.1879e-02],
        [-3.2043e-02,  2.7058e-02,  2.5777e-02,  1.0459e-03, -3.0937e-02,
          1.5134e-02,  1.5510e-02,  4.0531e-02],
        [ 7.1635e-04,  2.4118e-03, -1.6113e-02, -4.1515e-03, -1.8639e-02,
         -3.1819e-02, -2.2720e-02, -1.0452e-02],
        [-1.0375e-02, -3.0026e-02, -3.8533e-02,  2.5570e-03,  2.0458e-02,
         -1.1116e-02,  1.4085e-02,  1.4198e-02],
        [ 3.5488e-02, -1.8431e-02,  1.9249e-02, -6.7403e-03, -2.3507e-02,
          7.1611e-03,  9.5754e-03,  2.7074e-02],
        [ 1.0521e-02,  4.2241e-02, -1.0415e-02, -1.8640e-02,  3.7032e-03,
          2.1374e-02,  2.6131e-02,  9.1967e-03],
        [-1.6293e-02, -2.0425e-02, -9.8985e-03, -1.1845e-02,  3.0863e-03,
          8.8153e-03, -2.9658e-03, -4.6369e-02],
        [-7.9599e-03,  2.1610e-02, -3.5617e-02,  3.0161e-02,  6.1886e-03,
         -1.0006e-02,  2.0700e-02,  3.3793e

In [21]:
model_with_periodic.to(device)

opt = torch.optim.Adam([p for p in model_with_periodic.parameters() if p.requires_grad], lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

epochs = 15
for ep in range(1, epochs + 1):
    model_with_periodic.train()
    tot_loss, tot_correct, tot_n = 0.0, 0, 0

    for x, y, attn_mask, lengths in tqdm(train_dataloader_prompt, desc=f"train {ep}/{epochs}"):
        x = x.to(device)
        y = y.to(device)
        attn_mask = attn_mask.to(device)
        lengths = lengths.to(device)

        opt.zero_grad(set_to_none=True)

        out = model_with_periodic(input_ids=x, attention_mask=attn_mask)
        logits = get_logits(out)

        paste_pos = paste_pos_for_model(model_with_periodic, lengths)
        last_logits = logits_at_paste(logits, paste_pos)

        loss = criterion(last_logits, y)
        loss.backward()
        opt.step()

        bs = x.size(0)
        tot_loss += loss.item() * bs
        tot_correct += (last_logits.argmax(dim=-1) == y).sum().item()
        tot_n += bs

    train_loss = tot_loss / max(tot_n, 1)
    train_acc = tot_correct / max(tot_n, 1)

    test_loss, test_acc = eval_loop(test_dataloader_prompt, model_with_periodic, criterion)
    print(f"epoch {ep}: train_loss={train_loss:.4f} train_acc={train_acc:.3f} | test_loss={test_loss:.4f} test_acc={test_acc:.3f}")


train 1/15: 100%|██████████| 196/196 [00:03<00:00, 58.61it/s]
                                                     

epoch 1: train_loss=1.5627 train_acc=0.633 | test_loss=4.2040 test_acc=0.220


train 2/15: 100%|██████████| 196/196 [00:03<00:00, 59.39it/s]
                                                     

epoch 2: train_loss=1.2288 train_acc=0.667 | test_loss=4.3582 test_acc=0.225


train 3/15: 100%|██████████| 196/196 [00:03<00:00, 60.23it/s]
                                                     

epoch 3: train_loss=1.1356 train_acc=0.692 | test_loss=4.3665 test_acc=0.235


train 4/15: 100%|██████████| 196/196 [00:03<00:00, 61.34it/s]
                                                     

epoch 4: train_loss=1.0785 train_acc=0.711 | test_loss=4.3851 test_acc=0.237


train 5/15: 100%|██████████| 196/196 [00:03<00:00, 61.08it/s]
                                                     

epoch 5: train_loss=1.0308 train_acc=0.723 | test_loss=4.3585 test_acc=0.241


train 6/15: 100%|██████████| 196/196 [00:03<00:00, 61.79it/s]
                                                     

epoch 6: train_loss=1.0065 train_acc=0.730 | test_loss=4.3032 test_acc=0.244


train 7/15: 100%|██████████| 196/196 [00:03<00:00, 59.89it/s]
                                                     

epoch 7: train_loss=0.9880 train_acc=0.734 | test_loss=4.2478 test_acc=0.246


train 8/15: 100%|██████████| 196/196 [00:03<00:00, 60.62it/s]
                                                    

epoch 8: train_loss=0.9730 train_acc=0.738 | test_loss=4.2084 test_acc=0.246


train 9/15: 100%|██████████| 196/196 [00:03<00:00, 58.79it/s]
                                                     

epoch 9: train_loss=0.9617 train_acc=0.740 | test_loss=4.2269 test_acc=0.249


train 10/15: 100%|██████████| 196/196 [00:03<00:00, 60.19it/s]
                                                     

epoch 10: train_loss=0.9549 train_acc=0.742 | test_loss=4.1931 test_acc=0.251


train 11/15: 100%|██████████| 196/196 [00:03<00:00, 60.17it/s]
                                                     

epoch 11: train_loss=0.9459 train_acc=0.744 | test_loss=4.3036 test_acc=0.251


train 12/15: 100%|██████████| 196/196 [00:03<00:00, 60.74it/s]
                                                     

epoch 12: train_loss=0.9380 train_acc=0.746 | test_loss=4.3924 test_acc=0.253


train 13/15: 100%|██████████| 196/196 [00:03<00:00, 58.55it/s]
                                                     

epoch 13: train_loss=0.9342 train_acc=0.746 | test_loss=4.4192 test_acc=0.254


train 14/15: 100%|██████████| 196/196 [00:03<00:00, 62.12it/s]
                                                     

epoch 14: train_loss=0.9266 train_acc=0.748 | test_loss=4.4998 test_acc=0.252


train 15/15: 100%|██████████| 196/196 [00:03<00:00, 61.22it/s]
                                                     

epoch 15: train_loss=0.9155 train_acc=0.751 | test_loss=4.4584 test_acc=0.254




In [22]:
model_with_periodic.base_model.backbone.embeddings.weight

Parameter containing:
tensor([[ 2.4183e-01,  1.5556e-01,  7.2920e-01, -1.2094e+00,  1.2886e+00,
          4.8679e-02, -1.7976e-01, -8.4520e-03],
        [-1.5212e-01,  6.0555e-02, -1.1119e+00,  8.5851e-01,  4.6452e-01,
         -9.6510e-01,  5.1332e-01,  8.7872e-01],
        [ 9.4900e-02,  8.0048e-01,  1.2665e+00, -5.3615e-01, -1.1320e+00,
          2.6695e-01, -1.9365e-01,  4.9843e-01],
        [ 7.1381e-01,  4.9384e-01,  5.5582e-01,  1.4978e-01,  7.9409e-01,
         -1.1119e+00, -1.6840e-01,  9.7445e-01],
        [-8.1004e-02,  5.9683e-01,  3.6063e-01,  3.6787e-01, -1.1515e+00,
         -1.1423e+00,  7.4047e-01,  9.3772e-01],
        [-3.3474e-01, -8.0430e-01,  5.3718e-01,  1.0140e+00,  1.2552e+00,
         -2.2025e-01,  3.1017e-01,  9.3360e-01],
        [ 1.7565e-01,  1.0261e+00, -1.5965e-01, -1.5328e+00, -4.2284e-01,
         -7.9183e-01,  2.9912e-03,  5.6345e-01],
        [-4.8277e-01,  1.6674e-01, -1.2614e+00, -7.1049e-01, -6.6729e-01,
         -7.1693e-01,  1.2395e+00,  4.6473e

In [23]:
model_with_periodic.prompt_embed.weight

Parameter containing:
tensor([[-0.0601,  0.0605,  0.0152,  0.0140, -0.0243,  0.0013, -0.0131,  0.0038],
        [-0.0229,  0.1141,  0.0576,  0.0282, -0.0324, -0.0101,  0.0030,  0.0182],
        [-0.0279, -0.0049, -0.0149, -0.0209,  0.0130, -0.0046, -0.0323, -0.0343],
        [-0.0203, -0.0145, -0.0390,  0.0074,  0.0103,  0.0168, -0.0348,  0.0213],
        [ 0.0210, -0.0305,  0.0225, -0.0204, -0.0488, -0.0071,  0.0186,  0.0432],
        [ 0.0358,  0.0547,  0.0564, -0.0489, -0.0746,  0.0285,  0.0048, -0.0122],
        [-0.0347, -0.0166,  0.0041, -0.0240, -0.0026, -0.0174,  0.0501, -0.0409],
        [-0.0026, -0.0173,  0.0213, -0.0057, -0.0062,  0.0163,  0.0742, -0.0329],
        [ 0.0264,  0.0556,  0.0036,  0.0306, -0.0446,  0.1295,  0.0458, -0.0354],
        [-0.0065, -0.0247, -0.0344, -0.0074,  0.0116,  0.0306,  0.0331,  0.0013],
        [-0.0331,  0.0019, -0.0113,  0.0375, -0.0341,  0.0678,  0.0084, -0.0102],
        [-0.0079,  0.0104, -0.0051,  0.0403, -0.0162,  0.0506, -0.0270, -0.0