In [123]:
import torch
import torch.fft as fft
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import numpy as np
import warnings
warnings.filterwarnings('ignore')
def imscatter(X, images, zoom=2):
    ax = plt.gca()
    for i, img in enumerate(images):
        x, y = X[i, :]
        im = OffsetImage(img.numpy().T, zoom=zoom)
        ab = AnnotationBbox(im, (x, y), xycoords='data', frameon=False)
        ax.add_artist(ab)
    ax.update_datalim(X)
    ax.autoscale()
    ax.set_xticks([])
    ax.set_yticks([])

    
def plotMNIST(images, zoom = 2):
    N = len(images)
    dgrid = int(np.ceil(np.sqrt(N)))
    ex = np.arange(dgrid)
    x, y = np.meshgrid(ex, ex)
    X = np.zeros((N, 2))
    X[:, 0] = x.flatten()[0:N]
    X[:, 1] = y.flatten()[0:N]
    imscatter(X, images, zoom)

        
class TransformLowPass(object):

    def __init__(self):
        self.fft = True
        self.lowpass = True
        self.flatten = True
        self.norm = False
        rows, cols = 32,32 
        crow, ccol = int(rows / 2), int(cols / 2)
        mask = np.zeros((rows, cols), np.uint8)
        r = 6
        center = [crow, ccol]
        x, y = np.ogrid[:rows, :cols]
        mask_area = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= r*r
        mask[mask_area] = 1
        self.mask = torch.from_numpy(mask).bool()
#         print(torch.masked_select(self.mask,self.mask).shape)
        
    def __call__(self, img):
        if self.fft:
            img = fft.fft2(img)
        if self.lowpass:
            img = img*self.mask
        if self.flatten:
            img = torch.masked_select(img,self.mask)
        if self.norm:
            img = torch.abs(img)
        return img

In [124]:

epochs = 0
batch_size = 100
lr = .001
# ,transforms.Normalize((.5,.5,.5),(.5,.5,.5)),
transform = transforms.Compose([transforms.ToTensor(),TransformLowPass()])

train_dataset = torchvision.datasets.CIFAR10(root="./data",train=True, download=True,transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root="./data",train=False, download=True,transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

device = torch.device('cpu')


Files already downloaded and verified
Files already downloaded and verified


In [125]:
images, labels = iter(train_loader).next()
images[0].shape

torch.Size([339])

In [126]:

# plotMNIST(images,zoom = 2)

In [127]:
from complexPyTorch.complexLayers import ComplexBatchNorm2d, ComplexConv2d, ComplexLinear
from complexPyTorch.complexFunctions import complex_relu, complex_max_pool2d

class NeuralNet(nn.Module):
    
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.fc1 = ComplexLinear(339,10)   
#         self.fc2 = ComplexLinear(100,100)   
#         self.fc3 = ComplexLinear(100,10)   
             
    def forward(self,x):
        x = complex_relu(self.fc1(x))
#         x = self.fc3(x)
        x = x.abs()
        return x
        
model = NeuralNet().to(device)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

criterion = nn.CrossEntropyLoss()


im = model(images[0])


6800


In [128]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=lr)

for epoch in range(50):
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        output = model(images)
        loss = criterion(output, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(images), 
                len(train_loader.dataset),
                100. * batch_idx / len(train_loader), 
                loss.item())
            )

# Test
with torch.no_grad():
    n_correct=0 
    n_samples = 0
    n_class_correct = [0 for _ in range(10)]
    n_class_samples = [0 for _ in range(10)]
    
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        output = model(images)
        
        _,predicted = torch.max(output,1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()
        
        for i in range(batch_size):
            label = labels[i]
            pred = predicted[i]
            
            if (label == pred):
                n_class_correct[label]+=1
            n_class_samples[label] += 1

acc = 100.0 *n_correct / n_samples
print(f"Accuracy of network: {acc}")
for i in range(10):
    acc = 100.0 *n_class_correct[i] / n_class_samples[i]
    print(f"Accuracy of class {classes[i]}: {acc}")
    



Accuracy of network: 13.04
Accuracy of class plane: 12.5
Accuracy of class car: 16.7
Accuracy of class bird: 6.3
Accuracy of class cat: 18.3
Accuracy of class deer: 8.0
Accuracy of class dog: 6.3
Accuracy of class frog: 23.6
Accuracy of class horse: 15.4
Accuracy of class ship: 6.6
Accuracy of class truck: 16.7
