In [1]:
from acevedo_clss_and_fcns import * 
device = 'cpu'
if torch.cuda.is_available():
    torch.cuda.init()
    if torch.cuda.is_initialized():
        device = 'cuda:0'
print(f"{device = }")



  from .autonotebook import tqdm as notebook_tqdm


device = 'cuda:0'


In [9]:
def train_one_epoch(modelo: GIN_classifier_to_explain,
                    optimizer, 
                    train_loader: torch_geometric.loader.dataloader.DataLoader,
                    loss_fun: torch.nn.modules.loss,
                    device:str='cpu' ):

    correct = 0
    for i, data in enumerate(train_loader):
        assert not data.is_cuda   
        if (device == 'cuda:0') | (device == 'cuda'):                            
            data.to(device, non_blocking=True) 
            assert data.is_cuda       
                
        optimizer.zero_grad(set_to_none=True) # Zero your gradients for every batch        
        if (device == 'cuda:0') | (device == 'cuda'):
            #with torch.cuda.amp.autocast():      
            predictions = modelo(data.x, data.edge_index,  data.batch)# Make predictions for this batch
            loss        = loss_fun(predictions, data.y)
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.        
            pred     = predictions.argmax(dim=1)  # Use the class with highest probability.
            correct += int((pred == data.y).sum())  # Check against ground-truth labels.

    return correct / len(train_loader.dataset)

def validate(modelo: GIN_classifier_to_explain, loader: DataLoader, device: str = 'cpu'):
    modelo.eval()
    correct = 0
    for i, val_data in enumerate(loader):
        
        assert not val_data.is_cuda
        if (device == 'cuda:0') | (device == 'cuda'):
            val_data.to(device, non_blocking=True) 
            assert val_data.is_cuda                          

        val_predictions = modelo(val_data.x, val_data.edge_index,  val_data.batch)# Make predictions for this batch
        pred            = val_predictions.argmax(dim=1)

        correct += int((pred == val_data.y).sum())
        

    return correct / len(loader.dataset)   

#loader_path = "./results/dataloaders/MASKED_loader_Concen_plus_Fluxes.pt"
saving_path     = "./results/trained_pytorch_models"
model_subfolder = "/Non_masked_Phe"
loader_path     = "./results/dataloaders/loader_Concen_plus_Fluxes.pt"
loader = torch.load(loader_path)
a_batch         = next(iter(loader.get_train_loader()))
a_graph         = a_batch[0]
model           = GIN_classifier_to_explain_v2(
                                            n_nodes = a_graph.num_nodes, 
                                            num_features = a_graph.num_node_features, 
                                            n_classes = a_graph.num_classes,
                                            hidden_dim=8,
                                            num_layers=2).to(device, non_blocking=True).to(device)


optimizer       = torch.optim.Adam(model.parameters())
loss_function   = torch.nn.NLLLoss()
best_validation_accuracy = 1e-10
EPOCHS = 10
verbose = True


In [10]:
gc.collect()
torch.cuda.empty_cache() 
for epoch in tqdm.tqdm(range(EPOCHS)):
    
    train_accuracy = train_one_epoch(model,
                        optimizer=optimizer, 
                        train_loader=loader.get_train_loader(),
                        loss_fun=loss_function,
                        device = device)

    validation_accuracy = validate(model, loader.get_validation_loader(), device)
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        del validation_accuracy
        best_val_state_dict   = copy.deepcopy(model.state_dict())
        best_val_model        = copy.deepcopy(model)
        if verbose:
            timestamp     = datetime.now().strftime('%d-%m-%Y_%Hh_%Mmin')              
            print(f'Epoch: {epoch:03d}, train_accuracy: {train_accuracy:.4f}, best_validation_accuracy: {best_validation_accuracy:.4f}')
            model_path = saving_path+model_subfolder+'/Model_{}_{}_best_ValAcc_{}_epoch_{}.pt'.format(model.__class__.__name__,timestamp, best_validation_accuracy, epoch)
            torch.save(best_val_model, model_path)
            print(f"saved as {model_path}")



 10%|█         | 1/10 [00:03<00:34,  3.82s/it]

