In [None]:
!git clone https://github.com/PhocaHiro/s4_Phoca.git

In [None]:
%cd /content/s4
# ============= Set Up =============
# Requirements
!conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
!pip install -r requirements.txt
# Structured Kernels
%cd extensions/kernels/
!python setup.py install
%cd /content/s4

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models.s4.s4 import S4Block as S4  # Can use full version instead of minimal S4D standalone below
from models.s4.s4d import S4D
from tqdm.auto import tqdm
import copy

# Dropout broke in PyTorch 1.11
if tuple(map(int, torch.__version__.split('.')[:2])) == (1, 11):
    print("WARNING: Dropout is bugged in PyTorch 1.11. Results may be worse.")
    dropout_fn = nn.Dropout
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 12):
    dropout_fn = nn.Dropout1d
else:
    dropout_fn = nn.Dropout2d

CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled.
Falling back on slow Cauchy and Vandermonde kernel. Install at least one of pykeops or the CUDA extension for better speed and memory efficiency.




  from .autonotebook import tqdm as notebook_tqdm


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# ハイパラの設定

In [None]:
import json
import datetime

lr = 0.001
weight_decay = 0.01
num_workers = 4
batch_size = 8
d_model = 512
d_mlp = 512
prenorm = True
dropout = 0.1
grad_clip = 1000

hyperparameters = {
    "lr": lr,
    "weight_decay": weight_decay,
    "num_workers":  num_workers,
    "batch_size": batch_size,
    "d_model": d_model,
    "d_mlp": d_mlp,
    "prenorm": prenorm,
    "dropout": dropout,
    "grad_clip": grad_clip,
}

# ハイパラの種類が今後増える可能性を踏まえ、ファイル名にversionを記載する(hyparaVxxとなるように)
current_time = datetime.datetime.now()
current_time_str = current_time.strftime("%Y%m%d_%H%M")
with open(f'hyparams/hyparaV1_{current_time_str}.json', 'w') as f:
    json.dump(hyperparameters, f, indent=4)

# S4Block (p26 Figure21参照)

In [5]:
class S4Block(nn.Module):

    def __init__(
        self,
        d_model=256,
        d_mlp = 512,
        n_layers=2,
        dropout=0.2,
        prenorm=True,
    ):
        super(S4Block, self).__init__()

        self.prenorm = prenorm

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts1 = nn.ModuleList()
        self.linears = nn.ModuleList()
        self.glus = nn.ModuleList()
        self.dropouts2= nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(
                S4D(d_model, dropout=0.0, transposed=True, lr=min(0.001, lr)) # ドロップアウトはとりあえず使わない設定
            )
            self.linears.append(nn.Linear(d_model, d_model))
            self.norms.append(nn.LayerNorm(d_model))
            self.dropouts1.append(dropout_fn(dropout))
            self.glus.append(nn.GLU())   #TODO: これ何？
            self.dropouts2.append(dropout_fn(dropout))

        self.norm_mlp = nn.ModuleList([
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_mlp),
            nn.GELU(),
            dropout_fn(dropout),
            nn.Linear(d_model, d_mlp),
            dropout_fn(dropout)])

    def forward(self, x):
        """
        Input x is shape (B, L, d_model)
        """
        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        for s4, norm, dropout1, linear, glu, dropout2 in \
            zip(self.s4_layers, self.norms, self.dropouts1, self.linears, self.glus, self.dropouts2):
            # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)

            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z.transpose(-1, -2)).transpose(-1, -2)

            # Apply S4 block: we ignore the state input and output
            z, _ = s4(z)

            # Dropout on the output of the S4 block
            z = dropout1(z)

            # Mixing informations
            z = linear(z.transpose(-1, -2)).transpose(-1, -2)
            z = glu(z)
            z = dropout2(z)

            # Residual connection
            x = z + x

            if not self.prenorm:
                # Postnorm
                x = norm(x.transpose(-1, -2)).transpose(-1, -2)

        x = x.transpose(-1, -2)  # (B, d_model, L) -> (B, L, d_model)

        #TODO: x_にも操作が反映されてたりしないか確認する
        x_ = x
        x = x_ + self.norm_mlp(x)

        return x

# Encoder(画像にする必要はあるか？評価する上では画像にする必要はありそうだけど実際にモデルとしては軽いほうがいい)

In [None]:
class Encoder(nn.Module):
    """
    (input_dim, 64, 64)の画像を(1024,)のベクトルに変換する
    """
    
    def __init__(
        self,
        input_dim=3, # grayscaleなら1
        hidden_dim=30,
    ):
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.cv1 = nn.Conv2d(input_dim, 32, kernel_size=4, stride=2) # (input_dim, 64, 64) -> (32, 31, 31)
        self.cv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) # (32, 31, 31) -> (64, 14, 14)
        self.cv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2) # (64, 14, 14) -> (128, 6, 6)
        self.cv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2) # (128, 6, 6) -> (256, 2, 2)
        self.mean_posterior = nn.Linear(1024, hidden_dim)

    def forward(self, obs):
        """
        Parameters
        ----------
        obs : torch.Tensor (batch_size, L, input_dim, 64, 64)
            環境から得られた観測画像
        
        Returns
        ----------
         : torch.Tensor (batch_size, L, 1024)
            観測画像を1024次元に埋め込んだもの
        """
        hidden = F.relu(self.cv1(obs))
        hidden = F.relu(self.cv2(hidden))
        hidden = F.relu(self.cv3(hidden))
        embedded_obs = F.relu(self.cv4(hidden)).reshape(hidden.size(0), -1)
        return embedded_obs