In [1]:
from shearletNN.shearlets import getcomplexshearlets2D
from shearletNN.shearlet_utils import frequency_shearlet_transform, spatial_shearlet_transform, ShearletTransformLoader
from shearletNN.complex_resnet import complex_resnet18, complex_resnet34, complex_resnet50
from shearletNN.layers import CGELU

import torch
import torchvision
from torchvision.transforms import v2
from torchvision import transforms

import gc


patch_size = 128
image_size = 256

rows, cols = image_size, image_size


shearlets, shearletIdxs, RMS, dualFrameWeights = getcomplexshearlets2D(	rows, 
                                                                        cols, 
                                                                        1, 
                                                                        3, 
                                                                        1, 
                                                                        0.5,
                                                                        wavelet_eff_support = image_size,
                                                                        gaussian_eff_support = image_size
                                                                        )

shearlets = torch.tensor(shearlets).permute(2, 0, 1).type(torch.complex64).to(0)

In [2]:
from tqdm import tqdm

def train(model, optimizer, loader, accumulate=1):
    model.train()
    loss = torch.nn.CrossEntropyLoss()

    for i, (X, y) in tqdm(enumerate(loader)):
        out = model(X.to(0))
        optimizer.zero_grad()
        l = loss(out, y.to(0)) / accumulate
        l.backward()
        if i % accumulate == (accumulate - 1):
            optimizer.step()
        

def accuracy(output, target, topk=(1,)):
    output = output.to(torch.device('cpu'))
    target = target.to(torch.device('cpu'))
    maxk = max(topk)
    batch_size = target.shape[0]

    _, idx = output.sort(dim=1, descending=True)
    pred = idx.narrow(1, 0, maxk).t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(dim=0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def epoch_accuracy(loader_s, student):
    student.eval()

    out_epoch_s = [accuracy(student(L.to(0)), y)[0].detach().cpu().item() for L, y in loader_s]

    student.train()

    return sum(out_epoch_s) / len(out_epoch_s)

def test(network, test_loader):
    network.eval().to(0)
    test_loss = 0
    correct = 0
    total = 0
    test_losses=[]
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data.to(0))
            test_loss += torch.nn.CrossEntropyLoss()(output, target.to(0)).item()
            pred = output.data.max(1, keepdim=True)[1].cpu()
            correct += pred.eq(target.data.view_as(pred)).sum()
            total += target.shape[0]
        test_loss /= total
        test_losses.append(test_loss)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, total,
        100. * correct / total))

class IndexSubsetDataset:
    def __init__(self, ds, inds):
        self.ds = ds
        self.inds = inds

    def __iter__(self):
        for i in range(len(self.inds)):
            yield self[i]
    
    def __getitem__(self, i):
        return self.ds[self.inds[i]]
    
    def __len__(self):
        return len(self.inds)

In [3]:
batch_size_train = 64

In [4]:
from shearletNN.complex_deit import *
from shearletNN.layers import CReLU

def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

def shearlet_transform(img):
    return frequency_shearlet_transform(img, shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, patch_size, patch_size], x.shape
    break
print('building model...')
model = rope_mixed_deit_small_patch16_LS(img_size=128, in_chans=3, act_layer=CGELU, num_classes=101)
for p in model.parameters():
    assert not p.isnan().any()
optimizer = torch.optim.Adam(model.parameters(), lr=4e-4)
print('training model...')
for epoch in range(32):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=4)
    gc.collect()
    test(model, train_loader)
    test(model, val_loader)



Files already downloaded and verified
Files already downloaded and verified


0it [00:00, ?it/s]


AssertionError: torch.Size([64, 3, 128, 128])

In [None]:
model = rope_mixed_ape_deit_small_patch16_LS(img_size=64, in_chans=shearlets.shape[0] * 3, act_layer=CGELU)

for p in model.parameters():
    assert not p.isnan().any()

In [None]:
model(x)

RuntimeError: Input type (CUDAComplexFloatType) and weight type (CPUComplexFloatType) should be the same

In [None]:
def repeat3(x):
    return x.repeat(3, 1, 1)[:3]

transform = v2.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    repeat3,
])

ds_train = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_train = IndexSubsetDataset(ds_train, sum([list(range(len(ds_train)))[i::5] for i in range(1, 5)], []))

ds_val = torchvision.datasets.Caltech101('./', transform=transform, download = True)
ds_val = IndexSubsetDataset(ds_val, list(range(len(ds_val)))[0::5])

train_loader = torch.utils.data.DataLoader(
  ds_train,
  batch_size=batch_size_train, shuffle=True, num_workers=0)

def shearlet_transform(img):
    return frequency_shearlet_transform(img, shearlets, patch_size)

train_loader = ShearletTransformLoader(train_loader, shearlet_transform)

val_loader = torch.utils.data.DataLoader(
  ds_val,
  batch_size=batch_size_train, shuffle=False)

val_loader = ShearletTransformLoader(val_loader, shearlet_transform)

for x, y in tqdm(train_loader):
    assert list(x.shape) == [batch_size_train, shearlets.shape[0] * 3, patch_size, patch_size], x.shape
    break
