# Извлечение признаков из nanoGPT

In [8]:
from backbones.nanoGPT import GPT, GPTConfig

In [10]:
model = GPT.from_pretrained("gpt2")

loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
number of parameters: 123.65M


In [11]:
import tiktoken
enc = tiktoken.get_encoding("gpt2")
text = "You are a minecraft agent, you will be given an image history and actions history along with the instruction what to do. Your goal is to predict an action" * 32
token_num = len(enc.encode(text))
print(token_num)
print(token_num <= 1024)


1024
True


Как можем видеть, в гпт влезет довольно большая инструкция

Теперь имплементируем простой извлекатель признаков. Для этого скорректируем исходник nanoGPT, модифицируем метод forward:

```python
def forward(self, idx, targets=None, only_latents = False):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        
        if only_latents:
            #let's return as tuple to keep the signature of the return
            return None, None, x
        
        ...
```

In [16]:
import torch.nn as nn
import torch

class NanoGPTFeatureExtractor(nn.Module):
    def __init__(self, device = "cuda", padding_id = 0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = GPT.from_pretrained("gpt2").to(device)
        self.device = device
        self.tokenizer = tiktoken.get_encoding("gpt2")
        self.padding_id = padding_id

    @torch.no_grad()
    def forward(self, texts : list[str], pooling : str = "last", add_eos = False, max_tokens = 1024):
      """extract features from the batch of texts
         extract them according to pooling:
         pooling modes:
          None or 'none' : return all tokens as is
          'mean' : average of all tokens
          'last' : last token"""
      ids = [self.tokenizer.encode(t) for t in texts]
      if add_eos:
        eos = self.tokenizer.eot_token
        ids = [x + [eos] for x in ids]
      ids = [x[-max_tokens:] for x in ids]
      #add padding for same length in batch
      lengths = torch.tensor([len(id) for id in ids], device=self.device, dtype=torch.long)
      B = len(ids)
      T = int(lengths.max().item())
      padded_ids = torch.full((B, T), fill_value=self.padding_id, device=self.device, dtype=torch.long)
      for i, id in enumerate(ids):
        padded_ids[i, :len(id)] = torch.tensor(id, device=self.device, dtype=torch.long)
      _, _, latents = self.model(padded_ids, only_latents=True)
      return self._pooling(lengths, latents, pooling)

    def _pooling(self, lengths, latents, pooling):
      if pooling is None or pooling == 'none':
        return latents, lengths #B, T (padded), D
      B = latents.size(0)
      if pooling == 'last':
        return latents[torch.arange(B, device=latents.device), lengths-1], lengths #B, D
      T = latents.size(1)
      if pooling == 'mean':
        indices = torch.arange(T, device=latents.device).unsqueeze(0) # 1, T (padded)
        mask = indices < lengths.unsqueeze(1) #B, T(padded)
        mask = mask.float().unsqueeze(-1)
        sum = (latents * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1.0)
        return sum / denom, lengths #B, D
      return ValueError(f"pooling can be either none | last | mean but {pooling} was given")


In [19]:

texts = ["hello, world!", "i love minecraft so much"]

extractor = NanoGPTFeatureExtractor(device="cpu")

latents, lengths = extractor(texts, pooling = "last")
latents.shape

loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
number of parameters: 123.65M


torch.Size([2, 768])

Теперь этим кодом можем извлекать признаки из текстовых инструкций