In [1]:
import torch
import transformers
from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config



# Dataset generation

In [2]:
gen_config = GPT2Config(
    vocab_size=32,
    n_positions=1024,
    n_embd=16,
    n_layer=2,
    n_head=4,
    n_inner=None,
    activation_function="gelu_new",
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
)

gen_model = GPT2LMHeadModel(gen_config)

In [3]:
n_seqs = 1100
seq_length = 100
test_size = 100

prompts = torch.randint(16, (n_seqs, 32))

in_seqs = gen_model.generate(
    inputs = prompts,
    max_new_tokens=seq_length,
    do_sample=True,
    top_k=0,
)[:,32:].detach()
in_seqs.shape

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


torch.Size([1100, 100])

In [24]:
gen_model.generate(
    inputs = prompts,
    max_new_tokens=132,
    do_sample=True,
    top_k=0,
).shape

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


torch.Size([1100, 164])

In [4]:
from transformers import BertModel, BertConfig
import torch

def backward(self, *args, **kwargs):
    print(n_passes)
    n_passes += 1
    return self._old_backward(*args, **kwargs)
    

class BertEmbeddor(BertModel):
    def __init__(self, config=None, **kwargs):
        super().__init__(config, **kwargs)
        self.out_size = getattr(config, 'out_size', 0) or config.hidden_size
        assert self.out_size <= config.hidden_size
        self.loss = torch.nn.MSELoss()
        
    def forward(self, input_ids, labels=None, **kwargs):
        #global n_passes
        #n_passes = 0
        with torch.autograd.detect_anomaly():
            output = super().forward(input_ids=input_ids, **kwargs).last_hidden_state[...,:self.out_size]
            if labels is not None:
                loss = self.loss(output, labels)
                #loss._old_backward = loss.backward
                #loss.backward = backward.__get__(loss)
                return loss, output.detach()
        return output        
        

In [22]:
def method(self, arg):
    print(self, arg)
    
class Fun:
    pass
f = Fun()
print(vars(f))

f = Fun()
f.run = method.__get__(f)
f.run(2)

{}
<__main__.Fun object at 0x7f0371d840d0> 2


In [20]:
vars(f)

{'run': <bound method method of <__main__.Fun object at 0x7f014dbe9030>>}

In [6]:
true_config = BertConfig(
    vocab_size=32,
    hidden_size=16,
    num_hidden_layers=2,
    num_attention_heads=4,
    intermediate_size=32,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
)

true_process = BertEmbeddor(true_config)

In [7]:
out_seq = true_process(in_seqs).detach()
out_seq.shape

  with torch.autograd.detect_anomaly():
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


torch.Size([1100, 100, 16])

In [29]:
in_seqs.numpy()

(1100, 100)

In [32]:
out_seq.numpy()

array([[[-2.32138336e-01,  4.04119998e-01,  4.49637115e-01, ...,
          1.12798786e+00,  9.37545538e-01, -1.25721264e+00],
        [-1.09192860e+00,  2.27750745e-02,  1.65723300e+00, ...,
         -1.72552323e+00,  5.52578509e-01,  1.43263984e+00],
        [ 1.96849740e+00, -1.89418435e-01, -5.67110538e-01, ...,
          7.32037649e-02,  1.34004724e+00, -2.08085492e-01],
        ...,
        [ 7.38770962e-01,  7.88826168e-01, -7.56235957e-01, ...,
          5.25630005e-02,  2.06012702e+00, -1.36556506e+00],
        [ 2.48281121e-01,  1.48547143e-01,  3.16754691e-02, ...,
          1.30367851e+00,  1.03389490e+00, -3.73424813e-02],
        [-7.96971619e-01, -1.79041788e-01, -8.16561997e-01, ...,
          1.45417917e+00,  1.06448293e+00,  1.95972309e-01]],

       [[ 6.25927508e-01,  5.30399442e-01,  9.77656305e-01, ...,
          1.01721370e+00,  6.38967812e-01, -7.20833302e-01],
        [ 6.34040087e-02, -1.71405911e+00,  3.50597128e-02, ...,
         -7.55399346e-01,  6.27350748e

In [8]:
sum(p.numel() for p in true_process.parameters())

21680

# Model training

In [9]:
class EmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids, labels, attention_mask=None):
        self.input_ids = torch.as_tensor(input_ids, dtype=torch.long)
        self.labels = torch.as_tensor(labels, dtype=torch.float)
        self.attention_mask = torch.as_tensor(attention_mask) if attention_mask is not None else None

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if self.attention_mask is not None:
            return {
                'input_ids': self.input_ids[idx],
                'attention_mask': self.attention_mask[idx],
                'labels': self.labels[idx],   # embeddings as labels
            }
        return {
            'input_ids': self.input_ids[idx],
            'labels': self.labels[idx],  # embeddings as labels
        }

test_data = EmbeddingDataset(in_seqs[:test_size], labels=out_seq[:test_size])
train_data = EmbeddingDataset(in_seqs[test_size:], labels=out_seq[test_size:])


In [10]:
model_config = BertConfig(
    out_size=16,
    vocab_size=32,
    hidden_size=32,
    num_hidden_layers=4,
    num_attention_heads=8,
    intermediate_size=64,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
)

model = BertEmbeddor(model_config)

In [57]:
from transformers import TrainerCallback, Trainer, TrainingArguments, EarlyStoppingCallback
import csv
import os

class TestLossCallback(TrainerCallback):

    def __init__(self, logging_dir, file_name="test_loss.csv"):
        # Set the output file path to the specified logging directory and file name
        self.output_file = os.path.join(logging_dir, file_name)
        
        # Write header to the CSV file
        with open(self.output_file, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["epoch", "test_loss"])

    def on_evaluate(self, args, state, control, **kwargs):
        # Get validation loss from the logs
        test_loss = state.log_history[-1]["eval_loss"]
        epoch = state.epoch

        # Append validation loss to the file
        with open(self.output_file, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch, test_loss])

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=100,
    per_device_train_batch_size=8,
    no_cuda=False,
    warmup_steps=10,
    weight_decay=0.01,
    save_steps=100,
    eval_steps=100,
    save_total_limit=2,
    save_strategy='epoch',
    logging_dir='./logs',
    evaluation_strategy="epoch",
    load_best_model_at_end = True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=10, early_stopping_threshold=1e-5),
        TestLossCallback(logging_dir=training_args.logging_dir, file_name="test_loss.csv")
    ], 
)



In [53]:
torch.set_anomaly_enabled(True, True)
torch.is_anomaly_enabled()
#model

True

In [58]:
trainer.train()

  with torch.autograd.detect_anomaly():


Epoch,Training Loss,Validation Loss
1,No log,0.110869
2,No log,0.105959
3,No log,0.10381
4,0.154100,0.10291
5,0.154100,0.102388
6,0.154100,0.101862
7,0.154100,0.101634
8,0.141900,0.101308
9,0.141900,0.101115
10,0.141900,0.100903


  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():
  with torch.autograd.detect_anomaly():


TrainOutput(global_step=9875, training_loss=0.13115360385556765, metrics={'train_runtime': 1161.4403, 'train_samples_per_second': 86.1, 'train_steps_per_second': 10.762, 'total_flos': 1673030400000.0, 'train_loss': 0.13115360385556765, 'epoch': 79.0})

In [16]:
transformers.__version__

'4.46.0'