<a href="https://colab.research.google.com/github/HarshitaBadiyasar/SuperAGI-Assignment/blob/main/GPT_2_Model_%26_Checkpoints.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import math

# Configuration for a small GPT-2 model
class GPT2Config:
    def __init__(self):
        self.vocab_size = 50257
        self.max_position_embeddings = 1024
        self.n_layers = 12
        self.n_heads = 12
        self.n_embd = 768
        self.layer_norm_epsilon = 1e-5
        self.initializer_range = 0.02

# Scaled dot-product attention function
def scaled_dot_product_attention(query, key, value):
    temp = query.bmm(key.transpose(1, 2)) / math.sqrt(query.size(-1))
    softmax = nn.Softmax(dim=-1)
    return softmax(temp).bmm(value)

# Single attention head
class AttentionHead(nn.Module):
    def __init__(self, embd_dim):
        super().__init__()
        self.query = nn.Linear(embd_dim, embd_dim)
        self.key = nn.Linear(embd_dim, embd_dim)
        self.value = nn.Linear(embd_dim, embd_dim)

    def forward(self, hidden_state):
        return scaled_dot_product_attention(
            self.query(hidden_state), self.key(hidden_state), self.value(hidden_state)
        )

# Multi-head attention
class MultiHeadAttention(nn.Module):
    def __init__(self, embd_dim, n_heads):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(embd_dim) for _ in range(n_heads)])
        self.linear = nn.Linear(n_heads * embd_dim, embd_dim)

    def forward(self, hidden_state):
        attention = [head(hidden_state) for head in self.heads]
        concatenated = torch.cat(attention, dim=-1)
        return self.linear(concatenated)

# Pointwise Feed Forward layer
class PointwiseFeedForward(nn.Module):
    def __init__(self, embd_dim, ff_dim):
        super().__init__()
        self.linear1 = nn.Linear(embd_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embd_dim)

    def forward(self, hidden_state):
        return self.linear2(nn.functional.relu(self.linear1(hidden_state)))

# Transformer block
class TransformerBlock(nn.Module):
    def __init__(self, embd_dim, n_heads, ff_dim, layer_norm_epsilon):
        super().__init__()
        self.attention = MultiHeadAttention(embd_dim, n_heads)
        self.feed_forward = PointwiseFeedForward(embd_dim, ff_dim)
        self.layer_norm1 = nn.LayerNorm(embd_dim, eps=layer_norm_epsilon)
        self.layer_norm2 = nn.LayerNorm(embd_dim, eps=layer_norm_epsilon)

    def forward(self, hidden_state):
        attention_output = self.attention(hidden_state)
        norm1 = self.layer_norm1(hidden_state + attention_output)
        feed_forward_output = self.feed_forward(norm1)
        norm2 = self.layer_norm2(norm1 + feed_forward_output)
        return norm2

# GPT-2 model
class GPT2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embd_dim = config.n_embd
        self.token_embedding = nn.Embedding(config.vocab_size, self.embd_dim)
        self.position_embedding = nn.Embedding(config.max_position_embeddings, self.embd_dim)
        self.blocks = nn.ModuleList([
            TransformerBlock(self.embd_dim, config.n_heads, 4 * self.embd_dim, config.layer_norm_epsilon)
            for _ in range(config.n_layers)
        ])
        self.layer_norm = nn.LayerNorm(self.embd_dim, eps=config.layer_norm_epsilon)

    def forward(self, input_ids, positions_ids=None):
        if positions_ids is None:
            positions_ids = torch.arange(0, input_ids.size(1)).unsqueeze(0).to(input_ids.device)
        tokens = self.token_embedding(input_ids)
        positions = self.position_embedding(positions_ids)

        x = tokens + positions

        for block in self.blocks:
            x = block(x)

        x = self.layer_norm(x)
        return x

# Example usage
if __name__ == "__main__":
    # Configuration setup
    config = GPT2Config()
    # Create GPT-2 model instance
    model = GPT2(config)

    # Generate random input for demonstration
    input_ids = torch.randint(0, config.vocab_size, (1, 1024))
    # Obtain model output
    output = model(input_ids)
    print(output)


tensor([[[ 1.5960, -1.6826, -0.8329,  ...,  0.6209, -0.2323,  1.3404],
         [-1.4362, -0.8451,  0.8915,  ..., -0.5067, -1.3798,  0.0387],
         [ 0.8700, -1.5147, -0.7995,  ...,  0.5665,  0.2666, -0.5611],
         ...,
         [ 0.4757, -0.7607,  0.2302,  ..., -0.2831,  0.4556,  0.9543],
         [ 2.2712, -0.3046, -0.1703,  ...,  0.1419,  0.6453, -0.8552],
         [-0.2020, -1.7010, -1.1221,  ...,  0.0676, -0.0682,  0.6557]]],
       grad_fn=<NativeLayerNormBackward0>)
