In [1]:
from acevedo_clss_and_fcns import * 
device = 'cpu'
if torch.cuda.is_available():
    torch.cuda.init()
    if torch.cuda.is_initialized():
        device = 'cuda:0'
print(f"{device = }")
def train_one_epoch(modelo: GIN_classifier_to_explain_v2,
                    optimizer, 
                    train_loader: torch_geometric.loader.dataloader.DataLoader,
                    loss_fun: torch.nn.modules.loss,
                    device:str='cpu' ):

    correct = 0
    for i, data in enumerate(train_loader):
        assert not data.is_cuda   
        if (device == 'cuda:0') | (device == 'cuda'):                            
            data.to(device, non_blocking=True) 
            assert data.is_cuda       
                
        optimizer.zero_grad(set_to_none=True) # Zero your gradients for every batch        
        if (device == 'cuda:0') | (device == 'cuda'):
            #with torch.cuda.amp.autocast():      
            predictions = modelo(data.x, data.edge_index,  data.batch)# Make predictions for this batch
            loss        = loss_fun(predictions, data.y)
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.        
            pred     = predictions.argmax(dim=1)  # Use the class with highest probability.
            correct += int((pred == data.y).sum())  # Check against ground-truth labels.

    return correct / len(train_loader.dataset)

def validate(modelo: GIN_classifier_to_explain_v2, loader: DataLoader, device: str = 'cpu'):
    modelo.eval()
    correct = 0
    for i, val_data in enumerate(loader):
        
        assert not val_data.is_cuda
        if (device == 'cuda:0') | (device == 'cuda'):
            val_data.to(device, non_blocking=True) 
            assert val_data.is_cuda                          

        val_predictions = modelo(val_data.x, val_data.edge_index,  val_data.batch)# Make predictions for this batch
        pred            = val_predictions.argmax(dim=1)

        correct += int((pred == val_data.y).sum())
        

    return correct / len(loader.dataset)   

  from .autonotebook import tqdm as notebook_tqdm


device = 'cuda:0'


### Unmasked

In [2]:
loader_path     = "./results/dataloaders/loader_Concen_plus_Fluxes.pt"
saving_folder     = "./results/trained_models"
saving_subfolder = "/Non_masked_Phe"
loader = torch.load(loader_path)
a_batch         = next(iter(loader.get_train_loader()))
a_graph         = a_batch[0]
model           = GIN_classifier_to_explain_v2(
                                            n_nodes = a_graph.num_nodes, 
                                            num_features = a_graph.num_node_features, 
                                            n_classes = a_graph.num_classes,
                                            hidden_dim=8,
                                            num_layers=2).to(device, non_blocking=True).to(device)


optimizer       = torch.optim.Adam(model.parameters())
loss_function   = torch.nn.NLLLoss()
best_validation_accuracy = 1e-10
EPOCHS = 10
verbose = True


In [4]:
gc.collect()
torch.cuda.empty_cache() 
for epoch in tqdm.tqdm(range(EPOCHS)):
    
    train_accuracy = train_one_epoch(model,
                        optimizer=optimizer, 
                        train_loader=loader.get_train_loader(),
                        loss_fun=loss_function,
                        device = device)

    validation_accuracy = validate(model, loader.get_validation_loader(), device)
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        del validation_accuracy
        best_val_state_dict   = copy.deepcopy(model.state_dict())
        best_val_model        = copy.deepcopy(model)
        
        if verbose:
            timestamp     = datetime.now().strftime('%d-%m-%Y_%Hh_%Mmin')              
            print(f'Epoch: {epoch:03d}, train_accuracy: {train_accuracy:.4f}, best_validation_accuracy: {best_validation_accuracy:.4f}')            
            model_path = saving_folder+saving_subfolder+'/Model_{}_{}_best_ValAcc_{}_epoch_{}.pt'.format(model.__class__.__name__,timestamp, best_validation_accuracy, epoch)
            torch.save(best_val_model, model_path)
            print(f"saved as {model_path}")

 10%|█         | 1/10 [00:03<00:31,  3.46s/it]

