In [2]:
!conda install -c conda-forge transformers -y

Channels:
 - conda-forge
 - defaults
 - intel
 - pytorch
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done


    current version: 23.7.4
    latest version: 23.9.0

Please update conda by running

    $ conda update -n base -c conda-forge conda



## Package Plan ##

  environment location: /localdisk/dmitriim/miniconda3/envs/pytorch_onednn

  added / updated specs:
    - transformers


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    _libgcc_mutex-0.1          |      conda_forge           3 KB  conda-forge
    _openmp_mutex-4.5          |            2_gnu          23 KB  conda-forge
    aiohttp-3.8.6              |  py311h459d7ec_1         738 KB  conda-forge
    aiosignal-1.3.1            |     pyhd8ed1ab_0          12 KB  conda-forge
    aom-3.5.0                  |       h27087fc_0         2.7 MB  conda-forge
    async-timeout-4.0.3        

In [21]:
import transformers

import torch
import torch.nn.functional as F
from torch import nn

In [25]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention_out, _ = self.attention(value, key, query, attn_mask=mask)
        x = self.norm1(attention_out + query)
        x = self.dropout(x)
        forward = self.feed_forward(x)
        out = self.norm2(forward + x)
        return out

class GPT(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
    ):
        super(GPT, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embeddings = nn.Embedding(vocab_size, embed_size)
        self.position_embeddings = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(
            (self.word_embeddings(x) + self.position_embeddings(positions))
        )

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return self.fc_out(out)

# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT(
    vocab_size=10000,
    embed_size=4096,
    num_layers=28,
    heads=16,
    device=device,
    forward_expansion=4,
    dropout=0.1,
    max_length=128,
).to(device)

In [None]:
model_name = "EleutherAI/gpt-j-6B"
gpt = transformers.AutoModelForCausalLM.from_pretrained(model_name)

In [None]:
gpt.

In [26]:
from datasets import load_dataset
from tqdm.auto import tqdm
import torch.optim as optim

# model.gradient_checkpointing_enable()
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
codeparrot = load_dataset("transformersbook/codeparrot-train", streaming=True)
optimizer = optim.Adam(model.parameters(), lr=1e-5)

prompt = tokenizer("A cat sat on a mat", return_tensors='pt')
prompt = {key: value.to(device) for key, value in prompt.items()}

with torch.cuda.amp.autocast():
    for row in tqdm(codeparrot["train"]):
        if len(row["content"]) <= 1:
            continue

        batch = tokenizer(row["content"], truncation=True, max_length=128, return_tensors='pt')
        print(batch)
        batch = {k: v.to(device) for k, v in batch.items()}
        print(batch["input_ids"].shape)
        print(batch["attention_mask"].shape)

        out = model.forward(batch["input_ids"], batch["attention_mask"])

        loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),
                               reduction='mean')
        print(loss)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Repo card metadata block was not found. Setting CardData to empty.


Resolving data files:   0%|          | 0/183 [00:00<?, ?it/s]

0it [00:00, ?it/s]

{'input_ids': tensor([[29113, 29113,  7804,  4242, 21017,   198,  2235,   198,  2235,   220,
         15069,   357,    34,     8,  2211,    12,  4967, 41489, 31110,   402,
          2022,    39,   198,  2235,   198,  2235,   220, 49962,   739,   262,
         24843, 13789,    11, 10628,   362,    13,    15,   357,  1169,   366,
         34156, 15341,   198,  2235,   220,   345,   743,   407,   779,   428,
          2393,  2845,   287, 11846,   351,   262, 13789,    13,   198,  2235,
           220,   921,   743,  7330,   257,  4866,   286,   262, 13789,   379,
           198,  2235,   198,  2235,   220,   220,   220,   220,   220,  2638,
          1378,  2503,    13, 43073,    13,  2398,    14,   677,  4541,    14,
            43,  2149, 24290,    12,    17,    13,    15,   198,  2235,   198,
          2235,   220, 17486,  2672,   416,  9723,  1099,   393,  4987,   284,
           287,  3597,    11,  3788,   198,  2235,   220,  9387,   739,   262,
         13789,   318,  9387,   319,  

IndexError: index out of range in self

In [10]:
print(**batch)

TypeError: 'input_ids' is an invalid keyword argument for print()