In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision
import torchvision.transforms as T
from sklearn.model_selection import KFold
import os
import copy
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!python --version

In [None]:
%pip install codecarbon comet_ml

In [None]:
from comet_ml import Experiment
from codecarbon import EmissionsTracker
from datetime import datetime

# Initialise and start CodeCarbon tracker
tracker = EmissionsTracker()
tracker.start()

start_time = datetime.now()
print(f'Start time is {start_time}')

# Initialise the Comet experiment
experiment = Experiment(
    api_key="XXXXXXXXXXXXXXXXXXXXXXXXX",
    project_name="",
    workspace="",
)

In [None]:
# List of numbers in a sequence

sequences = [
    [1, 2, 3, 6, 8],
    [1, 4, 3, 6, 8],
    [1, 5, 3, 6, 8],

    [1, 2, 3, 6, 9],
    [1, 4, 3, 6, 9],
    [1, 5, 3, 6, 9],

    [1, 2, 3, 7, 8],
    [1, 4, 3, 7, 8],
    [1, 5, 3, 7, 8],

    [1, 2, 3, 7, 9],
    [1, 4, 3, 7, 9],
    [1, 5, 3, 7, 9],
    #________________ 12

    [1, 2, 6, 7, 8],
    [1, 4, 6, 7, 8],
    [1, 5, 6, 7, 8],

    [1, 2, 6, 7, 9],
    [1, 4, 6, 7, 9],
    [1, 5, 6, 7, 9],

    [1, 2, 6, 3, 8],
    [1, 4, 6, 3, 8],
    [1, 5, 6, 3, 8],

    [1, 2, 6, 3, 9],
    [1, 4, 6, 3, 9],
    [1, 5, 6, 3, 9],
    #________________ 24
    [1, 2, 7, 3, 8],
    [1, 4, 7, 3, 8],
    [1, 5, 7, 3, 8],

    [1, 2, 7, 3, 9],
    [1, 4, 7, 3, 9],
    [1, 5, 7, 3, 9],

    [1, 2, 7, 9, 8],
    [1, 4, 7, 9, 8],
    [1, 5, 7, 9, 8],

    [1, 2, 7, 8, 9],
    [1, 4, 7, 8, 9],
    [1, 5, 7, 8, 9],
    #________________ 36
    [2, 3, 4, 7, 8],
    [2, 5, 4, 7, 8],
    [2, 6, 4, 7, 8],

    [2, 3, 4, 7, 9],
    [2, 5, 4, 7, 9],
    [2, 6, 4, 7, 9],

    [2, 3, 4, 1, 8],
    [2, 5, 4, 1, 8],
    [2, 6, 4, 1, 8],

    [2, 3, 4, 1, 9],
    [2, 5, 4, 1, 9],
    [2, 6, 4, 1, 9],
    #________________ 48
    [2, 3, 7, 8, 9],
    [2, 5, 7, 8, 9],
    [2, 6, 7, 8, 9],

    [2, 3, 7, 9, 8],
    [2, 5, 7, 9, 8],
    [2, 6, 7, 9, 8],

    [2, 3, 7, 8, 1],
    [2, 5, 7, 8, 1],
    [2, 6, 7, 8, 1],

    [2, 3, 7, 9, 1],
    [2, 5, 7, 9, 1],
    [2, 6, 7, 9, 1],
    #________________ 60
    [2, 3, 8, 9, 1],
    [2, 5, 8, 9, 1],
    [2, 6, 8, 9, 1],

    [2, 3, 8, 9, 7],
    [2, 5, 8, 9, 7],
    [2, 6, 8, 9, 7],

    [2, 3, 8, 7, 1],
    [2, 5, 8, 7, 1],
    [2, 6, 8, 7, 1],

    [2, 3, 8, 1, 7],
    [2, 5, 8, 1, 7],
    [2, 6, 8, 1, 7],
    #________________ 72
    [3, 4, 5, 8, 2],
    [3, 6, 5, 8, 2],
    [3, 7, 5, 8, 2],

    [3, 4, 5, 2, 8],
    [3, 6, 5, 2, 8],
    [3, 7, 5, 2, 8],

    [3, 4, 5, 8, 1],
    [3, 6, 5, 8, 1],
    [3, 7, 5, 8, 1],

    [3, 4, 5, 1, 8],
    [3, 6, 5, 1, 8],
    [3, 7, 5, 1, 8],
    #_______________ 84
    [3, 4, 8, 9, 2],
    [3, 6, 8, 9, 2],
    [3, 7, 8, 9, 2],

    [3, 4, 8, 2, 9],
    [3, 6, 8, 2, 9],
    [3, 7, 8, 2, 9],

    [3, 4, 8, 9, 1],
    [3, 6, 8, 9, 1],
    [3, 7, 8, 9, 1],

    [3, 4, 8, 1, 9],
    [3, 6, 8, 1, 9],
    [3, 7, 8, 1, 9],
    #_______________ 96
    [3, 4, 9, 8, 2],
    [3, 6, 9, 8, 2],
    [3, 7, 9, 8, 2],

    [3, 4, 9, 2, 8],
    [3, 6, 9, 2, 8],
    [3, 7, 9, 2, 8],

    [3, 4, 9, 8, 1],
    [3, 6, 9, 8, 1],
    [3, 7, 9, 8, 1],

    [3, 4, 9, 1, 8],
    [3, 6, 9, 1, 8],
    [3, 7, 9, 1, 8],
    #_______________ 108
    [4, 5, 6, 9, 2],
    [4, 7, 6, 9, 2],
    [4, 8, 6, 9, 2],

    [4, 5, 6, 2, 9],
    [4, 7, 6, 2, 9],
    [4, 8, 6, 2, 9],

    [4, 5, 6, 9, 3],
    [4, 7, 6, 9, 3],
    [4, 8, 6, 9, 3],

    [4, 5, 6, 3, 9],
    [4, 7, 6, 3, 9],
    [4, 8, 6, 3, 9],
    #_______________ 120
    [4, 5, 9, 3, 2],
    [4, 7, 9, 3, 2],
    [4, 8, 9, 3, 2],

    [4, 5, 9, 2, 3],
    [4, 7, 9, 2, 3],
    [4, 8, 9, 2, 3],

    [4, 5, 9, 1, 3],
    [4, 7, 9, 1, 3],
    [4, 8, 9, 1, 3],

    [4, 5, 9, 3, 1],
    [4, 7, 9, 3, 1],
    [4, 8, 9, 3, 1],
    #_______________ 132
    [4, 5, 1, 3, 2],
    [4, 7, 1, 3, 2],
    [4, 8, 1, 3, 2],

    [4, 5, 1, 2, 3],
    [4, 7, 1, 2, 3],
    [4, 8, 1, 2, 3],

    [4, 5, 1, 9, 3],
    [4, 7, 1, 9, 3],
    [4, 8, 1, 9, 3],

    [4, 5, 1, 3, 9],
    [4, 7, 1, 3, 9],
    [4, 8, 1, 3, 9],
    #_______________ 144
    [5, 6, 7, 1, 2],
    [5, 8, 7, 1, 2],
    [5, 9, 7, 1, 2],

    [5, 6, 7, 2, 1],
    [5, 8, 7, 2, 1],
    [5, 9, 7, 2, 1],

    [5, 6, 7, 3, 2],
    [5, 8, 7, 3, 2],
    [5, 9, 7, 3, 2],

    [5, 6, 7, 2, 3],
    [5, 8, 7, 2, 3],
    [5, 9, 7, 2, 3],
    #_______________ 156
    [5, 6, 1, 2, 3],
    [5, 8, 1, 2, 3],
    [5, 9, 1, 2, 3],

    [5, 6, 1, 3, 2],
    [5, 8, 1, 3, 2],
    [5, 9, 1, 3, 2],

    [5, 6, 1, 3, 4],
    [5, 8, 1, 3, 4],
    [5, 9, 1, 3, 4],

    [5, 6, 1, 4, 3],
    [5, 8, 1, 4, 3],
    [5, 9, 1, 4, 3],
    #_______________ 168
    [5, 6, 3, 4, 1],
    [5, 8, 3, 4, 1],
    [5, 9, 3, 4, 1],

    [5, 6, 3, 1, 4],
    [5, 8, 3, 1, 4],
    [5, 9, 3, 1, 4],

    [5, 6, 3, 4, 2],
    [5, 8, 3, 4, 2],
    [5, 9, 3, 4, 2],

    [5, 6, 3, 2, 4],
    [5, 8, 3, 2, 4],
    [5, 9, 3, 2, 4],
    #_______________ 180
    [6, 7, 8, 2, 3],
    [6, 9, 8, 2, 3],
    [6, 1, 8, 2, 3],

    [6, 7, 8, 3, 2],
    [6, 9, 8, 3, 2],
    [6, 1, 8, 3, 2],

    [6, 7, 8, 2, 5],
    [6, 9, 8, 2, 5],
    [6, 1, 8, 2, 5],

    [6, 7, 8, 5, 2],
    [6, 9, 8, 5, 2],
    [6, 1, 8, 5, 2],
    #_______________ 192
    [6, 7, 2, 3, 4],
    [6, 9, 2, 3, 4],
    [6, 1, 2, 3, 4],

    [6, 7, 2, 4, 3],
    [6, 9, 2, 4, 3],
    [6, 1, 2, 4, 3],

    [6, 7, 2, 4, 5],
    [6, 9, 2, 4, 5],
    [6, 1, 2, 4, 5],

    [6, 7, 2, 5, 4],
    [6, 9, 2, 5, 4],
    [6, 1, 2, 5, 4],
    #_______________ 204
    [6, 7, 3, 4, 5],
    [6, 9, 3, 4, 5],
    [6, 1, 3, 4, 5],

    [6, 7, 3, 5, 4],
    [6, 9, 3, 5, 4],
    [6, 1, 3, 5, 4],

    [6, 7, 3, 2, 5],
    [6, 9, 3, 2, 5],
    [6, 1, 3, 2, 5],

    [6, 7, 3, 5, 2],
    [6, 9, 3, 5, 2],
    [6, 1, 3, 5, 2],
    #_______________ 216
]

