In [None]:
# !pip install ray[tune]
# !pip install grpcio
# !pip install grpcio-tools

In [1]:
import os
import torch
import tempfile
import lightning.pytorch as pl
import torch.nn.functional as F
import lightning as L
from torch import nn
from torch.utils.data import Dataset
from filelock import FileLock
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import transforms
from torchvision import models
from torchvision.transforms import v2
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)
from sklearn.model_selection import train_test_split
import random
from PIL import Image, ImageOps
from tqdm.notebook import tqdm
from glob import glob

In [2]:
print(torch.cuda.get_arch_list())

['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_37', 'sm_90']


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
OBJECTS_ROOT = r"/home/shared/datasets/pollen_dataset_2024_05_08_objects"

In [5]:
def pad2size(image, size=(224, 224)):
    return ImageOps.fit(image, size)

class PollenTripletDataset(Dataset):
    def __init__(self, img_dir, classes, 
                 static_transform=None, 
                 random_transform=None,
                inflation_rate=1):
        self.known_classes = classes
        self.known_classes.sort()
        self.transform = random_transform
        self._images = {}
        self._labels = {}
        self.inflation_rate=inflation_rate

        total_samples = 0
        for c in tqdm(self.known_classes):
            self._images[c] = glob(os.path.join(img_dir, c, "*.png"))
            print(f"{len(self._images[c])} samples in the class {c}")
            total_samples += len(self._images[c])
            self._images[c] = [ Image.open(img_path) for img_path in tqdm(self._images[c])] 
            self._labels[c] = [c] * len(self._images[c])  # Метки остаются строками
            if static_transform is not None:
                self._images[c] = [ static_transform(img) for img in tqdm(self._images[c])] 
        print(f"Total samples: {total_samples}")
        self._images = list(self._images.values())
        self._labels = list(self._labels.values())

    def __len__(self):
        return len(self._images) * self.inflation_rate

    def __getitem__(self, idx):
        idx = idx % len(self._images)
        anchor = random.choice(self._images[idx])
        label = self._labels[idx][0]
        positive = random.choice(self._images[idx])
        neg_weights = [1]*len(self._images)
        neg_weights[idx] = 0
        negative = random.choices(self._images, weights=neg_weights)[0]
        negative = random.choice(negative)
        
        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)

        return anchor, positive, negative, label


known_classes = os.listdir(OBJECTS_ROOT)

print(f"All known classes: {known_classes}")
train_classes, test_classes = train_test_split(known_classes, test_size=0.2, random_state=42)
train_classes, val_classes = train_test_split(train_classes, test_size=0.2, random_state=42)

print(f"Train classes: {train_classes}")
print(f"Validation classes: {val_classes}")
print(f"Test classes: {test_classes}")

input_size=(224, 224)

