In [1]:
# Get hidden states by context level and by layer

In [6]:
import os
os.environ['HF_HOME'] = '/sbksvol/amurali/'

from huggingface_hub import login
login(token = "<hf_token>")

import transformers
import torch

import csv
import numpy as np
import pandas as pd

from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM, pipeline, LlamaForCausalLM, GenerationConfig

from tqdm import tqdm

In [3]:
print(f"PyTorch Version: {torch.__version__}")
print(f"Transformers Version: {transformers.__version__}")
print(f"NumPy Version: {np.__version__}")

PyTorch Version: 2.2.2+cu121
Transformers Version: 4.33.3
NumPy Version: 1.26.4


In [10]:
# Load model/create pipeline
model_name = "ivnle/llamatales_jr_8b-lay8-hs512-hd8-33M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True).to("cuda")
# llamatales_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
# gen_config = GenerationConfig(
#     do_sample=True,
#     top_k=10,
#     max_new_tokens=512,
#     output_hidden_states=True,
#     return_dict_in_generate=True
# )

# Prompts
prompts = {
    1: "Once upon a time there was a dragon", 
    2: "Once upon a time there was a princess", 
    3: "Once upon a time there were two children",
    4: "Once upon a time there was a prince",
    5: "Once upon a time there was a frog",
    6: "Once upon a time there was a king",
    7: "Once upon a time there was a queen",
    8: "Once upon a time there was a wolf",
    9: "Once upon a time there was a genie",
    10: "Once upon a time there was a poor boy"
}


for prompt_id, prompt_text in prompts.items():
    print(f"Prompt {prompt_id}: \"{prompt_text}\"")
    data = []
    npz_data = {}
    hidden_state_file = f'./hidden_states/prompt_{prompt_id}.npz'
    for i in tqdm(range(1000)):
        # Generate using pipeline
        # sequences = llamatales_pipeline(
        #     prompt_text,
        #     do_sample=True,
        #     top_k=10,
        #     num_return_sequences=1,
        #     max_new_tokens=100
        # )
        # generated_story = sequences[0]['generated_text']
#         print(f"\nGenerated (pipeline):\n{generated_story}")

        # Count tokens
        # num_tokens_generated_story = len(tokenizer.encode(generated_story))

        # Generate hidden states from model.generate using same prompt
        inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            do_sample=True,
            top_k=10,
            num_return_sequences=1,
            max_new_tokens=512,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            return_dict_in_generate=True,
            output_hidden_states=True,
            use_cache = False
        )

        # outputs = model.generate(
        #     inputs.input_ids,
        #     attention_mask=inputs.attention_mask,
        #     num_return_sequences=1,
        #     eos_token_id=tokenizer.eos_token_id,
        #     pad_token_id=tokenizer.pad_token_id,
        #     generation_config=gen_config
        # )
        
        generated_story = tokenizer.batch_decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
#         print("New Story: \n", new_story)
        
#         print("Length of Generated Story: ", num_tokens_generated_story)
#         print("Length of New Story: ", len(tokenizer.encode(new_story)))
        
        # Map output tokens to prompt ID
        # output_id = [prompt_id] * num_tokens_generated_story

        # Extract hidden states
        convert_hidden_states = []
#         print(f"Type of hidden_states: {type(outputs.hidden_states)}")
        for layer_hidden_states in outputs.hidden_states:
            convert_hidden_states.append([t.detach().cpu().numpy() for t in layer_hidden_states])

        print(outputs.hidden_states[0][0].shape)

        # Debug hidden state shape
        print(f"Hidden states shape: {len(convert_hidden_states)} generated tokens × "
              f"{len(convert_hidden_states[0])} layers × "
              f"{len(convert_hidden_states[0][0])} sequences × "
              f"{len(convert_hidden_states[0][0][0])} tokens × " 
              f"{len(convert_hidden_states[0][0][0][0])} dims")

        for hs in convert_hidden_states:
            print(np.linalg.norm(hs))

        for n in range(len(convert_hidden_states)):
            print(f"{len(convert_hidden_states[n])} layers × "
                  f"{len(convert_hidden_states[n][0])} sequences × "
                  f"{len(convert_hidden_states[n][0][0])} tokens × " 
                  f"{len(convert_hidden_states[n][0][0][0])} dims")

