In [None]:
from daily_dialog import DailyDialog

builder = DailyDialog()
builder.download_and_prepare()
dataset = builder.as_dataset()

from transformers import AutoTokenizer
from datasets import DatasetDict

# Load a pre-trained tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the dialogues (flattening lists of utterances)
def tokenize_function(example):
    # Check if each dialog entry is a list of strings and join them
    if isinstance(example['dialog'], list):
        example['dialog'] = [' '.join(dialog) if isinstance(dialog, list) else dialog for dialog in example['dialog']]
    
    return tokenizer(example['dialog'], padding='max_length', truncation=True, max_length=128)

# Apply the tokenizer to the entire dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)


# Format the dataset for PyTorch or TensorFlow
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask'])


In [None]:
def print_dataset_structure(dataset):
    for split_name, split_data in dataset.items():
        print(f"\n--- {split_name.upper()} SPLIT ---")
        print(f"Number of rows: {len(split_data)}")
        print("Columns:", split_data.column_names)
        
        # Print a sample of the data (e.g., the first entry)
        print("\nSample data (first entry):")
        sample = split_data[0]
        for key, value in sample.items():
            # If the value is a long list (e.g., input_ids), print only the start for brevity
            if isinstance(value, list) and len(value) > 10:
                print(f"{key}: {value[:10]}... (truncated)")
            else:
                print(f"{key}: {value}")

# Print the structure of the tokenized_datasets
print_dataset_structure(tokenized_datasets)


In [None]:
import train

train.run(total_epoch=50, best_loss=float('inf'))

In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import train 
%matplotlib inline
import analysis.sae_analyzer as SAEAnalyzer

# Run the training
# When training:
model, valid_iter, device, sae_metadata = train.main(total_epoch=1, batch_size=32)

Repo card metadata block was not found. Setting CardData to empty.


Initialized DialogLoader with bert-base-uncased tokenizer
Loading DailyDialog dataset...
Dataset loaded successfully
Tokenizing dataset...
Formatting dataset for PyTorch...
Dataset preparation complete
Training step: 0/348 (0.0%), Loss: 75423.3359
Training step: 34/348 (9.8%), Loss: 107.2147
Training step: 68/348 (19.5%), Loss: 17.0420
Training step: 102/348 (29.3%), Loss: 7.9631
Training step: 136/348 (39.1%), Loss: 3.8269
Training step: 170/348 (48.9%), Loss: 3.6748
Training step: 204/348 (58.6%), Loss: 3.1389
Training step: 238/348 (68.4%), Loss: 3.2853
Training step: 272/348 (78.2%), Loss: 3.0937
Training step: 306/348 (87.9%), Loss: 2.9659
Training step: 340/348 (97.7%), Loss: 3.2965
Training step: 347/348 (99.7%), Loss: 4.0055
Epoch: 1 | Time: 0m 49s
	Train Loss: 1474.612
	Valid Loss: 1.882
	Act Accuracy: 0.549
	Emotion Accuracy: 0.702

Training completed! Best checkpoint saved at:
saved/model-1.882.pt
Best validation loss: 1.882

Collecting SAE metadata...
Debug - encoded shape:

In [3]:
import sys
if 'analysis.sae_analyzer' in sys.modules:
    del sys.modules['analysis.sae_analyzer']

# Then reload the module
from analysis.sae_analyzer import SAEAnalyzer

In [4]:
# Create analyzer from metadata (no need to recompute)
analyzer = SAEAnalyzer.from_metadata(sae_metadata)


# Now you can interactively analyze the results
# For example:
analyzer.plot_activation_distribution()
plt.show()

analyzer.plot_feature_correlations()
plt.show()

analyzer.plot_reconstruction_error()
plt.show()

KeyError: 'activations'

In [None]:
analyzer.plot_feature_correlations()
plt.show()

In [None]:
analyzer.plot_feature_correlations()
plt.show()

In [None]:
analyzer.plot_reconstruction_error()
plt.show()