In [1]:
from shearletNN.shearlets import getcomplexshearlets2D
from shearletNN.shearlet_utils import frequency_shearlet_transform, spatial_shearlet_transform, ShearletTransformLoader
from shearletNN.complex_resnet import complex_resnet18, complex_resnet34, complex_resnet50
from shearletNN.layers import CGELU, CReLU

import torch
import torchvision
from torchvision.transforms import v2
from torchvision import transforms

import gc


patch_size = 96
image_size = 128

rows, cols = image_size, image_size


shearlets, shearletIdxs, RMS, dualFrameWeights = getcomplexshearlets2D(	rows, 
                                                                        cols, 
                                                                        1, 
                                                                        3, 
                                                                        1, 
                                                                        0.5,
                                                                        wavelet_eff_support = image_size,
                                                                        gaussian_eff_support = image_size
                                                                        )

shearlets = torch.tensor(shearlets).permute(2, 0, 1).type(torch.complex64).to(0)

In [2]:
class Unraveling:
    def __init__(self, n):
        self.levels = []
        for i in range(0, n // 2):
            level = []
            for j in range(i, n - i):
                level.append((j, i))
                level.append((i, j))

                level.append((j, n - (i + 1)))
                level.append((n - (i + 1), j))

            level = list(set(level))
            self.levels.append((torch.tensor([x for x, _ in level]), torch.tensor([y for _, y in level])))
    
    def __call__(self, x):
        return [x[..., a, b] for a, b in self.levels]

# this is still technically resolution-independent because that property comes from the shearlet transform crop
# in this case we do not actually need variable number of input tokens because we are always extracting a finite amount of information...
# if the whole premise is that we can utilize even larger images then why does it matter that we are using a variable size input?
class Freakformer(torch.nn.Module):
    def __init__(self, n, chans, embed_dim = 384):
        # first we need to build the tokenizer using the Unraveling transformation and some linear layers
        # this is equivalent to the patching layer
        self.unravel = Unraveling(n)
        self.layers = torch.nn.ModuleList([torch.nn.Linear(len(level), embed_dim) for level in self.unravel.layers])
        pass

    def forward(self, x):
        unraveled = self.unravel(x)
        # (B, C, 4 ... 8(n // 2) + 4)
        tokens = torch.stack([layer(level.flatten(1)) for layer, level in zip(self.layers, unraveled)], -2)
        # (B, n // 2, embed_dim)
        pass

In [9]:
from tqdm import tqdm

def train(model, optimizer, loader, accumulate=1):
    model.train()
    loss = torch.nn.CrossEntropyLoss()

    for i, (X, y) in tqdm(enumerate(loader)):
        out = model(X.to(0))
        optimizer.zero_grad()
        l = loss(out, y.to(0)) / accumulate
        l.backward()
        if i % accumulate == (accumulate - 1):
            optimizer.step()
        

def accuracy(output, target, topk=(1,)):
    output = output.to(torch.device('cpu'))
    target = target.to(torch.device('cpu'))
    maxk = max(topk)
    batch_size = target.shape[0]

    _, idx = output.sort(dim=1, descending=True)
    pred = idx.narrow(1, 0, maxk).t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def epoch_accuracy(loader_s, student):
    student.eval()

    out_epoch_s = [accuracy(student(L.to(0)), y)[0].detach().cpu().item() for L, y in loader_s]

    student.train()

    return sum(out_epoch_s) / len(out_epoch_s)

def test(network, test_loader):
    network.eval().to(0)
    test_loss = 0
    correct = 0
    total = 0
    test_losses=[]
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data.to(0))
            test_loss += torch.nn.CrossEntropyLoss()(output, target.to(0)).item()
            pred = output.data.max(1, keepdim=True)[1].cpu()
            correct += pred.eq(target.data.view_as(pred)).sum()
            total += target.shape[0]
        test_loss /= total
        test_losses.append(test_loss)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, total,
        100. * correct / total))

class IndexSubsetDataset:
    def __init__(self, ds, inds):
        self.ds = ds
        self.inds = inds

    def __iter__(self):
        for i in range(len(self.inds)):
            yield self[i]
    
    def __getitem__(self, i):
        return self.ds[self.inds[i]]
    
    def __len__(self):
        return len(self.inds)
    
    
