# ToDo 

Most of work is done and only the following needs to be done

AutoUnit-**trainging step**  https://pytorch.org/tnt/stable/framework/auto_unit.html
- add login W&B
- add validation step


# Imports

In [None]:
import os
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import GPT2Tokenizer
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import wandb

# File paths to save/load the datasets
train_file = 'training_dataset.pt'
valid_file = 'validation_dataset.pt'


from torchtnt.framework.auto_unit import AutoUnit
from torchtnt.framework.fit import fit


In [None]:
# dir(fit)
wandb.init(project="my_transformer_project")

# Variables

In [None]:
data = "SouthernCrossAI/Project_Gutenberg_Australia"
sequence_length = 512
batch_size = 16

# File paths to save/load the datasets after first trasnsfermations
train_file = 'training_dataset.pt'
valid_file = 'validation_dataset.pt'

# Model Configuration
vocab_size = 50257
n_embd = 512
n_head = 8
n_layer = 8
n_layer_decoder = 1

# Import Data

In [None]:
ds = load_dataset(data)

# Split the dataset

In [None]:
dataset = ds['train'].train_test_split(test_size=0.2)
print(f"test = {len(dataset['test'])} and train = {len(dataset['train'])}")

# Tokenizer and function to prepare the data

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', clean_up_tokenization_spaces=True)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def tokenize_function(dataset):
    return tokenizer(dataset['Paragraph'], 
                    truncation=True, 
                    padding='max_length', 
                    max_length=sequence_length,
                    return_attention_mask=True)

# Convert the datasets

In [None]:
# Check if transformed datasets already exist
if os.path.exists(train_file) and os.path.exists(valid_file):
    # Load the transformed datasets
    training_dataset = torch.load(train_file)
    validation_dataset = torch.load(valid_file)
    print("Loaded existing transformed datasets.")
else:
    # Create the transformed datasets
    training_dataset = dataset['train'].map(tokenize_function, batched=True)
    validation_dataset = dataset['test'].map(tokenize_function, batched=True)
    
    # Set the format for PyTorch
    training_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    validation_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    
    # Save the transformed datasets for future use
    torch.save(training_dataset, train_file)
    torch.save(validation_dataset, valid_file)
    print("Transformed datasets created and saved.")

# Old Code# Check if transformed datasets already exist
if os.path.exists(train_file) and os.path.exists(valid_file):
    # Load the transformed datasets
    training_dataset = torch.load(train_file)
    validation_dataset = torch.load(valid_file)
    print("Loaded existing transformed datasets.")
else:
    # Create the transformed datasets
    training_dataset = dataset['train'].map(tokenize_function, batched=True)
    validation_dataset = dataset['test'].map(tokenize_function, batched=True)
    print("Transformed datasets created and saved.")

# Dataloaders

In [None]:
num_samples = 20 
train_indices = torch.randperm(len(training_dataset))[:num_samples]
valid_indices = torch.randperm(len(validation_dataset))[:num_samples]


small_train_subset = Subset(training_dataset, train_indices)
small_valid_subset = Subset(validation_dataset, valid_indices)
training_dataloader = DataLoader(small_train_subset, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(small_valid_subset, batch_size=batch_size, shuffle=False)

In [None]:
# Similarly, for the validation DataLoader
for batch in validation_dataloader:
    print("Batch input_ids shape:", batch['input_ids'].shape)
    print("Batch attention_mask shape:", batch['attention_mask'].shape)
    break  # Exit after printing the size of the first batch


# Model

Contextual embedding

### Embeddings

In [None]:
class Embeddings(nn.Module):
    def __init__(self):
        super(Embeddings, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(sequence_length, n_embd)

    def forward(self, x):
        tokens = self.token_embedding(x)
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0).expand_as(x)
        positions = self.position_embedding(positions)
        x = tokens + positions
        return x

### Trasformer

In [None]:
# Define the TransformerBlock to use key_padding_mask
class TransformerBlock(nn.Module):
    def __init__(self):
        super(TransformerBlock, self).__init__()
        self.attn = nn.MultiheadAttention(n_embd, n_head)
        self.ln1 = nn.LayerNorm(n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd)
        )
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x, key_padding_mask=None):
        # Transpose for MultiheadAttention (seq_len, batch_size, embed_dim)
        x = x.transpose(0, 1)
        attn_output, _ = self.attn(x, x, x, key_padding_mask=key_padding_mask)
        x = x + attn_output
        x = self.ln1(x)

        # Transpose back to (batch_size, seq_len, embed_dim)
        x = x.transpose(0, 1)
        mlp_output = self.mlp(x)
        x = x + mlp_output
        x = self.ln2(x)

        return x


