In [1]:
import random
import torch
from torch import nn

In [2]:
def create_context_examples(word_pairs, k=3):
    """
    Formats data for in-context learning:
    Input: "sing1 -> plur1; sing2 -> plur2; ...; singN -> "
    Target: "plurN"
    """
    context = []
    for _ in range(k):
        s, p = random.choice(word_pairs)
        context.append(f"{s} -> {p}")
    
    test_singular, test_plural = random.choice(word_pairs)
    context.append(f"{test_singular} -> ")
    
    return " ; ".join(context), test_plural


In [3]:
def make_batches(dataset, batch_size):
    """Groups dataset items into batches of specified size"""
    batch = []
    for item in dataset:
        batch.append(item)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if batch:  # Yield remaining items
        yield batch


In [4]:
class InContextTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=2, num_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, 512, d_model))
        self.layers = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, 128),
            num_layers
        )
        self.head = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        x = self.embed(x) + self.pos_embed[:, :x.size(1)]
        x = self.layers(x)
        return self.head(x)


In [17]:
def train_cycle(model, dataset, epochs=10, batch_size=32):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss(ignore_index=0)

    for epoch in range(epochs):
        random.shuffle(dataset)
        total_loss = 0
        for batch in make_batches(dataset, batch_size):
            inputs = torch.stack([ex[0] for ex in batch])
            targets = torch.stack([ex[1] for ex in batch])

            outputs = model(inputs)

            loss = loss_fn(outputs.view(-1, outputs.size(-1)), targets.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.item()
        
        print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")


In [18]:
word_pairs = [
    # Regular plurals
    ("cat", "cats"),
    ("dog", "dogs"),
    ("car", "cars"),
    ("apple", "apples"),
    ("book", "books"),
    ("tree", "trees"),
    ("cup", "cups"),
    ("pen", "pens"),
    ("chair", "chairs"),
    ("table", "tables"),
    ("house", "houses"),
    ("phone", "phones"),
    ("shoe", "shoes"),
    ("bag", "bags"),
    ("door", "doors"),
    ("window", "windows"),
    ("computer", "computers"),
    ("student", "students"),
    ("teacher", "teachers"),
    ("doctor", "doctors"),
    ("friend", "friends"),
    ("child", "children"),
    ("person", "people"),
    ("man", "men"),
    ("woman", "women"),
    ("mouse", "mice"),
    ("goose", "geese"),
    ("tooth", "teeth"),
    ("foot", "feet"),
    ("fish", "fish"),
    ("sheep", "sheep"),
    ("deer", "deer"),
    ("cactus", "cacti"),
    ("focus", "foci"),
    ("fungus", "fungi"),
    ("nucleus", "nuclei"),
    ("syllabus", "syllabi"),
    ("analysis", "analyses"),
    ("crisis", "crises"),
    ("thesis", "theses"),
    ("phenomenon", "phenomena"),
    ("criterion", "criteria"),
    ("datum", "data"),
    ("bus", "buses"),
    ("box", "boxes"),
    ("fox", "foxes"),
    ("watch", "watches"),
    ("wish", "wishes"),
    ("dish", "dishes"),
    ("baby", "babies"),
    ("city", "cities"),
    ("party", "parties"),
    ("story", "stories"),
    ("berry", "berries"),
    ("family", "families"),
    ("country", "countries"),
    ("lady", "ladies"),
    ("boy", "boys"),
    ("toy", "toys"),
    ("key", "keys"),
    ("day", "days"),
    ("monkey", "monkeys"),
    ("leaf", "leaves"),
    ("wolf", "wolves"),
    ("knife", "knives"),
    ("life", "lives"),
    ("wife", "wives"),
    ("calf", "calves"),
    ("half", "halves"),
    ("loaf", "loaves"),
    ("scarf", "scarves"),
    ("chief", "chiefs"),
    ("roof", "roofs"),
    ("belief", "beliefs"),
    ("chef", "chefs"),
    ("photo", "photos"),
    ("piano", "pianos"),
    ("halo", "halos"),
    ("potato", "potatoes"),
    ("tomato", "tomatoes"),
    ("hero", "heroes"),
    ("echo", "echoes"),
    ("zero", "zeroes"),
    ("kangaroo", "kangaroos"),
    ("radio", "radios"),
    ("studio", "studios"),
    ("video", "videos"),
    ("zoo", "zoos"),
    ("bamboo", "bamboos"),
    ("cargo", "cargoes"),
    ("volcano", "volcanoes"),
    ("tornado", "tornadoes"),
    ("mosquito", "mosquitoes"),
    ("buffalo", "buffaloes"),
    ("domino", "dominoes"),
    ("torpedo", "torpedoes"),
    ("veto", "vetoes"),
    ("alumnus", "alumni"),
    ("alumna", "alumnae"),
    ("medium", "media"),
    ("memorandum", "memoranda"),
    ("appendix", "appendices"),
    ("index", "indices"),
    ("matrix", "matrices"),
    ("vertex", "vertices"),
    ("axis", "axes"),
    ("ox", "oxen"),
    ("quiz", "quizzes"),
    ("church", "churches"),
    ("match", "matches"),
    ("branch", "branches"),
    ("peach", "peaches"),
    ("lunch", "lunches"),
    ("sandwich", "sandwiches"),
    ("witch", "witches"),
    ("pass", "passes"),
    ("glass", "glasses"),
    ("class", "classes"),
    ("kiss", "kisses"),
    ("bus", "buses"),
    ("gas", "gases"),
    ("status", "statuses"),
    ("octopus", "octopuses"),
    ("virus", "viruses"),
    ("radius", "radii"),
    ("genius", "geniuses"),
    ("species", "species"),
    ("series", "series"),
    ("aircraft", "aircraft"),
    ("means", "means"),
    ("barracks", "barracks"),
    ("salmon", "salmon"),
    ("shrimp", "shrimp"),
    ("trout", "trout"),
    ("swine", "swine"),
    ("hovercraft", "hovercraft"),
    ("crossroads", "crossroads"),
    ("headquarters", "headquarters"),
]


In [19]:
from collections import defaultdict

def build_vocab(word_pairs):
    special_tokens = ["<pad>", "<unk>"]
    tokens = set()
    for s, p in word_pairs:
        tokens.update(s)
        tokens.update(p)
        tokens.add("->")
        tokens.add(";")

    token_list = special_tokens + sorted(tokens)
    stoi = defaultdict(lambda: 1, {tok: i for i, tok in enumerate(token_list)})  # <unk> = 1
    itos = {i: tok for tok, i in stoi.items()}
    return stoi, itos

def tokenize(text, stoi):
    return [stoi[c] for c in text]

def detokenize(indices, itos):
    return ''.join(itos[i] for i in indices if i > 1)


In [20]:
def encode_dataset(word_pairs, stoi, k=3, max_len=128):
    dataset = []
    for _ in range(1000):  # number of examples
        context, target = create_context_examples(word_pairs, k=k)
        input_ids = tokenize(context, stoi)
        target_ids = tokenize(target, stoi)
        
        input_tensor = torch.tensor(input_ids + [0] * (max_len - len(input_ids)), dtype=torch.long)[:max_len]
        target_tensor = torch.tensor(target_ids + [0] * (max_len - len(target_ids)), dtype=torch.long)[:max_len]
        
        dataset.append((input_tensor, target_tensor))
    return dataset


In [21]:
def evaluate(model, word_pairs, stoi, itos, k=3, max_len=128):
    model.eval()
    correct = 0
    total = 20
    with torch.no_grad():
        for _ in range(total):
            context, target = create_context_examples(word_pairs, k=k)
            input_ids = tokenize(context, stoi)
            input_tensor = torch.tensor(input_ids + [0] * (max_len - len(input_ids)), dtype=torch.long)[:max_len].unsqueeze(0)

            output_logits = model(input_tensor)
            output_ids = output_logits.argmax(dim=-1)[0].tolist()

            predicted = detokenize(output_ids[len(input_ids):], itos)
            print(f"{context}{predicted.strip()} (Expected: {target})")
            if predicted.strip().startswith(target):
                correct += 1

    print(f"Accuracy: {correct}/{total}")
    model.train()


In [22]:
stoi, itos = build_vocab(word_pairs)
dataset = encode_dataset(word_pairs, stoi, k=3, max_len=64)
model = InContextTransformer(vocab_size=len(stoi), d_model=64, nhead=2, num_layers=2)

train_cycle(model, dataset, epochs=10, batch_size=16)
evaluate(model, word_pairs, stoi, itos, k=3, max_len=64)


Epoch 1, Loss: 173.5855
Epoch 2, Loss: 160.7125
Epoch 3, Loss: 158.8412
Epoch 4, Loss: 158.3738
Epoch 5, Loss: 157.8524
Epoch 6, Loss: 157.2612
Epoch 7, Loss: 156.6422
Epoch 8, Loss: 156.4576
Epoch 9, Loss: 155.6631
Epoch 10, Loss: 156.1328
crossroads -> crossroads ; city -> cities ; bamboo -> bamboos ; series ->  (Expected: series)
glass -> glasses ; cat -> cats ; shrimp -> shrimp ; crisis -> sa (Expected: crises)
chief -> chiefs ; genius -> geniuses ; halo -> halos ; index ->  (Expected: indices)
syllabus -> syllabi ; car -> cars ; medium -> media ; toy -> ssa (Expected: toys)
class -> classes ; hovercraft -> hovercraft ; bus -> buses ; window ->  (Expected: windows)
appendix -> appendices ; criterion -> criteria ; syllabus -> syllabi ; datum ->  (Expected: data)
photo -> photos ; torpedo -> torpedoes ; headquarters -> headquarters ; aircraft ->  (Expected: aircraft)
studio -> studios ; scarf -> scarves ; shoe -> shoes ; friend ->  (Expected: friends)
foot -> feet ; virus -> viruses 