In [1]:
# @title Installing  conllu
!pip install conllu


Collecting conllu
  Downloading conllu-6.0.0-py3-none-any.whl.metadata (21 kB)
Downloading conllu-6.0.0-py3-none-any.whl (16 kB)
Installing collected packages: conllu
Successfully installed conllu-6.0.0


In [2]:
# @title importing required modules
import torch
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score
import json
import os

In [3]:
# @title Function to input from user about model
def get_user_model_choice():
        print("Available model options:")
        models = {
            '1': 'distilbert-base-uncased',
            '2': 'bert-base-uncased',
            '3': 'gpt2',
            '4': 'microsoft/DialoGPT-medium',
            '5': 'google/electra-small-discriminator',
            '6': 'facebook/opt-350m',
            '7': 'custom'
        }

        for key, value in models.items():
            print(f"{key}: {value}")

        choice = input("Enter your choice (1-7): ").strip()

        if choice == '7':
            model_name = input("Enter custom model name from HuggingFace: ").strip()
            return model_name
        elif choice in models:
            return models[choice]
        else:
            print("Invalid choice, using default DistilBERT")
            return 'distilbert-base-uncased'

In [4]:
# @title Function to find the attribute of getting no. of layers
def _get_num_layers(model):
        #has_attr or has attribute checks whether this attribute is available in the model or not4
        if hasattr(model.config, 'n_layers'):
            return model.config.n_layers
        elif hasattr(model.config, 'num_hidden_layers'):
            return model.config.num_hidden_layers
        elif hasattr(model.config, 'n_layer'):
            return model.config.n_layer
        else:
            print("Not able to find")

In [5]:
# @title Model initialization Phase
model_name = get_user_model_choice()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_hidden_states=False,attn_implementation="eager", output_attentions=True)
model.eval()
label_encoder = LabelEncoder()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
num_layers = _get_num_layers(model)
if tokenizer.pad_token is None:
  tokenizer.pad_token = tokenizer.eos_token


Available model options:
1: distilbert-base-uncased
2: bert-base-uncased
3: gpt2
4: microsoft/DialoGPT-medium
5: google/electra-small-discriminator
6: facebook/opt-350m
7: custom
Enter your choice (1-7): 1


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [6]:
# @title Function to Prepare dataset
def prepare_data(test=None):
    import requests
    from io import StringIO
    from conllu import parse_incr

    # Load data
    if(not test):
        conllu_url = "https://raw.githubusercontent.com/Harry-Joseph9387/Interpreting-Layer-wise-Representations-in-LLMs-using-Sparse-Autoencoders/refs/heads/main/en_ewt-ud-train.conllu"
    else:
        conllu_url = "https://raw.githubusercontent.com/Harry-Joseph9387/Interpreting-Layer-wise-Representations-in-LLMs-using-Sparse-Autoencoders/refs/heads/main/en_lines-ud-train.conllu"

    response = requests.get(conllu_url)
    file_content = StringIO(response.text)

    sentences = []
    pos_labels = []
    dep_labels = []
    position_labels = []

    for tokenlist in parse_incr(file_content):
        if not tokenlist:
            continue

        tokens = []
        pos = []
        dep = []

        for token in tokenlist:
            if isinstance(token["id"], int) and token["form"] is not None:
                tokens.append(token["form"])
                pos.append(token["upos"])
                dep.append(token["deprel"])

        if tokens:
            sentences.append(" ".join(tokens))
            pos_labels.append(pos)
            dep_labels.append(dep)
            position_labels.append(list(range(len(tokens))))

    return sentences, {'pos': pos_labels, 'dep': dep_labels, 'position': position_labels}

In [31]:
train_sentences,train_labels_dict=prepare_data()

