In [1]:
from shearletNN.shearlets import getcomplexshearlets2D
from shearletNN.shearlet_utils import frequency_shearlet_transform, spatial_shearlet_transform, ShearletTransformLoader
from shearletNN.complex_resnet import complex_resnet18, complex_resnet34, complex_resnet50
from shearletNN.layers import CGELU, CReLU

import torch
import torchvision
from torchvision.transforms import v2
from torchvision import transforms

import gc


patch_size = 64
image_size = 128

rows, cols = image_size, image_size


shearlets, shearletIdxs, RMS, dualFrameWeights = getcomplexshearlets2D(	rows, 
                                                                        cols, 
                                                                        1, 
                                                                        3, 
                                                                        1, 
                                                                        0.5,
                                                                        wavelet_eff_support = image_size,
                                                                        gaussian_eff_support = image_size
                                                                        )

shearlets = torch.tensor(shearlets).permute(2, 0, 1).type(torch.complex64).to(0)

In [2]:
class Unraveling:
    def __init__(self, n):
        self.levels = []
        for i in range(0, n // 2):
            level = []
            for j in range(i, n - i):
                level.append((j, i))
                level.append((i, j))

                level.append((j, n - (i + 1)))
                level.append((n - (i + 1), j))

            level = list(set(level))
            self.levels.append((torch.tensor([x for x, _ in level]), torch.tensor([y for _, y in level])))
    
    def __call__(self, x):
        return [x[..., a, b] for a, b in self.levels]

In [3]:
from tqdm import tqdm

def train(model, optimizer, loader, accumulate=1):
    model.train()
    loss = torch.nn.CrossEntropyLoss()

    for i, (X, y) in tqdm(enumerate(loader)):
        out = model(X.to(0))
        optimizer.zero_grad()
        l = loss(out, y.to(0)) / accumulate
        l.backward()
        if i % accumulate == (accumulate - 1):
            optimizer.step()
        

def accuracy(output, target, topk=(1,)):
    output = output.to(torch.device('cpu'))
    target = target.to(torch.device('cpu'))
    maxk = max(topk)
    batch_size = target.shape[0]

    _, idx = output.sort(dim=1, descending=True)
    pred = idx.narrow(1, 0, maxk).t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def epoch_accuracy(loader_s, student):
    student.eval()

    out_epoch_s = [accuracy(student(L.to(0)), y)[0].detach().cpu().item() for L, y in loader_s]

    student.train()

    return sum(out_epoch_s) / len(out_epoch_s)

def test(network, test_loader):
    network.eval().to(0)
    test_loss = 0
    correct = 0
    total = 0
    test_losses=[]
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data.to(0))
            test_loss += torch.nn.CrossEntropyLoss()(output, target.to(0)).item()
            pred = output.data.max(1, keepdim=True)[1].cpu()
            correct += pred.eq(target.data.view_as(pred)).sum()
            total += target.shape[0]
        test_loss /= total
        test_losses.append(test_loss)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, total,
        100. * correct / total))

class IndexSubsetDataset:
    def __init__(self, ds, inds):
        self.ds = ds
        self.inds = inds

    def __iter__(self):
        for i in range(len(self.inds)):
            yield self[i]
    
    def __getitem__(self, i):
        return self.ds[self.inds[i]]
    
    def __len__(self):
        return len(self.inds)

In [4]:
batch_size_train = 64

In [7]:
from shearletNN.deit import deit_small_patch16_LS
torch.autograd.set_detect_anomaly(True)
def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

for x, y in tqdm(train_loader):
    print(x.dtype)
    assert list(x.shape) == [batch_size_train, 3, image_size, image_size], x.shape
    break
print('building model...')
model = deit_small_patch16_LS(img_size=128, in_chans=3, num_classes=101)

for p in model.parameters():
    assert not p.isnan().any()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(32):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=1)
    gc.collect()
    test(model, train_loader)
    test(model, val_loader)

Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/28 [00:00<?, ?it/s]


torch.float32
building model...
training model...
epoch 0


