In [1]:
# Training the brain, using text alone.

In [2]:
from visual_transformer import *

In [3]:
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from torch.utils.data import Dataset

In [4]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # -- penguins.farm version
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # -- penguins.army version
paths = ["text_pretraining_data/eng_sentences_pruned-train.txt"]
vocab_size = 10000
# tokenizer.save_model(".", "tokenizer/eng_sentences_tokenizer_vc10000")
tokenizer = ByteLevelBPETokenizer(
    "./text_pretraining_tokenizer/eng_sentences_tokenizer_vc10000_v2-vocab.json",
    "./text_pretraining_tokenizer/eng_sentences_tokenizer_vc10000_v2-merges.txt",
)   
tokenizer._tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)   
tokenizer.enable_truncation(max_length=32)
tokenizer.enable_padding()

## Dataset
class SampleDataset(Dataset):
    def __init__(self, seq_length = 32, evaluate: bool = False, tokenizer=None, device = None):
        if device is None:
            device = 'cpu'
        self.device = device
        self.seq_length = seq_length
        if tokenizer is None:
            tokenizer = ByteLevelBPETokenizer(
                "./text_pretraining_tokenizer/eng_sentences_tokenizer_v2-vocab.json",
                "./text_pretraining_tokenizer/eng_sentences_tokenizer_v2-merges.txt",
            )   
        tokenizer._tokenizer.post_processor = BertProcessing(
            ("</s>", tokenizer.token_to_id("</s>")),
            ("<s>", tokenizer.token_to_id("<s>")),
        )   
        tokenizer.enable_truncation(max_length=self.seq_length)
        tokenizer.enable_padding()#length=seq_length)
        # or use the RobertaTokenizer from `transformers` directly.

        self.examples = []

        src_files = Path("./text_pretraining_data/").glob("*-eval.txt") if evaluate else Path("./text_pretraining_data/").glob("*-train.txt")
        for src_file in src_files:
            print("🔥", src_file)
            lines = src_file.read_text(encoding="utf-8").splitlines()
            self.examples += [x.ids for x in tokenizer.encode_batch(lines)]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i): 
        # We’ll pad at the batch level.
        return torch.tensor(self.examples[i]).to(self.device)

In [5]:
brain = DefaultAgentBrain().to(device)

In [6]:
sdt = SampleDataset(tokenizer=tokenizer)
sdv = SampleDataset(tokenizer=tokenizer, evaluate=True)

🔥 text_pretraining_data/eng_sentences_pruned-train.txt
🔥 text_pretraining_data/eng_sentences_pruned-eval.txt


In [7]:
from torch.utils.data import DataLoader

batch_size = 80 #32*4*2#8
num_workers = 0

train_loader = DataLoader(sdt, batch_size=batch_size,
                          num_workers=num_workers, shuffle=True)

In [8]:
criterion = nn.CrossEntropyLoss(ignore_index=0)

def get_loss(res, inputs):
    return torch.sum(criterion(res[:, :, :-1], inputs[:, 1:]))

In [None]:
from tqdm import tqdm

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(brain.parameters(), lr=0.0001*80/256, eps=1e-9)#, #betas=(0.9, 0.98), eps=1e-9)

epochs = 16

for epoch in range(epochs):
    brain.train()
    train_loss = 0
    i = -1
    for inputs in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):
        i += 1
        inputs = inputs.to(device)
        img_context = torch.randn((batch_size, 256, 768), device=inputs.device) # easier for pretraining
        src_attention_mask, src_key_padding_mask = brain.get_masks(inputs, use_masks=True)
        text_encoding = brain.get_text_encoding(inputs, src_attention_mask, src_key_padding_mask)
        res = brain.get_text_decoding(text_encoding, src_attention_mask, src_key_padding_mask, img_context, return_full=True)
        loss = get_loss(res, inputs)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
        if i % 1000 == 0:
            avg_train_loss = train_loss / (i + 1)
            print(f"Average Training Loss at batch {i + 1}: {avg_train_loss}")
            torch.save(brain.text_enc.state_dict(), 'brain_checkpoints/text_encoder_weights_v3.pth')
            torch.save(brain.text_dec.state_dict(), 'brain_checkpoints/text_decoder_weights_v3.pth')
    avg_train_loss = train_loss / len(train_loader)
    print(f"Average Training Loss after epoch {epoch}: {avg_train_loss}")
    torch.save(brain.text_enc.state_dict(), 'brain_checkpoints/text_encoder_weights_v3.pth')
    torch.save(brain.text_dec.state_dict(), 'brain_checkpoints/text_decoder_weights_v3.pth')

Training Epoch 1/16:   0%|                                                                                                                                                                     | 0/20000 [00:00<?, ?it/s]