Epoch: 000, train_accuracy: 0.9524, best_validation_accuracy: 0.9611
saved as ./results/trained_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_04min_best_ValAcc_0.961146737080347_epoch_0.pt


 20%|██        | 2/10 [00:06<00:27,  3.48s/it]

Epoch: 001, train_accuracy: 0.9663, best_validation_accuracy: 0.9777
saved as ./results/trained_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_05min_best_ValAcc_0.9777442474537911_epoch_1.pt


 30%|███       | 3/10 [00:10<00:24,  3.49s/it]

Epoch: 002, train_accuracy: 0.9785, best_validation_accuracy: 0.9842
saved as ./results/trained_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_05min_best_ValAcc_0.984156921916258_epoch_2.pt


 40%|████      | 4/10 [00:14<00:21,  3.53s/it]

Epoch: 003, train_accuracy: 0.9839, best_validation_accuracy: 0.9876
saved as ./results/trained_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_05min_best_ValAcc_0.9875518672199171_epoch_3.pt


 50%|█████     | 5/10 [00:17<00:17,  3.55s/it]

Epoch: 004, train_accuracy: 0.9854, best_validation_accuracy: 0.9883
saved as ./results/trained_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_05min_best_ValAcc_0.988306299509619_epoch_4.pt


 60%|██████    | 6/10 [00:21<00:14,  3.54s/it]

Epoch: 005, train_accuracy: 0.9860, best_validation_accuracy: 0.9887
saved as ./results/trained_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_05min_best_ValAcc_0.98868351565447_epoch_5.pt


 70%|███████   | 7/10 [00:24<00:10,  3.54s/it]

Epoch: 006, train_accuracy: 0.9868, best_validation_accuracy: 0.9891
saved as ./results/trained_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_05min_best_ValAcc_0.989060731799321_epoch_6.pt


100%|██████████| 10/10 [00:35<00:00,  3.52s/it]

Epoch: 009, train_accuracy: 0.9882, best_validation_accuracy: 0.9894
saved as ./results/trained_models/Non_masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_05min_best_ValAcc_0.9894379479441721_epoch_9.pt





### Unmasked

In [5]:
saving_folder     = "./results/trained_models"
saving_subfolder = "/Masked_Phe"
loader_path     = "./results/dataloaders/MASKED_loader_Concen_plus_Fluxes.pt"
loader = torch.load(loader_path)
model           = GIN_classifier_to_explain_v2(
                                            n_nodes = a_graph.num_nodes, 
                                            num_features = a_graph.num_node_features, 
                                            n_classes = a_graph.num_classes,
                                            hidden_dim=8,
                                            num_layers=2).to(device, non_blocking=True).to(device)

optimizer       = torch.optim.Adam(model.parameters())
loss_function   = torch.nn.NLLLoss()
best_validation_accuracy = 1e-10
EPOCHS = 100
verbose = True
gc.collect()
torch.cuda.empty_cache() 
for epoch in tqdm.tqdm(range(EPOCHS)):
    
    train_accuracy = train_one_epoch(model,
                        optimizer=optimizer, 
                        train_loader=loader.get_train_loader(),
                        loss_fun=loss_function,
                        device = device)

    validation_accuracy = validate(model, loader.get_validation_loader(), device)
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        del validation_accuracy
        best_val_state_dict   = copy.deepcopy(model.state_dict())
        best_val_model        = copy.deepcopy(model)
        if verbose:
            timestamp     = datetime.now().strftime('%d-%m-%Y_%Hh_%Mmin')              
            print(f'Epoch: {epoch:03d}, train_accuracy: {train_accuracy:.4f}, best_validation_accuracy: {best_validation_accuracy:.4f}')
            model_path = saving_folder+saving_subfolder+'/Model_{}_{}_best_ValAcc_{}_epoch_{}.pt'.format(model.__class__.__name__,timestamp, best_validation_accuracy, epoch)
            torch.save(best_val_model, model_path)
            print(f"saved as {model_path}")

  1%|          | 1/100 [00:03<05:52,  3.56s/it]

