In [1]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from flash_attn_interface import flash_attention
from torch.utils.data import DataLoader, Dataset
from glob import glob
from tqdm import tqdm
import time
from transformers import AdamW
import os
from torch.cuda.amp import autocast, GradScaler
torch.manual_seed(0)
class FlashAttentionLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.n_embd / config.num_attention_heads)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)

    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_length, hidden_size = hidden_states.size()

        # Apply linear transformations and reshape for multi-head attention
        q = self.query(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size)
        k = self.key(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size)
        v = self.value(hidden_states).view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size)

        # Transpose to get the shape (batch_size, num_heads, seq_length, head_size)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Call flash_attention
        attention_output = flash_attention(q, k, v, dropout_prob=0.0, causal=False)
        # Reshape attention output back to (batch_size, seq_length, hidden_size)
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)
        return (attention_output,)

class CustomGPT2Block(GPT2Block):
    def __init__(self, config):
        super().__init__(config)
        # Replace the standard attention with FlashAttentionLayer
        self.attn = FlashAttentionLayer(config)

    def forward(self, hidden_states, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False):
        # Attention block
        attn_outputs = self.attn(hidden_states, attention_mask)
        a = attn_outputs[0]  # Output of the flash attention

        # Feed Forward block
        mlp_output = self.mlp(self.ln_2(a))
        hidden_states = mlp_output + a

        outputs = (hidden_states,) + attn_outputs[1:]  # Add attention outputs if they are present

        return outputs

class CustomGPT2Model(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        # Replace standard GPT2 blocks with CustomGPT2Block
        self.h = nn.ModuleList([CustomGPT2Block(config) for _ in range(config.n_layer)])
        # Initialize the embedding layers
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        default_layer_norm_eps = 1e-5

        # Initialize the final layer normalization with the default epsilon
        self.ln_f = nn.LayerNorm(config.n_embd, eps=default_layer_norm_eps)

        # Initialize the language modeling head
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        # Prepare inputs to the model
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])
        batch_size = input_ids.shape[0]
        if input_ids is not None:
            input_shape = input_ids.size()
            device = input_ids.device
        else:
            input_shape = inputs_embeds.size()[:-1]
            device = inputs_embeds.device

        if position_ids is None:
            # Create default position_ids if None provided
            position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

        if attention_mask is None:
            attention_mask = torch.ones(batch_size, input_shape[-1]).to(input_ids.device)

        # Get embeddings from GPT2Model
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for block in self.h:
            outputs = block(hidden_states, attention_mask=attention_mask)
            hidden_states = outputs[0]

        # Final layer normalization and head
        hidden_states = self.ln_f(hidden_states)
        logits = self.lm_head(hidden_states)

        return logits


# Initialize the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Retrieve the correct vocabulary size from the tokenizer
correct_vocab_size = tokenizer.vocab_size

# Update your GPT2 configuration with the correct vocabulary size
config = GPT2Config(
    vocab_size=correct_vocab_size,
    n_positions=1024,
    n_ctx=1024,
    n_embd=768,
    n_layer=12,
    n_head=12
)
trad_model = GPT2LMHeadModel(config)
# Initialize the custom model with Flash Attention
custom_model = CustomGPT2Model(config)
traditional_params = set(trad_model.state_dict().keys())
custom_params = set(custom_model.state_dict().keys())

# Find common parameters
common_params = traditional_params & custom_params