In [7]:
# @title Function to align tokens with labels
def _align_tokens_with_labels(sentences, labels, all_tokens, task_type, start_idx=0):
    aligned_data = []

    for batch_sent_idx, (sentence, original_labels, tokens) in enumerate(zip(sentences, labels, all_tokens)):
        sent_idx = start_idx + batch_sent_idx
        original_words = sentence.split()
        word_idx = 0

        for token_idx, token in enumerate(tokens):
            # 🚫 Skip special and pad tokens
            if token in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']:
                continue

            # 🔗 Handle subword tokens (e.g., ##word)
            if token.startswith('##'):
                if aligned_data and aligned_data[-1]['sentence_idx'] == sent_idx:
                    label = aligned_data[-1]['label']
                else:
                    label = original_labels[word_idx] if word_idx < len(original_labels) else 'O'
            else:
                if word_idx < len(original_labels):
                    label = original_labels[word_idx]
                    word_idx += 1
                else:
                    label = 'O'

            aligned_data.append({
                'sentence_idx': sent_idx,
                'token_idx': token_idx,
                'label': label,
                'task': task_type,
                'token': token,
                'batch_sent_idx': batch_sent_idx
            })

    print(f"Aligned data size (excluding special tokens): {len(aligned_data)}")
    return aligned_data


In [8]:
# @title Function to extract hidden representations from model
def _extract_representations_(sentences, labels_dict, tokenizer, num_layers, model, device, test=None):
    all_hidden_states = {i: [] for i in range(num_layers)}
    all_tokens = []
    batch_size = 8

    for i in range(0, len(sentences), batch_size):
        batch = sentences[i:i + batch_size]

        encoded = tokenizer(batch, padding=True, truncation=True,
                            max_length=128, return_tensors='pt')

        input_ids = encoded['input_ids'].to(device)
        attention_mask = encoded['attention_mask'].to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

            hidden_states = outputs.hidden_states if hasattr(outputs, 'hidden_states') else [outputs.last_hidden_state]

            print(f"{i // batch_size}th batch is processed out of {len(sentences) // batch_size}")

            tokens_batch = [tokenizer.convert_ids_to_tokens(ids) for ids in encoded['input_ids']]
            all_tokens.extend(tokens_batch)  # Save for alignment step

            # Filter and save hidden states layer by layer
            for layer_idx in range(1, len(hidden_states)):  # Skip layer 0 (embedding layer)
                layer_hidden = hidden_states[layer_idx]  # (batch_size, seq_len, hidden_dim)
                filtered_batch = []

                for b in range(len(batch)):
                    sentence_tokens = tokens_batch[b]
                    sentence_hidden = layer_hidden[b]  # (seq_len, hidden_dim)
                    sentence_mask = attention_mask[b]  # (seq_len,)

                    filtered = []

                    for t_idx, (tok, h_vec, attn) in enumerate(zip(sentence_tokens, sentence_hidden, sentence_mask)):
                        if attn == 0:
                            continue  # Skip PAD
                        if tok in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']:
                            continue  # Skip special tokens
                        filtered.append(h_vec.cpu())

                    if filtered:
                        filtered_batch.append(torch.stack(filtered))  # (filtered_seq_len, hidden_dim)
                    else:
                        filtered_batch.append(torch.empty((0, model.config.hidden_size)))  # fallback for empty case

                all_hidden_states[layer_idx - 1].extend(filtered_batch)

    print("all batches processed completely")

    # Initialize aligned data for all tasks
    all_task_alignedData = {task_type: [] for task_type in ['pos', 'dep', 'position']}

    # Process alignment batch by batch to maintain correct correspondence
    for i in range(0, len(sentences), batch_size):
        batch_sentences = sentences[i:i + batch_size]
        batch_tokens = all_tokens[i:i + batch_size]

        for task_type in ['pos', 'dep', 'position']:
            batch_labels = labels_dict[task_type][i:i + batch_size]
            batch_aligned_data = _align_tokens_with_labels(batch_sentences, batch_labels, batch_tokens, task_type, i)
            all_task_alignedData[task_type].extend(batch_aligned_data)

    return all_hidden_states, all_task_alignedData


