In [6]:
!pip install skorch pytorch_lightning einops torcheval tqdm

^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms
from torchvision.datasets import MNIST


def create_dataloader(batch_size):
    # target_transform = transforms.Compose([
    #     lambda x: print("hoge"),
    #     lambda x: print(x),
    #     torch.tensor,
    #     lambda x: F.one_hot(x, num_classes=10),
    # ])
    
    data_train = torch.utils.data.DataLoader(
        MNIST(
            '~/mnist_data', train=True, download=True,
            transform=transforms.ToTensor(),
            # target_transform=target_transform,
        ),
        batch_size=batch_size,
        shuffle=True
    )

    data_test = torch.utils.data.DataLoader(
        MNIST(
            '~/mnist_data', train=False, download=True,
            transform=transforms.ToTensor(),
            # target_transform=target_transform,
        ),
        batch_size=batch_size,
        shuffle=True
    )

    return data_train, data_test


In [2]:
batch_size = 32

trainloader, testloader = create_dataloader(batch_size=batch_size)

In [6]:
from ssm import SingleS4Classifier

model = SingleS4Classifier(
    d_input=784,
    d_output=10,
    d_model=512,
    n_layers=3,
    dropout=[0.2]*3,
    transposed=False,
    s4d=True,
)

In [7]:
model

SingleS4Classifier(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): GELU(approximate='none')
  )
  (decoder): Sequential(
    (0): Linear(in_features=512, out_features=10, bias=True)
    (1): LogSoftmax(dim=1)
  )
)

In [None]:
import torch
import torch.nn as nn

from tqdm.contrib import tenumerate

# 損失関数  criterion：基準
criterion = nn.CrossEntropyLoss()

# 最適化法の指定  optimizer：最適化
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(10):
    print(f"Epoch: {epoch}")
    train_loss, test_loss = 0.0, 0.0

    model.train()
    for idx, samples in tenumerate(trainloader):
        data, label = samples
        inputs = data.view(batch_size, -1) #.cuda()

        print(f"{inputs.shape=}") # (32, 784)
        
        targets = F.one_hot(label.view(batch_size), num_classes=10).float() #.cuda()

        optimizer.zero_grad()
        
        outputs = model(inputs)
        print(f" {outputs.shape=}, targets = {targets.shape=}") # (32, 10), targets = (32, 10)
        
        loss = criterion(outputs, targets)

        loss.backward()
        
        optimizer.step()
        
        train_loss += loss.item()

    print("train loss: ", train_loss / len(trainloader))


    model.eval()
    with torch.no_grad():
        for idx, samples in enumerate(testloader):
            data, label = samples
            inputs = data.view(-1, 784) #.cuda()
            targets = F.one_hot(label, num_classes=10).float() #.cuda()

            outputs = model(inputs)
            
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()

    print("test loss: ", test_loss / len(testloader))



In [3]:
from ssm import MambaClassifier, ModelArgs

args = ModelArgs(
    d_input=784,
    d_output=10,
    d_model=1,
    n_layer=2,
    d_state= 10,
    expand = 2,
    dt_rank = "auto",
    d_conv = 4,
    conv_bias = True,
    bias = False,
    device = "cpu"
)

model = MambaClassifier(args)

CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled.
Falling back on slow Cauchy and Vandermonde kernel. Install at least one of pykeops or the CUDA extension for better speed and memory efficiency.


No module named 'extensions'


In [4]:
model

MambaClassifier(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=1, bias=True)
    (1): GELU(approximate='none')
  )
  (layers): ModuleList(
    (0-1): 2 x ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=1, out_features=4, bias=False)
        (conv1d): Conv1d(2, 2, kernel_size=(4,), stride=(1,), padding=(3,), groups=2)
        (x_proj): Linear(in_features=2, out_features=21, bias=False)
        (dt_proj): Linear(in_features=1, out_features=2, bias=True)
        (out_proj): Linear(in_features=2, out_features=1, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (norm_f): RMSNorm()
  (decoder): Sequential(
    (0): Linear(in_features=784, out_features=10, bias=True)
    (1): LogSoftmax(dim=1)
  )
)

In [5]:
import torch
import torch.nn as nn

from tqdm.contrib import tenumerate

# 損失関数  criterion：基準
criterion = nn.CrossEntropyLoss()

# 最適化法の指定  optimizer：最適化
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(10):
    print(f"Epoch: {epoch}")
    train_loss, test_loss = 0.0, 0.0

    model.train()
    for idx, samples in tenumerate(trainloader):
        data, label = samples
        
        # print(f"{data.shape=}, {label.shape=}") # (32, 1, 28, 28), (32,)
        inputs = data.view(batch_size, -1, 1) #.cuda()

        # print(f"{inputs.shape=}") # (32, 784, 1)
        
        targets = F.one_hot(label.view(batch_size), num_classes=10).float() #.cuda()
        # targets = label.view(batch_size) 

        optimizer.zero_grad()
        
        outputs = model(inputs)
        assert outputs.shape[1] == 10, f"{outputs.shape=}, {targets.shape=}"
        # print(f"{outputs.shape=}, {targets.shape=}") # (32, 10), (32, 10)
        
        loss = criterion(outputs, targets)

        loss.backward()
        
        optimizer.step()
        
        train_loss += loss.item()

    print("train loss: ", train_loss / len(trainloader))


    model.eval()
    with torch.no_grad():
        for idx, samples in enumerate(testloader):
            data, label = samples
            inputs = data.view(-1, 784) #.cuda()
            targets = F.one_hot(label, num_classes=10).float() #.cuda()

            outputs = model(inputs)
            
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()

    print("test loss: ", test_loss / len(testloader))


Epoch: 0


  0%|          | 0/1875 [00:00<?, ?it/s]