In [1]:
import torch, argparse, gzip, os, pickle, warnings, copy, time
import numpy as np, pytorch_lightning as pl
from tqdm.notebook import tqdm
from s2cnn import s2_near_identity_grid, so3_near_identity_grid, SO3Convolution, S2Convolution, so3_integrate

warnings.filterwarnings('ignore')

In [2]:
TEST_PATH = "s2_mnist.gz"
MAX_EPOCHS = 20

In [3]:
class S2ConvNet(pl.LightningModule):
    def __init__(self, hparams, train_data, test_data):
        super().__init__()
        
        self.hparams = hparams
        self.train_data = train_data
        self.test_data = test_data
        
        self.channels = hparams.channels
        self.bandlimit = hparams.bandlimit
        self.kernel_max_beta = hparams.kernel_max_beta
        self.activation_fn = hparams.activation_fn
        self.batch_norm = hparams.batch_norm
        self.nodes = hparams.nodes
        
        self.loss_function = torch.nn.CrossEntropyLoss()
        
        if isinstance(self.kernel_max_beta, float):
            self.kernel_max_beta = [self.kernel_max_beta] * (len(self.channels))

        assert len(self.channels) == len(self.bandlimit) == len(self.kernel_max_beta)
        possible_activation_fns = ['ReLU', 'LeakyReLU']
        assert self.activation_fn in possible_activation_fns
        
        
        grid_s2 = s2_near_identity_grid(max_beta=self.kernel_max_beta[0] * np.pi, n_alpha=6, n_beta=1)
        grids_so3 = [
            so3_near_identity_grid(max_beta=max_beta * np.pi, n_alpha=6, n_beta=1, n_gamma=6) for max_beta in self.kernel_max_beta[1:]
        ]
        
        module_list = []
        self.channels.insert(0,1) # greyscale
        self.bandlimit.insert(0,30) # depends on image size
        
        in_ch = self.channels[0]
        out_ch = self.channels[1]
        b_in = self.bandlimit[0]
        b_out = self.bandlimit[1]
        
        module_list.append(S2Convolution(
            nfeature_in=in_ch, nfeature_out=out_ch, b_in=b_in, b_out=b_out, grid=grid_s2
        ))
        
        if self.activation_fn == 'ReLU':
                module_list.append(torch.nn.ReLU())
        elif self.activation_fn == 'LeakyReLU':
            module_list.append(torch.nn.LeakyReLU())
        else:
            raise NotImplementedError(f"Activation function must be in {possible_activation_fns}.")
        
        for i in range(1, len(self.channels)-1):
            in_ch = self.channels[i]
            out_ch = self.channels[i+1]
            b_in = self.bandlimit[i]
            b_out = self.bandlimit[i+1]
            
            module_list.append(
                SO3Convolution(
                    nfeature_in=in_ch,
                    nfeature_out=out_ch,
                    b_in=b_in,
                    b_out=b_out,
                    grid=grids_so3[i-1],
                )
            )
            
            if self.activation_fn == 'ReLU':
                module_list.append(torch.nn.ReLU())
            elif self.activation_fn == 'LeakyReLU':
                module_list.append(torch.nn.LeakyReLU())
            else:
                raise NotImplementedError(f"Activation function must be in {possible_activation_fns}.")
            
        self.conv = torch.nn.Sequential(*module_list)
        
        
        module_list = []
        
        self.nodes.insert(0,self.channels[-1])
        self.nodes.append(10)
        
        for i in range(len(self.nodes) - 1):
            in_nodes = self.nodes[i]
            out_nodes = self.nodes[i+1]
            if self.batch_norm:
                module_list.append(torch.nn.BatchNorm1d(in_nodes))
            module_list.append(torch.nn.Linear(in_features=in_nodes, out_features=out_nodes))
            if i != (len(self.nodes) - 2):
                if self.activation_fn == 'ReLU':
                    module_list.append(torch.nn.ReLU())
                elif self.activation_fn == 'LeakyReLU':
                    module_list.append(torch.nn.LeakyReLU())
                else:
                    raise NotImplementedError(f"Activation function must be in {possible_activation_fns}.")
                
        self.dense = torch.nn.Sequential(*module_list)
        
    def forward(self, x):
        x = self.conv(x)
        x = so3_integrate(x)
        x = self.dense(x)
        return x
    
    def loss(self, x, y_true):
        y_pred = self(x)
        loss = self.loss_function(y_pred, y_true)
        return loss
    
    def correct_predictions(self, x, y_true):
        outputs = self(x)
        _, y_pred = torch.max(outputs, 1)
        correct = (y_pred == y_true).long().sum()
        return correct
    
    def prepare_data(self):
        pass

    def train_dataloader(self):
        return torch.utils.data.DataLoader(dataset=self.train_data,
                                           batch_size=self.hparams.train_batch_size,
                                           shuffle=True, num_workers=self.hparams.num_workers)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(dataset=self.test_data,
                                           batch_size=self.hparams.test_batch_size,
                                           shuffle=False, num_workers=self.hparams.num_workers)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(dataset=self.test_data,
                                           batch_size=self.hparams.test_batch_size,
                                           shuffle=False, num_workers=self.hparams.num_workers)
    
    def configure_optimizers(self):
        self._optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr,
                                            weight_decay=self.hparams.weight_decay, amsgrad=False)
        
        return {'optimizer': self._optimizer}
    
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss(x, y)
        logs = {'loss': loss.cpu().item()}
        return {'loss': loss, 'log': logs}
    
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean().cpu().item()
        return {'avg_train_loss': avg_loss}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss(x, y)
        correct = self.correct_predictions(x, y)
        return {'val_loss': loss, 'val_correct': correct}
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean().cpu().item()
        test_correct = torch.stack([x['val_correct'] for x in outputs]).sum().cpu()
        test_acc = test_correct / len(self.test_data)

        logs = {'val_loss': avg_loss, 'val_acc': test_acc}        
        return {'val_loss': avg_loss, 'val_acc': test_acc, 'log': logs}

    def test_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss(x, y)
        correct = self.correct_predictions(x, y)
        return {'test_loss': loss, 'test_correct': correct}
    
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean().cpu().item()
        test_correct = torch.stack([x['test_correct'] for x in outputs]).sum().cpu()
        test_acc = test_correct / len(self.test_data)

        logs = {'test_loss': avg_loss, 'test_acc': test_acc}        
        return {'test_loss': avg_loss, 'test_acc': test_acc, 'log': logs}

    def get_progress_bar_dict(self):
        # call .item() only once but store elements without graphs
        running_train_loss = self.trainer.running_loss.mean()
        avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN')
        lr = self.hparams.lr

        tqdm_dict = {
            'loss': '{:.2E}'.format(avg_training_loss),
            'lr': '{:.2E}'.format(lr),
        }

        if self.trainer.truncated_bptt_steps is not None:
            tqdm_dict['split_idx'] = self.trainer.split_idx

        if self.trainer.logger is not None and self.trainer.logger.version is not None:
            tqdm_dict['v_num'] = self.trainer.logger.version

        return tqdm_dict
    
    def count_trainable_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters())