# Copy only common parameters
common_state_dict = {name: trad_model.state_dict()[name] for name in common_params}
custom_model.load_state_dict(common_state_dict, strict=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Transfer the custom_model to the GPU (if available)
custom_model.to(device)

# Define your TextFolderDataset
# Define your TextFolderDataset
class TextFolderDataset(Dataset):
    def __init__(self, file_directory, file_pattern, max_length):
        self.filepaths = glob(os.path.join(file_directory, file_pattern))
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token  # Set pad_token
        self.max_length = max_length

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        with open(self.filepaths[idx], 'r', encoding='utf-8') as file:
            text = file.read()

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True, 
            max_length=self.max_length, 
            truncation=True, 
            padding='max_length', 
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'][0]
        attention_mask = encoding['attention_mask'][0]
        labels = input_ids.clone()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

# Create your dataset and data loader
file_directory = "openwebtext/"
file_pattern = "urlsf_subset01-3[2,3]*"
max_length = 512
dataset = TextFolderDataset(file_directory, file_pattern, max_length)
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Training Loop

from tqdm import tqdm

optimizer = AdamW(custom_model.parameters(), lr=5e-5)
accumulation_steps = 2
num_epochs = 50

start_time = time.time()
scaler = GradScaler()
# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Training loop
for epoch in tqdm(range(num_epochs)):
    total_loss = 0
    optimizer.zero_grad()
    for step, batch in enumerate(train_dataloader):
        # Transfer the data to the GPU
        inputs = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        with autocast():
            logits = custom_model(inputs)
            # Reshape logits and labels for CrossEntropyLoss
            # logits shape: [batch_size, seq_length, vocab_size]
            # labels shape: [batch_size, seq_length]
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

        # Backward pass and optimization
        scaler.scale(loss).backward()
        if (step + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        total_loss += loss.item()

    avg_loss = total_loss / (len(train_dataloader) / accumulation_steps)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

    if avg_loss < 1.5:
        break


end_time = time.time()
print(f"The time needed for training is {end_time - start_time} seconds.")


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


  2%|█▋                                                                                  | 1/50 [00:26<21:22, 26.17s/it]

Epoch 1/50, Average Loss: 19.2852


  4%|███▎                                                                                | 2/50 [00:48<19:12, 24.01s/it]

Epoch 2/50, Average Loss: 13.0474


  6%|█████                                                                               | 3/50 [01:11<18:17, 23.35s/it]

Epoch 3/50, Average Loss: 7.5801


  8%|██████▋                                                                             | 4/50 [01:33<17:38, 23.01s/it]

Epoch 4/50, Average Loss: 6.0390


 10%|████████▍                                                                           | 5/50 [01:56<17:08, 22.85s/it]

Epoch 5/50, Average Loss: 5.9575


 12%|██████████                                                                          | 6/50 [02:18<16:37, 22.67s/it]

Epoch 6/50, Average Loss: 5.9379


 14%|███████████▊                                                                        | 7/50 [02:41<16:12, 22.62s/it]

Epoch 7/50, Average Loss: 5.8414


 16%|█████████████▍                                                                      | 8/50 [03:03<15:51, 22.64s/it]

Epoch 8/50, Average Loss: 5.7599


 18%|███████████████                                                                     | 9/50 [03:26<15:25, 22.57s/it]

Epoch 9/50, Average Loss: 5.7190


 20%|████████████████▌                                                                  | 10/50 [03:48<14:57, 22.44s/it]

Epoch 10/50, Average Loss: 5.6318


 22%|██████████████████▎                                                                | 11/50 [04:10<14:33, 22.40s/it]

Epoch 11/50, Average Loss: 5.5698


 24%|███████████████████▉                                                               | 12/50 [04:33<14:12, 22.42s/it]

Epoch 12/50, Average Loss: 5.5100


 26%|█████████████████████▌                                                             | 13/50 [04:55<13:50, 22.43s/it]

Epoch 13/50, Average Loss: 5.4351


 28%|███████████████████████▏                                                           | 14/50 [05:17<13:25, 22.39s/it]

Epoch 14/50, Average Loss: 5.3676


 30%|████████████████████████▉                                                          | 15/50 [05:40<13:03, 22.39s/it]

Epoch 15/50, Average Loss: 5.3104


 32%|██████████████████████████▌                                                        | 16/50 [06:02<12:43, 22.47s/it]

Epoch 16/50, Average Loss: 5.2441


 34%|████████████████████████████▏                                                      | 17/50 [06:25<12:21, 22.48s/it]

Epoch 17/50, Average Loss: 5.1822


 36%|█████████████████████████████▉                                                     | 18/50 [06:47<11:58, 22.44s/it]

Epoch 18/50, Average Loss: 5.1238


 38%|███████████████████████████████▌                                                   | 19/50 [07:09<11:32, 22.35s/it]

Epoch 19/50, Average Loss: 5.0694


 40%|█████████████████████████████████▏                                                 | 20/50 [07:32<11:13, 22.46s/it]

Epoch 20/50, Average Loss: 5.0124


 42%|██████████████████████████████████▊                                                | 21/50 [07:55<10:53, 22.52s/it]

Epoch 21/50, Average Loss: 4.9566


 44%|████████████████████████████████████▌                                              | 22/50 [08:18<10:32, 22.58s/it]

Epoch 22/50, Average Loss: 4.9050


 46%|██████████████████████████████████████▏                                            | 23/50 [08:40<10:09, 22.58s/it]

Epoch 23/50, Average Loss: 4.8510


 48%|███████████████████████████████████████▊                                           | 24/50 [09:03<09:47, 22.61s/it]

Epoch 24/50, Average Loss: 4.8028


 50%|█████████████████████████████████████████▌                                         | 25/50 [09:26<09:25, 22.64s/it]

Epoch 25/50, Average Loss: 4.7303


 52%|███████████████████████████████████████████▏                                       | 26/50 [09:48<09:04, 22.68s/it]

Epoch 26/50, Average Loss: 15.4651


 54%|████████████████████████████████████████████▊                                      | 27/50 [10:11<08:42, 22.71s/it]

Epoch 27/50, Average Loss: 4.7066


 56%|██████████████████████████████████████████████▍                                    | 28/50 [10:34<08:20, 22.75s/it]

Epoch 28/50, Average Loss: 4.2944


 58%|████████████████████████████████████████████████▏                                  | 29/50 [10:57<07:57, 22.74s/it]

Epoch 29/50, Average Loss: 4.0510


 60%|█████████████████████████████████████████████████▊                                 | 30/50 [11:19<07:35, 22.76s/it]

Epoch 30/50, Average Loss: 3.8031


 62%|███████████████████████████████████████████████████▍                               | 31/50 [11:42<07:09, 22.63s/it]

Epoch 31/50, Average Loss: 3.7104


 64%|█████████████████████████████████████████████████████                              | 32/50 [12:04<06:46, 22.59s/it]

Epoch 32/50, Average Loss: 3.5743


 66%|██████████████████████████████████████████████████████▊                            | 33/50 [12:27<06:23, 22.59s/it]

Epoch 33/50, Average Loss: 3.5211


 68%|████████████████████████████████████████████████████████▍                          | 34/50 [12:49<06:01, 22.61s/it]

Epoch 34/50, Average Loss: 3.4405


 70%|██████████████████████████████████████████████████████████                         | 35/50 [13:13<05:41, 22.80s/it]

Epoch 35/50, Average Loss: 3.3907


 72%|███████████████████████████████████████████████████████████▊                       | 36/50 [13:35<05:18, 22.78s/it]

Epoch 36/50, Average Loss: 3.3493


 74%|█████████████████████████████████████████████████████████████▍                     | 37/50 [13:58<04:57, 22.85s/it]

Epoch 37/50, Average Loss: 3.2985


 76%|███████████████████████████████████████████████████████████████                    | 38/50 [14:21<04:34, 22.88s/it]

Epoch 38/50, Average Loss: 3.2724


 78%|████████████████████████████████████████████████████████████████▋                  | 39/50 [14:45<04:12, 22.99s/it]

Epoch 39/50, Average Loss: 3.2332


 80%|██████████████████████████████████████████████████████████████████▍                | 40/50 [15:08<03:50, 23.02s/it]

Epoch 40/50, Average Loss: 3.1991


 82%|████████████████████████████████████████████████████████████████████               | 41/50 [15:31<03:27, 23.03s/it]

Epoch 41/50, Average Loss: 3.1735


 84%|█████████████████████████████████████████████████████████████████████▋             | 42/50 [15:54<03:03, 22.98s/it]

Epoch 42/50, Average Loss: 3.1463


 86%|███████████████████████████████████████████████████████████████████████▍           | 43/50 [16:16<02:40, 22.87s/it]

Epoch 43/50, Average Loss: 3.1223


 88%|█████████████████████████████████████████████████████████████████████████          | 44/50 [16:39<02:16, 22.82s/it]

Epoch 44/50, Average Loss: 3.1007


 90%|██████████████████████████████████████████████████████████████████████████▋        | 45/50 [17:02<01:54, 22.82s/it]

Epoch 45/50, Average Loss: 3.0800


 92%|████████████████████████████████████████████████████████████████████████████▎      | 46/50 [17:25<01:31, 22.82s/it]

Epoch 46/50, Average Loss: 3.0640


 94%|██████████████████████████████████████████████████████████████████████████████     | 47/50 [17:47<01:07, 22.67s/it]

Epoch 47/50, Average Loss: 3.0469


 96%|███████████████████████████████████████████████████████████████████████████████▋   | 48/50 [18:09<00:45, 22.52s/it]

Epoch 48/50, Average Loss: 3.0325


 98%|█████████████████████████████████████████████████████████████████████████████████▎ | 49/50 [18:31<00:22, 22.41s/it]

Epoch 49/50, Average Loss: 3.0114


100%|███████████████████████████████████████████████████████████████████████████████████| 50/50 [18:53<00:00, 22.68s/it]

Epoch 50/50, Average Loss: 3.0106
The time needed for training is 1133.9099576473236 seconds.



