In [2]:
import os
from tqdm import tqdm
from matplotlib import pyplot as plt
from torch.utils.tensorboard.writer import SummaryWriter
# from torch.profiler import profile, record_function, ProfilerActivity
import torch
from model import GVAE
from loss import reconstruction_loss, kl_loss
from dataset import SketchDataset
from torch.utils.data import DataLoader, Subset, random_split
os.chdir('SketchGraphs/')
import sketchgraphs.data as datalib
os.chdir('../')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
class MultiGPUTrainer:
    def __init__(
            self,
            model: torch.nn.Module,
            train_set: Subset,
            validate_set: Subset,
            optimizer: torch.optim.Optimizer,
            scheduler: torch.optim.lr_scheduler.LRScheduler,
            gpu_id: int,
            num_epochs: int,
            experiment_string: str,
            batch_size: int
            ):
        model.device = gpu_id
        self.model = model.to(gpu_id)
        self.train_loader = DataLoader(dataset = train_set, 
                                       batch_size = batch_size, 
                                       shuffle = True
                                      )
        self.validate_loader = DataLoader(dataset = validate_set, 
                                          batch_size = batch_size, 
                                          shuffle = True
                                         )
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.gpu_id = gpu_id
        self.writer = SummaryWriter(f'runs/{experiment_string}')
        self.num_epochs = num_epochs

        self.global_step = 0
        self.curr_epoch = 0
        self.min_validation_loss = float('inf')

    
    def train_batch(self, nodes : torch.Tensor, edges : torch.Tensor, node_params_mask : torch.Tensor) -> float:
        self.optimizer.zero_grad()

        nodes = nodes.to(self.gpu_id)
        edges = edges.to(self.gpu_id)
        node_params_mask = node_params_mask.to(self.gpu_id)

        pred_nodes, pred_edges, means, logvars = self.model(nodes, edges)

        assert pred_nodes.isfinite().all(), "Model output for nodes has non finite values"
        assert pred_edges.isfinite().all(), "Model output for edges has non finite values"
        assert means.isfinite().all(),      "Model output for means has non finite values"
        assert logvars.isfinite().all(),    "Model output for logvars has non finite values"

        loss = reconstruction_loss(pred_nodes, pred_edges, nodes, edges, node_params_mask)
        # loss += 0.1*kl_loss(means, logvars)

        assert loss.isfinite().all(), "Loss is non finite value"

        loss.backward()
        self.optimizer.step()

        return loss.item()
    
    def train_epoch(self):
        pbar = tqdm(self.train_loader)
        for nodes, edges, node_params_mask in pbar:
            iter_loss = self.train_batch(nodes, edges, node_params_mask)

            self.global_step += 1

            if (self.global_step % 10 == 9):
                if self.gpu_id == 0: self.writer.add_scalar("Training Loss", iter_loss, self.global_step)
                self.scheduler.step(iter_loss)
            
            pbar.set_description(f"Training Epoch {self.curr_epoch} Iter Loss: {iter_loss}  ")
    
    @torch.no_grad()
    def validate(self):
        pbar = tqdm(self.validate_loader)
        total_loss = 0
        for nodes, edges, node_params_mask in pbar:
            nodes = nodes.to(self.gpu_id)
            edges = edges.to(self.gpu_id)
            node_params_mask = node_params_mask.to(self.gpu_id)

            pred_nodes, pred_edges, means, logvars = self.model(nodes, edges)

            loss = reconstruction_loss(pred_nodes, pred_edges, nodes, edges, node_params_mask)
            # loss += 0.1*kl_loss(means, logvars)

            total_loss += loss

            assert loss.isfinite().all(), "Loss is non finite value"

            pbar.set_description(f"Validating Epoch {self.curr_epoch}  ")
        
        avg_loss = total_loss / len(pbar)
        if avg_loss < self.min_validation_loss:
            self.min_validation_loss = avg_loss
            if self.gpu_id == 0: 
                self.writer.add_scalar("Validation Loss", avg_loss, self.curr_epoch)
                self.save_checkpoint()
        
        if self.gpu_id == 0:
            fig, axes = plt.subplots(nrows = 4, ncols = 2, figsize=(8, 16))
            fig.suptitle(f"Target (left) vs Preds (right) for epoch {self.curr_epoch}")
            for i in range(4):
                target_sketch = SketchDataset.preds_to_sketch(nodes[i].cpu(), edges[i].cpu())
                pred_sketch = SketchDataset.preds_to_sketch(pred_nodes[i].cpu(), pred_edges[i].cpu())
                
                datalib.render_sketch(target_sketch, axes[i, 0])
                datalib.render_sketch(pred_sketch, axes[i, 1])
            
            self.writer.add_figure(f"Epoch result visualization", fig, self.curr_epoch)
            plt.close()
                
    def train(self):
        self.global_step = 0
        self.curr_epoch = 0

        while (self.curr_epoch < self.num_epochs):
            self.model.train()
            self.train_epoch()
            self.model.eval()
            self.validate()
            self.curr_epoch += 1
    
    def save_checkpoint(self):
        checkpoint = self.model.state_dict()
        torch.save(checkpoint, "best_model_checkpoint.pth")


