In [10]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from statistics import mean

from custom_classes import NasaDataset, SimpleAE, split_dataset, fix_seeds

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'{device=}')

device='cuda'


### Training

In [12]:
train_dataset = NasaDataset('../datasets/clean_train_data.csv')
fix_seeds(37)
train_dataset, test_dataset = split_dataset(train_dataset, test_size=0.1)
fix_seeds(37)
train_dataset, val_dataset = split_dataset(train_dataset, test_size=0.3)
# test_dataset = NasaDataset('../datasets/test.csv')

fix_seeds(37)
train_dataset.to(device)
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
fix_seeds(37)
test_dataset.to(device)
test_loader = DataLoader(val_dataset, batch_size=20, shuffle=True)
fix_seeds(37)
val_dataset.to(device)
val_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=..., shuffle=...)
print(f'Train: {len(train_dataset)}\nValidation: {len(val_dataset)}\nTest: {len(test_dataset)}')

input_shape = train_dataset.get_input_shape()
layers_sizes = (10, 5, 2)

model_ae = SimpleAE(input_shape, layers_sizes).to(device)
loss_func = nn.MSELoss()
optimiser = optim.Adam(model_ae.parameters(),
                       lr=1e-3)

Train: 13723
Validation: 4840
Test: 2068


In [13]:
epochs = 100
history = list()

for epoch in range(epochs):
    train_losses = list()
    for dta in train_loader:
        sample = dta['sensors']
        sample = sample.to(device)
        _, reconstruction = model_ae(sample)

        loss = loss_func(reconstruction, sample)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        train_losses.append(loss.item())
    
    with torch.no_grad():
        val_losses = list()
        for dta in val_loader:
            sample = dta['sensors']
            sample = sample.to(device)
            _, reconstruction = model_ae(sample)
            loss = loss_func(reconstruction, sample)
            val_losses.append(loss.item())
    
    train_loss = mean(train_losses)
    val_loss = mean(val_losses)
    history.append((epoch, train_loss, val_loss))
    if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
        print(f'{epoch+1:>3}/{epochs:>3}: {train_loss=}, {val_loss=}')

with torch.no_grad():
    test_losses = list()
    for sample in test_loader:
        sample = dta['sensors']
        sample = sample.to(device)
        _, reconstruction = model_ae(sample)
        loss = loss_func(reconstruction, sample)
        test_losses.append(loss.item())
    test_loss = mean(test_losses)

1/100: train_loss=9191527.774381367, val_loss=9161167.890829694
11/100: train_loss=330.76452005933294, val_loss=312.93900405683894
21/100: train_loss=312.34488664652065, val_loss=312.04952779711596
31/100: train_loss=313.08681269261166, val_loss=315.1933172411967
41/100: train_loss=314.1661903125924, val_loss=312.0009934315688
51/100: train_loss=312.7718149598856, val_loss=312.67325780832437
61/100: train_loss=313.0810111655245, val_loss=315.82958942174565
71/100: train_loss=312.86564441886355, val_loss=313.55725058787425
81/100: train_loss=313.1939313019554, val_loss=311.9957031627583
91/100: train_loss=312.87294622696123, val_loss=312.4889407539645
100/100: train_loss=312.9467437166661, val_loss=315.7480779922685


In [17]:
test_loss

181.62550354003906