#         print("Num Tokens: ", len(convert_hidden_states[0][0][0])) #same as context (length of original generated story)

        # Save hidden states to file
        # arr = np.empty(len(convert_hidden_states), dtype=object)
        # arr[:] = convert_hidden_states
        # npz_data[f"arr_{i}"] = arr
        # npz_data[f"arr_{i}"] = convert_hidden_states
        
        num_tokens_generated_story = len(tokenizer.encode(generated_story))
        
        # Store results
        data.append([prompt_id, prompt_text, generated_story, hidden_state_file, num_tokens_generated_story])
    
    np.savez_compressed(hidden_state_file, **npz_data)
    
    df = pd.DataFrame(data, columns=["prompt_id", "prompt", "story", "hidden_state_file", "len_generated_story"])
    if(prompt_id == 1): df.to_csv("story_dataset_new.csv", index = False)
    else:
        df.to_csv('story_dataset_new.csv', mode='a', header = False, index = False)
        


Prompt 1: "Once upon a time there was a dragon"


  0%|                                                                                  | 1/1000 [00:02<34:10,  2.05s/it]

torch.Size([1, 9, 512])
Hidden states shape: 263 generated tokens × 9 layers × 1 sequences × 9 tokens × 512 dims
3053.3718
3094.3384
3126.7314
3157.2288
3200.8206
3247.9143
3281.729
3312.0684
3350.6958
3380.3496
3406.9055
3436.44
3478.0977
3515.0015
3554.1714
3586.1753
3612.1433
3641.903
3668.8616
3696.6523
3721.286
3748.007
3773.895
3802.56
3829.883
3860.369
3887.3433
3912.3137
3942.7844
3977.7188
4003.6897
4027.5972
4051.592
4072.6257
4098.155
4119.4585
4143.108
4166.0205
4189.029
4215.0415
4240.1675
4265.683
4290.179
4313.0977
4337.406
4365.9434
4393.7812
4425.125
4446.6313
4472.0
4500.589
4518.868
4539.7314
4566.8457
4594.654
4616.971
4637.8447
4660.352
4685.661
4709.8403
4730.2104
4750.6304
4777.8066
4803.3613
4829.498
4850.305
4870.337
4897.1577
4925.2246
4952.6206
4972.509
4993.458
5010.1777
5029.0845
5050.328
5069.8184
5088.608
5111.9604
5137.178
5153.6147
5171.9263
5189.284
5207.274
5226.2266
5246.4565
5267.2095
5282.442
5300.578
5337.922
5357.214
5378.009
5393.27
5409.5938
54

  0%|▏                                                                                 | 2/1000 [00:04<34:30,  2.08s/it]

torch.Size([1, 9, 512])
Hidden states shape: 266 generated tokens × 9 layers × 1 sequences × 9 tokens × 512 dims
3053.3718
3084.9817
3162.695
3209.57
3245.1594
3308.5852
3339.002
3376.4055
3407.385
3433.5278
3467.8738
3509.7734
3554.0366
3587.3052
3614.9897
3638.2275
3663.7222
3687.9421
3718.2493
3746.8508
3769.3975
3799.533
3827.0156
3854.8447
3884.5405
3923.6511
3949.67
3978.7317
4002.0762
4028.5505
4058.5583
4084.4858
4105.7007
4131.026
4156.4634
4182.852
4208.631
4238.416
4270.171
4296.577
4322.2026
4345.0664
4367.879
4392.205
4418.022
4440.5347
4459.6455
4484.522
4507.443
4529.0825
4552.2886
4576.919
4596.5005
4619.138
4642.4585
4673.472
4694.5625
4717.0854
4738.962
4762.275
4782.666
4804.354
4830.5254
4853.838
4877.42
4901.7407
4922.7837
4942.674
4961.7134
4986.156
5006.023
5028.996
5045.7485
5062.5576
5080.734
5101.058
5126.545
5148.3564
5166.5303
5203.4175
5222.1055
5239.52
5259.7563
5279.5225
5296.346
5314.228
5334.584
5352.3135
5369.8867
5390.8525
5409.5693
5440.2803
5457.527

  0%|▏                                                                                 | 3/1000 [00:05<30:21,  1.83s/it]

