What do we want to know: if we train a model with synthetic data, does the quality of the output decrease like it should according to "The Curse of Recursion: Training on Generated Data Makes Models Forget" by Shumailov et al?  What if we train tiny models on separate subsets of the data and then have them generate data for a new big model?

We don't know if text on the internet is generated via LLMs or not, so our problem shouldn't assume we have "real" data.  Instead, what if we cluster it and expose small samples from each different sub-cluster to different generative models?  Will the sum of the different probability distributes adequately capture the nuance we need?

https://arxiv.org/pdf/2305.17493.pdf

In [None]:
#%pip install huggingface-hub huggingface-cli datasets accelerate evaluate
import torch
import torchvision
import numpy
#from datasets import load_dataset

In [None]:
device = torch.device("cuda")
assert torch.cuda.is_available()

In [None]:
training_data = torchvision.datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)

test_data = torchvision.datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)

In [None]:
batch_size = 32  # Can we do higher on my machine?
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

for x,y in train_dataloader:
    print(x.shape)
    print(y.shape)
    break

In [None]:
import torch.nn as nn

class TinyMNISTClassifier(nn.Module):
    def __init__(self, outputs: int = 10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, 3), # 9x9 for valid.
            nn.Conv2d(16, 32, 3, 3), # 3x3 for valid.
            nn.LeakyReLU(),
            nn.Flatten(), # 32x3x3 -> 288
            nn.Linear(288, outputs)
        )
    def forward(self, x):
        return self.net(x) # No softmax output.

"""
epsilon = Normal(0, 1).sample((batch, dim)).to(z_mean.device)
return z_mean + torch.exp(0.5 * z_log_var) * epsilon
"""

class TinyMNISTGenerator(nn.Module):
    def __init__(self, latent_size: int):
        super().__init__()
        # TODO: This is stupid and we should make a better generator.
        self.net = nn.Sequential(
            nn.Linear(latent_size, 28*28),
            nn.LeakyReLU(),
            nn.Linear(28*28, 28*28),
            nn.LeakyReLU(),
            nn.Linear(28*28, 28*28),
        )

    def sample_latent(self, z_mean, z_log_var):
        return torch.normal(z_mean, z_log_var)
        #return z_mean + torch.exp(z_log_var) * epsilon

    def generate(self, z_sample):
        return torch.reshape(self.net(z_sample), (-1, 28, 28))

    def forward(self, z_mean, z_log_var):
        # Alias for generate(sample_latent(z_mean, z_log_var))
        return self.generate(self.sample_latent(z_mean, z_log_var))

class TinyMNISTVAE(nn.Module):
    def __init__(self, latent_size: int=30):
        self.z_mean_encoder = TinyMNISTClassifier(latent_size)  # We could maybe do this all in one fell swoop?
        self.z_log_var_encoder = TinyMNISTClassifier(latent_size)
        self.decoder = TinyMNISTGenerator(latent_size)
    def forward(self, x):
        z_m = self.z_mean_encoder(x)
        z_std = self.z_log_var_encoder(x)
        return self.decoder(z_m, z_std)

In [None]:
model = TinyMNISTClassifier(10).to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(params=model.parameters())

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    running_loss = 0.0
    batch_losses = list()
    for batch_idx, (example, target) in enumerate(dataloader):
        example = example.to(device)
        target = target.to(device)
        prediction = model(example)
        loss = loss_fn(prediction, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        loss_value = loss.item()
        batch_losses.append(loss_value)
        running_loss = 0.9*running_loss + 0.1*loss_value
        if batch_idx % 100 == 0:
            batch_losses = numpy.asarray(batch_losses)
            print(f"Batch ({batch_idx}) loss: Mean: {batch_losses.mean()} \t Stddev: {batch_losses.std()} \t Running: {running_loss}")
            batch_losses = list()

def test(data_loader, model, loss_fn):
    model.eval()
    test_loss = 0.0
    correct_count = 0
    total_examples = 0
    confusion_matrix = numpy.zeros((10, 10), dtype=numpy.uint8)
    with torch.no_grad():
        for (x, y) in data_loader:
            x = x.to(device)
            y = y.to(device)
            pred = model(x)
            test_loss += loss_fn(pred, y).item()
            pred = pred.argmax(1).cpu().detach().numpy()
            for predicted, correct in zip(pred, y):
                total_examples += 1
                confusion_matrix[predicted][correct] += 1
                if predicted == correct:
                    correct_count += 1
    return test_loss, correct_count, correct_count / float(total_examples), confusion_matrix

In [None]:
#train(train_dataloader, model, loss_fn, opt)
test(test_dataloader, model, loss_fn)

In [None]:
from PIL import Image # For visualizing the confusion matrix.

for epoch_idx in range(0, 10):
    train(train_dataloader, model, loss_fn, opt)
    _, _, correct_percent, confusion_matrix = test(test_dataloader, model, loss_fn)
    confusion_matrix = Image.fromarray(confusion_matrix)
    display(confusion_matrix.resize((100, 100)))