In [1]:
from acevedo_clss_and_fcns import * 
device = 'cpu'
if torch.cuda.is_available():
    torch.cuda.init()
    if torch.cuda.is_initialized():
        device = 'cuda:0'
#device = torch.device(device)
print(f"{device = }")



  from .autonotebook import tqdm as notebook_tqdm


device = 'cuda:0'


In [4]:
def train_one_epoch(modelo: GIN_classifier_to_explain,
                    optimizer, 
                    train_loader: torch_geometric.loader.dataloader.DataLoader,
                    loss_fun: torch.nn.modules.loss,
                    device:str='cpu' ):

    correct = 0
    for i, data in enumerate(train_loader):
        assert not data.is_cuda   
        if (device == 'cuda:0') | (device == 'cuda'):                            
            data.to(device, non_blocking=True) 
            assert data.is_cuda       
                
        optimizer.zero_grad(set_to_none=True) # Zero your gradients for every batch        
        if (device == 'cuda:0') | (device == 'cuda'):
            #with torch.cuda.amp.autocast():      
            predictions = modelo(data.x, data.edge_index,  None)# Make predictions for this batch
            loss        = loss_fun(predictions, data.y)
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.        
            pred     = predictions.argmax(dim=1)  # Use the class with highest probability.
            correct += int((pred == data.y).sum())  # Check against ground-truth labels.

    return correct / len(train_loader.dataset)

def validate(modelo: GIN_classifier_to_explain, loader: DataLoader, device: str = 'cpu'):
    modelo.eval()
    correct = 0
    for i, val_data in enumerate(loader):
        
        assert not val_data.is_cuda
        if (device == 'cuda:0') | (device == 'cuda'):
            val_data.to(device, non_blocking=True) 
            assert val_data.is_cuda                          

        val_predictions = modelo(val_data.x, val_data.edge_index, None)# Make predictions for this batch
        pred            = val_predictions.argmax(dim=1)

        correct += int((pred == val_data.y).sum())
        

    return correct / len(loader.dataset)   

loader_path = "./results/dataloaders/MASKED_loader_Concen_plus_Fluxes.pt"

loader = torch.load(loader_path)

a_batch         = next(iter(loader.get_train_loader()))
a_graph         = a_batch[0]

batch_size      = len(a_batch.ptr)-1

gc.collect()
torch.cuda.empty_cache() 
model           = GIN_classifier_to_explain(
                                            batch_size=batch_size, 
                                            n_nodes = a_graph.num_nodes, 
                                            num_features = a_graph.num_node_features, 
                                            hidden_dim=16).to(device, non_blocking=True)
optimizer       = torch.optim.Adam(model.parameters())
loss_function   = torch.nn.NLLLoss()
best_validation_accuracy = 1e-10
EPOCHS = 220
verbose = True


In [None]:
saving_path   = "results/trained_pytorch_models"

for epoch in tqdm.tqdm(range(EPOCHS)):
    
    train_accuracy = train_one_epoch(model,
                        optimizer=optimizer, 
                        train_loader=loader.get_train_loader(),
                        loss_fun=loss_function,
                        device = device)

    validation_accuracy = validate(model, loader.get_validation_loader(), device)
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        del validation_accuracy
        best_val_state_dict   = copy.deepcopy(model.state_dict())
        best_val_model        = copy.deepcopy(model)
        if verbose:
            timestamp     = datetime.now().strftime('%d-%m-%Y_%Hh_%Mmin')              
            print(f'Epoch: {epoch:03d}, train_accuracy: {train_accuracy:.4f}, best_validation_accuracy: {best_validation_accuracy:.4f}')
            model_path = saving_path+'/Model_{}_{}_best_ValAcc_{}_epoch_{}.pt'.format(model.__class__.__name__,timestamp, best_validation_accuracy, epoch)
            torch.save(best_val_model, model_path)
            print(f"saved as {model_path}")



In [6]:
timestamp     = datetime.now().strftime('%d-%m-%Y_%Hh_%Mmin')   

saving_path   = "results/trained_pytorch_models"

model_path = saving_path+'/Model_{}_{}_best_ValAcc_{}_epoch_{}.pt'.format(model.__class__.__name__,timestamp, best_validation_accuracy, epoch)
torch.save(best_val_model, model_path)

# -------------end-----------------

In [8]:
def simple_train_validate_from_loader(loader_path, EPOCHS:int=10, device:str='cpu'):
    gc.collect()
    torch.cuda.empty_cache() 
    loader = torch.load(loader_path)
    a_batch         = next(iter(loader.get_train_loader()))
    a_graph         = a_batch[0]
    model          = GIN_classifier(0, a_graph.num_nodes, a_graph.num_node_features)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.NLLLoss()
    def train(loader):
        model.to(device)
        model.train()

        for data in loader:  # Iterate in batches over the training dataset.
            data.to('cuda')
            out = model(data)  # Perform a single forward pass.
            loss = criterion(out, data.y)  # Compute the loss.
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.
            optimizer.zero_grad()  # Clear gradients.

    def test(loader):
        model.eval()
        model.to(device)
        correct = 0
        for data in loader:  # Iterate in batches over the training/test dataset.
            data.to('cuda')
            out = model(data)  
            pred = out.argmax(dim=1)  # Use the class with highest probability.
            correct += int((pred == data.y).sum())  # Check against ground-truth labels.
        return correct / len(loader.dataset)  # Derive ratio of correct predictions.


    for epoch in range(1, EPOCHS):
        train(loader.get_train_loader())
        train_acc = test(loader.get_train_loader())
        test_acc = test(loader.get_validation_loader())
        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
        
def Advanced_train_validate_from_loader(loader_path: str=None, EPOCHS:int=10, device:str='cpu'):
    gc.collect()
    torch.cuda.empty_cache() 
    loader = torch.load(loader_path)

    a_batch         = next(iter(loader.get_train_loader()))
    a_graph         = a_batch[0]
    model          = GIN_classifier(0, a_graph.num_nodes, a_graph.num_node_features)
    optimizer     = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_function = torch.nn.NLLLoss()
    best_val_model, last_best_state_dict_path, last_best_model_path =  train_and_validate(
                                        model,loss_function,optimizer, EPOCHS =EPOCHS ,
                                        train_loader= loader.get_train_loader(),
                                        validation_loader= loader.get_validation_loader(),
                                        save_state_dict = False,save_entire_model=False,
                                        verbose=True,
                                        saving_path = 'results', device=device)
  