Average Training Loss at batch 1: 9.316092491149902


Training Epoch 1/16:   5%|███████▌                                                                                                                                                | 1000/20000 [04:29<1:25:08,  3.72it/s]

Average Training Loss at batch 1001: 4.725765901607471


Training Epoch 1/16:  10%|███████████████▏                                                                                                                                        | 2000/20000 [09:00<1:20:48,  3.71it/s]

Average Training Loss at batch 2001: 4.40001886811988


Training Epoch 1/16:  15%|██████████████████████▊                                                                                                                                 | 3000/20000 [13:31<1:16:24,  3.71it/s]

Average Training Loss at batch 3001: 4.2339056343763755


Training Epoch 1/16:  20%|██████████████████████████████▍                                                                                                                         | 4000/20000 [18:02<1:11:56,  3.71it/s]

Average Training Loss at batch 4001: 4.127308309748601


Training Epoch 1/16:  25%|██████████████████████████████████████                                                                                                                  | 5000/20000 [22:33<1:07:24,  3.71it/s]

Average Training Loss at batch 5001: 4.04621469686089


Training Epoch 1/16:  30%|█████████████████████████████████████████████▌                                                                                                          | 6000/20000 [27:04<1:02:53,  3.71it/s]

Average Training Loss at batch 6001: 3.9802095581741854


Training Epoch 1/16:  35%|█████████████████████████████████████████████████████▉                                                                                                    | 7000/20000 [31:35<58:19,  3.71it/s]

Average Training Loss at batch 7001: 3.9264698808422396


Training Epoch 1/16:  40%|█████████████████████████████████████████████████████████████▌                                                                                            | 8000/20000 [36:06<53:51,  3.71it/s]

Average Training Loss at batch 8001: 3.881814607693663


Training Epoch 1/16:  45%|█████████████████████████████████████████████████████████████████████▎                                                                                    | 9000/20000 [40:37<49:30,  3.70it/s]

Average Training Loss at batch 9001: 3.841321705262034


Training Epoch 1/16:  50%|████████████████████████████████████████████████████████████████████████████▌                                                                            | 10000/20000 [45:08<44:52,  3.71it/s]

Average Training Loss at batch 10001: 3.8063207524691256


Training Epoch 1/16:  55%|████████████████████████████████████████████████████████████████████████████████████▏                                                                    | 11000/20000 [49:40<40:31,  3.70it/s]

Average Training Loss at batch 11001: 3.77447204205808


Training Epoch 1/16:  60%|███████████████████████████████████████████████████████████████████████████████████████████▊                                                             | 12000/20000 [54:12<35:56,  3.71it/s]

Average Training Loss at batch 12001: 3.7459280223471354


Training Epoch 1/16:  65%|███████████████████████████████████████████████████████████████████████████████████████████████████▍                                                     | 13000/20000 [58:43<31:28,  3.71it/s]

Average Training Loss at batch 13001: 3.7190133005772616


Training Epoch 1/16:  70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                             | 14000/20000 [1:03:14<27:02,  3.70it/s]

Average Training Loss at batch 14001: 3.6944150092661205


Training Epoch 1/16:  75%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                     | 15000/20000 [1:07:45<22:31,  3.70it/s]

Average Training Loss at batch 15001: 3.6709845734233815


Training Epoch 1/16:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                              | 16000/20000 [1:12:17<17:58,  3.71it/s]

Average Training Loss at batch 16001: 3.6496240171847676


Training Epoch 1/16:  85%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                      | 17000/20000 [1:16:48<13:30,  3.70it/s]

Average Training Loss at batch 17001: 3.6288734807749927


Training Epoch 1/16:  90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉               | 18000/20000 [1:21:19<08:59,  3.71it/s]

Average Training Loss at batch 18001: 3.6098645449307196


Training Epoch 1/16:  95%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍       | 19000/20000 [1:25:50<04:29,  3.71it/s]

Average Training Loss at batch 19001: 3.5923654388411475


Training Epoch 1/16: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [1:30:22<00:00,  3.69it/s]


Average Training Loss after epoch 0: 3.5751865940213206


Training Epoch 2/16:   0%|                                                                                                                                                                     | 0/20000 [00:00<?, ?it/s]

Average Training Loss at batch 1: 3.109647750854492


Training Epoch 2/16:   5%|███████▌                                                                                                                                                | 1000/20000 [04:30<1:25:26,  3.71it/s]

Average Training Loss at batch 1001: 3.2007551610053


Training Epoch 2/16:  10%|███████████████▏                                                                                                                                        | 2000/20000 [09:02<1:21:09,  3.70it/s]

Average Training Loss at batch 2001: 3.193572289225222


