In [17]:
import torch
import torch.nn as nn

def linear(x, W, b):
    """
    x: (1, in_features)
    W: (out_features, in_features)
    b: (out_features,)
    returns: (1, out_features)
    """
    return x @ W.T + b

def rnn_step_concat(x_t, h_prev, W_i2h, b_i2h):
    """
    x_t:   (1, E)
    h_prev:(1, H)
    W_i2h: (H, E+H)   (because Linear(E+H -> H))
    b_i2h: (H,)
    returns:
      h_t: (1, H)
    """
    combined = torch.cat([x_t, h_prev], dim=1)         # (1, E+H)
    preact = linear(combined, W_i2h, b_i2h)            # (1, H)
    h_t = torch.tanh(preact)                           # (1, H)
    return h_t, combined, preact

# -----------------------------
# Example sizes like your note:
# E=5, H=4  => combined size = 9
# -----------------------------
E = 5
H = 4

# 3 time-step inputs (x1, x2, x3), each (1,5)
x1 = torch.tensor([[1., 1., 1., 1., 1.]])
x2 = torch.tensor([[2., 0., 0., 0., 0.]])
x3 = torch.tensor([[0., 2., 0., 0., 0.]])

# initial hidden (1,4)
h0 = torch.zeros(1, H)

# "Linear(9->4)" parameters from scratch:
# W_i2h shape (4,9), b_i2h shape (4,)
# We'll choose easy numbers so you can compute.
W_i2h = torch.tensor([
    [0.10,0.10,0.10,0.10,0.10, 0.10,0.10,0.10,0.10],  # neuron 1
    [0.20,0.20,0.20,0.20,0.20, 0.20,0.20,0.20,0.20],  # neuron 2
    [0.30,0.30,0.30,0.30,0.30, 0.30,0.30,0.30,0.30],  # neuron 3
    [0.40,0.40,0.40,0.40,0.40, 0.40,0.40,0.40,0.40],  # neuron 4
])
b_i2h = torch.tensor([0., 0., 0., 0.])

# Run 3 steps
h1, c1, p1 = rnn_step_concat(x1, h0, W_i2h, b_i2h)
h2, c2, p2 = rnn_step_concat(x2, h1, W_i2h, b_i2h)
h3, c3, p3 = rnn_step_concat(x3, h2, W_i2h, b_i2h)

print("combined1:", c1)
print("preact1:", p1)
print("h1:", h1)
print("------"*10)

print("combined2:", c2)
print("preact2:", p2)
print("h2:", h2)
print("------"*10)

print("combined3:", c3)
print("preact3:", p3)
print("h3:", h3)

combined1: tensor([[1., 1., 1., 1., 1., 0., 0., 0., 0.]])
preact1: tensor([[0.5000, 1.0000, 1.5000, 2.0000]])
h1: tensor([[0.4621, 0.7616, 0.9051, 0.9640]])
------------------------------------------------------------
combined2: tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4621, 0.7616, 0.9051, 0.9640]])
preact2: tensor([[0.5093, 1.0186, 1.5279, 2.0372]])
h2: tensor([[0.4694, 0.7693, 0.9101, 0.9666]])
------------------------------------------------------------
combined3: tensor([[0.0000, 2.0000, 0.0000, 0.0000, 0.0000, 0.4694, 0.7693, 0.9101, 0.9666]])
preact3: tensor([[0.5115, 1.0231, 1.5346, 2.0461]])
h3: tensor([[0.4711, 0.7711, 0.9112, 0.9671]])


In [18]:
class VanillaRNN(nn.Module):
    def __init__(self, input_dim, hidden_size, output_dim):
        super().__init__()
        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_dim+hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_dim)

    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)
    
    def forward(self, x_t, h_prev):
        combined = torch.cat((x_t, h_prev), dim=1)
        h_t = torch.tanh(self.i2h(combined))
        y_t = self.h2o(h_t)
        return h_t, y_t
        

In [19]:
import torch.nn.functional as F
import math

In [None]:
# ----------------------------
# 1) Tiny Dataset
# ----------------------------
sentences = [
    "what is the capital of india",
    "what is the capital of usa",
    "what is the capital of france",
]

def tokenize(text):
    return text.lower().split()


# ----------------------------
# 2) Build Vocabulary
# ----------------------------
tokens = []
for s in sentences:
    tokens += tokenize(s)

vocab = sorted(set(tokens))
vocab = ["<unk>"] + vocab

stoi = {w:i for i,w in enumerate(vocab)}
itos = {i:w for w,i in stoi.items()}
V = len(vocab)

def encode(text):
    return [stoi.get(w, stoi["<unk>"]) for w in tokenize(text)]


# ----------------------------
# 3) YOUR VanillaRNN
# ----------------------------
class VanillaRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.input2hidden = nn.Linear(input_size+hidden_size, hidden_size)
        self.hidden2output = nn.Linear(hidden_size, output_size)

    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)
    
    def forward(self, x_t, h_prev):
        combined = torch.cat((x_t, h_prev), dim=1)
        h_t = torch.tanh(self.input2hidden(combined))
        y_t = self.hidden2output(h_t)
        return h_t, y_t
    
# ----------------------------
# 4) Embedding Layer (separate)
# ----------------------------
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, emb_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)

    def forward(self, token_id):
        return self.embedding(token_id)


In [21]:
# ----------------------------
# 5) Create Model
# ----------------------------
torch.manual_seed(123)
emb_dim = 8
hidden_size = 16
embedding = EmbeddingLayer(V, emb_dim)
rnn = VanillaRNN(emb_dim, hidden_size, V)
params = list(embedding.parameters()) + list(rnn.parameters())
optimizer = torch.optim.SGD(params, lr=0.5)

In [22]:
# ----------------------------
# 6) Training Loop
# ----------------------------
for epoch in range(1, 101):
    total_loss = 0

    for sentence in sentences:
        token_ids = encode(sentence)
        token_ids = torch.tensor(token_ids, dtype=torch.long)

        h = rnn.init_hidden()

        loss = 0

        # next word prediction
        for t in range(len(token_ids) - 1):
            x_t = embedding(token_ids[t]).unsqueeze(0)   # (1, emb_dim)
            target = token_ids[t+1].unsqueeze(0)         # (1,)

            h, logits = rnn(x_t, h)

            loss += F.cross_entropy(logits, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss:.4f}")


Epoch 20, Loss: 23.0505
Epoch 40, Loss: 25.1173
Epoch 60, Loss: 25.4774
Epoch 80, Loss: 24.9905
Epoch 100, Loss: 25.4188


In [23]:
# ----------------------------
# 7) Test Prediction
# ----------------------------
def predict_next_word(prefix):
    rnn.eval()
    embedding.eval()

    with torch.no_grad():
        ids = encode(prefix)
        ids = torch.tensor(ids, dtype=torch.long)

        h = rnn.init_hidden()

        for t in range(len(ids)):
            x_t = embedding(ids[t]).unsqueeze(0)
            h, logits = rnn(x_t, h)

        probs = torch.softmax(logits, dim=1)
        predicted_id = torch.argmax(probs).item()
        return itos[predicted_id]

print("\nTesting...")
print("Input: what is the capital of")
print("Prediction:", predict_next_word("what is the capital of"))


Testing...
Input: what is the capital of
Prediction: france