In [3]:
dataset = SketchDataset(root="data/")

In [4]:
train_set, validate_set, test_set = random_split(dataset = dataset, lengths = [0.9, 0.033, 0.067])

print("Number of Graphs in total: ", len(dataset))
print("Number of Graphs for training: ", len(train_set))
print("Number of Graphs for validation: ", len(validate_set))
print("Number of Graphs for testing: ", len(test_set))

Number of Graphs in total:  3981513
Number of Graphs for training:  3583362
Number of Graphs for validation:  131390
Number of Graphs for testing:  266761


In [5]:
batch_size = 768
learning_rate = 1e-10
num_epochs = 50
experiment_string = "gvae_experiment_ddp_2"

In [6]:
model = GVAE(device)
# if os.path.exists(f"best_model_checkpoint.pth"):
   #  model.load_state_dict(torch.load(f"best_model_checkpoint.pth"))

optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 5)

In [7]:
trainer = MultiGPUTrainer(
    model = model,
    train_set = train_set,
    validate_set = validate_set,
    optimizer = optimizer,
    scheduler = scheduler,
    gpu_id = 0,
    num_epochs = num_epochs,
    experiment_string = experiment_string,
    batch_size = batch_size
)

In [None]:
trainer.train()

In [None]:
def train_on_multiple_gpus(rank: int, world_size: int):
    MultiGPUTrainer.ddp_setup(rank, world_size)

    model = GVAE(device)
    optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = "min", patience = 5)

    trainer = MultiGPUTrainer(
        model = model,
        train_set = train_set,
        validate_set = validate_set,
        optimizer = optimizer,
        scheduler = scheduler,
        gpu_id = rank,
        num_epochs = num_epochs,
        experiment_string = experiment_string,
        batch_size = batch_size,
        num_workers = 32
    )

    trainer.train()
    
    destroy_process_group()

# Working Code

In [1]:
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

In [2]:
import torch
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Subset, random_split, TensorDataset
from dataset import SketchDataset
from model import GVAE
from loss import reconstruction_loss, kl_loss
from distributed_trainer import MultiGPUTrainer, train_on_multiple_gpus
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataset = SketchDataset(root="data/")
dataset = TensorDataset(dataset.nodes, dataset.edges)
train_set, validate_set, test_set = random_split(dataset = dataset, lengths = [0.9, 0.033, 0.067])

print("Number of Graphs in total: ", len(dataset))
print("Number of Graphs for training: ", len(train_set))
print("Number of Graphs for validation: ", len(validate_set))
print("Number of Graphs for testing: ", len(test_set))

batch_size = 400
learning_rate = 5e-4
num_epochs = 25
experiment_string = "gvae_ddp_embedsize_768_num_head_8_num_tflayers_4_no_batchlayernorm"

Number of Graphs in total:  3981513
Number of Graphs for training:  3583362
Number of Graphs for validation:  131390
Number of Graphs for testing:  266761


In [3]:
world_size = torch.cuda.device_count()
mp.spawn(train_on_multiple_gpus, 
    args=(
        world_size, 
        train_set, 
        validate_set, 
        learning_rate, 
        batch_size, 
        num_epochs, 
        experiment_string
        ), 
    nprocs=world_size)

Training Epoch 0 Iter Loss: 1.8633977174758911  : 100%|██████████| 1120/1120 [15:34<00:00,  1.20it/s]]
Training Epoch 0 Iter Loss: 1.8966702222824097  : 100%|██████████| 1120/1120 [15:34<00:00,  1.20it/s]
Training Epoch 0 Iter Loss: 1.8579736948013306  : 100%|██████████| 1120/1120 [15:34<00:00,  1.20it/s]