### Model and Forward Pass

In [None]:

class BabyJoey(nn.Module):
    def __init__(self):
        super(BabyJoey, self).__init__()
        
        # Embeddings
        self.embeddings = Embeddings()
        
        # Decoder Blocks (based on n_layer_decoder)
        self.decoder_blocks = nn.ModuleList([TransformerBlock() for _ in range(n_layer_decoder)])

        # Output layer
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)

    def forward(self, x, attn_mask=None):
        # Get embeddings
        x = self.embeddings(x)

        # Apply decoder blocks with attention mask
        for block in self.decoder_blocks:
            x = block(x, key_padding_mask=attn_mask)  # Pass attn_mask as key_padding_mask

        # Layer norm and output
        x = self.ln_f(x)
        logits = self.head(x)

        return logits


In [None]:
# Initialize the BabyJoey model
model = BabyJoey()

# Traing Procedures

### Optimizer and loss function

In [None]:
# Define the loss function
loss_function = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

### Training Loop Functions and Class

In [None]:
class BabyJoeyUnit(AutoUnit):
    def __init__(self, module, device=None):
        super().__init__(module=module, device=device)
        self.loss_fn = nn.CrossEntropyLoss()
        self.metrics = {}
    
    def compute_loss(self, state, data):
        input_ids, attention_mask = data['input_ids'], data['attention_mask']
        
        # Ensure the attention mask is of type bool (for key_padding_mask)
        key_padding_mask = (attention_mask == 0).bool()

        logits = self.module(input_ids, attn_mask=key_padding_mask)
        # Shift the input ids by one to get the target sequence
        targets = input_ids[:, 1:].contiguous()
        logits = logits[:, :-1, :].contiguous()
        
        loss = self.loss_fn(logits.view(-1, logits.size(-1)), targets.view(-1))
        wandb.log({"train_loss": loss.item()})
        return loss, logits
    
    def on_train_end(self, state):
        torch.save(self.module.state_dict(), "baby_joey_model.pth")
        print("模型已保存为 'baby_joey_model.pth'")


    def configure_optimizers_and_lr_scheduler(self, module):
        optimizer = optim.AdamW(module.parameters(), lr=1e-4, weight_decay=1e-2)
        return optimizer, None
    
    def on_eval_step_end(self, state, data, step, loss, outputs):
        input_ids, attention_mask = data['input_ids'], data['attention_mask']
        key_padding_mask = (attention_mask == 0).bool()
        logits = self.module(input_ids, attn_mask=key_padding_mask)
        targets = input_ids[:, 1:].contiguous()
        logits = logits[:, :-1, :].contiguous()
        loss = self.loss_fn(logits.view(-1, logits.size(-1)), targets.view(-1))
        
        self.metrics.update({'val_loss': loss.item()})
        wandb.log({'validation_loss': loss.item()})

### Running TNT training 

In [None]:
from torchtnt.framework.train import train

# Correctly define the device as a torch.device object
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the BabyJoey model and move it to the device
model = BabyJoey().to(device)

# Define the custom AutoUnit with the correct device object
baby_joey_unit = BabyJoeyUnit(module=model, device=device)

# Train the model
fit(baby_joey_unit,
    train_dataloader=training_dataloader,
    eval_dataloader=validation_dataloader,
    max_epochs=2,
    )
    