# hyena-vitの実装

In [1]:
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
from torch.utils.data import DataLoader
from einops import rearrange
from einops.layers.torch import Rearrange
from torchinfo import summary
import sys
sys.path.append("../")
from hyena.standalone_hyena import HyenaOperator

import random
import os


from  torchvision import datasets, transforms
import torch.optim as optim


import timm

In [2]:
class Patching(nn.Module):
    def __init__(self, patch_size):
        """ [input]
            - patch_size (int) : パッチの縦の長さ（=横の長さ）
        """
        super().__init__()
        self.net = Rearrange("b c (h ph) (w pw) -> b (h w) (ph pw c)", ph = patch_size, pw = patch_size)
    
    def forward(self, x):
        """ [input]
            - x (torch.Tensor) : 画像データ
                - x.shape = torch.Size([batch_size, channels, image_height, image_width])
        """
        x = self.net(x)
        return x


In [3]:
class LinearProjection(nn.Module):
    def __init__(self, patch_dim, d_model):
        """ [input]
            - patch_dim (int) : 一枚あたりのパッチの次元（= channels * (patch_size ** 2)
            - dim (int) : パッチが変換されたベクトルの次元 
        """
        super().__init__()
        self.net = nn.Linear(patch_dim, d_model)

    def forward(self, x):
        """ [input]
            - x (torch.Tensor) 
                - x.shape = torch.Size([batch_size, n_patches, patch_dim])
        """
        x = self.net(x)
        return x


In [4]:
class ImgToHyena(nn.Module):
    def __init__(
        self,
        patch_size: int = 8, 
        patch_dim: int = 192,
        d_model: int = 10
        ):
        super().__init__()
        
        self.patch = Patching(patch_size=patch_size)
        self.linearprojection = LinearProjection(patch_dim=patch_dim, d_model=d_model)
        
    def forward(self, img):
        """ [input]
            - img (torch.Tensor) 
                - x.shape = torch.Size([batch_size, channel, height, width)
        """
        
        x = self.patch(img)
        x = self.linearprojection(x)
        
        return x
        

In [5]:

class HyenaNet(nn.Module):
    def __init__(
        self,
        patch_size: int = 8, 
        patch_dim: int = 192,
        d_model: int = 10, # model_width
        l_max: int = 784, # max_seq_len
        order: int = 10, # v, x1, ... x10
        filter_order: int = 64
    ):
        super().__init__()
        
        self.imgtohyena = ImgToHyena(
            patch_size=patch_size,
            patch_dim=patch_dim
        )
        
        self.norm = nn.LayerNorm(d_model)
        
        self.hyena =  HyenaOperator(
            d_model=d_model, 
            l_max=l_max, 
            order=order, 
            filter_order=filter_order
        )
        
        self.flat = nn.Flatten()
        self.fc = nn.Linear(784*10,10)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, img):
        x = self.imgtohyena(img)
        x = self.norm(x)
        x = self.hyena(x)
        x = self.flat(x)
        x = self.fc(x)
        x = self.softmax(x)
        
        return x


In [6]:
def get_transform(test=False):
    if test:
        transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])
    return transform

