In [1]:
# Install and import the necessary libraries
!pip install -q -U torch
!pip install -q -U accelerate peft bitsandbytes transformers trl einops evaluate
!pip install -q -U tqdm
!pip install -q -U git+https://github.com/sissa-data-science/DADApy

In [2]:
import os
import torch
from datasets import load_dataset
from datasets import load_from_disk
from peft import LoraConfig, prepare_model_for_kbit_training, PeftModel
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    AutoTokenizer,
    TrainingArguments,
    pipeline,
    logging,
    DistilBertModel,
    DistilBertTokenizer,
)
from tqdm import tqdm
from trl import SFTTrainer
from tqdm import tqdm
import gc
import matplotlib.pyplot as plt
from dadapy.data import Data
import numpy as np

2024-06-11 05:00:17.362864: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-11 05:00:17.362967: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-11 05:00:17.496611: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# Set device
torch.set_default_device("cuda")

# Load pre-trained model and tokenizer
model_name = "distilbert/distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name,
                                          add_eos_token=True,
                                          trust_remote_code=True)

tokenizer.pad_token = '[PAD]'
tokenizer.truncation_side = "right"

# Load dataset
dataset = load_dataset("Sp1786/multiclass-sentiment-analysis-dataset", split="train[:4500]")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.56M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/601k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/586k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/31232 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5205 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5206 [00:00<?, ? examples/s]

In [4]:
# Function that tokenizes the text
def tokenize(rows):
    # Retrieve Text
    if rows["text"][0] is None:
        text = ""
    else:
        text = rows["text"][0].replace('"', r'\"')

    # Tokenize
    encoded = tokenizer(
          text,
          add_special_tokens=True,
          max_length=260,
          return_token_type_ids=False,
          return_attention_mask=True,
          return_tensors='pt',
          padding='max_length',
        ).to("cuda")


    return encoded

In [5]:
# Applying tokenization to the dataset
tokenized_dataset = dataset.map(tokenize,
                                batched=True,
                                batch_size=1,
                                remove_columns=["id", "text", "label", "sentiment"])

Map:   0%|          | 0/4500 [00:00<?, ? examples/s]

In [6]:
model = DistilBertModel.from_pretrained(model_name, device_map="cuda")

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [7]:
# computing intrinsic dimentions
# Constants

torch.cuda.empty_cache()
gc.collect()

# total samples
num_data = 4500

batches = 10

# samples per batch
batch_data = num_data//batches

# Samples per sub-batch within each batch
per_batch = batch_data//batches

# Number of sub-batches per batch
number_batches = batch_data // per_batch

# initializing intrinsic dimension lists for each batch
intrinsic_dims = [[] for _ in range(batches)]

for x in range(batches):
  torch.set_default_device("cuda")

  # Collect hidden layers
  hidden_layers = []

  # Collect hidden layers per batch
  for batch in tqdm(range(number_batches)):
    for i in range(per_batch):
      # Extract inputs from the dataset using the tokenizer
      index = batch * per_batch + i + x * batch_data
      inputs = {k: torch.tensor(v).unsqueeze(0).to("cuda") for k, v in tokenized_dataset[index].items()}

      # Perform forward pass through the model
      with torch.no_grad():
          outputs = model(**inputs, output_hidden_states=True)

      # Append the hidden states to the list
      liste = list(map(lambda x: x.to('cpu'), outputs.hidden_states))
      hidden_layers.append(liste)
      del outputs, liste, inputs
      torch.cuda.empty_cache()
      gc.collect()


  # Move back to CPU
  torch.set_default_device("cpu")

  # Process hidden layers
  hidden_layers_avg = []
  n = len(hidden_layers[0])
  for j in range(n):
      avg_batch_layer = []
      for i in range(number_batches * per_batch):
          layer = hidden_layers[i][j].detach().cpu()
          avg_batch_layer.append(torch.mean(layer.squeeze(dim=0), dim=0))
      hidden_layers_avg.append(avg_batch_layer)

  # Handle empty tensors
  for layer in hidden_layers_avg:
      for idx, tensor in enumerate(layer):
          if tensor.shape == torch.Size([]):
              print("Encountered empty tensor. Filling with zeros.")
              layer[idx] = torch.zeros(2560)


  # Stack hidden layers
  hidden_layers_stacked = [torch.stack(layer) for layer in hidden_layers_avg]
  hidden_layers_stacked = torch.stack(hidden_layers_stacked)

  del hidden_layers, hidden_layers_avg
  gc.collect()

  # Compute intrinsic dimensions
  for i in range(n):
      X = hidden_layers_stacked[i].numpy()
      data = Data(X)
      data.remove_identical_points()
      id_list_2NN, _, _ = data.return_id_scaling_2NN()
      intrinsic_dims[x].append(id_list_2NN[1])

100%|██████████| 10/10 [02:26<00:00, 14.68s/it]


No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found


100%|██████████| 10/10 [02:24<00:00, 14.46s/it]


No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found


100%|██████████| 10/10 [02:26<00:00, 14.69s/it]


No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found


100%|██████████| 10/10 [02:26<00:00, 14.66s/it]


No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found


100%|██████████| 10/10 [02:25<00:00, 14.51s/it]


No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found


100%|██████████| 10/10 [02:26<00:00, 14.68s/it]


No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found


100%|██████████| 10/10 [02:23<00:00, 14.38s/it]


No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found


100%|██████████| 10/10 [02:24<00:00, 14.48s/it]


No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found


100%|██████████| 10/10 [02:23<00:00, 14.31s/it]


No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found


100%|██████████| 10/10 [02:22<00:00, 14.27s/it]


No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found
No identical identical points were found


In [8]:
# calculating mean and standard deviation of intrinsic dimensions for each layer

# ids = len(intrinsic_dims)

# # total layers
# n = len(intrinsic_dims[0])
# mean = []
# std = []

# for i in range(n):
#   layer_mean = np.mean([intrinsic_dims[0][i],intrinsic_dims[1][i],intrinsic_dims[2][i],intrinsic_dims[3][i],intrinsic_dims[4][i]])
#   mean.append(layer_mean)
#   layer_std = np.std([intrinsic_dims[0][i],intrinsic_dims[1][i],intrinsic_dims[2][i],intrinsic_dims[3][i],intrinsic_dims[4][i]])
#   std.append(layer_std)

In [9]:
# calculating mean and standard deviation of intrinsic dimensions for each layer
mean = np.mean(intrinsic_dims, axis=0)
std = np.std(intrinsic_dims, axis=0)

In [10]:
mean

array([25.37, 20.83, 20.57, 21.32, 24.01, 25.68, 24.69])

In [11]:
std

array([1.48, 1.16, 0.99, 1.05, 1.13, 1.15, 1.07])