In [1]:
from shearletNN.shearlets import getcomplexshearlets2D
from shearletNN.shearlet_utils import frequency_shearlet_transform, spatial_shearlet_transform, ShearletTransformLoader, shifted_frequency_shearlet_transform
from shearletNN.complex_resnet import complex_resnet18, complex_resnet34, complex_resnet50
from shearletNN.layers import CGELU, CReLU

import torch
import torchvision
from torchvision.transforms import v2
from torchvision import transforms

import gc


patch_size = 32
image_size = 32

rows, cols = image_size, image_size


shearlets, shearletIdxs, RMS, dualFrameWeights = getcomplexshearlets2D(	rows, 
                                                                        cols, 
                                                                        1, 
                                                                        3, 
                                                                        1, 
                                                                        0.5,
                                                                        wavelet_eff_support = image_size,
                                                                        gaussian_eff_support = image_size
                                                                        )

shearlets = torch.tensor(shearlets).permute(2, 0, 1).type(torch.complex64).to(0)

In [2]:
class Unraveling:
    def __init__(self, n):
        self.levels = []
        for i in range(0, n // 2):
            level = []
            for j in range(i, n - i):
                level.append((j, i))
                level.append((i, j))

                level.append((j, n - (i + 1)))
                level.append((n - (i + 1), j))

            level = list(set(level))
            self.levels.append((torch.tensor([x for x, _ in level]), torch.tensor([y for _, y in level])))
    
    def __call__(self, x):
        return [x[..., a, b] for a, b in self.levels]

# this is still technically resolution-independent because that property comes from the shearlet transform crop
# in this case we do not actually need variable number of input tokens because we are always extracting a finite amount of information...
# if the whole premise is that we can utilize even larger images then why does it matter that we are using a variable size input?
class Freakformer(torch.nn.Module):
    def __init__(self, n, chans, embed_dim = 384):
        # first we need to build the tokenizer using the Unraveling transformation and some linear layers
        # this is equivalent to the patching layer
        self.unravel = Unraveling(n)
        self.layers = torch.nn.ModuleList([torch.nn.Linear(len(level), embed_dim) for level in self.unravel.layers])
        pass

    def forward(self, x):
        unraveled = self.unravel(x)
        # (B, C, 4 ... 8(n // 2) + 4)
        tokens = torch.stack([layer(level.flatten(1)) for layer, level in zip(self.layers, unraveled)], -2)
        # (B, n // 2, embed_dim)
        pass

In [3]:
from tqdm import tqdm

def train(model, optimizer, loader, accumulate=1):
    model.train()
    loss = torch.nn.CrossEntropyLoss()

    for i, (X, y) in tqdm(enumerate(loader)):
        out = model(X.to(0))
        optimizer.zero_grad()
        l = loss(out, y.to(0)) / accumulate
        l.backward()
        if i % accumulate == (accumulate - 1):
            optimizer.step()
        

def accuracy(output, target, topk=(1,)):
    output = output.to(torch.device('cpu'))
    target = target.to(torch.device('cpu'))
    maxk = max(topk)
    batch_size = target.shape[0]

    _, idx = output.sort(dim=1, descending=True)
    pred = idx.narrow(1, 0, maxk).t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def epoch_accuracy(loader_s, student):
    student.eval()

    out_epoch_s = [accuracy(student(L.to(0)), y)[0].detach().cpu().item() for L, y in loader_s]

    student.train()

    return sum(out_epoch_s) / len(out_epoch_s)

def test(network, test_loader):
    network.eval().to(0)
    test_loss = 0
    correct = 0
    total = 0
    test_losses=[]
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data.to(0))
            test_loss += torch.nn.CrossEntropyLoss()(output, target.to(0)).item()
            pred = output.data.max(1, keepdim=True)[1].cpu()
            correct += pred.eq(target.data.view_as(pred)).sum()
            total += target.shape[0]
        test_loss /= total
        test_losses.append(test_loss)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, total,
        100. * correct / total))

class IndexSubsetDataset:
    def __init__(self, ds, inds):
        self.ds = ds
        self.inds = inds

    def __iter__(self):
        for i in range(len(self.inds)):
            yield self[i]
    
    def __getitem__(self, i):
        return self.ds[self.inds[i]]
    
    def __len__(self):
        return len(self.inds)
    
    
