In [1]:
from s4d import S4Model
import s4d
from s4d import DropoutNd
import torch.nn as nn
import torch
import torchvision
import torchvision.transforms as transforms


# Dropout broke in PyTorch 1.11
if tuple(map(int, torch.__version__.split('.')[:2])) == (1, 11):
    print("WARNING: Dropout is bugged in PyTorch 1.11. Results may be worse.")
    dropout_fn = nn.Dropout
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 12):
    dropout_fn = nn.Dropout1d
else:
    dropout_fn = nn.Dropout2d


config = {'d_model': 128, 'n_layers':4, 'dropout':0.1, 'grayscale': False, 'prenorm': False}


if config['grayscale']:
        transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize(mean=122.6 / 255.0, std=61.0 / 255.0),
            transforms.Lambda(lambda x: x.view(1, 1024).t())
        ])
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        transforms.Lambda(lambda x: x.view(3, 1024).t())
    ])

# S4 is trained on sequences with no data augmentation!
transform_train = transform_test = transform

d_input = 3 if not config['grayscale'] else 1
d_output = 10


  from .autonotebook import tqdm as notebook_tqdm


In [48]:
# Model
model = S4Model(
    d_input=d_input,
    dropout_fn=dropout_fn,
    d_output=d_output,
    d_model=config['d_model'],
    n_layers=config['n_layers'],
    dropout=config['dropout'],
    prenorm=config['prenorm'],
    skip_connection=False,
    layer_norm=False,
)

In [49]:
state = torch.load('ckpt.pth')

In [50]:
model.load_state_dict(state['model'])

<All keys matched successfully>

In [51]:
device = "cpu"
model = model.to(device)
model.eval()

