In [None]:
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from autoencoder import AutoencoderDataset, ConvAutoencoder

In [None]:
torch.manual_seed(0)

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
ad = AutoencoderDataset(DATA_DIR='./data', transform=torch.tensor)
print(ad.__len__())

train, test = random_split(ad, [9/10, 1/10])
train_dl = DataLoader(train, batch_size=5, shuffle=True)
test_dl = DataLoader(test, batch_size=5)

In [None]:
global_step = 0

writer = SummaryWriter()

model = ConvAutoencoder()
if global_step != 0:
    model.load_state_dict(torch.load(f'./checkpoints/unet/{global_step}.pth'))

model.double()
model.to(device)

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
 
# Train the autoencoder
num_epochs = 1500
min_loss = 1
for epoch in range(num_epochs):
    for i, data in enumerate(train_dl):
        x, _ = data
        x = x.to(device)
        optimizer.zero_grad()
        output = model(x)
        train_loss = criterion(output, x)
        train_loss.backward()
        optimizer.step()

        global_step += 1

        if global_step % 500 == 0:
            print('Epoch [{}/{}], Batch: {}, Global Step: {}, Train Loss: {:.4f}'.format(epoch+1, num_epochs, i, global_step, train_loss.item()))
            if global_step > 0:
                torch.save(model.state_dict(), f'./checkpoints/unet/{global_step}.pth')
        
        writer.add_scalar('Loss/train', train_loss.item(), global_step=global_step)
    
    test_loss = 0

    for data in test_dl:
        x, _ = data
        x = x.to(device)
        output = model(x)
        test_loss += criterion(x, output).item()

    writer.add_scalar('Loss/test', test_loss, global_step=global_step)
        

In [None]:
x, _ = next(iter(test_dl))
best_model = ConvAutoencoder()
best_model.load_state_dict(torch.load('./checkpoints/unet/14500.pth', weights_only=True))
best_model.to('cpu')
best_model.double()
plt.title('Reconstrução (DTG)')
plt.plot(x[0][1], label='Original')
y = best_model(x).detach().numpy()
plt.plot(y[0][1], label='Autoencoder')
plt.legend()
plt.show()
plt.close()

In [None]:
ad = AutoencoderDataset(transform=torch.tensor)
dl = DataLoader(ad, batch_size=1)

model = ConvAutoencoder()
model.load_state_dict(torch.load('checkpoints/unet/14500.pth', weights_only=True))
model.to('cpu')
model.double()

codes = None
for x, _ in dl:
  code = model.encode(x).detach().numpy()
  codes = code if codes is None else np.concatenate([codes, code], axis=0)

np.savetxt('encoding.txt', codes)