In [78]:
import torch, os, argparse, warnings, copy
import pytorch_lightning as pl
from tqdm.notebook import tqdm
from data_loader import load_train_data, load_test_data
from models import CConvNet, S2ConvNet

In [2]:
TRAIN_SAMPLES = 20000
# MAX_EPOCHS = 20
MAX_EPOCHS = 3
# MIN_DELTA = 0.
# PATIENCE = 10

TRAIN_PATH = "flat_mnist_train_28x28_" + str(TRAIN_SAMPLES) + ".gz"
TEST_PATH = "flat_mnist_test_28x28.gz"

In [3]:
train_data, test_data = load_train_data(TRAIN_PATH), load_train_data(TEST_PATH)

In [4]:
print("Total training examples: {}".format(len(train_data)))
print("Total test examples: {}".format(len(test_data)))

Total training examples: 20000
Total test examples: 10000


In [65]:
hparams = argparse.Namespace()

hparams.name = 'test_model'
hparams.train_batch_size = 32
hparams.test_batch_size = 32
hparams.num_workers = 0
hparams.lr = 1e-3
hparams.weight_decay = 0.

hparams.channels = [16, 24, 32, 64]
hparams.kernels = [3, 3, 3, 3]
hparams.strides = [1, 1, 1, 1]
hparams.activation_fn = 'LeakyReLU'
hparams.batch_norm = True
hparams.nodes = [64, 32]

In [66]:
model = CConvNet(hparams, train_data, test_data)
trainer = pl.Trainer(gpus=1, max_epochs=MAX_EPOCHS)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [67]:
model.count_trainable_parameters()

35970

In [None]:
trainer.fit(model)

In [None]:
trainer.test(model)

In [None]:
path1 = 's2_mnist_cs1.gz'
path2 = 'datasets/s2_mnist_test_sphere_center.gz'

In [None]:
data1, data2 = load_test_data(path1), load_train_data(path2)

In [None]:
x = train_data[0][0].unsqueeze(0)
xs = x.size()
print(xs)
x = x.reshape(xs[0], -1)
print(x.size())

