In [58]:

import matplotlib.pyplot as plt
import torchvision
import torch
from torch import optim

from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import datetime

import random
import numpy as np

representation_shape = (1,4)
num_classes = 2

''' Only generate the activations once, so that the same values are used for train and test datasets.
Use torch.rand so that values are in the range [0,1] as this is more likely as a representation'''
def generate_activations(num_classes, representation_shape):
    activations = torch.empty((num_classes, ) + representation_shape)
    for c in range(num_classes):
        activations[c] = torch.rand(representation_shape)
    return activations

activations = generate_activations(num_classes, representation_shape)


class SyntheticDataset(torch.utils.data.Dataset):
    def __init__(self, train, activations, noise=0.0, num_classes = 10):
        """
        Args:
            train (bool)        : for train=True set, return 60_000 samples, otherwise 10_000
            activations         : base set of activations: one per class. values in range [0,1]
            noise (float) = 0.0 : factor by which to scale randn values before adding to the base activations. 
            num_classes = 10    : standard for MNIST and CIFAR-10
        """
        self.train = train
        self.size = 60_000 if train else 10_000
        
        self.data = torch.empty((self.size,) + activations[0].shape)  # start off empty
        self.targets = torch.empty(self.size)  #, dtype=torch.int)  # Random targets
        
        for c in range(num_classes):
            #activations = torch.randn(self.shape)
            #print(activations.shape)
            # FIXME do I need to use clone()? This may not be producing the ball of noise I thought it was.
            self.data[c * self.size // num_classes : ((c+1) * self.size // num_classes)]    = activations[c].clone()
            self.targets[c * self.size // num_classes : ((c+1) * self.size // num_classes)] = c
        print(activations[0])
        print(self.data[0])
        print(self.data[1])
        self.data += noise * torch.randn(self.data.shape)  # add a configurable amount of noise
        self.data = torch.clamp(self.data, 0, 1)           # and clamp to [0,1] range again so that it is batchnorm compatible
        print(activations[0])
        print(self.data[0])
        print(self.data[1])
        print(f"Check: {activations[0] is activations[0]}")
            
    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        sample = self.data[idx], int(self.targets[idx])
        return sample



class DynamicSyntheticDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples, activations, noise=0.0, local_seed = 9):
        """
        Args:
            num_samples         : would use 60000/10000 for MNIST equivalent. May well require 1000 per class = 1 million for imagenet
            activations         : base set of activations: one per class. values in range [0,1]
            noise (float) = 0.0 : factor by which to scale randn values before adding to the base activations.             
            local_seed          : used to create random noise without upsetting the global generators
        """
        num_classes = len(activations)
        self.targets = torch.empty(num_samples)  #, dtype=torch.int)  # Random targets
        self.local_seed = local_seed
        self.noise = noise
        
        for c in range(num_classes):
            self.targets[c * num_samples // num_classes : ((c+1) * num_samples // num_classes)] = c

        self.activations = activations
        
        # Create a local random number generator
        self.local_rng = torch.Generator()
        
    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):

        sample_class = int(self.targets[idx])
        # Seed the local generator according to this get request
        self.local_rng.manual_seed(self.local_seed + idx)

        print(f"before adding noise: {sample_class=} {self.activations[sample_class][0]=}") 
        data = self.activations[sample_class].clone() 
        data += self.noise * torch.randn(data.shape, generator=self.local_rng)  # add a configurable amount of noise
        data = torch.clamp(data, 0, 1)                # and clamp to [0,1] range again so that it is batchnorm compatible 
        print(f"after adding noise: {self.activations[sample_class][0]=}") 
             
        
        sample = data, sample_class
        return sample

# dataset = SyntheticDataset(False, activations, noise=0.1, num_classes=num_classes)
dyn_dataset = DynamicSyntheticDataset(10, activations, noise=0.1, local_seed=1)
print(dyn_dataset.__getitem__(0))
print(dyn_dataset.__getitem__(1))
print(dyn_dataset.__getitem__(1))
print(dyn_dataset.__getitem__(0))


before adding noise: sample_class=0 self.activations[sample_class][0]=tensor([0.5830, 0.1675, 0.1069, 0.2999])
after adding noise: self.activations[sample_class][0]=tensor([0.5830, 0.1675, 0.1069, 0.2999])
(tensor([[0.6492, 0.1942, 0.1131, 0.3620]]), 0)
before adding noise: sample_class=0 self.activations[sample_class][0]=tensor([0.5830, 0.1675, 0.1069, 0.2999])
after adding noise: self.activations[sample_class][0]=tensor([0.5830, 0.1675, 0.1069, 0.2999])
(tensor([[0.6223, 0.1451, 0.0750, 0.1794]]), 0)
before adding noise: sample_class=0 self.activations[sample_class][0]=tensor([0.5830, 0.1675, 0.1069, 0.2999])
after adding noise: self.activations[sample_class][0]=tensor([0.5830, 0.1675, 0.1069, 0.2999])
(tensor([[0.6223, 0.1451, 0.0750, 0.1794]]), 0)
before adding noise: sample_class=0 self.activations[sample_class][0]=tensor([0.5830, 0.1675, 0.1069, 0.2999])
after adding noise: self.activations[sample_class][0]=tensor([0.5830, 0.1675, 0.1069, 0.2999])
(tensor([[0.6492, 0.1942, 0.1131