In [1]:
import sys
sys.path.append("../")
from hyena.standalone_hyena import HyenaOperator
from torch.utils.data import DataLoader, Dataset
import torch
from torchinfo import summary
from einops import rearrange
import timm
from PIL import Image
import numpy as np
from  torchvision import datasets, transforms
import torch.optim as optim
import torch.nn as nn
import os
import random
from tqdm import tqdm


In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
seed = 0
seed_everything(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [26]:

class HyenaNet(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.hyena =  HyenaOperator(
                        d_model=3, 
                        l_max=50176, 
                        order=10, 
                        filter_order=64,
                    )
        self.flat = nn.Flatten()
        self.fc = nn.Linear(50176*3,10)
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        x = self.hyena(x)
        x = self.flat(x)
        x = self.fc(x)
        x = self.softmax(x)
        return x

def get_transform(test=False):
    if test:
        transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])
    return transform


In [27]:
hyena = HyenaNet()

In [28]:
train_data = datasets.CIFAR10('../images/', #データを保存するdir
                              train = True,  #True : 学習用データ False : テストデータ 
                              download=True,  # downloadするか否か
                              transform = get_transform()) #前処理の設定
test_data = datasets.CIFAR10('../images/', #データを保存するdir
                              train = False,  #True : 学習用データ False : テストデータ 
                              download=True,  # downloadするか否か
                              transform = get_transform(test=True)) #前処理の設定
train_loader = DataLoader(train_data,batch_size=64,shuffle=True,num_workers=4,pin_memory=True)
test_loader = DataLoader(test_data,batch_size=64,shuffle=False,num_workers=4,pin_memory=True)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(hyena.parameters(), lr=1e-6)

Files already downloaded and verified


In [31]:
hyena.to(device)
hyena_train_loss = []
hyena_train_corrects = []
for epoch in range(30):
    epoch_loss = 0.0
    epoch_corrects = 0
    for i, data in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        img, label = data
        img = img.to(device)
        img = rearrange(img, "b c w h -> b (w h) c")
        label = label.to(device)
        output = hyena(img)
        loss = criterion(output, label)
        corrects = torch.sum(output.argmax(dim=1) == label).detach().cpu().item()
        epoch_loss += loss.detach().cpu().item()
        epoch_corrects += corrects
        loss.backward()
        optimizer.step()
    hyena_train_loss.append(epoch_loss/(i+1))
    hyena_train_corrects.append(epoch_corrects/len(train_data))

  x = self.softmax(x)
100%|██████████| 782/782 [00:51<00:00, 15.05it/s]
100%|██████████| 782/782 [00:51<00:00, 15.11it/s]
100%|██████████| 782/782 [00:52<00:00, 15.02it/s]
100%|██████████| 782/782 [00:52<00:00, 14.85it/s]
100%|██████████| 782/782 [00:52<00:00, 14.91it/s]
100%|██████████| 782/782 [00:51<00:00, 15.11it/s]
100%|██████████| 782/782 [00:51<00:00, 15.07it/s]
100%|██████████| 782/782 [00:51<00:00, 15.10it/s]
100%|██████████| 782/782 [00:51<00:00, 15.07it/s]
100%|██████████| 782/782 [00:52<00:00, 14.96it/s]
100%|██████████| 782/782 [00:52<00:00, 14.90it/s]
100%|██████████| 782/782 [00:52<00:00, 14.92it/s]
100%|██████████| 782/782 [00:52<00:00, 14.92it/s]
100%|██████████| 782/782 [00:52<00:00, 14.99it/s]
100%|██████████| 782/782 [00:52<00:00, 15.00it/s]
100%|██████████| 782/782 [00:52<00:00, 15.00it/s]
100%|██████████| 782/782 [00:51<00:00, 15.13it/s]
100%|██████████| 782/782 [00:52<00:00, 15.03it/s]
100%|██████████| 782/782 [00:52<00:00, 14.97it/s]
100%|██████████| 782/782 [00

In [33]:
hyena_train_loss

[2.3668214405893973,
 2.3668813830446402,
 2.3666416132236687,
 2.3668214405893973,
 2.3668813830446402,
 2.3669413254998832,
 2.3669413254998832,
 2.3668214405893973,
 2.3667614981341547,
 2.3669413254998832,
 2.3669413254998832,
 2.3669413254998832,
 2.3669413254998832,
 2.3669413254998832,
 2.3668813830446402,
 2.3668214405893973,
 2.3668813830446402,
 2.3668214405893973,
 2.3668813830446402,
 2.3668813830446402,
 2.3669413254998832,
 2.3667614981341547,
 2.3669413254998832,
 2.3669413254998832,
 2.3668813830446402,
 2.3667614981341547,
 2.3667614981341547,
 2.3668214405893973,
 2.3669413254998832,
 2.3667015556789117]