In [3]:
import torch
from transformer_lens import HookedTransformer
import warnings

import os
from tqdm.auto import tqdm
from collections import Counter

from sklearn.model_selection import train_test_split
import numpy as np

import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Suppress a specific UserWarning from SentencePiece
warnings.filterwarnings("ignore", message=r".*Ignoring tokenizer_config\.json since it is not set\. It is likely that you are loading a tokenizer from a previous version of the library which does not contain this file\..*")

print("--- Environment Setup ---")

# Check if a GPU is available and set the device
if torch.cuda.is_available():
    device = "cuda"
    print(f"GPU detected: {torch.cuda.get_device_name(0)}")
    # Clear cache to free up memory on the GPU
    torch.cuda.empty_cache()
else:
    device = "cpu"
    print("No GPU detected. Using CPU. This will be very slow.")


--- Environment Setup ---
GPU detected: NVIDIA GeForce RTX 3090


In [5]:

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
print(f"\n--- Loading Model: {model_name} ---")
print("This will download and load ~16 GB of model weights. This may take several minutes.")

# Load the model directly using HookedTransformer.
# We do NOT use quantization (`load_in_4bit`).
# `torch_dtype=torch.bfloat16` is recommended for performance and is supported by the 3090.
model = HookedTransformer.from_pretrained(
    model_name,
    device=device,
    torch_dtype=torch.bfloat16,
    # No quantization arguments needed!
)




--- Loading Model: meta-llama/Meta-Llama-3-8B-Instruct ---
This will download and load ~16 GB of model weights. This may take several minutes.


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 39.18it/s]


Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer
