## Imports and config

In [1]:
import os
import torch
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import torch.nn as nn
from tqdm import tqdm
import config

In [2]:
if config.USE_UNET: import unet as autoencoder
else: import convolutional_autoencoder as autoencoder

In [3]:
labeled_train_relative_set_size = round((config.LABELED_TRAIN_SET_ABSOLUTE_SIZE / (1 - config.UNLABELED_SET_SIZE)), 2)
labeled_test_relative_set_size = 1 - labeled_train_relative_set_size

In [4]:
# Set up of the device
if torch.backends.mps.is_available():
    device = torch.device("mps")#tqm
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: mps


## Datasets and Dataloaders

In [5]:
data_transforms = {
    'bce': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]),
    'mse': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=0.4543, std=0.1757)
    ]),
    'dispersion_calc': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
}

In [6]:
full_dataset = datasets.ImageFolder(config.BASE_DIR_RAW, transform=data_transforms[config.AE_TRANSFORMS])
full_noisy_dataset = datasets.ImageFolder(config.BASE_DIR_NOISY, transform=data_transforms[config.AE_TRANSFORMS])

In [7]:
indices = list(range(len(full_dataset)))
noisy_indices =list(range(len(full_noisy_dataset)))

# Get the directory paths of images
image_paths = [sample[0] for sample in full_dataset.samples]
noisy_image_paths = [sample[0] for sample in full_noisy_dataset.samples]

labels = [os.path.split(os.path.dirname(path))[-1] for path in image_paths]
noisy_labels = [os.path.split(os.path.dirname(path))[-1] for path in noisy_image_paths]

In [8]:
#Obtenemos el 20% de los datos 
val_indices, train_indices = train_test_split(indices, test_size=config.UNLABELED_SET_SIZE, stratify=labels, random_state=42)#Obtenemos el 20% de los datos 
noisy_val_indices, noisy_train_indices = train_test_split(noisy_indices, test_size=config.UNLABELED_SET_SIZE, stratify=noisy_labels, random_state=42)

#Obtenemos las etiquetas de los datos de entrenamiento
val_labels = [labels[i] for i in val_indices]
noisy_val_labels = [noisy_labels[i] for i in noisy_val_indices]

#dividir el 20% en 10% de entrenamiento y 10% de validación
_, val_indices = train_test_split(val_indices, test_size=labeled_test_relative_set_size, stratify=val_labels, random_state=42)
_, noisy_val_indices = train_test_split(noisy_val_indices, test_size=labeled_test_relative_set_size, stratify=noisy_val_labels, random_state=42)

train_dataset = Subset(full_dataset, train_indices)
noisy_train_dataset = Subset(full_noisy_dataset, noisy_train_indices)
val_dataset = Subset(full_dataset, val_indices)
noisy_val_dataset = Subset(full_dataset, noisy_val_indices)

In [9]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
noisy_loader = DataLoader(noisy_train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)
noisy_val_loader = DataLoader(noisy_val_dataset, batch_size=4, shuffle=False, num_workers=4)

print(f"Número de imágenes en el conjunto de entrenamiento: {len(train_loader.dataset)}")
print(f"Número de imágenes en el conjunto de entrenamiento: {len(noisy_loader.dataset)}")
print(f"Número de imágenes en el conjunto de validación: {len(val_loader.dataset)}")
print(f"Número de imágenes en el conjunto de validación: {len(noisy_val_loader.dataset)}")

Número de imágenes en el conjunto de entrenamiento: 49189
Número de imágenes en el conjunto de entrenamiento: 49189
Número de imágenes en el conjunto de validación: 6149
Número de imágenes en el conjunto de validación: 6149


In [10]:
# # Initialize variables for mean and std calculation
# mean = 0.0
# std = 0.0
# nb_samples = 0

# for data, _ in train_loader:
#     data = data.to(device)
#     batch_samples = data.size(0)  # number of images in the batch
#     data = data.view(batch_samples, -1)  # flatten the channel and spatial dimensions
#     mean += data.mean(1).sum(0)
#     std += data.std(1).sum(0)
#     nb_samples += batch_samples

