In [1]:
!pip install skorch pytorch_lightning einops torcheval tqdm


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms
from torchvision.datasets import MNIST


def create_dataloader(batch_size):
    # target_transform = transforms.Compose([
    #     lambda x: print("hoge"),
    #     lambda x: print(x),
    #     torch.tensor,
    #     lambda x: F.one_hot(x, num_classes=10),
    # ])
    
    data_train = torch.utils.data.DataLoader(
        MNIST(
            '~/mnist_data', train=True, download=True,
            transform=transforms.ToTensor(),
            # target_transform=target_transform,
        ),
        batch_size=batch_size,
        shuffle=True
    )

    data_test = torch.utils.data.DataLoader(
        MNIST(
            '~/mnist_data', train=False, download=True,
            transform=transforms.ToTensor(),
            # target_transform=target_transform,
        ),
        batch_size=batch_size,
        shuffle=True
    )

    return data_train, data_test


In [3]:
batch_size = 32

trainloader, testloader = create_dataloader(batch_size=batch_size)

In [4]:
from s4.single import SingleS4Classifier

model = SingleS4Classifier(
    d_input=784,
    d_output=10,
    d_model=512,
    n_layers=3,
    dropout=[0.2]*3,
    transposed=False,
    s4d=True,
)

  from .autonotebook import tqdm as notebook_tqdm
CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled.
Falling back on slow Cauchy and Vandermonde kernel. Install at least one of pykeops or the CUDA extension for better speed and memory efficiency.


No module named 'extensions'


In [5]:
model

SingleS4Classifier(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): GELU(approximate='none')
  )
  (decoder): Sequential(
    (0): Linear(in_features=512, out_features=10, bias=True)
    (1): LogSoftmax(dim=1)
  )
)

In [6]:
import torch
import torch.nn as nn

from tqdm.contrib import tenumerate

# 損失関数  criterion：基準
criterion = nn.CrossEntropyLoss()

# 最適化法の指定  optimizer：最適化
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(10):
    print(f"Epoch: {epoch}")
    train_loss, test_loss = 0.0, 0.0

    model.train()
    for idx, samples in tenumerate(trainloader):
        data, label = samples
        inputs = data.view(batch_size, -1) #.cuda()

        # print(f"inputs = {inputs.shape}")
        
        targets = F.one_hot(label.view(batch_size), num_classes=10).float() #.cuda()

        optimizer.zero_grad()
        
        outputs = model(inputs)
        # print(f"outputs = {outputs}, targets = {targets}")
        
        loss = criterion(outputs, targets)

        loss.backward()
        
        optimizer.step()
        
        train_loss += loss.item()

    print("train loss: ", train_loss / len(trainloader))


    model.eval()
    with torch.no_grad():
        for idx, samples in enumerate(testloader):
            data, label = samples
            inputs = data.view(-1, 784) #.cuda()
            targets = F.one_hot(label, num_classes=10).float() #.cuda()

            outputs = model(inputs)
            
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()

    print("test loss: ", test_loss / len(testloader))



Epoch: 0


100%|██████████| 1875/1875 [01:24<00:00, 22.11it/s]


train loss:  1.5813925941785176
test loss:  1.1434066592694851
Epoch: 1


100%|██████████| 1875/1875 [01:23<00:00, 22.55it/s]


train loss:  1.0009166517893473
test loss:  0.8894488156413118
Epoch: 2


100%|██████████| 1875/1875 [01:22<00:00, 22.65it/s]


train loss:  0.8005918872038523
test loss:  0.7293340910357028
Epoch: 3


100%|██████████| 1875/1875 [01:23<00:00, 22.58it/s]


train loss:  0.6935171256224314
test loss:  0.6394799912509065
Epoch: 4


100%|██████████| 1875/1875 [01:22<00:00, 22.85it/s]


train loss:  0.6119460261503855
test loss:  0.5679113820147591
Epoch: 5


100%|██████████| 1875/1875 [01:22<00:00, 22.82it/s]


train loss:  0.5526404209454854
test loss:  0.5223766243971956
Epoch: 6


100%|██████████| 1875/1875 [01:22<00:00, 22.66it/s]


train loss:  0.512718423851331
test loss:  0.48793672720273845
Epoch: 7


100%|██████████| 1875/1875 [01:22<00:00, 22.73it/s]


train loss:  0.4725037154277166
test loss:  0.4526738094064755
Epoch: 8


100%|██████████| 1875/1875 [01:20<00:00, 23.17it/s]


train loss:  0.44421250182787575
test loss:  0.42611014186002955
Epoch: 9


100%|██████████| 1875/1875 [01:34<00:00, 19.81it/s]


train loss:  0.42066710186799366
test loss:  0.40667796501526815