28it [00:10,  2.59it/s]



Test set: Avg. loss: 0.0168, Accuracy: 707/6941 (10%)


Test set: Avg. loss: 0.0169, Accuracy: 179/1736 (10%)

epoch 1


28it [00:10,  2.57it/s]



Test set: Avg. loss: 0.0163, Accuracy: 693/6941 (10%)


Test set: Avg. loss: 0.0164, Accuracy: 174/1736 (10%)

epoch 2


28it [00:10,  2.66it/s]



Test set: Avg. loss: 0.0157, Accuracy: 1234/6941 (18%)


Test set: Avg. loss: 0.0159, Accuracy: 295/1736 (17%)

epoch 3


28it [00:10,  2.59it/s]



Test set: Avg. loss: 0.0148, Accuracy: 1532/6941 (22%)


Test set: Avg. loss: 0.0150, Accuracy: 378/1736 (22%)

epoch 4


28it [00:10,  2.57it/s]



Test set: Avg. loss: 0.0139, Accuracy: 1771/6941 (26%)


Test set: Avg. loss: 0.0142, Accuracy: 439/1736 (25%)

epoch 5


28it [00:11,  2.51it/s]



Test set: Avg. loss: 0.0127, Accuracy: 2205/6941 (32%)


Test set: Avg. loss: 0.0132, Accuracy: 546/1736 (31%)

epoch 6


28it [00:10,  2.63it/s]



Test set: Avg. loss: 0.0120, Accuracy: 2362/6941 (34%)


Test set: Avg. loss: 0.0126, Accuracy: 563/1736 (32%)

epoch 7


28it [00:10,  2.63it/s]



Test set: Avg. loss: 0.0110, Accuracy: 2698/6941 (39%)


Test set: Avg. loss: 0.0118, Accuracy: 636/1736 (37%)

epoch 8


28it [00:11,  2.53it/s]



Test set: Avg. loss: 0.0103, Accuracy: 2946/6941 (42%)


Test set: Avg. loss: 0.0115, Accuracy: 649/1736 (37%)

epoch 9


28it [00:10,  2.64it/s]



Test set: Avg. loss: 0.0091, Accuracy: 3402/6941 (49%)


Test set: Avg. loss: 0.0108, Accuracy: 700/1736 (40%)

epoch 10


28it [00:10,  2.59it/s]



Test set: Avg. loss: 0.0083, Accuracy: 3833/6941 (55%)


Test set: Avg. loss: 0.0105, Accuracy: 742/1736 (43%)

epoch 11


28it [00:10,  2.58it/s]



Test set: Avg. loss: 0.0074, Accuracy: 4099/6941 (59%)


Test set: Avg. loss: 0.0101, Accuracy: 748/1736 (43%)

epoch 12


28it [00:11,  2.52it/s]



Test set: Avg. loss: 0.0066, Accuracy: 4526/6941 (65%)


Test set: Avg. loss: 0.0101, Accuracy: 772/1736 (44%)

epoch 13


28it [00:10,  2.62it/s]



Test set: Avg. loss: 0.0058, Accuracy: 5046/6941 (73%)


Test set: Avg. loss: 0.0099, Accuracy: 791/1736 (46%)

epoch 14


28it [00:11,  2.50it/s]



Test set: Avg. loss: 0.0049, Accuracy: 5505/6941 (79%)


Test set: Avg. loss: 0.0098, Accuracy: 805/1736 (46%)

epoch 15


28it [00:11,  2.54it/s]



Test set: Avg. loss: 0.0039, Accuracy: 5855/6941 (84%)


Test set: Avg. loss: 0.0096, Accuracy: 789/1736 (45%)

epoch 16


28it [00:10,  2.56it/s]



Test set: Avg. loss: 0.0034, Accuracy: 6087/6941 (88%)


Test set: Avg. loss: 0.0099, Accuracy: 799/1736 (46%)

epoch 17


28it [00:10,  2.61it/s]



Test set: Avg. loss: 0.0025, Accuracy: 6548/6941 (94%)