Epoch: 000, train_accuracy: 0.7494, best_validation_accuracy: 0.9619
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_00min_best_ValAcc_0.961901169370049_epoch_0.pt


 20%|██        | 2/10 [00:07<00:30,  3.77s/it]

Epoch: 001, train_accuracy: 0.9774, best_validation_accuracy: 0.9800
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_00min_best_ValAcc_0.980007544322897_epoch_1.pt


 40%|████      | 4/10 [00:15<00:22,  3.76s/it]

Epoch: 003, train_accuracy: 0.9841, best_validation_accuracy: 0.9876
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_00min_best_ValAcc_0.9875518672199171_epoch_3.pt


 60%|██████    | 6/10 [00:22<00:15,  3.76s/it]

Epoch: 005, train_accuracy: 0.9872, best_validation_accuracy: 0.9887
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_00min_best_ValAcc_0.98868351565447_epoch_5.pt


 70%|███████   | 7/10 [00:26<00:11,  3.75s/it]

Epoch: 006, train_accuracy: 0.9862, best_validation_accuracy: 0.9891
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_00min_best_ValAcc_0.989060731799321_epoch_6.pt


100%|██████████| 10/10 [00:37<00:00,  3.75s/it]


In [15]:
saving_path     = "./results/trained_pytorch_models"
model_subfolder = "/Masked_Phe"
loader_path     = "./results/dataloaders/MASKED_loader_Concen_plus_Fluxes.pt"
loader = torch.load(loader_path)
model           = GIN_classifier_to_explain_v2(
                                            n_nodes = a_graph.num_nodes, 
                                            num_features = a_graph.num_node_features, 
                                            n_classes = a_graph.num_classes,
                                            hidden_dim=8,
                                            num_layers=2).to(device, non_blocking=True).to(device)

optimizer       = torch.optim.Adam(model.parameters())
loss_function   = torch.nn.NLLLoss()
best_validation_accuracy = 1e-10
EPOCHS = 50
verbose = True
gc.collect()
torch.cuda.empty_cache() 
for epoch in tqdm.tqdm(range(EPOCHS)):
    
    train_accuracy = train_one_epoch(model,
                        optimizer=optimizer, 
                        train_loader=loader.get_train_loader(),
                        loss_fun=loss_function,
                        device = device)

    validation_accuracy = validate(model, loader.get_validation_loader(), device)
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        del validation_accuracy
        best_val_state_dict   = copy.deepcopy(model.state_dict())
        best_val_model        = copy.deepcopy(model)
        if verbose:
            timestamp     = datetime.now().strftime('%d-%m-%Y_%Hh_%Mmin')              
            print(f'Epoch: {epoch:03d}, train_accuracy: {train_accuracy:.4f}, best_validation_accuracy: {best_validation_accuracy:.4f}')
            model_path = saving_path+model_subfolder+'/Model_{}_{}_best_ValAcc_{}_epoch_{}.pt'.format(model.__class__.__name__,timestamp, best_validation_accuracy, epoch)
            torch.save(best_val_model, model_path)
            print(f"saved as {model_path}")

  2%|▏         | 1/50 [00:03<03:04,  3.77s/it]

Epoch: 000, train_accuracy: 0.5097, best_validation_accuracy: 0.4855
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_06min_best_ValAcc_0.4854771784232365_epoch_0.pt


  4%|▍         | 2/50 [00:07<03:00,  3.76s/it]

Epoch: 001, train_accuracy: 0.6570, best_validation_accuracy: 0.7292
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_06min_best_ValAcc_0.7291588079969823_epoch_1.pt


  6%|▌         | 3/50 [00:11<02:55,  3.74s/it]

