In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys

sys.path.append("/home/jovyan/vmeshchaninov/DiffusionTextGeneration-cond-ca")

In [11]:
import torch
from torch import nn
from torch.nn import functional as F
import math
from transformers import AutoModel, AutoTokenizer

from autoencoder.perceiver_resampler import PerceiverResampler

In [7]:
encoder = AutoModel.from_pretrained("bert-base-uncased")

Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Process ForkProcess-11:
Process ForkProcess-31:
Process ForkProcess-7:
Process ForkProcess-19:
Process ForkProcess-25:
Process ForkProcess-8:
Process ForkProcess-12:
Process ForkProcess-32:
Process ForkProcess-13:
Process ForkProcess-18:
Process ForkProcess-3:
Process ForkProcess-20:
Process ForkProcess-5:
Process ForkProcess-6:
Process ForkProcess-17:
Process ForkProcess-29:
Process ForkProcess-21:
Process ForkProcess-14:
Process ForkProcess-26:
Process ForkProcess-4:
Process ForkProcess-16:
Process ForkProcess-30:
Process ForkProcess-10:
Process ForkProcess-27:
Process ForkProcess-28:
Process ForkProcess-15:
Process ForkProcess-9:
Process ForkProcess-1:
Process ForkProcess-2:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last)

In [86]:
class AutoEncoder(nn.Module):
    def __init__(self,):
        super().__init__()

        self.encoder = AutoModel.from_pretrained("bert-base-uncased")
        self.compressor = PerceiverResampler(
            num_latents=32,
            embedding_dim=768,
            hidden_size=768,
            n_heads=12,
            n_layer=4,
        )
        self.decoder = PerceiverResampler(
            num_latents=64,
            embedding_dim=768,
            hidden_size=768,
            n_heads=12,
            n_layer=4,
        )
        self.projector = nn.Linear(768, 30522, bias=False)

    def encode(self, input_ids, attention_mask):
        return self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).last_hidden_state

    def compress(self, encodings, attention_mask):
        return self.compressor(
            x=encodings,
            mask_x=attention_mask
        )
    
    def reconstruct(self, latents, attention_mask):
        return self.decoder(
            x=latents, 
            mask_x=None,
            mask_latent=attention_mask,
        )

    def forward(self, input_ids, attention_mask):
        encodings = self.encode(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        
        latents = self.compress(
            encodings=encodings,
            attention_mask=attention_mask
        )
        
        recon_x = self.reconstruct(
            latents=latents,
            attention_mask=attention_mask,
        )
        
        logits = self.projector(recon_x)
        return logits

In [105]:
from datasets import Dataset
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy
from tqdm import tqdm

In [88]:
path = "/home/jovyan/vmeshchaninov/DiffusionLanguageModel/data/wikipedia/test/data-00000-of-00001.arrow"

In [90]:
dt = Dataset.from_file(path)

In [95]:
loader = DataLoader(dt, batch_size=32)

In [80]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [96]:
model = AutoEncoder().cuda()

In [104]:
optim = torch.optim.AdamW(
    params=list(model.compressor.parameters()) + list(model.decoder.parameters()) + list(model.projector.parameters()),
    lr=0.0001
)

In [103]:
T = tqdm(loader)

for text in loader:
    text = text["text"]
    input = tokenizer(texts, padding=True, return_tensors="pt", max_length=64, truncation=True)
    input_ids = input["input_ids"].cuda()
    mask = input["attention_mask"].cuda()
    logits = model(input_ids, attention_mask=mask)

    losses = cross_entropy(
        input=logits.reshape(-1, logits.shape[-1]),
        target=input_ids.reshape(-1),
        reduce=False,
    )
    losses = losses * mask.reshape(-1)
    loss = torch.sum(losses) / torch.sum(mask)

    optim.zero_grad()
    loss.backward()
    optim.step()
    T.set_description(f"{loss.item():0.5f}")
    print()



10.49067
6.37974
4.79527
4.51669
3.39336
2.69950
2.00125
1.79654
1.67591
1.61225
1.52871
1.50826
1.48204
1.46782
1.45248
1.41414
1.39425
1.40005
1.36578
1.32847
1.30277
1.28754
1.27114
1.27353
1.22971
1.19865
1.15081
1.15060
1.08809
1.10260
1.06970
1.06063
1.05947
1.05111
1.01788
1.02450
1.02670
1.02647
1.02363
0.97871
0.97138
0.98084
1.33228
1.12187
1.09918
1.04309
1.01486
0.95238
0.92084
0.83751
0.81326
0.73488
0.82370
0.81744
0.64869
0.63558
0.56373
0.57624
0.53440
0.45482
0.47869
0.40392
0.35399
0.36725
0.30026
0.24288
0.23812
0.23430
0.16673
0.12307
0.12235
0.09811
0.08660
0.09735
0.14878
0.13983
0.09283
0.08151
0.08959
0.09666
0.07437
0.10754
0.08335
0.10753
0.08441
0.12820
0.06251
0.12643
0.05676
0.12769
0.07077
0.10115
0.05099
0.08740
0.04846
0.05459
0.05865
0.09664
0.07130
0.09791
0.05496
0.06824
0.04932
0.04184
0.04048
0.08433
0.02659
0.04288
0.05942
0.10725
0.09566
0.11559
0.07737
0.11681
0.13818
0.14848
0.15301
0.25541
0.15790
0.15247
0.10750
0.10007
0.11159
0.06882
0.04934

KeyboardInterrupt: 

In [82]:
input = tokenizer(texts, padding=True, return_tensors="pt", max_length=64, truncation=True)

In [85]:
logits = model(input["input_ids"], attention_mask=input["attention_mask"])

latents torch.Size([5, 32, 768])
tensor([[-0.0284, -2.0810,  0.9054,  ..., -1.8932,  0.2363, -0.3202],
        [-0.0770, -0.0791, -0.0988,  ...,  0.6037,  0.7646, -1.2721],
        [-0.0058,  1.7168, -0.5959,  ...,  0.5185, -0.9237,  1.0923],
        ...,
        [-2.7830, -0.3116,  0.0957,  ..., -2.3433,  0.7221, -1.3160],
        [-0.9613, -0.7737, -1.5189,  ..., -0.3138, -1.4499, -1.7092],
        [-2.5439,  1.1114, -1.0132,  ..., -1.2546,  0.2382, -1.8102]],
       grad_fn=<SelectBackward0>)


In [84]:
logits

tensor([[[-0.4396, -0.8253, -0.5283,  ..., -0.0704,  0.5757, -0.9302],
         [-0.7846,  0.3032, -0.4085,  ...,  0.7150,  0.1543,  0.1302],
         [-0.3213,  0.1252,  0.8472,  ...,  0.8545, -1.0129, -0.5052],
         ...,
         [-1.2573,  0.2779, -0.7370,  ..., -0.1225,  0.5838,  0.3687],
         [-0.0331, -0.5335, -0.3597,  ...,  1.3123,  1.1517,  0.2731],
         [ 1.0163, -0.8445, -0.5798,  ..., -0.4975,  0.1629,  0.6302]],

        [[-0.6367, -0.6490, -0.5346,  ..., -0.0164,  0.6423, -1.1062],
         [-0.8178,  0.1542, -0.4438,  ...,  0.7381, -0.1106,  0.2450],
         [-0.2838,  0.0880,  0.8183,  ...,  0.8263, -1.0323, -0.5716],
         ...,
         [-1.4605,  0.3859, -0.7631,  ..., -0.0396,  0.6596,  0.2489],
         [ 0.1524, -0.5320, -0.4250,  ...,  1.4705,  1.0991,  0.2965],
         [ 0.8439, -0.8857, -0.6250,  ..., -0.4668,  0.1538,  0.6959]],

        [[-0.4414, -0.5522, -0.4304,  ..., -0.0622,  0.7180, -1.0274],
         [-0.8686,  0.3472, -0.3957,  ...,  0

In [30]:
a = torch.Tensor(
    [
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]
    ]
)

In [31]:
b = torch.Tensor(
    [
        [1, 1, 0, 0],
        [1, 1, 1, 0]
    ]
)

In [37]:
(a.view(2, 1, 5).repeat(1, 4, 1) * b.view(2, 4, 1)).shape

torch.Size([2, 4, 5])