Test set: Avg. loss: 0.0097, Accuracy: 817/1736 (47%)

epoch 18


28it [00:10,  2.62it/s]



Test set: Avg. loss: 0.0018, Accuracy: 6671/6941 (96%)


Test set: Avg. loss: 0.0096, Accuracy: 816/1736 (47%)

epoch 19


28it [00:11,  2.54it/s]



Test set: Avg. loss: 0.0013, Accuracy: 6806/6941 (98%)


Test set: Avg. loss: 0.0098, Accuracy: 821/1736 (47%)

epoch 20


28it [00:10,  2.60it/s]



Test set: Avg. loss: 0.0008, Accuracy: 6882/6941 (99%)


Test set: Avg. loss: 0.0098, Accuracy: 816/1736 (47%)

epoch 21


28it [00:10,  2.67it/s]



Test set: Avg. loss: 0.0005, Accuracy: 6915/6941 (100%)


Test set: Avg. loss: 0.0099, Accuracy: 814/1736 (47%)

epoch 22


28it [00:10,  2.67it/s]



Test set: Avg. loss: 0.0004, Accuracy: 6925/6941 (100%)


Test set: Avg. loss: 0.0099, Accuracy: 828/1736 (48%)

epoch 23


28it [00:10,  2.69it/s]



Test set: Avg. loss: 0.0003, Accuracy: 6927/6941 (100%)


Test set: Avg. loss: 0.0099, Accuracy: 828/1736 (48%)

epoch 24


28it [00:10,  2.62it/s]



Test set: Avg. loss: 0.0002, Accuracy: 6932/6941 (100%)


Test set: Avg. loss: 0.0099, Accuracy: 827/1736 (48%)

epoch 25


28it [00:10,  2.67it/s]



Test set: Avg. loss: 0.0002, Accuracy: 6932/6941 (100%)


Test set: Avg. loss: 0.0101, Accuracy: 829/1736 (48%)

epoch 26


28it [00:10,  2.74it/s]



Test set: Avg. loss: 0.0001, Accuracy: 6934/6941 (100%)


Test set: Avg. loss: 0.0100, Accuracy: 828/1736 (48%)

epoch 27


28it [00:10,  2.68it/s]



Test set: Avg. loss: 0.0001, Accuracy: 6931/6941 (100%)


Test set: Avg. loss: 0.0102, Accuracy: 816/1736 (47%)

epoch 28


28it [00:10,  2.65it/s]



Test set: Avg. loss: 0.0002, Accuracy: 6932/6941 (100%)


Test set: Avg. loss: 0.0102, Accuracy: 828/1736 (48%)

epoch 29


28it [00:10,  2.71it/s]



Test set: Avg. loss: 0.0001, Accuracy: 6936/6941 (100%)


Test set: Avg. loss: 0.0102, Accuracy: 828/1736 (48%)

epoch 30


28it [00:10,  2.67it/s]



Test set: Avg. loss: 0.0001, Accuracy: 6936/6941 (100%)


Test set: Avg. loss: 0.0102, Accuracy: 838/1736 (48%)

epoch 31


28it [00:10,  2.58it/s]



Test set: Avg. loss: 0.0001, Accuracy: 6935/6941 (100%)


Test set: Avg. loss: 0.0103, Accuracy: 828/1736 (48%)



In [4]:
def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

def shearlet_transform(img):
    return frequency_shearlet_transform(img, shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, 6, patch_size, patch_size], x.shape
    break
print('building model...')
model = complex_resnet18(in_dim=6, complex=True, phasor=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(16):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=1)
    gc.collect()
    test(model, train_loader)
    test(model, val_loader)

Files already downloaded and verified
Files already downloaded and verified


0it [00:00, ?it/s]


AssertionError: torch.Size([64, 30, 128, 128])