In [4]:
def load_train_data(path):
    
    with gzip.open(path, 'rb') as f:
        dataset = pickle.load(f)
        
    train_data = torch.from_numpy(dataset["images"][:, None, :, :].astype(np.float32))
    train_labels = torch.from_numpy(dataset["labels"].astype(np.int64))

    train_dataset = torch.utils.data.TensorDataset(train_data, train_labels)
    
    return train_dataset
    
def load_test_data(path):
    
    with gzip.open(path, 'rb') as f:
        dataset = pickle.load(f)
        
    test_data = torch.from_numpy(dataset["test"]["images"][:, None, :, :].astype(np.float32))
    test_labels = torch.from_numpy(dataset["test"]["labels"].astype(np.int64))

    test_dataset = torch.utils.data.TensorDataset(test_data, test_labels)
    
    return test_dataset

In [5]:
if torch.cuda.is_available():
    print('GPU available: ' + torch.cuda.get_device_name())
else:
    raise RuntimeError('No GPU found.')

GPU available: NVIDIA GeForce RTX 2070 SUPER


In [6]:
hparams = argparse.Namespace()

hparams.name = '157k'
hparams.train_batch_size = 32
hparams.test_batch_size = 32
hparams.num_workers = 0
hparams.lr = 1e-4
hparams.weight_decay = 0.

