In [1]:
# Import necessary packages and modules
import torch
from titan.memory_modules import MAGModule, MACModule, MALModule

# Set device (CPU/GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define model hyperparameters
d_model = 64  # Hidden size
batch_size = 2
seq_len = 32

# Instantiate the model variant, e.g., MAGModule
model = MAGModule(d_model=d_model).to(device)

# Print total number of trainable parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")

# Create dummy input data
x = torch.randn(batch_size, seq_len, d_model).to(device)

# Run a forward pass
output, mem_loss = model(x)

# Print output shape and associative memory loss
print("Output shape:", output.shape)
print("Associative memory loss:", mem_loss.item())

# --- Optionally, test other variants ---

# Testing MACModule:
# For MAC, simulate current segment and historical memory.
current_segment = x[:, :16, :]
historical_memory = x[:, 16:, :]  # Dummy historical memory for demonstration
mac_model = MACModule(d_model=d_model).to(device)
mac_output = mac_model(current_segment, historical_memory)
print("MAC output shape:", mac_output.shape)

# Testing MALModule:
mal_model = MALModule(d_model=d_model).to(device)
mal_output, mal_mem_loss = mal_model(x)
print("MAL output shape:", mal_output.shape)
print("MAL associative memory loss:", mal_mem_loss.item())


Total trainable parameters: 63232
Output shape: torch.Size([2, 32, 64])
Associative memory loss: 0.05827818438410759
MAC output shape: torch.Size([2, 42, 64])
MAL output shape: torch.Size([2, 32, 64])
MAL associative memory loss: 0.04597906768321991


In [None]:
!python trainMAG.py

Epoch 1:
  Train Loss: 0.4591  Train Accuracy: 0.7748
  Val Loss:   0.4756  Val Accuracy:   0.7718
Epoch 2:
  Train Loss: 0.2717  Train Accuracy: 0.8901
  Val Loss:   0.4942  Val Accuracy:   0.7993
Epoch 3:
  Train Loss: 0.2056  Train Accuracy: 0.9201
  Val Loss:   0.6838  Val Accuracy:   0.7901


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]
Generating train split: 100%|██████████| 67349/67349 [00:00<00:00, 1413513.44 examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]
Generating validation split: 100%|██████████| 872/872 [00:00<00:00, 176909.79 examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]
Generating test split: 100%|██████████| 1821/1821 [00:00<00:00, 672403.17 examples/s]


In [5]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
from titan.memory_modules import MAGModule  # or change to MACModule or MALModule as needed