In [None]:
class UnravelNN(torch.nn.Module):
    def __init__(self, n, embed_dim):
        sizes = [4 + i*8 for i in range(n)][::-1]
        self.layers = torch.nn.ModuleList([torch.nn.Linear(size, embed_dim, dtype=torch.complex64) for size in sizes])
        self.act = CReLU()

    def forward(self, x):
        x = torch.stack([layer(a) for layer, a in zip(self.layers, x)], -2) # (B, C, n, embed_dim)
        # if we flatten the C and embed_dim together we have n tokens, one for each frequency level.  
        # This is kind of what we would want as a transformer input
        return
        

In [4]:
from shearletNN.complex_deit import complex_rope_mixed_deit_small_patch16_LS
from shearletNN.layers import CGELU

def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

def shearlet_transform(img):
    return frequency_shearlet_transform(img, shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, 3 * shearlets.shape[0], patch_size, patch_size], x.shape
    break
print('building model...')
model = complex_rope_mixed_deit_small_patch16_LS(img_size=patch_size, in_chans=3 * shearlets.shape[0], act_layer=CGELU, num_classes=101)
for p in model.parameters():
    assert not p.isnan().any()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(32):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=4)
    gc.collect()
    test(model, train_loader)
    test(model, val_loader)

Files already downloaded and verified
Files already downloaded and verified


0it [00:03, ?it/s]


building model...
training model...
epoch 0


  with torch.cuda.amp.autocast(enabled=False):
  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
109it [00:25,  4.30it/s]



Test set: Avg. loss: 0.0722, Accuracy: 638/6941 (9%)


Test set: Avg. loss: 0.0742, Accuracy: 160/1736 (9%)

epoch 1


109it [00:24,  4.39it/s]



Test set: Avg. loss: 0.0711, Accuracy: 764/6941 (11%)


Test set: Avg. loss: 0.0731, Accuracy: 199/1736 (11%)

epoch 2


109it [00:24,  4.48it/s]



Test set: Avg. loss: 0.0651, Accuracy: 756/6941 (11%)


Test set: Avg. loss: 0.0682, Accuracy: 199/1736 (11%)

epoch 3


109it [00:24,  4.50it/s]



Test set: Avg. loss: 0.0644, Accuracy: 860/6941 (12%)


Test set: Avg. loss: 0.0671, Accuracy: 212/1736 (12%)

epoch 4


109it [00:23,  4.56it/s]



Test set: Avg. loss: 0.0641, Accuracy: 836/6941 (12%)


Test set: Avg. loss: 0.0667, Accuracy: 208/1736 (12%)

epoch 5


109it [00:25,  4.35it/s]



Test set: Avg. loss: 0.0637, Accuracy: 849/6941 (12%)


Test set: Avg. loss: 0.0670, Accuracy: 208/1736 (12%)

epoch 6


109it [00:24,  4.49it/s]



Test set: Avg. loss: 0.0633, Accuracy: 924/6941 (13%)


Test set: Avg. loss: 0.0670, Accuracy: 228/1736 (13%)

epoch 7


109it [00:24,  4.51it/s]



Test set: Avg. loss: 0.0626, Accuracy: 974/6941 (14%)


Test set: Avg. loss: 0.0662, Accuracy: 237/1736 (14%)

epoch 8


109it [00:23,  4.56it/s]



Test set: Avg. loss: 0.0620, Accuracy: 1099/6941 (16%)


Test set: Avg. loss: 0.0661, Accuracy: 265/1736 (15%)

epoch 9


109it [00:24,  4.48it/s]



Test set: Avg. loss: 0.0613, Accuracy: 1161/6941 (17%)


Test set: Avg. loss: 0.0663, Accuracy: 278/1736 (16%)

epoch 10


109it [00:23,  4.54it/s]



Test set: Avg. loss: 0.0606, Accuracy: 1216/6941 (18%)


Test set: Avg. loss: 0.0657, Accuracy: 289/1736 (17%)

epoch 11


109it [00:23,  4.61it/s]



Test set: Avg. loss: 0.0598, Accuracy: 1218/6941 (18%)


Test set: Avg. loss: 0.0646, Accuracy: 286/1736 (16%)

epoch 12


