# Synthetic Thalamus for bAbI QA Tasks

This notebook demonstrates the use of the Synthetic Thalamus for reasoning tasks using the bAbI dataset, which contains 20 different types of text-based question answering tasks.

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import re
import urllib.request
import tarfile
import random

# Add the parent directory to the path
sys.path.append('..')
from core.thalamus import SyntheticThalamus
from core.adapters import TextAdapter

## Download and Prepare the bAbI Dataset

In [None]:
# Download the bAbI dataset
def download_babi_data():
    url = 'https://s3.amazonaws.com/text-datasets/babi_tasks_1-20_v1-2.tar.gz'
    filename = 'babi_tasks_1-20_v1-2.tar.gz'
    data_dir = 'data/babi'
    
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    
    if not os.path.exists(os.path.join(data_dir, filename)):
        print(f"Downloading bAbI dataset...")
        urllib.request.urlretrieve(url, os.path.join(data_dir, filename))
    
    if not os.path.exists(os.path.join(data_dir, 'tasks_1-20_v1-2')):
        print(f"Extracting bAbI dataset...")
        with tarfile.open(os.path.join(data_dir, filename), 'r:gz') as tar:
            tar.extractall(data_dir)
            
    print("Dataset ready.")
    return os.path.join(data_dir, 'tasks_1-20_v1-2', 'en-10k')

babi_dir = download_babi_data()

## Parse the bAbI Dataset

