In [5]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from transformers import GPT2Model, GPT2Config

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

to_numpy = lambda x: x.detach().cpu().numpy()

cuda:1


In [6]:

class ICLTransformer(nn.Module):
    def __init__(self, input_dim=100, val_dim=1, embed_dim=256, num_heads=8, num_layers=12,
                 image_size=16, patch_size=4, max_seq_len=128):
        # max_seq_len corresponds to 2k, so k x_k samples and k f(x_k), and then one more for x_query
        super().__init__()

        self.grid_h = image_size // patch_size
        self.grid_w = image_size // patch_size
        self.num_patches = self.grid_h * self.grid_w

        self.time_embed = nn.Parameter(torch.randn(1, max_seq_len * 2, embed_dim))

        self.val_dim = val_dim
        self.value_proj = nn.Linear(val_dim, embed_dim)

        self.fc_in = nn.Linear(input_dim, 256)
        # decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        # self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        config = GPT2Config(
            n_positions=2 * max_seq_len,
            n_embd=embed_dim,
            n_layer=num_layers,
            n_head=num_heads,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
            use_cache=False,
        )
        self.transformer = GPT2Model(config)

        self.fc_out = nn.Linear(embed_dim, val_dim)

    def forward(self, x_in, values):
        x_embeddings = self.fc_in(x_in)
        B, T, _ = x_embeddings.shape
        
        val_embeds = self.value_proj(values)
        embeddings = self.interleave(x_embeddings, val_embeds)
        inds = torch.arange(T).to(device)
        # embeddings = embeddings + self.time_embed[:, :2 * T]

        # mask = nn.Transformer.generate_square_subsequent_mask(2 * T).to(embeddings.device)
        # memory = torch.zeros(B, 1, embeddings.shape[-1]).to(embeddings.device)
        # x = self.transformer(tgt=embeddings, memory=memory, tgt_mask=mask)
        x = self.transformer(inputs_embeds=embeddings).last_hidden_state
        predictions = self.fc_out(x)

        return predictions[:, ::2, 0][:, inds] 
    
    def interleave(self, xs, ys):
        B, T, D = xs.shape
        stacked = torch.stack((xs, ys), dim=2)  # [B, T, 2, D]
        interleaved = stacked.view(B, 2 * T, D)
        return interleaved

In [None]:
n_epochs = 500_000
batch_size = 64
d_max = 100
d_cur = 5
n_samples = 2 * d_cur + 1
losses = []
final_losses = []
transformer = ICLTransformer(d_max).to(device)
optim = torch.optim.AdamW(transformer.parameters(), 1e-4)

for epoch in range(1, n_epochs+1):
    if epoch % 2000 == 0:
        d_cur += 1
        n_samples = 2 * d_cur + 1
    if epoch % 100 == 0:
        print(f'{epoch}: {losses[-1]}, {final_losses[-1]}')
    xs = torch.randn(batch_size, n_samples, d_max).to(device)
    xs[:, :, d_cur:] = 0
    ws = torch.randn(batch_size, 1, d_max).to(device)
    ys = (ws * xs).sum(-1, keepdim=True)
    y_preds = transformer(xs, ys)
    y_targets = ys.squeeze(-1)
    loss = F.mse_loss(y_preds, y_targets)
    final_loss = F.mse_loss(y_preds[:, -1], y_targets[:, -1])
    losses.append(loss.item())
    final_losses.append(final_loss.item())
    optim.zero_grad()
    loss.backward()
    optim.step()
    

  loss = F.mse_loss(y_preds, y_targets)
  final_loss = F.mse_loss(y_preds[:, -1], y_targets[:, -1])


100: 4.592724800109863, 4.225244522094727


KeyboardInterrupt: 