# The below code is for the single attention 

In [2]:
# ```python
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import re
import os

# Print versions and device info
print("PyTorch Version:", torch.__version__)
print("NumPy Version:", np.__version__)
print("Matplotlib Version:", matplotlib.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))

# Set random seed for reproducibility
torch.manual_seed(42)

# Sample text (Indus Valley Civilization discovery)
text = """
In 1856, British colonial officials in India were busy monitoring the construction of a railway connecting the cities of Lahore and Karachi in modern-day Pakistan along the Indus River valley.
As they continued to work, some of the laborers discovered many fire-baked bricks lodged in the dry terrain. There were hundreds of thousands of fairly uniform bricks, which seemed to be quite old. Nonetheless, the workers used some of them to construct the road bed, unaware that they were using ancient artifacts. They soon found among the bricks stone artifacts made of soapstone, featuring intricate artistic markings.
Though they did not know it then, and though the first major excavations did not take place until the 1920s, these railway workers had happened upon the remnants of the Indus Valley Civilization, also known as the Harappan Civilization, after Harappa, the first of its sites to be excavated, in what was then the Punjab province of British India and is now in Pakistan. Initially, many archaeologists thought they had found ruins of the ancient Maurya Empire, a large empire which dominated ancient India between c. 322 and 185 BCE.
Before the excavation of these Harappan cities, scholars thought that Indian civilization had begun in the Ganges valley as Aryan immigrants from Persia and central Asia populated the region around 1250 BCE. The discovery of ancient Harappan cities unsettled that conception and moved the timeline back another 1500 years,situating the Indus Valley Civilization in an entirely different environmental context.
"""

# Preprocess text: lowercase, remove extra spaces
text = re.sub(r'\s+', ' ', text.lower()).strip()
chars = sorted(list(set(text)))
char_to_int = {c: i for i, c in enumerate(chars)}
int_to_char = {i: c for i, c in enumerate(chars)}
n_chars = len(text)
n_vocab = len(chars)

# Prepare data: sequences of length seq_length, predict next character
seq_length = 20
dataX, dataY = [], []
for i in range(0, n_chars - seq_length):
    seq_in = text[i:i + seq_length]
    seq_out = text[i + seq_length]
    dataX.append([char_to_int[char] for char in seq_in])
    dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print(f"Number of training sequences: {n_patterns}")

# Convert to tensors
X = torch.tensor(dataX, dtype=torch.float32).reshape(n_patterns, seq_length, 1) / float(n_vocab)
y = torch.tensor(dataY, dtype=torch.long)

# Define LSTM with enhanced attention (Transformer-style)
class LSTMWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, seq_length, temperature=1.0):
        super(LSTMWithAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.seq_length = seq_length
        self.temperature = temperature
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, dropout=0.3)
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)  # lstm_out: (batch_size, seq_length, hidden_dim)
        q = self.query(lstm_out)  # (batch_size, seq_length, hidden_dim)
        k = self.key(lstm_out)  # (batch_size, seq_length, hidden_dim)
        v = self.value(lstm_out)  # (batch_size, seq_length, hidden_dim)
        scores = torch.bmm(q, k.transpose(1, 2)) / (self.hidden_dim ** 0.5)  # (batch_size, seq_length, seq_length)
        scores = scores / self.temperature  # Sharpen attention
        attn_weights = self.softmax(scores)  # (batch_size, seq_length, seq_length)
        context = torch.bmm(attn_weights, v)  # (batch_size, seq_length, hidden_dim)
        last_context = context[:, -1, :]  # (batch_size, hidden_dim)
        out = self.fc(last_context)  # (batch_size, output_dim)
        return out, attn_weights

# Hyperparameters
input_dim = 1
hidden_dim = 256
output_dim = n_vocab
batch_size = 32
temperature = 0.3  # Further sharpened attention

# Initialize model, loss, and optimizer
model = LSTMWithAttention(input_dim, hidden_dim, output_dim, seq_length, temperature).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

# Move data to device
X = X.to(device)
y = y.to(device)

