In [13]:
#!pip install transformers -U

In [1]:
import sys
import os
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
import random
import torch.nn.functional as F
import importlib
import numpy as np

In [2]:
# Add the correct path to the local transformers directory
local_path = os.path.abspath('../src/')
print("Adding path:", local_path)  # Verify the path to be added
sys.path.insert(0, local_path)

Adding path: /Users/johnschroter/IdeaProjects/Sigma-GPT/src


Creating local Path to files

Confirming local copies are being used

In [3]:
# Import your modified GPT2 classes
from transformers.models.gpt2.tokenization_gpt2 import *
from transformers.models.gpt2.modeling_gpt2 import *

# Verify that the modules are being loaded from the correct path
import transformers.models.gpt2.tokenization_gpt2
import transformers.models.gpt2.modeling_gpt2

print(transformers.models.gpt2.tokenization_gpt2.__file__)  # Should point to your local file
print(transformers.models.gpt2.modeling_gpt2.__file__)  # Should point to your local file

/Users/johnschroter/IdeaProjects/Sigma-GPT/src/transformers/models/gpt2/tokenization_gpt2.py
/Users/johnschroter/IdeaProjects/Sigma-GPT/src/transformers/models/gpt2/modeling_gpt2.py


Randomly initilizing sigma-gpt

In [4]:
# Initialize the tokenizer (pre-trained vocab is fine for tokenizer)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
# Initialize the configuration with random parameters
config = GPT2Config()

# Initialize the model with the custom configuration
#model = CustomGPT2LMHeadModel(config)
model = CustomGPT2LMHeadModel.from_pretrained('gpt2')

# Initialize weights randomly
#model.init_weights()

Some weights of CustomGPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['lm_head.weight', 'wte.LayerNorm.bias', 'wte.LayerNorm.weight', 'wte.next_position_embeddings.weight', 'wte.position_embeddings.weight', 'wte.word_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


load in dataset

In [18]:
# Load Wikitext-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Preprocess the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=32)

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]


KeyboardInterrupt



In [5]:
# Load Penn Treebank dataset
dataset = load_dataset("ptb_text_only")

# Preprocess the dataset
def tokenize_function(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=32)

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["sentence"])
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]



You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [6]:
class AdaptiveShuffle:
    def __init__(self, initial_shuffle_percentage=1.0, max_adjustment_per_epoch=0.05, performance_threshold=0.25):
        self.shuffle_percentage = initial_shuffle_percentage
        self.max_adjustment_per_epoch = max_adjustment_per_epoch
        self.performance_threshold = performance_threshold
        self.previous_loss = None

    def adjust_shuffle_percentage(self, current_loss):
        if self.previous_loss is not None:
            improvement = (self.previous_loss - current_loss) / self.previous_loss
            if improvement > self.performance_threshold:
                self.shuffle_percentage = min(self.shuffle_percentage + self.max_adjustment_per_epoch, 1.0)
            elif improvement < -self.performance_threshold:
                self.shuffle_percentage = max(self.shuffle_percentage - self.max_adjustment_per_epoch, 0.0)
        self.previous_loss = current_loss

    def get_current_shuffle_percentage(self):
        return self.shuffle_percentage



In [7]:
from torch.utils.data import Dataset

class ShuffledDataset(Dataset):
    def __init__(self, input_ids, position_ids, next_position_ids, attention_mask):
        self.input_ids = input_ids
        self.position_ids = position_ids
        self.next_position_ids = next_position_ids
        self.attention_mask = attention_mask

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.input_ids[idx], dtype=torch.long),
            'position_ids': torch.tensor(self.position_ids[idx], dtype=torch.long),
            'next_position_ids': torch.tensor(self.next_position_ids[idx], dtype=torch.long),
            'attention_mask': torch.tensor(self.attention_mask[idx], dtype=torch.long)
        }

In [8]:
# Function to shuffle a percentage of tokens within each sequence
def shuffle_with_positional_ids(dataset, shuffle_percentage):
    shuffled_input_ids_list = []
    shuffled_pos_ids_list = []
    next_pos_ids_list = []
    attention_mask_list = []

    for example in dataset:
        input_ids = example['input_ids']
        attention_mask = example['attention_mask']

        # Calculate the number of tokens to shuffle
        seq_length = len(input_ids)
        num_shuffled_tokens = int(seq_length * shuffle_percentage)

        # Get indices to shuffle
        indices = list(range(seq_length))
        indices_to_shuffle = np.random.choice(indices, num_shuffled_tokens, replace=False)

        # Create a permutation for the selected indices
        permutation = np.random.permutation(num_shuffled_tokens)

        # Create shuffled input_ids, pos_ids, and attention_mask
        shuffled_input_ids = input_ids.copy()
        pos_ids = list(range(seq_length))
        shuffled_pos_ids = pos_ids.copy()
        shuffled_attention_mask = attention_mask.copy()

        for i, idx in enumerate(indices_to_shuffle):
            shuffled_input_ids[idx] = input_ids[indices_to_shuffle[permutation[i]]]
            shuffled_pos_ids[idx] = pos_ids[indices_to_shuffle[permutation[i]]]
            shuffled_attention_mask[idx] = attention_mask[indices_to_shuffle[permutation[i]]]

        # Create the next shuffled pos ids
        next_pos_ids = shuffled_pos_ids[1:] + [shuffled_pos_ids[0]]

        # Append to lists
        shuffled_input_ids_list.append(shuffled_input_ids)
        shuffled_pos_ids_list.append(shuffled_pos_ids)
        next_pos_ids_list.append(next_pos_ids)
        attention_mask_list.append(shuffled_attention_mask)

    return ShuffledDataset(
        shuffled_input_ids_list,
        shuffled_pos_ids_list,
        next_pos_ids_list,
        attention_mask_list
    )

