In [2]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from statistics import mean

from custom_classes import NasaDataset, SimpleAE, split_dataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'{device=}')

### Training

In [None]:
train_dataset = NasaDataset('../datasets/train.csv')
train_dataset, test_dataset = split_dataset(train_dataset, test_size=0.1)
train_dataset, val_dataset = split_dataset(train_dataset, test_size=0.2)
# test_dataset = NasaDataset('../datasets/test.csv')

train_loader = DataLoader(train_dataset, batch_size=..., shuffle=...)
test_loader = DataLoader(val_dataset, batch_size=..., shuffle=...)
val_loader = DataLoader(train_dataset, batch_size=..., shuffle=...)
# test_loader = DataLoader(test_dataset, batch_size=..., shuffle=...)

input_shape = train_dataset.get_input_shape()
layers_sizes = (..., ..., ...)

model_ae = SimpleAE(input_shape, layers_sizes)
loss_func = nn.MSELoss()
optimiser = optim.Adam(model_ae.parameters(),
                       lr=1e-3)

In [None]:
epochs = 100
history = list()

for epoch in range(epochs):
    train_losses = list()
    for (sample, _) in train_loader:
        sample = sample.to(device)
        _, reconstruction = model_ae(sample)

        loss = loss_func(reconstruction, sample)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        train_losses.append(loss)
    
    with torch.no_grad():
        val_losses = list()
        for (sample, _) in val_loader:
            sample = sample.to(device)
            _, reconstruction = model_ae(sample)
            loss = loss_func(reconstruction, sample)
            val_losses.append(loss)
    
    train_loss = mean(train_losses)
    val_loss = mean(val_losses)
    history.append((epoch, train_loss, val_loss))
    if (epoch) % 10 == 0 or epoch == epochs - 1:
        print(f'{epoch+1}/{epochs}: {train_loss=}, {val_loss=}')

with torch.no_grad():
    test_losses = list()
    for (sample, _) in test_loader:
        sample = sample.to(device)
        _, reconstruction = model_ae(sample)
        loss = loss_func(reconstruction, sample)
        test_losses.append(loss)
    test_loss = mean(test_losses)