# Train the model
epochs = 300  # Increased for better convergence
for epoch in range(epochs):
    model.train()
    total_loss = 0
    num_batches = max(1, (n_patterns + batch_size - 1) // batch_size)
    for i in range(0, n_patterns, batch_size):
        batch_X = X[i:i+batch_size]
        batch_y = y[i:i+batch_size]
        optimizer.zero_grad()
        outputs, _ = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    avg_loss = total_loss / num_batches if num_batches > 0 else total_loss
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")

# Compute connectivity for a sample input
model.eval()
sample_input = text[:seq_length]
input_data = torch.tensor([[char_to_int[c] for c in sample_input]], dtype=torch.float32).reshape(1, seq_length, 1) / float(n_vocab)
input_data = input_data.to(device)
with torch.no_grad():
    _, attn_weights = model(input_data)
attn_weights = attn_weights.squeeze(0).cpu().numpy()  # Shape: (seq_length, seq_length)

# Debug attention weights
print("Attention weights shape:", attn_weights.shape)
print("Attention weights min:", attn_weights.min(), "max:", attn_weights.max())
print("Attention weights sample:", attn_weights[:3, :3])
print("Input characters:", [text[i] for i in range(seq_length)])
print("Output characters:", [text[i+1] for i in range(seq_length)])
# Highlight focus for multiple output characters
for output_idx in range(5, 20):
    focus = attn_weights[output_idx]
    max_focus_idx = np.argmax(focus)
    print(f"When decoding output character {output_idx+1} ('{text[output_idx+1]}'), the model focuses most on input character {max_focus_idx+1} ('{text[max_focus_idx]}') with weight {focus[max_focus_idx]:.4f}")

# Normalize attention weights for better contrast
attn_weights = (attn_weights - attn_weights.min()) / (attn_weights.max() - attn_weights.min())


# Plot heatmap
plt.figure(figsize=(14, 12))
sns.heatmap(attn_weights, xticklabels=[text[i] for i in range(seq_length)],
            yticklabels=[text[i+1] for i in range(seq_length)],
            cmap='Blues', annot=True, fmt='.2f', cbar_kws={'label': 'Attention Weight'})
plt.title('LSTM Connectivity: Input-Output Character Dependencies', fontsize=16)
plt.xlabel('Input Characters', fontsize=14)
plt.ylabel('Output Characters', fontsize=14)
plt.xticks(rotation=45)
plt.yticks(rotation=0)

# Save heatmap to an absolute path
save_path = os.path.join(os.getcwd(), 'connectivity_heatmap_1.png')
plt.savefig(save_path, bbox_inches='tight', dpi=300)
print(f"Heatmap saved to: {save_path}")
plt.show()

PyTorch Version: 2.6.0+cu124
NumPy Version: 1.26.4
Matplotlib Version: 3.9.2
Using device: cuda
GPU Name: NVIDIA GeForce RTX 3050
Number of training sequences: 1538




Epoch 1/300, Loss: 3.1646, LR: 0.001000
Epoch 2/300, Loss: 3.0214, LR: 0.001000
Epoch 3/300, Loss: 3.0140, LR: 0.001000
Epoch 4/300, Loss: 3.0083, LR: 0.001000
Epoch 5/300, Loss: 3.0043, LR: 0.001000
Epoch 6/300, Loss: 3.0017, LR: 0.001000
Epoch 7/300, Loss: 2.9999, LR: 0.001000
Epoch 8/300, Loss: 2.9983, LR: 0.001000
Epoch 9/300, Loss: 2.9970, LR: 0.001000
Epoch 10/300, Loss: 2.9957, LR: 0.001000
Epoch 11/300, Loss: 2.9952, LR: 0.001000
Epoch 12/300, Loss: 2.9937, LR: 0.001000
Epoch 13/300, Loss: 2.9917, LR: 0.001000
Epoch 14/300, Loss: 2.9880, LR: 0.001000
Epoch 15/300, Loss: 2.9819, LR: 0.001000
Epoch 16/300, Loss: 2.9709, LR: 0.001000
Epoch 17/300, Loss: 2.9565, LR: 0.001000
Epoch 18/300, Loss: 2.9548, LR: 0.001000
Epoch 19/300, Loss: 2.9485, LR: 0.001000
Epoch 20/300, Loss: 2.9424, LR: 0.001000
Epoch 21/300, Loss: 2.9389, LR: 0.001000
Epoch 22/300, Loss: 2.9361, LR: 0.001000
Epoch 23/300, Loss: 2.9294, LR: 0.001000
Epoch 24/300, Loss: 2.9266, LR: 0.001000
Epoch 25/300, Loss: 2.923

  plt.show()


## Multi head attention code 

In [22]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import re
import os

# Print versions and device info
print("PyTorch Version:", torch.__version__)
print("NumPy Version:", np.__version__)
print("Matplotlib Version:", matplotlib.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))

