In [212]:
import sys
sys.path.append("../")
from hyena.standalone_hyena import HyenaOperator
from torch.utils.data import DataLoader, Dataset
import torch
from torchinfo import summary
from einops import rearrange
import timm
from PIL import Image
import numpy as np
from  torchvision import datasets, transforms
import torch.optim as optim
import torch.nn as nn
import os
import random
from tqdm import tqdm


: 

In [139]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
seed = 0
seed_everything(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [198]:

class HyenaNet(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.hyena =  HyenaOperator(
                        d_model=224, 
                        l_max=224, 
                        order=10, 
                        filter_order=64,
                    )
        self.flat = nn.Flatten()
        self.fc = nn.Linear(50176,10)
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        x = self.hyena(x)
        x = self.flat(x)
        x = self.fc(x)
        x = self.softmax(x)
        return x

def get_transform(test=False):
    if test:
        transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])
    return transform

vit = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=10)

In [164]:
hyena = HyenaNet()

In [169]:
fm_train_data = datasets.FashionMNIST("../images/", train=True, transform=get_transform(), download=True)
train_loader = DataLoader(fm_train_data, batch_size=1024)
fm_test_data = datasets.FashionMNIST("../images/", train=False, transform=get_transform(test=True), download=True)
test_loader = DataLoader(fm_test_data, batch_size=1024)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(hyena.parameters(), lr=1e-6)

## hyena train

In [183]:
hyena.to(device)
train_loss = []
train_corrects = []
for epoch in range(30):
    epoch_loss = 0.0
    epoch_corrects = 0
    for i, data in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        img, label = data
        img = img.to(device)
        label = label.to(device)
        output = hyena(img.squeeze())
        loss = criterion(output, label)
        corrects = torch.sum(output.argmax(dim=1) == label).detach().cpu().item()
        epoch_loss += loss.detach().cpu().item()
        epoch_corrects += corrects
        loss.backward()
        optimizer.step()
    train_loss.append(epoch_loss/(i+1))
    train_corrects.append(epoch_corrects/len(fm_train_data))
    

  x = self.softmax(x)
100%|██████████| 59/59 [00:25<00:00,  2.31it/s]
100%|██████████| 59/59 [00:26<00:00,  2.25it/s]
100%|██████████| 59/59 [00:25<00:00,  2.30it/s]
100%|██████████| 59/59 [00:25<00:00,  2.33it/s]
100%|██████████| 59/59 [00:24<00:00,  2.36it/s]
100%|██████████| 59/59 [00:25<00:00,  2.28it/s]
100%|██████████| 59/59 [00:25<00:00,  2.31it/s]
100%|██████████| 59/59 [00:25<00:00,  2.32it/s]
100%|██████████| 59/59 [00:24<00:00,  2.37it/s]
100%|██████████| 59/59 [00:25<00:00,  2.36it/s]
100%|██████████| 59/59 [00:25<00:00,  2.27it/s]
100%|██████████| 59/59 [00:26<00:00,  2.24it/s]
100%|██████████| 59/59 [00:25<00:00,  2.28it/s]
100%|██████████| 59/59 [00:25<00:00,  2.28it/s]
100%|██████████| 59/59 [00:25<00:00,  2.30it/s]
100%|██████████| 59/59 [00:25<00:00,  2.28it/s]
100%|██████████| 59/59 [00:24<00:00,  2.36it/s]
100%|██████████| 59/59 [00:25<00:00,  2.27it/s]
100%|██████████| 59/59 [00:25<00:00,  2.28it/s]
100%|██████████| 59/59 [00:25<00:00,  2.35it/s]
100%|██████████| 5

In [211]:
import matplotlib.pyplot as plt


30