In [None]:
def parse_babi_task(task_id):
    """Parse a bAbI task file into stories, questions, and answers."""
    task_file = os.path.join(babi_dir, f"qa{task_id}_train.txt")
    
    stories, questions, answers = [], [], []
    story = []
    
    with open(task_file, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
                
            nid, text = line.split(' ', 1)
            nid = int(nid)
            
            if nid == 1:
                story = []
                
            if '\t' in text:  # This line contains a question
                q, a = text.split('\t')
                q = q.strip()
                a = a.strip()
                
                # Store the story, question, and answer
                stories.append(' '.join(story))
                questions.append(q)
                answers.append(a)
            else:
                story.append(text)
                
    return stories, questions, answers

# Parse bAbI task 1 (single supporting fact)
stories, questions, answers = parse_babi_task(1)

print(f"Number of examples: {len(stories)}")
print("\nExample:")
print(f"Story: {stories[0]}")
print(f"Question: {questions[0]}")
print(f"Answer: {answers[0]}")

## Prepare Tokenization and Vocabulary

In [None]:
def create_vocabulary(stories, questions, answers):
    """Create a vocabulary from all words in the dataset."""
    vocab = set()
    for story in stories:
        for word in story.split():
            vocab.add(word.lower())
    for question in questions:
        for word in question.split():
            vocab.add(word.lower())
    for answer in answers:
        vocab.add(answer.lower())
    
    # Create word-to-index mapping
    word_to_idx = {'<PAD>': 0, '<UNK>': 1}
    for word in vocab:
        word_to_idx[word] = len(word_to_idx)
    
    return word_to_idx

def tokenize(text, word_to_idx, max_len=None):
    """Convert text to token indices."""
    tokens = [word_to_idx.get(word.lower(), word_to_idx['<UNK>']) for word in text.split()]
    if max_len is not None:
        if len(tokens) < max_len:
            tokens += [word_to_idx['<PAD>']] * (max_len - len(tokens))
        else:
            tokens = tokens[:max_len]
    return tokens

# Create vocabulary
word_to_idx = create_vocabulary(stories, questions, answers)
vocab_size = len(word_to_idx)
print(f"Vocabulary size: {vocab_size}")

# Create answer-to-index mapping
answer_to_idx = {}
for answer in set(answers):
    answer_to_idx[answer.lower()] = len(answer_to_idx)
num_answers = len(answer_to_idx)
print(f"Number of unique answers: {num_answers}")

## Create bAbI Dataset and DataLoader

In [None]:
class BabiDataset(Dataset):
    def __init__(self, stories, questions, answers, word_to_idx, answer_to_idx, max_story_len=100, max_question_len=20):
        self.stories = stories
        self.questions = questions
        self.answers = answers
        self.word_to_idx = word_to_idx
        self.answer_to_idx = answer_to_idx
        self.max_story_len = max_story_len
        self.max_question_len = max_question_len
        
    def __len__(self):
        return len(self.stories)
    
    def __getitem__(self, idx):
        story = tokenize(self.stories[idx], self.word_to_idx, self.max_story_len)
        question = tokenize(self.questions[idx], self.word_to_idx, self.max_question_len)
        answer = self.answer_to_idx[self.answers[idx].lower()]
        
        return {
            'story': torch.tensor(story, dtype=torch.long),
            'question': torch.tensor(question, dtype=torch.long),
            'answer': torch.tensor(answer, dtype=torch.long)
        }

# Create dataset and dataloader
dataset = BabiDataset(stories, questions, answers, word_to_idx, answer_to_idx)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

## Define the Thalamus-based QA Model

In [None]:
class ThalamusQAModel(pl.LightningModule):
    def __init__(self, vocab_size, num_answers, d_model=128, task_id=0):
        super().__init__()
        self.save_hyperparameters()
        
        # Embedding layers for story and question
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # Story encoder
        self.story_encoder = nn.LSTM(
            input_size=d_model,
            hidden_size=d_model,
            batch_first=True,
            bidirectional=True
        )
        
        # Question encoder
        self.question_encoder = nn.LSTM(
            input_size=d_model,
            hidden_size=d_model,
            batch_first=True,
            bidirectional=True
        )
        
        # Synthetic thalamus for token gating
        self.thalamus = SyntheticThalamus(
            d_model=d_model * 2,  # Bidirectional LSTM output
            n_heads=4,
            k=16,
            phase_dim=16,
            task_dim=64,
            num_tasks=20  # 20 different bAbI tasks
        )
        
        # Workspace transformer for reasoning
        self.workspace = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model * 2 + 16,  # Bidirectional LSTM + phase dim
                nhead=4,
                batch_first=True
            ),
            num_layers=2
        )
        
        # Answer predictor
        self.answer_predictor = nn.Sequential(
            nn.Linear(d_model * 2 + 16, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_answers)
        )
        
        self.task_id = task_id
        self.loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, story, question, task_id=None):
        if task_id is None:
            task_id = torch.tensor([self.task_id] * story.size(0), dtype=torch.long, device=story.device)
        
        # Embed story and question
        story_emb = self.embedding(story)     # [B, story_len, d_model]
        question_emb = self.embedding(question)  # [B, question_len, d_model]
        
        # Encode story and question
        story_encoded, _ = self.story_encoder(story_emb)  # [B, story_len, 2*d_model]
        question_encoded, _ = self.question_encoder(question_emb)  # [B, question_len, 2*d_model]
        
        # Use question encoding as context for the thalamus
        # Average the question encoding to get a single vector
        question_context = question_encoded.mean(dim=1, keepdim=True)  # [B, 1, 2*d_model]
        
        # Apply synthetic thalamus to story tokens with question as context
        gated_tokens = self.thalamus(story_encoded, task_id, question_context)  # [B, k, 2*d_model+phase_dim]
        
        # Process gated tokens with transformer workspace
        workspace_output = self.workspace(gated_tokens)  # [B, k, 2*d_model+phase_dim]
        
        # Pooling (mean across tokens)
        pooled = workspace_output.mean(dim=1)  # [B, 2*d_model+phase_dim]
        
        # Predict answer
        logits = self.answer_predictor(pooled)  # [B, num_answers]
        return logits
    
    def training_step(self, batch, batch_idx):
        story = batch['story']
        question = batch['question']
        answer = batch['answer']
        
        logits = self(story, question)
        loss = self.loss_fn(logits, answer)
        
        # Log accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == answer).float().mean()
        
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        story = batch['story']
        question = batch['question']
        answer = batch['answer']
        
        logits = self(story, question)
        loss = self.loss_fn(logits, answer)
        
        # Log accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == answer).float().mean()
        
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