# mean /= nb_samples
# std /= nb_samples

# print("Mean:", mean)
# print("Std:", std)

## Function definitions

In [11]:
def train_model(model, criterion, optimizer, train_loader, val_loader, device, num_epochs=10, patience=3):
    best_val_loss = float('inf')
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for inputs, _ in tqdm(train_loader):
            inputs = inputs.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, inputs)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        train_loss = running_loss / len(train_loader)
        
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for inputs, _ in val_loader:
                inputs = inputs.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, inputs)
                val_loss += loss.item()
        
        val_loss = val_loss / len(val_loader)
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f}')
        print(f'Val Loss: {val_loss:.4f}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), config.AUTOENCODER_SAVE_PATH)
        else:
            epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping triggered!")
            break
    
    return model

In [12]:
def train_denoising_model(model, criterion, optimizer, noisy_loader, original_loader, val_loader, noisy_val_loader, device, num_epochs=10, patience=3):
    best_val_loss = float('inf')
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        full_loader = zip(noisy_loader, original_loader)
        model.train()
        running_loss = 0.0
        
        for batch in tqdm(full_loader):
            inputs, _ = batch[0]
            loss_inputs, _ = batch[1]
            inputs = inputs.to(device)
            loss_inputs = loss_inputs.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, loss_inputs)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        train_loss = running_loss / len(train_loader)
        
        model.eval()
        val_loss = 0.0
        
        full_val_loader = zip(noisy_val_loader, val_loader)
        
        with torch.no_grad():
            for batch in full_val_loader:
                inputs, _ = batch[0]
                loss_inputs, _ = batch[1]
                inputs = inputs.to(device)
                loss_inputs = loss_inputs.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, loss_inputs)
                val_loss += loss.item()
        
        val_loss = val_loss / len(val_loader)
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f}')
        print(f'Val Loss: {val_loss:.4f}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), config.DENOISING_AUTOENCODER_SAVE_PATH)
        else:
            epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping triggered!")
            break
    
    return model

In [13]:
def load_model(model_path, device):
    model = autoencoder.AutoEncoder()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    return model

In [14]:
def evaluate_model(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            inputs[0] = inputs[0] * 0.1757 + 0.4543
            save_image(inputs[0], 'input.jpg')
            outputs[0] = outputs[0] * 0.1757 + 0.4543
            save_image(outputs[0], 'output.jpg')
            break

In [15]:
def evaluate_denoising_model(model, noisy_loader, original_loader, device):
    model.eval()
    full_loader = zip(noisy_loader, original_loader)
    with torch.no_grad():
        for batch in full_loader:
            inputs, _ = batch[0]
            inputs = inputs.to(device)
            outputs = model(inputs)
            inputs[0] = inputs[0] * 0.1515 + 0.4726
            save_image(inputs[0], 'noisy_input.jpg')
            outputs[0] = outputs[0] * 0.1515 + 0.4726
            save_image(outputs[0], 'denoised_output.jpg')
            break

## U-Net Autoencoder

In [16]:
model = autoencoder.AutoEncoder()
model.to(device)

criterion = config.LOSS_FUNCTION
optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

In [17]:
if not config.USE_DENOISING_AUTOENCODER:
	train_model(model, criterion, optimizer, train_loader, val_loader, device, num_epochs=config.EPOCHS, patience=config.PATIENCE)
else:
	train_denoising_model(model, criterion, optimizer, noisy_loader, train_loader, val_loader, noisy_val_loader, device, num_epochs=config.EPOCHS, patience=config.PATIENCE)

  1%|          | 146/12298 [00:20<04:21, 46.51it/s] 

In [None]:
if not config.USE_DENOISING_AUTOENCODER:
	model = load_model(config.AUTOENCODER_SAVE_PATH, device)
else:
	model = load_model(config.DENOISING_AUTOENCODER_SAVE_PATH, device)

In [None]:
if not config.USE_DENOISING_AUTOENCODER:
	evaluate_model(model, val_loader, device)
else:
	evaluate_denoising_model(model, noisy_loader, train_loader, device)