In [6]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from transformers import GPT2Model, GPT2Config

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

to_numpy = lambda x: x.detach().cpu().numpy()

cuda:1


In [7]:
class ICLTransformer(nn.Module):
    def __init__(self, input_dim=100, val_dim=1, embed_dim=256, num_heads=8, num_layers=12,
                 image_size=16, patch_size=4, max_seq_len=128):
        # max_seq_len corresponds to 2k, so k x_k samples and k f(x_k), and then one more for x_query
        super().__init__()

        self.grid_h = image_size // patch_size
        self.grid_w = image_size // patch_size
        self.num_patches = self.grid_h * self.grid_w

        self.time_embed = nn.Parameter(torch.randn(1, max_seq_len * 2, embed_dim))

        self.val_dim = val_dim
        self.value_proj = nn.Linear(val_dim, embed_dim)

        self.fc_in = nn.Linear(input_dim, 256)
        # decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        # self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        config = GPT2Config(
            n_positions=2 * max_seq_len,
            n_embd=embed_dim,
            n_layer=num_layers,
            n_head=num_heads,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
            use_cache=False,
        )
        self.transformer = GPT2Model(config)

        self.fc_out = nn.Linear(embed_dim, val_dim)

    def forward(self, x_in, values):
        x_embeddings = self.fc_in(x_in)
        B, T, _ = x_embeddings.shape
        
        val_embeds = self.value_proj(values)
        embeddings = self.interleave(x_embeddings, val_embeds)
        inds = torch.arange(T).to(device)
        # embeddings = embeddings + self.time_embed[:, :2 * T]

        # mask = nn.Transformer.generate_square_subsequent_mask(2 * T).to(embeddings.device)
        # memory = torch.zeros(B, 1, embeddings.shape[-1]).to(embeddings.device)
        # x = self.transformer(tgt=embeddings, memory=memory, tgt_mask=mask)
        x = self.transformer(inputs_embeds=embeddings).last_hidden_state
        predictions = self.fc_out(x)

        return predictions[:, ::2, 0][:, inds]
    
    def interleave(self, xs, ys):
        B, T, D = xs.shape
        stacked = torch.stack((xs, ys), dim=2)  # [B, T, 2, D]
        interleaved = stacked.view(B, 2 * T, D)
        return interleaved

