In [6]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import numpy as np


In [2]:
# includes residual
       
class Lconv_core(nn.Module):
    """ L-conv layer with full L """
    def __init__(self,d,num_L=1,cin=1,cout=1,rank=8):
        """
        L:(num_L, d, d)
        Wi: (num_L, cout, cin)
        """
        super().__init__()
        self.L = nn.Parameter(torch.Tensor(num_L, d, d))
        self.Wi = nn.Parameter(torch.Tensor(num_L, cout, cin))
        
        # initialize weights and biases
        nn.init.kaiming_normal_(self.L) 
        nn.init.kaiming_normal_(self.Wi)
                
    def forward(self, x):
        # x:(batch, channel, flat_d)
        # h = (x + Li x Wi) W0
        y = torch.einsum('kdf,bcf,koc->bod', self.L, x, self.Wi ) +x #+ self.b        
        return y


In [30]:
class Net(nn.Module):
    def __init__(self,d,c,device,k=1, rec=1, hid = 5):
        super().__init__()
        self._rec = rec
        self._c = c
#         with torch.no_grad():
        self.ones = torch.ones((1,c,1)).to(device)
        self.LC = Lconv_core(d=d,num_L=k,cin=c,cout=c)
        self.flat = nn.Flatten(2)
        self.lin1 = nn.Linear(c,hid)
        self.out = nn.Linear(hid,1)
        
    def forward(self, inp):
        print(inp.shape)
        x,y = inp
        # x,y->(batch, channel, flat_d)
        x = self.flat(x)
        y = self.flat(y)
        # copy input to c channels
        x = x * self.ones
        # pass through L-conv
        for _ in range(self._rec):
            x = self.LC(x)
        x = torch.einsum('bcd,bad->bc', x, y)
        x = F.tanh(x)
        x = self.lin1(x)
        x = F.tanh(x) 
        return self.out(x)

In [31]:
class Reshape(nn.Module):
    def __init__(self,shape=None):
        self.shape = shape
        super().__init__()
    def forward(self,x):
        return x.view(-1,*self.shape)

In [32]:
from time import time
class trainer():
    def __init__(self, model, device, optimizer, dataset_class = None, 
                 train_loader=None, test_loader=None,
                 batch_size = 64, test_batch_size = 1000,loss_func = F.nll_loss,
                ):
        """
        usage:
            t = trainer(...)
            t.fit(epochs)
            
        methods:
            .fit(epochs) : train + test; print results; stores results in trainer.history <dict>
            .train(epoch)
            .test()
        """
        self.device = device #torch.device(device)
        self.optimizer = optimizer
        self.model = model
        self.loss_func = loss_func

        self.scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
        self.history = {'train loss':[], 'test loss':[], 'train acc':[], 'test acc':[], 'train time':[]}
        
        if dataset_class:
            self.make_dataloaders(dataset_class, batch_size, test_batch_size)
        else:
            self.train_loader = train_loader 
            self.test_loader = test_loader
        
    def make_dataloaders(self, dataset_class, batch_size, test_batch_size):
        train_kwargs = {'batch_size': batch_size}
        test_kwargs = {'batch_size': test_batch_size}
        if self.device.type =='cuda':        
            cuda_kwargs = {'num_workers': 1,
                           'pin_memory': True,
                           'shuffle': True}
            train_kwargs.update(cuda_kwargs)
            test_kwargs.update(cuda_kwargs)
        
        print('Creating data loaders...',end='')
        dataset1 = dataset_class('../data', train=True, download=True,)
        dataset2 = dataset_class('../data', train=False,)

        self.train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
        self.test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
        print('Done')

