In [33]:
import random
import numpy as np
from model import *
import torch
from torch import nn
from torch.optim import Adam

index_to_color = {0:'white',1: 'red', 2: 'green', 3: 'blue', 4: 'yellow', 5: 'purple',6:'black'}

def generate_data(batch_size=100):
    rand_index_list = np.array([sorted(np.random.choice(range(1, 20), 5, replace=False)) for _ in range(batch_size)])
    colors = np.random.choice(range(1,6), (batch_size,5))
    target = np.zeros((batch_size,25))
    target[:, 20:] = colors
    data = np.zeros((batch_size,25))
    data[np.arange(batch_size)[:, None], rand_index_list] = colors
    data[:, 20:] = 6
    return [data, target]

data = generate_data(batch_size=100)
print(data[0][0],data[1][0])

[0. 0. 0. 0. 5. 0. 5. 0. 3. 0. 0. 0. 4. 0. 0. 0. 5. 0. 0. 0. 6. 6. 6. 6.
 6.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 5. 5. 3. 4.
 5.]


In [41]:
# Define model parameters
args = ModelArgs(
    d_model=3,
    n_layer=6, 
    vocab_size=7,
    expand=4, 
    d_state=1,
)

# Initialize the Mamba model
model = Mamba(args)

# Define a loss function
loss_fn = nn.CrossEntropyLoss(reduction='sum')

# Define an optimizer
optimizer = Adam(model.parameters(), lr=1e-3)

# For each epoch
for epoch in range(3000):
    # Generate data
    inputs, targets = generate_data(batch_size=20)
    inputs = torch.from_numpy(inputs).long()
    targets = torch.from_numpy(targets).long()
    outputs = model(inputs)
    outputs = nn.Softmax(dim=2)(outputs)
    loss = loss_fn(outputs.reshape(-1,8), targets.reshape(-1))
    if epoch % 100 == 0:
        print(f'Epoch {epoch}: loss={loss.item()}')

    # Backward pass: compute the gradients of the loss with respect to the model's parameters
    loss.backward()

    # Update the model's parameters
    optimizer.step()

    # Zero the gradients
    optimizer.zero_grad()

Epoch 0: loss=866.26953125
Epoch 100: loss=808.976318359375
Epoch 200: loss=779.822265625
Epoch 300: loss=750.4093627929688
Epoch 400: loss=749.6165161132812
Epoch 500: loss=757.1575317382812
Epoch 600: loss=745.1011962890625
Epoch 700: loss=746.2762451171875
Epoch 800: loss=752.2911376953125
Epoch 900: loss=744.0885009765625
Epoch 1000: loss=750.1497802734375
Epoch 1100: loss=753.6925048828125
Epoch 1200: loss=747.4166259765625
Epoch 1300: loss=748.4771728515625
Epoch 1400: loss=749.7415161132812
Epoch 1500: loss=750.9656372070312
Epoch 1600: loss=747.675048828125
Epoch 1700: loss=747.0822143554688
Epoch 1800: loss=742.825439453125
Epoch 1900: loss=738.3043212890625
Epoch 2000: loss=734.7710571289062
Epoch 2100: loss=707.7189331054688
Epoch 2200: loss=706.251953125
Epoch 2300: loss=710.7028198242188
Epoch 2400: loss=710.2073364257812
Epoch 2500: loss=703.957275390625
Epoch 2600: loss=704.3675537109375
Epoch 2700: loss=704.9950561523438
Epoch 2800: loss=697.6864624023438
Epoch 2900: lo

In [45]:
inputs, targets = generate_data(batch_size=1)
inputs = torch.from_numpy(inputs).long()
targets = torch.from_numpy(targets).long()
outputs = model(inputs)
outputs = nn.Softmax(dim=2)(outputs)
print(inputs[0])
print(outputs[0].argmax(dim=1))
print(targets[0])

tensor([0, 4, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 4, 4, 0, 1, 0, 0, 6, 6, 6, 6,
        6])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
        1])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 4, 4,
        1])
