In [None]:
import random
import numpy as np
from model import *
import torch
from torch import nn
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt

index_to_color = {0:'white',1: 'red', 2: 'green', 3: 'blue', 4: 'yellow', 5: 'purple',6:'black'}

def generate_data(batch_size=100):
    rand_index_list = np.array([sorted(np.random.choice(range(1, 20), 5, replace=False)) for _ in range(batch_size)])
    colors = np.random.choice(range(1,6), (batch_size,5))
    target = np.zeros((batch_size,25))
    target[:, 20:] = colors
    data = np.zeros((batch_size,25))
    data[np.arange(batch_size)[:, None], rand_index_list] = colors
    data[:, 20:] = 6
    return [data, target]

data = generate_data(batch_size=100)
print(data[0][0],data[1][0])

In [None]:
# Define model parameters
args = ModelArgs(
    d_model=16,#when set to 64, the model does not learn
    n_layer=2, 
    vocab_size=7,
)

# Initialize the Mamba model
model = Mamba(args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define a loss function
loss_fn = nn.CrossEntropyLoss()

# Define an optimizer
optimizer = Adam(model.parameters(), lr=1e-4)

# List to store loss values
loss_values = []

# For each epoch
for epoch in tqdm(range(10000)):
    # Generate data
    inputs, targets = generate_data(batch_size=64)
    inputs = torch.from_numpy(inputs).long().to(device)
    targets = torch.from_numpy(targets).long().to(device)
    outputs = model(inputs)
    outputs = nn.Softmax(dim=2)(outputs)
    loss = loss_fn(outputs.reshape(-1,8), targets.reshape(-1))
    loss_values.append(loss.item())  # Store loss value
    if epoch % 1000 == 0:
        print(f'Epoch {epoch}: loss={loss.item()}')

    # Backward pass: compute the gradients of the loss with respect to the model's parameters
    loss.backward()

    # Update the model's parameters
    optimizer.step()

    # Zero the gradients
    optimizer.zero_grad()

# Plot loss values
plt.plot(loss_values)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

In [None]:
inputs, targets = generate_data(batch_size=1)
inputs = torch.from_numpy(inputs).long()
targets = torch.from_numpy(targets).long()
outputs = model(inputs)
outputs = nn.Softmax(dim=2)(outputs)
print(inputs[0])
print(outputs[0].argmax(dim=1))
print(targets[0])