In [32]:
all_hidden_states,train_all_task_alignedData=_extract_representations_(train_sentences, train_labels_dict,tokenizer,num_layers,model,device)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
1273th batch is processed out of 1568
1274th batch is processed out of 1568
1275th batch is processed out of 1568
1276th batch is processed out of 1568
1277th batch is processed out of 1568
1278th batch is processed out of 1568
1279th batch is processed out of 1568
1280th batch is processed out of 1568
1281th batch is processed out of 1568
1282th batch is processed out of 1568
1283th batch is processed out of 1568
1284th batch is processed out of 1568
1285th batch is processed out of 1568
1286th batch is processed out of 1568
1287th batch is processed out of 1568
1288th batch is processed out of 1568
1289th batch is processed out of 1568
1290th batch is processed out of 1568
1291th batch is processed out of 1568
1292th batch is processed out of 1568
1293th batch is processed out of 1568
1294th batch is processed out of 1568
1295th batch is processed out of 1568
1296th batch is processed out of 1568
1297th batch is process

In [28]:
# @title check alignment hidden vector with label and token on the basis of count
def check_alignment(hidden_states, all_task_alignedData, task_type='pos'):
    print(f"\nChecking alignment for task: {task_type}")

    num_sentences = len(hidden_states[0])  # assuming all layers have same count
    print(f"Total sentences: {num_sentences}")

    for sent_idx in range(num_sentences):
        # Hidden vectors (from any one layer, e.g., layer 0)
        hidden_vecs = hidden_states[0][sent_idx]  # shape: (seq_len, hidden_dim)
        hidden_token_count = hidden_vecs.shape[0]

        # From aligned data (tokens + labels) — skip special tokens
        aligned_entries = [
            d for d in all_task_alignedData[task_type]
            if d['sentence_idx'] == sent_idx and d['label'] != 'SPECIAL'
        ]
        aligned_token_count = len(aligned_entries)
        aligned_label_count = aligned_token_count  # 1 label per token

        if hidden_token_count != aligned_token_count or hidden_token_count != aligned_label_count:
            print(f"[Mismatch] Sentence {sent_idx} → Hidden: {hidden_token_count}, Tokens: {aligned_token_count}, Labels: {aligned_label_count}")
        # else:
        #     print(f"[OK] Sentence {sent_idx} → Count: {hidden_token_count}")

    print("✅ Alignment check complete.\n")
check_alignment(hidden_states, all_task_alignedData, task_type='pos')


Checking alignment for task: pos
Total sentences: 3457
✅ Alignment check complete.



In [1]:
# @title Code to check the label count and token count of each sentence
def _check_label_token_count_(sentences,all_task_alignedData):
  for sent_id in range(len(sentences)):
    result = [d for d in all_task_alignedData['pos'] if d.get('sentence_idx') == sent_id]
    sentence_ith= [d['token'] for d in result if 'token_idx' in d]
    sentence_ith_label= [d['label'] for d in result if 'token_idx' in d]
    # print(sentence_ith)
    # print(sentence_ith_label)
    if(len(sentence_ith)!=len(sentence_ith_label)):
      print(f"mismatch at f{sent_id}")
_check_label_token_count_(sentences,all_task_alignedData)

NameError: name 'sentences' is not defined

In [33]:
# @title Function to save model's current state to drive
def save_checkpoint(layer_idx, encoder, decoder, optimizer, current_epoch,task):
    checkpoint = {
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': current_epoch
    }
    CHECKPOINT_DIR=f"/content/drive/MyDrive/SAE_outputs/sae_checkpoints_{task}"
    if(not os.path.exists(CHECKPOINT_DIR)):
          os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    torch.save(checkpoint, f"{CHECKPOINT_DIR}/layer_{layer_idx}_model.pt")

In [34]:
# @title Function to load model's current state from drive
def load_checkpoint(layer_idx, encoder, decoder, optimizer,task,resume):
  if resume=='c':
    return 0

    checkpoint_path = f"/content/drive/MyDrive/SAE_outputs/sae_checkpoints_{task}/layer_{layer_idx}_model.pt"
    print(checkpoint_path)
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if(resume=='b'):
          return checkpoint.get('epoch', 0)
        else:
          return 0
    print('couldnt found the file, starting from scretch')
    return 0

In [35]:
# @title Function that applies topK sparsity
def _apply_top_k_sparsity(activations, k):
    top_values, top_indices = torch.topk(torch.abs(activations), k, dim=1)
    sparse_activations = torch.zeros_like(activations)
    row_idx = torch.arange(activations.size(0)).unsqueeze(1).expand(-1, k)
    sparse_activations[row_idx, top_indices] = activations[row_idx, top_indices]
    return sparse_activations