In [None]:
dataset = torchvision.datasets.MNIST("./", train = False, download = True)

In [None]:
# Patch images with labels
def find_match(number):
    for entry in dataset:
        if entry[1] == number:
            match = entry
            break
    return match

image_sequences = []
image_label_sequences = []
for sequence in sequences:
    image_sequence = []
    image_label_sequence = []
    for num in sequence:
        entry = find_match(num)
        image_sequence.append(entry[0])
        image_label_sequence.append(entry)
    image_sequences.append(image_sequence)
    image_label_sequences.append(image_label_sequence)

In [None]:
# Convert to matrices and resize

tf = T.Compose([
     T.Resize((28)),
     T.ToTensor() # Returns a tensor with normalized values between 0 and 1
])
seqs = []

for seq in image_sequences:
    img_tensors = []
    for img in seq:
        img_tensor = tf(img)
        img_tensors.append(img_tensor)
    img_tensors_stack = torch.stack(img_tensors)
    seqs.append(img_tensors_stack)

seqs_stack = torch.stack(seqs)
seqs_reshaped = seqs_stack.reshape(216, 1, 5, 28, 28)

In [None]:
### Utils
def collate(batch):
    batch = torch.stack(batch)    
    #batch = batch / 255.0
    batch = batch.to(device)    
    return batch[:,:,0:4], batch[:,:,4]

