In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

import  os, json
from tqdm import tqdm 

from dataset import dataset
from models.encoder.custom_autoencoder import CustomAutoEncoder
from utils.model_control import run_train_test_1_input
from utils.custom_logger import get_logger

In [15]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [12]:
portion_dataset = torch.load('dataset/portion_dataset.pth')

In [13]:
train_size = int(len(portion_dataset)*0.8)
test_size = len(portion_dataset)-train_size

train, test = random_split(portion_dataset, [train_size, test_size])

In [14]:
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

In [16]:
model = CustomAutoEncoder(input_dim=4, latent_dim=512)

for i in model.parameters():
  i.requires_grad = True

model.train()
model.to(device)

CustomAutoEncoder(
  (encoder): Sequential(
    (0): Linear(in_features=12, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): GELU(approximate='none')
  )
  (decoder): Sequential(
    (0): Linear(in_features=512, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=1024, out_features=12, bias=True)
    (4): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
)

In [17]:
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
num_epochs = 100
test_step = 10
logger = get_logger('autoencoder_portion')

In [None]:
run_train_test_1_input(
    model=model, 
    train_loader=train_loader, 
    test_loader=test_loader, 
    criterion=criterion, 
    optimizer=optimizer, 
    num_epochs=num_epochs, 
    test_step=test_step, 
    logger=logger, 
    device=device,
    path='.checkpoints/best/ae_portion_best.pth'
)