In [36]:
# @title Function to train Sparse autoencoder for a layer
def train_sparse_autoencoder(hidden_states, layer_idx,device, resume, task, sae_dim=2048, sparsity_coeff=0.01,
                           top_k=50, epochs=100, lr=0.001 , batch_size=131072):

        all_tokens = []
        hidden_dim =hidden_states[layer_idx][0].shape[-1]
        with torch.no_grad():
          all_tokens = torch.cat([
              batch.view(-1, hidden_dim) for batch in hidden_states[layer_idx]
          ], dim=0).to(device)





        encoder = torch.nn.Linear(hidden_dim, sae_dim, bias=True).to(device)
        decoder = torch.nn.Linear(sae_dim, hidden_dim, bias=False).to(device)
        torch.nn.init.xavier_uniform_(encoder.weight)
        torch.nn.init.xavier_uniform_(decoder.weight)
        optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)


        start_epoch = load_checkpoint(layer_idx, encoder, decoder, optimizer,task,resume)

        dataset = torch.utils.data.TensorDataset(all_tokens)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) #splits into batches of shuffled hidden vectors while preserving the shape

        try:

            for epoch in range(start_epoch,epochs):
                total_loss = 0
                total_recon_loss = 0
                total_sparsity_loss = 0

                for batch_idx, (batch_vectors,) in enumerate(dataloader):
                    batch_vectors = batch_vectors.to(device)
                    raw_activations = encoder(batch_vectors)
                    sparse_activations = _apply_top_k_sparsity(raw_activations, top_k)
                    reconstructed = decoder(sparse_activations)
                    recon_loss = torch.nn.functional.mse_loss(reconstructed, batch_vectors)
                    sparsity_loss = torch.mean(torch.abs(sparse_activations))



                    total_loss_batch = recon_loss + sparsity_coeff * sparsity_loss
                    optimizer.zero_grad()#clears previous gradients
                    total_loss_batch.backward()#computes gradients
                    optimizer.step()#updates the weights θ=θ−η⋅gradient n is the learning rate

                    total_loss += total_loss_batch.item()
                    total_recon_loss += recon_loss.item()
                    total_sparsity_loss += sparsity_loss.item()

                if epoch % 10 == 0:
                    avg_loss = total_loss / len(dataloader)
                    avg_recon = total_recon_loss / len(dataloader)
                    avg_sparsity = total_sparsity_loss / len(dataloader)
                    print(f"Epoch {epoch}: avg-Loss={avg_loss:.4f}, Recon={avg_recon:.4f}, Sparsity={avg_sparsity:.4f} of layer{layer_idx}")

        except KeyboardInterrupt:
          print(f"\nTraining interrupted at epoch {epoch}. Saving checkpoint for layer {layer_idx}...")
          save_checkpoint(layer_idx, encoder, decoder, optimizer, epoch,task)


        save_checkpoint(layer_idx, encoder, decoder, optimizer, epoch,task)

        encoder.eval()
        decoder.eval()




In [37]:
# @title Function to Train the Encoder for all layers based on a task
import torch
from concurrent.futures import ThreadPoolExecutor
def _training_all_encoders_(task,hidden_states,device):
  max_parallel=3
  # resume=input('retrain or resume from last epoch or start training from scratch? a/b/c ')
  resume='c'
  def wrapped(layer_idx):
    print(f"🔧 Training SAE for layer {layer_idx} of task '{task}'", flush=True)
    return train_sparse_autoencoder(hidden_states, layer_idx, device, resume, task)

  with ThreadPoolExecutor(max_workers=max_parallel) as executor:
    executor.map(wrapped, range(num_layers))

  print(f"✅ Finished all layers for task '{task}'")


In [10]:
# @title Code to mount drive
import os
from google.colab import drive
if( not os.path.exists('/content/drive')):
      drive.mount('/content/drive')

Mounted at /content/drive


In [38]:
# @title Code to execute the training of Encoder
import os, torch, pickle



tasks = ['pos','dep',"position"]


if os.path.exists("last_task.txt"):
  with open("last_task.txt", "r") as f:
    completed_tasks=f.read().strip()