## Train/Validation Split and Trainer Setup

In [None]:
# Create train/validation split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Initialize model
model = ThalamusQAModel(vocab_size=vocab_size, num_answers=num_answers, task_id=0)

# Initialize trainer
trainer = pl.Trainer(
    max_epochs=10,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else None,
    callbacks=[
        pl.callbacks.ModelCheckpoint(monitor='val_acc', mode='max'),
        pl.callbacks.EarlyStopping(monitor='val_loss', patience=3)
    ]
)

## Train the Model

In [None]:
# Train the model
trainer.fit(model, train_loader, val_loader)

## Analyze the Thalamus Behavior

In [None]:
# Function to visualize which story tokens were gated by the thalamus
def visualize_token_gating(model, story, question, answer, word_to_idx, idx_to_word=None):
    if idx_to_word is None:
        idx_to_word = {idx: word for word, idx in word_to_idx.items()}
    
    # Tokenize inputs
    story_tokens = tokenize(story, word_to_idx, max_len=100)
    question_tokens = tokenize(question, word_to_idx, max_len=20)
    
    # Convert to tensors and add batch dimension
    story_tensor = torch.tensor(story_tokens, dtype=torch.long).unsqueeze(0)
    question_tensor = torch.tensor(question_tokens, dtype=torch.long).unsqueeze(0)
    task_id = torch.tensor([0], dtype=torch.long)
    
    model.eval()
    with torch.no_grad():
        # Embed story and question
        story_emb = model.embedding(story_tensor)
        question_emb = model.embedding(question_tensor)
        
        # Encode story and question
        story_encoded, _ = model.story_encoder(story_emb)
        question_encoded, _ = model.question_encoder(question_emb)
        
        # Use question encoding as context for the thalamus
        question_context = question_encoded.mean(dim=1, keepdim=True)
        
        # Extract salience scores from the thalamus (without gating)
        B, N, D = story_encoded.size()
        task_embedding = model.thalamus.task_embed(task_id)
        task_embedding = model.thalamus.task_proj(task_embedding)
        task_embedding = task_embedding.unsqueeze(1).expand(-1, N, -1)
        x_combined = story_encoded + task_embedding
        x_attn, _ = model.thalamus.scorer(x_combined, question_context, question_context)
        scores = x_attn.norm(dim=-1)  # [B, N]
        
        # For visualization, get logits and prediction
        logits = model(story_tensor, question_tensor)
        pred = torch.argmax(logits, dim=1).item()
        
        # Get topk indices
        _, topk_indices = scores[0].topk(model.thalamus.k)
        topk_indices = topk_indices.cpu().numpy()
    
    # Map tokens back to words
    story_words = [idx_to_word[idx] for idx in story_tokens if idx != word_to_idx['<PAD>']]
    
    # Visualize
    plt.figure(figsize=(15, 5))
    plt.bar(range(len(story_words)), scores[0, :len(story_words)].cpu().numpy())
    
    # Highlight top-k tokens
    for idx in topk_indices:
        if idx < len(story_words):
            plt.bar(idx, scores[0, idx].cpu().numpy(), color='red')
    
    plt.xticks(range(len(story_words)), story_words, rotation=45, ha='right')
    plt.title(f"Question: {question} | Predicted Answer: {pred} | True Answer: {answer}")
    plt.xlabel('Story Tokens')
    plt.ylabel('Salience Score')
    plt.tight_layout()
    plt.show()
    
    # Print the gated tokens
    print("Gated tokens:")
    for idx in topk_indices:
        if idx < len(story_words):
            print(f"- {story_words[idx]} (score: {scores[0, idx]:.4f})")

