In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import BertTokenizer
from sklearn.metrics import accuracy_score
from lmu import LMUCell

KeyboardInterrupt: 

In [None]:
dataset = load_dataset("dair-ai/emotion", "split")
label_list = dataset['train'].features['label'].names  # emotion labels


In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
MAX_LEN = 128

def tokenize_batch(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=MAX_LEN)

tokenized_train = dataset["train"].map(tokenize_batch, batched=True)
tokenized_val = dataset["validation"].map(tokenize_batch, batched=True)
tokenized_test = dataset["test"].map(tokenize_batch, batched=True)


In [None]:
class EmotionDataset(Dataset):
    def __init__(self, hf_dataset):
        self.input_ids = hf_dataset["input_ids"]
        self.attention_mask = hf_dataset["attention_mask"]
        self.labels = hf_dataset["label"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.input_ids[idx]),
            "attention_mask": torch.tensor(self.attention_mask[idx]),
            "labels": torch.tensor(self.labels[idx]),
        }

train_ds = EmotionDataset(tokenized_train)
val_ds = EmotionDataset(tokenized_val)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)


In [None]:
# class LMUWrapper(nn.Module):
#     def __init__(self, input_size, hidden_size, memory_size):
#         super().__init__()
#         self.lmu_cell = LMUCell(input_size, hidden_size, memory_size)
#         self.hidden_size = hidden_size

#     def forward(self, x):
#         batch_size, seq_len, input_size = x.shape
#         h = torch.zeros(batch_size, self.hidden_size, device=x.device)
#         c = torch.zeros(batch_size, self.hidden_size, device=x.device)
#         outputs = []
#         for t in range(seq_len):
#             h, c = self.lmu_cell(x[:, t, :], (h, c))
#             outputs.append(h.unsqueeze(1))
#         return torch.cat(outputs, dim=1)  # [batch, seq_len, hidden_size]


class LMUWrapper(nn.Module):
    def __init__(self, input_size, hidden_size, memory_size):
        super().__init__()
        hidden_cell = nn.Linear(input_size + memory_size, hidden_size)
        self.lmu_cell = LMUCell(input_size, hidden_size, memory_size, hidden_cell)
        self.hidden_size = hidden_size

    def forward(self, x):
        batch_size, seq_len, input_size = x.shape
        h = torch.zeros(batch_size, self.hidden_size, device=x.device)
        c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        outputs = []
        for t in range(seq_len):
            h, c = self.lmu_cell(x[:, t, :], (h, c))
            outputs.append(h.unsqueeze(1))
        return torch.cat(outputs, dim=1)  # [batch, seq_len, hidden_size]


In [None]:
class EmotionLMUClassifier(nn.Module):
    def __init__(self, num_labels, hidden_size=128, memory_size=64):
        super().__init__()
        self.bert = nn.Embedding(30522, 768)  # BERT vocab size
        self.bert.weight.data = tokenizer.get_vocab().values()
        self.embedding_dim = 768
        self.lmu = LMUWrapper(self.embedding_dim, hidden_size, memory_size)
        self.fc = nn.Linear(hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        x = self.bert(input_ids)  # [batch, seq_len, 768]
        lmu_out = self.lmu(x)
        final_state = lmu_out[:, -1, :]  # use last hidden
        return self.fc(final_state)

In [None]:
import torch
import torch.nn as nn
import math


class SimpleLMU(nn.Module):
    def __init__(self, input_size, hidden_size, memory_size, theta=100):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.memory_size = memory_size

        # Learnable input -> hidden
        self.input_to_hidden = nn.Linear(input_size, hidden_size)
        self.input_to_memory = nn.Linear(input_size, memory_size)

        # Hidden -> output
        self.fc = nn.Linear(hidden_size, hidden_size)

        # Legendre memory matrices (fixed)
        Q = torch.arange(memory_size, dtype=torch.float32)
        R = 2 * Q + 1
        A = torch.outer(R, Q) - torch.outer(Q, R)
        B = R.unsqueeze(1)

        A = A * (1 / theta)
        B = B * (1 / theta)

        self.register_buffer('A', A)
        self.register_buffer('B', B)
        self.U = None  # Internal state placeholder

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        device = x.device

        H = torch.zeros(batch_size, self.hidden_size, device=device)
        M = torch.zeros(batch_size, self.memory_size, device=device)

        outputs = []
        for t in range(seq_len):
            u_t = x[:, t, :]
            m_t = M @ self.A.T + self.input_to_memory(u_t) @ self.B  # FIX here, no B transpose
            h_t = torch.tanh(self.input_to_hidden(u_t) + self.fc(H))
            H = h_t
            M = m_t
            outputs.append(h_t.unsqueeze(1))


        return torch.cat(outputs, dim=1)  # [batch, seq_len, hidden]


In [None]:
from transformers import BertModel

class EmotionLMUClassifier(nn.Module):
    def __init__(self, num_labels, hidden_size=128, memory_size=64):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.embedding_dim = self.bert.config.hidden_size
        self.lmu = SimpleLMU(self.embedding_dim, hidden_size, memory_size)
        self.fc = nn.Linear(hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state  # [batch, seq_len, 768]
        lmu_out = self.lmu(embeddings)
        final_state = lmu_out[:, -1, :]  # use last time step
        return self.fc(final_state)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)
model = EmotionLMUClassifier(num_labels=len(label_list)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
loss_fn = nn.CrossEntropyLoss()

# Training loop
EPOCHS = 100
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attn_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attn_mask)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Train Loss: {total_loss/len(train_loader):.4f}")
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attn_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids, attn_mask)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    print(f"Validation Accuracy: {acc:.4f}")


cuda
Epoch 1 | Train Loss: 1.3683
Validation Accuracy: 0.5360


KeyboardInterrupt: 