# Define the Titan-based classifier model (same as used during training)
class TitanClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, num_classes, max_length):
        super(TitanClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        # Here, we use the MAG variant; change as needed
        self.titan = MAGModule(d_model=d_model)
        # Classification head: average pooled representation -> linear layer
        self.classifier = nn.Linear(d_model, num_classes)
    
    def forward(self, input_ids):
        # input_ids: (batch, seq_len)
        x = self.embedding(input_ids)  # (batch, seq_len, d_model)
        titan_out, mem_loss = self.titan(x)  # (batch, seq_len, d_model), scalar memory loss (aggregated)
        # Mean pooling over the sequence dimension
        x_pooled = titan_out.mean(dim=1)  # (batch, d_model)
        logits = self.classifier(x_pooled)  # (batch, num_classes)
        return logits, mem_loss

# Define a Dataset class for SST2 (validation split)
class SST2Dataset(Dataset):
    def __init__(self, split, tokenizer, max_length):
        self.samples = load_dataset("glue", "sst2")[split]
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        text = self.samples[idx]["sentence"]
        label = self.samples[idx]["label"]  # For validation, labels are provided.
        encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")
        input_ids = encoding["input_ids"].squeeze(0)  # shape: (max_length)
        return input_ids, label

# Hyperparameters (should match those used during training)
d_model = 64
num_classes = 2
max_length = 64
batch_size = 16

# Load tokenizer and get vocab size
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
vocab_size = tokenizer.vocab_size

# Instantiate the model and load saved weights
model = TitanClassifier(vocab_size, d_model, num_classes, max_length)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.load_state_dict(torch.load("titan_classifier.pt", map_location=device))
model.eval()

# Create the SST2 validation dataset and loader
val_dataset = SST2Dataset("validation", tokenizer, max_length)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Evaluate the model on the validation set
total_correct = 0
total_samples = 0
all_logits = []
all_labels = []
with torch.no_grad():
    for input_ids, labels in val_loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)
        logits, _ = model(input_ids)
        preds = torch.argmax(logits, dim=-1)
        total_correct += (preds == labels).sum().item()
        total_samples += input_ids.size(0)
        all_logits.append(logits.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

accuracy = total_correct / total_samples
print(f"Validation Accuracy: {accuracy:.4f}")

# Optionally, aggregate logits and labels for further analysis
all_logits = np.concatenate(all_logits, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

# You could also compute other metrics (e.g., F1 score) if desired


  from .autonotebook import tqdm as notebook_tqdm


Validation Accuracy: 0.7901


In [6]:
!python trainMAC.py

Epoch 1:
  Train Loss: 0.4561  Train Accuracy: 0.7825
  Val Loss:   0.4548  Val Accuracy:   0.7924
Epoch 2:
  Train Loss: 0.2780  Train Accuracy: 0.8876
  Val Loss:   0.4573  Val Accuracy:   0.7867
Epoch 3:
  Train Loss: 0.2144  Train Accuracy: 0.9166
  Val Loss:   0.4962  Val Accuracy:   0.7878


In [7]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np

# Import the MAC variant classifier.
# If you defined TitanClassifierMAC in your train_mac.py, ensure it's accessible in your PYTHONPATH.
from titan.memory_modules import MACModule
import torch.nn.functional as F

# Define TitanClassifierMAC (same as in train_mac.py)
class TitanClassifierMAC(nn.Module):
    def __init__(self, vocab_size, d_model, num_classes, max_length, persistent_len=10):
        super(TitanClassifierMAC, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        # MAC module from your titan package
        self.mac = MACModule(d_model, persistent_len=persistent_len)
        # Classification head: average pooling then a linear layer
        self.classifier = nn.Linear(d_model, num_classes)
    
    def forward(self, input_ids):
        # input_ids: (batch, seq_len)
        x = self.embedding(input_ids)  # (batch, seq_len, d_model)
        seq_len = x.size(1)
        half = seq_len // 2
        current_segment = x[:, :half, :]       # First half: current segment
        historical_memory = x[:, half:, :]       # Second half: historical memory
        
        # Pass segments through the MAC module.
        # MACModule concatenates persistent memory, historical memory, and current segment.
        mac_out = self.mac(current_segment, historical_memory)  # (batch, persistent_len + seq_len, d_model)
        
        # Mean pooling over the token dimension.
        pooled = mac_out.mean(dim=1)  # (batch, d_model)
        logits = self.classifier(pooled)  # (batch, num_classes)
        return logits

# Custom dataset for SST2
class SST2Dataset(Dataset):
    def __init__(self, split, tokenizer, max_length):
        self.samples = load_dataset("glue", "sst2")[split]
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        text = self.samples[idx]["sentence"]
        label = self.samples[idx]["label"]
        encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")
        input_ids = encoding["input_ids"].squeeze(0)
        return input_ids, label

# Hyperparameters (must match training configuration)
d_model = 64
num_classes = 2
max_length = 64
batch_size = 16

# Load tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
vocab_size = tokenizer.vocab_size
val_dataset = SST2Dataset("validation", tokenizer, max_length)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Instantiate the model and load saved weights.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TitanClassifierMAC(vocab_size, d_model, num_classes, max_length).to(device)
model.load_state_dict(torch.load("titan_classifier_mac.pt", map_location=device))
model.eval()

# Evaluate the model on the validation set.
total_correct = 0
total_samples = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for input_ids, labels in val_loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)
        logits = model(input_ids)
        preds = torch.argmax(logits, dim=-1)
        total_correct += (preds == labels).sum().item()
        total_samples += input_ids.size(0)
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

accuracy = total_correct / total_samples
print(f"Validation Accuracy: {accuracy:.4f}")

# Optionally, you can further compute metrics like F1-score, confusion matrix, etc.
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)


Validation Accuracy: 0.7878


In [8]:
!python trainMAL.py

Epoch 1:
  Train Loss: 0.4594  Train Accuracy: 0.7784
  Val Loss:   0.4682  Val Accuracy:   0.7867
Epoch 2:
  Train Loss: 0.2732  Train Accuracy: 0.8880
  Val Loss:   0.5034  Val Accuracy:   0.7867
Epoch 3:
  Train Loss: 0.2084  Train Accuracy: 0.9184
  Val Loss:   0.5879  Val Accuracy:   0.7752


In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np

# Define TitanClassifierMAL (should match your training definition)
# This model uses the MAL variant from your titan.memory_modules module.
from titan.memory_modules import MALModule

class TitanClassifierMAL(nn.Module):
    def __init__(self, vocab_size, d_model, num_classes, max_length):
        super(TitanClassifierMAL, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        # MALModule acts as a standalone memory layer.
        self.mal = MALModule(d_model=d_model)
        # Classification head: we apply mean pooling over the sequence and then a linear layer.
        self.classifier = nn.Linear(d_model, num_classes)
    
    def forward(self, input_ids):
        # input_ids shape: (batch, seq_len)
        x = self.embedding(input_ids)  # shape: (batch, seq_len, d_model)
        mal_out, mem_loss = self.mal(x)  # mal_out: (batch, seq_len, d_model)
        pooled = mal_out.mean(dim=1)     # shape: (batch, d_model)
        logits = self.classifier(pooled) # shape: (batch, num_classes)
        return logits

# Custom dataset for SST2 validation.
class SST2Dataset(Dataset):
    def __init__(self, split, tokenizer, max_length):
        self.samples = load_dataset("glue", "sst2")[split]
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        text = self.samples[idx]["sentence"]
        label = self.samples[idx]["label"]
        encoding = self.tokenizer(text,
                                  truncation=True,
                                  padding="max_length",
                                  max_length=self.max_length,
                                  return_tensors="pt")
        input_ids = encoding["input_ids"].squeeze(0)  # shape: (max_length)
        return input_ids, label

# Hyperparameters (must match those used during training)
d_model = 64
num_classes = 2
max_length = 64
batch_size = 16

# Load tokenizer and prepare dataset and DataLoader.
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
vocab_size = tokenizer.vocab_size

val_dataset = SST2Dataset("validation", tokenizer, max_length)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Instantiate the model and load saved weights.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TitanClassifierMAL(vocab_size, d_model, num_classes, max_length).to(device)
model.load_state_dict(torch.load("titan_classifier_mal.pt", map_location=device))
model.eval()

# Evaluate on the validation set.
total_correct = 0
total_samples = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for input_ids, labels in val_loader:
        input_ids = input_ids.to(device)
        labels = labels.to(device)
        logits = model(input_ids)
        preds = torch.argmax(logits, dim=-1)
        total_correct += (preds == labels).sum().item()
        total_samples += input_ids.size(0)
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

accuracy = total_correct / total_samples
print(f"Validation Accuracy: {accuracy:.4f}")

# Optionally, you can aggregate predictions and labels for further metrics.
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)


Validation Accuracy: 0.7752
