In [1]:
from datasets import load_dataset
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch.utils.data import DataLoader, TensorDataset
torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the TweetEval dataset
tweet_eval_dataset = load_dataset('tweet_eval',"emoji")

# Accessing different splits
train_dataset = tweet_eval_dataset['train']
test_dataset = tweet_eval_dataset['test']
validation_dataset = tweet_eval_dataset['validation']

In [3]:
import pandas as pd
import re

# Concatenate datasets together
all_data = pd.concat([train_dataset.to_pandas(), validation_dataset.to_pandas(), test_dataset.to_pandas()])

# Drop all labels
all_data = all_data.drop(columns=['label'])
all_data['text'] = all_data['text']

def clean_text(text):
    # 1. Remove all characters except punctuation and English characters
    text = re.sub(r'[^a-zA-Z\s.,!?]', '', text)
    # 2. Remove all space at the beginning of the sentence
    text = text.lstrip()
    # 3. Remove all extra space
    text = re.sub(r'\s+', ' ', text)
    return text

all_data = all_data.head(100)
texts = all_data['text'].tolist()


texts = [clean_text(text) for text in texts]

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
class T5WithVAE(T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        self.sequence_length = 128
        self.feature_size = config.d_model
        # here we assume that the input max length is 512, AND the latent size is 512
        self.latent_size = self.sequence_length *  self.feature_size
        # self.dense = torch.nn.Linear(self.latent_size, self.latent_size)
        self.to_style = torch.nn.Linear(self.latent_size // 4, self.latent_size // 2)
        
        self.lm_head = torch.nn.Linear(config.d_model, config.vocab_size, bias=False)
        
        self.ignore_index = -100  # usually the index for padding tokens in Hugging Face models
        self.vocab_size = config.vocab_size

        self.tie_weights()
        
    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mean)
    
    def compute_loss(self, lm_logits, labels):
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index)
        logits_flat = lm_logits.view(-1, self.vocab_size)
        labels_flat = labels.view(-1)
        return loss_fct(logits_flat, labels_flat)
    
    def forward(self,input_ids=None, **kwargs):
        encoder_outputs = self.encoder(input_ids=input_ids,return_dict=True)
        latent_vector = encoder_outputs.last_hidden_state
        # flatten the latent vector
        latent_vector = latent_vector.view(latent_vector.size(0), -1)
        #latent_vector = self.dense(latent_vector) 
        
        # split the latent vector into three parts
        # content_vector have first 1/2 size of the latent vector
        # style_mean have second 1/4 size of the latent vector
        # style_var have last 1/4 size of the latent vector
        content_vector = latent_vector[:, :self.latent_size // 2]
        style_mean = latent_vector[:, self.latent_size // 2: self.latent_size // 2 + self.latent_size // 4]
        style_logvar  = latent_vector[:, self.latent_size // 2 + self.latent_size // 4:]
        
        # reparameterization trick
        style_vector = self.reparameterize(style_mean, style_logvar )
        style_vector = self.to_style(style_vector)
        
        # concatenate the content vector and style vector
        combined_vector = torch.cat([content_vector, style_vector], dim=1)
        combined_vector = combined_vector.view(combined_vector.size(0), self.sequence_length, self.feature_size)
        
        # decoder, the decoder_input_ids should be the same as the encoder input_ids
        decoder_outputs = self.decoder(input_ids =input_ids, encoder_hidden_states=combined_vector, return_dict=True)
        sequence_output = decoder_outputs[0]
        lm_logits = self.lm_head(sequence_output)
        
        # calculate the loss
        #lm_logits = decoder_outputs.logits
        reconstruction_loss = self.compute_loss(
            lm_logits, 
            input_ids)
        
        # Calculate KL divergence
        kl_loss = -0.5 * torch.sum(1 + style_logvar - style_mean.pow(2) - style_logvar.exp())
    
        # Combine the losses
        loss = reconstruction_loss + kl_loss
        
        return loss, decoder_outputs
    
    def generate(self, input_ids=None, **kwargs):
        decoder_outputs = self.decoder(input_ids=input_ids, return_dict=True)
        sequence_output = decoder_outputs[0]
        lm_logits = self.lm_head(sequence_output)
        return lm_logits
    
    
model = T5WithVAE.from_pretrained('t5-small').to(device)       

Some weights of T5WithVAE were not initialized from the model checkpoint at t5-small and are newly initialized: ['to_style.weight', 'to_style.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')


# Tokenization
inputs = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
input_ids = inputs['input_ids']

# DataLoader
batch_size = 4
dataset = TensorDataset(input_ids)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
# Training
model.train()

# Define the number of epochs
epochs = 3

for epoch in range(epochs):
    # Tracking variables
    train_loss = 0
    for batch in loader:
        # Assuming that 'batch' is a tuple of (input_ids, labels)
        input_ids = batch[0].to(device)
        
        # Forward pass
        loss, outputs = model(input_ids=input_ids)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Update parameters
        optimizer.step()
        
        # Update training loss
        train_loss += loss.item()
    
    # Calculate the average loss over the training data
    avg_train_loss = train_loss / len(loader)
    
    # Print progress
    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.3f}")

Epoch 1/3 | Train Loss: 2400.233
Epoch 2/3 | Train Loss: 2084.003
Epoch 3/3 | Train Loss: 1773.296
