Skip to content

LucTuc/gpt-pyg-attention

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Self-Attention for Transformers with Graph Attention Networks

This repository contains an implementation of GPT where the original self-attention mechanism is replaced by Graph Attention Networks (GAT (from now on referred to as GATv1) and GATv2).

Motivation

My area of expertise are graph neural networks (GNNs), which utilize a message-passing mechanism to exchange information between nodes during training. Interestingly, the self-attention mechanism, which is the "heart" of modern transformer networks, can also be thought of as a communication mechanism between nodes in a directed graph. In this graph, the tokens are represented as nodes connected by directed edges. In the decoder part of the transformer implemented here, each token in a given context is connected to itself and all following tokens. An example of a context graph of size 4 is visualized below, where T1 through T4 represent individual tokens within the given context. You can find a more detailed comparison between transformers and GNNs in this blogpost.

Thus, the first goal of this project was to familiarize myself with transformers and create a direct connection to my previous research. The second goal was to create a plug-and-play framework where the self-attention can be replaced by any GNN. As everything is implemented using classes from Pytorch Geometric, it is very simple to replace GAT layers with any convolutional GNN layer from Pytorch Geometric. So feel free to fork the repository and try out different GNN layers and let me know the results! :)

Usage

First, install the conda environment:

conda env create -f gpt.yml

Then, run the code as:

python pyg-gpt.py --gat_version GATConv

where the --gat_version argument is set to either GATConv or GATv2Conv.

Results

The following plot shows the validation losses for the different models. They were all trained for 5000 epochs on the tinyshakespeare dataset:

GATv2, which fixes the static attention issue of vanilla GAT, reaches a similary low loss as the original transformer implementation. You can look at example outputs of all models in the model_outputs folder. Although they all produce non-sense, there is a very obvious improvements in the models that use attention over the simple bigram model, and it undoubtedly starts to resemble Shakespeare's style.

Note that I trained these models on an A100 GPU, so reproduction on your local machine might require downsizing the model parameters quite a bit.

Acknowledgements

This codebase and specifically the gpt.py and bigram.py code was originally created in the Neural Networks: Zero To Hero video lecture series by Andrej Karpathy, specifically on the first lecture on nanoGPT. Many thanks to him for making the code available and teaching me the basics of LLMs in an intuitive manner.

About

Simple GPT implementation originally implemented by Andrej Karpathy & adapted with PyG GAT as a self-attention replacement.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%