In [9]:
from torch.utils.tensorboard import SummaryWriter
import time
import os
import shutil
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

def plot_attention_heatmap(attention_matrix, epoch, layer_idx, head_idx):
    fig, ax = plt.subplots(figsize=(8, 6))
    cax = ax.matshow(attention_matrix, cmap='viridis')
    fig.colorbar(cax)
    ax.set_title(f'Epoch {epoch+1}, Layer {layer_idx+1}, Head {head_idx+1}')
    return fig

def train_model(model, tokenizer, adaptive_shuffle, train_dataset, eval_dataset, num_epochs=10, batch_size=64, log_dir='./logs'):

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
    loss_fn = torch.nn.CrossEntropyLoss()
    writer = SummaryWriter(log_dir=log_dir)

    def compute_accuracy(logits, labels):
        preds = torch.argmax(logits, dim=-1)
        correct = (preds == labels).float()
        return correct.sum() / correct.numel()

    # Log the model architecture
    dummy_input_ids = torch.randint(0, 100, (1, 64)).to(model.device)
    dummy_position_ids = torch.randint(0, 100, (1, 64)).to(model.device)
    dummy_next_position_ids = torch.randint(0, 100, (1, 64)).to(model.device)
    dummy_attention_mask = torch.ones((1, 64)).to(model.device)
    #writer.add_graph(model, (dummy_input_ids, dummy_position_ids, dummy_next_position_ids, dummy_attention_mask))

    for epoch in range(num_epochs):
        shuffle_percentage = adaptive_shuffle.get_current_shuffle_percentage()
        print(f"Epoch {epoch + 1}: Shuffle Percentage={shuffle_percentage}")

        # Shuffle the sequences based on the current shuffle percentage
        shuffled_train_dataset = shuffle_with_positional_ids(train_dataset, shuffle_percentage)
        train_loader = DataLoader(shuffled_train_dataset, batch_size=batch_size, shuffle=True)


        # Select the first batch for logging attention heatmaps
        first_batch = next(iter(train_loader))
        fixed_input_ids = first_batch['input_ids'].to(model.device)
        fixed_position_ids = first_batch['position_ids'].to(model.device)
        fixed_next_position_ids = first_batch['next_position_ids'].to(model.device)
        fixed_attention_mask = first_batch['attention_mask'].to(model.device)
        
        
        model.train()
        total_loss = 0
        total_correct = 0
        total_samples = 0
        start_time = time.time()
        
        count = 0
        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()
            # Get input and target sequences
            input_ids = batch['input_ids'].to(model.device)
            position_ids = batch['position_ids'].to(model.device)
            next_position_ids = batch['next_position_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            # Forward pass
            outputs = model(input_ids=input_ids, position_ids=position_ids, next_position_ids=next_position_ids, attention_mask=attention_mask, output_attentions=True)
            logits = outputs.logits
            # Compute loss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            # Compute accuracy
            total_correct += compute_accuracy(shift_logits, shift_labels).item() * shift_labels.numel()
            total_samples += shift_labels.numel()


            if batch_idx % (1 * 1) == 0:
                print(f"Epoch {epoch + 1}, Batch {batch_idx}, Loss: {loss.item()}")
            
            count = count + 1
            if count > 5:
                break
        # Log gradient norms
        total_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1. / 2)
        writer.add_scalar('Gradient Norm/Train', total_norm, epoch)

        # Log learning rate
        writer.add_scalar('Learning Rate', scheduler.get_last_lr()[0], epoch)
        # Log attention heatmaps
        with torch.no_grad():
            model.eval()
            fixed_outputs = model(input_ids=fixed_input_ids, position_ids=fixed_position_ids, next_position_ids=fixed_next_position_ids, attention_mask=fixed_attention_mask, output_attentions=True)
            for layer_idx, layer_attention in enumerate(fixed_outputs.attentions):
                for head_idx, head_attention in enumerate(layer_attention[0]):
                    attention_matrix = head_attention.detach().cpu().numpy()
                    fig = plot_attention_heatmap(attention_matrix, epoch, layer_idx, head_idx)
                    writer.add_figure(f'Attention/Layer_{layer_idx+1}_Head_{head_idx+1}', fig, epoch)
                    
        scheduler.step()
        average_loss = total_loss / len(train_loader)
        train_accuracy = total_correct / total_samples
        writer.add_scalar('Loss/Train', average_loss, epoch)
        writer.add_scalar('Accuracy/Train', train_accuracy, epoch)
        
        adaptive_shuffle.adjust_shuffle_percentage(average_loss)
        print(f"Epoch {epoch + 1}: Average Loss={average_loss}, Train Accuracy={train_accuracy}")

        # Evaluation part
        model.eval()
        eval_loss = 0
        eval_correct = 0
        eval_samples = 0
        shuffled_eval_dataset = shuffle_with_positional_ids(eval_dataset, shuffle_percentage)
        eval_loader = DataLoader(shuffled_eval_dataset, batch_size=batch_size, shuffle=False)
        
        eval_count = 0
        with torch.no_grad():
            for batch in eval_loader:
                input_ids = batch['input_ids'].to(model.device)
                position_ids = batch['position_ids'].to(model.device)
                next_position_ids = batch['next_position_ids'].to(model.device)
                attention_mask = batch['attention_mask'].to(model.device)
                labels = input_ids.clone()
                outputs = model(input_ids=input_ids, position_ids=position_ids, next_position_ids=next_position_ids, attention_mask=attention_mask)
                logits = outputs.logits
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = input_ids[..., 1:].contiguous()
                loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                eval_loss += loss.item()
                eval_correct += compute_accuracy(shift_logits, shift_labels).item() * shift_labels.numel()
                eval_samples += shift_labels.numel()

                eval_count = eval_count +1
                if eval_count > 5:
                    break

        average_eval_loss = eval_loss / len(eval_loader)
        eval_accuracy = eval_correct / eval_samples
        writer.add_scalar('Loss/Eval', average_eval_loss, epoch)
        writer.add_scalar('Accuracy/Eval', eval_accuracy, epoch)
        print(f"Epoch {epoch + 1}: Evaluation Loss={average_eval_loss}, Eval Accuracy={eval_accuracy}")

        epoch_time = time.time() - start_time
        writer.add_scalar('Time/Epoch', epoch_time, epoch)
        print(f"Epoch {epoch + 1}: Time Taken={epoch_time}s")

        # Log weight and bias histograms
        for name, param in model.named_parameters():
            writer.add_histogram(f"{name}/weight", param, epoch)
            if param.grad is not None:
                writer.add_histogram(f"{name}/grad", param.grad, epoch)

    print("Training completed")
    writer.close()

