In [1]:
from acevedo_clss_and_fcns import * 
device = 'cpu'
if torch.cuda.is_available():
    torch.cuda.init()
    if torch.cuda.is_initialized():
        device = 'cuda:0'
print(f"{device = }")



  from .autonotebook import tqdm as notebook_tqdm


device = 'cuda:0'


In [2]:
def train_one_epoch(modelo: GIN_classifier_to_explain_v2,
                    optimizer, 
                    train_loader: torch_geometric.loader.dataloader.DataLoader,
                    loss_fun: torch.nn.modules.loss,
                    device:str='cpu' ):

    correct = 0
    for i, data in enumerate(train_loader):
        assert not data.is_cuda   
        if (device == 'cuda:0') | (device == 'cuda'):                            
            data.to(device, non_blocking=True) 
            assert data.is_cuda       
                
        optimizer.zero_grad(set_to_none=True) # Zero your gradients for every batch        
        if (device == 'cuda:0') | (device == 'cuda'):
            #with torch.cuda.amp.autocast():      
            predictions = modelo(data.x, data.edge_index,  data.batch)# Make predictions for this batch
            loss        = loss_fun(predictions, data.y)
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.        
            pred     = predictions.argmax(dim=1)  # Use the class with highest probability.
            correct += int((pred == data.y).sum())  # Check against ground-truth labels.

    return correct / len(train_loader.dataset)

def validate(modelo: GIN_classifier_to_explain_v2, loader: DataLoader, device: str = 'cpu'):
    modelo.eval()
    correct = 0
    for i, val_data in enumerate(loader):
        
        assert not val_data.is_cuda
        if (device == 'cuda:0') | (device == 'cuda'):
            val_data.to(device, non_blocking=True) 
            assert val_data.is_cuda                          

        val_predictions = modelo(val_data.x, val_data.edge_index,  val_data.batch)# Make predictions for this batch
        pred            = val_predictions.argmax(dim=1)

        correct += int((pred == val_data.y).sum())
        

    return correct / len(loader.dataset)   

#loader_path = "./results/dataloaders/MASKED_loader_Concen_plus_Fluxes.pt"
saving_path     = "./results/trained_pytorch_models"
model_subfolder = "/Non_masked_Phe"
loader_path     = "./results/dataloaders/loader_Concen_plus_Fluxes.pt"
loader = torch.load(loader_path)
a_batch         = next(iter(loader.get_train_loader()))
a_graph         = a_batch[0]
model           = GIN_classifier_to_explain_v2(
                                            n_nodes = a_graph.num_nodes, 
                                            num_features = a_graph.num_node_features, 
                                            n_classes = a_graph.num_classes,
                                            hidden_dim=8,
                                            num_layers=2).to(device, non_blocking=True).to(device)


optimizer       = torch.optim.Adam(model.parameters())
loss_function   = torch.nn.NLLLoss()
best_validation_accuracy = 1e-10
EPOCHS = 10
verbose = True


In [3]:
gc.collect()
torch.cuda.empty_cache() 
for epoch in tqdm.tqdm(range(EPOCHS)):
    
    train_accuracy = train_one_epoch(model,
                        optimizer=optimizer, 
                        train_loader=loader.get_train_loader(),
                        loss_fun=loss_function,
                        device = device)

    validation_accuracy = validate(model, loader.get_validation_loader(), device)
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        del validation_accuracy
        best_val_state_dict   = copy.deepcopy(model.state_dict())
        best_val_model        = copy.deepcopy(model)
        if verbose:
            timestamp     = datetime.now().strftime('%d-%m-%Y_%Hh_%Mmin')              
            print(f'Epoch: {epoch:03d}, train_accuracy: {train_accuracy:.4f}, best_validation_accuracy: {best_validation_accuracy:.4f}')
            model_path = saving_path+model_subfolder+'/Model_{}_{}_best_ValAcc_{}_epoch_{}.pt'.format(model.__class__.__name__,timestamp, best_validation_accuracy, epoch)
            torch.save(best_val_model, model_path)
            print(f"saved as {model_path}")



 10%|█         | 1/10 [00:04<00:41,  4.64s/it]

