In [1]:
# Training the brain, using text alone.

In [2]:
from visual_transformer import *

In [3]:
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from torch.utils.data import Dataset

In [4]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # -- penguins.farm version
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # -- penguins.army version
paths = ["text_pretraining_data/eng_sentences_pruned-train.txt"]
vocab_size = 10000
# tokenizer.save_model(".", "tokenizer/eng_sentences_tokenizer_vc10000")
tokenizer = ByteLevelBPETokenizer(
    "./text_pretraining_tokenizer/eng_sentences_tokenizer_vc10000_v2-vocab.json",
    "./text_pretraining_tokenizer/eng_sentences_tokenizer_vc10000_v2-merges.txt",
)   
tokenizer._tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)   
tokenizer.enable_truncation(max_length=32)
tokenizer.enable_padding()

## Dataset
class SampleDataset(Dataset):
    def __init__(self, seq_length = 32, evaluate: bool = False, tokenizer=None, device = None):
        if device is None:
            device = 'cpu'
        self.device = device
        self.seq_length = seq_length
        if tokenizer is None:
            tokenizer = ByteLevelBPETokenizer(
                "./text_pretraining_tokenizer/eng_sentences_tokenizer_v2-vocab.json",
                "./text_pretraining_tokenizer/eng_sentences_tokenizer_v2-merges.txt",
            )   
        tokenizer._tokenizer.post_processor = BertProcessing(
            ("</s>", tokenizer.token_to_id("</s>")),
            ("<s>", tokenizer.token_to_id("<s>")),
        )   
        tokenizer.enable_truncation(max_length=self.seq_length)
        tokenizer.enable_padding()#length=seq_length)
        # or use the RobertaTokenizer from `transformers` directly.

        self.examples = []

        src_files = Path("./text_pretraining_data/").glob("*-eval.txt") if evaluate else Path("./text_pretraining_data/").glob("*-train.txt")
        for src_file in src_files:
            print("🔥", src_file)
            lines = src_file.read_text(encoding="utf-8").splitlines()
            self.examples += [x.ids for x in tokenizer.encode_batch(lines)]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i): 
        # We’ll pad at the batch level.
        return torch.tensor(self.examples[i]).to(self.device)

In [5]:
brain = DefaultAgentBrain().to(device)

In [None]:
sdt = SampleDataset(tokenizer=tokenizer)
sdv = SampleDataset(tokenizer=tokenizer, evaluate=True)

🔥 text_pretraining_data/eng_sentences_pruned-train.txt


In [None]:
from torch.utils.data import DataLoader

batch_size = 80 #32*4*2#8
num_workers = 0

train_loader = DataLoader(sdt, batch_size=batch_size,
                          num_workers=num_workers, shuffle=True)

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=0)

def get_loss(res, inputs):
    return torch.sum(criterion(res[:, :, :-1], inputs[:, 1:]))

In [None]:
from tqdm import tqdm

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(brain.parameters(), lr=0.0001*80/256, eps=1e-9)#, #betas=(0.9, 0.98), eps=1e-9)

epochs = 16

for epoch in range(epochs):
    brain.train()
    train_loss = 0
    i = -1
    for inputs in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):
        i += 1
        inputs = inputs.to(device)
        img_context = torch.randn((batch_size, 256, 768), device=inputs.device) # easier for pretraining
        src_attention_mask, src_key_padding_mask = brain.get_masks(inputs, use_masks=True)
        text_encoding = brain.get_text_encoding(inputs, src_attention_mask, src_key_padding_mask)
        res = brain.get_text_decoding(text_encoding, src_attention_mask, src_key_padding_mask, img_context, return_full=True)
        loss = get_loss(res, inputs)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
        if i % 1000 == 0:
            avg_train_loss = train_loss / (i + 1)
            print(f"Average Training Loss at batch {i + 1}: {avg_train_loss}")
            torch.save(brain.text_enc.state_dict(), 'brain_checkpoints/text_encoder_weights_v4.pth')
            torch.save(brain.text_dec.state_dict(), 'brain_checkpoints/text_decoder_weights_v4.pth')
    avg_train_loss = train_loss / len(train_loader)
    print(f"Average Training Loss after epoch {epoch}: {avg_train_loss}")
    torch.save(brain.text_enc.state_dict(), 'brain_checkpoints/text_encoder_weights_v4.pth')
    torch.save(brain.text_dec.state_dict(), 'brain_checkpoints/text_decoder_weights_v4.pth')