# Set random seed for reproducibility
torch.manual_seed(42)

# Sample text (Indus Valley Civilization discovery)
text = """
In 1856, British colonial officials in India were busy monitoring the construction of a railway connecting the cities of Lahore and Karachi in modern-day Pakistan along the Indus River valley.
As they continued to work, some of the laborers discovered many fire-baked bricks lodged in the dry terrain. There were hundreds of thousands of fairly uniform bricks, which seemed to be quite old. Nonetheless, the workers used some of them to construct the road bed, unaware that they were using ancient artifacts. They soon found among the bricks stone artifacts made of soapstone, featuring intricate artistic markings.
Though they did not know it then, and though the first major excavations did not take place until the 1920s, these railway workers had happened upon the remnants of the Indus Valley Civilization, also known as the Harappan Civilization, after Harappa, the first of its sites to be excavated, in what was then the Punjab province of British India and is now in Pakistan. Initially, many archaeologists thought they had found ruins of the ancient Maurya Empire, a large empire which dominated ancient India between c. 322 and 185 BCE.
Before the excavation of these Harappan cities, scholars thought that Indian civilization had begun in the Ganges valley as Aryan immigrants from Persia and central Asia populated the region around 1250 BCE. The discovery of ancient Harappan cities unsettled that conception and moved the timeline back another 1500 years,situating the Indus Valley Civilization in an entirely different environmental context.
"""

# Preprocess text: lowercase, remove extra spaces
text = re.sub(r'\s+', ' ', text.lower()).strip()
chars = sorted(list(set(text)))
char_to_int = {c: i for i, c in enumerate(chars)}
int_to_char = {i: c for i, c in enumerate(chars)}
n_chars = len(text)
n_vocab = len(chars)

# Prepare data: sequences of length seq_length, predict next character
seq_length = 20
dataX, dataY = [], []
for i in range(0, n_chars - seq_length):
    seq_in = text[i:i + seq_length]
    seq_out = text[i + seq_length]
    dataX.append([char_to_int[char] for char in seq_in])
    dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print(f"Number of training sequences: {n_patterns}")

# Convert to tensors
X = torch.tensor(dataX, dtype=torch.float32).reshape(n_patterns, seq_length, 1) / float(n_vocab)
y = torch.tensor(dataY, dtype=torch.long)

# Define LSTM with multi-head attention
class LSTMWithMultiHeadAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, seq_length, num_heads=4, temperature=1.0):
        super(LSTMWithMultiHeadAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.seq_length = seq_length
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.temperature = temperature
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, dropout=0.3)
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size = x.size(0)
        lstm_out, _ = self.lstm(x)  # (batch_size, seq_length, hidden_dim)
        
        # Multi-head attention
        q = self.query(lstm_out).view(batch_size, self.seq_length, self.num_heads, self.head_dim)  # (batch_size, seq_length, num_heads, head_dim)
        k = self.key(lstm_out).view(batch_size, self.seq_length, self.num_heads, self.head_dim)
        v = self.value(lstm_out).view(batch_size, self.seq_length, self.num_heads, self.head_dim)
        
        q = q.transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
        k = k.transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
        v = v.transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (batch_size, num_heads, seq_length, seq_length)
        scores = scores / self.temperature
        attn_weights = self.softmax(scores)  # (batch_size, num_heads, seq_length, seq_length)
        
        context = torch.matmul(attn_weights, v)  # (batch_size, num_heads, seq_length, head_dim)
        context = context.transpose(1, 2).contiguous().view(batch_size, self.seq_length, self.hidden_dim)  # (batch_size, seq_length, hidden_dim)
        
        # Average attention weights across heads
        attn_weights = attn_weights.mean(dim=1)  # (batch_size, seq_length, seq_length)
        
        last_context = context[:, -1, :]  # (batch_size, hidden_dim)
        out = self.fc(last_context)  # (batch_size, output_dim)
        return out, attn_weights

# Hyperparameters
input_dim = 1
hidden_dim = 256
output_dim = n_vocab
batch_size = 32
num_heads = 4
temperature = 1.0  # Softer attention for more distributed focus