torch.Size([1, 9, 512])
Hidden states shape: 204 generated tokens × 9 layers × 1 sequences × 9 tokens × 512 dims
3053.3718
3094.3384
3135.8726
3173.915
3218.2703
3251.4375
3293.0483
3326.848
3356.2063
3384.368
3412.8853
3441.429
3473.3438
3497.2078
3525.0715
3551.7527
3585.8325
3628.6038
3648.4036
3670.772
3699.4575
3729.054
3754.7102
3780.9346
3814.5903
3851.502
3876.7834
3911.9404
3942.3826
3971.4307
3996.7178
4020.0205
4044.272
4067.0764
4093.7175
4140.015
4166.7773
4191.1147
4224.1
4272.9834
4304.8584
4326.267
4346.816
4374.379
4399.142
4423.176
4451.3643
4473.34
4501.5747
4526.7017
4555.4517
4579.4897
4600.7812
4622.114
4640.894
4662.201
4684.391
4704.2627
4726.5137
4752.74
4774.3066
4796.685
4814.4565
4833.0195
4857.6997
4879.4004
4905.841
4937.4717
4956.5176
4975.0107
5000.6064
5023.722
5047.782
5075.606
5106.9155
5123.4785
5141.8325
5163.687
5183.208
5201.716
5222.6636
5249.0674
5268.402
5286.0215
5326.5527
5344.173
5360.341
5379.605
5396.7183
5416.225
5437.233
5462.1113
5479.2

  0%|▎                                                                                 | 4/1000 [00:07<27:57,  1.68s/it]

torch.Size([1, 9, 512])
Hidden states shape: 195 generated tokens × 9 layers × 1 sequences × 9 tokens × 512 dims
3053.3718
3094.3384
3132.9514
3170.9998
3207.078
3244.5637
3285.4312
3318.6035
3348.8723
3384.4553
3418.5063
3443.9575
3478.2354
3515.3135
3541.5234
3570.2812
3597.6345
3623.321
3649.4077
3679.0771
3706.751
3731.1816
3757.1375
3783.252
3806.8452
3831.1716
3879.3894
3916.083
3964.4126
4007.195
4048.2988
4091.7747
4118.5835
4143.769
4166.2754
4189.18
4212.4985
4238.6626
4262.2153
4286.932
4308.5835
4328.8257
4351.171
4382.9585
4419.634
4444.719
4471.3027
4495.77
4517.443
4540.443
4565.2456
4585.6772
4606.9927
4627.013
4651.183
4678.764
4700.5967
4721.237
4743.1343
4770.42
4788.244
4808.0537
4827.079
4846.3335
4871.319
4890.5337
4910.824
4930.6724
4951.7505
4978.421
5001.6895
5026.669
5050.2324
5069.726
5091.3916
5108.7163
5127.241
5145.837
5166.1787
5184.6416
5202.7993
5219.6177
5236.288
5253.173
5269.4526
5285.345
5305.3633
5330.0894
5347.542
5365.0493
5382.4478
5400.684
5418

  0%|▍                                                                                 | 5/1000 [00:09<29:43,  1.79s/it]

torch.Size([1, 9, 512])
Hidden states shape: 255 generated tokens × 9 layers × 1 sequences × 9 tokens × 512 dims
3053.3718
3094.3384
3126.7314
3157.2288
3191.1196
3235.018
3271.4104
3314.175
3343.6672
3372.3755
3398.9868
3432.9717
3466.172
3511.7366
3537.5554
3565.1208
3592.4048
3619.269
3645.9539
3673.1414
3700.103
3735.3691
3774.7014
3807.0557
3836.5903
3866.0344
3895.9182
3922.8691
3946.9414
3971.5286
3996.5854
4022.6978
4048.9336
4072.8557
4099.2837
4128.2876
4153.0776
4182.579
4207.2583
4230.105
4253.6636
4276.6245
4304.459
4328.3525
4352.2603
4373.6787
4394.5776
4418.764
4440.251
4462.8604
4486.592
4510.462
4530.7944
4551.6206
4572.77
4596.452
4624.471
4653.663
4678.448
4707.508
4734.6187
4761.238
4786.8228
4815.6777
4836.0386
4857.036
4875.5024
4896.682
4915.994
4937.044
4958.5947
4979.7207
5001.5205
5022.178
5054.6943
5074.5703
5092.872
5113.5273
5132.087
5151.6714
5170.647
5188.6665
5205.895
5222.348
5239.381
5256.775
5280.6616
5297.9624
5318.134
5336.4463
5354.2505
5373.201
5

  1%|▍                                                                                 | 6/1000 [00:11<31:22,  1.89s/it]