static_transforms = v2.Compose([
    v2.Lambda(pad2size),
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


print("\nLoad Train Dataset")
dataset_train = PollenTripletDataset(OBJECTS_ROOT, train_classes, static_transform=static_transforms, inflation_rate=100)
print("\nLoad Validation Dataset")
dataset_val = PollenTripletDataset(OBJECTS_ROOT, val_classes, static_transform=static_transforms, inflation_rate=100)
print("\nLoad Test Dataset")
dataset_test = PollenTripletDataset(OBJECTS_ROOT, test_classes, static_transform=static_transforms, inflation_rate=100)

All known classes: ['Quercus', 'Tilia', 'Corylus', 'Acer', 'Populus tremula', 'Betula', 'Alnus', 'Pinus', 'Salix']
Train classes: ['Acer', 'Salix', 'Populus tremula', 'Corylus', 'Alnus']
Validation classes: ['Betula', 'Quercus']
Test classes: ['Pinus', 'Tilia']

Load Train Dataset


  0%|          | 0/5 [00:00<?, ?it/s]

319 samples in the class Acer


  0%|          | 0/319 [00:00<?, ?it/s]

  0%|          | 0/319 [00:00<?, ?it/s]

25 samples in the class Alnus


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

141 samples in the class Corylus


  0%|          | 0/141 [00:00<?, ?it/s]

  0%|          | 0/141 [00:00<?, ?it/s]

201 samples in the class Populus tremula


  0%|          | 0/201 [00:00<?, ?it/s]

  0%|          | 0/201 [00:00<?, ?it/s]

227 samples in the class Salix


  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

Total samples: 913

Load Validation Dataset


  0%|          | 0/2 [00:00<?, ?it/s]

206 samples in the class Betula


  0%|          | 0/206 [00:00<?, ?it/s]

  0%|          | 0/206 [00:00<?, ?it/s]

90 samples in the class Quercus


  0%|          | 0/90 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

Total samples: 296

Load Test Dataset


  0%|          | 0/2 [00:00<?, ?it/s]

67 samples in the class Pinus


  0%|          | 0/67 [00:00<?, ?it/s]

  0%|          | 0/67 [00:00<?, ?it/s]

114 samples in the class Tilia


  0%|          | 0/114 [00:00<?, ?it/s]

  0%|          | 0/114 [00:00<?, ?it/s]

Total samples: 181


In [6]:
training_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size= 16,
    shuffle= True,
    num_workers= 6
)
val_loader = torch.utils.data.DataLoader(
    dataset_val, 
    batch_size= 4,
    shuffle= True,
    num_workers= 6
)
test_loader = torch.utils.data.DataLoader(
    dataset_test, 
    batch_size= 1,
    shuffle= True,
    num_workers= 6
)

In [7]:
class EmbeddingModel(nn.Module):
    def __init__(self, backbone, embedings):
        super().__init__()
        if backbone == "resnet18":
            self.model = models.resnet18(weights='DEFAULT')
        elif backbone == "resnet34":
            self.model = models.resnet34(weights='DEFAULT')
        elif backbone == "resnet50":
            self.model = models.resnet50(weights='DEFAULT')
        elif backbone == "resnet101":
            self.model = models.resnet101(weights='DEFAULT')
        elif backbone == "resnet152":
            self.model = models.resnet152(weights='DEFAULT')
        else:
            raise ValueError("Unsupported backbone")
        num_features = self.model.fc.in_features
        self.embedings = nn.Linear(num_features, embedings)
        self.model.fc = nn.Identity()

    def forward(self, x):
        x = self.model(x)
        x = self.embedings(x)
        x = nn.functional.normalize(x) # L2 normalization to put all values on a sphere
        return x

class PollenEmbedingsModule(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.optim_lr=config["optim_lr"]
        self.model = EmbeddingModel(config["backbone"], config["embedings_size"])
        
        config["check_val"]=20
        config["optim_betas"]=(0.9, 0.999)
        config["optim_eps"]=1e-08
        config["optim_weight_decay"]=0
        # call this to save (arguments) to the checkpoint
        self.save_hyperparameters(config)

        self.loss_function = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
        self.best_score = 0
        self.best_val_epoch = -1

        self.train_step_outputs = []
        self.validation_step_outputs = []
        self.train_losses = []
        self.val_losses = []

    def forward(self, inputs):
        return self.model(inputs)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        anchor, positive, negative, label = batch

        anchor_out = self.model(anchor)
        positive_out = self.model(positive)
        negative_out = self.model(negative)
        
        loss = self.loss_function(anchor_out, positive_out, negative_out)
        pred = {"train_loss": loss, "train_number": len(anchor_out)}
        self.train_step_outputs.append(pred)
        return loss

    def on_train_epoch_end(self):
        outputs = self.train_step_outputs
        train_loss, num_items = 0, 0
        for output in outputs:
            train_loss += output["train_loss"].sum().item()
            num_items += output["train_number"]

        mean_train_loss = torch.tensor(train_loss / num_items)
        self.log("train_loss", mean_train_loss, sync_dist=True)
        self.train_losses.append(mean_train_loss.item())
        self.train_step_outputs.clear()  # free memory

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        anchor, positive, negative, label = batch

        anchor_out = self.model(anchor)
        positive_out = self.model(positive)
        negative_out = self.model(negative)
        
        loss = self.loss_function(anchor_out, positive_out, negative_out)

        pred = {"val_loss": loss, "val_number": len(anchor_out)}
        self.validation_step_outputs.append(pred)
        return {"val_loss": loss}

    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        val_loss, num_items = 0, 0
        for output in outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]

        mean_val_loss = torch.tensor(val_loss / num_items)
        self.val_losses.append(mean_val_loss.item())

        tensorboard_logs = {
            "val_loss": mean_val_loss
        }
        self.log("val_loss", mean_val_loss, sync_dist=True)
        self.validation_step_outputs.clear()  # free memory

        if mean_val_loss > self.best_score:
            self.best_score = mean_val_loss
            self.best_val_epoch = self.current_epoch

        return {"log": tensorboard_logs}

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.parameters(),
                                          lr=self.hparams.optim_lr,
                                          betas=self.hparams.optim_betas,
                                          eps=self.hparams.optim_eps,
                                          weight_decay=self.hparams.optim_weight_decay)
        return {
            "optimizer": self.optimizer,
            "lr_scheduler": {
                "scheduler": ReduceLROnPlateau(self.optimizer, factor=0.1, patience=10),
                "frequency": self.hparams.check_val,
                "monitor": "val_loss",
                "interval": "epoch",
            }
        }

    def plot_losses(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label='Train Loss')
        plt.plot([i * self.hparams.check_val for i in range(len(self.val_losses))], self.val_losses, label='Validation Loss')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.title('Training and Validation Losses')
        plt.legend()
        plt.show()

