### Imports

In [1]:
import os
import warnings
import json
from architectures import SimpleCNN
from torch.utils.tensorboard import SummaryWriter
from Img_Dataset import Imgdataset
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from matplotlib import pyplot as plt

### Utils

In [2]:
def plot(inputs, targets, predictions, path, update):
    """Plotting the inputs, targets and predictions to file ``path``."""
    os.makedirs(path, exist_ok=True)
    fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
    
    for i in range(len(inputs)):
        for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
            ax.clear()
            ax.set_title(title)
            ax.imshow(data[i, 0], cmap="gray", interpolation="none")
            ax.set_axis_off()
        fig.savefig(os.path.join(path, f"{update:07d}_{i:02d}.png"), dpi=100)
    
    plt.close(fig)

### Dataset

In [3]:
ds = Imgdataset("C:\\Users\\Bashar Hanna\\Desktop\\py2proj\\training\\000")
training_set = torch.utils.data.Subset(
    ds,
        indices=np.arange(int(len(ds) * (3 / 5)))
    )
validation_set = torch.utils.data.Subset(
        ds,
        indices=np.arange(int(len(ds) * (3 / 5)), int(len(ds) * (4 / 5)))
    )
test_set = torch.utils.data.Subset(
        ds,
        indices=np.arange(int(len(ds) * (4 / 5)), len(ds))
    )

### Collate_fn

In [4]:
import torch


def collate_fn(batch):
    max_X = max([item[0].shape[1] for item in batch])
    max_Y = max([item[0].shape[2] for item in batch])
    n_feature_channels = batch[0][0].shape[0]
    n_samples = len(batch)
    #print(max([item[3].shape[2] for item in batch]))
    stacked_pixelated_images = torch.zeros(
        (n_samples, n_feature_channels, max_X, max_Y), dtype=torch.float32)
    stacked_known_arrays = torch.zeros(
        (n_samples, n_feature_channels, max_X, max_Y), dtype=torch.bool)
    stacked_target_arrays = torch.zeros(
        n_samples, n_feature_channels, max_X, max_Y, dtype=torch.float32)

    for i, (pixelated_image, known_array, target_array, target_array_with_padding) in enumerate(batch):
        stacked_pixelated_images[i, :, :pixelated_image.shape[1],
                                 :pixelated_image.shape[2]] = pixelated_image
        stacked_known_arrays[i, :, :known_array.shape[1],
                             :known_array.shape[2]] = known_array
        stacked_target_arrays[i, :, :target_array_with_padding.shape[1],
                              :target_array_with_padding.shape[2]] = target_array_with_padding

    return stacked_pixelated_images, stacked_known_arrays, stacked_target_arrays


### Dataloader

