In [1]:
from shearletNN.shearlets import getcomplexshearlets2D
from shearletNN.shearlet_utils import frequency_shearlet_transform, spatial_shearlet_transform, ShearletTransformLoader
from shearletNN.complex_resnet import complex_resnet18, complex_resnet34, complex_resnet50
from shearletNN.layers import CGELU, CReLU

import torch
import torchvision
from torchvision.transforms import v2
from torchvision import transforms

import gc


patch_size = 64
image_size = 128

rows, cols = image_size, image_size


shearlets, shearletIdxs, RMS, dualFrameWeights = getcomplexshearlets2D(	rows, 
                                                                        cols, 
                                                                        1, 
                                                                        3, 
                                                                        1, 
                                                                        0.5,
                                                                        wavelet_eff_support = image_size,
                                                                        gaussian_eff_support = image_size
                                                                        )

shearlets = torch.tensor(shearlets).permute(2, 0, 1).type(torch.complex64).to(0)

In [2]:
class Unraveling:
    def __init__(self, n):
        self.levels = []
        for i in range(0, n // 2):
            level = []
            for j in range(i, n - i):
                level.append((j, i))
                level.append((i, j))

                level.append((j, n - (i + 1)))
                level.append((n - (i + 1), j))

            level = list(set(level))
            self.levels.append((torch.tensor([x for x, _ in level]), torch.tensor([y for _, y in level])))
    
    def __call__(self, x):
        return [x[..., a, b] for a, b in self.levels]

In [3]:
from tqdm import tqdm

def train(model, optimizer, loader, accumulate=1):
    model.train()
    loss = torch.nn.CrossEntropyLoss()

    for i, (X, y) in tqdm(enumerate(loader)):
        out = model(X.to(0))
        optimizer.zero_grad()
        l = loss(out, y.to(0)) / accumulate
        l.backward()
        if i % accumulate == (accumulate - 1):
            optimizer.step()
        

def accuracy(output, target, topk=(1,)):
    output = output.to(torch.device('cpu'))
    target = target.to(torch.device('cpu'))
    maxk = max(topk)
    batch_size = target.shape[0]

    _, idx = output.sort(dim=1, descending=True)
    pred = idx.narrow(1, 0, maxk).t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def epoch_accuracy(loader_s, student):
    student.eval()

    out_epoch_s = [accuracy(student(L.to(0)), y)[0].detach().cpu().item() for L, y in loader_s]

    student.train()

    return sum(out_epoch_s) / len(out_epoch_s)

def test(network, test_loader):
    network.eval().to(0)
    test_loss = 0
    correct = 0
    total = 0
    test_losses=[]
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data.to(0))
            test_loss += torch.nn.CrossEntropyLoss()(output, target.to(0)).item()
            pred = output.data.max(1, keepdim=True)[1].cpu()
            correct += pred.eq(target.data.view_as(pred)).sum()
            total += target.shape[0]
        test_loss /= total
        test_losses.append(test_loss)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, total,
        100. * correct / total))

class IndexSubsetDataset:
    def __init__(self, ds, inds):
        self.ds = ds
        self.inds = inds

    def __iter__(self):
        for i in range(len(self.inds)):
            yield self[i]
    
    def __getitem__(self, i):
        return self.ds[self.inds[i]]
    
    def __len__(self):
        return len(self.inds)

In [4]:
batch_size_train = 64

In [9]:
def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

def shearlet_transform(img):
    return frequency_shearlet_transform(img, shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, patch_size, patch_size], x.shape
    print(x.shape)
    unravel = Unraveling(64)
    print(unravel(x)[0].shape)
    break

for i, l in enumerate(unravel.levels[::-1]):
    assert len(l[0]) == (i * 8) + 4, (i, len(l[0]), l)
    for x, y in zip(*l):
        if x != 31 - i and y != 31 - i:
            if x != 32 + i and y != 32 + i:
                assert False, (i, (x, y))

Files already downloaded and verified
Files already downloaded and verified


0it [00:00, ?it/s]


torch.Size([64, 30, 64, 64])
torch.Size([64, 30, 252])