else:
  completed_tasks=''
  open("last_task.txt", "w")


hidden_states=all_hidden_states

for each_task in tasks:
  if each_task in completed_tasks.split(' '):
    continue
  try:
    print(f"started training for task {each_task}")
    sae_models=_training_all_encoders_(each_task,hidden_states,device)

  except KeyboardInterrupt:
    print(f"\nTraining interrupted at task {each_task}.")
    break


  with open("last_task.txt", "a") as f:
    f.write(each_task+' ')

open("last_task.txt", "w")

started training for task pos
🔧 Training SAE for layer 0 of task 'pos'
🔧 Training SAE for layer 1 of task 'pos'
🔧 Training SAE for layer 2 of task 'pos'
Epoch 0: avg-Loss=0.4586, Recon=0.4584, Sparsity=0.0287 of layer1
Epoch 0: avg-Loss=0.5230, Recon=0.5227, Sparsity=0.0303 of layer2
Epoch 0: avg-Loss=0.4231, Recon=0.4228, Sparsity=0.0275 of layer0
Epoch 10: avg-Loss=0.2150, Recon=0.2146, Sparsity=0.0404 of layer1
Epoch 10: avg-Loss=0.1951, Recon=0.1947, Sparsity=0.0372 of layer0
Epoch 10: avg-Loss=0.2395, Recon=0.2390, Sparsity=0.0491 of layer2
Epoch 20: avg-Loss=0.1365, Recon=0.1360, Sparsity=0.0503 of layer1
Epoch 20: avg-Loss=0.1257, Recon=0.1253, Sparsity=0.0441 of layer0
Epoch 20: avg-Loss=0.1489, Recon=0.1482, Sparsity=0.0662 of layer2
Epoch 30: avg-Loss=0.1082, Recon=0.1077, Sparsity=0.0499 of layer1
Epoch 30: avg-Loss=0.0991, Recon=0.0987, Sparsity=0.0442 of layer0
Epoch 30: avg-Loss=0.1200, Recon=0.1194, Sparsity=0.0628 of layer2
Epoch 40: avg-Loss=0.0930, Recon=0.0925, Spars

<_io.TextIOWrapper name='last_task.txt' mode='w' encoding='utf-8'>

In [50]:
# @title Function to interpret strongly activated features for a layer
def _find_interpretable_features(all_features, aligned_data, top_n=10):

        interpretable_features = []
        num_features = all_features.shape[1]

        for feature_idx in range(num_features):
            feature_activations = all_features[:, feature_idx]
            threshold = torch.quantile(feature_activations, 0.90)
            strong_positions = torch.where(feature_activations > threshold)[0]

            if len(strong_positions) < 5:
                continue

            labels_for_feature = []
            for pos in strong_positions:
                labels_for_feature.append(aligned_data[pos.item()]['label']) #aligned data is a list of dictionary representing each token , each token data is just appended

            if labels_for_feature:
                from collections import Counter
                label_counts = Counter(labels_for_feature)
                most_common_count = label_counts.most_common(1)[0][1]
                consistency = most_common_count / len(labels_for_feature)

                interpretable_features.append({
                    'feature_idx': feature_idx,
                    'consistency': consistency,
                    'dominant_label': label_counts.most_common(1)[0][0],
                    'activation_count': len(strong_positions),
                    'label_distribution': dict(label_counts.most_common(5))
                })

        '''
            interpretable_features = [
        {'feature_idx': 5, 'consistency': 0.87},
        {'feature_idx': 2, 'consistency': 0.93},
        {'feature_idx': 7, 'consistency': 0.76}
        ]

        '''

        interpretable_features.sort(key=lambda x: x['consistency'], reverse=True)
        return interpretable_features[:top_n]


In [42]:
# @title Function to compute relation of label and activation score using MI
from sklearn.metrics import mutual_info_score
from collections import defaultdict, Counter

