In [1]:
# Install required packages
!pip install transformers datasets conceptnet-lite torch tqdm


Collecting conceptnet-lite
  Downloading conceptnet_lite-0.2.0-py3-none-any.whl.metadata (14 kB)
Collecting lmdb<2.0,>=1.0 (from conceptnet-lite)
  Downloading lmdb-1.6.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Collecting peewee<4.0,>=3.10 (from conceptnet-lite)
  Downloading peewee-3.18.1.tar.gz (3.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m51.8 MB/s[0m eta [36m0:00:00[0m
  Installing build dependencies ... [?done
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting pysmartdl<2.0,>=1.3 (from conceptnet-lite)
  Downloading pySmartDL-1.3.4-py3-none-any.whl.metadata (2.8 kB)
Downloading conceptnet_lite-0.2.0-py3-none-any.whl (16 kB)
Downloading lmdb-1.6.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (300 kB)
Downloading pySmartDL-1.3.4-py3-none-any.whl (20 kB)
Building wheels for collected packages: peewee


In [2]:
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from tqdm import tqdm
import conceptnet_lite
import re

# Connect to ConceptNet Lite (downloads DB if not present)
conceptnet_lite.connect()


File not found: /home/jovyan/conceptnet.db
Download compressed database
[*] 1.85 GB / 1.85 GB @ 90.2 MB/s [##################] [100%, 0s left]    
Extract compressed database (this can take a few minutes)


In [3]:
# Load CommonsenseQA subset
full_train = load_dataset("tau/commonsense_qa", split="train")
raw_train = full_train.select(range(1000))
raw_valid = full_train.select(range(1000, 1300))

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def preprocess(example):
    inputs = []
    for choice in example["choices"]["text"]:
        text = example["question"] + " " + choice
        inputs.append(text)
    return {"inputs": inputs, "answer_idx": ord(example["answerKey"]) - ord("A")}

train_data = [preprocess(ex) for ex in raw_train]
valid_data = [preprocess(ex) for ex in raw_valid]


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [9]:
def extract_conceptnet_facts(text):
    words = re.findall(r'\b\w+\b', text.lower())
    facts = []
    for word in words:
        # Use edges_for(word) instead of edges(start=word)
        edges = list(conceptnet_lite.edges_for(word))
        for e in edges[:2]:  # Limit for speed
            facts.append(f"{e.start} {e.relation} {e.end}")
    return " ".join(facts[:3]) if facts else ""


In [5]:
class NeuroSymbolicQA(nn.Module):
    def __init__(self, model_name="bert-base-uncased", symbolic_dim=128):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.symbolic_proj = nn.Linear(128, symbolic_dim)
        self.classifier = nn.Linear(self.encoder.config.hidden_size + symbolic_dim, 1)
        self.symbolic_dim = symbolic_dim

    def forward(self, input_ids, attention_mask, symbolic_vec):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0]  # CLS token
        symb = self.symbolic_proj(symbolic_vec)
        combined = torch.cat([pooled, symb], dim=1)
        logits = self.classifier(combined)
        return logits


In [6]:
symbolic_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
symbolic_encoder = AutoModel.from_pretrained("bert-base-uncased")

def get_symbolic_embedding(fact_text):
    if not fact_text:
        return torch.zeros(1, 128)
    tokens = symbolic_tokenizer(fact_text, truncation=True, padding="max_length", max_length=16, return_tensors="pt")
    with torch.no_grad():
        out = symbolic_encoder(**tokens)
        return out.last_hidden_state[:, 0, :128]  # Take CLS, first 128 dims


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [7]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    input_ids, attention_masks, symbolic_vecs, labels = [], [], [], []
    for ex in batch:
        choice_ids, choice_masks, choice_syms = [], [], []
        for inp in ex["inputs"]:
            tokens = tokenizer(inp, truncation=True, padding="max_length", max_length=32, return_tensors="pt")
            choice_ids.append(tokens["input_ids"])
            choice_masks.append(tokens["attention_mask"])
            symb_text = extract_conceptnet_facts(inp)
            symb_vec = get_symbolic_embedding(symb_text)
            choice_syms.append(symb_vec)
        input_ids.append(torch.cat(choice_ids))
        attention_masks.append(torch.cat(choice_masks))
        symbolic_vecs.append(torch.cat(choice_syms))
        labels.append(torch.tensor(ex["answer_idx"]))
    return (
        torch.stack(input_ids),
        torch.stack(attention_masks),
        torch.stack(symbolic_vecs).squeeze(2),
        torch.stack(labels)
    )

train_loader = DataLoader(train_data, batch_size=8, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_data, batch_size=8, shuffle=False, collate_fn=collate_fn)


In [12]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NeuroSymbolicQA().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

num_epochs = 20
best_val_acc = 0.0
best_model_path = "best_neurosymbolic_model.pt"

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for input_ids, attn_masks, symb_vecs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        input_ids = input_ids.to(device)
        attn_masks = attn_masks.to(device)
        symb_vecs = symb_vecs.to(device)
        labels = labels.to(device)
        logits = []
        for i in range(input_ids.size(1)):  # For each choice
            logit = model(input_ids[:, i], attn_masks[:, i], symb_vecs[:, i])
            logits.append(logit)
        logits = torch.cat(logits, dim=1)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} - Loss: {total_loss/len(train_loader):.4f}")

    # Validation after each epoch
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for input_ids, attn_masks, symb_vecs, labels in valid_loader:
            input_ids = input_ids.to(device)
            attn_masks = attn_masks.to(device)
            symb_vecs = symb_vecs.to(device)
            labels = labels.to(device)
            logits = []
            for i in range(input_ids.size(1)):
                logit = model(input_ids[:, i], attn_masks[:, i], symb_vecs[:, i])
                logits.append(logit)
            logits = torch.cat(logits, dim=1)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = correct / total
    print(f"Validation Accuracy: {val_acc:.2%}")

    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print("Best model saved.")

print(f"\nBest Validation Accuracy: {best_val_acc:.2%}")


Epoch 1/20: 100%|██████████| 125/125 [03:52<00:00,  1.86s/it]


Epoch 1 - Loss: 1.5842
Validation Accuracy: 34.33%
Best model saved.


Epoch 2/20: 100%|██████████| 125/125 [03:58<00:00,  1.91s/it]


Epoch 2 - Loss: 1.3940
Validation Accuracy: 39.00%
Best model saved.


Epoch 3/20: 100%|██████████| 125/125 [05:49<00:00,  2.80s/it]


Epoch 3 - Loss: 0.8935
Validation Accuracy: 42.67%
Best model saved.


Epoch 4/20: 100%|██████████| 125/125 [09:22<00:00,  4.50s/it]


Epoch 4 - Loss: 0.4275
Validation Accuracy: 38.67%


Epoch 5/20: 100%|██████████| 125/125 [03:56<00:00,  1.89s/it]


Epoch 5 - Loss: 0.2291
Validation Accuracy: 41.00%


Epoch 6/20: 100%|██████████| 125/125 [03:43<00:00,  1.79s/it]


Epoch 6 - Loss: 0.1585
Validation Accuracy: 38.67%


Epoch 7/20: 100%|██████████| 125/125 [03:35<00:00,  1.72s/it]


Epoch 7 - Loss: 0.1231
Validation Accuracy: 40.67%


Epoch 8/20: 100%|██████████| 125/125 [04:04<00:00,  1.96s/it]


Epoch 8 - Loss: 0.1131
Validation Accuracy: 39.00%


Epoch 9/20: 100%|██████████| 125/125 [03:36<00:00,  1.73s/it]


Epoch 9 - Loss: 0.1006
Validation Accuracy: 38.33%


Epoch 10/20: 100%|██████████| 125/125 [03:29<00:00,  1.68s/it]


Epoch 10 - Loss: 0.0807
Validation Accuracy: 38.67%


Epoch 11/20: 100%|██████████| 125/125 [03:26<00:00,  1.65s/it]


Epoch 11 - Loss: 0.0852
Validation Accuracy: 40.67%


Epoch 12/20: 100%|██████████| 125/125 [04:01<00:00,  1.93s/it]


Epoch 12 - Loss: 0.0688
Validation Accuracy: 38.67%


Epoch 13/20: 100%|██████████| 125/125 [03:48<00:00,  1.83s/it]


Epoch 13 - Loss: 0.0735
Validation Accuracy: 36.33%


Epoch 14/20: 100%|██████████| 125/125 [03:52<00:00,  1.86s/it]


Epoch 14 - Loss: 0.0633
Validation Accuracy: 36.33%


Epoch 15/20: 100%|██████████| 125/125 [03:34<00:00,  1.72s/it]


Epoch 15 - Loss: 0.0634
Validation Accuracy: 36.33%


Epoch 16/20: 100%|██████████| 125/125 [03:36<00:00,  1.73s/it]


Epoch 16 - Loss: 0.0878
Validation Accuracy: 38.00%


Epoch 17/20: 100%|██████████| 125/125 [03:46<00:00,  1.81s/it]


Epoch 17 - Loss: 0.0669
Validation Accuracy: 38.00%


Epoch 18/20: 100%|██████████| 125/125 [10:10<00:00,  4.89s/it]


Epoch 18 - Loss: 0.0648
Validation Accuracy: 38.00%


Epoch 19/20: 100%|██████████| 125/125 [05:25<00:00,  2.60s/it]


Epoch 19 - Loss: 0.0592
Validation Accuracy: 38.33%


Epoch 20/20: 100%|██████████| 125/125 [03:55<00:00,  1.89s/it]


Epoch 20 - Loss: 0.0594
Validation Accuracy: 37.00%

Best Validation Accuracy: 42.67%