In [7]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
seed = 0
seed_everything(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"


train_data = datasets.CIFAR10('../images/', #データを保存するdir
                              train = True,  #True : 学習用データ False : テストデータ 
                              download=True,  # downloadするか否か
                              transform = get_transform()) #前処理の設定
test_data = datasets.CIFAR10('../images/', #データを保存するdir
                              train = False,  #True : 学習用データ False : テストデータ 
                              download=True,  # downloadするか否か
                              transform = get_transform(test=True)) #前処理の設定
train_loader = DataLoader(train_data,batch_size=64,shuffle=True,num_workers=4,pin_memory=True)
test_loader = DataLoader(test_data,batch_size=64,shuffle=False,num_workers=4,pin_memory=True)
criterion = nn.CrossEntropyLoss()


Files already downloaded and verified
Files already downloaded and verified


In [21]:
hyena = HyenaNet()
optimizer = optim.Adam(hyena.parameters(), lr=1e-6)


In [22]:

hyena.to(device)
hyena_train_loss = []
hyena_train_corrects = []
hyena_test_loss = []
hyena_test_corrects = []
for epoch in range(20):
    hyena.train()
    epoch_train_loss = 0.0
    epoch_train_corrects = 0
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        img, label = data
        img = img.to(device)
        label = label.to(device)
        output = hyena(img)
        loss = criterion(output, label)
        corrects = torch.sum(output.argmax(dim=1) == label).detach().cpu().item()
        epoch_train_loss += loss.detach().cpu().item()
        epoch_train_corrects += corrects
        loss.backward()
        optimizer.step()
    hyena_train_loss.append(epoch_train_loss/(i+1))
    hyena_train_corrects.append(epoch_train_corrects/len(train_data))
    
    hyena.eval()
    epoch_test_loss = 0.0
    epoch_test_corrects = 0
    for j, data in enumerate(test_loader):

        img, label = data
        img = img.to(device)
        label = label.to(device)
        output = hyena(img)
        loss = criterion(output, label)
        corrects = torch.sum(output.argmax(dim=1) == label).detach().cpu().item()
        epoch_test_loss += loss.detach().cpu().item()
        epoch_test_corrects += corrects

    hyena_test_loss.append(epoch_test_loss/(i+1))
    hyena_test_corrects.append(epoch_test_corrects/len(test_data))

    print("-"*50)
    print("epoch : {}".format(epoch))
    print("train  loss : {:.4f}, acc : {:.4f}".format(epoch_train_loss/(i+1),epoch_train_corrects/len(train_data)))
    print("test   loss : {:.4f}, acc : {:.4f}".format(epoch_test_loss/(j+1),epoch_test_corrects/len(test_data)))    
    

--------------------------------------------------
epoch : 0
train  loss : 2.3320, acc : 0.1072
test   loss : 2.3191, acc : 0.1195
--------------------------------------------------
epoch : 1
train  loss : 2.3072, acc : 0.1281
test   loss : 2.2928, acc : 0.1451
--------------------------------------------------
epoch : 2
train  loss : 2.2789, acc : 0.1586
test   loss : 2.2702, acc : 0.1659
--------------------------------------------------
epoch : 3
train  loss : 2.2581, acc : 0.1809
test   loss : 2.2516, acc : 0.1892
--------------------------------------------------
epoch : 4
train  loss : 2.2424, acc : 0.1992
test   loss : 2.2352, acc : 0.2072
--------------------------------------------------
epoch : 5
train  loss : 2.2289, acc : 0.2141
test   loss : 2.2231, acc : 0.2205
--------------------------------------------------
epoch : 6
train  loss : 2.2181, acc : 0.2269
test   loss : 2.2161, acc : 0.2288
--------------------------------------------------
epoch : 7
train  loss : 2.2083, 

## ViT

In [8]:
vit = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=10)
optimizer = optim.Adam(vit.parameters(), lr=1e-6)

In [9]:

vit.to(device)
vit_train_loss = []
vit_train_corrects = []
vit_test_loss = []
vit_test_corrects = []
for epoch in range(20):
    vit.train()
    epoch_train_loss = 0.0
    epoch_train_corrects = 0
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        img, label = data
        img = img.to(device)
        label = label.to(device)
        output = vit(img)
        loss = criterion(output, label)
        corrects = torch.sum(output.argmax(dim=1) == label).detach().cpu().item()
        epoch_train_loss += loss.detach().cpu().item()
        epoch_train_corrects += corrects
        loss.backward()
        optimizer.step()
    vit_train_loss.append(epoch_train_loss/(i+1))
    vit_train_corrects.append(epoch_train_corrects/len(train_data))
    
    vit.eval()
    epoch_test_loss = 0.0
    epoch_test_corrects = 0
    for j, data in enumerate(test_loader):

        img, label = data
        img = img.to(device)
        label = label.to(device)
        output = vit(img)
        loss = criterion(output, label)
        corrects = torch.sum(output.argmax(dim=1) == label).detach().cpu().item()
        epoch_test_loss += loss.detach().cpu().item()
        epoch_test_corrects += corrects

    vit_test_loss.append(epoch_test_loss/(i+1))
    vit_test_corrects.append(epoch_test_corrects/len(test_data))

    print("-"*50)
    print("epoch : {}".format(epoch))
    print("train  loss : {:.4f}, acc : {:.4f}".format(epoch_train_loss/(i+1),epoch_train_corrects/len(train_data)))
    print("test   loss : {:.4f}, acc : {:.4f}".format(epoch_test_loss/(j+1),epoch_test_corrects/len(test_data)))    
    


  return F.conv2d(input, weight, bias, self.stride,