def compute_label_mi(all_features, aligned_data, interpretable_features, threshold=0.5):
    label_mi_scores = defaultdict(float)  # cumulative MI per label

    # Flattened labels per sample
    all_labels = [data["label"] for data in aligned_data]  # or data["labels"] if multilabel
    all_labels = np.array(all_labels)

    for feature in interpretable_features:
        idx = feature["feature_idx"]
        activations = all_features[:, idx].detach().cpu().numpy()

        # Binarize activations (top-k or threshold)
        top_indices = activations.argsort()[-100:]  # top 100 activations (tune as needed)
        bin_activations = np.zeros_like(activations)
        bin_activations[top_indices] = 1

        # Align labels (assumes one label per sample, adapt for multilabel)
        mi = mutual_info_score(bin_activations, all_labels)

        # Count labels among top activations
        top_labels = [all_labels[i] for i in top_indices]
        top_label_counts = Counter(top_labels)

        for label, count in top_label_counts.items():
            label_mi_scores[label] += mi * (count / len(top_indices))  # weight MI by how dominant label is

    # Sort by MI
    sorted_mi_labels = sorted(label_mi_scores.items(), key=lambda x: x[1], reverse=True)
    return sorted_mi_labels[:5]  # top 5 labels

In [48]:
# @title Function to test a Trained encoder
def test_sparse_autoencoder(sae_model, hidden_states, aligned_data,device,layer_idx,top_k=50):

        layer_idx = sae_model['layer_idx']
        encoder = sae_model['encoder']
        top_k = top_k


        all_features = []


        encoder.eval()

        with torch.no_grad():
            for each_batch in hidden_states[layer_idx]:
                seq_len, hidden_dim = each_batch.shape
                each_batch_combined_tokens = each_batch.view(-1, hidden_dim).to(device)
                raw_activations = encoder(each_batch_combined_tokens)
                sparse_features = _apply_top_k_sparsity(raw_activations, top_k)
                all_features.append(sparse_features.cpu())

        all_features = torch.cat(all_features, dim=0)
        print(f"Completed processing all batches for layer {layer_idx}")


        interpretable_features =_find_interpretable_features(all_features, aligned_data, top_n=10)

        label_mi_list = compute_label_mi(all_features, aligned_data, interpretable_features)

        return {
            'layer_idx': layer_idx,
            'features': all_features,
            'interpretable_features': interpretable_features,
            'label_mi_list': label_mi_list

        }


In [51]:
# @title Function to test Trained encoders from all layers
from concurrent.futures import ThreadPoolExecutor

def _testing_all_encoders_(sae_models,hidden_states,aligned_data,device,task):
  max_parallel=3
  each_task_sae_results = {}

  def wrapped(layer_idx):
    print(f" Testing SAE for layer {layer_idx} of task '{task}'", flush=True)
    return test_sparse_autoencoder(sae_models[layer_idx], hidden_states, aligned_data,device,layer_idx)

  with ThreadPoolExecutor(max_workers=max_parallel) as executor:
    futures = executor.map(wrapped, range(num_layers))

    for layer_idx, result in enumerate(futures):
      each_task_sae_results[layer_idx] = result
  print(f" Finished testing all layers for task '{task}'")
  return each_task_sae_results


In [None]:
# @title Code to prepare testing dataset
test_sentences,test_labels_dict=prepare_data(test=True)


hidden_states,test_all_task_alignedData=_extract_representations_(test_sentences, test_labels_dict,tokenizer,num_layers,model,device)

In [None]:
check_alignment(hidden_states, test_all_task_alignedData, task_type='pos')

In [40]:
# @title Function to load the trained sae models
def _load_trained_sae_models_(task):
  models={}
  base_dir=f"/content/drive/MyDrive/SAE_outputs/sae_checkpoints_{task}/"
  for layer_idx in range(num_layers):
    model_path = os.path.join(base_dir, f"layer_{layer_idx}_model.pt")
    checkpoint = torch.load(model_path, map_location=device)
    hidden_dim = checkpoint['encoder_state_dict']['weight'].shape[1]
    sae_dim = checkpoint['encoder_state_dict']['weight'].shape[0]

    encoder = torch.nn.Linear(hidden_dim, sae_dim, bias=True).to(device)
    encoder.load_state_dict(checkpoint['encoder_state_dict'])


    models[layer_idx] = {
        'encoder': encoder,
        'layer_idx': layer_idx
    }
  return models

