In [1]:
import torch
from transformers import AutoModel, AutoTokenizer, BertForSequenceClassification
from functorch import jacrev, make_functional_with_buffers
import gc
from torch import nn
from torch.nn.functional import relu
from torch.autograd import grad

import numpy as np
import torch
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
from torch.utils.data import Dataset
import logging

from datasets import load_dataset

raw_datasets  = load_dataset("glue", 'sst2')

In [2]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig
#from roberta import RobertaForSequenceClassification


model_name = "FacebookAI/roberta-base"

#config.num_labels=2
tokenizer = AutoTokenizer.from_pretrained(model_name)



In [3]:
from transformers import AutoTokenizer, DataCollatorWithPadding




def preprocessing_function(examples):
    return tokenizer(examples['sentence'], truncation=True, max_length=128)


tokenized_datasets = raw_datasets.map(preprocessing_function, batched=True)
# llama_tokenized_datasets = llama_tokenized_datasets.rename_column("target", "label")
tokenized_datasets.set_format("torch")

# Data collator for padding a batch of examples to the maximum length seen in the batch
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [4]:
import torch
import torch.nn as nn
from transformers import RobertaForSequenceClassification
from transformers.activations import ACT2FN
import random



model = RobertaForSequenceClassification.from_pretrained(model_name, num_labels=2).to('cuda')

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
import leader

leader.PEFT(model, method='row', targets=['key'], rank=1) 
#targets=['key', 'value', 'dense', 'query'])
# method = 'row', 'column', 'random'

In [6]:
import evaluate
import numpy as np
from sklearn import metrics
import torch
import numpy as np

def compute_metrics(eval_pred):


    logits, labels = eval_pred # eval_pred is the tuple of predictions and labels returned by the model
    predictions = np.argmax(logits, axis=-1)
    
    precision = metrics.precision_score(labels, predictions, average="macro")
    recall = metrics.recall_score(labels, predictions, average="macro")
    f1 = metrics.f1_score(labels, predictions, average="macro")
    accuracy = metrics.accuracy_score(labels, predictions)
    
    return {"precision": precision, "recall": recall, "f1-score": f1, 'accuracy': accuracy}

In [7]:
from transformers import TrainingArguments, Trainer

import time
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
    output_dir='dir',
    learning_rate=2e-4,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.00,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_total_limit=2,
    save_steps=10000000,
    logging_steps=100,
   
    load_best_model_at_end=True,
    lr_scheduler_type="cosine",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.
    warmup_steps=100,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],

    data_collator=data_collator,
    compute_metrics=compute_metrics
)



In [8]:
tokenized_datasets["validation"]['sentence'][0:10]

["it 's a charming and often affecting journey . ",
 'unflinchingly bleak and desperate ',
 'allows us to hope that nolan is poised to embark a major career as a commercial yet inventive filmmaker . ',
 "the acting , costumes , music , cinematography and sound are all astounding given the production 's austere locales . ",
 "it 's slow -- very , very slow . ",
 'although laced with humor and a few fanciful touches , the film is a refreshingly serious look at young women . ',
 'a sometimes tedious film . ',
 "or doing last year 's taxes with your ex-wife . ",
 "you do n't have to know about music to appreciate the film 's easygoing blend of comedy and romance . ",
 "in exactly 89 minutes , most of which passed as slowly as if i 'd been sitting naked on an igloo , formula 51 sank from quirky to jerky to utter turkey . "]

In [9]:
#trainer.train()

In [10]:
for name, param in model.named_parameters():
    param.requires_grad = False

# Unfreeze specific layers by name
for name, param in model.named_parameters():
    if 'classifier' in name:
        param.requires_grad = True

In [11]:
# Function to tokenize input text
def tokenize_input(texts, tokenizer, max_length=10):
    return tokenizer(
        texts,
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )

# Example input sets
texts_set_1 = tokenized_datasets["validation"]['sentence'][0:5]
texts_set_2 = tokenized_datasets["validation"]['sentence'][5:8]


# Tokenize inputs
input_set_1 = tokenize_input(texts_set_1, tokenizer)
input_set_2 = tokenize_input(texts_set_2, tokenizer)

# Move inputs to the same device as the model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_set_1 = {k: v.to(device) for k, v in input_set_1.items()}
input_set_2 = {k: v.to(device) for k, v in input_set_2.items()}
model = model.to(device)

In [13]:
import torch
from transformers import BertForSequenceClassification, BertTokenizer
import torch.nn as nn
from functorch import make_functional_with_buffers, vmap, jacrev


# Get word embeddings instead of text input
def get_word_embeddings(input_ids):
    with torch.no_grad():
        outputs = model.roberta.embeddings(input_ids)
    return outputs

# We will focus on the last layer output
class BertLastLayer(nn.Module):
    def __init__(self, bert_model):
        super(BertLastLayer, self).__init__()
        self.bert_model = bert_model
    
    def forward(self, embeddings):
        # Get the last hidden states from BERT
        outputs = self.bert_model(inputs_embeds=embeddings)
        # We are only interested in the last hidden state
        last_hidden_state = outputs.logits
        return last_hidden_state  # CLS token's representation for simplicity

# Example usage:
device = 'cuda'
bert_last_layer = BertLastLayer(model).to(device)

# Example sentence
sentence = ["hello", 'world']

# Tokenize the sentence
#inputs = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True).to(device)
input_ids_train = input_set_1['input_ids']

# Get word embeddings for the sentence
x_train = get_word_embeddings(input_ids_train)

input_ids_test = input_set_2['input_ids']

# Get word embeddings for the sentence
x_test = get_word_embeddings(input_ids_test)

# Convert the BERT model to a functional model using functorch, including buffers
fnet, params, buffers = make_functional_with_buffers(bert_last_layer)

# Function for a single pass through the functional model
def fnet_single(params, buffers, x):
    return fnet(params, buffers, x.unsqueeze(0)).squeeze(0)

# NTK Calculation (similar to your original code)
def empirical_ntk_jacobian_contraction(fnet_single, params, buffers, x1, x2):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, None, 0))(params, buffers, x1)
    jac1 = [j.flatten(2) for j in jac1]

    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, None, 0))(params, buffers, x2)
    jac2 = [j.flatten(2) for j in jac2]

    # Compute J(x1) @ J(x2).T
    result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result



# Compute NTK
result = empirical_ntk_jacobian_contraction(fnet_single, params, buffers, x_train, x_test)
print(result.shape)


torch.Size([5, 3, 2, 2])


  warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')
  warn_deprecated('jacrev')
  warn_deprecated('vmap', 'torch.vmap')
