In [2]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from statistics import mean

from custom_classes import NasaDataset, SimpleAE

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'{device=}')

### Training

In [None]:
train_dataset = NasaDataset('../datasets/train.csv')
test_dataset = NasaDataset('../datasets/test.csv')

train_loader = DataLoader(train_dataset, batch_size=..., shuffle=...)
test_loader = DataLoader(test_dataset, batch_size=..., shuffle=...)

input_shape = train_dataset.get_input_shape()
layers_sizes = (..., ..., ...)

model_ae = SimpleAE(input_shape, layers_sizes)
loss_func = nn.MSELoss()
optimiser = optim.Adam(model_ae.parameters(),
                       lr=1e-3)

In [None]:
epochs = ...
history = list()

for epoch in range(epochs):
    train_losses = list()
    for (sample, _) in train_loader:
        sample = sample.to(device)
        _, reconstruction = model_ae(sample)

        loss = loss_func(reconstruction, sample)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        train_losses.append(loss)
    
    with torch.no_grad():
        test_losses = list()
        for (sample, _) in test_loader:
            sample = sample.to(device)
            _, reconstruction = model_ae(sample)
            loss = loss_func(reconstruction, sample)
            test_losses.append(loss)
    
    history.append((epoch, mean(train_losses), mean(test_losses)))