In [8]:
default_config = {
    "backbone": "resnet50",
    "optim_lr": 0.001,
    "embedings_size": 128,
}

In [9]:
def train_func(config):
    model = PollenEmbedingsModule(config)

    trainer = pl.Trainer(
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        callbacks=[RayTrainReportCallback()],
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
    )
    trainer = prepare_trainer(trainer)
    trainer.fit(model, train_dataloaders=training_loader, val_dataloaders=val_loader,)

In [10]:
search_space = {
    "backbone": tune.choice(["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]),
    "optim_lr": tune.loguniform(1e-5, 1e-2),
    "embedings_size": tune.choice([64, 128, 256]),
}

In [12]:
# The maximum training epochs
num_epochs = 20

# Number of sampls from parameter space
num_samples = 10

In [13]:
scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

In [14]:
from ray.train import RunConfig, ScalingConfig, CheckpointConfig

scaling_config = ScalingConfig(
    num_workers=1, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)

run_config = RunConfig(
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="val_loss",
        checkpoint_score_order="min",
    ),
)

In [15]:
from ray.train.torch import TorchTrainer

# Define a TorchTrainer without hyper-parameters for Tuner
ray_trainer = TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    run_config=run_config,
)

In [16]:
def tune_pollen_asha(num_samples=10):
    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

    tuner = tune.Tuner(
        ray_trainer,
        param_space={"train_loop_config": search_space},
        tune_config=tune.TuneConfig(
            metric="val_loss",
            mode="min",
            num_samples=num_samples,
            scheduler=scheduler,
        ),
    )
    return tuner.fit()

results = tune_pollen_asha(num_samples=num_samples)

0,1
Current time:,2024-06-21 07:42:50
Running for:,00:02:49.93
Memory:,14.0/125.7 GiB

Trial name,status,loc,train_loop_config/ba ckbone,train_loop_config/em bedings_size,train_loop_config/op tim_lr,iter,total time (s),val_loss,epoch,step
TorchTrainer_7720c_00000,TERMINATED,10.1.147.149:422335,resnet34,256,0.00648277,5,160.733,0.0729473,4,160


[36m(TorchTrainer pid=422335)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=422335)[0m - (ip=10.1.147.149, pid=422453) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=422453)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=422453)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=422453)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=422453)[0m IPU available: False, using: 0 IPUs
[36m(RayTrainWorker pid=422453)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=422453)[0m   rank_zero_warn(
[36m(RayTrainWorker pid=422453)[0m Missing logger folder: /tmp/ray/session_2024-06-21_07-39-56_023663_420378/artifacts/2024-06-21_07-40-00/TorchTrainer_2024-06-21_07-39-55/working_dirs/TorchTrainer_7720c_00000_0_backbone=resnet34,embedings_size=256,optim_lr=0.0065_2024-06-21_07-40-00/lightning_logs
[36m(RayTrainWorker pid=422453)[0m LOCAL_RANK: 0 - CUDA_VISIBLE

In [17]:
results.get_best_result(metric="val_loss", mode="min")

Result(
  metrics={'val_loss': 0.07294729351997375, 'train_loss': 0.020982593297958374, 'epoch': 4, 'step': 160},
  path='/home/jovyan/ray_results/TorchTrainer_2024-06-21_07-39-55/TorchTrainer_7720c_00000_0_backbone=resnet34,embedings_size=256,optim_lr=0.0065_2024-06-21_07-40-00',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/home/jovyan/ray_results/TorchTrainer_2024-06-21_07-39-55/TorchTrainer_7720c_00000_0_backbone=resnet34,embedings_size=256,optim_lr=0.0065_2024-06-21_07-40-00/checkpoint_000004)
)