def reset_weights(m):
  '''
    Try resetting model weights to avoid
    weight leakage.
  '''
  for layer in m.children():
        if hasattr(layer, 'reset_parameters'):
            print(f'Reset trainable parameters of layer = {layer}')
            layer.reset_parameters()

In [None]:
train_loader = DataLoader(seqs_reshaped, batch_size=1, collate_fn=collate, drop_last=True)
data, target = next(iter(train_loader))

In [None]:
### ConvLSTM cell and layer
### ConvLSTM cell
class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):
        super(ConvLSTMCell, self).__init__()
        if activation == "tanh":
            self.activation = torch.tanh # activation="tanh"
        elif activation == "relu":
            self.activation = torch.relu
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels, # num_channels=1 num_kernels=28 = 29
            out_channels=4 * out_channels, # 28 x 4 = 112
            kernel_size=kernel_size,  # kernel_size=(3, 3)
            padding=padding)  # padding=(1, 1)         

        # Initialize weights for Hadamard Products - with torch.rand the values are between 0 and 1
        self.W_ci = nn.Parameter(torch.rand(out_channels, *frame_size)) # out-channels=28
        self.W_co = nn.Parameter(torch.rand(out_channels, *frame_size)) # frame_size=(28, 28)
        self.W_cf = nn.Parameter(torch.rand(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):
        # X is a frame
        
        conv_output = self.conv(torch.cat([X, H_prev], dim=1)) # concatenate x and hidden state
        #print(conv_output.shape) # 8 matrices of size torch.Size([15, 112, 28, 28])
   
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)

        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev)
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev)
        
        # Prevent NaNs - convert it to numbers
        input_gate_without_nan = torch.nan_to_num(input_gate)
        forget_gate_without_nan = torch.nan_to_num(forget_gate)

        # Current Cell output
        current_cell_state = forget_gate_without_nan*C_prev + input_gate_without_nan * self.activation(C_conv)
        
        output_gate = torch.sigmoid(o_conv + self.W_co * current_cell_state)
        output_gate_without_nan = torch.nan_to_num(output_gate)

        # Current Hidden State
        current_hidden_state = output_gate_without_nan * self.activation(current_cell_state)
        
        return current_hidden_state, current_cell_state

