In [None]:
import random
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertModel
import inflect
from collections import Counter
import numpy as np
from tqdm import tqdm

In [None]:
p = inflect.engine()

entities = ["man", "woman", "person", "child", "robot", "dog"]
actions = ["pick up", "drop", "push", "pull", "kick", "hold", "throw", "walk", "run"]
objects = ["ball", "bottle", "chair", "box", "book", "cup"]
locations = [
    "in the room",
    "on the table",
    "near the door",
    "in the park",
    "beside the chair",
    "under the table",
    "next to the shelf"
]

In [None]:
NEG_TOKEN = "__NEG__"
LOC_TOKEN = "__LOC__"


POSITIVE_TEMPLATES = [
    "{e} will {a} the {o} {l}",
    "{e} is going to {a} the {o} {l}",
    "{e} should {a} the {o} {l}",
    "{e} is {a_ing} the {o} {l}",
    "please {a} the {o} {l}",
    "can you {a} the {o} {l}?"
]

NEGATIVE_TEMPLATES = [
    "{e} will not {a} the {o} {l}",
    "{e} is not {a_ing} the {o} {l}",
    "{e} does not {a} the {o} {l}",
    "{e} should not {a} the {o} {l}",
    "do not let {e} {a} the {o} {l}",
    "{e} must never {a} the {o} {l}"
]

EXCLUSION_TEMPLATES = [
    "{e} should {a} the {o} {l}, but not the other one",
    "{e} must {a} everything except the {o} {l}"
]

CONDITIONAL_TEMPLATES = [
    "{e} will {a} the {o} {l} if it is needed",
    "{e} will not {a} the {o} {l} unless required"
]

In [None]:
def inflect_action(a):
    if " " in a:
        base, rest = a.split(" ", 1)
    else:
        base = a
        rest = ""

    prog = base + "ing"
    if base.endswith("e") and not base.endswith("ee"):
        prog = base[:-1] + "ing"

    return {
        "base": a,
        "ing": prog + (" " + rest if rest else "")
    }