Epoch: 000, train_accuracy: 0.5794, best_validation_accuracy: 0.9604
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_05min_best_ValAcc_0.9603923047906451_epoch_0.pt


 20%|██        | 2/10 [00:08<00:32,  4.01s/it]

Epoch: 001, train_accuracy: 0.9616, best_validation_accuracy: 0.9770
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_05min_best_ValAcc_0.976989815164089_epoch_1.pt


 30%|███       | 3/10 [00:11<00:26,  3.82s/it]

Epoch: 002, train_accuracy: 0.9795, best_validation_accuracy: 0.9826
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_06min_best_ValAcc_0.982648057336854_epoch_2.pt


 40%|████      | 4/10 [00:15<00:22,  3.73s/it]

Epoch: 003, train_accuracy: 0.9850, best_validation_accuracy: 0.9879
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_06min_best_ValAcc_0.987929083364768_epoch_3.pt


 50%|█████     | 5/10 [00:18<00:18,  3.68s/it]

Epoch: 004, train_accuracy: 0.9860, best_validation_accuracy: 0.9891
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_06min_best_ValAcc_0.989060731799321_epoch_4.pt


 60%|██████    | 6/10 [00:22<00:14,  3.64s/it]

Epoch: 005, train_accuracy: 0.9872, best_validation_accuracy: 0.9894
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_06min_best_ValAcc_0.9894379479441721_epoch_5.pt


100%|██████████| 10/10 [00:36<00:00,  3.69s/it]

Epoch: 009, train_accuracy: 0.9884, best_validation_accuracy: 0.9898
saved as ./results/trained_pytorch_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_06min_best_ValAcc_0.989815164089023_epoch_9.pt





# Masked

In [4]:
saving_path     = "./results/trained_pytorch_models"
model_subfolder = "/Masked_Phe"
loader_path     = "./results/dataloaders/MASKED_loader_Concen_plus_Fluxes.pt"
loader = torch.load(loader_path)
model           = GIN_classifier_to_explain_v2(
                                            n_nodes = a_graph.num_nodes, 
                                            num_features = a_graph.num_node_features, 
                                            n_classes = a_graph.num_classes,
                                            hidden_dim=8,
                                            num_layers=2).to(device, non_blocking=True).to(device)

optimizer       = torch.optim.Adam(model.parameters())
loss_function   = torch.nn.NLLLoss()
best_validation_accuracy = 1e-10
EPOCHS = 50
verbose = True
gc.collect()
torch.cuda.empty_cache() 
for epoch in tqdm.tqdm(range(EPOCHS)):
    
    train_accuracy = train_one_epoch(model,
                        optimizer=optimizer, 
                        train_loader=loader.get_train_loader(),
                        loss_fun=loss_function,
                        device = device)

    validation_accuracy = validate(model, loader.get_validation_loader(), device)
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        del validation_accuracy
        best_val_state_dict   = copy.deepcopy(model.state_dict())
        best_val_model        = copy.deepcopy(model)
        if verbose:
            timestamp     = datetime.now().strftime('%d-%m-%Y_%Hh_%Mmin')              
            print(f'Epoch: {epoch:03d}, train_accuracy: {train_accuracy:.4f}, best_validation_accuracy: {best_validation_accuracy:.4f}')
            model_path = saving_path+model_subfolder+'/Model_{}_{}_best_ValAcc_{}_epoch_{}.pt'.format(model.__class__.__name__,timestamp, best_validation_accuracy, epoch)
            torch.save(best_val_model, model_path)
            print(f"saved as {model_path}")

  2%|▏         | 1/50 [00:03<02:57,  3.62s/it]

Epoch: 000, train_accuracy: 0.5237, best_validation_accuracy: 0.6813
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.6812523576009053_epoch_0.pt


  6%|▌         | 3/50 [00:10<02:48,  3.59s/it]

