----
----
# <b> DLMI Challenge </b>
# <b> Autoencoders training </b>
# <b> Matteo MARENGO | matteo.marengo@ens-paris-saclay.fr </b>
# <b> Manal MEFTAH | manal.meftah@ens-paris-saclay.fr </b>

----
----
# <b> Import libraries </b>

In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

----
----
# <b> Load the model </b>

In [None]:
# Parameters
image_size = 224
batch_size = 32
epochs = 50
learning_rate = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CombinedDataset(Dataset):
    def __init__(self, folders, transform=None):
        self.folders = folders
        self.transform = transform
        self.file_list = []
        for folder_path in folders:
            self.file_list.extend([
                os.path.join(dp, f) for dp, dn, filenames in os.walk(folder_path) for f in filenames if os.path.splitext(f)[1].lower() in ['.png', '.jpg', '.jpeg']
            ])

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, image  # Autoencoder input and output are the same

# Transformations
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load datasets
combined_dataset = CombinedDataset(['/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/trainset', '/kaggle/input/dlmi-mms-data/dlmi-lymphocytosis-classification/testset'], transform=transform)
combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),  # Output: (64, 112, 112)
            nn.ReLU(True),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),  # Output: (64, 56, 56)
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # Output: (128, 28, 28)
            nn.ReLU(True),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),  # Output: (128, 14, 14)
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # Output: (256, 7, 7)
            nn.ReLU(True),
            nn.Flatten(),  # Flatten the output for the fully connected layer
            nn.Linear(256*7*7, 4096),  # Compress to latent space
            nn.ReLU(True)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(4096, 256*7*7),  # Expand from latent space
            nn.ReLU(True),
            nn.Unflatten(1, (256, 7, 7)),  # Unflatten to get back to the convolutional tensor shape
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Output: (128, 14, 14)
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # Output: (64, 28, 28)
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='nearest'),  # Output: (64, 56, 56)
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),  # Output: (3, 112, 112)
            nn.Upsample(scale_factor=2, mode='nearest'),  # Output: (3, 224, 224)
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

----
----
# <b> Do the training </b>

In [None]:
# Model, loss function, and optimizer setup
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training the model with tqdm progress bar
epoch_losses = []
model.train()
for epoch in range(epochs):
    epoch_loss = 0.0
    n_batches = 0
    with tqdm(total=len(combined_loader), desc=f'Epoch {epoch+1}/{epochs}', position=0, leave=True) as pbar:
        for data in combined_loader:
            imgs, _ = data
            imgs = imgs.to(device)  # Move inputs to the device
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, imgs)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            n_batches += 1
            pbar.update(1)
            pbar.set_postfix({'Loss': loss.item()})
    average_epoch_loss = epoch_loss / n_batches
    epoch_losses.append(average_epoch_loss)
    print(f'Epoch {epoch+1}, Average Loss: {average_epoch_loss}')

# After training, plot the loss evolution
plt.figure(figsize=(10, 6))
plt.plot(epoch_losses, label='Training Loss')
plt.title('Loss Evolution')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig('/kaggle/working/loss_evolution.png')
plt.show()

# Save the encoder weights
torch.save(model.encoder.state_dict(), '/kaggle/working/encoder_weights_full_images.pth')

# Save the decoder weights
torch.save(model.decoder.state_dict(), '/kaggle/working/decoder_weights_full_images.pth')

print("Model trained and encoder weights saved!")