109it [00:23,  4.67it/s]



Test set: Avg. loss: 0.0595, Accuracy: 1243/6941 (18%)


Test set: Avg. loss: 0.0649, Accuracy: 296/1736 (17%)

epoch 13


109it [00:23,  4.68it/s]



Test set: Avg. loss: 0.0584, Accuracy: 1297/6941 (19%)


Test set: Avg. loss: 0.0643, Accuracy: 299/1736 (17%)

epoch 14


109it [00:24,  4.41it/s]



Test set: Avg. loss: 0.0581, Accuracy: 1382/6941 (20%)


Test set: Avg. loss: 0.0640, Accuracy: 329/1736 (19%)

epoch 15


109it [00:24,  4.48it/s]



Test set: Avg. loss: 0.0575, Accuracy: 1402/6941 (20%)


Test set: Avg. loss: 0.0647, Accuracy: 330/1736 (19%)

epoch 16


109it [00:24,  4.40it/s]



Test set: Avg. loss: 0.0567, Accuracy: 1430/6941 (21%)


Test set: Avg. loss: 0.0634, Accuracy: 327/1736 (19%)

epoch 17


109it [00:25,  4.27it/s]



Test set: Avg. loss: 0.0559, Accuracy: 1495/6941 (22%)


Test set: Avg. loss: 0.0632, Accuracy: 335/1736 (19%)

epoch 18


109it [00:25,  4.34it/s]



Test set: Avg. loss: 0.0558, Accuracy: 1537/6941 (22%)


Test set: Avg. loss: 0.0625, Accuracy: 356/1736 (21%)

epoch 19


109it [00:24,  4.42it/s]



Test set: Avg. loss: 0.0552, Accuracy: 1547/6941 (22%)


Test set: Avg. loss: 0.0626, Accuracy: 334/1736 (19%)

epoch 20


109it [00:24,  4.45it/s]



Test set: Avg. loss: 0.0547, Accuracy: 1624/6941 (23%)


Test set: Avg. loss: 0.0619, Accuracy: 352/1736 (20%)

epoch 21


109it [00:24,  4.41it/s]



Test set: Avg. loss: 0.0544, Accuracy: 1679/6941 (24%)


Test set: Avg. loss: 0.0623, Accuracy: 354/1736 (20%)

epoch 22


109it [00:24,  4.48it/s]



Test set: Avg. loss: 0.0543, Accuracy: 1668/6941 (24%)


Test set: Avg. loss: 0.0618, Accuracy: 354/1736 (20%)

epoch 23


109it [00:24,  4.40it/s]



Test set: Avg. loss: 0.0538, Accuracy: 1697/6941 (24%)


Test set: Avg. loss: 0.0618, Accuracy: 369/1736 (21%)

epoch 24


109it [00:25,  4.29it/s]



Test set: Avg. loss: 0.0534, Accuracy: 1729/6941 (25%)


Test set: Avg. loss: 0.0619, Accuracy: 355/1736 (20%)

epoch 25


109it [00:24,  4.53it/s]



Test set: Avg. loss: 0.0529, Accuracy: 1780/6941 (26%)


Test set: Avg. loss: 0.0621, Accuracy: 378/1736 (22%)

epoch 26


109it [00:23,  4.55it/s]



Test set: Avg. loss: 0.0531, Accuracy: 1738/6941 (25%)


Test set: Avg. loss: 0.0620, Accuracy: 362/1736 (21%)

epoch 27


109it [00:23,  4.63it/s]



Test set: Avg. loss: 0.0527, Accuracy: 1829/6941 (26%)


Test set: Avg. loss: 0.0613, Accuracy: 376/1736 (22%)

epoch 28


109it [00:24,  4.47it/s]



Test set: Avg. loss: 0.0526, Accuracy: 1806/6941 (26%)


Test set: Avg. loss: 0.0615, Accuracy: 368/1736 (21%)

epoch 29


109it [00:24,  4.42it/s]



Test set: Avg. loss: 0.0522, Accuracy: 1809/6941 (26%)


