In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW, get_linear_schedule_with_warmup
import json
import numpy as np
from tqdm import tqdm

class ArxivDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=512, max_samples=1000):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        
        print("Loading and processing dataset...")
        with open(file_path, 'r') as f:
            for i, line in enumerate(tqdm(f, desc="Loading papers")):
                if i >= max_samples:
                    break
                paper = json.loads(line)
                # Combine title and abstract with proper formatting
                text = f"Title: {paper['title']}\nAbstract: {paper['abstract']}"
                
                # Tokenize the text
                encodings = tokenizer(
                    text,
                    truncation=True,
                    padding="max_length",
                    max_length=max_length,
                    return_tensors="pt"
                )
                
                # Remove the batch dimension
                encodings = {key: val.squeeze(0) for key, val in encodings.items()}
                self.data.append(encodings)
    
    def __getitem__(self, idx):
        item = {key: val.clone() for key, val in self.data[idx].items()}
        item["labels"] = item["input_ids"].clone()
        return item
    
    def __len__(self):
        return len(self.data)

def train_model(model, train_dataloader, optimizer, scheduler, device, num_epochs=10):
    model.train()
    
    for epoch in range(num_epochs):
        total_loss = 0
        progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch + 1}')
        
        # Store gradients for balancing on every 10th epoch
        if (epoch + 1) % 2 == 0:
            param_gradients = []
            
        for batch_idx, batch in enumerate(progress_bar):
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Clear gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(**batch)
            loss = outputs.loss
            
            # Backward pass
            loss.backward()
            
            # On 10th epoch, collect gradients before updating
            if (epoch + 1) % 2 == 0:
                for param in model.parameters():
                    if param.grad is not None:
                        param_gradients.append(param.grad.abs().mean().item())
            
            # Clip gradients
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # Update parameters
            optimizer.step()
            scheduler.step()
            
            # Update progress bar
            total_loss += loss.item()
            avg_loss = total_loss / (batch_idx + 1)
            progress_bar.set_postfix({'avg_loss': avg_loss})
            
            # Save checkpoint every 1000 steps
            if (batch_idx + 1) % 1000 == 0:
                checkpoint = {
                    'epoch': epoch,
                    'batch_idx': batch_idx,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': loss.item(),
                }
                torch.save(checkpoint, f'checkpoint_epoch{epoch}_batch{batch_idx}.pt')
        
        # After collecting gradients on 10th epoch, balance the weights
        if (epoch + 1) % 2 == 0:
            mean_gradient = sum(param_gradients) / len(param_gradients)
            
            with torch.no_grad():
                for param in model.parameters():
                    if param.grad is not None:
                        current_grad_mean = param.grad.abs().mean()
                        if current_grad_mean > 0:  # Avoid division by zero
                            # Calculate scaling factor
                            scale_factor = (mean_gradient / current_grad_mean)/10
                            
                            # Add random noise to determine sign
                            random_signs = torch.randint(0, 2, param.shape, device=param.device) * 2 - 1
                            
                            # Scale the weights and apply random signs
                            param.data *= scale_factor * random_signs.float()
        
        print(f"Average loss for epoch {epoch + 1}: {total_loss / len(train_dataloader)}")

def main():
    # Load GPT-2 XL model and tokenizer
    print("Loading GPT-2 XL model and tokenizer...")
    model_name = "gpt2-large"
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)
    
    # Set pad token
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id
    
    # Load local arXiv dataset
    dataset = ArxivDataset(
        file_path="arxiv-metadata-oai-snapshot.json",
        tokenizer=tokenizer,
        max_length=512,
        max_samples=1000  # Adjust this number as needed
    )
    
    # Create dataloader
    train_dataloader = DataLoader(
        dataset,
        batch_size=2,  # Adjust based on your GPU memory
        shuffle=True,
        num_workers=4  # Adjust based on your CPU cores
    )
    
    # Setup training parameters
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model.to(device)
    
    # Optimizer and scheduler setup
    optimizer = torch.optim.SGD(model.parameters(), lr=5e-5)
    num_training_steps = len(train_dataloader) * 10  # 3 epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=100,
        num_training_steps=num_training_steps
    )
    
    # Train the model
    print("Starting training...")
    train_model(model, train_dataloader, optimizer, scheduler, device)
    
    # Save the trained model
    print("Saving model...")
    model.save_pretrained("./trained_gpt2_xl")
    tokenizer.save_pretrained("./trained_gpt2_xl")

if __name__ == "__main__":
    main()

Loading GPT-2 XL model and tokenizer...




Loading and processing dataset...


Loading papers: 1000it [00:01, 669.60it/s]


Using device: cuda
Starting training...


Epoch 1: 100%|██████████| 500/500 [03:03<00:00,  2.73it/s, avg_loss=3.13]


Average loss for epoch 1: 3.1321546308994295


Epoch 2: 100%|██████████| 500/500 [03:16<00:00,  2.55it/s, avg_loss=1.45]


Average loss for epoch 2: 1.4528525344133376


Epoch 3: 100%|██████████| 500/500 [03:02<00:00,  2.73it/s, avg_loss=7.06]


Average loss for epoch 3: 7.06100291800499


Epoch 4: 100%|██████████| 500/500 [03:16<00:00,  2.54it/s, avg_loss=5.43]


Average loss for epoch 4: 5.431151164531708


Epoch 5: 100%|██████████| 500/500 [03:00<00:00,  2.77it/s, avg_loss=187] 


Average loss for epoch 5: 187.3257192955017


Epoch 6: 100%|██████████| 500/500 [03:15<00:00,  2.56it/s, avg_loss=42.1]


Average loss for epoch 6: 42.148122616767886


Epoch 7: 100%|██████████| 500/500 [02:47<00:00,  2.98it/s, avg_loss=nan]  


Average loss for epoch 7: nan


Epoch 8: 100%|██████████| 500/500 [03:02<00:00,  2.75it/s, avg_loss=nan]


Average loss for epoch 8: nan


Epoch 9:  67%|██████▋   | 333/500 [01:52<00:56,  2.97it/s, avg_loss=nan]


KeyboardInterrupt: 