In [68]:
class MLP(pl.LightningModule):
    def __init__(self, hparams, train_data, test_data):
        super().__init__()
        
        self.hparams = copy.deepcopy(hparams)
        self.train_data = train_data
        self.test_data = test_data

        self.activation_fn = self.hparams.activation_fn
        self.batch_norm = self.hparams.batch_norm
        self.nodes = self.hparams.nodes.copy()
        
        self.loss_function = torch.nn.CrossEntropyLoss()
        
        possible_activation_fns = ['ReLU', 'LeakyReLU']
        assert self.activation_fn in possible_activation_fns

        module_list = []
        
        self.nodes.insert(0,28*28)
        self.nodes.append(10)
        
        for i in range(len(self.nodes) - 1):
            in_nodes = self.nodes[i]
            out_nodes = self.nodes[i+1]
            if self.batch_norm:
                if i>0:
                    module_list.append(torch.nn.BatchNorm1d(in_nodes))
            module_list.append(torch.nn.Linear(in_features=in_nodes, out_features=out_nodes))
            if i != (len(self.nodes) - 2):
                if self.activation_fn == 'ReLU':
                    module_list.append(torch.nn.ReLU())
                elif self.activation_fn == 'LeakyReLU':
                    module_list.append(torch.nn.LeakyReLU())
                else:
                    raise NotImplementedError(f"Activation function must be in {possible_activation_fns}.")
                
        self.dense = torch.nn.Sequential(*module_list)
        
    def forward(self, x):
        xs = x.size()
        x = x.reshape(xs[0], -1)
        x = self.dense(x)
        return x
    
    def loss(self, x, y_true):
        y_pred = self(x)
        loss = self.loss_function(y_pred, y_true)
        return loss
    
    def correct_predictions(self, x, y_true):
        outputs = self(x)
        _, y_pred = torch.max(outputs, 1)
        correct = (y_pred == y_true).long().sum()
        return correct
    
    def prepare_data(self):
        pass

    def train_dataloader(self):
        return torch.utils.data.DataLoader(dataset=self.train_data,
                                           batch_size=self.hparams.train_batch_size,
                                           shuffle=True, num_workers=self.hparams.num_workers)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(dataset=self.test_data,
                                           batch_size=self.hparams.test_batch_size,
                                           shuffle=False, num_workers=self.hparams.num_workers)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(dataset=self.test_data,
                                           batch_size=self.hparams.test_batch_size,
                                           shuffle=False, num_workers=self.hparams.num_workers)
    
    def configure_optimizers(self):
        self._optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr,
                                            weight_decay=self.hparams.weight_decay, amsgrad=False)
        
        return {'optimizer': self._optimizer}

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss(x, y)
        correct = self.correct_predictions(x, y)
        
        logs = {'loss': loss.cpu().item()}
        return {'loss': loss, 'train_correct': correct, 'log': logs}
    
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean().cpu().item()
        train_correct = torch.stack([x['train_correct'] for x in outputs]).sum().cpu()
        train_acc = train_correct / len(self.train_data)
        
        logs = {'train_loss': avg_loss, 'train_acc': train_acc}    
        return {'train_loss': avg_loss, 'train_acc': train_acc, 'log': logs}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss(x, y)
        correct = self.correct_predictions(x, y)
        return {'val_loss': loss, 'val_correct': correct}
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean().cpu().item()
        val_correct = torch.stack([x['val_correct'] for x in outputs]).sum().cpu()
        val_acc = val_correct / len(self.test_data)

        logs = {'val_loss': avg_loss, 'val_acc': val_acc}        
        return {'val_loss': avg_loss, 'val_acc': val_acc, 'log': logs}

    def test_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss(x, y)
        correct = self.correct_predictions(x, y)
        return {'test_loss': loss, 'test_correct': correct}
    
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean().cpu().item()
        test_correct = torch.stack([x['test_correct'] for x in outputs]).sum().cpu()
        test_acc = test_correct / len(self.test_data)

        logs = {'test_loss': avg_loss, 'test_acc': test_acc}        
        return {'test_loss': avg_loss, 'test_acc': test_acc, 'log': logs}

    def get_progress_bar_dict(self):
        running_train_loss = self.trainer.running_loss.mean()
        avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN')
        lr = self.hparams.lr

        tqdm_dict = {
            'loss': '{:.2E}'.format(avg_training_loss),
            'lr': '{:.2E}'.format(lr),
        }

        if self.trainer.truncated_bptt_steps is not None:
            tqdm_dict['split_idx'] = self.trainer.split_idx

        if self.trainer.logger is not None and self.trainer.logger.version is not None:
            tqdm_dict['v_num'] = self.trainer.logger.version

        return tqdm_dict

    
    def count_trainable_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters())

In [101]:
MLP_hparams = argparse.Namespace()

MLP_hparams.name = 'test_model'
MLP_hparams.train_batch_size = 32
MLP_hparams.test_batch_size = 32
MLP_hparams.num_workers = 0
MLP_hparams.lr = 1e-3
MLP_hparams.weight_decay = 0.

MLP_hparams.activation_fn = 'LeakyReLU'
MLP_hparams.batch_norm = True
# MLP_hparams.nodes = [44, 23]
MLP_hparams.nodes = [500, 150]

In [102]:
MLP_model = MLP(MLP_hparams, train_data, test_data)

In [103]:
MLP_model.count_trainable_parameters()

470460

In [48]:
MLP_model

MLP(
  (loss_function): CrossEntropyLoss()
  (dense): Sequential(
    (0): Linear(in_features=784, out_features=42, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): BatchNorm1d(42, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=42, out_features=22, bias=True)
    (4): LeakyReLU(negative_slope=0.01)
    (5): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Linear(in_features=22, out_features=10, bias=True)
  )
)

In [90]:
S2_hparams = argparse.Namespace()

S2_hparams.name = 'test_model'
S2_hparams.train_batch_size = 32
S2_hparams.test_batch_size = 32
S2_hparams.num_workers = 0
S2_hparams.lr = 1e-4
S2_hparams.weight_decay = 0.

S2_hparams.channels = [8, 11, 64]
S2_hparams.bandlimit = [14, 8, 2]
S2_hparams.kernel_max_beta = [0.05, 0.125, 0.5]
S2_hparams.activation_fn = 'LeakyReLU'
S2_hparams.batch_norm = True
S2_hparams.nodes = [64, 32]

In [91]:
S2_model = S2ConvNet(S2_hparams, train_data, test_data)
S2_model.count_trainable_parameters()

35533