class FreqModel(torch.nn.Module):
    def __init__(
        self,
        img_size = 224,
        in_chans = 3,
        embed_dim = 384,
    ):
        super().__init__()
        self.img_size = (img_size, img_size)
        self.grid_size = (img_size // 2, 1)
        self.num_patches = img_size // 2

        self.unravel = Unraveling(img_size)
        self.layers = torch.nn.ModuleList([torch.nn.Linear(len(level[0]) * in_chans, embed_dim, dtype=torch.complex64) for level in self.unravel.levels])

    def forward(self, x):

        unraveled = self.unravel(x)
        print([(level.flatten(1).shape, layer.weight.shape) for layer, level in zip(self.layers, unraveled)])
        tokens = torch.cat([layer(level.flatten(1)) for layer, level in zip(self.layers, unraveled)], -2)
        
        return x
    

def linearleaves(module):
    # returns a list of pairs of (parent, submodule_name) pairs for all submodule leaves of the current module
    if isinstance(module, torch.nn.Linear):
        return [(module, None)]

    linear_children = []
    for name, mod in module.named_modules():
        if isinstance(mod, torch.nn.Linear):
            linear_children.append((name, module))
    return linear_children

In [4]:
batch_size_train = 256

In [10]:
from shearletNN.complex_deit import Attention, vit_models, Block, FreqEmbed
from shearletNN.layers import CReLU, ComplexLayerNorm
from functools import partial

def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

train_transform = v2.Compose([
    transforms.RandomResizedCrop((image_size, image_size), scale=(0.5, 1.0)),
    # transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

val_transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=train_transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=val_transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

shearlets = shearlets[:3]

def shearlet_transform(img):
    return frequency_shearlet_transform(img, shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, patch_size, patch_size], x.shape
    break
print('building model...')

model = vit_models(
    img_size=patch_size,
    in_chans=shearlets.shape[0] * 3,
    patch_size=16,
    embed_dim=384,
    depth=12,
    num_heads=6,
    mlp_ratio=4,
    qkv_bias=True,
    norm_layer=partial(ComplexLayerNorm, eps=1e-6),
    block_layers=Block,
    Attention_block=Attention,
    act_layer=CGELU,
    Patch_layer=FreqEmbed,
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(128):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=1)
    gc.collect()
    if epoch % 16 == 15:
        test(model, train_loader)
        test(model, val_loader)

Files already downloaded and verified
Files already downloaded and verified


0it [00:00, ?it/s]


building model...


AttributeError: 'vit_models' object has no attribute 'weight'

In [18]:
print('training model...')
for epoch in range(256):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=1)
    gc.collect()
    if epoch % 16 == 15:
        test(model, train_loader)
        test(model, val_loader)

training model...
epoch 0


28it [00:08,  3.38it/s]


epoch 1


28it [00:08,  3.43it/s]


epoch 2


28it [00:07,  3.52it/s]


epoch 3


28it [00:08,  3.48it/s]


epoch 4


28it [00:07,  3.52it/s]


epoch 5


28it [00:07,  3.54it/s]


epoch 6


28it [00:08,  3.50it/s]


epoch 7


28it [00:07,  3.54it/s]


epoch 8


28it [00:07,  3.52it/s]


epoch 9


28it [00:07,  3.51it/s]


epoch 10


28it [00:07,  3.51it/s]


epoch 11


28it [00:07,  3.50it/s]


epoch 12


28it [00:07,  3.53it/s]


epoch 13


28it [00:07,  3.50it/s]


epoch 14


28it [00:07,  3.53it/s]


epoch 15


28it [00:07,  3.53it/s]



Test set: Avg. loss: 0.0018, Accuracy: 6090/6941 (88%)


Test set: Avg. loss: 0.0162, Accuracy: 669/1736 (39%)

epoch 16


28it [00:07,  3.51it/s]


epoch 17


28it [00:07,  3.53it/s]


epoch 18


28it [00:07,  3.55it/s]


epoch 19


28it [00:08,  3.50it/s]


epoch 20


28it [00:07,  3.50it/s]


epoch 21


28it [00:08,  3.49it/s]


epoch 22


28it [00:08,  3.49it/s]


epoch 23


28it [00:07,  3.52it/s]


epoch 24


28it [00:07,  3.51it/s]


epoch 25


28it [00:07,  3.53it/s]


epoch 26


28it [00:07,  3.51it/s]


epoch 27


28it [00:07,  3.51it/s]


epoch 28


28it [00:08,  3.50it/s]


epoch 29


28it [00:07,  3.52it/s]


epoch 30


28it [00:07,  3.51it/s]


epoch 31


28it [00:07,  3.54it/s]



Test set: Avg. loss: 0.0017, Accuracy: 6160/6941 (89%)


Test set: Avg. loss: 0.0167, Accuracy: 644/1736 (37%)

epoch 32


28it [00:08,  3.48it/s]


epoch 33


28it [00:08,  3.49it/s]


epoch 34


28it [00:08,  3.44it/s]


epoch 35


28it [00:08,  3.47it/s]


epoch 36


28it [00:08,  3.49it/s]


epoch 37


28it [00:08,  3.46it/s]


epoch 38


28it [00:08,  3.49it/s]


epoch 39


28it [00:07,  3.54it/s]


epoch 40


28it [00:07,  3.52it/s]


epoch 41


28it [00:08,  3.47it/s]


epoch 42


28it [00:07,  3.53it/s]


epoch 43


28it [00:07,  3.52it/s]


epoch 44


28it [00:07,  3.52it/s]


epoch 45


28it [00:07,  3.53it/s]


epoch 46


28it [00:07,  3.53it/s]


epoch 47


28it [00:07,  3.53it/s]



Test set: Avg. loss: 0.0016, Accuracy: 6161/6941 (89%)


Test set: Avg. loss: 0.0175, Accuracy: 601/1736 (35%)

epoch 48


28it [00:07,  3.53it/s]


epoch 49


28it [00:07,  3.51it/s]


epoch 50


28it [00:07,  3.50it/s]


epoch 51


28it [00:07,  3.51it/s]


epoch 52


28it [00:08,  3.50it/s]


epoch 53


28it [00:07,  3.51it/s]


epoch 54


28it [00:08,  3.47it/s]


epoch 55


28it [00:07,  3.52it/s]


epoch 56


28it [00:07,  3.54it/s]


epoch 57


28it [00:07,  3.55it/s]


epoch 58


28it [00:07,  3.51it/s]


epoch 59


28it [00:07,  3.51it/s]


epoch 60


28it [00:07,  3.53it/s]


epoch 61


28it [00:07,  3.50it/s]


epoch 62


28it [00:07,  3.50it/s]


epoch 63


28it [00:08,  3.46it/s]



Test set: Avg. loss: 0.0015, Accuracy: 6251/6941 (90%)


Test set: Avg. loss: 0.0171, Accuracy: 648/1736 (37%)

epoch 64


28it [00:08,  3.48it/s]


epoch 65


28it [00:07,  3.55it/s]


epoch 66


28it [00:07,  3.53it/s]


epoch 67


28it [00:07,  3.52it/s]


epoch 68


28it [00:08,  3.49it/s]


epoch 69


28it [00:07,  3.51it/s]


epoch 70


28it [00:07,  3.51it/s]


epoch 71


28it [00:07,  3.57it/s]


epoch 72


28it [00:07,  3.52it/s]


epoch 73


28it [00:07,  3.54it/s]


epoch 74


28it [00:07,  3.51it/s]


epoch 75


28it [00:07,  3.56it/s]


epoch 76


28it [00:07,  3.55it/s]


epoch 77


28it [00:08,  3.49it/s]


epoch 78


28it [00:08,  3.49it/s]


epoch 79


28it [00:07,  3.52it/s]



Test set: Avg. loss: 0.0013, Accuracy: 6317/6941 (91%)


Test set: Avg. loss: 0.0179, Accuracy: 650/1736 (37%)

epoch 80


28it [00:07,  3.51it/s]


epoch 81


28it [00:07,  3.52it/s]


epoch 82


28it [00:08,  3.49it/s]


epoch 83


28it [00:07,  3.51it/s]


epoch 84


28it [00:07,  3.52it/s]


epoch 85


28it [00:07,  3.53it/s]


epoch 86


28it [00:07,  3.53it/s]


epoch 87


28it [00:07,  3.52it/s]


epoch 88


28it [00:07,  3.53it/s]


epoch 89


28it [00:07,  3.52it/s]


epoch 90


28it [00:07,  3.54it/s]


epoch 91


28it [00:07,  3.52it/s]


epoch 92


28it [00:07,  3.51it/s]


epoch 93


28it [00:08,  3.48it/s]


epoch 94


28it [00:07,  3.50it/s]


epoch 95


28it [00:08,  3.49it/s]



Test set: Avg. loss: 0.0012, Accuracy: 6355/6941 (92%)


Test set: Avg. loss: 0.0181, Accuracy: 613/1736 (35%)

epoch 96


28it [00:07,  3.50it/s]


epoch 97


28it [00:07,  3.54it/s]


epoch 98


28it [00:08,  3.49it/s]


epoch 99


28it [00:07,  3.52it/s]


epoch 100


28it [00:07,  3.55it/s]


epoch 101


28it [00:07,  3.52it/s]


epoch 102


28it [00:08,  3.48it/s]


epoch 103


28it [00:07,  3.50it/s]


epoch 104


28it [00:07,  3.50it/s]


epoch 105


28it [00:07,  3.50it/s]


epoch 106


28it [00:07,  3.50it/s]


epoch 107


28it [00:07,  3.51it/s]


epoch 108


28it [00:08,  3.48it/s]


epoch 109


28it [00:07,  3.52it/s]


epoch 110


28it [00:07,  3.51it/s]


epoch 111


28it [00:07,  3.52it/s]



Test set: Avg. loss: 0.0012, Accuracy: 6329/6941 (91%)


Test set: Avg. loss: 0.0179, Accuracy: 644/1736 (37%)

epoch 112


28it [00:07,  3.54it/s]


epoch 113


28it [00:07,  3.52it/s]


epoch 114


28it [00:07,  3.51it/s]


epoch 115


28it [00:07,  3.55it/s]


epoch 116


28it [00:07,  3.53it/s]


epoch 117


28it [00:07,  3.52it/s]


epoch 118


28it [00:07,  3.55it/s]


epoch 119


28it [00:07,  3.54it/s]


epoch 120


28it [00:08,  3.50it/s]


epoch 121


28it [00:07,  3.54it/s]


epoch 122


28it [00:07,  3.53it/s]


epoch 123


28it [00:07,  3.53it/s]


epoch 124


28it [00:07,  3.53it/s]


epoch 125


28it [00:07,  3.53it/s]


epoch 126


28it [00:07,  3.52it/s]


epoch 127


28it [00:07,  3.52it/s]



Test set: Avg. loss: 0.0011, Accuracy: 6392/6941 (92%)


Test set: Avg. loss: 0.0182, Accuracy: 631/1736 (36%)

epoch 128


28it [00:08,  3.47it/s]


epoch 129


28it [00:08,  3.46it/s]


epoch 130


28it [00:08,  3.46it/s]


epoch 131


28it [00:08,  3.43it/s]


epoch 132


28it [00:08,  3.48it/s]


epoch 133


28it [00:08,  3.45it/s]


epoch 134


28it [00:08,  3.46it/s]


epoch 135


28it [00:08,  3.47it/s]


epoch 136


28it [00:08,  3.45it/s]


epoch 137


28it [00:08,  3.44it/s]


epoch 138


28it [00:08,  3.45it/s]


epoch 139


28it [00:08,  3.40it/s]


epoch 140


28it [00:08,  3.45it/s]


epoch 141


28it [00:08,  3.43it/s]


epoch 142


28it [00:08,  3.43it/s]


epoch 143


28it [00:08,  3.48it/s]



Test set: Avg. loss: 0.0011, Accuracy: 6384/6941 (92%)


Test set: Avg. loss: 0.0189, Accuracy: 587/1736 (34%)

epoch 144


28it [00:08,  3.49it/s]


epoch 145


28it [00:08,  3.48it/s]


epoch 146


28it [00:07,  3.55it/s]


epoch 147


28it [00:07,  3.51it/s]


epoch 148


28it [00:07,  3.57it/s]


epoch 149


28it [00:07,  3.50it/s]


epoch 150


28it [00:07,  3.56it/s]


epoch 151


28it [00:07,  3.54it/s]


epoch 152


28it [00:07,  3.50it/s]


epoch 153


28it [00:07,  3.53it/s]


epoch 154


28it [00:07,  3.51it/s]


epoch 155


28it [00:08,  3.50it/s]


epoch 156


28it [00:07,  3.52it/s]


epoch 157


28it [00:07,  3.51it/s]


epoch 158


28it [00:07,  3.53it/s]


epoch 159


28it [00:07,  3.53it/s]



Test set: Avg. loss: 0.0011, Accuracy: 6399/6941 (92%)


Test set: Avg. loss: 0.0190, Accuracy: 646/1736 (37%)

epoch 160


28it [00:07,  3.50it/s]


epoch 161


28it [00:08,  3.50it/s]


epoch 162


28it [00:08,  3.46it/s]


epoch 163


28it [00:07,  3.52it/s]


epoch 164


28it [00:07,  3.52it/s]


epoch 165


28it [00:07,  3.55it/s]


epoch 166


28it [00:07,  3.55it/s]


epoch 167


28it [00:07,  3.55it/s]


epoch 168


28it [00:07,  3.54it/s]


epoch 169


28it [00:07,  3.54it/s]


epoch 170


28it [00:07,  3.50it/s]


epoch 171


28it [00:07,  3.52it/s]


epoch 172


28it [00:08,  3.48it/s]


epoch 173


28it [00:07,  3.54it/s]


epoch 174


28it [00:07,  3.52it/s]


epoch 175


28it [00:07,  3.55it/s]



Test set: Avg. loss: 0.0009, Accuracy: 6500/6941 (94%)


Test set: Avg. loss: 0.0183, Accuracy: 635/1736 (37%)

epoch 176


28it [00:07,  3.51it/s]


epoch 177


28it [00:07,  3.54it/s]


epoch 178


28it [00:07,  3.53it/s]


epoch 179


28it [00:07,  3.53it/s]


epoch 180


28it [00:07,  3.52it/s]


epoch 181


28it [00:07,  3.51it/s]


epoch 182


28it [00:07,  3.51it/s]


epoch 183


28it [00:07,  3.52it/s]


epoch 184


28it [00:07,  3.51it/s]


epoch 185


28it [00:07,  3.57it/s]


epoch 186


28it [00:07,  3.55it/s]


epoch 187


28it [00:07,  3.54it/s]


epoch 188


28it [00:07,  3.55it/s]


epoch 189


28it [00:07,  3.54it/s]


epoch 190


28it [00:07,  3.58it/s]


epoch 191


28it [00:07,  3.56it/s]



Test set: Avg. loss: 0.0009, Accuracy: 6486/6941 (93%)


Test set: Avg. loss: 0.0185, Accuracy: 614/1736 (35%)

epoch 192


28it [00:07,  3.55it/s]


epoch 193


28it [00:07,  3.53it/s]


epoch 194


28it [00:07,  3.53it/s]


epoch 195


28it [00:07,  3.53it/s]


epoch 196


28it [00:07,  3.54it/s]


epoch 197


28it [00:07,  3.52it/s]


epoch 198


28it [00:07,  3.52it/s]


epoch 199


28it [00:07,  3.53it/s]


epoch 200


28it [00:07,  3.51it/s]


epoch 201


28it [00:07,  3.53it/s]


epoch 202


28it [00:07,  3.52it/s]


epoch 203


28it [00:07,  3.50it/s]


epoch 204


28it [00:07,  3.54it/s]


epoch 205


28it [00:07,  3.55it/s]


epoch 206


28it [00:07,  3.58it/s]


epoch 207


28it [00:07,  3.58it/s]



Test set: Avg. loss: 0.0009, Accuracy: 6518/6941 (94%)


Test set: Avg. loss: 0.0185, Accuracy: 630/1736 (36%)

epoch 208


28it [00:07,  3.53it/s]


epoch 209


28it [00:07,  3.52it/s]


epoch 210


28it [00:07,  3.54it/s]


epoch 211


28it [00:07,  3.55it/s]


epoch 212


28it [00:08,  3.48it/s]


epoch 213


28it [00:08,  3.45it/s]


epoch 214


28it [00:07,  3.54it/s]


epoch 215


28it [00:07,  3.53it/s]


epoch 216


28it [00:07,  3.53it/s]


epoch 217


28it [00:07,  3.51it/s]


epoch 218


28it [00:07,  3.52it/s]


epoch 219


28it [00:07,  3.53it/s]


epoch 220


28it [00:07,  3.51it/s]


epoch 221


28it [00:07,  3.52it/s]


epoch 222


28it [00:07,  3.55it/s]


epoch 223


28it [00:07,  3.52it/s]



Test set: Avg. loss: 0.0009, Accuracy: 6511/6941 (94%)


Test set: Avg. loss: 0.0189, Accuracy: 623/1736 (36%)

epoch 224


28it [00:08,  3.48it/s]


epoch 225


28it [00:08,  3.46it/s]


epoch 226


28it [00:07,  3.54it/s]


epoch 227


28it [00:08,  3.47it/s]


epoch 228


28it [00:07,  3.51it/s]


epoch 229


28it [00:08,  3.42it/s]


epoch 230


28it [00:07,  3.53it/s]


epoch 231


28it [00:07,  3.54it/s]


epoch 232


28it [00:07,  3.52it/s]


epoch 233


28it [00:07,  3.52it/s]


epoch 234


28it [00:07,  3.53it/s]


epoch 235


28it [00:07,  3.52it/s]


epoch 236


28it [00:08,  3.50it/s]


epoch 237


28it [00:07,  3.54it/s]


epoch 238


28it [00:08,  3.48it/s]


epoch 239


28it [00:08,  3.49it/s]



Test set: Avg. loss: 0.0007, Accuracy: 6583/6941 (95%)


Test set: Avg. loss: 0.0192, Accuracy: 636/1736 (37%)

epoch 240


28it [00:08,  3.48it/s]


epoch 241


28it [00:07,  3.51it/s]


epoch 242


28it [00:07,  3.51it/s]


epoch 243


28it [00:08,  3.49it/s]


epoch 244


28it [00:08,  3.48it/s]


epoch 245


28it [00:07,  3.55it/s]


epoch 246


28it [00:08,  3.45it/s]


epoch 247


28it [00:07,  3.53it/s]


epoch 248


28it [00:07,  3.50it/s]


epoch 249


28it [00:07,  3.52it/s]


epoch 250


28it [00:07,  3.51it/s]


epoch 251


28it [00:07,  3.54it/s]


epoch 252


28it [00:07,  3.53it/s]


epoch 253


28it [00:07,  3.55it/s]


epoch 254


28it [00:08,  3.45it/s]


epoch 255


28it [00:08,  3.39it/s]



Test set: Avg. loss: 0.0008, Accuracy: 6532/6941 (94%)


Test set: Avg. loss: 0.0187, Accuracy: 642/1736 (37%)



In [26]:
from shearletNN.complex_deit import Attention, vit_models, Block
from shearletNN.layers import CReLU, ComplexLayerNorm, ComplexConcat, ComplexFlatten
from functools import partial


def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

train_transform = v2.Compose([
    transforms.RandomResizedCrop((image_size, image_size), scale=(0.5, 1.0)),
    # transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

val_transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=train_transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=val_transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

shearlets = shearlets[:3]

def shearlet_transform(img):
    return frequency_shearlet_transform(img, shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, patch_size, patch_size], x.shape
    break
print('building model...')

model = torch.nn.Sequential(torch.nn.Flatten(), 
                            torch.nn.Linear(shearlets.shape[0] * 3 * patch_size * patch_size, 2048, dtype=torch.complex64), 
                            CReLU(), 
                            torch.nn.Linear(2048, 1024, dtype=torch.complex64), 
                            CReLU(), 
                            ComplexConcat(), 
                            torch.nn.Linear(2048, 2048),
                            torch.nn.Linear(2048, 1024),
                            torch.nn.Linear(1024, 101))

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(256):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=1)
    gc.collect()
    if epoch % 16 == 15:
        test(model, train_loader)
        test(model, val_loader)

Files already downloaded and verified
Files already downloaded and verified


0it [00:00, ?it/s]


building model...
training model...
epoch 0


28it [00:05,  5.38it/s]


epoch 1


28it [00:05,  5.35it/s]


epoch 2


28it [00:05,  5.37it/s]


epoch 3


28it [00:05,  5.39it/s]


epoch 4


28it [00:05,  5.32it/s]


epoch 5


28it [00:05,  5.40it/s]


epoch 6


28it [00:05,  5.39it/s]


epoch 7


28it [00:05,  5.39it/s]


epoch 8


28it [00:05,  5.37it/s]


epoch 9


28it [00:05,  5.39it/s]


epoch 10


28it [00:05,  5.40it/s]


epoch 11


28it [00:05,  5.44it/s]


epoch 12


28it [00:05,  5.37it/s]


epoch 13


28it [00:05,  5.41it/s]


epoch 14


28it [00:05,  5.36it/s]


epoch 15


28it [00:05,  5.38it/s]



Test set: Avg. loss: 0.0139, Accuracy: 1772/6941 (26%)


Test set: Avg. loss: 0.0159, Accuracy: 342/1736 (20%)

epoch 16


28it [00:05,  5.40it/s]


epoch 17


28it [00:05,  5.43it/s]


epoch 18


28it [00:05,  5.28it/s]


epoch 19


28it [00:05,  5.39it/s]


epoch 20


28it [00:05,  5.31it/s]


epoch 21


28it [00:05,  5.33it/s]


epoch 22


28it [00:05,  5.40it/s]


epoch 23


28it [00:05,  5.39it/s]


epoch 24


28it [00:05,  5.38it/s]


epoch 25


28it [00:05,  5.39it/s]


epoch 26


28it [00:05,  5.39it/s]


epoch 27


28it [00:05,  5.38it/s]


epoch 28


28it [00:05,  5.29it/s]


epoch 29


28it [00:05,  5.36it/s]


epoch 30


28it [00:05,  5.41it/s]


epoch 31


28it [00:05,  5.41it/s]



Test set: Avg. loss: 0.0127, Accuracy: 2117/6941 (30%)


Test set: Avg. loss: 0.0150, Accuracy: 388/1736 (22%)

epoch 32


28it [00:05,  5.40it/s]


epoch 33


28it [00:05,  5.34it/s]


epoch 34


28it [00:05,  5.16it/s]


epoch 35


28it [00:05,  5.37it/s]


epoch 36


28it [00:05,  5.38it/s]


epoch 37


28it [00:05,  5.34it/s]


epoch 38


28it [00:05,  5.38it/s]


epoch 39


28it [00:05,  5.39it/s]


epoch 40


28it [00:05,  5.41it/s]


epoch 41


28it [00:05,  5.39it/s]


epoch 42


28it [00:05,  5.32it/s]


epoch 43


28it [00:05,  5.33it/s]


epoch 44


28it [00:05,  5.31it/s]


epoch 45


28it [00:05,  5.31it/s]


epoch 46


28it [00:05,  5.31it/s]


epoch 47


28it [00:05,  5.40it/s]



Test set: Avg. loss: 0.0120, Accuracy: 2363/6941 (34%)


Test set: Avg. loss: 0.0141, Accuracy: 459/1736 (26%)

epoch 48


28it [00:05,  5.40it/s]


epoch 49


28it [00:05,  5.41it/s]


epoch 50


28it [00:05,  5.36it/s]


epoch 51


28it [00:05,  5.40it/s]


epoch 52


28it [00:05,  5.37it/s]


epoch 53


28it [00:05,  5.35it/s]


epoch 54


28it [00:05,  5.36it/s]


epoch 55


28it [00:05,  5.36it/s]


epoch 56


28it [00:05,  5.34it/s]


epoch 57


28it [00:05,  5.35it/s]


epoch 58


28it [00:05,  5.29it/s]


epoch 59


28it [00:05,  5.36it/s]


epoch 60


28it [00:05,  5.37it/s]


epoch 61


28it [00:05,  5.28it/s]


epoch 62


28it [00:05,  5.38it/s]


epoch 63


28it [00:05,  5.38it/s]



Test set: Avg. loss: 0.0116, Accuracy: 2492/6941 (36%)


Test set: Avg. loss: 0.0136, Accuracy: 471/1736 (27%)

epoch 64


28it [00:05,  5.37it/s]


epoch 65


28it [00:05,  5.37it/s]


epoch 66


28it [00:05,  5.39it/s]


epoch 67


28it [00:05,  5.36it/s]


epoch 68


28it [00:05,  5.41it/s]


epoch 69


28it [00:05,  5.35it/s]


epoch 70


28it [00:05,  5.35it/s]


epoch 71


28it [00:05,  5.38it/s]


epoch 72


28it [00:05,  5.38it/s]


epoch 73


28it [00:05,  5.36it/s]


epoch 74


28it [00:05,  5.37it/s]


epoch 75


28it [00:05,  5.35it/s]


epoch 76


28it [00:05,  5.36it/s]


epoch 77


28it [00:05,  5.37it/s]


epoch 78


28it [00:05,  5.32it/s]


epoch 79


28it [00:05,  5.36it/s]



Test set: Avg. loss: 0.0111, Accuracy: 2621/6941 (38%)


Test set: Avg. loss: 0.0137, Accuracy: 454/1736 (26%)

epoch 80


28it [00:05,  5.42it/s]


epoch 81


28it [00:05,  5.34it/s]


epoch 82


28it [00:05,  5.33it/s]


epoch 83


28it [00:05,  5.37it/s]


epoch 84


28it [00:05,  5.42it/s]


epoch 85


28it [00:05,  5.39it/s]


epoch 86


28it [00:05,  5.35it/s]


epoch 87


28it [00:05,  5.39it/s]


epoch 88


28it [00:05,  5.35it/s]


epoch 89


28it [00:05,  5.42it/s]


epoch 90


28it [00:05,  5.38it/s]


epoch 91


28it [00:05,  5.36it/s]


epoch 92


28it [00:05,  5.37it/s]


epoch 93


28it [00:05,  5.37it/s]


epoch 94


28it [00:05,  5.33it/s]


epoch 95


28it [00:05,  5.32it/s]



Test set: Avg. loss: 0.0106, Accuracy: 2729/6941 (39%)


Test set: Avg. loss: 0.0132, Accuracy: 513/1736 (30%)

epoch 96


28it [00:05,  5.40it/s]


epoch 97


28it [00:05,  5.39it/s]


epoch 98


28it [00:05,  5.38it/s]


epoch 99


28it [00:05,  5.37it/s]


epoch 100


28it [00:05,  5.35it/s]


epoch 101


28it [00:05,  5.33it/s]


epoch 102


28it [00:05,  5.33it/s]


epoch 103


28it [00:05,  5.38it/s]


epoch 104


28it [00:05,  5.33it/s]


epoch 105


28it [00:05,  5.37it/s]


epoch 106


28it [00:05,  5.38it/s]


epoch 107


28it [00:05,  5.42it/s]


epoch 108


28it [00:05,  5.39it/s]


epoch 109


28it [00:05,  5.43it/s]


epoch 110


28it [00:05,  5.37it/s]


epoch 111


28it [00:05,  5.35it/s]



Test set: Avg. loss: 0.0105, Accuracy: 2891/6941 (42%)


Test set: Avg. loss: 0.0132, Accuracy: 509/1736 (29%)

epoch 112


28it [00:05,  5.35it/s]


epoch 113


28it [00:05,  5.26it/s]


epoch 114


28it [00:05,  5.39it/s]


epoch 115


28it [00:05,  5.40it/s]


epoch 116


28it [00:05,  5.39it/s]


epoch 117


28it [00:05,  5.36it/s]


epoch 118


28it [00:05,  5.37it/s]


epoch 119


28it [00:05,  5.40it/s]


epoch 120


28it [00:05,  5.35it/s]


epoch 121


28it [00:05,  5.35it/s]


epoch 122


28it [00:05,  5.25it/s]


epoch 123


28it [00:05,  5.34it/s]


epoch 124


28it [00:05,  5.32it/s]


epoch 125


28it [00:05,  5.30it/s]


epoch 126


28it [00:05,  5.36it/s]


epoch 127


28it [00:05,  5.36it/s]



Test set: Avg. loss: 0.0101, Accuracy: 2940/6941 (42%)


Test set: Avg. loss: 0.0128, Accuracy: 515/1736 (30%)

epoch 128


28it [00:05,  5.29it/s]


epoch 129


28it [00:05,  5.36it/s]


epoch 130


28it [00:05,  5.34it/s]


epoch 131


28it [00:05,  5.30it/s]


epoch 132


28it [00:05,  5.33it/s]


epoch 133


28it [00:05,  5.36it/s]


epoch 134


28it [00:05,  5.36it/s]


epoch 135


28it [00:05,  5.31it/s]


epoch 136


28it [00:05,  5.36it/s]


epoch 137


28it [00:05,  5.39it/s]


epoch 138


28it [00:05,  5.39it/s]


epoch 139


28it [00:05,  5.38it/s]


epoch 140


28it [00:05,  5.37it/s]


epoch 141


28it [00:05,  5.33it/s]


epoch 142


28it [00:05,  5.34it/s]


epoch 143


28it [00:05,  5.22it/s]



Test set: Avg. loss: 0.0097, Accuracy: 3005/6941 (43%)


Test set: Avg. loss: 0.0128, Accuracy: 519/1736 (30%)

epoch 144


28it [00:05,  5.33it/s]


epoch 145


28it [00:05,  5.32it/s]


epoch 146


28it [00:05,  5.35it/s]


epoch 147


28it [00:05,  5.37it/s]


epoch 148


28it [00:05,  5.26it/s]


epoch 149


28it [00:05,  5.39it/s]


epoch 150


28it [00:05,  5.34it/s]


epoch 151


28it [00:05,  5.36it/s]


epoch 152


28it [00:05,  5.34it/s]


epoch 153


28it [00:05,  5.36it/s]


epoch 154


28it [00:05,  5.34it/s]


epoch 155


28it [00:05,  5.35it/s]


epoch 156


28it [00:05,  5.36it/s]


epoch 157


28it [00:05,  5.33it/s]


epoch 158


28it [00:05,  5.34it/s]


epoch 159


28it [00:05,  5.34it/s]



Test set: Avg. loss: 0.0096, Accuracy: 3110/6941 (45%)


Test set: Avg. loss: 0.0123, Accuracy: 563/1736 (32%)

epoch 160


28it [00:05,  5.32it/s]


epoch 161


28it [00:05,  5.26it/s]


epoch 162


28it [00:05,  5.26it/s]


epoch 163


28it [00:05,  5.07it/s]


epoch 164


28it [00:05,  5.37it/s]


epoch 165


28it [00:05,  5.36it/s]


epoch 166


28it [00:05,  5.34it/s]


epoch 167


28it [00:05,  5.37it/s]


epoch 168


28it [00:05,  5.35it/s]


epoch 169


28it [00:05,  5.26it/s]


epoch 170


28it [00:05,  5.35it/s]


epoch 171


28it [00:05,  5.34it/s]


epoch 172


28it [00:05,  5.34it/s]


epoch 173


28it [00:05,  5.32it/s]


epoch 174


28it [00:05,  5.32it/s]


epoch 175


28it [00:05,  5.36it/s]



Test set: Avg. loss: 0.0092, Accuracy: 3209/6941 (46%)


Test set: Avg. loss: 0.0123, Accuracy: 577/1736 (33%)

epoch 176


28it [00:05,  5.34it/s]


epoch 177


28it [00:05,  5.34it/s]


epoch 178


28it [00:05,  5.34it/s]


epoch 179


28it [00:05,  5.27it/s]


epoch 180


28it [00:05,  5.40it/s]


epoch 181


28it [00:05,  5.37it/s]


epoch 182


28it [00:05,  5.39it/s]


epoch 183


28it [00:05,  5.37it/s]


epoch 184


28it [00:05,  5.37it/s]


epoch 185


28it [00:05,  5.39it/s]


epoch 186


28it [00:05,  5.33it/s]


epoch 187


28it [00:05,  5.38it/s]


epoch 188


28it [00:05,  5.37it/s]


epoch 189


28it [00:05,  5.36it/s]


epoch 190


28it [00:05,  5.29it/s]


epoch 191


28it [00:05,  5.40it/s]



Test set: Avg. loss: 0.0092, Accuracy: 3249/6941 (47%)


Test set: Avg. loss: 0.0125, Accuracy: 560/1736 (32%)

epoch 192


28it [00:05,  5.38it/s]


epoch 193


28it [00:05,  5.36it/s]


epoch 194


28it [00:05,  5.40it/s]


epoch 195


28it [00:05,  5.42it/s]


epoch 196


28it [00:05,  5.40it/s]


epoch 197


28it [00:05,  5.31it/s]


epoch 198


28it [00:05,  5.29it/s]


epoch 199


28it [00:05,  5.32it/s]


epoch 200


28it [00:05,  5.37it/s]


epoch 201


28it [00:05,  5.41it/s]


epoch 202


28it [00:05,  5.39it/s]


epoch 203


28it [00:05,  5.41it/s]


epoch 204


28it [00:05,  5.40it/s]


epoch 205


28it [00:05,  5.37it/s]


epoch 206


28it [00:05,  5.43it/s]


epoch 207


28it [00:05,  5.42it/s]



Test set: Avg. loss: 0.0088, Accuracy: 3322/6941 (48%)


Test set: Avg. loss: 0.0124, Accuracy: 578/1736 (33%)

epoch 208


28it [00:05,  5.37it/s]


epoch 209


28it [00:05,  5.36it/s]


epoch 210


28it [00:05,  5.34it/s]


epoch 211


28it [00:05,  5.34it/s]


epoch 212


28it [00:05,  5.34it/s]


epoch 213


28it [00:05,  5.38it/s]


epoch 214


28it [00:05,  5.34it/s]


epoch 215


28it [00:05,  5.29it/s]


epoch 216


28it [00:05,  5.35it/s]


epoch 217


28it [00:05,  5.32it/s]


epoch 218


28it [00:05,  5.37it/s]


epoch 219


28it [00:05,  5.33it/s]


epoch 220


28it [00:05,  5.42it/s]


epoch 221


28it [00:05,  5.38it/s]


epoch 222


28it [00:05,  5.35it/s]


epoch 223


28it [00:05,  5.34it/s]



Test set: Avg. loss: 0.0087, Accuracy: 3347/6941 (48%)


Test set: Avg. loss: 0.0122, Accuracy: 615/1736 (35%)

epoch 224


28it [00:05,  5.37it/s]


epoch 225


28it [00:05,  5.36it/s]


epoch 226


28it [00:05,  5.34it/s]


epoch 227


28it [00:05,  5.34it/s]


epoch 228


28it [00:05,  5.33it/s]


epoch 229


28it [00:05,  5.35it/s]


epoch 230


28it [00:05,  5.36it/s]


epoch 231


28it [00:05,  5.38it/s]


epoch 232


28it [00:05,  5.37it/s]


epoch 233


28it [00:05,  5.38it/s]


epoch 234


28it [00:05,  5.37it/s]


epoch 235


28it [00:05,  5.27it/s]


epoch 236


28it [00:05,  5.35it/s]


epoch 237


28it [00:05,  5.32it/s]


epoch 238


28it [00:05,  5.37it/s]


epoch 239


28it [00:05,  5.34it/s]



Test set: Avg. loss: 0.0085, Accuracy: 3454/6941 (50%)


Test set: Avg. loss: 0.0127, Accuracy: 582/1736 (34%)

epoch 240


28it [00:05,  5.20it/s]


epoch 241


28it [00:05,  5.30it/s]


epoch 242


28it [00:05,  5.21it/s]


epoch 243


28it [00:05,  5.35it/s]


epoch 244


28it [00:05,  5.34it/s]


epoch 245


28it [00:05,  5.31it/s]


epoch 246


28it [00:05,  5.35it/s]


epoch 247


28it [00:05,  5.40it/s]


epoch 248


28it [00:05,  5.36it/s]


epoch 249


28it [00:05,  5.39it/s]


epoch 250


28it [00:05,  5.34it/s]


epoch 251


28it [00:05,  5.35it/s]


epoch 252


28it [00:05,  5.38it/s]


epoch 253


28it [00:05,  5.42it/s]


epoch 254


28it [00:05,  5.35it/s]


epoch 255


28it [00:05,  5.40it/s]



Test set: Avg. loss: 0.0083, Accuracy: 3542/6941 (51%)


Test set: Avg. loss: 0.0122, Accuracy: 586/1736 (34%)



In [27]:
for epoch in range(256):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=1)
    gc.collect()
    if epoch % 16 == 15:
        test(model, train_loader)
        test(model, val_loader)

epoch 0


28it [00:05,  5.23it/s]


epoch 1


28it [00:05,  5.36it/s]


epoch 2


28it [00:05,  5.46it/s]


epoch 3


28it [00:05,  5.50it/s]


epoch 4


28it [00:05,  5.50it/s]


epoch 5


28it [00:05,  5.38it/s]


epoch 6


28it [00:05,  5.35it/s]


epoch 7


28it [00:05,  5.30it/s]


epoch 8


28it [00:05,  5.39it/s]


epoch 9


28it [00:05,  5.40it/s]


epoch 10


28it [00:05,  5.39it/s]


epoch 11


28it [00:05,  5.42it/s]


epoch 12


28it [00:05,  5.46it/s]


epoch 13


28it [00:05,  5.50it/s]


epoch 14


28it [00:05,  5.51it/s]


epoch 15


28it [00:05,  5.51it/s]



Test set: Avg. loss: 0.0081, Accuracy: 3559/6941 (51%)


Test set: Avg. loss: 0.0121, Accuracy: 612/1736 (35%)

epoch 16


28it [00:05,  5.47it/s]


epoch 17


28it [00:05,  5.40it/s]


epoch 18


28it [00:05,  5.39it/s]


epoch 19


28it [00:05,  5.46it/s]


epoch 20


28it [00:05,  5.46it/s]


epoch 21


28it [00:05,  5.41it/s]


epoch 22


28it [00:05,  5.42it/s]


epoch 23


28it [00:05,  5.42it/s]


epoch 24


28it [00:05,  5.42it/s]


epoch 25


28it [00:05,  5.30it/s]


epoch 26


28it [00:05,  5.47it/s]


epoch 27


28it [00:05,  5.42it/s]


epoch 28


28it [00:05,  5.40it/s]


epoch 29


28it [00:05,  5.33it/s]


epoch 30


28it [00:05,  5.33it/s]


epoch 31


28it [00:05,  5.45it/s]



Test set: Avg. loss: 0.0078, Accuracy: 3635/6941 (52%)


Test set: Avg. loss: 0.0119, Accuracy: 635/1736 (37%)

epoch 32


28it [00:05,  5.36it/s]


epoch 33


28it [00:05,  5.33it/s]


epoch 34


28it [00:05,  5.36it/s]


epoch 35


28it [00:05,  5.33it/s]


epoch 36


28it [00:05,  5.36it/s]


epoch 37


28it [00:05,  5.36it/s]


epoch 38


28it [00:05,  5.37it/s]


epoch 39


28it [00:05,  5.33it/s]


epoch 40


28it [00:05,  5.33it/s]


epoch 41


28it [00:05,  5.34it/s]


epoch 42


28it [00:05,  5.35it/s]


epoch 43


28it [00:05,  5.34it/s]


epoch 44


28it [00:05,  5.37it/s]


epoch 45


28it [00:05,  5.34it/s]


epoch 46


28it [00:05,  5.33it/s]


epoch 47


28it [00:05,  5.33it/s]



Test set: Avg. loss: 0.0077, Accuracy: 3636/6941 (52%)


Test set: Avg. loss: 0.0119, Accuracy: 650/1736 (37%)

epoch 48


28it [00:05,  5.31it/s]


epoch 49


28it [00:05,  5.36it/s]


epoch 50


28it [00:05,  5.34it/s]


epoch 51


28it [00:05,  5.35it/s]


epoch 52


28it [00:05,  5.36it/s]


epoch 53


28it [00:05,  5.38it/s]


epoch 54


28it [00:05,  5.29it/s]


epoch 55


28it [00:05,  5.24it/s]


epoch 56


28it [00:05,  5.35it/s]


epoch 57


28it [00:05,  5.32it/s]


epoch 58


28it [00:05,  5.20it/s]


epoch 59


28it [00:05,  5.32it/s]


epoch 60


28it [00:05,  5.31it/s]


epoch 61


28it [00:05,  5.34it/s]


epoch 62


28it [00:05,  5.35it/s]


epoch 63


28it [00:05,  5.37it/s]



Test set: Avg. loss: 0.0075, Accuracy: 3823/6941 (55%)


Test set: Avg. loss: 0.0122, Accuracy: 641/1736 (37%)

epoch 64


28it [00:05,  5.36it/s]


epoch 65


28it [00:05,  5.19it/s]


epoch 66


28it [00:05,  5.34it/s]


epoch 67


28it [00:05,  5.38it/s]


epoch 68


28it [00:05,  5.38it/s]


epoch 69


28it [00:05,  5.28it/s]


epoch 70


28it [00:05,  5.34it/s]


epoch 71


28it [00:05,  5.35it/s]


epoch 72


28it [00:05,  5.15it/s]


epoch 73


28it [00:05,  5.10it/s]


epoch 74


28it [00:05,  5.32it/s]


epoch 75


28it [00:05,  5.41it/s]


epoch 76


28it [00:05,  5.27it/s]


epoch 77


28it [00:05,  5.29it/s]


epoch 78


28it [00:05,  5.32it/s]


epoch 79


28it [00:05,  5.38it/s]



Test set: Avg. loss: 0.0072, Accuracy: 3896/6941 (56%)


Test set: Avg. loss: 0.0120, Accuracy: 664/1736 (38%)

epoch 80


28it [00:05,  5.23it/s]


epoch 81


28it [00:05,  5.24it/s]


epoch 82


28it [00:05,  5.34it/s]


epoch 83


28it [00:05,  5.19it/s]


epoch 84


28it [00:05,  5.13it/s]


epoch 85


28it [00:05,  5.27it/s]


epoch 86


28it [00:05,  5.34it/s]


epoch 87


28it [00:05,  5.36it/s]


epoch 88


28it [00:05,  5.29it/s]


epoch 89


28it [00:05,  5.09it/s]


epoch 90


28it [00:05,  5.31it/s]


epoch 91


28it [00:05,  5.40it/s]


epoch 92


28it [00:05,  5.39it/s]


epoch 93


28it [00:05,  5.33it/s]


epoch 94


28it [00:05,  5.38it/s]


epoch 95


28it [00:05,  5.36it/s]



Test set: Avg. loss: 0.0071, Accuracy: 3911/6941 (56%)


Test set: Avg. loss: 0.0122, Accuracy: 654/1736 (38%)

epoch 96


28it [00:05,  5.35it/s]


epoch 97


28it [00:05,  5.42it/s]


epoch 98


28it [00:05,  5.44it/s]


epoch 99


28it [00:05,  5.31it/s]


epoch 100


28it [00:05,  5.48it/s]


epoch 101


28it [00:05,  5.32it/s]


epoch 102


28it [00:05,  5.43it/s]


epoch 103


28it [00:05,  5.32it/s]


epoch 104


28it [00:05,  5.38it/s]


epoch 105


28it [00:05,  5.41it/s]


epoch 106


28it [00:05,  5.39it/s]


epoch 107


28it [00:05,  5.43it/s]


epoch 108


28it [00:05,  5.42it/s]


epoch 109


28it [00:05,  5.44it/s]


epoch 110


28it [00:05,  5.37it/s]


epoch 111


28it [00:05,  5.38it/s]



Test set: Avg. loss: 0.0069, Accuracy: 3983/6941 (57%)


Test set: Avg. loss: 0.0118, Accuracy: 672/1736 (39%)

epoch 112


28it [00:05,  5.40it/s]


epoch 113


28it [00:05,  5.42it/s]


epoch 114


28it [00:05,  5.37it/s]


epoch 115


28it [00:05,  5.31it/s]


epoch 116


28it [00:05,  5.40it/s]


epoch 117


28it [00:05,  5.42it/s]


epoch 118


28it [00:05,  5.35it/s]


epoch 119


28it [00:05,  5.42it/s]


epoch 120


28it [00:05,  5.36it/s]


epoch 121


28it [00:05,  5.39it/s]


epoch 122


28it [00:05,  5.39it/s]


epoch 123


28it [00:05,  5.46it/s]


epoch 124


28it [00:05,  5.42it/s]


epoch 125


28it [00:05,  5.38it/s]


epoch 126


28it [00:05,  5.41it/s]


epoch 127


28it [00:05,  5.39it/s]



Test set: Avg. loss: 0.0067, Accuracy: 4049/6941 (58%)


Test set: Avg. loss: 0.0123, Accuracy: 638/1736 (37%)

epoch 128


28it [00:05,  5.42it/s]


epoch 129


28it [00:05,  5.43it/s]


epoch 130


28it [00:05,  5.40it/s]


epoch 131


28it [00:05,  5.43it/s]


epoch 132


28it [00:05,  5.38it/s]


epoch 133


28it [00:05,  5.28it/s]


epoch 134


28it [00:05,  5.42it/s]


epoch 135


28it [00:05,  5.41it/s]


epoch 136


28it [00:05,  5.42it/s]


epoch 137


28it [00:05,  5.39it/s]


epoch 138


28it [00:05,  5.41it/s]


epoch 139


28it [00:05,  5.44it/s]


epoch 140


28it [00:05,  5.41it/s]


epoch 141


28it [00:05,  5.43it/s]


epoch 142


28it [00:05,  5.44it/s]


epoch 143


28it [00:05,  5.44it/s]



Test set: Avg. loss: 0.0067, Accuracy: 4099/6941 (59%)


Test set: Avg. loss: 0.0123, Accuracy: 660/1736 (38%)

epoch 144


28it [00:05,  5.42it/s]


epoch 145


28it [00:05,  5.42it/s]


epoch 146


28it [00:05,  5.41it/s]


epoch 147


28it [00:05,  5.40it/s]


epoch 148


28it [00:05,  5.39it/s]


epoch 149


28it [00:05,  5.36it/s]


epoch 150


28it [00:05,  5.35it/s]


epoch 151


28it [00:05,  5.35it/s]


epoch 152


28it [00:05,  5.43it/s]


epoch 153


28it [00:05,  5.44it/s]


epoch 154


28it [00:05,  5.36it/s]


epoch 155


28it [00:05,  5.32it/s]


epoch 156


28it [00:05,  5.42it/s]


epoch 157


28it [00:05,  5.40it/s]


epoch 158


28it [00:05,  5.41it/s]


epoch 159


28it [00:05,  5.43it/s]



Test set: Avg. loss: 0.0065, Accuracy: 4175/6941 (60%)


Test set: Avg. loss: 0.0121, Accuracy: 653/1736 (38%)

epoch 160


28it [00:05,  5.44it/s]


epoch 161


28it [00:05,  5.22it/s]


epoch 162


28it [00:05,  5.33it/s]


epoch 163


28it [00:05,  5.31it/s]


epoch 164


28it [00:05,  5.32it/s]


epoch 165


28it [00:05,  5.27it/s]


epoch 166


28it [00:05,  5.25it/s]


epoch 167


28it [00:05,  5.31it/s]


epoch 168


28it [00:05,  5.31it/s]


epoch 169


28it [00:05,  5.30it/s]


epoch 170


28it [00:05,  5.28it/s]


epoch 171


28it [00:05,  5.37it/s]


epoch 172


28it [00:05,  5.30it/s]


epoch 173


28it [00:05,  5.25it/s]


epoch 174


28it [00:05,  5.31it/s]


epoch 175


28it [00:05,  5.29it/s]



Test set: Avg. loss: 0.0062, Accuracy: 4190/6941 (60%)


Test set: Avg. loss: 0.0120, Accuracy: 673/1736 (39%)

epoch 176


28it [00:05,  5.28it/s]


epoch 177


28it [00:05,  5.34it/s]


epoch 178


28it [00:05,  5.30it/s]


epoch 179


28it [00:05,  5.31it/s]


epoch 180


28it [00:05,  5.30it/s]


epoch 181


28it [00:05,  5.33it/s]


epoch 182


28it [00:05,  5.35it/s]


epoch 183


28it [00:05,  5.33it/s]


epoch 184


28it [00:05,  5.31it/s]


epoch 185


28it [00:05,  5.29it/s]


epoch 186


28it [00:05,  5.31it/s]


epoch 187


28it [00:05,  5.37it/s]


epoch 188


28it [00:05,  5.32it/s]


epoch 189


28it [00:05,  5.33it/s]


epoch 190


28it [00:05,  5.31it/s]


epoch 191


28it [00:05,  5.27it/s]



Test set: Avg. loss: 0.0060, Accuracy: 4330/6941 (62%)


Test set: Avg. loss: 0.0123, Accuracy: 665/1736 (38%)

epoch 192


28it [00:05,  5.30it/s]


epoch 193


28it [00:05,  5.32it/s]


epoch 194


28it [00:05,  5.32it/s]


epoch 195


28it [00:05,  5.30it/s]


epoch 196


28it [00:05,  5.34it/s]


epoch 197


28it [00:05,  5.38it/s]


epoch 198


28it [00:05,  5.33it/s]


epoch 199


28it [00:05,  5.30it/s]


epoch 200


28it [00:05,  5.27it/s]


epoch 201


28it [00:05,  5.30it/s]


epoch 202


28it [00:05,  5.30it/s]


epoch 203


28it [00:05,  5.30it/s]


epoch 204


28it [00:05,  5.29it/s]


epoch 205


28it [00:05,  5.31it/s]


epoch 206


28it [00:05,  5.28it/s]


epoch 207


28it [00:05,  5.30it/s]



Test set: Avg. loss: 0.0059, Accuracy: 4334/6941 (62%)


Test set: Avg. loss: 0.0121, Accuracy: 709/1736 (41%)

epoch 208


28it [00:05,  5.33it/s]


epoch 209


28it [00:05,  5.33it/s]


epoch 210


28it [00:05,  5.29it/s]


epoch 211


28it [00:05,  5.31it/s]


epoch 212


28it [00:05,  5.31it/s]


epoch 213


28it [00:05,  5.34it/s]


epoch 214


28it [00:05,  5.30it/s]


epoch 215


28it [00:05,  5.34it/s]


epoch 216


28it [00:05,  5.27it/s]


epoch 217


28it [00:05,  5.29it/s]


epoch 218


28it [00:05,  5.30it/s]


epoch 219


28it [00:05,  5.28it/s]


epoch 220


28it [00:05,  5.33it/s]


epoch 221


28it [00:05,  5.32it/s]


epoch 222


28it [00:05,  5.35it/s]


epoch 223


28it [00:05,  5.37it/s]



Test set: Avg. loss: 0.0058, Accuracy: 4455/6941 (64%)


Test set: Avg. loss: 0.0122, Accuracy: 698/1736 (40%)

epoch 224


28it [00:05,  5.33it/s]


epoch 225


28it [00:05,  5.34it/s]


epoch 226


28it [00:05,  5.32it/s]


epoch 227


28it [00:05,  5.32it/s]


epoch 228


28it [00:05,  5.34it/s]


epoch 229


28it [00:05,  5.34it/s]


epoch 230


28it [00:05,  5.27it/s]


epoch 231


28it [00:05,  5.32it/s]


epoch 232


28it [00:05,  5.33it/s]


epoch 233


28it [00:05,  5.37it/s]


epoch 234


28it [00:05,  5.35it/s]


epoch 235


28it [00:05,  5.39it/s]


epoch 236


28it [00:05,  5.26it/s]


epoch 237


28it [00:05,  5.31it/s]


epoch 238


28it [00:05,  5.35it/s]


epoch 239


28it [00:05,  5.30it/s]



Test set: Avg. loss: 0.0057, Accuracy: 4473/6941 (64%)


Test set: Avg. loss: 0.0124, Accuracy: 689/1736 (40%)

epoch 240


28it [00:05,  5.35it/s]


epoch 241


28it [00:05,  5.30it/s]


epoch 242


28it [00:05,  5.29it/s]


epoch 243


28it [00:05,  5.24it/s]


epoch 244


28it [00:05,  5.37it/s]


epoch 245


28it [00:05,  5.36it/s]


epoch 246


28it [00:05,  5.34it/s]


epoch 247


28it [00:05,  5.36it/s]


epoch 248


28it [00:05,  5.35it/s]


epoch 249


28it [00:05,  5.38it/s]


epoch 250


28it [00:05,  5.34it/s]


epoch 251


28it [00:05,  5.38it/s]


epoch 252


28it [00:05,  5.37it/s]


epoch 253


28it [00:05,  5.36it/s]


epoch 254


28it [00:05,  5.35it/s]


epoch 255


28it [00:05,  5.35it/s]



Test set: Avg. loss: 0.0056, Accuracy: 4533/6941 (65%)


Test set: Avg. loss: 0.0125, Accuracy: 689/1736 (40%)



In [None]:
# amplitude:
amp = torch.abs(x)
print((amp.real.max(), amp.real.min()))

phase = torch.arctan(x.imag / x.real)

phase = torch.nan_to_num(torch.arctan(x.imag / x.real), posinf=torch.math.pi / 2, neginf= -torch.math.pi / 2)

print((phase.real.max(), phase.real.min()))

def to_magnitude_phase(x):
    """
    return magnitude/phase representation 
    """
    return torch.complex(torch.abs(x), torch.nan_to_num(torch.arctan(x.imag / x.real), posinf=torch.math.pi / 2, neginf= -torch.math.pi / 2))

(tensor(5.4802, device='cuda:0'), tensor(0., device='cuda:0'))
(tensor(1.5708, device='cuda:0'), tensor(-1.5708, device='cuda:0'))


In [None]:
x = torch.tensor([[1.5,.0,.0,.0]])
layerNorm = torch.nn.LayerNorm(4, elementwise_affine = False)
y1 = layerNorm(x)
mean = x.mean(-1, keepdim = True)
var = x.var(-1, keepdim = True, unbiased=False)
y2 = (x-mean)/torch.sqrt(var+layerNorm.eps)

torch.allclose(y1, y2)

True

In [None]:
x = torch.randn((50,20,100))

layerNorm = torch.nn.LayerNorm(x.shape[-1], elementwise_affine = True)
y1 = layerNorm(x)

mean = torch.mean(x, dim=-1, keepdim=True)
var = torch.square(x - mean).mean(dim=-1, keepdim=True)
y2 = ((x - mean) / torch.sqrt(var + layerNorm.eps)) * layerNorm.weight + layerNorm.bias

torch.allclose(y1, y2, atol=1e-5, rtol=1e-5)

True