In [1]:
!hostname

gpu19


In [2]:
import pytorch_lightning as pl

from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pickle

from Oasis_DataModule_file import Oasis_PL
from UKBB_All_Substructs_Net_file import UKBB_All_Substructs_Net2, Oasis_PCA_Net, UKBB_All_Substructs_Net2_Binary

from sklearn.metrics import confusion_matrix
import os
import torch
from torch.nn import Linear
from torch.optim import Adam, SGD

import sys
sys.path.insert(0, "/vol/bitbucket/wjb120/data/ukbb_mesh/testbrain")
import UKBB_All_Substructs_DataModule_file


In [3]:
from Oasis_DataModule_file import all_sub_structs_list
substructs = all_sub_structs_list

In [4]:
rand_seeds = [3, 4, 8, 9, 27]
min_epochs = 10
max_epochs = 100

# Set filename & redo model and fit function (e.g. double etc)

In [5]:
pre_model_base = "/vol/bitbucket/wjb120/data/ukbb_mesh/testbrain/pretraining_models/hippocampi_only/GCN_"

In [6]:
folder = "./pre_training_results/hippocampi_only"

In [7]:
models_filename_base = "GCN_v2_"

In [8]:
predictions_filename_base = models_filename_base + "preds_"
models_filename_base = os.path.join(folder, models_filename_base)
predictions_filename_base = os.path.join(folder, predictions_filename_base)
print("models: ", models_filename_base)
print("predis: ", predictions_filename_base)

models:  ./pre_training_results/hippocampi_only/GCN_v2_
predis:  ./pre_training_results/hippocampi_only/GCN_v2_preds_


In [9]:
for seed in rand_seeds:
    #re-seed everything
    pl.seed_everything(seed, workers=True)
    
    #setup data (static labels type, rebalanced)
    pl_test = Oasis_PL(
        batch_size = 20,
        reload_path = False,
        balance_classes = True,
        static = True
    )
    
    # get umeyama working for data
    pl_test.prepare_data()
    pl_test.setup(reload_sampler = True)
    pl_test.set_umeyama(pl_test.train_set, n_examples = 100)

    #setup model
    filename = pre_model_base + str(seed)
    with open(filename, 'rb') as next_model:
        model = pickle.load(next_model)
        
    # Give new linear layers to sub-models and final model
    # and turn of grads for all but them
    list_params_to_optimise = []
    model.freeze()
    for model_ in model.all_gnns:
        model_.lin = Linear(64, 1)
        list_params_to_optimise = list_params_to_optimise + list(model_.lin.parameters())    
    model.lin = Linear(2, 1)
    list_params_to_optimise = list_params_to_optimise + list(model.lin.parameters())
    new_optim = Adam(list_params_to_optimise, lr=0.0001)

    #send to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # set up callbacks
    early_stop_callback = EarlyStopping(
       monitor='val_loss',
       min_delta=0.00,
       patience=3,
       verbose=False,
       mode='min'
    )

    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        save_top_k=1,
        mode='min',
    )
    
    class Replace_Optimizer(Callback):
        def on_train_epoch_start(self, trainer, pl_module):
            if trainer.current_epoch == 0:
                print("Changing optimisers")
                trainer.optimizers = [new_optim]

    callbacks_list = [early_stop_callback, checkpoint_callback, Replace_Optimizer()]
    
    
    #setup trainer
#     callbacks_list = [callbacks_list]
    
    trainer = pl.Trainer(
        gpus=1, 
        deterministic=True, 
        min_epochs= min_epochs,
        max_epochs = max_epochs,
        callbacks=callbacks_list,
#         val_check_interval = float(0.5)
    )
    
    #fit
    model = model.double()
    trainer.fit(model, pl_test)
    model = model.to(device)
    
    # get predictions
    dl = pl_test.test_dataloader()
    true_list = []
    pred_list = []
    model.eval()
    for batch in dl:
        num_graphs = batch["y"].shape[0]
        true = batch["y"].squeeze().detach().cpu().numpy()
        true_list.extend(true)
        x_vals = batch["x"]
        for i, _ in enumerate(x_vals):
            x_vals[i] = x_vals[i].to(model.device)
        preds = model.forward(x_vals, num_graphs).squeeze()
        preds = preds.detach().cpu().numpy()
        pred_list.extend(preds)
        
    my_lists = [true_list, pred_list]
    
    #save predictions
    predictions_name = (predictions_filename_base + str(seed))
    with open(predictions_name, 'wb') as pred_file:
            pickle.dump(my_lists, pred_file)
    
    #save final model
    model_name = (models_filename_base + str(seed))
    with open(model_name, 'wb') as model_file:
            pickle.dump(model, model_file)

Global seed set to 3
  0%|          | 4/1218 [00:00<00:33, 36.01it/s]

Initialising underlying dataset... 
...loaded flat_list from ./cached_files/static_oasis_multi_per_sub_cache  with 1740 subjects...
done! 

...rebalancing dataset...
train...


100%|██████████| 1218/1218 [00:31<00:00, 38.95it/s]
  2%|▏         | 4/174 [00:00<00:04, 38.82it/s]

done! Now val...


100%|██████████| 174/174 [00:04<00:00, 36.35it/s]
  5%|▌         | 5/100 [00:00<00:02, 42.44it/s]

done!
Calculating reference shape(s) for umeyama