# Create inverse mapping for visualization
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

# Get a sample for visualization
sample_idx = random.randint(0, len(val_dataset) - 1)
sample = val_dataset[sample_idx]
story_tokens = sample['story'].tolist()
question_tokens = sample['question'].tolist()
answer_idx = sample['answer'].item()

# Convert back to words for display
story_words = ' '.join([idx_to_word[idx] for idx in story_tokens if idx != word_to_idx['<PAD>']])
question_words = ' '.join([idx_to_word[idx] for idx in question_tokens if idx != word_to_idx['<PAD>']])
answer_word = list(answer_to_idx.keys())[list(answer_to_idx.values()).index(answer_idx)]

print("Sample Story:")
print(story_words)
print("\nQuestion:")
print(question_words)
print("\nAnswer:")
print(answer_word)

# Visualize the token gating
visualize_token_gating(model, story_words, question_words, answer_word, word_to_idx, idx_to_word)

## Experiment with Different Task IDs

The synthetic thalamus can adapt its gating behavior based on the task ID. Let's investigate how different task IDs affect the model's performance and gating pattern.

In [None]:
def evaluate_with_different_tasks(model, dataset, num_tasks=5):
    """Evaluate the model with different task IDs and visualize differences."""
    dataloader = DataLoader(dataset, batch_size=32)
    
    task_accuracies = []
    
    for task_id in range(num_tasks):
        correct = 0
        total = 0
        
        model.eval()
        with torch.no_grad():
            for batch in dataloader:
                story = batch['story']
                question = batch['question']
                answer = batch['answer']
                
                # Set the task ID explicitly
                task_ids = torch.full((story.size(0),), task_id, dtype=torch.long)
                
                logits = model(story, question, task_ids)
                preds = torch.argmax(logits, dim=1)
                
                total += answer.size(0)
                correct += (preds == answer).sum().item()
        
        accuracy = correct / total
        task_accuracies.append(accuracy)
        print(f"Task ID {task_id}: Accuracy = {accuracy:.4f}")
    
    # Plot accuracies
    plt.figure(figsize=(10, 5))
    plt.bar(range(num_tasks), task_accuracies)
    plt.xlabel('Task ID')
    plt.ylabel('Accuracy')
    plt.title('Model Performance with Different Task IDs')
    plt.xticks(range(num_tasks))
    plt.ylim(0, 1.0)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()
    
    # Visualize the gating pattern for the same example with different task IDs
    sample_idx = random.randint(0, len(dataset) - 1)
    sample = dataset[sample_idx]
    
    story_tokens = sample['story'].tolist()
    question_tokens = sample['question'].tolist()
    answer_idx = sample['answer'].item()
    
    # Convert back to words for display
    story_words = ' '.join([idx_to_word[idx] for idx in story_tokens if idx != word_to_idx['<PAD>']])
    question_words = ' '.join([idx_to_word[idx] for idx in question_tokens if idx != word_to_idx['<PAD>']])
    answer_word = list(answer_to_idx.keys())[list(answer_to_idx.values()).index(answer_idx)]
    
    print("\nSample Story:")
    print(story_words)
    print("\nQuestion:")
    print(question_words)
    print("\nAnswer:")
    print(answer_word)
    
    # Compare gating patterns for each task ID
    plt.figure(figsize=(15, 10))
    
    for i, task_id in enumerate(range(min(4, num_tasks))):
        # Prepare inputs
        story_tensor = torch.tensor(story_tokens, dtype=torch.long).unsqueeze(0)
        question_tensor = torch.tensor(question_tokens, dtype=torch.long).unsqueeze(0)
        task_tensor = torch.tensor([task_id], dtype=torch.long)
        
        with torch.no_grad():
            # Process through the model
            story_emb = model.embedding(story_tensor)
            question_emb = model.embedding(question_tensor)
            
            story_encoded, _ = model.story_encoder(story_emb)
            question_encoded, _ = model.question_encoder(question_emb)
            
            question_context = question_encoded.mean(dim=1, keepdim=True)
            
            # Extract salience scores
            B, N, D = story_encoded.size()
            task_embedding = model.thalamus.task_embed(task_tensor)
            task_embedding = model.thalamus.task_proj(task_embedding)
            task_embedding = task_embedding.unsqueeze(1).expand(-1, N, -1)
            x_combined = story_encoded + task_embedding
            x_attn, _ = model.thalamus.scorer(x_combined, question_context, question_context)
            scores = x_attn.norm(dim=-1)[0, :len(story_words)].cpu().numpy()
            
            # Make prediction
            logits = model(story_tensor, question_tensor, task_tensor)
            pred = torch.argmax(logits, dim=1).item()
            pred_word = list(answer_to_idx.keys())[list(answer_to_idx.values()).index(pred)]
            
            # Plot
            plt.subplot(2, 2, i+1)
            plt.bar(range(len(story_words)), scores)
            plt.title(f"Task ID {task_id} | Predicted: {pred_word}")
            plt.xticks(range(len(story_words)), [idx_to_word[idx] for idx in story_tokens[:len(story_words)]], 
                      rotation=45, ha='right', fontsize=8)
            plt.ylabel('Salience Score')
    
    plt.tight_layout()
    plt.show()