print('building model...')
model = complex_resnet18(in_dim=shearlets.shape[0] * 3, complex=True, phasor=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print('training model...')
for epoch in range(16):
    print('epoch', epoch)
    train(model.to(0), optimizer, train_loader, accumulate=4)
    gc.collect()
    test(model, train_loader)
    test(model, val_loader)

Files already downloaded and verified
Files already downloaded and verified


0it [00:00, ?it/s]


building model...
training model...
epoch 0


109it [00:34,  3.20it/s]



Test set: Avg. loss: 0.0667, Accuracy: 718/6941 (10%)


Test set: Avg. loss: 0.0690, Accuracy: 184/1736 (11%)

epoch 1


109it [00:33,  3.23it/s]



Test set: Avg. loss: 0.0617, Accuracy: 1468/6941 (21%)


Test set: Avg. loss: 0.0652, Accuracy: 359/1736 (21%)

epoch 2


109it [00:33,  3.23it/s]



Test set: Avg. loss: 0.0578, Accuracy: 1570/6941 (23%)


Test set: Avg. loss: 0.0627, Accuracy: 369/1736 (21%)

epoch 3


109it [00:33,  3.25it/s]



Test set: Avg. loss: 0.0526, Accuracy: 1973/6941 (28%)


Test set: Avg. loss: 0.0593, Accuracy: 444/1736 (26%)

epoch 4


109it [00:34,  3.18it/s]



Test set: Avg. loss: 0.0484, Accuracy: 2275/6941 (33%)


Test set: Avg. loss: 0.0562, Accuracy: 509/1736 (29%)

epoch 5


109it [00:33,  3.23it/s]



Test set: Avg. loss: 0.0459, Accuracy: 2535/6941 (37%)


Test set: Avg. loss: 0.0553, Accuracy: 491/1736 (28%)

epoch 6


109it [00:33,  3.25it/s]



Test set: Avg. loss: 0.0406, Accuracy: 3207/6941 (46%)


Test set: Avg. loss: 0.0518, Accuracy: 553/1736 (32%)

epoch 7


109it [00:33,  3.24it/s]



Test set: Avg. loss: 0.0373, Accuracy: 3300/6941 (48%)


Test set: Avg. loss: 0.0516, Accuracy: 577/1736 (33%)

epoch 8


109it [00:33,  3.25it/s]



Test set: Avg. loss: 0.0346, Accuracy: 3563/6941 (51%)


Test set: Avg. loss: 0.0508, Accuracy: 568/1736 (33%)

epoch 9


109it [00:33,  3.23it/s]



Test set: Avg. loss: 0.0311, Accuracy: 3961/6941 (57%)


Test set: Avg. loss: 0.0497, Accuracy: 602/1736 (35%)

epoch 10


109it [00:33,  3.26it/s]



Test set: Avg. loss: 0.0268, Accuracy: 4605/6941 (66%)


Test set: Avg. loss: 0.0498, Accuracy: 612/1736 (35%)

epoch 11


109it [00:33,  3.26it/s]



Test set: Avg. loss: 0.0242, Accuracy: 4916/6941 (71%)


Test set: Avg. loss: 0.0483, Accuracy: 611/1736 (35%)

epoch 12


109it [00:33,  3.27it/s]



Test set: Avg. loss: 0.0204, Accuracy: 5310/6941 (77%)


Test set: Avg. loss: 0.0478, Accuracy: 630/1736 (36%)

epoch 13


109it [00:33,  3.27it/s]



Test set: Avg. loss: 0.0173, Accuracy: 5544/6941 (80%)


Test set: Avg. loss: 0.0473, Accuracy: 638/1736 (37%)

epoch 14


109it [00:33,  3.27it/s]



Test set: Avg. loss: 0.0156, Accuracy: 5723/6941 (82%)


Test set: Avg. loss: 0.0478, Accuracy: 618/1736 (36%)

epoch 15


109it [00:33,  3.27it/s]



Test set: Avg. loss: 0.0130, Accuracy: 5916/6941 (85%)


Test set: Avg. loss: 0.0481, Accuracy: 632/1736 (36%)



In [None]:
# amplitude:
amp = torch.abs(x)
print((amp.real.max(), amp.real.min()))

phase = torch.arctan(x.imag / x.real)

phase = torch.nan_to_num(torch.arctan(x.imag / x.real), posinf=torch.math.pi / 2, neginf= -torch.math.pi / 2)

print((phase.real.max(), phase.real.min()))

def to_magnitude_phase(x):
    """
    return magnitude/phase representation 
    """
    return torch.complex(torch.abs(x), torch.nan_to_num(torch.arctan(x.imag / x.real), posinf=torch.math.pi / 2, neginf= -torch.math.pi / 2))

(tensor(5.4802, device='cuda:0'), tensor(0., device='cuda:0'))
(tensor(1.5708, device='cuda:0'), tensor(-1.5708, device='cuda:0'))


In [None]:
x = torch.tensor([[1.5,.0,.0,.0]])
layerNorm = torch.nn.LayerNorm(4, elementwise_affine = False)
y1 = layerNorm(x)
mean = x.mean(-1, keepdim = True)
var = x.var(-1, keepdim = True, unbiased=False)
y2 = (x-mean)/torch.sqrt(var+layerNorm.eps)

torch.allclose(y1, y2)

True

In [None]:
x = torch.randn((50,20,100))

layerNorm = torch.nn.LayerNorm(x.shape[-1], elementwise_affine = True)
y1 = layerNorm(x)

mean = torch.mean(x, dim=-1, keepdim=True)
var = torch.square(x - mean).mean(dim=-1, keepdim=True)
y2 = ((x - mean) / torch.sqrt(var + layerNorm.eps)) * layerNorm.weight + layerNorm.bias

torch.allclose(y1, y2, atol=1e-5, rtol=1e-5)

True