In [2]:
import random
import numpy as np
from model import *
import torch
from torch import nn
from torch.optim import Adam

index_to_color = {0:'white',1: 'red', 2: 'green', 3: 'blue', 4: 'yellow', 5: 'purple',6:'black'}

def generate_data(batch_size=100):
    rand_index_list = np.array([sorted(np.random.choice(range(1, 20), 5, replace=False)) for _ in range(batch_size)])
    colors = np.random.choice(range(1,6), (batch_size,5))
    target = np.zeros((batch_size,25))
    target[:, 20:] = colors
    data = np.zeros((batch_size,25))
    data[np.arange(batch_size)[:, None], rand_index_list] = colors
    data[:, 20:] = 6
    return [data, target]

data = generate_data(batch_size=100)
print(data[0][0],data[1][0])

[0. 2. 0. 0. 0. 0. 4. 0. 0. 0. 0. 0. 4. 0. 2. 1. 0. 0. 0. 0. 6. 6. 6. 6.
 6.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 2. 4. 4. 2.
 1.]


In [8]:
# Define model parameters
args = ModelArgs(
    d_model=16,
    n_layer=2, 
    vocab_size=7,
    d_state=8, 
)

# Initialize the Mamba model
model = Mamba(args)

# Define a loss function
loss_fn = nn.CrossEntropyLoss(reduction='sum')

# Define an optimizer
optimizer = Adam(model.parameters(), lr=1e-4)

# For each epoch
for epoch in range(3000):
    # Generate data
    inputs, targets = generate_data(batch_size=64)
    inputs = torch.from_numpy(inputs).long()
    targets = torch.from_numpy(targets).long()
    outputs = model(inputs)
    outputs = nn.Softmax(dim=2)(outputs)
    loss = loss_fn(outputs.reshape(-1,8), targets.reshape(-1))
    if epoch % 100 == 0:
        print(f'Epoch {epoch}: loss={loss.item()}')

    # Backward pass: compute the gradients of the loss with respect to the model's parameters
    loss.backward()

    # Update the model's parameters
    optimizer.step()

    # Zero the gradients
    optimizer.zero_grad()

Epoch 0: loss=2676.363037109375
Epoch 100: loss=2673.0322265625
Epoch 200: loss=2613.73291015625
Epoch 300: loss=2612.4677734375
Epoch 400: loss=2511.308349609375
Epoch 500: loss=2475.52392578125
Epoch 600: loss=2428.9716796875
Epoch 700: loss=2408.201171875
Epoch 800: loss=2418.505859375
Epoch 900: loss=2403.130859375
Epoch 1000: loss=2419.111328125
Epoch 1100: loss=2405.806884765625
Epoch 1200: loss=2408.31494140625
Epoch 1300: loss=2427.336181640625
Epoch 1400: loss=2417.472412109375
Epoch 1500: loss=2414.2177734375
Epoch 1600: loss=2421.294921875
Epoch 1700: loss=2413.44091796875
Epoch 1800: loss=2417.180908203125
Epoch 1900: loss=2403.21142578125
Epoch 2000: loss=2335.81201171875
Epoch 2100: loss=2337.158447265625
Epoch 2200: loss=2340.873291015625
Epoch 2300: loss=2347.194580078125
Epoch 2400: loss=2351.195556640625
Epoch 2500: loss=2357.284423828125
Epoch 2600: loss=2347.50439453125
Epoch 2700: loss=2333.29296875
Epoch 2800: loss=2331.7978515625
Epoch 2900: loss=2343.54736328125

In [9]:
inputs, targets = generate_data(batch_size=1)
inputs = torch.from_numpy(inputs).long()
targets = torch.from_numpy(targets).long()
outputs = model(inputs)
outputs = nn.Softmax(dim=2)(outputs)
print(inputs[0])
print(outputs[0].argmax(dim=1))
print(targets[0])

tensor([0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 4, 0, 0, 0, 1, 6, 6, 6, 6,
        6])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 5, 5,
        5])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 1, 4,
        1])