### ConvLSTM layer
class ConvLSTM(nn.Module):
    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTM, self).__init__()
        self.out_channels = out_channels

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, 
        kernel_size, padding, activation, frame_size)

    def forward(self, X):
        # X is a frame sequence (batch_size, num_channels, seq_len, height, width)
        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, self.out_channels, seq_len, 
        height, width, device=device)
        
        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels, 
        height, width, device=device)

        # Initialize Cell Input
        C = torch.zeros(batch_size,self.out_channels, 
        height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):
            H, C = self.convLSTMcell(X[:,:,time_step], H, C)
            output[:,:,time_step] = H
        return output

In [None]:
## Try convLSTM cell initialisation
# Initialize output
batch_size=1
in_channels=1
out_channels=28
kernel_size=(3, 3)
padding=(1, 1)
activation="tanh"
frame_size=(28, 28)
output = torch.zeros(batch_size, out_channels, seq_len, height, width, device=device)
print(output.shape)
        
# Initialize Hidden State
H = torch.zeros(batch_size, out_channels, height, width, device=device)
print(H.shape)

# Initialize Cell Input
C = torch.zeros(batch_size, out_channels, height, width, device=device)
print(C.shape)

convLSTMcell = ConvLSTMCell(in_channels, out_channels, 
        kernel_size, padding, activation, frame_size)

In [None]:
### ConvLSTM model 1 layer

class Seq2Seq(nn.Module):
    def __init__(self, num_channels, num_kernels, kernel_size, padding, 
    activation, frame_size):
        super(Seq2Seq, self).__init__()
        self.sequential = nn.Sequential()

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding, 
                activation=activation, frame_size=frame_size)
        )

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
        ) 

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=kernel_size, padding=padding)

    def forward(self, X):
        # Forward propagation through all the layers
        output = self.sequential(X)

        # Return only the last output frame
        output = self.conv(output[:,:,-1])        
        return nn.Sigmoid()(output)

In [None]:
### K-fold Cross Validator
# Params
torch.manual_seed(42)
num_epochs = 100
criterion = nn.BCELoss(reduction='sum')

# Fold results storage objects
train_start_results = {}
val_start_results = {}

train_end_results = {}
val_end_results = {}

# Per fold epoch results storage objects
train_results_per_epoch = []
val_results_per_epoch = []

train_results = []
val_results = []

# Define the K-fold Cross Validator
k_folds = 5
kfold = KFold(n_splits=k_folds, shuffle=True)

# Whole dataset
dataset = seqs_reshaped

# K-fold Cross Validation model evaluation
for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset)):
    print(f'FOLD {fold}')
    
    # Sample elements randomly from a given list of ids, no replacement.
    train_subsampler = SubsetRandomSampler(train_ids)
    val_subsampler = SubsetRandomSampler(val_ids)
    
    # Define data loaders for training and testing data in this fold
    train_loader = DataLoader(dataset, batch_size=15, collate_fn=collate, sampler=train_subsampler)
    val_loader = DataLoader(dataset, batch_size=15, collate_fn=collate, sampler=val_subsampler)
    
    # Initialization
    model = Seq2Seq(num_channels=1, num_kernels=28, kernel_size=(3, 3), padding=(1, 1), activation="tanh", frame_size=(28, 28)).to(device)
    model.apply(reset_weights)  
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    train_results_per_epoch = []
    val_results_per_epoch = []
    for epoch in range(1, num_epochs+1):
        train_loss = 0                                                 
        model.train()        
        for batch, (x, y) in enumerate(train_loader, 1):
            output = model(x)
            
            optimizer.zero_grad()
            
            loss_train = criterion(output.flatten(), y.flatten())       
            loss_train.backward()
            
            optimizer.step()
                                      
            train_loss += loss_train.item()                                 
        total_train_loss = train_loss / len(train_loader.dataset)
        train_results_per_epoch.append(total_train_loss)
        
        val_loss = 0                                                 
        model.eval()                                                   
        with torch.no_grad():                                          
            for x, y in val_loader:                          
                output = model(x)                                   
                loss_val = criterion(output.flatten(), y.flatten())                
                val_loss += loss_val.item()                                
        total_val_loss = val_loss / len(val_loader.dataset)
        val_results_per_epoch.append(total_val_loss)
        
        # Store scores        
        if epoch == 1:
            val_start_results[fold] = total_val_loss
            train_start_results[fold] = total_train_loss
        else:
            val_end_results[fold] = total_val_loss
            train_end_results[fold] = total_train_loss
            
        print("Epoch:{} Training Loss:{:.2f} Validation Loss:{:.2f}\n".format(
            epoch, total_train_loss, total_val_loss))
    train_results.append(train_results_per_epoch)
    val_results.append(val_results_per_epoch)
    
    # Saving the model
    save_path = f'./convlstm-seqgmnist200-model-fold-{fold}.pth'
    torch.save(model.state_dict(), save_path)

