In [1]:
import torch
import torch.nn as nn
import numpy as np
import os

from tqdm.notebook import tqdm

from data import TSPDataset, _normalize_np_array
from config import Config, Checkpoint, MetaData, Metrics
from model import ConvolutionalSalesmanNet, construct_path, calc_path_metric, calc_path_cost

In [2]:
CONFIG_PATH = os.getenv("CONFIG_PATH", "./config.json")

if os.path.exists(CONFIG_PATH):
    config = Config.from_json(CONFIG_PATH)
else:
    config = Config()
    config.store_as_json(CONFIG_PATH)


In [3]:
train_dataset = TSPDataset.from_disk(config.data_path, config.num_path_variations_per_example, problem_size_upper_bound=config.train_problem_size_cutoff)

validation_uuids = config.get_validation_uuids()
if validation_uuids is None:
    validation_dataset = train_dataset.stratified_split(config.validation_tot_size)
    validation_uuids = validation_dataset.get_uuids()
    config.store_validation_uuids(validation_uuids)
else:
    validation_dataset = train_dataset.split_by_uuids(validation_uuids)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(validation_dataset)}")

Loading data from disk...:   0%|          | 0/218000 [00:00<?, ?it/s]

Training dataset size: 217000
Validation dataset size: 1000


In [4]:

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, shuffle=True)

In [5]:
curr_checkpoint = config.get_curr_checkpoint()
bssf_path_const_metric = config.get_bssf_metric()

In [6]:
model = ConvolutionalSalesmanNet().to(config.device)
loss_fn = nn.MSELoss().to(config.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.min_lr)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config.max_lr, total_steps=config.tot_train_batches)

if curr_checkpoint is None:
    curr_checkpoint = Checkpoint(
        None,
        None,
        None,
        Metrics(),
        MetaData()
    )
else:
    print("Loading checkpoint state dicts")
    model.load_state_dict(curr_checkpoint.model_state_dict)
    optimizer.load_state_dict(curr_checkpoint.optimizer_state_dict)
    lr_scheduler.load_state_dict(curr_checkpoint.lr_scheduler_state_dict)

Loading checkpoint state dicts


In [7]:
train_iter = iter(train_loader)
master_p_bar = tqdm(range(config.tot_train_batches), desc="Training")
master_p_bar.update(curr_checkpoint.metadata.num_batches_trained)
model.train()

try:
    curr_checkpoint_tot_loss = 0
    while curr_checkpoint.metadata.num_batches_trained < config.tot_train_batches:
        master_p_bar.set_description("Training")
        batch: torch.Tensor
        target: torch.Tensor

        try:
            (batch, target) = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            (batch, target) = next(train_iter)

        optimizer.zero_grad()

        # Format the data and move to GPU
        batch = batch.squeeze(0).to(config.device)
        target = target.squeeze(0).to(config.device)

        # Predict and Calculate Loss
        path_predictions: torch.Tensor = model(batch)
        loss: torch.Tensor = loss_fn(path_predictions, target)

        # Adjust Model
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # Update Progress
        curr_checkpoint.metadata.num_batches_trained += 1
        master_p_bar.update(1)
        curr_checkpoint_tot_loss += loss.item()

        if curr_checkpoint.metadata.num_batches_trained % 100 == 0:
            # Update postfix
            num_batches_before_checkpoint =  curr_checkpoint.metadata.num_batches_trained % config.batches_per_checkpoint
            trained_batches = config.batches_per_checkpoint - num_batches_before_checkpoint
            master_p_bar.set_postfix(
                train_loss= curr_checkpoint_tot_loss/ trained_batches, 
                vald_loss = curr_checkpoint.metrics.validation_loss[-1] if len(curr_checkpoint.metrics.validation_loss) > 0 else None,
                path_construction_metric = curr_checkpoint.metrics.path_construction_metrics[-1] if len(curr_checkpoint.metrics.path_construction_metrics) > 0 else None)

        if curr_checkpoint.metadata.num_batches_trained % config.batches_per_checkpoint == 0:
            optimizer.zero_grad()
            master_p_bar.set_description("Validating")
            # Run Validation and Save Checkpoint
            curr_checkpoint.metrics.training_loss.append(curr_checkpoint_tot_loss / config.batches_per_checkpoint)
            curr_checkpoint_tot_loss = 0


            model.eval()
            with torch.no_grad():
                tot_validation_loss = 0
                path_construction_metric = []
                for (batch, target) in validation_loader:
                    batch = batch.squeeze(0).to(config.device)
                    target = target.squeeze(0).to(config.device)

                    path_predictions = model(batch)
                    loss = loss_fn(path_predictions, target)

                    tot_validation_loss += loss.item()

                    constructed_path = construct_path(model, batch[0, 0])
                    path_metric = calc_path_metric(constructed_path, target[0], batch[0, 0])
                    path_construction_metric.append(path_metric)

            curr_checkpoint.metrics.validation_loss.append(tot_validation_loss / len(validation_loader))
            checkpoint_path_construction_average = np.mean(path_construction_metric)
            curr_checkpoint.metrics.path_construction_metrics.append(checkpoint_path_construction_average)
            curr_checkpoint.metrics.learning_rate.append(lr_scheduler.get_last_lr()[0])

            # Update postfix
            num_batches_before_checkpoint = config.batches_per_checkpoint % curr_checkpoint.metadata.num_batches_trained
            trained_batches = config.batches_per_checkpoint - num_batches_before_checkpoint
            master_p_bar.set_postfix(train_loss= curr_checkpoint.metrics.training_loss[-1], 
                                     vald_loss = curr_checkpoint.metrics.validation_loss[-1],
                                     path_construction_metric = curr_checkpoint.metrics.path_construction_metrics[-1])


            if bssf_path_const_metric is None or checkpoint_path_construction_average < bssf_path_const_metric:
                bssf_path_const_metric = checkpoint_path_construction_average
                config.store_new_bssf(model, curr_checkpoint.metrics)

            # Get new state dicts
            curr_checkpoint.model_state_dict = model.state_dict()
            curr_checkpoint.optimizer_state_dict = optimizer.state_dict()
            curr_checkpoint.lr_scheduler_state_dict = lr_scheduler.state_dict()
            config.store_new_checkpoint(curr_checkpoint)
            model.train()

except KeyboardInterrupt:
    print("Interrupted storing checkpoint gracefully...")
    curr_checkpoint.model_state_dict = model.state_dict()
    curr_checkpoint.optimizer_state_dict = optimizer.state_dict()
    curr_checkpoint.lr_scheduler_state_dict = lr_scheduler.state_dict()
    config.store_new_checkpoint(curr_checkpoint)

except Exception as e:
    print(batch.shape)
    print(e)

    # Dang memory leaks
    del model

    raise e




Training:   0%|          | 0/600000 [00:00<?, ?it/s]

Interrupted storing checkpoint gracefully...


In [None]:
# test_dataset = TSPDataset(DATA_PATH, problem_size_lower_bound=TEST_PROBLEM_SIZE_CUTOFF)