In [45]:
# @title Code to execute the testing phase
tasks = ['pos','dep',"position"]


sae_models_results={}
if os.path.exists("last_task.txt"):
  with open("last_task.txt", "r") as f:
    completed_tasks=f.read().strip()

else:
  completed_tasks=''
  open("last_task.txt", "w")



for each_task in tasks:
  if each_task in completed_tasks.split(' '):
    continue

  try:
    print(f"started testing for task {each_task}")

    sae_models=_load_trained_sae_models_(each_task)
    sae_models_results[each_task]=_testing_all_encoders_(sae_models,hidden_states,test_all_task_alignedData[each_task],device,each_task)

  except KeyboardInterrupt:
          print(f"\nTraining interrupted at task {each_task}.")
          break
  with open("last_task.txt", "a") as f:
    f.write(each_task+' ')



started testing for task pos
 Testing SAE for layer 0 of task 'pos'
 Testing SAE for layer 1 of task 'pos'
 Testing SAE for layer 2 of task 'pos'
 Testing SAE for layer 3 of task 'pos'
 Testing SAE for layer 4 of task 'pos'
 Testing SAE for layer 5 of task 'pos'
 Finished testing all layers for task 'pos'
started testing for task dep
 Testing SAE for layer 0 of task 'dep' Testing SAE for layer 1 of task 'dep'

 Testing SAE for layer 2 of task 'dep'
 Testing SAE for layer 3 of task 'dep'
 Testing SAE for layer 4 of task 'dep'
 Testing SAE for layer 5 of task 'dep'
 Finished testing all layers for task 'dep'
started testing for task position
 Testing SAE for layer 0 of task 'position'
 Testing SAE for layer 1 of task 'position'
 Testing SAE for layer 2 of task 'position'
 Testing SAE for layer 3 of task 'position'
 Testing SAE for layer 4 of task 'position'
 Testing SAE for layer 5 of task 'position'
 Finished testing all layers for task 'position'


In [52]:
# @title Code to display the interpretation of each layer
from collections import defaultdict, Counter

def summarize_labels_per_layer_across_tasks(sae_output):
    layer_label_map = defaultdict(list)  # {layer_idx: [dominant_labels from all tasks]}

    for task_data in sae_output.values():
        for layer_idx, layer_data in task_data.items():
            for feat in layer_data['interpretable_features']:
                label = feat['dominant_label']
                layer_label_map[layer_idx].append(label)

    print("GLOBAL SAE LAYER SUMMARY (All Tasks)\n")
    for layer_idx in sorted(layer_label_map.keys()):
        label_counts = Counter(layer_label_map[layer_idx])
        top_labels = label_counts.most_common(3)

        print(f"  Layer {layer_idx}:")
        print(f"   - Dominant label     : {top_labels[0][0]} ({top_labels[0][1]} features)")
        if len(top_labels) > 1:
            print("   - Other strong labels:")
            for label, count in top_labels[1:]:
                print(f"       - {label}: {count}")
        else:
            print("   - No other strong label")

        total_feats = sum(label_counts.values())
        print(f"  - Total features     : {total_feats}")
        print()

summarize_labels_per_layer_across_tasks(sae_models_results)

GLOBAL SAE LAYER SUMMARY (All Tasks)

  Layer 0:
   - Dominant label     : O (13 features)
   - Other strong labels:
       - PRON: 7
       - det: 3
  - Total features     : 30

  Layer 1:
   - Dominant label     : O (25 features)
   - Other strong labels:
       - PRON: 3
       - nsubj: 1
  - Total features     : 30

  Layer 2:
   - Dominant label     : O (26 features)
   - Other strong labels:
       - ADP: 1
       - PROPN: 1
  - Total features     : 30

  Layer 3:
   - Dominant label     : O (25 features)
   - Other strong labels:
       - punct: 2
       - PUNCT: 1
  - Total features     : 30

  Layer 4:
   - Dominant label     : O (19 features)
   - Other strong labels:
       - PUNCT: 6
       - punct: 3
  - Total features     : 30

  Layer 5:
   - Dominant label     : O (28 features)
   - Other strong labels:
       - 0: 2
  - Total features     : 30

