In [3]:
import bittensor as bt
from datetime import datetime
import pandas as pd
import numpy as np

from bitmind.utils.uids import get_random_uids
from bitmind.utils.data import sample_dataset_index_name
from bitmind.protocol import prepare_image_synapse
from bitmind.validator.reward import get_rewards
from bitmind.image_transforms import random_aug_transforms
from bitmind.base.validator import BaseValidatorNeuron
from bitmind.synthetic_image_generation.synthetic_image_generator import SyntheticImageGenerator
from bitmind.image_dataset import ImageDataset
from bitmind.constants import VALIDATOR_DATASET_META


class CustomValidator(BaseValidatorNeuron):

    def __init__(self, config=None):
        super(CustomValidator, self).__init__(config=config)

        bt.logging.info("load_state()")
        self.load_state()

        bt.logging.info("Loading real datasets")
        self.real_image_datasets = [
            ImageDataset(ds['path'], 'train', ds.get('name', None))
            for ds in VALIDATOR_DATASET_META['real']
        ]

        self.synthetic_image_generator = SyntheticImageGenerator(
            prompt_type='annotation', use_random_diffuser=True, diffuser_name=None)

        self._fake_prob = self.config.get('fake_prob', 0.5)

    async def forward(self):
        """
        Validator forward pass. Consists of:
        - Generating the query
        - Querying the miners
        - Getting the responses
        - Rewarding the miners
        - Updating the scores
        """
        miner_uids = get_random_uids(self, k=self.config.neuron.sample_size)
        if np.random.rand() > self._fake_prob:
            bt.logging.info('sampling real image')
        
            label = 0
            real_dataset_index, source_dataset = sample_dataset_index_name(self.real_image_datasets)
            real_dataset = self.real_image_datasets[real_dataset_index]
            samples, idx = real_dataset.sample(k=1)  # {'image': PIL Image ,'id': int}
            sample = samples[0]
        
        else:
            label = 1
        
            if self.config.neuron.prompt_type == 'annotation':
                bt.logging.info('generating fake image from annotation of real image')
        
                retries = 10
                while retries > 0:
                    retries -= 1
        
                    # sample image(s) from real dataset for captioning
                    real_dataset_index, source_dataset = sample_dataset_index_name(self.real_image_datasets)
                    real_dataset = self.real_image_datasets[real_dataset_index]
                    images_to_caption, image_indexes = real_dataset.sample(k=1)  # [{'image': PIL Image ,'id': int}, ...]
        
                    # generate captions for the real images, then synthetic images from these captions
                    sample = self.synthetic_image_generator.generate(
                        k=1, real_images=images_to_caption)[0]  # {'prompt': str, 'image': PIL Image ,'id': int}
        
                    if not np.any(np.isnan(sample['image'])):
                        break
        
                    bt.logging.warning("NaN encountered in prompt/image generation, retrying...")
        
            elif self.config.neuron.prompt_type == 'random':
                bt.logging.info('generating fake image using prompt_generator')
                sample = self.synthetic_image_generator.generate(k=1)[0]
            else:
                bt.logging.error(f'unsupported neuron.prompt_type: {self.config.neuron.prompt_type}')
                raise NotImplementedError
        
        image = random_aug_transforms(sample['image'])
        data_aug_params = random_aug_transforms.params
        
        bt.logging.info(f"Querying {len(miner_uids)} miners...")
        axons = [self.metagraph.axons[uid] for uid in miner_uids]
        responses = await self.dendrite(
            axons=axons,
            synapse=prepare_image_synapse(image=image),
            deserialize=True
        )

        # TODO: Swap in custom reward function here
        rewards = get_rewards(label=label, responses=responses)
    
        bt.logging.info(f"Received responses: {responses}")
        bt.logging.info(f"Scored responses: {rewards}")

        self.update_scores(rewards, miner_uids)
        return uids, responses, label, rewards, self.scores

In [None]:
vali = CustomValidator()

uids, responses, label, rewards, scores = vali.forward()