In [None]:
# Stop CO2 tracker and print emissions

emissions: float = tracker.stop()
print(f"Emissions: {emissions} kg")

# Calculate the time spent
stop_time = datetime.now() - start_time
time_spend = start_time - stop_time

# Time logs
experiment.log_metric("start_time", start_time) 
experiment.log_metric("stop_time", stop_time)
experiment.log_metric("time_spend", time_spend)

# Turn off Comet
experiment.end()

In [None]:
# Print start fold results
print(f'Start K-FOLD RESULTS FOR {k_folds} FOLDS')
sum = 0.0
for key, value in train_start_results.items():
    print(f'Fold {key}: {value}')
    sum += value
print(f'Average train: {sum/len(train_start_results.items())}')

sum = 0.0
for key, value in val_start_results.items():
    print(f'Fold {key}: {value}')
    sum += value
print(f'Average val: {sum/len(val_start_results.items())}')

In [None]:
# Print final fold results
print(f'End K-FOLD RESULTS FOR {k_folds} FOLDS')
sum = 0.0
for key, value in train_end_results.items():
    print(f'Fold {key}: {value}')
    sum += value
print(f'Average train: {sum/len(train_end_results.items())}')

sum = 0.0
for key, value in val_end_results.items():
    print(f'Fold {key}: {value}')
    sum += value
print(f'Average val: {sum/len(val_end_results.items())}')

In [None]:
# Train and validation results

import matplotlib.pyplot as plt
x = list(range(0, 100))

fig, ax = plt.subplots()
t1, = ax.plot(x, train_results[0], c="blue")
t2, = ax.plot(x, train_results[1], c="brown")
t3, = ax.plot(x, train_results[2], c="green")
t4, = ax.plot(x, train_results[3], c="orange")
t5, = ax.plot(x, train_results[4], c="magenta")
v1, = ax.plot(x, val_results[0], c="blue", ls="dashed")
v2, = ax.plot(x, val_results[1], c="brown", ls="dashed")
v3, = ax.plot(x, val_results[2], c="green", ls="dashed")
v4, = ax.plot(x, val_results[3], c="orange", ls="dashed")
v5, = ax.plot(x, val_results[4], c="magenta", ls="dashed")
ax.legend((t1, t2, t3, t4, t5, v1, v2, v3, v4, v5), ('1st train fold', '2nd train fold', "3rd train fold", "4th train fold", "5th train fold", '1st val fold', '2nd val fold', "3rd val fold", "4th val fold", "5th val fold"), loc='upper right', shadow=True)
ax.set_xlabel('epochs')
ax.set_ylabel('loss')
ax.set_title('Train and validation results for 5 folds')
plt.show()

In [None]:
# Train and validation results

import matplotlib.pyplot as plt
x = list(range(0, 100))

fig, ax = plt.subplots()
t1, = ax.plot(x, train_results[0], c="blue")
t2, = ax.plot(x, train_results[1], c="brown")
v1, = ax.plot(x, val_results[0], c="blue", ls="dashed")
v2, = ax.plot(x, val_results[1], c="brown", ls="dashed")
ax.legend((t1, t2, v1, v2), ('1st train fold', '2nd train fold', '1st val fold', '2nd val fold'), loc='upper right', shadow=True)
ax.set_xlabel('epochs')
ax.set_ylabel('loss')
ax.set_title('Train and validation results for 5 folds')
plt.show()

In [None]:
# Inference for 15 sequences
data_loader = DataLoader(dataset, batch_size=15, collate_fn=collate, drop_last=True, shuffle=True)
data, target = next(iter(data_loader))

model.eval()                                                   
with torch.no_grad():                                          
    output = model(data)

# Reshape targets and generated
targets = target.reshape(15, 28, 28, 1)
imgs_gen = output.reshape(15, 28, 28, 1)

# Join tensors for a singe image
combined = torch.cat((targets, imgs_gen), 0)
combined.shape

In [None]:
fig = plt.figure(figsize=(15, 15))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(2, 15),  # creates 2x2 grid of axes
                 axes_pad=0.05,  # pad between axes in inch.
                 )

for ax, im in zip(grid, combined):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)
    
plt.show()

In [None]:
import json

with open("./" + "train_results.json", 'w') as outfile:
    json.dump(train_results, outfile)
with open("./" + "val_results.json", 'w') as outfile:
    json.dump(val_results, outfile)