In [None]:
def generate_dataset(n=10000):
    rows = []
    others = NEGATIVE_TEMPLATES + EXCLUSION_TEMPLATES + CONDITIONAL_TEMPLATES

    for _ in range(n//2):
        e = random.choice(entities)
        a = random.choice(actions)
        o = random.choice(objects)
        l = random.choice(locations)
        t = random.choice(POSITIVE_TEMPLATES)

        f = inflect_action(a)

        s = t.format(e=e, a=f["base"], a_ing=f["ing"], o=o, l=f"{LOC_TOKEN} {l}")
        rows.append([s, e, a, o, l, 0])

        t = random.choice(others)
        s = t.format(e=e, a=f["base"], a_ing=f["ing"], o=o, l=f"{LOC_TOKEN} {l}")
        s = s.replace("not", f"{NEG_TOKEN} not")
        rows.append([s, e, a, o, l, 1])

    random.shuffle(rows)
    df = pd.DataFrame(rows, columns=["sentence", "entity", "action", "object", "location", "neg"])
    return df

df = generate_dataset(12000)

In [None]:
lengths = df["sentence"].str.split().apply(len)
print("Max sentence length:", lengths.max())

Max sentence length: 17


In [None]:
def build_vocab(texts, min_freq=1):
    words = []
    for s in texts:
        words += s.replace(NEG_TOKEN,"").replace(LOC_TOKEN,"").lower().split()

    freq = Counter(words)
    print(freq)

    vocab = {"<pad>":0, "<unk>":1, "<cls>":2}
    for w,c in freq.items():
          vocab[w] = len(vocab)

    return vocab

vocab = build_vocab(df["sentence"])
vocab_size = len(vocab)
print(vocab)
print(vocab_size)

def encode(sentence, max_len=40):
    words = sentence.replace(NEG_TOKEN,"").replace(LOC_TOKEN,"").lower().split()
    ids = [vocab.get(w,1) for w in words]
    ids = ids[:max_len]
    ids += [0] * (max_len - len(ids))
    return torch.tensor(ids)

Counter({'the': 24576, 'not': 4196, 'chair': 3485, 'in': 3374, 'is': 3246, 'table': 2950, 'will': 2825, 'to': 2747, 'should': 2211, 'ball': 2164, 'book': 2032, 'box': 2016, 'bottle': 1948, 'cup': 1904, 'on': 1802, 'woman': 1779, 'beside': 1768, 'next': 1758, 'child': 1702, 'near': 1692, 'person': 1685, 'man': 1640, 'dog': 1623, 'robot': 1613, 'under': 1606, 'shelf': 1522, 'door': 1486, 'room': 1472, 'park': 1471, 'up': 1242, 'push': 1212, 'must': 1211, 'pull': 1208, 'throw': 1192, 'run': 1174, 'kick': 1156, 'hold': 1146, 'drop': 1139, 'pick': 1082, 'walk': 1027, 'going': 989, 'please': 984, 'can': 974, 'you': 974, 'never': 619, 'if': 593, 'it': 593, 'needed': 593, 'everything': 592, 'except': 592, 'unless': 585, 'required': 585, 'but': 576, 'other': 576, 'one': 576, 'do': 574, 'let': 574, 'does': 568, 'table?': 286, 'holding': 206, 'runing': 200, 'pushing': 196, 'kicking': 196, 'walking': 185, 'pulling': 182, 'throwing': 172, 'table,': 172, 'droping': 167, 'picking': 160, 'shelf?': 156

In [None]:
entity2id = {e:i for i,e in enumerate(sorted(set(df.entity)))}
action2id = {a:i for i,a in enumerate(sorted(set(df.action)))}
object2id = {o:i for i,o in enumerate(sorted(set(df.object)))}
location2id = {l:i for i,l in enumerate(sorted(set(df.location)))}

df["entity_id"] = df["entity"].map(entity2id)
df["action_id"] = df["action"].map(action2id)
df["object_id"] = df["object"].map(object2id)
df["location_id"] = df["location"].map(location2id)

train_df = df.sample(frac=0.9)
test_df = df.drop(train_df.index)
df

Unnamed: 0,sentence,entity,action,object,location,neg,entity_id,action_id,object_id,location_id
0,can you pick up the book __LOC__ beside the ch...,person,pick up,book,beside the chair,0,3,3,1,0
1,robot does __NEG__ not pull the box __LOC__ ne...,robot,pull,box,near the door,1,4,4,3,3
2,woman is going to throw the cup __LOC__ beside...,woman,throw,cup,beside the chair,0,5,7,5,0
3,child is walking the chair __LOC__ in the room,child,walk,chair,in the room,0,0,8,4,2
4,man does __NEG__ not push the chair __LOC__ ne...,man,push,chair,near the door,1,2,5,4,3
...,...,...,...,...,...,...,...,...,...,...
11995,child will __NEG__ not throw the ball __LOC__ ...,child,throw,ball,near the door,1,0,7,0,3
11996,woman will __NEG__ not throw the chair __LOC__...,woman,throw,chair,beside the chair,1,5,7,4,0
11997,child does __NEG__ not kick the ball __LOC__ i...,child,kick,ball,in the park,1,0,2,0,1
11998,man is __NEG__ not picking up the book __LOC__...,man,pick up,book,in the park,1,2,3,1,1


In [None]:
class CmdDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        sentence = row["sentence"]
        ids = encode(sentence)

        neg_mask = torch.tensor([1 if NEG_TOKEN in sentence else 0] * 40).float()
        loc_mask = torch.tensor([1 if LOC_TOKEN in sentence else 0] * 40).float()

        return (
            ids,
            torch.tensor(row.entity_id),
            torch.tensor(row.action_id),
            torch.tensor(row.object_id),
            torch.tensor(row.location_id),
            torch.tensor(row.neg).float(),
            neg_mask,
            loc_mask
        )

train_ds = CmdDataset(train_df)
test_ds = CmdDataset(test_df)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=16)

In [None]:
class Model(nn.Module):
    def __init__(self, vocab_size, d_model=256, heads=4, layers=4):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(40, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=heads,
            dim_feedforward=512,
            batch_first=True
        )

        self.enc_head = nn.TransformerEncoder(encoder_layer, num_layers=layers)
        self.entity_head = nn.Linear(d_model, len(entity2id))
        self.action_head = nn.Linear(d_model, len(action2id))
        self.object_head = nn.Linear(d_model, len(object2id))
        self.location_head = nn.Linear(d_model, len(location2id))
        self.neg_head = nn.Linear(d_model, 1)

    def forward(self, x):
        b, seq = x.shape
        pos = self.pos(torch.arange(seq).to(x.device)).unsqueeze(0).expand(b,seq,-1)
        e = self.embed(x) + pos
        h = self.enc_head(e)
        cls = h[:,0,:]

        att = (h @ h.transpose(1,2))

        return {
            "entity": self.entity_head(cls),
            "action": self.action_head(cls),
            "object": self.object_head(cls),
            "location": self.location_head(cls),
            "neg": self.neg_head(cls).squeeze(),
            "att": att
        }

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model(vocab_size).to(device)
ce = nn.CrossEntropyLoss()
bce = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [None]:
for epoch in range(6):
    model.train()
    total_loss = 0

    for ids, e,a, o, l, n, neg_mask, loc_mask in tqdm(train_loader):
        ids = ids.to(device)
        e,a, o, l, n = e.to(device),a.to(device), o.to(device), l.to(device), n.to(device)
        neg_mask, loc_mask = neg_mask.to(device), loc_mask.to(device)

        out = model(ids)

        loss = (
            ce(out["entity"], e)+
            ce(out["action"], a)+
            ce(out["object"], o)+
            ce(out["location"], l)+
            bce(out["neg"], n)
        )

        att = out["att"]
        cls_att = att[:,0,:]
        mask = torch.clamp(neg_mask + loc_mask, max=1)
        att_weight = (cls_att * mask).sum() / (mask.sum()+1e-6)
        loss += 0.2*(1-att_weight)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print("Epoch:", epoch, "Loss:", total_loss/len(train_loader))


100%|██████████| 675/675 [02:54<00:00,  3.88it/s]


Epoch: 0 Loss: -1011.3351414207176


100%|██████████| 675/675 [02:53<00:00,  3.90it/s]


Epoch: 1 Loss: -1219.8768981481483


100%|██████████| 675/675 [02:53<00:00,  3.89it/s]


Epoch: 2 Loss: -1449.8444283492477


100%|██████████| 675/675 [02:53<00:00,  3.88it/s]


Epoch: 3 Loss: -1717.4691921657986


100%|██████████| 675/675 [02:53<00:00,  3.89it/s]


Epoch: 4 Loss: -2009.2550674551505


100%|██████████| 675/675 [02:53<00:00,  3.89it/s]

Epoch: 5 Loss: -2308.6568880208333





In [None]:
def eval_epoch(model):
    model.eval()
    correct = {k: 0 for k in [
        "entity","object","location","action","negation"
    ]}
    total = 0
    full_correct = 0

    with torch.no_grad():
        for ids, e, a, o, l, n, _, _ in test_loader:
            ids = ids.to(device)
            out = model(ids)

            pe = out["entity"].argmax(-1).cpu()
            pa = out["action"].argmax(-1).cpu()
            po = out["object"].argmax(-1).cpu()
            pl = out["location"].argmax(-1).cpu()
            pn = (torch.sigmoid(out["neg"])>0.5).cpu()

            correct["entity"] += (pe == e).sum()
            correct["action"] += (pa == a).sum()
            correct["object"] += (po == o).sum()
            correct["location"] += (pl == l).sum()
            correct["negation"] += (pn == n).sum()

            full_correct += ((pe==e) & (pa==a) & (po==o) & (pl==l) & (pn==n)).sum()

            total += len(a)


    accuracies = {k: (correct[k] / total) for k in correct}
    accuracies["full_semantic_accuracy"] = full_correct / total

    return accuracies

In [None]:
val_acc = eval_epoch(model)
print("Validation accuracy:", val_acc)

Validation accuracy: {'entity': tensor(0.8617), 'object': tensor(1.), 'location': tensor(1.), 'action': tensor(1.), 'negation': tensor(1.), 'full_semantic_accuracy': tensor(0.8617)}


In [None]:
def predict(sentence):

    out = model(encode(sentence).unsqueeze(0).to(device))
    # model.eval()

    ans = {
        "entity": list(entity2id.keys())[out["entity"].argmax().item()],
        "action" : list(action2id.keys())[out["action"].argmax().item()],
        "obj" : list(object2id.keys())[out["object"].argmax().item()],
        "loc" : list(location2id.keys())[out["location"].argmax().item()],
        "neg" : torch.sigmoid(out["neg"]).item() > 0.5
    }

    print(ans)
    # return ans


predict("the man should not pick up the bottle near the door")
predict("woman please hold the cup in the park")
predict("do not let the robot throw the box on the table")
predict("child is running beside the chair")
predict("never push the chair next to the shelf")
predict("can the dog walk in the room?")


{'entity': 'man', 'action': 'pick up', 'obj': 'bottle', 'loc': 'near the door', 'neg': True}
{'entity': 'woman', 'action': 'hold', 'obj': 'cup', 'loc': 'in the park', 'neg': False}
{'entity': 'robot', 'action': 'throw', 'obj': 'box', 'loc': 'on the table', 'neg': True}
{'entity': 'child', 'action': 'drop', 'obj': 'chair', 'loc': 'beside the chair', 'neg': False}
{'entity': 'man', 'action': 'push', 'obj': 'chair', 'loc': 'next to the shelf', 'neg': False}
{'entity': 'man', 'action': 'walk', 'obj': 'chair', 'loc': 'in the room', 'neg': False}
