# Experiment 2

In [56]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Imports

In [57]:
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import train_test_split
from latent_vector import visualize_latent_space, create_metadata_list, log_embeddings

from plant_village_dataset import PlantVillageDataset

## Prepare Data

In [58]:
DEVICE = 'mps'

In [59]:
dataset = PlantVillageDataset('images')

Loading Plant Village
 - Normalizing dataset


 - Calculating mean and standard deviation: 100%|██████████| 867/867 [01:42<00:00,  8.43batch/s]

 - Normalized dataset:
  - Mean: [0.4671, 0.4895, 0.4123]
  - Standard deviation: [0.1709, 0.1443, 0.1880]





In [60]:
idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}

In [61]:
labels = dataset.get_labels()
labels = np.array(labels)

unused_indices, used_indices = train_test_split(np.arange(len(dataset)), test_size=0.1, stratify=labels)

used_sampler = SubsetRandomSampler(used_indices)

dataloader = DataLoader(dataset, batch_size=64, sampler=used_sampler, num_workers=12)

## Write Visualization

##### Autoencoder

In [62]:
encoder = torch.load('./models/uae_1.pth', map_location=torch.device(DEVICE)).encoder
encoder.to(DEVICE)
pass

In [63]:
writer = SummaryWriter('runs/latent_space_visualization/uae')

reduced_latents, labels = visualize_latent_space(encoder, dataloader, DEVICE, use_tsne=False)

metadata = create_metadata_list(labels, idx_to_class)

log_embeddings(writer, reduced_latents, metadata)

##### Encoder + MLP

In [64]:
encoder = torch.load('./models/emlp_2_b.pth', map_location=torch.device(DEVICE)).encoder
encoder.to(DEVICE)
pass

In [65]:
writer = SummaryWriter('runs/latent_space_visualization/emlp')

reduced_latents, labels = visualize_latent_space(encoder, dataloader, DEVICE, use_tsne=False)

metadata = create_metadata_list(labels, idx_to_class)

log_embeddings(writer, reduced_latents, metadata)