Epoch: 002, train_accuracy: 0.7123, best_validation_accuracy: 0.7627
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_06min_best_ValAcc_0.7627310448887212_epoch_2.pt


  8%|▊         | 4/50 [00:14<02:52,  3.74s/it]

Epoch: 003, train_accuracy: 0.7545, best_validation_accuracy: 0.7974
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_06min_best_ValAcc_0.7974349302150132_epoch_3.pt


 12%|█▏        | 6/50 [00:22<02:44,  3.75s/it]

Epoch: 005, train_accuracy: 0.7897, best_validation_accuracy: 0.8231
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_06min_best_ValAcc_0.8230856280648812_epoch_5.pt


 14%|█▍        | 7/50 [00:26<02:42,  3.78s/it]

Epoch: 006, train_accuracy: 0.8055, best_validation_accuracy: 0.8336
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_06min_best_ValAcc_0.8336476801207091_epoch_6.pt


 16%|█▌        | 8/50 [00:30<02:38,  3.77s/it]

Epoch: 007, train_accuracy: 0.8086, best_validation_accuracy: 0.8393
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_06min_best_ValAcc_0.8393059222934741_epoch_7.pt


 22%|██▏       | 11/50 [00:41<02:26,  3.76s/it]

Epoch: 010, train_accuracy: 0.8362, best_validation_accuracy: 0.8570
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_06min_best_ValAcc_0.8570350811014712_epoch_10.pt


 36%|███▌      | 18/50 [01:07<02:00,  3.76s/it]

Epoch: 017, train_accuracy: 0.8409, best_validation_accuracy: 0.8650
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_07min_best_ValAcc_0.8649566201433422_epoch_17.pt


 42%|████▏     | 21/50 [01:18<01:48,  3.75s/it]

Epoch: 020, train_accuracy: 0.8689, best_validation_accuracy: 0.8770
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_07min_best_ValAcc_0.8770275367785741_epoch_20.pt


 46%|████▌     | 23/50 [01:26<01:41,  3.75s/it]

Epoch: 022, train_accuracy: 0.8588, best_validation_accuracy: 0.8857
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_07min_best_ValAcc_0.8857035081101471_epoch_22.pt


 50%|█████     | 25/50 [01:33<01:33,  3.75s/it]

Epoch: 024, train_accuracy: 0.8736, best_validation_accuracy: 0.8910
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_07min_best_ValAcc_0.8909845341380611_epoch_24.pt


 62%|██████▏   | 31/50 [01:56<01:11,  3.75s/it]

Epoch: 030, train_accuracy: 0.8903, best_validation_accuracy: 0.8948
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_08min_best_ValAcc_0.8947566955865711_epoch_30.pt


 66%|██████▌   | 33/50 [02:03<01:03,  3.75s/it]

Epoch: 032, train_accuracy: 0.8862, best_validation_accuracy: 0.8982
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_08min_best_ValAcc_0.8981516408902301_epoch_32.pt


 72%|███████▏  | 36/50 [02:15<00:52,  3.75s/it]

Epoch: 035, train_accuracy: 0.8970, best_validation_accuracy: 0.9008
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_08min_best_ValAcc_0.9007921539041871_epoch_35.pt


 80%|████████  | 40/50 [02:30<00:37,  3.76s/it]

Epoch: 039, train_accuracy: 0.9075, best_validation_accuracy: 0.9095
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_08min_best_ValAcc_0.90946812523576_epoch_39.pt


 94%|█████████▍| 47/50 [02:56<00:11,  3.76s/it]

Epoch: 046, train_accuracy: 0.9132, best_validation_accuracy: 0.9121
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_09min_best_ValAcc_0.9121086382497171_epoch_46.pt


100%|██████████| 50/50 [03:07<00:00,  3.76s/it]

Epoch: 049, train_accuracy: 0.9179, best_validation_accuracy: 0.9132
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_23-10-2022_15h_09min_best_ValAcc_0.91324028668427_epoch_49.pt





# -------------end-----------------