In [1]:
from acevedo_clss_and_fcns import * 
device = 'cpu'
if torch.cuda.is_available():
    torch.cuda.init()
    if torch.cuda.is_initialized():
        device = 'cuda:0'
print(f"{device = }")



  from .autonotebook import tqdm as notebook_tqdm


device = 'cuda:0'


In [2]:
def train_one_epoch(modelo: GIN_classifier_to_explain_v2,
                    optimizer, 
                    train_loader: torch_geometric.loader.dataloader.DataLoader,
                    loss_fun: torch.nn.modules.loss,
                    device:str='cpu' ):

    correct = 0
    for i, data in enumerate(train_loader):
        assert not data.is_cuda   
        if (device == 'cuda:0') | (device == 'cuda'):                            
            data.to(device, non_blocking=True) 
            assert data.is_cuda       
                
        optimizer.zero_grad(set_to_none=True) # Zero your gradients for every batch        
        if (device == 'cuda:0') | (device == 'cuda'):
            #with torch.cuda.amp.autocast():      
            predictions = modelo(data.x, data.edge_index,  data.batch)# Make predictions for this batch
            loss        = loss_fun(predictions, data.y)
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.        
            pred     = predictions.argmax(dim=1)  # Use the class with highest probability.
            correct += int((pred == data.y).sum())  # Check against ground-truth labels.

    return correct / len(train_loader.dataset)

def validate(modelo: GIN_classifier_to_explain_v2, loader: DataLoader, device: str = 'cpu'):
    modelo.eval()
    correct = 0
    for i, val_data in enumerate(loader):
        
        assert not val_data.is_cuda
        if (device == 'cuda:0') | (device == 'cuda'):
            val_data.to(device, non_blocking=True) 
            assert val_data.is_cuda                          

        val_predictions = modelo(val_data.x, val_data.edge_index,  val_data.batch)# Make predictions for this batch
        pred            = val_predictions.argmax(dim=1)

        correct += int((pred == val_data.y).sum())
        

    return correct / len(loader.dataset)   

#loader_path = "./results/dataloaders/MASKED_loader_Concen_plus_Fluxes.pt"
saving_path     = "./results/trained_pytorch_models"
model_subfolder = "/Non_masked_Phe"
loader_path     = "./results/dataloaders/loader_Concen_plus_Fluxes.pt"
loader = torch.load(loader_path)
a_batch         = next(iter(loader.get_train_loader()))
a_graph         = a_batch[0]
model           = GIN_classifier_to_explain_v2(
                                            n_nodes = a_graph.num_nodes, 
                                            num_features = a_graph.num_node_features, 
                                            n_classes = a_graph.num_classes,
                                            hidden_dim=8,
                                            num_layers=2).to(device, non_blocking=True).to(device)


optimizer       = torch.optim.Adam(model.parameters())
loss_function   = torch.nn.NLLLoss()
best_validation_accuracy = 1e-10
EPOCHS = 10
verbose = True


In [3]:
gc.collect()
torch.cuda.empty_cache() 
for epoch in tqdm.tqdm(range(EPOCHS)):
    
    train_accuracy = train_one_epoch(model,
                        optimizer=optimizer, 
                        train_loader=loader.get_train_loader(),
                        loss_fun=loss_function,
                        device = device)

    validation_accuracy = validate(model, loader.get_validation_loader(), device)
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        del validation_accuracy
        best_val_state_dict   = copy.deepcopy(model.state_dict())
        best_val_model        = copy.deepcopy(model)
        if verbose:
            timestamp     = datetime.now().strftime('%d-%m-%Y_%Hh_%Mmin')              
            print(f'Epoch: {epoch:03d}, train_accuracy: {train_accuracy:.4f}, best_validation_accuracy: {best_validation_accuracy:.4f}')
            model_path = saving_path+model_subfolder+'/Model_{}_{}_best_ValAcc_{}_epoch_{}.pt'.format(model.__class__.__name__,timestamp, best_validation_accuracy, epoch)
            torch.save(best_val_model, model_path)
            print(f"saved as {model_path}")



 10%|█         | 1/10 [00:04<00:41,  4.61s/it]

Epoch: 000, train_accuracy: 0.4948, best_validation_accuracy: 0.4975
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_22min_best_ValAcc_0.4975480950584685_epoch_0.pt


 20%|██        | 2/10 [00:08<00:31,  3.93s/it]

