In [1]:
!pip install transformers datasets conceptnet-lite torch tqdm



In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel, get_cosine_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
import conceptnet_lite
import re

conceptnet_lite.connect()

In [3]:
# Load CommonsenseQA subset
raw_train = load_dataset("tau/commonsense_qa", split="train[:-1000]")
raw_valid = load_dataset("tau/commonsense_qa", split="train[-1000:]")
raw_test = load_dataset("tau/commonsense_qa", split="validation")

def preprocess(example):
    inputs = []
    for choice in example["choices"]["text"]:
        text = example["question"] + " " + choice
        inputs.append(text)
    return {"inputs": inputs, "answer_idx": ord(example["answerKey"]) - ord("A")}

train_data = [preprocess(ex) for ex in raw_train]
valid_data = [preprocess(ex) for ex in raw_valid]
test_data  = [preprocess(ex) for ex in raw_test]


In [4]:
def extract_conceptnet_facts(text):
    words = re.findall(r'\b\w{4,}\b', text.lower())  # filter short words
    facts = []
    for word in words:
        try:
            concept = conceptnet_lite.Concept.get(label=word)
            edges = concept.edges(direction='both')[:3]
            for e in edges:
                facts.append(f"{e.start} {e.relation} {e.end}")
        except Exception:
            continue
    return ". ".join(facts[:5]) if facts else ""


In [5]:
symbolic_tokenizer = AutoTokenizer.from_pretrained("roberta-large")
symbolic_encoder = AutoModel.from_pretrained("roberta-large")

def get_symbolic_embedding(fact_text, device):
    if not fact_text:
        return torch.zeros(1, 256).to(device)
    tokens = symbolic_tokenizer(fact_text, truncation=True, padding="max_length", max_length=16, return_tensors="pt").to(device)
    with torch.no_grad():
        out = symbolic_encoder(**tokens)
        return out.last_hidden_state[:, 0, :256]  # CLS, first 256 dims

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
def collate_fn(batch):
    input_ids, attention_masks, symbolic_vecs, labels = [], [], [], []
    for ex in batch:
        choice_ids, choice_masks, choice_syms = [], [], []
        for inp in ex["inputs"]:
            tokens = symbolic_tokenizer(inp, truncation=True, padding="max_length", max_length=32, return_tensors="pt")
            choice_ids.append(tokens["input_ids"])
            choice_masks.append(tokens["attention_mask"])
            symb_text = extract_conceptnet_facts(inp)
            symb_vec = get_symbolic_embedding(symb_text, device='cpu')
            choice_syms.append(symb_vec)
        input_ids.append(torch.cat(choice_ids))
        attention_masks.append(torch.cat(choice_masks))
        symbolic_vecs.append(torch.cat(choice_syms))
        labels.append(torch.tensor(ex["answer_idx"]))
    return (
        torch.stack(input_ids),
        torch.stack(attention_masks),
        torch.stack(symbolic_vecs).squeeze(2),
        torch.stack(labels)
    )

train_loader = DataLoader(train_data, batch_size=16, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_data, batch_size=16, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_data,  batch_size=16, shuffle=False, collate_fn=collate_fn)


In [7]:
class NeuroSymbolicQA(nn.Module):
    def __init__(self, model_name="roberta-large", symbolic_dim=256, dropout=0.2):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.symbolic_proj = nn.Linear(symbolic_dim, symbolic_dim)
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=self.encoder.config.hidden_size,
            kdim=symbolic_dim,
            vdim=symbolic_dim,
            num_heads=4,
            batch_first=True
        )
        self.gate = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size + symbolic_dim, 1),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.encoder.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask, symbolic_vec):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        neural_out = outputs.last_hidden_state[:, 0]  # CLS token
        symbolic_out = self.symbolic_proj(symbolic_vec)
        # Cross-attention: neural queries, symbolic keys/values
        attn_out, _ = self.cross_attn(
            neural_out.unsqueeze(1), 
            symbolic_out.unsqueeze(1), 
            symbolic_out.unsqueeze(1)
        )
        attn_out = attn_out.squeeze(1)
        # Gated fusion
        combined = torch.cat([neural_out, symbolic_out], dim=1)
        gate = self.gate(combined)
        gated_out = gate * neural_out + (1 - gate) * attn_out
        gated_out = self.dropout(gated_out)
        logits = self.classifier(gated_out)
        return logits


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NeuroSymbolicQA().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
num_epochs = 30
criterion = nn.CrossEntropyLoss()
best_val_acc = 0.0
best_model_path = "best_neurosymbolic_model.pt"
patience = 5
epochs_no_improve = 0

