In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from Dataloader import prepare_data_loader

In [None]:
training_data_dir = './data/lunar/training/data/S12_GradeA/'
training_labels_file = './data/lunar/training/catalogs/apollo12_catalog_GradeA_final.csv'

train_loader = prepare_data_loader(overlap=0.25, window_length=1, decimation_factor=3, spect_nfft=128, spect_nperseg=128, batch_size=32, data_dir = training_data_dir, labels_file_path=training_labels_file)

In [3]:
class Encoder(nn.Module):
    def __init__(self, hidden_size):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.max_pool_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1)
        self.max_pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(2176, hidden_size)        
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.max_pool_1(x)
        x = self.relu(self.conv2(x))
        x = self.max_pool_2(x)
        # print(x.shape)
        x = self.fc(x.reshape(x.shape[0],-1))
        # print(x.shape)

        return x


class Decoder(nn.Module):
    def __init__(self, hidden_size):
        super(Decoder, self).__init__()        
        self.conv1 = nn.ConvTranspose2d(8, 16, 
							kernel_size=3, 
							stride=2, 
							padding=1, 
							output_padding=1)
        self.conv2 = nn.ConvTranspose2d(16, 1, 
							kernel_size=[4, 5], 
							stride=2, 
							padding=1, 
							output_padding=1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.fc1 = nn.Linear(hidden_size, 2176)

    def forward(self, x):
        x = self.fc1(x)
        x = x.reshape(-1, 8, 16, 17)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x


class ConvAutoencoder(nn.Module):
    def __init__(self, hidden_size):
        super(ConvAutoencoder, self).__init__()
        self.encoder = Encoder(hidden_size)
        self.decoder = Decoder(hidden_size)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def save_encoder(self, path):
        torch.save(self.encoder.state_dict(), path)

In [4]:
# Define the autoencoder architecture
class Autoencoder(nn.Module):
	def __init__(self):
		super(Autoencoder, self).__init__()
		self.encoder = nn.Sequential(
			nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
			nn.ReLU(),
			nn.MaxPool2d(kernel_size=2, stride=2),
			nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
			nn.ReLU(),
			nn.MaxPool2d(kernel_size=2, stride=2)
		)
		self.decoder = nn.Sequential(
			nn.ConvTranspose2d(8, 16, 
							kernel_size=3, 
							stride=2, 
							padding=1, 
							output_padding=1),
			nn.ReLU(),
			nn.ConvTranspose2d(16, 1, 
							kernel_size=[4, 5], 
							stride=2, 
							padding=1, 
							output_padding=1),
			nn.Sigmoid()
		)
		
	def forward(self, x):
		x = self.encoder(x)
		x = self.decoder(x)
		return x


# Initialize the autoencoder

# Define transform
# transform = transforms.Compose([
# 	transforms.Resize((64, 64)),
# 	transforms.ToTensor(),
# ])

# # Load dataset
# train_dataset = datasets.Flowers102(root='flowers', 
# 									split='train', 
# 									transform=transform, 
# 									download=True)
# test_dataset = datasets.Flowers102(root='flowers', 
# 								split='test', 
# 								transform=transform)
# # Define the dataloader
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
# 										batch_size=128, 
# 										shuffle=True)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
# 		
# 								batch_size=128)

In [5]:
# model = Autoencoder()
model = ConvAutoencoder(hidden_size=100)

In [6]:
height, width = train_loader.dataset.tensors[0].shape[1:]

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model.to(device)

In [None]:
# Define the loss function and optimizer
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

# Train the autoencoder
num_epochs = 100
for epoch in range(num_epochs):
	epoch_loss = 0
	for data in train_loader:
		img, _ = data
		img = img.reshape(-1,1,height, width).to(device)
		optimizer.zero_grad()
		output = model(img)
		loss = criterion(output, img)
		loss.backward()
		epoch_loss += loss.item()
		optimizer.step()
	if epoch % 5== 0:
		print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, epoch_loss))

# Save the model
torch.save(model.state_dict(), 'conv_autoencoder.pth')


In [None]:
with torch.no_grad():
	for data, _ in train_loader:
		data = data.reshape(-1,1,height, width).to(device)
		recon = model(data)
		break
		
import matplotlib.pyplot as plt
plt.figure(dpi=250)
fig, ax = plt.subplots(2, 7, figsize=(15, 4))
for i in range(7):
	ax[0, i].imshow(data[i].cpu().numpy().transpose((1, 2, 0)))
	ax[1, i].imshow(recon[i].cpu().numpy().transpose((1, 2, 0)))
	ax[0, i].axis('OFF')
	ax[1, i].axis('OFF')
plt.show()


In [None]:
plt.figure(figsize=(10, 4))
model.eval()
with torch.no_grad():
    img, label = next(iter(train_loader))
    img = img.reshape(-1, 1, height, width).to(device)
    output = model(img)
    # plt.imshow(output[0].cpu().numpy().squeeze())
    # plt.axis('off')
    #make subplot 2x1
    print(label[0])
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    ax[0].imshow(img[0].cpu().numpy().squeeze())
    ax[1].imshow(output[0].cpu().numpy().squeeze())