Epoch: 001, train_accuracy: 0.8204, best_validation_accuracy: 0.9623
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_22min_best_ValAcc_0.9622783855149001_epoch_1.pt


 30%|███       | 3/10 [00:11<00:25,  3.71s/it]

Epoch: 002, train_accuracy: 0.9699, best_validation_accuracy: 0.9740
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_22min_best_ValAcc_0.9739720860052811_epoch_2.pt


 40%|████      | 4/10 [00:15<00:21,  3.62s/it]

Epoch: 003, train_accuracy: 0.9760, best_validation_accuracy: 0.9777
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_22min_best_ValAcc_0.9777442474537911_epoch_3.pt


 50%|█████     | 5/10 [00:18<00:17,  3.57s/it]

Epoch: 004, train_accuracy: 0.9797, best_validation_accuracy: 0.9808
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_22min_best_ValAcc_0.980761976612599_epoch_4.pt


 60%|██████    | 6/10 [00:21<00:14,  3.55s/it]

Epoch: 005, train_accuracy: 0.9819, best_validation_accuracy: 0.9838
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_22min_best_ValAcc_0.983779705771407_epoch_5.pt


 70%|███████   | 7/10 [00:25<00:10,  3.54s/it]

Epoch: 006, train_accuracy: 0.9825, best_validation_accuracy: 0.9853
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_22min_best_ValAcc_0.985288570350811_epoch_6.pt


 80%|████████  | 8/10 [00:29<00:07,  3.53s/it]

Epoch: 007, train_accuracy: 0.9846, best_validation_accuracy: 0.9879
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_22min_best_ValAcc_0.987929083364768_epoch_7.pt


100%|██████████| 10/10 [00:36<00:00,  3.61s/it]

Epoch: 009, train_accuracy: 0.9866, best_validation_accuracy: 0.9883
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_22min_best_ValAcc_0.988306299509619_epoch_9.pt





# Masked

In [5]:
saving_path     = "./results/trained_pytorch_models"
model_subfolder = "/Masked_Phe"
loader_path     = "./results/dataloaders/MASKED_loader_Concen_plus_Fluxes.pt"
loader = torch.load(loader_path)
model           = GIN_classifier_to_explain_v2(
                                            n_nodes = a_graph.num_nodes, 
                                            num_features = a_graph.num_node_features, 
                                            n_classes = a_graph.num_classes,
                                            hidden_dim=8,
                                            num_layers=2).to(device, non_blocking=True).to(device)

optimizer       = torch.optim.Adam(model.parameters())
loss_function   = torch.nn.NLLLoss()
best_validation_accuracy = 1e-10
EPOCHS = 100
verbose = True
gc.collect()
torch.cuda.empty_cache() 
for epoch in tqdm.tqdm(range(EPOCHS)):
    
    train_accuracy = train_one_epoch(model,
                        optimizer=optimizer, 
                        train_loader=loader.get_train_loader(),
                        loss_fun=loss_function,
                        device = device)

    validation_accuracy = validate(model, loader.get_validation_loader(), device)
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        del validation_accuracy
        best_val_state_dict   = copy.deepcopy(model.state_dict())
        best_val_model        = copy.deepcopy(model)
        if verbose:
            timestamp     = datetime.now().strftime('%d-%m-%Y_%Hh_%Mmin')              
            print(f'Epoch: {epoch:03d}, train_accuracy: {train_accuracy:.4f}, best_validation_accuracy: {best_validation_accuracy:.4f}')
            model_path = saving_path+model_subfolder+'/Model_{}_{}_best_ValAcc_{}_epoch_{}.pt'.format(model.__class__.__name__,timestamp, best_validation_accuracy, epoch)
            torch.save(best_val_model, model_path)
            print(f"saved as {model_path}")

  1%|          | 1/100 [00:03<05:53,  3.58s/it]

Epoch: 000, train_accuracy: 0.5072, best_validation_accuracy: 0.4791
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.4790645039607695_epoch_0.pt


  2%|▏         | 2/100 [00:07<05:47,  3.55s/it]

Epoch: 001, train_accuracy: 0.6283, best_validation_accuracy: 0.6416
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.6416446623915504_epoch_1.pt


  3%|▎         | 3/100 [00:10<05:43,  3.54s/it]

Epoch: 002, train_accuracy: 0.6893, best_validation_accuracy: 0.7043
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.7042625424368163_epoch_2.pt


  4%|▍         | 4/100 [00:14<05:38,  3.52s/it]