# Initialize model, loss, and optimizer
model = LSTMWithMultiHeadAttention(input_dim, hidden_dim, output_dim, seq_length, num_heads, temperature).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

# Move data to device
X = X.to(device)
y = y.to(device)

# Train the model
epochs = 300
for epoch in range(epochs):
    model.train()
    total_loss = 0
    num_batches = max(1, (n_patterns + batch_size - 1) // batch_size)
    for i in range(0, n_patterns, batch_size):
        batch_X = X[i:i+batch_size]
        batch_y = y[i:i+batch_size]
        optimizer.zero_grad()
        outputs, _ = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    avg_loss = total_loss / num_batches if num_batches > 0 else total_loss
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")

# Compute connectivity for a sample input
model.eval()
sample_input = text[:seq_length]
input_data = torch.tensor([[char_to_int[c] for c in sample_input]], dtype=torch.float32).reshape(1, seq_length, 1) / float(n_vocab)
input_data = input_data.to(device)
with torch.no_grad():
    _, attn_weights = model(input_data)
attn_weights = attn_weights.squeeze(0).cpu().numpy()  # Shape: (seq_length, seq_length)

# Debug attention weights
print("Attention weights shape:", attn_weights.shape)
print("Attention weights min:", attn_weights.min(), "max:", attn_weights.max())
print("Attention weights sample:", attn_weights[:3, :3])
print("Input characters:", [text[i] for i in range(seq_length)])
print("Output characters:", [text[i+1] for i in range(seq_length)])

# Highlight focus for output characters
for output_idx in range(5, 15):
    focus = attn_weights[output_idx]
    max_focus_idx = np.argmax(focus)
    print(f"When decoding output character {output_idx+1} ('{text[output_idx+1]}'), the model focuses most on input character {max_focus_idx+1} ('{text[max_focus_idx]}') with weight {focus[max_focus_idx]:.4f}")

# Analyze attention weights for input 't' (index 12)
t_idx = 12  # 0-based index of 't'
print("\nAttention weights for input character 't' (index 13):")
for output_idx in range(seq_length):
    weight = attn_weights[output_idx, t_idx]
    if weight > 0.1:  # Threshold to highlight significant weights
        print(f"Output character {output_idx+1} ('{text[output_idx+1]}'): {weight:.4f}")

# Normalize attention weights for better contrast
attn_weights = (attn_weights - attn_weights.min()) / (attn_weights.max() - attn_weights.min())

# Apply log scale for better visualization
attn_weights = np.log1p(attn_weights)  # log(1 + x) to handle small values

# Plot heatmap
plt.figure(figsize=(14, 12))
sns.heatmap(attn_weights, xticklabels=[text[i] for i in range(seq_length)],
            yticklabels=[text[i+1] for i in range(seq_length)],
            cmap='Blues', annot=True, fmt='.2f', cbar_kws={'label': 'Log Attention Weight'})
plt.title('LSTM Connectivity: Input-Output Character Dependencies', fontsize=16)
plt.xlabel('Input Characters', fontsize=14)
plt.ylabel('Output Characters', fontsize=14)
plt.xticks(rotation=45)
plt.yticks(rotation=0)

# Save heatmap to an absolute path
save_path = os.path.join(os.getcwd(), 'connectivity_heatmap_adjusted.png')
plt.savefig(save_path, bbox_inches='tight', dpi=300)
print(f"Heatmap saved to: {save_path}")
plt.close()

# Optional: Display plot in Jupyter (comment out if running as script)
# plt.show()


PyTorch Version: 2.6.0+cu124
NumPy Version: 1.26.4
Matplotlib Version: 3.9.2
Using device: cuda
GPU Name: NVIDIA GeForce RTX 3050
Number of training sequences: 1538
Epoch 1/300, Loss: 3.1501, LR: 0.001000




Epoch 2/300, Loss: 3.0202, LR: 0.001000
Epoch 3/300, Loss: 3.0140, LR: 0.001000
Epoch 4/300, Loss: 3.0084, LR: 0.001000
Epoch 5/300, Loss: 3.0046, LR: 0.001000
Epoch 6/300, Loss: 3.0022, LR: 0.001000
Epoch 7/300, Loss: 3.0003, LR: 0.001000
Epoch 8/300, Loss: 2.9988, LR: 0.001000
Epoch 9/300, Loss: 2.9976, LR: 0.001000
Epoch 10/300, Loss: 2.9966, LR: 0.001000
Epoch 11/300, Loss: 2.9956, LR: 0.001000
Epoch 12/300, Loss: 2.9948, LR: 0.001000
Epoch 13/300, Loss: 2.9941, LR: 0.001000
Epoch 14/300, Loss: 2.9934, LR: 0.001000
Epoch 15/300, Loss: 2.9928, LR: 0.001000
Epoch 16/300, Loss: 2.9922, LR: 0.001000
Epoch 17/300, Loss: 2.9916, LR: 0.001000
Epoch 18/300, Loss: 2.9911, LR: 0.001000
Epoch 19/300, Loss: 2.9906, LR: 0.001000
Epoch 20/300, Loss: 2.9900, LR: 0.001000
Epoch 21/300, Loss: 2.9888, LR: 0.001000
Epoch 22/300, Loss: 2.9860, LR: 0.001000
Epoch 23/300, Loss: 2.9882, LR: 0.001000
Epoch 24/300, Loss: 2.9952, LR: 0.001000
Epoch 25/300, Loss: 2.9901, LR: 0.001000
Epoch 26/300, Loss: 2.98

In [None]:

# # Optional: Display plot in Jupyter (comment out if running as script)
# # plt.show()

# # Save model weights (optional)
# torch.save(model.state_dict(), 'lstm_model_weights.pth')
# ```

# ### Changes Made
# 1. **Attention Mechanism**:
#    - Added a `value` layer (`self.value`) to make the attention mechanism more Transformer-like.
#    - Reduced `temperature` to 0.3 for even sharper focus.
#    - Increased dropout to 0.3 to prevent overfitting.

# 2. **Training**:
#    - Extended to 300 epochs for better convergence.
#    - Added gradient clipping (`clip_grad_norm_`) to stabilize training.
#    - Kept the learning rate scheduler.

# 3. **Visualization**:
#    - Increased figure size to `(14, 12)` for better readability.
#    - Added `annot=True` to display attention weights on the heatmap.
#    - Added a colorbar label (“Attention Weight”) for clarity.
#    - Rotated x-axis labels for better visibility.

# 4. **Debugging**:
#    - Expanded the focus analysis to cover output characters 6 to 15 (indices 5 to 14).
#    - Kept detailed debugging of attention weights (shape, min, max, sample).

# 5. **Distill Style**:
#    - Ensured each row has a clear peak (bright spot) by sharpening attention weights.
#    - Used `Blues` colormap and normalized weights for high contrast, matching the Distill article’s style.

# ### Expected Heatmap Appearance
# The heatmap should now closely resemble the Distill article’s “Connectivity” figure:
# - **Shape**: 20x20 grid (`seq_length=20`).
# - **Axes**:
#   - X-axis: Input characters (e.g., `"in 1856, british col"`).
#   - Y-axis: Output characters (e.g., `"n 1856, british colo"`).
# - **Patterns**:
#   - Each row (output character) should have a distinct bright spot (dark blue) indicating the input character with the highest attention weight.
#   - You might see diagonal or localized patterns, e.g., the model focusing on nearby characters or punctuation (e.g., `","` after `"1856"`).
#   - Example: When predicting the character after `"in 1856,"`, the model might focus on `","` or `"6"`, shown as a dark blue cell in that row.
# - **Annotations**: Each cell shows the attention weight (e.g., 0.85), making it easy to identify the primary focus.
# - **Debug Output**: The console will show specific focuses, e.g.:
#   ```
#   When decoding output character 6 ('5'), the model focuses most on input character 4 ('1') with weight 0.7234
#   ```
#   This indicates the model is using the `"1"` in `"1856"` to predict the next digit `"5"`, which is a reasonable dependency.

# ### Instructions to Run
# 1. **Verify Environment**:
#    - Ensure you’re in the `pytorch_env`:
#      ```bash
#      conda activate pytorch_env
#      ```
#    - Check versions:
#      ```bash
#      pip show pytorch numpy matplotlib
#      ```
#      Expected:
#      - PyTorch: 2.4.0
#      - NumPy: 2.2.4
#      - Matplotlib: >=3.9.0
#    - If missing, reinstall:
#      ```bash
#      conda install -c pytorch -c nvidia pytorch torchvision torchaudio pytorch-cuda=12.2
#      pip install numpy==2.2.4 matplotlib>=3.9.0 seaborn optree>=0.13.0
#      ```

# 2. **Check GPU**:
#    - Verify driver:
#      ```bash
#      nvidia-smi
#      ```
#      Ensure version >=522.06.
#    - Test CUDA:
#      ```python
#      import torch
#      print(torch.cuda.is_available())
#      print(torch.cuda.get_device_name(0))
#      ```

# 3. **Run the Code**:
#    - Save as `lstm_connectivity_pytorch_distill_final.py` or copy into a Jupyter Notebook cell.
#    - Run:
#      ```bash
#      python lstm_connectivity_pytorch_distill_final.py
#      ```
#    - Output:
#      - Console:
#        ```
#        PyTorch Version: 2.4.0
#        NumPy Version: 2.2.4
#        Matplotlib Version: 3.9.0
#        Using device: cuda
#        GPU Name: NVIDIA GeForce RTX 3060
#        Number of training sequences: ~594
#        Epoch 1/300, Loss: ..., LR: 0.001000
#        ...
#        Attention weights shape: (20, 20)
#        Attention weights min: ... max: ...
#        Attention weights sample: [[...]]
#        Input characters: ['i', 'n', ' ', '1', '8', '5', '6', ',', ' ', 'b', 'r', 'i', 't', 'i', 's', 'h', ' ', 'c', 'o', 'l']
#        Output characters: ['n', ' ', '1', '8', '5', '6', ',', ' ', 'b', 'r', 'i', 't', 'i', 's', 'h', ' ', 'c', 'o', 'l', 'o']
#        When decoding output character 6 ('5'), the model focuses most on input character ... ('...') with weight ...
#        ...
#        Heatmap saved to: C:\path\to\current\directory\connectivity_heatmap.png
#        ```
#      - Heatmap (`connectivity_heatmap.png`):
#        - Rows: Output characters.
#        - Columns: Input characters.
#        - Values: Attention weights with annotations, showing clear focus for each output character.

# 4. **Inspect the Heatmap**:
#    - Open `connectivity_heatmap.png` in the directory printed in the console.
#    - Look for a bright spot in each row indicating the input character the model focuses on.
#    - Verify that the patterns make sense (e.g., focus on nearby characters, punctuation, or word boundaries).

# ### Troubleshooting
# - **Uniform Heatmap**:
#   - Check the attention weights range:
#     - If `min` and `max` are close, reduce `temperature` further:
#       ```python
#       temperature = 0.1
#       ```
#     - Increase `epochs` to 400:
#       ```python
#       epochs = 400
#       ```

# - **Random Patterns**:
#   - Increase dropout:
#     ```python
#     self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, dropout=0.4)
#     ```
#   - Adjust learning rate decay:
#     ```python
#     scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
#     ```

# - **Unintuitive Focus**:
#   - Check the focus output for output characters 6–15. If the model focuses on unrelated characters (e.g., focusing on `"i"` when predicting `"5"`), the attention mechanism might need further tuning.
#   - Add a multi-head attention mechanism (I can provide this in a follow-up if needed).

# - **Visualization Issues**:
#   - If annotations are hard to read, reduce the font size:
#     ```python
#     sns.heatmap(attn_weights, ..., annot=True, fmt='.2f', annot_kws={'size': 8})
#     ```
#   - Try a different colormap:
#     ```python
#     sns.heatmap(attn_weights, ..., cmap='viridis')
#     ```

# ### Follow-Up
# - **Share Heatmap Details**: Describe the heatmap’s appearance (e.g., “each row has a clear bright spot,” “still uniform,” “focuses on punctuation”). Share the debug output (attention weights shape, min, max, sample, and focus for output characters 6–15).
# - **Expected Patterns**: If you expect specific focus (e.g., “should focus on the previous character,” “should focus on punctuation”), I can adjust the attention mechanism or training.
# - **Further Refinement**: Need more epochs, a different temperature, or a larger dataset? Let me know.
# - **GPU Check**: If training is slow (should take ~2-3 minutes for 300 epochs), share `nvidia-smi` output.

# Let me know how the heatmap looks now or share the debug output for further refinement!