S4Model(
  (encoder): Linear(in_features=3, out_features=128, bias=True)
  (s4_layers): ModuleList(
    (0): S4D(
      (kernel): S4DKernel()
      (activation): GELU(approximate='none')
      (dropout): DropoutNd()
      (output_linear): Sequential(
        (0): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
        (1): GLU(dim=-2)
      )
    )
    (1): S4D(
      (kernel): S4DKernel()
      (activation): GELU(approximate='none')
      (dropout): DropoutNd()
      (output_linear): Sequential(
        (0): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
        (1): GLU(dim=-2)
      )
    )
    (2): S4D(
      (kernel): S4DKernel()
      (activation): GELU(approximate='none')
      (dropout): DropoutNd()
      (output_linear): Sequential(
        (0): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
        (1): GLU(dim=-2)
      )
    )
    (3): S4D(
      (kernel): S4DKernel()
      (activation): GELU(approximate='none')
      (dropout): DropoutNd()
      (output_linear): Sequenti

In [52]:
data = torchvision.datasets.CIFAR10(root='./data/cifar/', train=False, download=True, transform=transform_test)

Files already downloaded and verified


In [53]:
import random

random.seed(42)

batch_size = 32
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False)
batch = (next(iter(data_loader)))[0]

In [54]:
output, all_hidden_states = model(batch)

data = (output, all_hidden_states)
torch.save(data, 'S4_Skip=False_Norm=False.pt')

In [30]:
import matplotlib.pyplot as plt
import numpy as np

def plot_residual(ratios, aggregate="all", c="C0", ax=None, **kwargs):
    num_layers, num_samples = ratios.shape[:2]

    if aggregate == "all":
        for sample_idx in range(num_samples):
            ax.plot(np.arange(num_layers), ratios[:, sample_idx],
                    c=c, alpha=.1, **kwargs)# ax=ax)

    mean_value = ratios.mean(axis=-1)
    std_value = ratios.std(axis=-1)

    ax.plot(np.arange(num_layers), mean_value,
                c=c, **kwargs)#, ax=ax)

    if aggregate == "std":
        ax.fill_between(np.arange(num_layers), mean_value - std_value, mean_value + std_value,
                         color=c, alpha=.2)

    plt.xlabel(f"layer index")
    plt.ylim([0 - 0.01, 1 + 0.01])
    plt.grid(alpha=.3)

In [31]:
import torch
import jax.numpy as jnp
import jax
import random

def compute_low_rank(x, k=1):
    U, s, Vh = jax.vmap(jnp.linalg.svd)(x)
    return jnp.einsum("ij,j,jk->ik", U[:, :k], s[:k], Vh[:k ,:])

def l1_matrix_norm(x):
    return x.abs().sum(axis=-2 % x.ndim).max(axis=-1).values

def linf_matrix_norm(x):
    return l1_matrix_norm(x.transpose(-2, -1))

def composite_norm(x):
    return torch.sqrt(l1_matrix_norm(x) * linf_matrix_norm(x))

all_norms = {
    "l1": l1_matrix_norm,
    "l2": lambda r: torch.norm(r, p=2, dim=(-2, -1)),
    "l_inf": linf_matrix_norm,
    "l1 * l_inf": composite_norm,
}

all_norms_names = list(all_norms.keys())

def sample_path(depth, num_layers, num_heads):
    selected_layers = sorted(random.sample(list(range(num_layers)), depth))
    selected_heads = random.choices(list(range(num_heads)), k=depth)
    return selected_layers, selected_heads

def sample_P_matrix(attentions, depth: int):
    num_layers, num_samples, num_heads, t, _ = attentions.shape
    selected_layers, selected_heads = sample_path(depth, num_layers, num_heads)
    sample_idx = random.choice(list(range(num_samples)))
    P = torch.eye(t)
    for layer, head in zip(selected_layers, selected_heads):
        P = P @ attentions[layer, sample_idx, head]
    return P

In [33]:
plt.rcParams["text.usetex"] = False

plt.figure(figsize=(10, 6))

for skip in [True, False]:
    for layer_norm in [True, False]:

        if skip==True and layer_norm==True:
            data = torch.load("S4_Skip=True_Norm=True.pt")
        elif skip==True and layer_norm==False:
            data = torch.load("S4_Skip=True_Norm=False.pt")
        elif skip==False and layer_norm==True:
            data = torch.load("S4_Skip=False_Norm=True.pt")
        else:
            data = torch.load("S4_Skip=False_Norm=False.pt")
        
        all_hidden_states = data[1]
        hidden_states = torch.stack(all_hidden_states)
        residuals = hidden_states - hidden_states.mean(dim=-2, keepdim=True)

        norm_fn = all_norms["l1 * l_inf"]
        ratio = norm_fn(residuals) / norm_fn(hidden_states)

        num_layers, num_samples = ratio.shape[:2]

        mean_value = ratio.mean(axis=-1)
        std_value = ratio.std(axis=-1)

        if skip==True and layer_norm==True:
            plt.fill_between(np.arange(num_layers), mean_value - std_value, mean_value + std_value,
                            color="C0", alpha=0.2)
            plt.plot(np.arange(num_layers), mean_value, color="C0", label="Skip=True, Norm=True")
        elif skip==True and layer_norm==False:
            plt.fill_between(np.arange(num_layers), mean_value - std_value, mean_value + std_value,
                            color="C1", alpha=0.2)
            plt.plot(np.arange(num_layers), mean_value, color="C1", label="Skip=True, Norm=False")
        elif skip==False and layer_norm==True:
            plt.fill_between(np.arange(num_layers), mean_value - std_value, mean_value + std_value,
                            color="C2", alpha=0.2)
            plt.plot(np.arange(num_layers), mean_value, color="C2", label="Skip=False, Norm=True")
        else:
            plt.fill_between(np.arange(num_layers), mean_value - std_value, mean_value + std_value,
                            color="C3", alpha=0.2)
            plt.plot(np.arange(num_layers), mean_value, color="C3", label="Skip=False, Norm=False")
        



plt.xlabel("Layer")
plt.ylabel("Ratio")
plt.title("Ratio of Residuals to Hidden States")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

AttributeError: 'list' object has no attribute 'detach'

<Figure size 1000x600 with 0 Axes>