100%|██████████| 100/100 [00:02<00:00, 39.66it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | all_gnns  | ModuleList | 34.6 K
1 | criterion | BCELoss    | 0     
2 | lin       | Linear     | 3     
-----------------------------------------
133       Trainable params
34.4 K    Non-trainable params
34.6 K    Total params
0.138     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 3


Training: 0it [00:00, ?it/s]

Changing optimisers


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Trainer was signaled to stop but required minimum epochs (10) or minimum steps (None) has not been met. Training will continue...


Validating: 0it [00:00, ?it/s]

Trainer was signaled to stop but required minimum epochs (10) or minimum steps (None) has not been met. Training will continue...


Validating: 0it [00:00, ?it/s]

Trainer was signaled to stop but required minimum epochs (10) or minimum steps (None) has not been met. Training will continue...


Validating: 0it [00:00, ?it/s]

Trainer was signaled to stop but required minimum epochs (10) or minimum steps (None) has not been met. Training will continue...


Validating: 0it [00:00, ?it/s]

Trainer was signaled to stop but required minimum epochs (10) or minimum steps (None) has not been met. Training will continue...


Validating: 0it [00:00, ?it/s]

Trainer was signaled to stop but required minimum epochs (10) or minimum steps (None) has not been met. Training will continue...


Validating: 0it [00:00, ?it/s]

Global seed set to 4
  0%|          | 5/1218 [00:00<00:27, 43.44it/s]

Initialising underlying dataset... 
...loaded flat_list from ./cached_files/static_oasis_multi_per_sub_cache  with 1740 subjects...
done! 

...rebalancing dataset...
train...


100%|██████████| 1218/1218 [00:30<00:00, 39.38it/s]
  3%|▎         | 5/174 [00:00<00:04, 40.60it/s]

done! Now val...


100%|██████████| 174/174 [00:04<00:00, 39.19it/s]
  5%|▌         | 5/100 [00:00<00:02, 42.39it/s]

done!
Calculating reference shape(s) for umeyama


100%|██████████| 100/100 [00:02<00:00, 39.19it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | all_gnns  | ModuleList | 34.6 K
1 | criterion | BCELoss    | 0     
2 | lin       | Linear     | 3     
-----------------------------------------
133       Trainable params
34.4 K    Non-trainable params
34.6 K    Total params
0.138     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 4


Training: 0it [00:00, ?it/s]

Changing optimisers


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Global seed set to 8
  0%|          | 0/1218 [00:00<?, ?it/s]

Initialising underlying dataset... 
...loaded flat_list from ./cached_files/static_oasis_multi_per_sub_cache  with 1740 subjects...
done! 

...rebalancing dataset...
train...


100%|██████████| 1218/1218 [00:32<00:00, 37.73it/s]
  3%|▎         | 5/174 [00:00<00:04, 41.70it/s]

done! Now val...


100%|██████████| 174/174 [00:04<00:00, 39.97it/s]
  4%|▍         | 4/100 [00:00<00:02, 37.32it/s]

done!
Calculating reference shape(s) for umeyama


100%|██████████| 100/100 [00:02<00:00, 39.04it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | all_gnns  | ModuleList | 34.6 K
1 | criterion | BCELoss    | 0     
2 | lin       | Linear     | 3     
-----------------------------------------
133       Trainable params
34.4 K    Non-trainable params
34.6 K    Total params
0.138     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 8


Training: 0it [00:00, ?it/s]

Changing optimisers


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Global seed set to 9
  0%|          | 5/1218 [00:00<00:30, 40.28it/s]

Initialising underlying dataset... 
...loaded flat_list from ./cached_files/static_oasis_multi_per_sub_cache  with 1740 subjects...
done! 

...rebalancing dataset...
train...


100%|██████████| 1218/1218 [00:31<00:00, 38.81it/s]
  3%|▎         | 5/174 [00:00<00:04, 42.00it/s]

done! Now val...


100%|██████████| 174/174 [00:04<00:00, 39.54it/s]
  4%|▍         | 4/100 [00:00<00:02, 39.80it/s]

done!
Calculating reference shape(s) for umeyama


100%|██████████| 100/100 [00:02<00:00, 39.30it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | all_gnns  | ModuleList | 34.6 K
1 | criterion | BCELoss    | 0     
2 | lin       | Linear     | 3     
-----------------------------------------
133       Trainable params
34.4 K    Non-trainable params
34.6 K    Total params
0.138     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 9


Training: 0it [00:00, ?it/s]

Changing optimisers


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Global seed set to 27
  0%|          | 5/1218 [00:00<00:28, 42.41it/s]

Initialising underlying dataset... 
...loaded flat_list from ./cached_files/static_oasis_multi_per_sub_cache  with 1740 subjects...
done! 

...rebalancing dataset...
train...


100%|██████████| 1218/1218 [00:30<00:00, 39.81it/s]
  3%|▎         | 5/174 [00:00<00:03, 42.34it/s]

done! Now val...


100%|██████████| 174/174 [00:04<00:00, 40.04it/s]
  4%|▍         | 4/100 [00:00<00:02, 37.77it/s]

done!
Calculating reference shape(s) for umeyama


100%|██████████| 100/100 [00:02<00:00, 38.60it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | all_gnns  | ModuleList | 34.6 K
1 | criterion | BCELoss    | 0     
2 | lin       | Linear     | 3     
-----------------------------------------
133       Trainable params
34.4 K    Non-trainable params
34.6 K    Total params
0.138     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 27


Training: 0it [00:00, ?it/s]

Changing optimisers


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]