In [5]:
train_loader = torch.utils.data.DataLoader(training_set, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(validation_set, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn)

### CNN

In [6]:
class MyCNN(torch.nn.Module):
    def __init__(self, n_in_channels: int = 1, n_hidden_layers: int = 3, n_kernels: int = 32, kernel_size: int = 7) -> None:
        super().__init__()
        cnn = []
        for i in range(n_hidden_layers):
            cnn.append(torch.nn.Conv2d(
                in_channels=n_in_channels,
                out_channels=n_kernels,
                kernel_size=kernel_size,
                padding=kernel_size // 2
            ))
            cnn.append(torch.nn.ReLU())
            n_in_channels = n_kernels
        self.hidden_layers = torch.nn.Sequential(*cnn)

        self.output_layer = torch.nn.Conv2d(
            in_channels=n_in_channels,
            out_channels=1,
            kernel_size=kernel_size,
            padding=kernel_size // 2)

    def forward(self, x):
        """Apply CNN to input ``x`` of shape ``(N, n_channels, X, Y)``, where
        ``N=n_samples`` and ``X``, ``Y`` are spatial dimensions."""
        # Apply hidden layers: (N, n_in_channels, X, Y) -> (N, n_kernels, X, Y)
        cnn_out = self.hidden_layers(x)
        # Apply output layer: (N, n_kernels, X, Y) -> (N, 1, X, Y)
        predictions = self.output_layer(cnn_out)
        return predictions


### Evaluate Model

In [7]:
def evaluate_model(model: torch.nn.Module, loader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
    """Function for evaluation of a model ``model`` on the data in
    ``loader`` on device ``device``, using the specified ``loss_fn`` loss
    function."""
    model.eval()
    # We will accumulate the mean loss
    loss = 0
    with torch.no_grad():  # We do not need gradients for evaluation
        # Loop over all samples in the specified data loader
        for pixelated_images, known_arrays, target_arrays in tqdm(loader, desc="Evaluating", position=0, leave=False):
            # Get a sample and move inputs and targets to device
            inputs = torch.cat((pixelated_images, known_arrays),1).to(device)
            targets = target_arrays.to(device)
            
            # Get outputs of the specified model
            outputs = model(inputs)
            outputs = outputs*((~known_arrays)*1)
            outputs = torch.clamp(outputs, min=0, max=255)
            # Here, we could clamp the outputs to the minimum and maximum values
            # of the inputs for better performance
            
            # Add the current loss, which is the mean loss over all minibatch
            # samples (unless explicitly otherwise specified when creating the
            # loss function!)
            loss += loss_fn(outputs, targets).item()
    # Get final mean loss by dividing by the number of minibatch iterations
    # (which we summed up in the above loop)
    loss /= len(loader)
    model.train()
    return loss

## Main

In [8]:
def main(
        results_path,
        network_config: dict,
        learning_rate: float = 1e-3,
        weight_decay: float = 1e-5,
        n_updates: int = 50_000,
        device: str = "cuda"
):
    device = torch.device(device)
    if "cuda" in device.type and not torch.cuda.is_available():
        warnings.warn("CUDA not available, falling back to CPU")
        device = torch.device("cpu")
    np.random.seed(4)
    torch.manual_seed(4)

    # Prepare a path to plot to
    plot_path = os.path.join(results_path, "plots")
    os.makedirs(plot_path, exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(results_path, "tensorboard"))
    net = SimpleCNN(**network_config)
    net.to(device)
    mse = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
    write_stats_at = 1000  # Write status to TensorBoard every x updates
    plot_at = 10_000  # Plot every x updates
    validate_at = 5000  # Evaluate model on validation set and check for new best model every x updates
    update = 0  # Current update counter
    best_validation_loss = np.inf  # Best validation loss so far
    update_progress_bar = tqdm(total=n_updates, desc=f"loss: {np.nan:7.5f}", position=0)
    saved_model_file = os.path.join(results_path, "best_model.pt")
    torch.save(net, saved_model_file)
    while update < n_updates:
        for pixelated_images, known_arrays, targets in train_loader:
            # Get next samples
            inputs = torch.cat((pixelated_images, known_arrays),1).to(device)
            targets = targets.to(device)
            
            # Reset gradients
            optimizer.zero_grad()
            
            # Get outputs of our network
            outputs = net(inputs)
            outputs = outputs*((~known_arrays)*1)
            outputs = torch.clamp(outputs, min=0, max=255)
            
            # Calculate loss, do backward pass and update weights
            loss = mse(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Write current training status
            if (update + 1) % write_stats_at == 0:
                writer.add_scalar(tag="Loss/training", scalar_value=loss.cpu(), global_step=update)
                for i, (name, param) in enumerate(net.named_parameters()):
                    writer.add_histogram(tag=f"Parameters/[{i}] {name}", values=param.cpu(), global_step=update)
                    writer.add_histogram(tag=f"Gradients/[{i}] {name}", values=param.grad.cpu(), global_step=update)
            
            # Plot output
            if (update + 1) % plot_at == 0:
                plot(inputs.detach().cpu().numpy(), targets.detach().cpu().numpy(), outputs.detach().cpu().numpy(),
                     plot_path, update)
            
            # Evaluate model on validation set
            if (update + 1) % validate_at == 0:
                val_loss = evaluate_model(net, loader=val_loader, loss_fn=mse, device=device)
                writer.add_scalar(tag="Loss/validation", scalar_value=val_loss, global_step=update)
                # Save best model for early stopping
                if val_loss < best_validation_loss:
                    best_validation_loss = val_loss
                    torch.save(net, saved_model_file)
            
            update_progress_bar.set_description(f"loss: {loss:7.5f}", refresh=True)
            update_progress_bar.update()
            
            # Increment update counter, exit if maximum number of updates is
            # reached. Here, we could apply some early stopping heuristic and
            # also exit if its stopping criterion is met
            update += 1
            if update >= n_updates:
                break
    update_progress_bar.close()
    writer.close()
    print("Finished Training!")
    
    # Load best model and compute score on test set
    print(f"Computing scores for best model")
    net = torch.load(saved_model_file)
    train_loss = evaluate_model(net, loader=train_loader, loss_fn=mse, device=device)
    val_loss = evaluate_model(net, loader=val_loader, loss_fn=mse, device=device)
    test_loss = evaluate_model(net, loader=test_loader, loss_fn=mse, device=device)
    
    print(f"Scores:")
    print(f"  training loss: {train_loss}")
    print(f"validation loss: {val_loss}")
    print(f"      test loss: {test_loss}")
    
    # Write result to file
    with open(os.path.join(results_path, "results.txt"), "w") as rf:
        print(f"Scores:", file=rf)
        print(f"  training loss: {train_loss}", file=rf)
        print(f"validation loss: {val_loss}", file=rf)
        print(f"      test loss: {test_loss}", file=rf)

# Start training

In [9]:
import json


with open("C:\\Users\\Bashar Hanna\\Desktop\\py2proj\\config_file.json") as cf:
    config = json.load(cf)
main(**config)

loss: 25.06704: 100%|██████████| 50000/50000 [50:45<00:00, 16.42it/s]   


Finished Training!
Computing scores for best model


                                                           

Scores:
  training loss: 24.931518743435543
validation loss: 61.228685569763186
      test loss: 69.21344728469849