Training Epoch 2/16:  15%|██████████████████████▊                                                                                                                                 | 3000/20000 [13:33<1:16:32,  3.70it/s]

Average Training Loss at batch 3001: 3.1835153334063713


Training Epoch 2/16:  20%|██████████████████████████████▍                                                                                                                         | 4000/20000 [18:04<1:12:06,  3.70it/s]

Average Training Loss at batch 4001: 3.177908931961956


Training Epoch 2/16:  25%|██████████████████████████████████████                                                                                                                  | 5000/20000 [22:35<1:07:19,  3.71it/s]

Average Training Loss at batch 5001: 3.172420911516244


Training Epoch 2/16:  30%|█████████████████████████████████████████████▌                                                                                                          | 6000/20000 [27:07<1:03:06,  3.70it/s]

Average Training Loss at batch 6001: 3.166583309569293


Training Epoch 2/16:  35%|█████████████████████████████████████████████████████▉                                                                                                    | 7000/20000 [31:38<58:35,  3.70it/s]

Average Training Loss at batch 7001: 3.1612249050050476


Training Epoch 2/16:  40%|█████████████████████████████████████████████████████████████▌                                                                                            | 8000/20000 [36:09<53:51,  3.71it/s]

Average Training Loss at batch 8001: 3.155896200267781


Training Epoch 2/16:  45%|█████████████████████████████████████████████████████████████████████▎                                                                                    | 9000/20000 [40:40<49:28,  3.71it/s]

Average Training Loss at batch 9001: 3.1501564208487247


Training Epoch 2/16:  50%|████████████████████████████████████████████████████████████████████████████▌                                                                            | 10000/20000 [45:11<44:56,  3.71it/s]

Average Training Loss at batch 10001: 3.145067584143914


Training Epoch 2/16:  55%|████████████████████████████████████████████████████████████████████████████████████▏                                                                    | 11000/20000 [49:43<40:25,  3.71it/s]

Average Training Loss at batch 11001: 3.1398323572892384


Training Epoch 2/16:  60%|███████████████████████████████████████████████████████████████████████████████████████████▊                                                             | 12000/20000 [54:14<36:03,  3.70it/s]

Average Training Loss at batch 12001: 3.1342914981570265


Training Epoch 2/16:  65%|███████████████████████████████████████████████████████████████████████████████████████████████████▍                                                     | 13000/20000 [58:45<31:29,  3.70it/s]

Average Training Loss at batch 13001: 3.129219017160112


Training Epoch 2/16:  70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                             | 14000/20000 [1:03:16<27:02,  3.70it/s]

Average Training Loss at batch 14001: 3.1235915145808635


Training Epoch 2/16:  75%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                     | 15000/20000 [1:07:48<22:27,  3.71it/s]

Average Training Loss at batch 15001: 3.119052573908377


Training Epoch 2/16:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                              | 16000/20000 [1:12:20<17:59,  3.71it/s]

Average Training Loss at batch 16001: 3.114170181550843


Training Epoch 2/16:  85%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                      | 17000/20000 [1:16:51<13:27,  3.72it/s]

Average Training Loss at batch 17001: 3.1096745298144524


Training Epoch 2/16:  90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉               | 18000/20000 [1:21:22<08:58,  3.71it/s]

Average Training Loss at batch 18001: 3.1053871576948024


Training Epoch 2/16:  95%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍       | 19000/20000 [1:25:53<04:30,  3.70it/s]

Average Training Loss at batch 19001: 3.101329814044496


Training Epoch 2/16: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [1:30:25<00:00,  3.69it/s]


Average Training Loss after epoch 1: 3.0971908952951432


Training Epoch 3/16:   0%|                                                                                                                                                                     | 0/20000 [00:00<?, ?it/s]

Average Training Loss at batch 1: 2.7698516845703125


Training Epoch 3/16:   5%|███████▌                                                                                                                                                | 1000/20000 [04:31<1:25:41,  3.70it/s]

Average Training Loss at batch 1001: 2.970735984367805


Training Epoch 3/16:  10%|███████████████▏                                                                                                                                        | 2000/20000 [09:02<1:20:46,  3.71it/s]

Average Training Loss at batch 2001: 2.9645126723814226


Training Epoch 3/16:  15%|██████████████████████▊                                                                                                                                 | 3000/20000 [13:33<1:16:31,  3.70it/s]

Average Training Loss at batch 3001: 2.9605469156288775


Training Epoch 3/16:  20%|██████████████████████████████▍                                                                                                                         | 4000/20000 [18:05<1:12:05,  3.70it/s]

Average Training Loss at batch 4001: 2.959237880034615


Training Epoch 3/16:  20%|██████████████████████████████▍                                                                                                                         | 4002/20000 [18:11<7:35:46,  1.71s/it]