Training Epoch 0 Iter Loss: 1.8943809270858765  : 100%|██████████| 1120/1120 [15:34<00:00,  1.20it/s]
Training Epoch 0 Iter Loss: 1.9076563119888306  : 100%|██████████| 1120/1120 [15:34<00:00,  1.20it/s]
Training Epoch 0 Iter Loss: 1.9261585474014282  : 100%|██████████| 1120/1120 [15:34<00:00,  1.20it/s]
Training Epoch 0 Iter Loss: 1.8773307800292969  : 100%|██████████| 1120/1120 [15:34<00:00,  1.20it/s]
Validating Epoch 0  : 100%|██████████| 42/42 [00:13<00:00,  3.23it/s]
Validating Epoch 0  : 100%|██████████| 42/42 [00:12<00:00,  3.23it/s]
Validating Epoch 0  : 100%|██████████| 42/42 [00:13<00:00,  3.18it/s]
Validating Epoch 0  : 100%|██████████| 42/42 [00:13<00:00,  3.16it/s]
Vali

In [None]:
from torch.utils.tensorboard.writer import SummaryWriter

train_loader = DataLoader(dataset = train_set, batch_size = batch_size, shuffle = True)
validate_loader = DataLoader(dataset = validate_set, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(dataset = test_set, batch_size = batch_size, shuffle = True)
#torch.multiprocessing.set_sharing_strategy('file_system')



# validate_batches = [batch for batch in tqdm(DataLoader(dataset = validate_set, batch_size = batch_size, shuffle = True, persistent_workers = True, num_workers = 32))]
# test_batches = [batch for batch in tqdm(DataLoader(dataset = test_set, batch_size = batch_size, shuffle = True, persistent_workers = True, num_workers = 32))]

model = GVAE(device)
# if os.path.exists(f"best_model_checkpoint.pth"):
   #  model.load_state_dict(torch.load(f"best_model_checkpoint.pth"))

optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 5)

writer = SummaryWriter('runs/gvae_experiment_8')

In [5]:
print(device)

cuda:0


In [None]:
best_validation_loss = float('inf')# validate_model(model, validate_loader, writer)
global_step = 0
for epoch in range(num_epochs):
    # Train Model for one epoch
    model.train()
    total_train_loss = 0.0
    num_train_batches = 0
    pbar = tqdm(train_loader)
    for target_nodes, target_edges in pbar:
        optimizer.zero_grad()

        target_nodes = target_nodes.to(device)
        target_edges = target_edges.to(device)

        pred_nodes, pred_edges, means, logvars = model(target_nodes, target_edges)

        if (not pred_nodes.isfinite().all()):
            raise ValueError("pred nodes is not finite")
        if (not pred_edges.isfinite().all()):
            raise ValueError("pred edges is not finite")
        if (not means.isfinite().all()):
            raise ValueError("means is not finite")
        if (not logvars.isfinite().all()):
            raise ValueError("logvars is not finite")
        
        loss = reconstruction_loss(pred_nodes, pred_edges, target_nodes, target_edges) 
        loss += 0.1*kl_loss(means, logvars)
        
        if (not loss.isfinite().all()):
            raise ValueError("Loss is not finite")
            
        loss.backward()
        optimizer.step()
        
        pbar.set_description(f"Epoch {epoch} --Training-- Iter loss: {loss.item()} -")
        
        
        total_train_loss += loss.item()
        num_train_batches += 1
        # Log training loss every 100 mini-batches
        if num_train_batches % 100 == 99:
            avg_train_loss = total_train_loss / num_train_batches  # Calculate average up to the current batch
            writer.add_scalar("Training Loss", avg_train_loss, global_step)
            
            # Step Scheduler
            scheduler.step(avg_train_loss)
            
            total_train_loss = 0.0
            num_train_batches = 0
        
        global_step += 1
    
    # Deallocate GPU memory
    target_nodes = None
    target_edges = None
    node_params_mask = None
    pred_nodes = None
    pred_edges = None
    means = None
    logvars = None
    torch.cuda.synchronize()
    
    # Validate Model at end of epoch
    model.eval()
    with torch.no_grad():
        total_validate_loss = 0.0
        num_validate_batches = 0
        for batch_idx, (nodes, edges) in enumerate(tqdm(validate_loader)):
            target_nodes = nodes.to(device)
            target_edges = edges.to(device)

            pred_nodes, pred_edges, means, logvars = model(target_nodes, target_edges)

            loss = reconstruction_loss(pred_nodes, pred_edges, target_nodes, target_edges) 
            loss += 0.1*kl_loss(means, logvars)
            
            if (not loss.isfinite().all()):
                raise ValueError("Loss is not finite")
        
            # Save an 4 image pairs of an input and output of model
            if num_validate_batches == 0:
                fig, axes = plt.subplots(nrows = 4, ncols = 2, figsize=(8, 16))
                fig.suptitle(f"Target (left) vs Preds (right) for epoch {epoch}")
                for i in range(4):
                    target_sketch = SketchDataset.preds_to_sketch(target_nodes[i].cpu(), target_edges[i].cpu())
                    pred_sketch = SketchDataset.preds_to_sketch(pred_nodes[i].detach().cpu(), pred_edges[i].detach().cpu())
                
                    datalib.render_sketch(target_sketch, axes[i, 0])
                    datalib.render_sketch(pred_sketch, axes[i, 1])
            
                writer.add_figure(f"Epoch result visualization", fig, epoch)
                plt.close()
            
            
            total_validate_loss += loss.item()
            num_validate_batches += 1
    
        # Calculate and log the average validation loss for this epoch
        avg_validate_loss = total_validate_loss / num_validate_batches  # Calculate average up to the last batch
        writer.add_scalar("Validation Loss", avg_validate_loss, epoch)

        # Save the model checkpoint if the validation loss improves
        if avg_validate_loss < best_validation_loss:
            best_validation_loss = avg_validate_loss
            checkpoint_path = f"best_model_checkpoint.pth"
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Saved model checkpoint with validation loss: {best_validation_loss:.4f} to {checkpoint_path}")
    
    # Deallocate GPU memory
    target_nodes = None
    target_edges = None
    node_params_mask = None
    pred_nodes = None
    pred_edges = None
    means = None
    logvars = None
    torch.cuda.synchronize()


