In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# Parameters
batch_size = 64
learning_rate = 1e-3

# Transform for data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to Tensor
])

In [None]:
# Download and load EuroSAT dataset
import requests
import zipfile
import os

# Directory to check
directory_path = '../data/eurosat/2750/'

# URL of the file to be downloaded
url = 'https://madm.dfki.de/files/sentinel/EuroSAT.zip'
filename = url.split('/')[-1]  # Extract filename
zip_path = os.path.join('./', filename)  # Path to save the zip file

# Check if the specific directory exists
if not(os.path.exists(directory_path)):
    print(f"Directory {directory_path} not found. Proceeding with download...")
    
    # Download the file
    try:
        # For a more secure approach, replace verify=False with the path to a valid CA bundle
        response = requests.get(url, verify=False)  # Not recommended for production
        with open(zip_path, 'wb') as f:
            f.write(response.content)
        print("Download complete.")
        
        # Unzip the file
        print("Unzipping file...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall('../data/eurosat/')  # Extract to the specified directory
        print("Unzipping complete.")
        
        # Optionally, remove the zip file after extraction
        # os.remove(zip_path)
    except requests.exceptions.SSLError as e:
        print(f"SSL Error occurred: {e}")
    
    # Optionally, remove the zip file after extraction
    os.remove(zip_path)
else:
    print(f"Directory {directory_path} already exists. Aborting download.")

In [None]:
dataset = datasets.EuroSAT(root='../data/', download=False, transform=transform)


# Split dataset into training and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def show_images(images, nmax=2):
    fig, ax = plt.subplots(figsize=(nmax * 5, 5))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(images.detach()[:nmax], nrow=nmax).permute(1, 2, 0))

# Get a batch of training data
images, _ = next(iter(train_loader))

# Show images
show_images(images, nmax=5)

In [None]:
# Function to add Gaussian noise
def add_gaussian_noise(images, mean=0., std=0.2):
    return torch.clip(images + torch.randn_like(images) * std + mean, 0, 1)

In [None]:
# Define the autoencoder
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Dropout(0.25)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Dropout(0.25),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [None]:
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Create a writer instance
writer = SummaryWriter('../data/nn/runs/autoencoder_experiment_1')

In [None]:
# Initialize early stopping parameters
best_val_loss = float('inf')  # Initialize with a very high value
patience = 20  # Number of epochs to wait for improvement before stopping the training
patience_counter = 0  # Counter to keep track of epochs without improvement


# Training
num_epochs = 300
# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for data in train_loader:
        img, _ = data
        img = img.to(device)
        noise_img = add_gaussian_noise(img).to(device)
        optimizer.zero_grad()
        output = model(noise_img)
        loss = criterion(output, img)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * img.size(0)
    train_loss = train_loss / len(train_loader.dataset)

    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in val_loader:
            img, _ = data
            img = img.to(device)
            noise_img = add_gaussian_noise(img).to(device)
            output = model(noise_img)
            loss = criterion(output, img)
            val_loss += loss.item() * img.size(0)
    val_loss = val_loss / len(val_loader.dataset)

    # Check for improvement
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0  # Reset counter if there's an improvement
    else:
        patience_counter += 1  # Increment counter if no improvement
    
    # Check if early stopping is needed
    if patience_counter > patience:
        print(f'Early stopping triggered at epoch {epoch+1}. No improvement in validation loss for {patience} consecutive epochs.')
        break  # Break out of the loop to stop training
        
    # Log the losses
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/val', val_loss, epoch)
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

print('Training completed.')

In [None]:
torch.save(model.state_dict(), "../data/nn/ae_trained.pth")