class FreqModel(torch.nn.Module):
    def __init__(
        self,
        img_size = 224,
        in_chans = 3,
        embed_dim = 384,
    ):
        super().__init__()
        self.img_size = (img_size, img_size)
        self.grid_size = (img_size // 2, 1)
        self.num_patches = img_size // 2

        self.unravel = Unraveling(img_size)
        self.layers = torch.nn.ModuleList([torch.nn.Linear(len(level[0]) * in_chans, embed_dim, dtype=torch.complex64) for level in self.unravel.levels])

    def forward(self, x):

        unraveled = self.unravel(x)
        print([(level.flatten(1).shape, layer.weight.shape) for layer, level in zip(self.layers, unraveled)])
        tokens = torch.cat([layer(level.flatten(1)) for layer, level in zip(self.layers, unraveled)], -2)
        
        return x
    

def linearleaves(module):
    # returns a list of pairs of (parent, submodule_name) pairs for all submodule leaves of the current module
    if isinstance(module, torch.nn.Linear):
        return [(module, None)]

    linear_children = []
    for name, mod in module.named_modules():
        if isinstance(mod, torch.nn.Linear):
            linear_children.append((name, module))
    return linear_children

In [4]:
batch_size_train = 256

In [5]:
def linearleaves(module):
    # returns a list of pairs of (parent, submodule_name) pairs for all submodule leaves of the current module
    if isinstance(module, torch.nn.Linear):
        return [(module, None)]

    linear_children = []
    for name, mod in module.named_modules():
        if isinstance(mod, torch.nn.Linear) or isinstance(mod, torch.nn.Conv2d):
            linear_children.append((name, module))
    return linear_children
        

def getattrrecur(mod, s):
    s = s.split('.')
    for substr in s:
        mod = getattr(mod, substr)
    return mod


def setattrrecur(mod, s, value):
    s = s.split('.')
    for substr in s[:-1]:
        mod = getattr(mod, substr)
    setattr(mod, s[-1], value)


def spectral_normalize(model):
    for name, mod in linearleaves(model):
        setattrrecur(model, name, torch.nn.utils.parametrizations.spectral_norm(getattrrecur(mod, name)))
    
    return model

In [6]:
from shearletNN.complex_deit import Attention, vit_models, Block, FreqEmbed, complex_freakformer_small_patch1_LS, complex_freakformer_small_patch2_LS
from shearletNN.layers import CReLU, ComplexLayerNorm
from functools import partial

def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

train_transform = v2.Compose([
    transforms.RandomResizedCrop((image_size, image_size), scale=(0.5, 1.0)),
    # transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

val_transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

ds_train = torchvision.datasets.CIFAR10('../', transform=train_transform, download = True, train=True)

ds_val = torchvision.datasets.CIFAR10('../', transform=val_transform, download = True, train=False)

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

shearlets = shearlets[:3]

def shearlet_transform(img):
    return shifted_frequency_shearlet_transform(img.to(0), shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

#for x, y in tqdm(train_loader):
#    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, patch_size, patch_size], x.shape
#    break
print('building model...')

model = spectral_normalize(complex_freakformer_small_patch2_LS(
    img_size=patch_size,
    in_chans=shearlets.shape[0] * 3,
))

# model = complex_resnet18(in_dim=shearlets.shape[0] * 3)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(128):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=1)
    gc.collect()
    if epoch % 8 == 7:
        test(model, train_loader)
        test(model, val_loader)



Files already downloaded and verified
Files already downloaded and verified
building model...
training model...
epoch 0


391it [00:36, 10.58it/s]


epoch 1


391it [00:37, 10.50it/s]


epoch 2


391it [00:36, 10.66it/s]


epoch 3


391it [00:36, 10.69it/s]


epoch 4


391it [00:36, 10.65it/s]


epoch 5


391it [00:36, 10.60it/s]


epoch 6


391it [00:36, 10.66it/s]


epoch 7


391it [00:36, 10.74it/s]



Test set: Avg. loss: 0.0137, Accuracy: 19488/50000 (39%)


Test set: Avg. loss: 0.0136, Accuracy: 4062/10000 (41%)

epoch 8


391it [00:36, 10.77it/s]


epoch 9


391it [00:36, 10.79it/s]


epoch 10


391it [00:36, 10.79it/s]


epoch 11


391it [00:36, 10.80it/s]


epoch 12


391it [00:36, 10.83it/s]


epoch 13


391it [00:35, 10.86it/s]


epoch 14


391it [00:36, 10.85it/s]


epoch 15


391it [00:35, 10.87it/s]



Test set: Avg. loss: 0.0116, Accuracy: 24306/50000 (49%)


Test set: Avg. loss: 0.0120, Accuracy: 4756/10000 (48%)

epoch 16


391it [00:35, 10.92it/s]


epoch 17


391it [00:36, 10.86it/s]


epoch 18


391it [00:36, 10.71it/s]


epoch 19


391it [00:35, 10.87it/s]


epoch 20


391it [00:36, 10.85it/s]


epoch 21


391it [00:35, 10.92it/s]


epoch 22


391it [00:36, 10.73it/s]


epoch 23


391it [00:35, 10.88it/s]



Test set: Avg. loss: 0.0104, Accuracy: 26908/50000 (54%)


Test set: Avg. loss: 0.0113, Accuracy: 5170/10000 (52%)

epoch 24


391it [00:36, 10.82it/s]


epoch 25


391it [00:36, 10.70it/s]


epoch 26


391it [00:36, 10.71it/s]


epoch 27


391it [00:36, 10.63it/s]


epoch 28


391it [00:36, 10.69it/s]


epoch 29


391it [00:36, 10.66it/s]


epoch 30


391it [00:36, 10.73it/s]


epoch 31


391it [00:36, 10.64it/s]



Test set: Avg. loss: 0.0093, Accuracy: 29472/50000 (59%)


Test set: Avg. loss: 0.0109, Accuracy: 5341/10000 (53%)

epoch 32


391it [00:36, 10.65it/s]


epoch 33


391it [00:36, 10.68it/s]


epoch 34


391it [00:36, 10.80it/s]


epoch 35


391it [00:35, 10.88it/s]


epoch 36


391it [00:36, 10.80it/s]


epoch 37


391it [00:36, 10.83it/s]


epoch 38


391it [00:36, 10.74it/s]


epoch 39


391it [00:36, 10.68it/s]



Test set: Avg. loss: 0.0085, Accuracy: 31461/50000 (63%)


Test set: Avg. loss: 0.0106, Accuracy: 5544/10000 (55%)

epoch 40


391it [00:36, 10.72it/s]


epoch 41


391it [00:36, 10.76it/s]


epoch 42


391it [00:36, 10.72it/s]


epoch 43


391it [00:36, 10.84it/s]


epoch 44


391it [00:36, 10.79it/s]


epoch 45


391it [00:36, 10.86it/s]


epoch 46


391it [00:35, 10.93it/s]


epoch 47


391it [00:36, 10.83it/s]



Test set: Avg. loss: 0.0075, Accuracy: 33685/50000 (67%)


Test set: Avg. loss: 0.0108, Accuracy: 5610/10000 (56%)

epoch 48


391it [00:36, 10.86it/s]


epoch 49


391it [00:35, 10.89it/s]


epoch 50


391it [00:36, 10.86it/s]


epoch 51


391it [00:36, 10.83it/s]


epoch 52


391it [00:36, 10.76it/s]


epoch 53


391it [00:36, 10.65it/s]


epoch 54


391it [00:35, 10.91it/s]


epoch 55


391it [00:36, 10.81it/s]



Test set: Avg. loss: 0.0068, Accuracy: 35418/50000 (71%)


Test set: Avg. loss: 0.0111, Accuracy: 5680/10000 (57%)

epoch 56


391it [00:36, 10.82it/s]


epoch 57


391it [00:35, 10.90it/s]


epoch 58


391it [00:35, 10.87it/s]


epoch 59


391it [00:35, 10.90it/s]


epoch 60


391it [00:35, 10.86it/s]


epoch 61


391it [00:35, 10.89it/s]


epoch 62


391it [00:36, 10.83it/s]


epoch 63


391it [00:36, 10.71it/s]



Test set: Avg. loss: 0.0061, Accuracy: 36951/50000 (74%)


Test set: Avg. loss: 0.0113, Accuracy: 5806/10000 (58%)

epoch 64


391it [00:36, 10.68it/s]


epoch 65


391it [00:36, 10.78it/s]


epoch 66


391it [00:36, 10.77it/s]


epoch 67


391it [00:35, 10.92it/s]


epoch 68


391it [00:35, 10.89it/s]


epoch 69


391it [00:35, 10.86it/s]


epoch 70


391it [00:35, 10.95it/s]


epoch 71


391it [00:35, 10.87it/s]



Test set: Avg. loss: 0.0056, Accuracy: 38071/50000 (76%)


Test set: Avg. loss: 0.0114, Accuracy: 5788/10000 (58%)

epoch 72


391it [00:35, 10.88it/s]


epoch 73


391it [00:36, 10.73it/s]


epoch 74


391it [00:36, 10.79it/s]


epoch 75


391it [00:35, 10.86it/s]


epoch 76


391it [00:35, 10.89it/s]


epoch 77


391it [00:36, 10.73it/s]


epoch 78


391it [00:36, 10.75it/s]


epoch 79


391it [00:36, 10.64it/s]



Test set: Avg. loss: 0.0051, Accuracy: 39192/50000 (78%)


Test set: Avg. loss: 0.0117, Accuracy: 5749/10000 (57%)

epoch 80


391it [00:36, 10.77it/s]


epoch 81


391it [00:36, 10.74it/s]


epoch 82


391it [00:36, 10.63it/s]


epoch 83


391it [00:36, 10.78it/s]


epoch 84


391it [00:36, 10.81it/s]


epoch 85


391it [00:35, 10.87it/s]


epoch 86


391it [00:36, 10.80it/s]


epoch 87


391it [00:35, 10.87it/s]



Test set: Avg. loss: 0.0047, Accuracy: 40127/50000 (80%)


Test set: Avg. loss: 0.0119, Accuracy: 5870/10000 (59%)

epoch 88


391it [00:36, 10.85it/s]


epoch 89


391it [00:36, 10.82it/s]


epoch 90


391it [00:35, 10.91it/s]


epoch 91


391it [00:36, 10.77it/s]


epoch 92


391it [00:36, 10.74it/s]


epoch 93


391it [00:36, 10.74it/s]


epoch 94


391it [00:36, 10.84it/s]


epoch 95


391it [00:36, 10.83it/s]



Test set: Avg. loss: 0.0041, Accuracy: 41298/50000 (83%)


Test set: Avg. loss: 0.0126, Accuracy: 5881/10000 (59%)

epoch 96


391it [00:36, 10.86it/s]


epoch 97


391it [00:35, 10.91it/s]


epoch 98


391it [00:36, 10.84it/s]


epoch 99


391it [00:36, 10.73it/s]


epoch 100


391it [00:36, 10.72it/s]


epoch 101


391it [00:36, 10.69it/s]


epoch 102


391it [00:36, 10.77it/s]


epoch 103


391it [00:35, 10.92it/s]



Test set: Avg. loss: 0.0039, Accuracy: 41921/50000 (84%)


Test set: Avg. loss: 0.0128, Accuracy: 5926/10000 (59%)

epoch 104


391it [00:35, 10.89it/s]


epoch 105


391it [00:36, 10.85it/s]


epoch 106


391it [00:35, 10.90it/s]


epoch 107


391it [00:36, 10.84it/s]


epoch 108


391it [00:36, 10.81it/s]


epoch 109


391it [00:35, 10.87it/s]


epoch 110


391it [00:36, 10.74it/s]


epoch 111


86it [00:07, 10.76it/s]


KeyboardInterrupt: 

In [6]:
from shearletNN.complex_deit import vit_models, LPatchEmbed, Attention, Block, ComplexLayerNorm, complex_Lfreakformer_small_patch1_LS
from shearletNN.layers import CReLU, ComplexLayerNorm
from functools import partial
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

patch_size = 32
image_size = 32

def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

train_transform = v2.Compose([
    transforms.RandomResizedCrop((image_size, image_size), scale=(0.5, 1.0)),
    # transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

val_transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

ds_train = torchvision.datasets.Caltech101('./', transform=train_transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=val_transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

ds_train = torchvision.datasets.CIFAR10('../', transform=train_transform, download = True, train=True)

ds_val = torchvision.datasets.CIFAR10('../', transform=val_transform, download = True, train=False)

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

shearlets = shearlets[:3]

def shearlet_transform(img):
    return shifted_frequency_shearlet_transform(img.to(0), shearlets.to(0), image_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, image_size, image_size], x.shape
    break
print('building model...')

model = complex_Lfreakformer_small_patch1_LS(img_size = image_size, in_chans = shearlets.shape[0] * 3)


with torch.no_grad():
    for param in model.parameters():
        if not ((param == 0).all() or (param == 1).all()):
            param.data /= 2**(0.5)

model = spectral_normalize(model)


optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(240):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=1)
    gc.collect()
    if epoch % 8 == 7:
        test(model, train_loader)
        test(model, val_loader)



Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


0it [00:00, ?it/s]


building model...
32 16
training model...
epoch 0


391it [00:23, 16.93it/s]


epoch 1


391it [00:23, 16.75it/s]


epoch 2


391it [00:23, 16.79it/s]


epoch 3


391it [00:23, 16.98it/s]


epoch 4


391it [00:22, 17.34it/s]


epoch 5


391it [00:23, 16.91it/s]


epoch 6


391it [00:23, 16.81it/s]


epoch 7


391it [00:23, 16.85it/s]



Test set: Avg. loss: 0.0172, Accuracy: 9547/50000 (19%)


Test set: Avg. loss: 0.0172, Accuracy: 2053/10000 (21%)

epoch 8


391it [00:23, 16.58it/s]


epoch 9


391it [00:23, 16.70it/s]


epoch 10


391it [00:23, 16.83it/s]


epoch 11


391it [00:23, 16.78it/s]


epoch 12


391it [00:23, 16.72it/s]


epoch 13


391it [00:23, 16.80it/s]


epoch 14


391it [00:23, 16.82it/s]


epoch 15


391it [00:23, 16.77it/s]



Test set: Avg. loss: 0.0156, Accuracy: 13996/50000 (28%)


Test set: Avg. loss: 0.0153, Accuracy: 2947/10000 (29%)

epoch 16


391it [00:23, 16.63it/s]


epoch 17


391it [00:23, 16.77it/s]


epoch 18


391it [00:23, 16.75it/s]


epoch 19


391it [00:23, 16.84it/s]


epoch 20


391it [00:23, 16.80it/s]


epoch 21


391it [00:23, 16.76it/s]


epoch 22


391it [00:23, 16.81it/s]


epoch 23


391it [00:23, 16.75it/s]



Test set: Avg. loss: 0.0138, Accuracy: 18942/50000 (38%)


Test set: Avg. loss: 0.0134, Accuracy: 3996/10000 (40%)

epoch 24


391it [00:23, 16.75it/s]


epoch 25


391it [00:23, 16.75it/s]


epoch 26


391it [00:23, 16.79it/s]


epoch 27


391it [00:23, 16.69it/s]


epoch 28


391it [00:23, 16.81it/s]


epoch 29


391it [00:23, 16.78it/s]


epoch 30


391it [00:23, 16.83it/s]


epoch 31


391it [00:23, 16.80it/s]



Test set: Avg. loss: 0.0122, Accuracy: 22763/50000 (46%)


Test set: Avg. loss: 0.0121, Accuracy: 4647/10000 (46%)

epoch 32


391it [00:23, 16.78it/s]


epoch 33


391it [00:23, 16.81it/s]


epoch 34


391it [00:23, 16.80it/s]


epoch 35


391it [00:23, 16.77it/s]


epoch 36


391it [00:23, 16.86it/s]


epoch 37


391it [00:23, 16.84it/s]


epoch 38


391it [00:23, 16.78it/s]


epoch 39


391it [00:23, 16.71it/s]



Test set: Avg. loss: 0.0113, Accuracy: 25162/50000 (50%)


Test set: Avg. loss: 0.0114, Accuracy: 4959/10000 (50%)

epoch 40


391it [00:23, 16.69it/s]


epoch 41


391it [00:23, 16.65it/s]


epoch 42


391it [00:23, 16.69it/s]


epoch 43


391it [00:23, 16.81it/s]


epoch 44


391it [00:23, 16.80it/s]


epoch 45


391it [00:23, 16.87it/s]


epoch 46


391it [00:23, 16.83it/s]


epoch 47


391it [00:23, 16.76it/s]



Test set: Avg. loss: 0.0104, Accuracy: 27081/50000 (54%)


Test set: Avg. loss: 0.0109, Accuracy: 5256/10000 (53%)

epoch 48


391it [00:23, 16.79it/s]


epoch 49


391it [00:23, 16.64it/s]


epoch 50


391it [00:23, 16.77it/s]


epoch 51


391it [00:23, 16.79it/s]


epoch 52


391it [00:23, 16.81it/s]


epoch 53


391it [00:23, 16.79it/s]


epoch 54


391it [00:23, 16.78it/s]


epoch 55


391it [00:23, 16.87it/s]



Test set: Avg. loss: 0.0095, Accuracy: 28954/50000 (58%)


Test set: Avg. loss: 0.0107, Accuracy: 5348/10000 (53%)

epoch 56


391it [00:23, 16.75it/s]


epoch 57


391it [00:23, 16.84it/s]


epoch 58


391it [00:23, 16.74it/s]


epoch 59


391it [00:23, 16.81it/s]


epoch 60


391it [00:23, 16.55it/s]


epoch 61


391it [00:23, 16.71it/s]


epoch 62


391it [00:23, 16.92it/s]


epoch 63


391it [00:23, 16.90it/s]



Test set: Avg. loss: 0.0083, Accuracy: 31525/50000 (63%)


Test set: Avg. loss: 0.0102, Accuracy: 5718/10000 (57%)

epoch 64


391it [00:22, 17.06it/s]


epoch 65


391it [00:22, 17.07it/s]


epoch 66


391it [00:22, 17.08it/s]


epoch 67


391it [00:22, 17.23it/s]


epoch 68


391it [00:23, 16.92it/s]


epoch 69


391it [00:23, 16.85it/s]


epoch 70


391it [00:23, 16.80it/s]


epoch 71


391it [00:23, 16.81it/s]



Test set: Avg. loss: 0.0074, Accuracy: 33736/50000 (67%)


Test set: Avg. loss: 0.0100, Accuracy: 5836/10000 (58%)

epoch 72


391it [00:23, 16.77it/s]


epoch 73


391it [00:23, 16.67it/s]


epoch 74


391it [00:23, 16.60it/s]


epoch 75


391it [00:23, 16.69it/s]


epoch 76


391it [00:23, 16.75it/s]


epoch 77


391it [00:23, 16.70it/s]


epoch 78


391it [00:23, 16.65it/s]


epoch 79


391it [00:23, 16.66it/s]



Test set: Avg. loss: 0.0067, Accuracy: 35453/50000 (71%)


Test set: Avg. loss: 0.0098, Accuracy: 6005/10000 (60%)

epoch 80


391it [00:23, 16.74it/s]


epoch 81


391it [00:23, 16.81it/s]


epoch 82


391it [00:22, 17.06it/s]


epoch 83


391it [00:22, 17.03it/s]


epoch 84


391it [00:23, 16.95it/s]


epoch 85


391it [00:23, 16.80it/s]


epoch 86


391it [00:23, 16.85it/s]


epoch 87


391it [00:23, 16.84it/s]



Test set: Avg. loss: 0.0060, Accuracy: 37046/50000 (74%)


Test set: Avg. loss: 0.0100, Accuracy: 6115/10000 (61%)

epoch 88


391it [00:23, 16.79it/s]


epoch 89


391it [00:23, 16.71it/s]


epoch 90


391it [00:23, 16.90it/s]


epoch 91


391it [00:23, 16.93it/s]


epoch 92


391it [00:23, 16.79it/s]


epoch 93


391it [00:23, 16.86it/s]


epoch 94


391it [00:23, 16.97it/s]


epoch 95


391it [00:23, 16.82it/s]



Test set: Avg. loss: 0.0055, Accuracy: 38199/50000 (76%)


Test set: Avg. loss: 0.0108, Accuracy: 5971/10000 (60%)

epoch 96


154it [00:09, 17.38it/s]

In [None]:
from shearletNN.complex_deit import Attention, vit_models, Block, FreqEmbed, complex_freakformer_small_patch2_LS
from shearletNN.layers import CReLU, ComplexLayerNorm
from functools import partial

def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

train_transform = v2.Compose([
    transforms.RandomResizedCrop((image_size, image_size), scale=(0.5, 1.0)),
    # transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

val_transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

ds_train = torchvision.datasets.CIFAR10('../', transform=train_transform, download = True, train=True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.CIFAR10('../', transform=val_transform, download = True, train=False)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)


val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)


model = torchvision.models.resnet18()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(128):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=1)
    gc.collect()

    test(model, train_loader)
    test(model, val_loader)

In [None]:
print('training model...')
for epoch in range(256):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=1)
    gc.collect()
    if epoch % 16 == 15:
        test(model, train_loader)
        test(model, val_loader)

In [6]:
from shearletNN.complex_deit import vit_models, LPatchEmbed, Attention, Block, ComplexLayerNorm
from shearletNN.layers import CReLU, ComplexLayerNorm
from functools import partial

patch_size = 64
image_size = 128

def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

train_transform = v2.Compose([
    transforms.RandomResizedCrop((image_size, image_size), scale=(0.5, 1.0)),
    # transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

val_transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

ds_train = torchvision.datasets.Caltech101('./', transform=train_transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=val_transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

shearlets = shearlets[:3]

def shearlet_transform(img):
    return shifted_frequency_shearlet_transform(img.to(0), shearlets.to(0), image_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, image_size, image_size], x.shape
    break
print('building model...')

model = vit_models(
        in_chans=shearlets.shape[0] * 3,
        img_size=image_size,
        patch_size=32,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(ComplexLayerNorm, eps=1e-6),
        block_layers=Block,
        Attention_block=Attention,
        act_layer=CGELU,
        Patch_layer=LPatchEmbed,
    )


with torch.no_grad():
    for param in model.parameters():
        if not ((param == 0).all() or (param == 1).all()):
            param.data /= 2**(0.5)

model = spectral_normalize(model)


optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(128):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=2)
    gc.collect()
    if epoch % 8 == 7:
        test(model, train_loader)
        test(model, val_loader)



Files already downloaded and verified
Files already downloaded and verified


0it [00:00, ?it/s]


building model...
training model...
epoch 0


55it [00:09,  5.82it/s]


epoch 1


55it [00:09,  5.74it/s]


epoch 2


55it [00:09,  5.59it/s]


epoch 3


55it [00:09,  5.75it/s]


epoch 4


55it [00:09,  5.92it/s]


epoch 5


55it [00:09,  6.05it/s]


epoch 6


55it [00:09,  6.00it/s]


epoch 7


55it [00:09,  5.77it/s]



Test set: Avg. loss: 25.8199, Accuracy: 1011/6941 (15%)


Test set: Avg. loss: 30.7178, Accuracy: 191/1736 (11%)

epoch 8


55it [00:09,  5.90it/s]


epoch 9


55it [00:09,  6.02it/s]


epoch 10


55it [00:09,  5.89it/s]


epoch 11


55it [00:09,  5.85it/s]


epoch 12


55it [00:09,  5.85it/s]


epoch 13


55it [00:09,  6.01it/s]


epoch 14


55it [00:09,  5.96it/s]


epoch 15


55it [00:09,  5.75it/s]



Test set: Avg. loss: 15.6713, Accuracy: 1021/6941 (15%)


Test set: Avg. loss: 18.0983, Accuracy: 214/1736 (12%)

epoch 16


55it [00:09,  5.92it/s]


epoch 17


55it [00:09,  6.08it/s]


epoch 18


55it [00:09,  5.97it/s]


epoch 19


55it [00:09,  5.92it/s]


epoch 20


55it [00:09,  5.94it/s]


epoch 21


55it [00:09,  5.96it/s]


epoch 22


55it [00:09,  5.95it/s]


epoch 23


55it [00:09,  5.98it/s]



Test set: Avg. loss: 27.5682, Accuracy: 910/6941 (13%)


Test set: Avg. loss: 33.7617, Accuracy: 162/1736 (9%)

epoch 24


55it [00:09,  6.05it/s]


epoch 25


55it [00:09,  6.04it/s]


epoch 26


55it [00:09,  5.98it/s]


epoch 27


55it [00:09,  5.96it/s]


epoch 28


55it [00:09,  5.83it/s]


epoch 29


55it [00:09,  5.80it/s]


epoch 30


55it [00:09,  5.71it/s]


epoch 31


55it [00:09,  6.00it/s]



Test set: Avg. loss: 8.6829, Accuracy: 1209/6941 (17%)


Test set: Avg. loss: 10.1437, Accuracy: 243/1736 (14%)

epoch 32


55it [00:09,  6.04it/s]


epoch 33


55it [00:09,  5.98it/s]


epoch 34


55it [00:09,  6.05it/s]


epoch 35


55it [00:09,  6.06it/s]


epoch 36


55it [00:09,  6.05it/s]


epoch 37


55it [00:09,  5.97it/s]


epoch 38


55it [00:09,  6.01it/s]


epoch 39


55it [00:09,  5.88it/s]



Test set: Avg. loss: 6.5396, Accuracy: 1338/6941 (19%)


Test set: Avg. loss: 8.0872, Accuracy: 249/1736 (14%)

epoch 40


55it [00:09,  5.83it/s]


epoch 41


55it [00:09,  5.58it/s]


epoch 42


55it [00:09,  5.98it/s]


epoch 43


55it [00:09,  6.05it/s]


epoch 44


55it [00:09,  6.11it/s]


epoch 45


55it [00:09,  6.00it/s]


epoch 46


55it [00:09,  6.03it/s]


epoch 47


55it [00:09,  5.79it/s]



Test set: Avg. loss: 5.8299, Accuracy: 1323/6941 (19%)


Test set: Avg. loss: 6.9055, Accuracy: 244/1736 (14%)

epoch 48


55it [00:09,  5.98it/s]


epoch 49


55it [00:09,  6.04it/s]


epoch 50


55it [00:09,  6.04it/s]


epoch 51


55it [00:09,  6.05it/s]


epoch 52


55it [00:08,  6.14it/s]


epoch 53


55it [00:08,  6.15it/s]


epoch 54


55it [00:08,  6.20it/s]


epoch 55


55it [00:08,  6.12it/s]



Test set: Avg. loss: 4.7187, Accuracy: 1374/6941 (20%)


Test set: Avg. loss: 5.7180, Accuracy: 252/1736 (15%)

epoch 56


55it [00:08,  6.19it/s]


epoch 57


55it [00:09,  6.07it/s]


epoch 58


55it [00:08,  6.18it/s]


epoch 59


55it [00:09,  6.02it/s]


epoch 60


55it [00:08,  6.15it/s]


epoch 61


55it [00:09,  6.08it/s]


epoch 62


55it [00:08,  6.13it/s]


epoch 63


55it [00:09,  5.86it/s]



Test set: Avg. loss: 3.9129, Accuracy: 1492/6941 (21%)


Test set: Avg. loss: 4.8920, Accuracy: 241/1736 (14%)

epoch 64


55it [00:08,  6.13it/s]


epoch 65


55it [00:09,  6.07it/s]


epoch 66


55it [00:09,  6.08it/s]


epoch 67


55it [00:09,  6.03it/s]


epoch 68


55it [00:08,  6.13it/s]


epoch 69


55it [00:09,  6.05it/s]


epoch 70


55it [00:09,  6.09it/s]


epoch 71


55it [00:09,  6.06it/s]



Test set: Avg. loss: 3.4768, Accuracy: 1550/6941 (22%)


Test set: Avg. loss: 4.5090, Accuracy: 236/1736 (14%)

epoch 72


55it [00:09,  6.04it/s]


epoch 73


55it [00:09,  5.97it/s]


epoch 74


55it [00:09,  5.97it/s]


epoch 75


55it [00:09,  6.00it/s]


epoch 76


55it [00:09,  6.05it/s]


epoch 77


55it [00:09,  5.96it/s]


epoch 78


55it [00:09,  5.98it/s]


epoch 79


55it [00:09,  6.03it/s]



Test set: Avg. loss: 2.9842, Accuracy: 1550/6941 (22%)


Test set: Avg. loss: 3.8391, Accuracy: 234/1736 (13%)

epoch 80


55it [00:09,  5.96it/s]


epoch 81


55it [00:09,  5.87it/s]


epoch 82


55it [00:09,  5.97it/s]


epoch 83


55it [00:09,  6.02it/s]


epoch 84


55it [00:09,  5.89it/s]


epoch 85


55it [00:09,  6.04it/s]


epoch 86


55it [00:09,  5.99it/s]


epoch 87


55it [00:09,  5.90it/s]



Test set: Avg. loss: 2.5861, Accuracy: 1679/6941 (24%)


Test set: Avg. loss: 3.4592, Accuracy: 247/1736 (14%)

epoch 88


55it [00:09,  6.00it/s]


epoch 89


55it [00:09,  6.01it/s]


epoch 90


55it [00:09,  5.97it/s]


epoch 91


55it [00:09,  5.98it/s]


epoch 92


55it [00:09,  5.97it/s]


epoch 93


55it [00:09,  6.04it/s]


epoch 94


55it [00:09,  5.93it/s]


epoch 95


55it [00:09,  6.03it/s]



Test set: Avg. loss: 2.3804, Accuracy: 1689/6941 (24%)


Test set: Avg. loss: 3.0743, Accuracy: 270/1736 (16%)

epoch 96


55it [00:09,  5.95it/s]


epoch 97


55it [00:09,  5.97it/s]


epoch 98


55it [00:09,  5.98it/s]


epoch 99


55it [00:09,  5.98it/s]


epoch 100


55it [00:09,  5.93it/s]


epoch 101


55it [00:09,  5.92it/s]


epoch 102


55it [00:09,  5.85it/s]


epoch 103


55it [00:09,  5.98it/s]



Test set: Avg. loss: 2.0931, Accuracy: 1710/6941 (25%)


Test set: Avg. loss: 2.7644, Accuracy: 299/1736 (17%)

epoch 104


55it [00:09,  6.00it/s]


epoch 105


55it [00:09,  6.04it/s]


epoch 106


55it [00:09,  5.83it/s]


epoch 107


55it [00:09,  6.04it/s]


epoch 108


55it [00:09,  6.03it/s]


epoch 109


55it [00:09,  6.07it/s]


epoch 110


55it [00:09,  6.00it/s]


epoch 111


55it [00:09,  5.99it/s]



Test set: Avg. loss: 2.0038, Accuracy: 1634/6941 (24%)


Test set: Avg. loss: 2.6723, Accuracy: 245/1736 (14%)

epoch 112


55it [00:09,  5.99it/s]


epoch 113


55it [00:09,  5.95it/s]


epoch 114


55it [00:09,  6.01it/s]


epoch 115


55it [00:09,  5.94it/s]


epoch 116


55it [00:09,  5.96it/s]


epoch 117


55it [00:09,  5.89it/s]


epoch 118


55it [00:09,  6.05it/s]


epoch 119


55it [00:09,  5.96it/s]



Test set: Avg. loss: 1.7236, Accuracy: 1731/6941 (25%)


Test set: Avg. loss: 2.3976, Accuracy: 319/1736 (18%)

epoch 120


55it [00:09,  6.03it/s]


epoch 121


55it [00:09,  5.96it/s]


epoch 122


55it [00:09,  6.02it/s]


epoch 123


55it [00:09,  5.95it/s]


epoch 124


55it [00:09,  5.99it/s]


epoch 125


55it [00:09,  5.95it/s]


epoch 126


55it [00:09,  5.96it/s]


epoch 127


55it [00:09,  5.89it/s]



Test set: Avg. loss: 1.5677, Accuracy: 1779/6941 (26%)


Test set: Avg. loss: 2.1764, Accuracy: 290/1736 (17%)



In [9]:
for epoch in range(256):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=2)
    gc.collect()
    if epoch % 8 == 7:
        test(model, train_loader)
        test(model, val_loader)

epoch 0


55it [00:09,  5.90it/s]


epoch 1


55it [00:09,  6.06it/s]


epoch 2


55it [00:08,  6.13it/s]


epoch 3


55it [00:08,  6.14it/s]


epoch 4


55it [00:08,  6.12it/s]


epoch 5


55it [00:09,  6.10it/s]


epoch 6


55it [00:08,  6.16it/s]


epoch 7


55it [00:08,  6.16it/s]



Test set: Avg. loss: 0.0306, Accuracy: 2123/6941 (31%)


Test set: Avg. loss: 0.0416, Accuracy: 345/1736 (20%)

epoch 8


55it [00:08,  6.12it/s]


epoch 9


55it [00:09,  6.00it/s]


epoch 10


55it [00:09,  5.83it/s]


epoch 11


55it [00:09,  5.81it/s]


epoch 12


55it [00:09,  5.83it/s]


epoch 13


55it [00:09,  5.83it/s]


epoch 14


55it [00:09,  5.93it/s]


epoch 15


55it [00:09,  5.84it/s]



Test set: Avg. loss: 0.0284, Accuracy: 2314/6941 (33%)


Test set: Avg. loss: 0.0371, Accuracy: 403/1736 (23%)

epoch 16


55it [00:09,  5.68it/s]


epoch 17


55it [00:09,  5.81it/s]


epoch 18


55it [00:09,  5.83it/s]


epoch 19


55it [00:09,  5.95it/s]


epoch 20


55it [00:09,  5.82it/s]


epoch 21


55it [00:09,  6.03it/s]


epoch 22


55it [00:09,  6.05it/s]


epoch 23


55it [00:09,  6.01it/s]



Test set: Avg. loss: 0.0266, Accuracy: 2388/6941 (34%)


Test set: Avg. loss: 0.0349, Accuracy: 434/1736 (25%)

epoch 24


55it [00:09,  6.07it/s]


epoch 25


55it [00:09,  6.04it/s]


epoch 26


55it [00:09,  5.98it/s]


epoch 27


55it [00:09,  6.05it/s]


epoch 28


55it [00:09,  6.01it/s]


epoch 29


55it [00:09,  6.03it/s]


epoch 30


55it [00:09,  5.80it/s]


epoch 31


55it [00:09,  5.91it/s]



Test set: Avg. loss: 0.0257, Accuracy: 2417/6941 (35%)


Test set: Avg. loss: 0.0322, Accuracy: 461/1736 (27%)

epoch 32


55it [00:09,  5.87it/s]


epoch 33


55it [00:09,  5.69it/s]


epoch 34


55it [00:09,  5.84it/s]


epoch 35


55it [00:09,  5.93it/s]


epoch 36


55it [00:09,  5.94it/s]


epoch 37


55it [00:09,  5.84it/s]


epoch 38


55it [00:09,  5.76it/s]


epoch 39


55it [00:09,  5.76it/s]



Test set: Avg. loss: 0.0247, Accuracy: 2513/6941 (36%)


Test set: Avg. loss: 0.0326, Accuracy: 438/1736 (25%)

epoch 40


55it [00:10,  5.39it/s]


epoch 41


55it [00:09,  5.72it/s]


epoch 42


55it [00:09,  5.77it/s]


epoch 43


55it [00:09,  5.59it/s]


epoch 44


55it [00:09,  5.67it/s]


epoch 45


55it [00:09,  5.85it/s]


epoch 46


55it [00:09,  5.72it/s]


epoch 47


55it [00:09,  5.94it/s]



Test set: Avg. loss: 0.0239, Accuracy: 2582/6941 (37%)


Test set: Avg. loss: 0.0307, Accuracy: 481/1736 (28%)

epoch 48


55it [00:09,  5.79it/s]


epoch 49


55it [00:09,  5.98it/s]


epoch 50


55it [00:09,  5.91it/s]


epoch 51


55it [00:09,  5.99it/s]


epoch 52


55it [00:09,  5.83it/s]


epoch 53


55it [00:09,  5.93it/s]


epoch 54


55it [00:09,  6.09it/s]


epoch 55


55it [00:09,  5.94it/s]



Test set: Avg. loss: 0.0243, Accuracy: 2591/6941 (37%)


Test set: Avg. loss: 0.0321, Accuracy: 472/1736 (27%)

epoch 56


55it [00:09,  5.85it/s]


epoch 57


55it [00:09,  5.97it/s]


epoch 58


55it [00:09,  6.03it/s]


epoch 59


55it [00:09,  6.08it/s]


epoch 60


55it [00:09,  6.00it/s]


epoch 61


55it [00:09,  5.85it/s]


epoch 62


55it [00:09,  6.02it/s]


epoch 63


55it [00:09,  6.05it/s]



Test set: Avg. loss: 0.0234, Accuracy: 2697/6941 (39%)


Test set: Avg. loss: 0.0316, Accuracy: 477/1736 (27%)

epoch 64


55it [00:09,  6.06it/s]


epoch 65


55it [00:09,  5.80it/s]


epoch 66


55it [00:09,  5.68it/s]


epoch 67


55it [00:09,  5.70it/s]


epoch 68


55it [00:09,  5.86it/s]


epoch 69


55it [00:09,  5.89it/s]


epoch 70


55it [00:09,  5.84it/s]


epoch 71


55it [00:09,  5.75it/s]



Test set: Avg. loss: 0.0236, Accuracy: 2644/6941 (38%)


Test set: Avg. loss: 0.0308, Accuracy: 483/1736 (28%)

epoch 72


55it [00:09,  5.95it/s]


epoch 73


55it [00:09,  5.87it/s]


epoch 74


55it [00:09,  5.89it/s]


epoch 75


55it [00:09,  5.67it/s]


epoch 76


55it [00:09,  6.01it/s]


epoch 77


55it [00:09,  6.01it/s]


epoch 78


55it [00:09,  5.92it/s]


epoch 79


55it [00:09,  5.95it/s]



Test set: Avg. loss: 0.0231, Accuracy: 2764/6941 (40%)


Test set: Avg. loss: 0.0303, Accuracy: 487/1736 (28%)

epoch 80


55it [00:09,  6.03it/s]


epoch 81


55it [00:09,  6.05it/s]


epoch 82


55it [00:09,  5.75it/s]


epoch 83


55it [00:09,  5.70it/s]


epoch 84


55it [00:09,  5.79it/s]


epoch 85


55it [00:09,  5.82it/s]


epoch 86


55it [00:09,  5.73it/s]


epoch 87


55it [00:09,  5.91it/s]



Test set: Avg. loss: 0.0223, Accuracy: 2815/6941 (41%)


Test set: Avg. loss: 0.0293, Accuracy: 537/1736 (31%)

epoch 88


55it [00:09,  5.62it/s]


epoch 89


55it [00:09,  5.94it/s]


epoch 90


55it [00:09,  5.78it/s]


epoch 91


55it [00:09,  5.81it/s]


epoch 92


55it [00:09,  5.83it/s]


epoch 93


55it [00:09,  5.94it/s]


epoch 94


55it [00:09,  5.83it/s]


epoch 95


55it [00:09,  5.94it/s]



Test set: Avg. loss: 0.0223, Accuracy: 2812/6941 (41%)


Test set: Avg. loss: 0.0310, Accuracy: 492/1736 (28%)

epoch 96


55it [00:09,  5.95it/s]


epoch 97


55it [00:09,  6.03it/s]


epoch 98


55it [00:09,  6.00it/s]


epoch 99


55it [00:09,  5.95it/s]


epoch 100


55it [00:09,  6.03it/s]


epoch 101


55it [00:09,  5.99it/s]


epoch 102


55it [00:09,  6.03it/s]


epoch 103


55it [00:09,  6.06it/s]



Test set: Avg. loss: 0.0216, Accuracy: 2896/6941 (42%)


Test set: Avg. loss: 0.0302, Accuracy: 496/1736 (29%)

epoch 104


55it [00:08,  6.15it/s]


epoch 105


47it [00:07,  5.95it/s]


KeyboardInterrupt: 



Files already downloaded and verified
Files already downloaded and verified


0it [00:00, ?it/s]


building model...
training model...
epoch 0


0it [00:00, ?it/s]


RuntimeError: The size of tensor a (4) must match the size of tensor b (16) at non-singleton dimension 1

In [None]:
# amplitude:
amp = torch.abs(x)
print((amp.real.max(), amp.real.min()))

phase = torch.arctan(x.imag / x.real)

phase = torch.nan_to_num(torch.arctan(x.imag / x.real), posinf=torch.math.pi / 2, neginf= -torch.math.pi / 2)

print((phase.real.max(), phase.real.min()))

def to_magnitude_phase(x):
    """
    return magnitude/phase representation 
    """
    return torch.complex(torch.abs(x), torch.nan_to_num(torch.arctan(x.imag / x.real), posinf=torch.math.pi / 2, neginf= -torch.math.pi / 2))

(tensor(5.4802, device='cuda:0'), tensor(0., device='cuda:0'))
(tensor(1.5708, device='cuda:0'), tensor(-1.5708, device='cuda:0'))


In [None]:
x = torch.tensor([[1.5,.0,.0,.0]])
layerNorm = torch.nn.LayerNorm(4, elementwise_affine = False)
y1 = layerNorm(x)
mean = x.mean(-1, keepdim = True)
var = x.var(-1, keepdim = True, unbiased=False)
y2 = (x-mean)/torch.sqrt(var+layerNorm.eps)

torch.allclose(y1, y2)

True

In [None]:
x = torch.randn((50,20,100))

layerNorm = torch.nn.LayerNorm(x.shape[-1], elementwise_affine = True)
y1 = layerNorm(x)

mean = torch.mean(x, dim=-1, keepdim=True)
var = torch.square(x - mean).mean(dim=-1, keepdim=True)
y2 = ((x - mean) / torch.sqrt(var + layerNorm.eps)) * layerNorm.weight + layerNorm.bias

torch.allclose(y1, y2, atol=1e-5, rtol=1e-5)

True