In [1]:
import torch
import torchvision
from mamba_ssm import Mamba
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = torchvision.datasets.CIFAR10("data/cifar10", download=False)

In [3]:
data.classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [4]:
X = data.data
y = data.targets

In [6]:
X = X / 255.0

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [10]:
train_loader = torch.utils.data.DataLoader(dataset=torch.utils.data.TensorDataset(torch.tensor(X_train), torch.tensor(y_train)), batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=torch.utils.data.TensorDataset(torch.tensor(X_test), torch.tensor(y_test)), batch_size=32, shuffle=True)

In [8]:
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model= 32 * 3, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [12]:
# Function to Train the Model
def train_model(model, train_loader, val_loader, optimizer, device, num_epochs):
    # Loop through the specified number of epochs
    for epoch in range(num_epochs):
        # Set the model to training mode
        model.train()
        # Initialize total loss for the current epoch
        total_loss = 0

        # Loop through the batches in the training data
        for batch in train_loader:
            inputs, labels = [t.to(device) for t in batch]

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = outputs.loss
            total_loss += loss.item()
            
            # TODO: y los labels?

            loss.backward()
            optimizer.step()

        model.eval()  # Set the model to evaluation mode
        val_loss = 0

        # Disable gradient computation during validation
        with torch.no_grad():
            for batch in val_loader:
                inputs, labels = [t.to(device) for t in batch]

                outputs = model(inputs)
                loss = outputs.loss
                val_loss += loss.item()
        # Print the average loss for the current epoch
        print(
            f'Epoch {epoch+1}, Training Loss: {total_loss/len(train_loader)},Validation loss:{val_loss/len(val_loader)}')


# Call the function to train the model
train_model(model, train_loader, val_loader, optimizer, "cuda", num_epochs=30)

ValueError: not enough values to unpack (expected 3, got 2)