In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch import Tensor
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import Sigmoid
from torch.nn import Module
from torch.optim import SGD
from torch.nn import MSELoss
from torch.nn.init import kaiming_uniform_
from torch.nn.init import xavier_uniform_

from models.neuralnetwork.fully_connected import FullyConnected4Layers

In [None]:
class DWDataset(Dataset):
    def __init__(self, spherical_harmonics_coefficients, measurements):
        self.X = measurements
        self.y = spherical_harmonics_coefficients
 
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        return [self.X[index], self.y[index]]

## Data preparation

In [None]:
dataset = DWDataset(...)

generator = torch.Generator().manual_seed(1)

train, validation, test = random_split(dataset, [0.6,0.2,0.2], generator=generator)

train_data_loader = DataLoader(train, batch_size=32, shuffle=True)
validation_data_loader = DataLoader(validation, batch_size=32, shuffle=True)
test_data_loader = DataLoader(test, batch_size=1024, shuffle=False)

## Model training

In [None]:
loss = MSELoss()
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
model = FullyConnected4Layers(number_of_inputs=..., number_of_outputs=...)

In [None]:
for epoch in range(100):
    # Mini batches
    for i, (inputs, targets) in enumerate(train_data_loader):
        optimizer.zero_grad()
        
        yhat = model(inputs)
        
        loss_evaluation = loss(yhat, targets)
        
        loss_evaluation.backward()
        
        optimizer.step()

## Model evaluation

In [None]:
predictions, actuals = list(), list()

for i, (inputs, targets) in enumerate(test_data_loader):
    yhat = model(inputs)
    
    yhat = yhat.detach().numpy()
    
    actual = targets.numpy()
    actual = actual.reshape((len(actual), 1))
    
    predictions.append(yhat)
    
    actuals.append(actual)

predictions, actuals = vstack(predictions), vstack(actuals)

mse_final_loss = loss(predictions, actuals)