In [None]:

# Test Model
print(f"Testing model ---")
model.eval()
total_test_loss = 0.0
num_test_batches = 0
for batch_idx, (nodes, edges, node_params_mask) in enumerate(test_loader):
    target_nodes = nodes.to(device)
    target_edges = edges.to(device)
    node_params_mask = node_params_mask.to(device)

    pred_nodes, pred_edges, means, logvars = model(target_nodes, target_edges)
    
    loss = reconstruction_loss(pred_nodes, pred_edges, target_nodes, target_edges, node_params_mask) + kl_loss(means, logvars)

    total_test_loss += loss.item()
    num_test_batches += 1

writer.add_scalar("Test Loss", total_test_loss / num_test_batches)


In [31]:
import torch
import torch.nn.functional as F
from torch import Tensor
from config import NUM_PRIMITIVE_TYPES, NUM_CONSTRAINT_TYPES

def reconstruction_loss(pred_nodes : Tensor, pred_edges : Tensor, target_nodes : Tensor, target_edges : Tensor, node_params_mask : Tensor):
    '''Node Loss'''
    bce = F.binary_cross_entropy(input = pred_nodes[:,:,0], target = target_nodes[:,:,0], reduction = 'sum')

    weight = torch.tensor([1.0, 4.0, 4.0, 3.0, 1.0]).to(pred_nodes.device)             # Weight circles, arcs, and points higher since they are much rarer than line and none types
    primitive_type_labels = torch.argmax(target_nodes[:,:,1:6], dim = 2)               # batch_size x num_nodes (class index for each node)
    primitive_type_logits = pred_nodes[:,:,1:6].permute(0,2,1).contiguous() # batch_size x num_primitive_types x num_nodes
    
    node_cross = F.nll_loss(input = primitive_type_logits, target = primitive_type_labels, weight = weight, reduction = 'sum')

    # node_params_mask ensures that only relevant primtive parameters are used for loss 
    mse = F.mse_loss(input = pred_nodes[:,:,6:] * node_params_mask, target = target_nodes[:,:,6:], reduction='sum')

    # Normalize losses to prevent mse loss from dominating
    node_loss = bce + node_cross + mse
    node_loss = bce + 4 * node_cross + 8 * mse
    
    '''Edge Loss'''
    subnode_a_labels = torch.argmax(target_edges[:,:,:,0:4], dim = 3)
    subnode_a_logits = pred_edges[:,:,:,0:4].permute(0, 3, 1, 2).contiguous()
    sub_a_cross_entropy = F.nll_loss(input = subnode_a_logits, target = subnode_a_labels, reduction = 'sum')

    subnode_b_labels = torch.argmax(target_edges[:,:,:,4:8], dim = 3)
    subnode_b_logits = pred_edges[:,:,:,4:8].permute(0, 3, 1, 2).contiguous()
    sub_b_cross_entropy = F.nll_loss(input = subnode_b_logits, target = subnode_b_labels, reduction = 'sum')

    constraint_type_labels = torch.argmax(target_edges[:,:,:,8:], dim = 3)
    constraint_type_logits = pred_edges[:,:,:,8:].permute(0, 3, 1, 2).contiguous()
    constraint_cross_entropy = F.nll_loss(input = constraint_type_logits, target = constraint_type_labels, reduction = 'mean')

    edge_loss = sub_a_cross_entropy + sub_b_cross_entropy + constraint_cross_entropy
    edge_loss = sub_a_cross_entropy + sub_b_cross_entropy + constraint_cross_entropy
    
    return node_loss + 0.3 * edge_loss

