In [2]:
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
from DalleDecoder import DecoderOnlyTransformer
from DallEdVAE import dVAE
from train_transformer import train
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import time
import pickle


############################ dVAE ############################
# (I) Loading dVAE
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
inp_ch = 3
n_hid = 256
n_init = 128
bpg = 2
K = 8192
D = 512
Beta = 6.6
path_to_dvae = 'path_to_pretrained_dVAE'

dvae = dVAE(inp_ch, n_hid, n_init, bpg, K, D, Beta).to(device)
dvae.load_state_dict(torch.load(path_to_dvae))

######################## Data Loading ########################
# (II) Data (img-txt pairs) to Train Transformer
path_to_data = 'path_to_img_txt_pair_dataset'
batch_size = 256

with open(path_to_data, 'rb') as f:
    dataset = pickle.load(f)

tokenizer = get_tokenizer('basic_english')
def create_tokens(dataset):
  for sample in dataset:
    yield tokenizer(sample[1][1]) 

vocab = build_vocab_from_iterator(create_tokens(dataset), specials=["<start>"]) 
vocab.set_default_index(vocab["<start>"])

with open(f'vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)

text_vocab_size = len(vocab) # 16384 in the paper
text_seq_len = 256

class MyDataset(Dataset):
    def __init__(self, dataset, dvae, tokenizer, vocab, text_seq_len, text_vocab_size, device):
      super().__init__()
      self.dataset = dataset
      self.dvae = dvae
      self.tokenizer = tokenizer
      self.vocab = vocab
      self.text_seq_len = text_seq_len
      self.text_vocab_size = text_vocab_size
      self.device = device

    def _sent_padding(self, sent_vec, maxlen):
      sent_vec = torch.tensor(sent_vec)
      maxlen -= len(sent_vec)
      return F.pad(sent_vec, (0, maxlen))

    def text2token(self, text):
      text_vector = self._sent_padding(self.vocab(self.tokenizer(text)), maxlen=self.text_seq_len)
      text_range = torch.arange(self.text_seq_len) + self.text_vocab_size
      text = torch.where(text_vector == 0, text_range, text_vector) # <pad_i> tokens
      text = F.pad(text, (1, 0), value = 0) # add <bos>
      return text

    def __len__(self):
      return len(self.dataset)

    def __getitem__(self, idx):
      img, (file_name, txt) = self.dataset[idx]
      txt_tokens = self.text2token(txt)
      # txt_tokens: [257]
      img_tokens = self.dvae.get_code_book(img.unsqueeze(0)) # inp: [1, 3, 256, 256] 
      # img_tokens: [1, 1024]
      img_tokens = img_tokens.squeeze(0)
      # img_tokens: [1024]
      return (txt_tokens.to(self.device), img_tokens.to(self.device))

dataset = MyDataset(dataset, dvae, tokenizer, vocab, text_seq_len, text_vocab_size, device)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

######################## Transformer #########################
# (III) Loading Transformer
image_vocab_size = K # 8192
image_seq_len = 1024
d_model = 512
N = 64
heads = 64
d_ff = 2048
path_to_transformer = None

total_len_text_vocab = text_vocab_size + text_seq_len # 16384 + 256 in the paper

loss_img_weight = 7

epochs = 20
lr = 1e-3
print_step = 10

dec_only_transformer = DecoderOnlyTransformer(text_vocab_size, text_seq_len,
                                              image_vocab_size, image_seq_len,
                                              d_model, N, heads, d_ff).to(device)

if path_to_transformer != None:
    dec_only_transformer.load_state_dict(torch.load(path_to_transformer))

optimizer = optim.Adam(dec_only_transformer.parameters(), lr=lr)

########################## Training ###########################
# (IV) Training the Transformer
start_time = time.time()
total_L = train(dec_only_transformer, optimizer, dataloader, epochs, text_seq_len,
                total_len_text_vocab, loss_img_weight, print_step)

end_time = time.time()
print(f"Training time: {end_time- start_time:.3f} seconds")