In [8]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=100, hidden_dim=100):
        super().__init__()
        self.input = nn.Linear(input_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        x = F.relu(self.input(x))
        return self.output(x)
    
    def _initialize_weights(self):
        layers = [m for m in self.modules() if isinstance(m, (nn.Linear, nn.Conv2d))]
        for i, m in enumerate(layers):
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                if i < len(layers) - 1:
                    nn.init.normal_(m.weight)
                else:
                    nn.init.normal_(m.weight, std=2/self.hidden_dim)

In [9]:
n_epochs = 500_000
batch_size = 64
d_max = 100
d_cur = 5
n_samples = 2 * d_cur + 1
losses = []
final_losses = []
transformer = ICLTransformer(d_max).to(device)
optim = torch.optim.AdamW(transformer.parameters(), 1e-4)

for epoch in range(1, n_epochs+1):
    if epoch % 2000 == 0:
        d_cur += 1
        n_samples = 2 * d_cur + 1
    if epoch % 100 == 0:
        print(f'{epoch}: {losses[-1]}, {final_losses[-1]}')
    xs = torch.randn(batch_size, n_samples, d_max).to(device)
    xs[:, :, d_cur:] = 0
    ws = torch.randn(batch_size, 1, d_max).to(device)
    ys = (ws * xs).sum(-1, keepdim=True)
    y_preds = transformer(xs, ys)
    y_targets = ys.squeeze(-1)
    loss = F.mse_loss(y_preds, y_targets)
    final_loss = F.mse_loss(y_preds[:, -1], y_targets[:, -1])
    losses.append(loss.item())
    final_losses.append(final_loss.item())
    optim.zero_grad()
    loss.backward()
    optim.step()
    

100: 5.224421501159668, 6.439401626586914
200: 4.115876197814941, 4.443578720092773
300: 4.406532287597656, 3.1530070304870605
400: 4.095528602600098, 3.0782384872436523
500: 3.84342622756958, 3.2245068550109863
600: 3.4031782150268555, 3.1673550605773926
700: 4.011458873748779, 3.2145540714263916
800: 3.410419225692749, 2.5860767364501953
900: 2.815985918045044, 3.127397060394287
1000: 2.7383556365966797, 1.4801092147827148
1100: 2.7133703231811523, 1.680066466331482
1200: 2.337367057800293, 1.5427826642990112
1300: 2.147916793823242, 0.8538510799407959
1400: 1.8265982866287231, 0.4219081401824951
1500: 2.0315144062042236, 0.6993225812911987
1600: 1.6928085088729858, 0.3308414816856384
1700: 1.65704345703125, 0.26988545060157776
1800: 1.9510554075241089, 0.3630797564983368
1900: 1.5730009078979492, 0.176939457654953
2000: 1.6106585264205933, 0.2578014135360718
2100: 3.820971727371216, 3.325739860534668
2200: 3.862192153930664, 1.2951825857162476
2300: 2.6715681552886963, 1.23356044292

KeyboardInterrupt: 

In [None]:
# n_epochs = 500_000
# batch_size = 64
# d_max = 100
# d_cur = 5
# n_samples = 5 * d_cur + 1
# losses = []
# final_losses = []
# transformer = ICLTransformer(d_max).to(device)
# optim = torch.optim.AdamW(transformer.parameters(), 1e-4)

# for epoch in range(1, n_epochs+1):
#     if epoch % 2000 == 0:
#         d_cur += 1
#         n_samples = 5 * d_cur + 1
#     if epoch % 100 == 0:
#         print(f'{epoch}: {losses[-1]}, {final_losses[-1]}')
#     xs = torch.randn(batch_size, n_samples, d_max).to(device)
#     xs[:, :, d_cur:] = 0
#     ws_1 = torch.randn(batch_size, 1, d_max).to(device)
#     ws_2 = torch.randn(batch_size, 1, d_max).to(device) * (2 / d_max)**0.5
#     ys = (ws_2 * F.relu(ws_1 * xs)).sum(-1, keepdim=True)
#     y_preds = transformer(xs, ys)
#     y_targets = ys
#     loss = F.mse_loss(y_preds, y_targets)
#     final_loss = F.mse_loss(y_preds[:, -1], y_targets[:, -1])
#     losses.append(loss.item())
#     final_losses.append(final_loss.item())
#     optim.zero_grad()
#     loss.backward()
#     optim.step()

100: 0.048953838646411896, 0.033080246299505234
200: 0.029762040823698044, 0.029734712094068527
300: 0.030237438157200813, 0.02668357826769352
400: 0.035440947860479355, 0.021727513521909714
500: 0.04698522761464119, 0.05671188235282898
600: 0.04262048751115799, 0.021662253886461258
700: 0.033532582223415375, 0.040008530020713806
800: 0.03731117397546768, 0.021442417055368423
900: 0.03133833408355713, 0.03280448913574219
1000: 0.031016496941447258, 0.02640456147491932
1100: 0.031058307737112045, 0.020216336473822594
1200: 0.03279753029346466, 0.025422774255275726
1300: 0.03680151700973511, 0.04315032437443733
1400: 0.052755169570446014, 0.03241295367479324
1500: 0.04150332137942314, 0.05401502922177315
1600: 0.03237580880522728, 0.02893785759806633
1700: 0.04033122584223747, 0.04666707664728165
1800: 0.046846844255924225, 0.09029857814311981
1900: 0.026266159489750862, 0.016595646739006042
2000: 0.03872684761881828, 0.018839705735445023
2100: 0.055760517716407776, 0.02602071687579155
2

In [None]:
print(y_preds[:, -1])
print(y_targets[:, -1])

tensor([-1.5038, -1.4061, -0.4550, -1.7284,  0.2009, -0.8888, -1.2016,  2.2628,
         0.2417,  3.3320,  2.4030, -2.4408,  2.5768,  0.3298,  1.2715,  1.2522,
        -1.9192,  2.8138, -1.6523,  3.4755, -5.4497, -0.6134,  0.7242, -1.3211,
         3.4456,  2.2038,  0.3460,  0.7415,  0.3938,  0.9828,  1.4486, -5.0805,
         1.9532, -0.9781,  0.9409,  0.4682,  0.6703, -5.7165,  0.7405, -0.6933,
         2.5001, -1.4460,  1.5262,  2.3915,  1.1792, -5.3400, -0.0508, -2.6798,
        -1.3692, -3.3743,  1.2782,  3.0200, -1.7513, -0.6687,  2.6614, -1.3479,
         2.2771,  1.7686,  3.2906, -1.9842,  2.5961, -0.4701, -2.1348,  1.6573],
       device='cuda:1', grad_fn=<SelectBackward0>)
tensor([-1.6063,  0.6132, -0.0369, -2.3265, -2.0652, -3.0482, -1.1844, -1.6541,
        -0.7453,  3.5628,  2.2337, -0.4224,  3.4141, -0.1282,  1.3983,  1.4424,
        -2.6016,  0.8868, -1.6059,  3.2607, -7.0381, -0.6906,  0.6043, -0.4372,
         2.9834,  2.7182,  0.0267,  0.3544, -3.5444,  1.9392,  2.308