adaptive_shuffle = AdaptiveShuffle()
train_model(model, tokenizer, adaptive_shuffle, train_dataset, eval_dataset)


Epoch 1: Shuffle Percentage=1.0
Epoch 1, Batch 0, Loss: 13.54886245727539
Epoch 1, Batch 1, Loss: 13.046292304992676
Epoch 1, Batch 2, Loss: 12.963492393493652
Epoch 1, Batch 3, Loss: 8.204950332641602
Epoch 1, Batch 4, Loss: 6.171317100524902
Epoch 1, Batch 5, Loss: 6.219213962554932
Epoch 1: Average Loss=0.09141964825453366, Train Accuracy=0.14474126455994943
Epoch 1: Evaluation Loss=0.7370542490257407, Eval Accuracy=0.25646841526031494
Epoch 1: Time Taken=41.761093854904175s
Epoch 2: Shuffle Percentage=1.0
Epoch 2, Batch 0, Loss: 6.213573455810547
Epoch 2, Batch 1, Loss: 6.324640274047852
Epoch 2, Batch 2, Loss: 6.55735969543457
Epoch 2, Batch 3, Loss: 6.454240322113037
Epoch 2, Batch 4, Loss: 6.219867706298828
Epoch 2, Batch 5, Loss: 6.089297771453857
Epoch 2: Average Loss=0.05753644259142658, Train Accuracy=0.20220094298322996
Epoch 2: Evaluation Loss=0.7239555322899008, Eval Accuracy=0.17977150281270346
Epoch 2: Time Taken=43.42031502723694s
Epoch 3: Shuffle Percentage=1.0
Epoch 

KeyboardInterrupt: 

In [None]:
# Shuffle the sequences based on the current shuffle percentage
shuffled_train_dataset = shuffle_with_positional_ids(train_dataset, 0)

# Create DataLoader
train_loader = DataLoader(shuffled_train_dataset, batch_size=32, shuffle=True)

In [None]:
count = 0
for batch_idx, batch in enumerate(train_loader):
    count = count + 1


In [None]:
count

In [None]:
shuffled_train_dataset[14]

In [None]:
model.att