In [None]:
class UnravelNN(torch.nn.Module):
    def __init__(self, n, embed_dim):
        sizes = [4 + i*8 for i in range(n)][::-1]
        self.layers = torch.nn.ModuleList([torch.nn.Linear(size, embed_dim, dtype=torch.complex64) for size in sizes])
        self.act = CReLU()

    def forward(self, x):
        x = torch.stack([layer(a) for layer, a in zip(self.layers, x)], -2) # (B, C, n, embed_dim)
        # if we flatten the C and embed_dim together we have n tokens, one for each frequency level.  
        # This is kind of what we would want as a transformer input
        return
        

In [4]:
def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

def shearlet_transform(img):
    return frequency_shearlet_transform(img, shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, patch_size, patch_size], x.shape
    break
print('building model...')
model = complex_resnet18(in_dim=shearlets.shape[0] * 3, complex=True, phasor=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(16):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=4)
    gc.collect()
    test(model, train_loader)
    test(model, val_loader)

Files already downloaded and verified
Files already downloaded and verified


0it [00:03, ?it/s]


building model...
training model...
epoch 0


109it [00:35,  3.09it/s]



Test set: Avg. loss: 0.0672, Accuracy: 1217/6941 (18%)


Test set: Avg. loss: 0.0695, Accuracy: 296/1736 (17%)

epoch 1


109it [00:28,  3.77it/s]



Test set: Avg. loss: 0.0626, Accuracy: 1458/6941 (21%)


Test set: Avg. loss: 0.0660, Accuracy: 352/1736 (20%)

epoch 2


109it [00:43,  2.51it/s]



Test set: Avg. loss: 0.0597, Accuracy: 1548/6941 (22%)


Test set: Avg. loss: 0.0634, Accuracy: 354/1736 (20%)

epoch 3


109it [00:37,  2.94it/s]



Test set: Avg. loss: 0.0567, Accuracy: 1661/6941 (24%)


Test set: Avg. loss: 0.0622, Accuracy: 385/1736 (22%)

epoch 4


109it [00:34,  3.12it/s]



Test set: Avg. loss: 0.0541, Accuracy: 1678/6941 (24%)


Test set: Avg. loss: 0.0599, Accuracy: 368/1736 (21%)

epoch 5


109it [00:38,  2.87it/s]



Test set: Avg. loss: 0.0490, Accuracy: 2276/6941 (33%)


Test set: Avg. loss: 0.0565, Accuracy: 485/1736 (28%)

epoch 6


109it [00:39,  2.79it/s]



Test set: Avg. loss: 0.0457, Accuracy: 2374/6941 (34%)


Test set: Avg. loss: 0.0560, Accuracy: 516/1736 (30%)

epoch 7


109it [00:43,  2.50it/s]



Test set: Avg. loss: 0.0423, Accuracy: 2784/6941 (40%)


Test set: Avg. loss: 0.0540, Accuracy: 529/1736 (30%)

epoch 8


109it [00:39,  2.77it/s]



Test set: Avg. loss: 0.0403, Accuracy: 3146/6941 (45%)


Test set: Avg. loss: 0.0541, Accuracy: 534/1736 (31%)

epoch 9


109it [00:39,  2.77it/s]



Test set: Avg. loss: 0.0369, Accuracy: 3295/6941 (47%)


Test set: Avg. loss: 0.0524, Accuracy: 550/1736 (32%)

epoch 10


109it [00:40,  2.71it/s]



Test set: Avg. loss: 0.0335, Accuracy: 3763/6941 (54%)


Test set: Avg. loss: 0.0519, Accuracy: 559/1736 (32%)

epoch 11


109it [00:41,  2.65it/s]



Test set: Avg. loss: 0.0314, Accuracy: 3835/6941 (55%)


Test set: Avg. loss: 0.0523, Accuracy: 568/1736 (33%)

epoch 12


109it [00:40,  2.71it/s]



Test set: Avg. loss: 0.0273, Accuracy: 4450/6941 (64%)


Test set: Avg. loss: 0.0509, Accuracy: 570/1736 (33%)

epoch 13


109it [00:39,  2.73it/s]