hparams.channels = [8, 16, 16, 24, 24, 32, 64]
hparams.bandlimit = [30, 15, 15, 8, 8, 4, 2]
hparams.kernel_max_beta = [0.0625, 0.0625, 0.125, 0.125, 0.25, 0.25, 0.5]
hparams.activation_fn = 'ReLU'
hparams.batch_norm = True
hparams.nodes = [64, 32]

In [7]:
def train_model(hparams, train_data, test_data):
    
    args_copy = copy.deepcopy(vars(hparams))
    hparams1 = argparse.Namespace(**args_copy)
    
    model = S2ConvNet(hparams1, train_data, test_data)

    print(f"Number of trainable / total parameters: {model.count_trainable_parameters(), model.count_trainable_parameters()}")

    monitor = 'val_acc'
    mode = 'max'
    early_stopping = pl.callbacks.EarlyStopping(monitor=monitor, min_delta=0., patience=10, mode=mode)
    checkpoint = pl.callbacks.model_checkpoint.ModelCheckpoint(monitor=monitor, mode=mode)

    trainer = pl.Trainer(gpus=1, max_epochs=MAX_EPOCHS, logger=False, early_stop_callback=early_stopping, checkpoint_callback=checkpoint)

    trainer.fit(model)

    best_model = torch.load(checkpoint.best_model_path)
    model.load_state_dict(best_model['state_dict'])
    model.eval()
    test_results = trainer.test(model)
    
    return test_results, copy.deepcopy(vars(hparams))

In [None]:
test_data = load_test_data(TEST_PATH)

print("Total test examples: {}".format(len(test_data)))

training_set_sizes = [10000, 20000, 30000, 40000, 50000]

middle_dummy = []
with tqdm(total=len(training_set_sizes)) as qbar:
    for TRAIN_SAMPLES in training_set_sizes:
        TRAIN_PATH = "s2_mnist_train_dwr_" + str(TRAIN_SAMPLES) + ".gz"
        train_data = load_train_data(TRAIN_PATH)
        print("Total training examples: {}".format(len(train_data)))

        inner_dummy = []
        for i in tqdm(range(3)):
            test_results, resulting_hparams = train_model(hparams, train_data, test_data)

            inner_dummy.append([test_results, resulting_hparams])

        middle_dummy.append(inner_dummy)
        qbar.update(1)

Total test examples: 10000


  0%|          | 0/5 [00:00<?, ?it/s]

Total training examples: 10000


  0%|          | 0/3 [00:00<?, ?it/s]

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


Number of trainable / total parameters: (156882, 156882)



  | Name          | Type             | Params
---------------------------------------------------
0 | loss_function | CrossEntropyLoss | 0     
1 | conv          | Sequential       | 149 K 
2 | dense         | Sequential       | 6 K   


Validation sanity check: 0it [00:00, ?it/s]

load 0.pkl.gz... done
load 0.pkl.gz... done
load 1.pkl.gz... done
load 2.pkl.gz... done
load 1.pkl.gz... done
load 3.pkl.gz... done
load 4.pkl.gz... done
load 2.pkl.gz... done
load 14.pkl.gz... done
compute 14.pkl.gz... save 14.pkl.gz... done
load 15.pkl.gz... done
load 16.pkl.gz... done
compute 15.pkl.gz... save 15.pkl.gz... done
load 17.pkl.gz... done
compute 16.pkl.gz... save 16.pkl.gz... done
load 18.pkl.gz... done
load 19.pkl.gz... done
compute 17.pkl.gz... save 17.pkl.gz... done
load 20.pkl.gz... done


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
filename = 'S2CNN_first_tests_2_smaller_training_sets.pickle'

if os.path.isfile(filename):
    filename = str(time.time()) + filename
    print('File already existed, timestamp was prepended to filename.')
    
with open(filename, 'wb') as file:
    pickle.dump(middle_dummy, file)