In [1]:
import torch

In [7]:
VOCAB_SIZE = 50768 # Size of vocabulary
DATA_PER_CLASS = 10
SYNTHETIC_SIZE = 120
DEVICE="cpu"

In [9]:
sentences_syn = torch.randint(
    high=VOCAB_SIZE,
    size=(120, 128), 
    requires_grad=True, 
    device=DEVICE,
    dtype=torch.float,
)

In [10]:
sentences_syn

tensor([[35756., 48711., 15608.,  ..., 14561., 26030., 27185.],
        [12290.,  4941., 33795.,  ..., 42057.,  7235., 24361.],
        [31802., 42214., 48225.,  ..., 19891., 30722., 49903.],
        ...,
        [ 7481., 16925., 26879.,  ..., 16968., 33761., 10144.],
        [10534., 30321., 11366.,  ..., 23064., 43659., 13331.],
        [19262., 25335., 36413.,  ..., 42492., 29445., 37697.]],
       requires_grad=True)

In [None]:
from typing import Tuple

def decode_embeddings(embeddings: torch.nn.Embedding, embedded_data: torch.Tensor):
    num_sentences = embedded_data.shape[0]
    sentence_len = embedded_data.shape[1]

    sentences = torch.zeros(num_sentences, sentence_len, device=DEVICE, dtype=torch.long)

    for i in range(num_sentences):
        sentence = torch.cdist(embedded_data[i, :, :].to(DEVICE), embeddings.weight.to(DEVICE), p=2)
        sentences[i] = sentence.argmin(-1)
    
    return sentences.type(torch.long)


def distance_wb(gwr, gws):
    shape = gwr.shape
    if len(shape) == 4: # conv, out*in*h*w
        gwr = gwr.reshape(shape[0], shape[1] * shape[2] * shape[3])
        gws = gws.reshape(shape[0], shape[1] * shape[2] * shape[3])
    elif len(shape) == 3:  # layernorm, C*h*w
        gwr = gwr.reshape(shape[0], shape[1] * shape[2])
        gws = gws.reshape(shape[0], shape[1] * shape[2])
    elif len(shape) == 2: # linear, out*in
        tmp = 'do nothing'
    elif len(shape) == 1: # batchnorm/instancenorm, C; groupnorm x, bias
        gwr = gwr.reshape(1, shape[0])
        gws = gws.reshape(1, shape[0])
        return torch.tensor(0, dtype=torch.float, device=gwr.device)

    dis_weight = (
        torch.sum(
            1 - torch.sum(gwr * gws, dim=-1) / (torch.norm(gwr, dim=-1) * torch.norm(gws, dim=-1) + 0.000001)
        )
    )
    dis = dis_weight
    return dis


def match_loss(
        gradient_weights_syn: Tuple[torch.Tensor, ...], 
        gradient_weights_real: Tuple[torch.Tensor, ...], 
        device
    ) -> torch.Tensor:

    distance = torch.tensor(0.0).to(device)

    for ig in range(len(gradient_weights_real)):
        gwr = gradient_weights_real[ig]
        gws = gradient_weights_syn[ig]
        distance += distance_wb(gwr, gws)
    
    return distance

In [None]:
from copy import deepcopy
import copy
import torch
from torch.utils.data import DataLoader
from transformers import GPT2LMHeadModel

K = 1 # number of initialization
T = 10 # number of epochs
C = 100 # number of minibatches

def initialize_model() -> GPT2LMHeadModel:
    pass

def sample_batch(real_data, syn_data):
    # Samples a batch from Real and synthetic dataset. Both have the same class
    pass

def evaluate():
    pass


real_data = torch.Tensor()
syn_data = torch.randint(0, 50257-1, (120, 128)).to(DEVICE)

criterion_real = torch.nn.CrossEntropyLoss()
criterion_syn = torch.nn.CrossEntropyLoss()

criterion_create_syn = torch.nn.CrossEntropyLoss()
optimizer_create_syn = torch.optim.SGD([syn_data, ], lr=0.1, momentum=0.5)


for k in range(K):
    model = initialize_model()
    model.train()
    optimizer_model = torch.optim.SGD(model.parameters())
    optimizer_model.zero_grad()
    model_parameters = list(model.parameters())
    
    for t in range(T): # Epochs

        for batch_real, batch_syn in sample_batch(real_data, syn_data):
            loss = torch.tensor(0.0).to(DEVICE)

            # Compute the real loss and get real gradient weights
            x_real = batch_real['input_ids']
            attn_mask_real = batch_real['attention_mask']
            y_real = x_real.clone()
            out_real = model(x_real, attention_mask=attn_mask_real, labels=y_real)
            loss_real = criterion_real(out_real, y_real)
            gradient_weights_real = torch.autograd.grad(loss_real, model_parameters) # These are like the "target values" for the match_loss

            # Compute the synthetic loss and get synthetic gradient weights
            x_syn = batch_syn['input_ids']
            y_syn = x_syn.clone()
            syn_embed = model.get_input_embeddings()(x_syn) # Grab the synthetic embeddings, this is the vector space that we wish to update
            syn_embed.requires_grad = True # It requires grads now
            optimizer_data = torch.optim.SGD([syn_embed,]) # It is being tracked in an optimizer
            out_syn = model(inputs_embeds=syn_embed, labels=y_syn) # We pass it through the remaining model
            loss_syn = criterion_syn(out_syn, y_real)
            gradient_weights_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True) # We obtain embeddings for all layers, including embeddings layer.

            loss += match_loss(
                gradient_weights_syn, 
                gradient_weights_real, 
                device=DEVICE
            )

            optimizer_data.zero_grad()
            loss.backward()
            optimizer_data.step() # Update the synthetic embeddings

            # Obtain Synthetic Tokens
            syn_embed_train = copy.deepcopy(syn_embed.detach())
            syn_data_train = decode_embeddings(model.get_input_embeddings(), syn_embed_train)
            
            # Now we can update our network using synthetic tokens

## Crux Algorithm

In [None]:
loss = torch.tensor(0.0).to(DEVICE)

# Compute the real loss and get real gradient weights
x_real = batch_real['input_ids']
attn_mask_real = batch_real['attention_mask']
y_real = x_real.clone()
out_real = model(x_real, attention_mask=attn_mask_real, labels=y_real)
loss_real = criterion_real(out_real, y_real)
gradient_weights_real = torch.autograd.grad(loss_real, model_parameters) # These are like the "target values" for the match_loss

# Compute the synthetic loss and get synthetic gradient weights
x_syn = batch_syn['input_ids']
y_syn = x_syn.clone()
syn_embed = model.get_input_embeddings()(x_syn) # Grab the synthetic embeddings, this is the vector space that we wish to update
syn_embed.requires_grad = True # It requires grads now
optimizer_data = torch.optim.SGD([syn_embed,]) # It is being tracked in an optimizer
out_syn = model(inputs_embeds=syn_embed, labels=y_syn) # We pass it through the remaining model
loss_syn = criterion_syn(out_syn, y_real)
gradient_weights_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True) # We obtain embeddings for all layers, including embeddings layer.

loss += match_loss(
    gradient_weights_syn, 
    gradient_weights_real, 
    device=DEVICE
)

optimizer_data.zero_grad()
loss.backward()
optimizer_data.step() # Update the synthetic embeddings

# Obtain Synthetic Tokens
syn_embed_train = copy.deepcopy(syn_embed.detach())
syn_data_train = decode_embeddings(model.get_input_embeddings(), syn_embed_train)

# Now we can update our network using synthetic tokens