def kl_loss(means : Tensor, logvars : Tensor):
    MAX_LOGVAR = 20
    logvars = torch.clamp(input = logvars, max = MAX_LOGVAR)

    kld = -0.5 * torch.sum(1 + logvars - means * means - torch.exp(logvars))
    kld = torch.clamp(input = kld, max = 1000)
    return kld

In [8]:
import torch
import torch.nn.functional as F
from torch import Tensor
from config import NUM_PRIMITIVE_TYPES, NUM_CONSTRAINT_TYPES

def reconstruction_loss(pred_nodes : Tensor, pred_edges : Tensor, target_nodes : Tensor, target_edges : Tensor, node_params_mask : Tensor):
    '''Node Loss'''
    weight = torch.tensor([1.0, 4.0, 4.0, 3.0, 0.1]).to(pred_nodes.device)  # Weight circles, arcs, and points higher since they are much rarer than line and none types
    primitive_type_labels = torch.argmax(target_nodes[:,:,1:6], dim = 2)    # batch_size x num_nodes (class index for each node)
    primitive_type_logits = pred_nodes[:,:,1:6].permute(0,2,1).contiguous() # batch_size x num_primitive_types x num_nodes
    
    node_cross = F.nll_loss(
        input = primitive_type_logits.log(), 
        target = primitive_type_labels, 
        weight = weight, 
        reduction = 'sum')

    # Only apply bce for primitives that are not none types
    bce = F.binary_cross_entropy(
        input = pred_nodes[primitive_type_labels != 4][:,0], 
        target = target_nodes[primitive_type_labels != 4][:,0],
        reduction = 'sum')
    # node_params_mask ensures that only relevant primtive parameters are used for loss 
    mse = F.mse_loss(input = pred_nodes[:,:,6:] * node_params_mask, target = target_nodes[:,:,6:], reduction='sum')

    # Total node loss
    node_loss = bce + 4 * node_cross + 8 * mse
    
    '''Edge Loss'''
    constraint_type_labels = torch.argmax(target_edges[:,:,:,8:], dim = 3)
    constraint_type_logits = pred_edges[:,:,:,8:].permute(0, 3, 1, 2).contiguous()
    # There are far more none constraint types, so weigh them less
    constraint_cross_entropy = F.nll_loss(
        input = constraint_type_logits.log(), 
        target = constraint_type_labels,
        weight = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.05]).to(pred_edges.device),
        reduction = 'sum')
    
    # Only apply subnode loss to constraints that are not none -------
    subnode_a_labels = torch.argmax(target_edges[:,:,:,0:4], dim = 3)[constraint_type_labels != 8]
    subnode_a_logits = pred_edges[:,:,:,0:4][constraint_type_labels != 8]
    sub_a_cross_entropy = F.nll_loss(
        input = subnode_a_logits.log(), 
        target = subnode_a_labels, 
        reduction = 'sum')

    subnode_b_labels = torch.argmax(target_edges[:,:,:,4:8], dim = 3)[constraint_type_labels != 8]
    subnode_b_logits = pred_edges[:,:,:,4:8][constraint_type_labels != 8]
    sub_b_cross_entropy = F.nll_loss(
        input = subnode_b_logits.log(), 
        target = subnode_b_labels, 
        reduction = 'sum')


    edge_loss = sub_a_cross_entropy + sub_b_cross_entropy + constraint_cross_entropy
    
    return node_loss + 0.3 * edge_loss

def kl_loss(means : Tensor, logvars : Tensor):
    # MAX_LOGVAR = 20
    # logvars = torch.clamp(input = logvars, max = MAX_LOGVAR)

    kld = -0.5 * torch.sum(1 + logvars - means * means - torch.exp(logvars))
    # kld = torch.clamp(input = kld, max = 1000)
    return kld

In [5]:
train_ldr = iter(train_loader)

In [6]:
nodes, edges, node_params_mask = next(train_ldr)
nodes2 = nodes.clone()
nodes2[0,0,0] = 1

In [9]:
reconstruction_loss(nodes2, edges, nodes, edges, node_params_mask)

tensor(100.)