# Train Autoencoder for Embedding Extraction

In [None]:
%load_ext autoreload
%autoreload 2

## Imports

In [None]:
import tlc
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from timm import create_model
from tlc_tools.common import infer_torch_device

## Project Setup

In [None]:
PROJECT_NAME = "3LC Tutorials"
TRANSIENT_DATA_PATH = "../../transient_data"
CHECKPOINT_PATH = TRANSIENT_DATA_PATH + "/autoencoder_model.pth"
BACKBONE = "resnet50"

## Load Input Table

In [None]:
table = tlc.Table.from_names("initial", "CIFAR-10-val", "3LC Tutorials")

In [None]:
# Prepare Data
transform = transforms.Compose([
    transforms.ToTensor(),          # Convert to tensor
])

def map_fn(sample):
    image = sample[0]
    image = transform(image)
    return image

table.clear_maps()
table.map(map_fn)

table[0]

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, backbone_name='resnet50', embedding_dim=512):
        super(Autoencoder, self).__init__()
        
        # Load the backbone as an encoder
        self.encoder = create_model(backbone_name, pretrained=True, num_classes=0)
        encoder_output_dim = self.encoder.feature_info[-1]['num_chs']
        
        # Add a projection layer to reduce to embedding_dim
        self.projector = nn.Linear(encoder_output_dim, embedding_dim)
        
        # Define the decoder
        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, encoder_output_dim),
            nn.ReLU(),
            nn.Linear(encoder_output_dim, 32 * 32 * 3),  # Assuming input images are 128x128x3
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # Encoder
        features = self.encoder(x)
        embeddings = self.projector(features)
        
        # Decoder
        reconstructions = self.decoder(embeddings)
        reconstructions = reconstructions.view(x.size(0), 3, 32, 32)
        return embeddings, reconstructions


In [None]:
# Initialize the model
embedding_dim = 512  # Desired embedding dimension
model = Autoencoder(backbone_name=BACKBONE, embedding_dim=embedding_dim)

# Training Components
criterion = nn.MSELoss()  # Reconstruction loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


# Load your dataset
dataloader = DataLoader(table, batch_size=32, shuffle=True)

device = infer_torch_device()
model.to(device)

In [None]:
# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    
    for images in dataloader:
        images = images.to(device)
        
        # Forward pass
        embeddings, reconstructions = model(images)
        
        # Compute loss
        loss = criterion(reconstructions, images)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss / len(dataloader):.4f}")

In [None]:
unreduced_loss = nn.MSELoss(reduction="none")  # Reconstruction loss

def metrics_fn(batch, predictor_output):
    embeddings, reconstructions = predictor_output.forward
    reconstructed_images = [transforms.ToPILImage()(image.cpu()) for image in reconstructions]
    reconstruction_loss = unreduced_loss(reconstructions.to(device), batch.to(device)).mean(dim=(1, 2, 3))
    return {
        "embeddings": embeddings.cpu().detach().numpy(),
        "reconstructions": reconstructed_images,
        "reconstruction_loss": reconstruction_loss.cpu().detach().numpy()
    }

In [None]:
run = tlc.init(project_name="3LC Tutorials")

tlc.collect_metrics(
    table,
    metrics_fn,
    model,
    collect_aggregates=False,
    dataloader_args={"batch_size": 32}
)

run.set_status_completed()

In [None]:
run.reduce_embeddings_by_foreign_table_url(
    table.url,
    source_embedding_column="embeddings",
    method="pacmap",
)

In [None]:
# Save the model
torch.save(model.state_dict(), CHECKPOINT_PATH)
print(f"Model saved to {CHECKPOINT_PATH}")