Epoch: 000, train_accuracy: 0.5670, best_validation_accuracy: 0.7258
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_07min_best_ValAcc_0.7257638626933233_epoch_0.pt


  3%|▎         | 3/100 [00:10<05:46,  3.57s/it]

Epoch: 002, train_accuracy: 0.7507, best_validation_accuracy: 0.7744
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_07min_best_ValAcc_0.7744247453791022_epoch_2.pt


  4%|▍         | 4/100 [00:14<05:39,  3.54s/it]

Epoch: 003, train_accuracy: 0.7887, best_validation_accuracy: 0.8182
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_07min_best_ValAcc_0.8181818181818182_epoch_3.pt


  5%|▌         | 5/100 [00:17<05:34,  3.52s/it]

Epoch: 004, train_accuracy: 0.8348, best_validation_accuracy: 0.8487
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_07min_best_ValAcc_0.8487363259147491_epoch_4.pt


  6%|▌         | 6/100 [00:21<05:29,  3.51s/it]

Epoch: 005, train_accuracy: 0.8447, best_validation_accuracy: 0.8514
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_07min_best_ValAcc_0.8513768389287062_epoch_5.pt


  7%|▋         | 7/100 [00:24<05:25,  3.50s/it]

Epoch: 006, train_accuracy: 0.8588, best_validation_accuracy: 0.8525
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_07min_best_ValAcc_0.8525084873632591_epoch_6.pt


  8%|▊         | 8/100 [00:28<05:21,  3.49s/it]

Epoch: 007, train_accuracy: 0.8663, best_validation_accuracy: 0.8676
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_07min_best_ValAcc_0.8675971331572991_epoch_7.pt


 11%|█         | 11/100 [00:38<05:09,  3.48s/it]

Epoch: 010, train_accuracy: 0.8862, best_validation_accuracy: 0.8789
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_08min_best_ValAcc_0.8789136175028291_epoch_10.pt


 12%|█▏        | 12/100 [00:42<05:06,  3.48s/it]

Epoch: 011, train_accuracy: 0.8880, best_validation_accuracy: 0.8849
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_08min_best_ValAcc_0.8849490758204451_epoch_11.pt


 13%|█▎        | 13/100 [00:45<05:03,  3.49s/it]

Epoch: 012, train_accuracy: 0.8955, best_validation_accuracy: 0.8853
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_08min_best_ValAcc_0.8853262919652961_epoch_12.pt


 15%|█▌        | 15/100 [00:52<04:56,  3.48s/it]

Epoch: 014, train_accuracy: 0.8955, best_validation_accuracy: 0.8944
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_08min_best_ValAcc_0.8943794794417201_epoch_14.pt


 17%|█▋        | 17/100 [00:59<04:49,  3.49s/it]

Epoch: 016, train_accuracy: 0.9092, best_validation_accuracy: 0.8974
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_08min_best_ValAcc_0.8973972086005281_epoch_16.pt


 21%|██        | 21/100 [01:13<04:34,  3.48s/it]

Epoch: 020, train_accuracy: 0.9053, best_validation_accuracy: 0.9068
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_08min_best_ValAcc_0.9068276122218031_epoch_20.pt


 26%|██▌       | 26/100 [01:30<04:18,  3.49s/it]

Epoch: 025, train_accuracy: 0.9212, best_validation_accuracy: 0.9080
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_08min_best_ValAcc_0.9079592606563561_epoch_25.pt


 28%|██▊       | 28/100 [01:37<04:12,  3.51s/it]

Epoch: 027, train_accuracy: 0.9228, best_validation_accuracy: 0.9110
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_09min_best_ValAcc_0.9109769898151641_epoch_27.pt


 29%|██▉       | 29/100 [01:41<04:08,  3.50s/it]

Epoch: 028, train_accuracy: 0.9220, best_validation_accuracy: 0.9114
saved as ./results/trained_models/Masked_Phe/Model_GIN_classifier_to_explain_v2_06-11-2022_19h_09min_best_ValAcc_0.911354205960015_epoch_28.pt


 30%|███       | 30/100 [01:44<04:05,  3.50s/it]