total_steps = num_epochs * len(train_loader)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for input_ids, attn_masks, symb_vecs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        input_ids = input_ids.to(device)
        attn_masks = attn_masks.to(device)
        symb_vecs = symb_vecs.to(device)
        labels = labels.to(device)
        logits = []
        for i in range(input_ids.size(1)):  # For each choice
            logit = model(input_ids[:, i], attn_masks[:, i], symb_vecs[:, i])
            logits.append(logit)
        logits = torch.cat(logits, dim=1)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} - Loss: {total_loss/len(train_loader):.4f}")

    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for input_ids, attn_masks, symb_vecs, labels in valid_loader:
            input_ids = input_ids.to(device)
            attn_masks = attn_masks.to(device)
            symb_vecs = symb_vecs.to(device)
            labels = labels.to(device)
            logits = []
            for i in range(input_ids.size(1)):
                logit = model(input_ids[:, i], attn_masks[:, i], symb_vecs[:, i])
                logits.append(logit)
            logits = torch.cat(logits, dim=1)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = correct / total
    print(f"Validation Accuracy: {val_acc:.2%}")

    # Early stopping & checkpoint
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print("Best model saved.")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping.")
            break

print(f"\nBest Validation Accuracy: {best_val_acc:.2%}")


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/30: 100%|██████████| 63/63 [02:48<00:00,  2.68s/it]


Epoch 1 - Loss: 1.6094
Validation Accuracy: 31.33%
Best model saved.


Epoch 2/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 2 - Loss: 1.5879
Validation Accuracy: 38.00%
Best model saved.


Epoch 3/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 3 - Loss: 1.4424
Validation Accuracy: 53.67%
Best model saved.


Epoch 4/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 4 - Loss: 1.2115
Validation Accuracy: 55.33%
Best model saved.


Epoch 5/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 5 - Loss: 0.9348
Validation Accuracy: 59.00%
Best model saved.


Epoch 6/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 6 - Loss: 0.7055
Validation Accuracy: 57.67%


Epoch 7/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 7 - Loss: 0.4822
Validation Accuracy: 59.33%
Best model saved.


Epoch 8/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 8 - Loss: 0.3753
Validation Accuracy: 59.33%


Epoch 9/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 9 - Loss: 0.3135
Validation Accuracy: 57.33%


Epoch 10/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 10 - Loss: 0.2614
Validation Accuracy: 58.00%


Epoch 11/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 11 - Loss: 0.2126
Validation Accuracy: 58.33%


Epoch 12/30: 100%|██████████| 63/63 [02:48<00:00,  2.67s/it]


Epoch 12 - Loss: 0.1859
Validation Accuracy: 57.00%
Early stopping.

Best Validation Accuracy: 59.33%


In [None]:
# Load best model and evaluate on test set
model.load_state_dict(torch.load(best_model_path))
model.eval()
yhat = []
with torch.no_grad():
    for input_ids, attn_masks, symb_vecs, labels in test_loader:
        input_ids = input_ids.to(device)
        attn_masks = attn_masks.to(device)
        symb_vecs = symb_vecs.to(device)
        logits = []
        for i in range(input_ids.size(1)):
            logit = model(input_ids[:, i], attn_masks[:, i], symb_vecs[:, i])
            logits.append(logit)
        logits = torch.cat(logits, dim=1)
        preds = logits.argmax(dim=1)
        yhat.extend(preds.cpu().tolist())

print("Test set predictions (yhat):", yhat)