Epoch: 003, train_accuracy: 0.7104, best_validation_accuracy: 0.7261
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.7261410788381742_epoch_3.pt


  5%|▌         | 5/100 [00:17<05:35,  3.53s/it]

Epoch: 004, train_accuracy: 0.7232, best_validation_accuracy: 0.7348
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.7348170501697473_epoch_4.pt


  6%|▌         | 6/100 [00:21<05:31,  3.53s/it]

Epoch: 005, train_accuracy: 0.7488, best_validation_accuracy: 0.7605
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.7604677480196153_epoch_5.pt


  7%|▋         | 7/100 [00:24<05:26,  3.51s/it]

Epoch: 006, train_accuracy: 0.7635, best_validation_accuracy: 0.7869
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.7868728781591852_epoch_6.pt


  8%|▊         | 8/100 [00:28<05:24,  3.52s/it]

Epoch: 007, train_accuracy: 0.7653, best_validation_accuracy: 0.7959
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.7959260656356092_epoch_7.pt


  9%|▉         | 9/100 [00:31<05:20,  3.52s/it]

Epoch: 008, train_accuracy: 0.7887, best_validation_accuracy: 0.8008
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.8008298755186722_epoch_8.pt


 10%|█         | 10/100 [00:35<05:16,  3.52s/it]

Epoch: 009, train_accuracy: 0.7929, best_validation_accuracy: 0.8110
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.8110147114296492_epoch_9.pt


 12%|█▏        | 12/100 [00:42<05:10,  3.53s/it]

Epoch: 011, train_accuracy: 0.8080, best_validation_accuracy: 0.8223
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.8223311957751792_epoch_11.pt


 13%|█▎        | 13/100 [00:45<05:07,  3.53s/it]

Epoch: 012, train_accuracy: 0.8112, best_validation_accuracy: 0.8257
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.8257261410788381_epoch_12.pt


 14%|█▍        | 14/100 [00:49<05:03,  3.53s/it]

Epoch: 013, train_accuracy: 0.8198, best_validation_accuracy: 0.8261
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.8261033572236892_epoch_13.pt


 15%|█▌        | 15/100 [00:52<04:59,  3.52s/it]

Epoch: 014, train_accuracy: 0.8261, best_validation_accuracy: 0.8303
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_26min_best_ValAcc_0.8302527348170502_epoch_14.pt


 16%|█▌        | 16/100 [00:56<04:56,  3.53s/it]

Epoch: 015, train_accuracy: 0.8258, best_validation_accuracy: 0.8374
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8374198415692191_epoch_15.pt


 17%|█▋        | 17/100 [00:59<04:52,  3.53s/it]

Epoch: 016, train_accuracy: 0.8378, best_validation_accuracy: 0.8397
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8396831384383252_epoch_16.pt


 18%|█▊        | 18/100 [01:03<04:49,  3.53s/it]

Epoch: 017, train_accuracy: 0.8393, best_validation_accuracy: 0.8465
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8464730290456431_epoch_17.pt


 19%|█▉        | 19/100 [01:07<04:45,  3.53s/it]

Epoch: 018, train_accuracy: 0.8384, best_validation_accuracy: 0.8469
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8468502451904941_epoch_18.pt


 22%|██▏       | 22/100 [01:17<04:34,  3.52s/it]

Epoch: 021, train_accuracy: 0.8525, best_validation_accuracy: 0.8472
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8472274613353451_epoch_21.pt


 23%|██▎       | 23/100 [01:21<04:31,  3.53s/it]

Epoch: 022, train_accuracy: 0.8571, best_validation_accuracy: 0.8574
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8574122972463222_epoch_22.pt


 24%|██▍       | 24/100 [01:24<04:28,  3.53s/it]

Epoch: 023, train_accuracy: 0.8549, best_validation_accuracy: 0.8601
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8600528102602791_epoch_23.pt


 25%|██▌       | 25/100 [01:28<04:25,  3.54s/it]

Epoch: 024, train_accuracy: 0.8602, best_validation_accuracy: 0.8616
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8615616748396832_epoch_24.pt


 26%|██▌       | 26/100 [01:31<04:21,  3.54s/it]

Epoch: 025, train_accuracy: 0.8632, best_validation_accuracy: 0.8665
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8664654847227461_epoch_25.pt


 27%|██▋       | 27/100 [01:35<04:18,  3.54s/it]