# Evaluate with different task IDs
evaluate_with_different_tasks(model, val_dataset, num_tasks=5)

## Phase Analysis

In [None]:
def visualize_phase_embeddings(model, sample):
    """Visualize the phase embeddings produced by the thalamus."""
    story_tokens = sample['story']
    question_tokens = sample['question']
    answer_idx = sample['answer'].item()
    
    # Add batch dimension
    story_tensor = story_tokens.unsqueeze(0)
    question_tensor = question_tokens.unsqueeze(0)
    task_id = torch.tensor([0], dtype=torch.long)
    
    model.eval()
    with torch.no_grad():
        # Get the embeddings
        story_emb = model.embedding(story_tensor)
        question_emb = model.embedding(question_tensor)
        
        # Encode story and question
        story_encoded, _ = model.story_encoder(story_emb)
        question_encoded, _ = model.question_encoder(question_emb)
        
        # Use question encoding as context for the thalamus
        question_context = question_encoded.mean(dim=1, keepdim=True)
        
        # Apply thalamus
        gated_tokens = model.thalamus(story_encoded, task_id, question_context)
        
        # Extract the phase embeddings (last phase_dim dimensions)
        phase_dim = model.thalamus.phase_dim
        phase_embeddings = gated_tokens[0, :, -phase_dim:].cpu().numpy()
        
        # Get token indices for the top-k gated tokens
        B, N, D = story_encoded.size()
        task_embedding = model.thalamus.task_embed(task_id)
        task_embedding = model.thalamus.task_proj(task_embedding)
        task_embedding = task_embedding.unsqueeze(1).expand(-1, N, -1)
        x_combined = story_encoded + task_embedding
        x_attn, _ = model.thalamus.scorer(x_combined, question_context, question_context)
        scores = x_attn.norm(dim=-1)  # [B, N]
        _, topk_indices = scores[0].topk(model.thalamus.k)
        topk_indices = topk_indices.cpu().numpy()
    
    # Convert token indices to words for display
    story_words = [idx_to_word[idx.item()] for idx in story_tokens if idx != word_to_idx['<PAD>']]
    gated_words = [story_words[idx] if idx < len(story_words) else '<PAD>' for idx in topk_indices]
    
    # Visualize phase embeddings
    plt.figure(figsize=(12, 8))
    
    # Plot as heatmap
    plt.subplot(2, 1, 1)
    im = plt.imshow(phase_embeddings, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
    plt.colorbar(im, label='Phase Value')
    plt.yticks(range(len(gated_words)), gated_words)
    plt.xlabel('Phase Dimension')
    plt.title('Phase Embeddings by Token')
    
    # Plot phase values as sine waves
    plt.subplot(2, 1, 2)
    for i, word in enumerate(gated_words[:8]):  # Show only first 8 for clarity
        plt.plot(phase_embeddings[i], label=f"{word}")
    plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
    plt.xlabel('Phase Dimension')
    plt.ylabel('Phase Value')
    plt.title('Phase Embeddings as Sine Waves')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# Get a sample for visualization
sample_idx = random.randint(0, len(val_dataset) - 1)
sample = val_dataset[sample_idx]

# Visualize phase embeddings
visualize_phase_embeddings(model, sample)

## Multi-Task Learning

The synthetic thalamus architecture is particularly well-suited for multi-task learning. Let's create a simple experiment with multiple bAbI tasks.

In [None]:
def prepare_multitask_data(task_ids=[1, 2, 3]):
    """Prepare a dataset with multiple bAbI tasks."""
    all_stories, all_questions, all_answers = [], [], []
    task_markers = []
    
    for task_id in task_ids:
        stories, questions, answers = parse_babi_task(task_id)
        all_stories.extend(stories)
        all_questions.extend(questions)
        all_answers.extend(answers)
        task_markers.extend([task_id-1] * len(stories))  # 0-indexed task IDs
    
    # Create vocabulary and answer mapping
    word_to_idx = create_vocabulary(all_stories, all_questions, all_answers)
    vocab_size = len(word_to_idx)
    
    answer_to_idx = {}
    for answer in set(all_answers):
        answer_to_idx[answer.lower()] = len(answer_to_idx)
    num_answers = len(answer_to_idx)
    
    print(f"Total examples: {len(all_stories)}")
    print(f"Vocabulary size: {vocab_size}")
    print(f"Number of unique answers: {num_answers}")
    
    # Create a dataset class with task IDs
    class MultitaskBabiDataset(Dataset):
        def __init__(self, stories, questions, answers, task_ids, word_to_idx, answer_to_idx):
            self.stories = stories
            self.questions = questions
            self.answers = answers
            self.task_ids = task_ids
            self.word_to_idx = word_to_idx
            self.answer_to_idx = answer_to_idx
            
        def __len__(self):
            return len(self.stories)
        
        def __getitem__(self, idx):
            story = tokenize(self.stories[idx], self.word_to_idx, max_len=100)
            question = tokenize(self.questions[idx], self.word_to_idx, max_len=20)
            answer = self.answer_to_idx[self.answers[idx].lower()]
            task_id = self.task_ids[idx]
            
            return {
                'story': torch.tensor(story, dtype=torch.long),
                'question': torch.tensor(question, dtype=torch.long),
                'answer': torch.tensor(answer, dtype=torch.long),
                'task_id': torch.tensor(task_id, dtype=torch.long)
            }
    
    # Create the dataset
    dataset = MultitaskBabiDataset(all_stories, all_questions, all_answers, task_markers,
                                  word_to_idx, answer_to_idx)
    
    return dataset, word_to_idx, answer_to_idx

# Prepare multitask data
multitask_dataset, mt_word_to_idx, mt_answer_to_idx = prepare_multitask_data(task_ids=[1, 2, 3])

# Create train/val split
train_size = int(0.8 * len(multitask_dataset))
val_size = len(multitask_dataset) - train_size
mt_train_dataset, mt_val_dataset = torch.utils.data.random_split(multitask_dataset, [train_size, val_size])

mt_train_loader = DataLoader(mt_train_dataset, batch_size=32, shuffle=True)
mt_val_loader = DataLoader(mt_val_dataset, batch_size=32)

# Initialize a new model for multitask learning
multitask_model = ThalamusQAModel(vocab_size=len(mt_word_to_idx), num_answers=len(mt_answer_to_idx))

# Modify training_step to use the provided task_id
def training_step(self, batch, batch_idx):
    story = batch['story']
    question = batch['question']
    answer = batch['answer']
    task_id = batch['task_id']
    
    logits = self(story, question, task_id)
    loss = self.loss_fn(logits, answer)
    
    # Log accuracy
    preds = torch.argmax(logits, dim=1)
    acc = (preds == answer).float().mean()
    
    self.log('train_loss', loss)
    self.log('train_acc', acc)
    
    return loss

def validation_step(self, batch, batch_idx):
    story = batch['story']
    question = batch['question']
    answer = batch['answer']
    task_id = batch['task_id']
    
    logits = self(story, question, task_id)
    loss = self.loss_fn(logits, answer)
    
    # Log accuracy
    preds = torch.argmax(logits, dim=1)
    acc = (preds == answer).float().mean()
    
    self.log('val_loss', loss)
    self.log('val_acc', acc)
    
    return loss

# Replace the training and validation step methods
multitask_model.training_step = training_step.__get__(multitask_model)
multitask_model.validation_step = validation_step.__get__(multitask_model)

# Initialize trainer
mt_trainer = pl.Trainer(
    max_epochs=10,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else None,
    callbacks=[
        pl.callbacks.ModelCheckpoint(monitor='val_acc', mode='max'),
        pl.callbacks.EarlyStopping(monitor='val_loss', patience=3)
    ]
)

## Train the Multitask Model

In [None]:
# Train the multitask model
mt_trainer.fit(multitask_model, mt_train_loader, mt_val_loader)

## Evaluate Task-Specific Performance

In [None]:
def evaluate_per_task(model, dataset):
    """Evaluate the model's performance on each task separately."""
    task_correct = {0: 0, 1: 0, 2: 0}
    task_total = {0: 0, 1: 0, 2: 0}
    
    dataloader = DataLoader(dataset, batch_size=32)
    
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            story = batch['story']
            question = batch['question']
            answer = batch['answer']
            task_id = batch['task_id']
            
            logits = model(story, question, task_id)
            preds = torch.argmax(logits, dim=1)
            
            # Update counters per task
            for i in range(len(task_id)):
                t_id = task_id[i].item()
                task_total[t_id] += 1
                if preds[i] == answer[i]:
                    task_correct[t_id] += 1
    
    # Calculate and print accuracies
    print("Task-specific performance:")
    accuracies = {}
    for task_id in task_correct.keys():
        if task_total[task_id] > 0:
            acc = task_correct[task_id] / task_total[task_id]
            accuracies[task_id] = acc
            print(f"Task {task_id+1}: Accuracy = {acc:.4f} ({task_correct[task_id]}/{task_total[task_id]})")
    
    # Plot results
    plt.figure(figsize=(10, 5))
    task_ids = list(accuracies.keys())
    accs = [accuracies[t] for t in task_ids]
    
    plt.bar([t+1 for t in task_ids], accs)
    plt.xlabel('Task ID')
    plt.ylabel('Accuracy')
    plt.title('Model Performance by Task')
    plt.xticks([t+1 for t in task_ids])
    plt.ylim(0, 1.0)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

# Evaluate the multitask model
evaluate_per_task(multitask_model, mt_val_dataset)

## Conclusion

In this notebook, we've demonstrated how the Synthetic Thalamus can be used for question-answering tasks, particularly using the bAbI dataset. The key insights include:

1. **Task-Conditioned Processing**: The thalamus can adapt its gating behavior based on the task ID, allowing the model to share parameters while maintaining task-specific processing.

2. **Selective Attention**: The thalamus focuses computational resources on the most relevant parts of the input story, similar to how human attention works.

3. **Phase Embeddings**: The phase tags add an extra dimension to token representation that could enable more complex context-dependent processing.

4. **Multi-Task Learning**: The thalamus architecture naturally supports multi-task learning, allowing the model to learn shared representations while maintaining task-specific behaviors.

These properties make the Synthetic Thalamus a promising approach for building more efficient and adaptable neural network architectures for language understanding and reasoning tasks.