Test set: Avg. loss: 0.0617, Accuracy: 383/1736 (22%)

epoch 30


109it [00:23,  4.56it/s]



Test set: Avg. loss: 0.0517, Accuracy: 1855/6941 (27%)


Test set: Avg. loss: 0.0616, Accuracy: 374/1736 (22%)

epoch 31


109it [00:24,  4.50it/s]



Test set: Avg. loss: 0.0520, Accuracy: 1823/6941 (26%)


Test set: Avg. loss: 0.0619, Accuracy: 384/1736 (22%)



In [4]:
from shearletNN.complex_deit import rope_mixed_ape_deit_small_patch8_LS, rope_mixed_ape_deit_small_patch16_LS
from shearletNN.layers import CReLU

def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

def shearlet_transform(img):
    return frequency_shearlet_transform(img, shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, patch_size, patch_size], x.shape
    break
print('building model...')
model = rope_mixed_ape_deit_small_patch16_LS(img_size=64, in_chans=shearlets.shape[0] * 3, act_layer=CGELU)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(16):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=4)
    gc.collect()
    test(model, train_loader)
    test(model, val_loader)



Files already downloaded and verified
Files already downloaded and verified


0it [00:04, ?it/s]


building model...
training model...
epoch 0


  with torch.cuda.amp.autocast(enabled=False):


0


  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
0it [00:06, ?it/s]

1
2
3
4
5
6
7





AssertionError: nan in fc1 output

In [5]:
for p in model.parameters():
    assert not p.isnan().any()

AssertionError: 

In [None]:
model(x)

tensor([[  2.5989,   2.7786,  -1.9066,  ..., -16.6108, -14.2221, -14.1932],
        [  5.6196,   4.4043,  -0.3138,  ..., -22.7475, -19.6062, -20.9539],
        [  2.9905,   2.7827,  -3.0569,  ..., -16.7674, -13.6852, -11.4824],
        ...,
        [  7.8631,   6.3224,   0.2821,  ..., -23.9034, -23.6231, -24.1045],
        [  3.3916,   2.9951,  -3.1254,  ..., -20.9101, -17.9030, -17.2719],
        [  4.3800,   4.5667,   0.8396,  ..., -17.4205, -17.9941, -18.4154]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [None]:
# amplitude:
amp = torch.abs(x)
print((amp.real.max(), amp.real.min()))

phase = torch.arctan(x.imag / x.real)

phase = torch.nan_to_num(torch.arctan(x.imag / x.real), posinf=torch.math.pi / 2, neginf= -torch.math.pi / 2)

print((phase.real.max(), phase.real.min()))

def to_magnitude_phase(x):
    """
    return magnitude/phase representation 
    """
    return torch.complex(torch.abs(x), torch.nan_to_num(torch.arctan(x.imag / x.real), posinf=torch.math.pi / 2, neginf= -torch.math.pi / 2))

(tensor(5.4802, device='cuda:0'), tensor(0., device='cuda:0'))
(tensor(1.5708, device='cuda:0'), tensor(-1.5708, device='cuda:0'))


In [None]:
x = torch.tensor([[1.5,.0,.0,.0]])
layerNorm = torch.nn.LayerNorm(4, elementwise_affine = False)
y1 = layerNorm(x)
mean = x.mean(-1, keepdim = True)
var = x.var(-1, keepdim = True, unbiased=False)
y2 = (x-mean)/torch.sqrt(var+layerNorm.eps)

torch.allclose(y1, y2)

True

In [None]:
x = torch.randn((50,20,100))

layerNorm = torch.nn.LayerNorm(x.shape[-1], elementwise_affine = True)
y1 = layerNorm(x)

mean = torch.mean(x, dim=-1, keepdim=True)
var = torch.square(x - mean).mean(dim=-1, keepdim=True)
y2 = ((x - mean) / torch.sqrt(var + layerNorm.eps)) * layerNorm.weight + layerNorm.bias

torch.allclose(y1, y2, atol=1e-5, rtol=1e-5)

True