#         self.optimizer = optimizer
#         self.model, self.device, self.train_loader,  = model, device, train_loader,
#         self.test_loader = test_loader
        
    def progbar(self,percent, N=10):
        n = int(percent//N)
        return '[' + '='*n + '>' +'.'*(N-n-1) +']'
    
    def train(self,epoch):
        self.model.train()
        training_loss = 0
        correct = 0
        t0 = time()
        for batch_idx, (data, target) in enumerate(self.train_loader):
            if type(data)==list:
                data = [d.to(self.device) for d in data]
            else:
                data = data.to(self.device)
            if type(target)==list:
                target = [d.to(self.device) for d in target]
            else:
                target = target.to(self.device)
                
#             data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()
            print(self.model(data).shape)
            output = self.model(data)
            loss = self.loss_func(output, target)
            training_loss += loss.sum().item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            loss.backward(retain_graph=True)
            self.optimizer.step()
            if batch_idx % 10 == 0:
                perc = 100. * batch_idx / len(self.train_loader)
                t1 = time()
                print('Train Epoch: {} {} {:.1f}s [{}/{} ({:.0f}%)]\tLoss: {:.4g}'.format(
                    epoch, self.progbar(perc), t1-t0,
                    batch_idx * len(data), len(self.train_loader.dataset), # n/N
                    perc, # % passed
                    loss.item()), end='\r')
        
        training_loss /= len(self.train_loader.dataset)
        acc = correct / len(self.train_loader.dataset)    
        print('\nTraining: loss: {:.4g}, Acc: {:.2f}%'.format(training_loss, 100.*acc))
            
        return {'loss':training_loss, 'acc':acc , 'time':t1-t0}
                
        
    def test(self):
        self.model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in self.test_loader:
                if type(data)==list:
                    data = [d.to(self.device) for d in data]
                else:
                    data = data.to(self.device)
                if type(target)==list:
                    target = [d.to(self.device) for d in target]
                else:
                    target = target.to(self.device)
                #data, target = data.to(self.device), target.to(self.device)
                print(self.model(data).shape)
                output = self.model(data)
                
                test_loss += self.loss_func(output, target, reduction='sum').item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        
        test_loss /= len(self.test_loader.dataset)
        test_acc = correct / len(self.test_loader.dataset)
        
        print('Test loss: {:.4g}, Test acc.: {:.2f}%'.format( test_loss, 100.*test_acc))
        return {'loss':test_loss, 'acc':test_acc}
    
    def fit(self,epochs=1):
        for epoch in range(1, epochs + 1):
            r = self.train(epoch)
            self.history['train loss'] += [r['loss']]
            self.history['train acc'] += [r['acc']]
            self.history['train time'] += [r['time']]
            
            r = self.test()
            self.history['test loss'] += [r['loss']]
            self.history['test acc'] += [r['acc']]
            self.scheduler.step()

In [33]:
import numpy as np 

DIM = 20 

# TODO rather than  generating a bunch of
# small perturbed 
class FixedRot(datasets.VisionDataset):
    num_targets = 10
    def __init__(self,*args,angle =np.pi/3,N=50000,size=(7,7),
                 train=True,dataseed=0, fp=None,**kwargs):
        super().__init__(*args,**kwargs)

        if fp is not None:
            npz = fp
            resx = np.linspace(-2., 1., 20)
            self.data = np.zeros([N, 1, DIM, DIM])
            for i in range(0, N):
                #print(type(npz[i]), npz[i])
                resy = (np.rint(npz[i]*9)).astype(int) // 2
                resy = -resy + 9
                for xi in range(0, DIM):
                    #print(xi, resy.shape)
                    self.data[i][0][resy[xi]][xi] = 1.

        self.data = torch.Tensor(self.data)
        #if not train: 
        #    dataseed += 1
        #    N = int(0.2*N)
        torch.manual_seed(dataseed)
        angles = torch.ones(N)*angle # torch.rand(N)*2*np.pi
        #self.data = torch.rand(N,1,*size)-.5
        print(N, self.data.shape)
        with torch.no_grad():
            # Build affine matrices for random translation of each image
            affineMatrices = torch.zeros(N,2,3)
            affineMatrices[:,0,0] = angles.cos()
            affineMatrices[:,1,1] = angles.cos()
            affineMatrices[:,0,1] = angles.sin()
            affineMatrices[:,1,0] = -angles.sin()
            
            flowgrid = F.affine_grid(affineMatrices, size = self.data.shape) # self.data.size()
            self.data_rot = F.grid_sample(self.data, flowgrid)

    def __getitem__(self,idx):
        return self.data[idx], self.data_rot[idx] # , self.data[idx]
    
    def __len__(self):
        return len(self.data)
    
    def default_aug_layers(self):
        return RandomRotateTranslate(0)# no translation


In [34]:
npz = np.load("ode10000x20.npz")
DIM  = 20 
# dataset = CIFAR100
dataset = FixedRot #AugRotMNIST__0
# dataset = AugRotMNIST__0
num_targets = dataset.num_targets

batch_size = 150

test_batch_size = 200 #1000 #150
device = torch.device("cuda")
# device = torch.device("cpu")

train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': test_batch_size}
# if use_cuda:
cuda_kwargs = {'num_workers': 1,
               'pin_memory': True,
               'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

ang_n = 45
ang = np.pi/ang_n
dataset1 = dataset('../data', fp=npz['arr_0'], N=8500, train=True, size=(DIM,DIM), angle=ang)
dataset2 = dataset('../data', fp=npz['arr_0'], N=2500, train=False, size=(DIM,DIM), angle=ang)

train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)


8500 torch.Size([8500, 1, 20, 20])
2500 torch.Size([2500, 1, 20, 20])




In [35]:
for (x,y) in train_loader:
    x = x.to(device)
    y = y.to(device)

In [36]:
x.shape, y.shape

(torch.Size([100, 1, 20, 20]), torch.Size([100, 1, 20, 20]))

In [37]:
s = dataset1.data[0].shape
d = np.prod(s)
# lc = Lconv_2(k=2,d=d, cin=1, cout=1, rank=50)
rec = 8
# model = Net(d,10,k=1,device=device,rec=rec).to(device)

model = Net(d,10,k=1,device=device,rec=rec).to(device)

optimizer = optim.Adam(model.parameters(),lr = 5e-3)


In [38]:
t = trainer(model, device, optimizer, train_loader=train_loader, test_loader=test_loader, 
            #dataset_class=AugRotMNIST__0, #dataset, 
            loss_func=F.mse_loss)
t.fit(20)

torch.Size([150, 1, 20, 20])


ValueError: too many values to unpack (expected 2)