Epoch: 002, train_accuracy: 0.6696, best_validation_accuracy: 0.6937
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.6937004903809884_epoch_2.pt


  8%|▊         | 4/50 [00:14<02:44,  3.59s/it]

Epoch: 003, train_accuracy: 0.7019, best_validation_accuracy: 0.7077
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.7076574877404753_epoch_3.pt


 10%|█         | 5/50 [00:17<02:41,  3.58s/it]

Epoch: 004, train_accuracy: 0.7098, best_validation_accuracy: 0.7416
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.7416069407770652_epoch_4.pt


 12%|█▏        | 6/50 [00:21<02:37,  3.58s/it]

Epoch: 005, train_accuracy: 0.7492, best_validation_accuracy: 0.7714
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.7714070162202943_epoch_5.pt


 14%|█▍        | 7/50 [00:25<02:33,  3.58s/it]

Epoch: 006, train_accuracy: 0.7838, best_validation_accuracy: 0.8197
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.8196906827612221_epoch_6.pt


 16%|█▌        | 8/50 [00:28<02:30,  3.58s/it]

Epoch: 007, train_accuracy: 0.8232, best_validation_accuracy: 0.8510
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.8509996227838551_epoch_7.pt


 20%|██        | 10/50 [00:35<02:23,  3.58s/it]

Epoch: 009, train_accuracy: 0.8728, best_validation_accuracy: 0.8684
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.8683515654470011_epoch_9.pt


 22%|██▏       | 11/50 [00:39<02:19,  3.58s/it]

Epoch: 010, train_accuracy: 0.8854, best_validation_accuracy: 0.8853
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.8853262919652961_epoch_10.pt


 24%|██▍       | 12/50 [00:42<02:15,  3.58s/it]

Epoch: 011, train_accuracy: 0.8907, best_validation_accuracy: 0.8982
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.8981516408902301_epoch_11.pt


 28%|██▊       | 14/50 [00:50<02:08,  3.57s/it]

Epoch: 013, train_accuracy: 0.9021, best_validation_accuracy: 0.9000
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.9000377216144851_epoch_13.pt


 30%|███       | 15/50 [00:53<02:05,  3.58s/it]

Epoch: 014, train_accuracy: 0.9051, best_validation_accuracy: 0.9027
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_07min_best_ValAcc_0.9026782346284421_epoch_14.pt


 34%|███▍      | 17/50 [01:00<01:58,  3.58s/it]

Epoch: 016, train_accuracy: 0.9140, best_validation_accuracy: 0.9034
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_08min_best_ValAcc_0.9034326669181441_epoch_16.pt


 36%|███▌      | 18/50 [01:04<01:54,  3.58s/it]

Epoch: 017, train_accuracy: 0.9161, best_validation_accuracy: 0.9046
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_08min_best_ValAcc_0.9045643153526971_epoch_17.pt


 42%|████▏     | 21/50 [01:15<01:43,  3.58s/it]

Epoch: 020, train_accuracy: 0.9165, best_validation_accuracy: 0.9087
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_08min_best_ValAcc_0.9087136929460581_epoch_20.pt


 48%|████▊     | 24/50 [01:25<01:32,  3.58s/it]

Epoch: 023, train_accuracy: 0.9199, best_validation_accuracy: 0.9091
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_08min_best_ValAcc_0.9090909090909091_epoch_23.pt


 60%|██████    | 30/50 [01:47<01:11,  3.58s/it]

Epoch: 029, train_accuracy: 0.9222, best_validation_accuracy: 0.9095
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_08min_best_ValAcc_0.90946812523576_epoch_29.pt


 78%|███████▊  | 39/50 [02:19<00:39,  3.57s/it]

Epoch: 038, train_accuracy: 0.9256, best_validation_accuracy: 0.9106
saved as ./results/trained_pytorch_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_25-10-2022_14h_09min_best_ValAcc_0.9105997736703131_epoch_38.pt


100%|██████████| 50/50 [02:58<00:00,  3.57s/it]


# -------------end-----------------