Test set: Avg. loss: 0.0238, Accuracy: 4895/6941 (71%)


Test set: Avg. loss: 0.0505, Accuracy: 584/1736 (34%)

epoch 14


109it [00:39,  2.76it/s]



Test set: Avg. loss: 0.0205, Accuracy: 5139/6941 (74%)


Test set: Avg. loss: 0.0506, Accuracy: 589/1736 (34%)

epoch 15


109it [00:40,  2.68it/s]



Test set: Avg. loss: 0.0182, Accuracy: 5483/6941 (79%)


Test set: Avg. loss: 0.0507, Accuracy: 565/1736 (33%)



In [4]:
from shearletNN.complex_deit import rope_mixed_ape_deit_small_patch8_LS, rope_mixed_ape_deit_small_patch16_LS
from shearletNN.layers import CReLU

def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

def shearlet_transform(img):
    return frequency_shearlet_transform(img, shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, patch_size, patch_size], x.shape
    break
print('building model...')
model = rope_mixed_ape_deit_small_patch16_LS(img_size=64, in_chans=shearlets.shape[0] * 3, act_layer=CGELU)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(16):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=4)
    gc.collect()
    test(model, train_loader)
    test(model, val_loader)



Files already downloaded and verified
Files already downloaded and verified


0it [00:04, ?it/s]


building model...
training model...
epoch 0


  with torch.cuda.amp.autocast(enabled=False):


0


  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
0it [00:06, ?it/s]

1
2
3
4
5
6
7





AssertionError: nan in fc1 output

In [5]:
for p in model.parameters():
    assert not p.isnan().any()

AssertionError: 

In [None]:
model(x)

tensor([[  2.5989,   2.7786,  -1.9066,  ..., -16.6108, -14.2221, -14.1932],
        [  5.6196,   4.4043,  -0.3138,  ..., -22.7475, -19.6062, -20.9539],
        [  2.9905,   2.7827,  -3.0569,  ..., -16.7674, -13.6852, -11.4824],
        ...,
        [  7.8631,   6.3224,   0.2821,  ..., -23.9034, -23.6231, -24.1045],
        [  3.3916,   2.9951,  -3.1254,  ..., -20.9101, -17.9030, -17.2719],
        [  4.3800,   4.5667,   0.8396,  ..., -17.4205, -17.9941, -18.4154]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [None]:
# amplitude:
amp = torch.abs(x)
print((amp.real.max(), amp.real.min()))

phase = torch.arctan(x.imag / x.real)

phase = torch.nan_to_num(torch.arctan(x.imag / x.real), posinf=torch.math.pi / 2, neginf= -torch.math.pi / 2)

print((phase.real.max(), phase.real.min()))

def to_magnitude_phase(x):
    """
    return magnitude/phase representation 
    """
    return torch.complex(torch.abs(x), torch.nan_to_num(torch.arctan(x.imag / x.real), posinf=torch.math.pi / 2, neginf= -torch.math.pi / 2))

(tensor(5.4802, device='cuda:0'), tensor(0., device='cuda:0'))
(tensor(1.5708, device='cuda:0'), tensor(-1.5708, device='cuda:0'))


In [None]:
x = torch.tensor([[1.5,.0,.0,.0]])
layerNorm = torch.nn.LayerNorm(4, elementwise_affine = False)
y1 = layerNorm(x)
mean = x.mean(-1, keepdim = True)
var = x.var(-1, keepdim = True, unbiased=False)
y2 = (x-mean)/torch.sqrt(var+layerNorm.eps)

torch.allclose(y1, y2)

True

In [None]:
x = torch.randn((50,20,100))

layerNorm = torch.nn.LayerNorm(x.shape[-1], elementwise_affine = True)
y1 = layerNorm(x)

mean = torch.mean(x, dim=-1, keepdim=True)
var = torch.square(x - mean).mean(dim=-1, keepdim=True)
y2 = ((x - mean) / torch.sqrt(var + layerNorm.eps)) * layerNorm.weight + layerNorm.bias

torch.allclose(y1, y2, atol=1e-5, rtol=1e-5)

True