# Experiment 1

In [1]:
%load_ext autoreload
%autoreload 2

## Imports

In [2]:
import gc
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch import nn, optim
from torchinfo import summary

from plant_village_dataset import PlantVillageDataset
from runner import Runner
from convnext import ConvNext
from unet_autoencoder import UNetAutoencoder

## Prepare Data

In [3]:
BATCH_SIZE = 128

In [4]:
def split(dataset, batch_size, labeled_ratio, test_ratio):    
    labels = np.array([label for _, label in dataset])

    unlabeled_indices, labeled_indices = train_test_split(np.arange(len(dataset)),
                                                          test_size=labeled_ratio,
                                                          stratify=labels)   
    
    ul_train_indices, ul_val_indices = train_test_split(unlabeled_indices, test_size=0.1)
    
    relative_test_ratio = test_ratio / labeled_ratio
    
    train_val_indices, test_indices = train_test_split(labeled_indices,
                                                       test_size=relative_test_ratio,
                                                       stratify=labels[labeled_indices])
    
    train_indices, val_indices = train_test_split(train_val_indices,
                                                  test_size=0.2,
                                                  stratify=labels[train_val_indices])

    ul_train_sampler = SubsetRandomSampler(ul_train_indices)
    ul_val_sampler = SubsetRandomSampler(ul_val_indices)
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    test_sampler = SubsetRandomSampler(test_indices)

    ul_train_loader = DataLoader(dataset, batch_size=batch_size, sampler=ul_train_sampler)
    ul_val_loader = DataLoader(dataset, batch_size=batch_size, sampler=ul_val_sampler)
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler)
    test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)

    return ul_train_loader, ul_val_loader, train_loader, val_loader, test_loader

In [5]:
class ReconstructionDataLoader:
    def __init__(self, base_loader):
        self.base_loader = base_loader

    def __iter__(self):
        for data in self.base_loader:
            images, _ = data  # Ignore labels or other types of data
            yield images, images  # Yield images as both input and target

    def __len__(self):
        return len(self.base_loader)

In [6]:
dataset = PlantVillageDataset('images')

Loading Plant Village
 - Normalizing dataset


 - Calculating mean and standard deviation: 100%|██████████| 434/434 [01:08<00:00,  6.29batch/s]

 - Normalized dataset:
  - Mean: [0.4671, 0.4895, 0.4123]
  - Standard deviation: [0.1709, 0.1443, 0.1880]





## Run 1

In [7]:
ul_train_loader, ul_val_loader, train_loader, val_loader, test_loader = split(dataset, batch_size=BATCH_SIZE, labeled_ratio=0.2, test_ratio=0.1)

ul_train_loader = ReconstructionDataLoader(ul_train_loader)
ul_val_loader = ReconstructionDataLoader(ul_val_loader)

##### CNN

In [None]:
cnn = ConvNext(num_classes=len(dataset.classes))
cnn_optim = optim.Adam(cnn.parameters(), lr=1e-3)
cnn_criterion = nn.CrossEntropyLoss()
cnn_runner = Runner('cnn_1', cnn, cnn_optim, cnn_criterion, device='mps')
cnn_runner.train(train_loader, val_loader, num_epochs=3)
cnn_runner.test(test_loader)
pass

Downloading: "https://download.pytorch.org/models/convnext_tiny-983f1562.pth" to /Users/ariel.arevalo/.cache/torch/hub/checkpoints/convnext_tiny-983f1562.pth
100%|██████████| 109M/109M [01:27<00:00, 1.31MB/s] 


Training:   0%|          | 0/3 [00:00<?, ? epoch/s]

Training:   0%|          | 0/35 [00:00<?, ?batch/s]

##### Autoencoder

In [None]:
uae = UNetAutoencoder()
uae_optim = optim.Adam(uae.parameters(), lr=1e-3)
uae_criterion = nn.MSELoss()
uae_runner = Runner('uae_1', uae, uae_optim, uae_criterion, device='mps')
uae_runner.train(ul_train_loader, ul_val_loader, num_epochs=3)

enc = uae.encoder

##### Frozen Encoder + MLP

In [None]:
# Train one Frankenstein with the Encoder's weights set to eval() (Frozen)

##### Live Encoder + MLP

In [None]:
# Train the second Frankenstein normally

##### Cleanup

In [None]:
del ul_train_loader, ul_val_loader, train_loader, val_loader, test_loader
gc.collect()

## Run 2

In [None]:
ul_train_loader, ul_val_loader, train_loader, val_loader, test_loader = split(dataset, batch_size=BATCH_SIZE, labeled_ratio=0.5, test_ratio=0.15)

ul_train_loader = ReconstructionDataLoader(ul_train_loader)
ul_val_loader = ReconstructionDataLoader(ul_val_loader)

##### CNN

In [None]:
cnn = ConvNext(num_classes=len(dataset.classes))
cnn_optim = optim.Adam(cnn.parameters(), lr=1e-3)
cnn_criterion = nn.CrossEntropyLoss()
cnn_runner = Runner('cnn_2', cnn, cnn_optim, cnn_criterion, device='mps')
cnn_runner.train(train_loader, val_loader, num_epochs=3)
cnn_runner.test(test_loader)
pass

##### Autoencoder

In [None]:
# Declare UNetAutoEncoder
# Train UNetAutoEncoder
# Extract Encoder from UNetAutoEncoder

##### Frozen Encoder + MLP

In [None]:
# Train one Frankenstein with the Encoder's weights set to eval() (Frozen)

##### Live Encoder + MLP

In [None]:
# Train the second Frankenstein normally