In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from Dataloader import prepare_data_loader
from Models import ConvAutoencoder

In [2]:
# training_data_dir = './data/lunar/training/data/S12_GradeA/'
# training_labels_file = './data/lunar/training/catalogs/apollo12_catalog_GradeA_final.csv'

# train_loader = prepare_data_loader(overlap=0.25, window_length=1, decimation_factor=3, spect_nfft=128, spect_nperseg=128, batch_size=128, data_dir = training_data_dir, labels_file_path=training_labels_file)

In [None]:
training_data_dir = './data/seismic_autoencoder_data/'
training_labels_file = './data/lunar/training/catalogs/apollo12_catalog_GradeA_final.csv'

# training_data_dir = './data/apollo/'
autoencoder_loader = prepare_data_loader(overlap=0.25, window_length=1, decimation_factor=3, spect_nfft=128, spect_nperseg=128, batch_size=128, data_dir = training_data_dir, labels_file_path=training_labels_file)

In [4]:
class Encoder(nn.Module):
    def __init__(self, hidden_size):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=4, stride=1, padding=1)
        self.max_pool_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(3, 2, kernel_size=5, stride=1, padding=1)
        self.max_pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(480, hidden_size)  
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        # self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.max_pool_1(x)
        x = self.relu(self.conv2(x))
        x = self.max_pool_2(x)
        x = self.fc1(x.reshape(x.shape[0],-1))
        # print(x.shape)
        x = self.tanh(x)

        return x

class Decoder(nn.Module):
    def __init__(self, hidden_size):
        super(Decoder, self).__init__()        
        self.conv1 = nn.ConvTranspose2d(2, 3, 
							kernel_size=5, 
							stride=2, 
							padding=1, 
							output_padding=1)
        self.conv2 = nn.ConvTranspose2d(3, 1, 
							kernel_size=[4, 5], 
							stride=2, 
							padding=1, 
							output_padding=1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(hidden_size, 480)

    def forward(self, x):
        x = self.relu(self.fc1(x))

        x = x.reshape(-1, 2, 15, 16)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x


class ConvAutoencoder(nn.Module):
    def __init__(self, hidden_size):
        super(ConvAutoencoder, self).__init__()
        self.encoder = Encoder(hidden_size)
        self.decoder = Decoder(hidden_size)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def save_encoder(self, path):
        torch.save(self.encoder.state_dict(), path)

In [5]:
# model = Autoencoder()
model = ConvAutoencoder(hidden_size=100)

#load model
# model.load_state_dict(torch.load('./models/witek_autoencoder.pth'))

In [6]:
height, width = autoencoder_loader.dataset.tensors[0].shape[1:]

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model.to(device)

In [8]:
# # Define the loss function and optimizer
# criterion = nn.L1Loss()
# # criterion = nn.KLDivLoss(reduction='batchmean')
# # criterion = nn.KLDivLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.00005)

# # Train the autoencoder
# num_epochs = 100
# for epoch in range(num_epochs):
# 	epoch_loss = 0
# 	epoch_recon_loss = 0
# 	epoch_stretch_loss = 0
# 	for data in train_loader:
# 		img, _ = data
# 		img = img.reshape(-1,1,height, width).to(device)
# 		optimizer.zero_grad()
# 		output = model(img)
# 		embd_space = model.encoder(img)
# 		recon_loss = criterion(output, img)
# 		# stretch_loss = torch.abs(torch.mean(1-(embd_space[:,0]*embd_space[:,0] + embd_space[:,1]*embd_space[:,1])))
# 		# loss = recon_loss+0.0*stretch_loss
# 		# loss.backward()
# 		recon_loss.backward()
# 		# epoch_loss += loss.item()
# 		epoch_recon_loss += recon_loss.item()
# 		# epoch_stretch_loss += stretch_loss.item()
# 		optimizer.step()
# 	if epoch % 5== 0:
# 		# print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_recon_loss:.4f}')
# 		print(f'Epoch [{epoch+1}/{num_epochs}], Recon Loss: {epoch_recon_loss:.4f}, Stretch Loss: {epoch_stretch_loss:.4f}')


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Define the loss functions and optimizer
criterion_recon = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train the autoencoder
num_epochs = 200
for epoch in range(num_epochs):
    epoch_recon_loss = 0

    for x_data in autoencoder_loader:
        x_data = x_data[0].reshape(-1, 1, height, width).to(device)
        
        optimizer.zero_grad()
        
        # Forward pass through the model
        recon = model(x_data)
        embd_space = model.encoder(x_data)
        
        recon_loss = criterion_recon(recon, x_data)
        
        recon_loss.backward()
        optimizer.step()
        
        epoch_recon_loss += recon_loss.item()

    if epoch % 5 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Recon Loss: {epoch_recon_loss:.4f}')

In [None]:
with torch.no_grad():
	for data in autoencoder_loader:
		data = data[0].reshape(-1,1,height, width).to(device)
		recon = model(data)
		break
		
import matplotlib.pyplot as plt
plt.figure(dpi=250)
fig, ax = plt.subplots(2, 7, figsize=(15, 4))
for i in range(7):
	ax[0, i].imshow(data[i].cpu().numpy().transpose((1, 2, 0)))
	ax[1, i].imshow(recon[i].cpu().numpy().transpose((1, 2, 0)))
	ax[0, i].axis('OFF')
	ax[1, i].axis('OFF')
plt.show()

In [None]:
plt.figure(figsize=(10, 4))
model.eval()
with torch.no_grad():
    img = next(iter(autoencoder_loader))
    print(img[1][0])
    img = img[0].reshape(-1, 1, height, width).to(device)
    output = model(img)
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    ax[0].imshow(img[0].cpu().numpy().squeeze())
    ax[1].imshow(output[0].cpu().numpy().squeeze())

In [29]:
# autoencoder.save_encoder('./models/encoder.pth')
# torch.save(model.state_dict(), "./models/witek_autoencoder_2.pth")