torch.Size([1, 9, 512])
Hidden states shape: 268 generated tokens × 9 layers × 1 sequences × 9 tokens × 512 dims
3053.3718
3094.3384
3126.7314
3157.2288
3191.1196
3226.9062
3271.578
3306.4849
3350.8716
3391.2053
3433.23
3477.2722
3507.8438
3538.1504
3587.7693
3626.0886
3661.412
3692.3315
3732.9658
3768.448
3799.6426
3824.119
3852.9492
3879.0242
3903.955
3936.7842
3971.1675
3995.78
4024.998
4050.8655
4078.4226
4108.35
4142.106
4167.183
4194.7593
4220.511
4243.9004
4269.7505
4298.6187
4327.177
4354.5605
4390.948
4414.0435
4438.066
4458.9487
4483.4736
4504.8984
4525.06
4550.3325
4571.527
4597.229
4621.308
4642.793
4670.217
4697.199
4719.224
4739.4487
4760.296
4778.5947
4797.641
4816.951
4834.2437
4852.6904
4876.1943
4899.394
4916.1934
4934.0596
4955.6836
4974.099
4997.486
5017.384
5035.4185
5059.568
5080.7847
5098.877
5120.27
5139.3984
5156.1865
5173.461
5191.3604
5208.407
5226.544
5244.4614
5263.9546
5280.402
5299.3833
5315.651
5332.336
5349.5073
5366.8086
5383.965
5401.4497
5416.6123
54

  1%|▌                                                                                 | 7/1000 [00:12<29:47,  1.80s/it]

torch.Size([1, 9, 512])
Hidden states shape: 215 generated tokens × 9 layers × 1 sequences × 9 tokens × 512 dims
3053.3718
3094.3384
3126.664
3161.242
3190.6506
3224.5154
3258.2031
3294.883
3327.5876
3356.62
3386.0637
3415.5413
3449.8762
3479.3818
3504.5295
3544.2415
3585.6182
3612.5876
3641.754
3677.5908
3707.6274
3740.2822
3778.6252
3808.657
3843.932
3872.2454
3898.2944
3923.009
3948.36
3973.0933
3999.568
4025.4966
4046.541
4074.7527
4112.829
4142.8145
4162.3687
4185.6074
4208.4927
4231.548
4257.3394
4279.2446
4303.5815
4325.398
4345.6143
4367.111
4390.5396
4412.9434
4432.3687
4460.854
4484.703
4506.1665
4530.726
4550.9243
4571.3027
4591.85
4613.514
4633.2305
4653.7026
4673.8413
4695.4746
4714.3354
4733.9277
4752.5195
4775.3315
4795.466
4814.9907
4834.1274
4852.351
4871.6377
4890.369
4910.9897
4939.5557
4963.599
4981.765
5002.8687
5021.7207
5040.8384
5059.5103
5080.339
5100.6753
5123.162
5145.219
5170.0786
5196.067
5217.217
5234.8276
5257.148
5274.8486
5295.1543
5314.331
5334.179
535

  1%|▌                                                                                 | 7/1000 [00:14<34:38,  2.09s/it]


KeyboardInterrupt: 

In [None]:
# Load model/create pipeline
model_name = "ivnle/llamatales_jr_8b-lay8-hs512-hd8-33M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True).to("cuda")

# Prompts
prompts = {
    1: "Once upon a time there was a dragon", 
    2: "Once upon a time there was a princess", 
    3: "Once upon a time there were two children",
    4: "Once upon a time there was a prince",
    5: "Once upon a time there was a frog",
    6: "Once upon a time there was a king",
    7: "Once upon a time there was a queen",
    8: "Once upon a time there was a wolf",
    9: "Once upon a time there was a genie",
    10: "Once upon a time there was a poor boy"
}

In [6]:
for prompt in prompts.values():
    print(len(tokenizer.encode(prompt)))

9
9
9
9
9
9
9
9
9
10


In [None]:
# for i in range(1000):
#     if(len(npz_data[f"arr_{i}"][0][0][0][0]) != 512):
#         # print(i)
#         print(len(npz_data[f"arr_{i}"]))

In [None]:
# npz_data["arr_1"][0][0].shape

In [None]:
#100x9x1x9x512

In [None]:
df = pd.read_csv("story_dataset_new.csv")
df