Epoch: 026, train_accuracy: 0.8663, best_validation_accuracy: 0.8699
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8698604300264051_epoch_26.pt


 32%|███▏      | 32/100 [01:52<04:00,  3.54s/it]

Epoch: 031, train_accuracy: 0.8758, best_validation_accuracy: 0.8763
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_27min_best_ValAcc_0.8762731044888721_epoch_31.pt


 33%|███▎      | 33/100 [01:56<03:57,  3.54s/it]

Epoch: 032, train_accuracy: 0.8793, best_validation_accuracy: 0.8774
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_28min_best_ValAcc_0.8774047529234251_epoch_32.pt


 36%|███▌      | 36/100 [02:07<03:45,  3.53s/it]

Epoch: 035, train_accuracy: 0.8771, best_validation_accuracy: 0.8842
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_28min_best_ValAcc_0.8841946435307431_epoch_35.pt


 40%|████      | 40/100 [02:21<03:32,  3.53s/it]

Epoch: 039, train_accuracy: 0.8821, best_validation_accuracy: 0.8883
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_28min_best_ValAcc_0.8883440211241042_epoch_39.pt


 49%|████▉     | 49/100 [02:53<02:59,  3.52s/it]

Epoch: 048, train_accuracy: 0.8874, best_validation_accuracy: 0.8951
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_28min_best_ValAcc_0.8951339117314221_epoch_48.pt


 50%|█████     | 50/100 [02:56<02:56,  3.53s/it]

Epoch: 049, train_accuracy: 0.8974, best_validation_accuracy: 0.8963
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_29min_best_ValAcc_0.8962655601659751_epoch_49.pt


 53%|█████▎    | 53/100 [03:07<02:46,  3.54s/it]

Epoch: 052, train_accuracy: 0.8970, best_validation_accuracy: 0.8985
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_29min_best_ValAcc_0.8985288570350811_epoch_52.pt


 57%|█████▋    | 57/100 [03:21<02:31,  3.52s/it]

Epoch: 056, train_accuracy: 0.8951, best_validation_accuracy: 0.9038
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_29min_best_ValAcc_0.9038098830629951_epoch_56.pt


 58%|█████▊    | 58/100 [03:24<02:28,  3.53s/it]

Epoch: 057, train_accuracy: 0.8980, best_validation_accuracy: 0.9057
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_29min_best_ValAcc_0.9056959637872501_epoch_57.pt


 60%|██████    | 60/100 [03:31<02:20,  3.51s/it]

Epoch: 059, train_accuracy: 0.8994, best_validation_accuracy: 0.9065
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_29min_best_ValAcc_0.9064503960769521_epoch_59.pt


 62%|██████▏   | 62/100 [03:38<02:13,  3.52s/it]

Epoch: 061, train_accuracy: 0.9031, best_validation_accuracy: 0.9076
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_29min_best_ValAcc_0.907582044511505_epoch_61.pt


 63%|██████▎   | 63/100 [03:42<02:10,  3.53s/it]

Epoch: 062, train_accuracy: 0.9037, best_validation_accuracy: 0.9087
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_29min_best_ValAcc_0.9087136929460581_epoch_62.pt


 74%|███████▍  | 74/100 [04:21<01:31,  3.51s/it]

Epoch: 073, train_accuracy: 0.9047, best_validation_accuracy: 0.9114
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_30min_best_ValAcc_0.911354205960015_epoch_73.pt


 76%|███████▌  | 76/100 [04:28<01:24,  3.52s/it]

Epoch: 075, train_accuracy: 0.9084, best_validation_accuracy: 0.9121
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_30min_best_ValAcc_0.9121086382497171_epoch_75.pt


 83%|████████▎ | 83/100 [04:52<00:59,  3.53s/it]

Epoch: 082, train_accuracy: 0.9094, best_validation_accuracy: 0.9132
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_30min_best_ValAcc_0.91324028668427_epoch_82.pt


 92%|█████████▏| 92/100 [05:24<00:28,  3.52s/it]

Epoch: 091, train_accuracy: 0.9120, best_validation_accuracy: 0.9147
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_27-10-2022_16h_31min_best_ValAcc_0.9147491512636741_epoch_91.pt


100%|██████████| 100/100 [05:52<00:00,  3.53s/it]


# -------------end-----------------