<a href="https://colab.research.google.com/github/KhashayarEhteshami/Text-to-Image/blob/main/Text_to_Image.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import io
import h5py
import torch
import ipywidgets
import numpy as np
from torch import nn
from PIL import Image
from IPython import display
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedShuffleSplit

training = 1
num_epoch = 30
batch_size = 128
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset_root = os.path.join("/", "kaggle", "input", "flowershd5dataset")
kernel_root = os.path.join("/", "kaggle", "input", "text-to-image-DistilBERT-pytorch")

In [None]:
num_epoch = 60

In [None]:
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

In [None]:
class TrainDataset:
    def __init__(self, dataset_root, kernel_root):
        self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

        if os.path.exists(os.path.join(kernel_root, "data.npy")):
            self.data = np.load(os.path.join(kernel_root, "data.npy"), allow_pickle=True)
        else:
            f = h5py.File(os.path.join(dataset_root, "data", "flowers", "flowers.hdf5"), mode="r")
            self.data = self.prepareData(f['train'])
        np.save('data.npy', self.data)
        self.max_seq_len = max(map(lambda x: len(x["text"]["input_ids"]), self.data))

        self.transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=(90, 90)),
            transforms.RandomRotation(degrees=(180, 180)),
            transforms.RandomRotation(degrees=(270, 270)),
            transforms.RandomVerticalFlip(p=1),
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
        ])

    def prepareData(self, data):
        preparedData = []
        for idx, img_name in enumerate(tqdm(data)):
            image = np.array(Image.open(io.BytesIO(bytes(np.array(data[img_name]['img'])))).resize((256,256)))
            text = np.array(data[img_name]['txt']).item().strip()
            input_ids = self.tokenizer.encode(text)
            token_type_ids = [0] * (len(input_ids) - 1) + [1]
            attention_mask = [1] * len(token_type_ids)
            preparedData.append({
                "image": image,
                "text": {
                    "input_ids": input_ids,
                    "token_type_ids": token_type_ids,
                    "attention_mask": attention_mask
                },
            })
        return preparedData

    def padTokens(self, text_dict):
        pad_len = self.max_seq_len - sum(text_dict["attention_mask"])
        text_dict['input_ids'] =  [5] * pad_len + text_dict['input_ids'] # <pad> = 5
        text_dict['token_type_ids'] =  [2] * pad_len + text_dict['token_type_ids']
        text_dict['attention_mask'] = [0] * pad_len + text_dict['attention_mask']   
        return text_dict

    @staticmethod
    def collate_fn_module(batch, idx):
        images, texts = [], {}
        for data in batch:
            images.append(data[idx][0])
            for key in data[idx][1]:
                if key not in texts:
                    texts[key] = []
                texts[key].append(data[0][1][key])

        images = torch.stack(images).to(device)
        for key in texts:
            texts[key] = torch.tensor(texts[key]).to(device)
        return images, texts

    def collate_fn(self, batch):
        right_images, right_texts = self.collate_fn_module(batch, 0)
        wrong_images, wrong_texts = self.collate_fn_module(batch, 1)
        return (right_images, right_texts), (wrong_images, wrong_texts)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, right_idx):
        right_data = self.data[right_idx].copy()
        right_image = self.transforms(Image.fromarray(right_data["image"]))
        right_text = self.padTokens(right_data["text"].copy())

        wrong_idx = np.random.choice([(i) for i in range(len(self.data)) if i != right_idx])
        wrong_data = self.data[wrong_idx].copy()
        wrong_image = self.transforms(Image.fromarray(wrong_data["image"]))
        wrong_text = self.padTokens(wrong_data["text"].copy())
        return (right_image, right_text), (wrong_image, wrong_text)

In [None]:
%%time
train_dataset = TrainDataset(dataset_root, kernel_root)

In [None]:
def visualize(image, text, tokenizer=None, ax = None, title = None):
    if ax is None: _, ax = plt.subplots(1, 1, figsize = (10, 10))
    if title: ax.set_title(title)
    ax.imshow(image.permute(1, 2, 0))
    if tokenizer== None:
        pass
    else:
        ax.set_xlabel(tokenizer.decode([t for t in text["input_ids"] if t != 5]))

In [None]:
idx = np.random.randint(len(train_dataset))
right, wrong = train_dataset[idx]
_, axs = plt.subplots(1, 2, figsize=(5, 5))
visualize(*right, train_dataset.tokenizer, axs[0], "Right")
visualize(*wrong, ax = axs[1], title = "Wrong")
plt.show()

In [None]:
#use pretrain checkpoint
if os.path.exists('./The Flower checkpoint3.pt'):
    checkpoint = torch.load('./The Flower checkpoint3.pt', map_location=device)
    torch.save(checkpoint, "The Flower checkpoint3.pt")
else:
    print('the file is not exist')
    checkpoint = {}

In [None]:
train_indices = checkpoint.get('train_indices', None)
valid_indices = checkpoint.get('valid_indices', None)
if train_indices is None or valid_indices is None:
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2)
    splitter = sss.split(train_dataset, np.array([len(x["text"]["input_ids"])%10 for x in train_dataset.data]))
    train_indices, valid_indices = next(splitter)
# Creating PT data samplers and loaders
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, collate_fn=train_dataset.collate_fn, drop_last=True)
valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler, collate_fn=train_dataset.collate_fn, drop_last=True)

In [None]:
text_encoder_model = "distilbert-base-uncased"
pretrained = True 
trainable = False  

In [None]:
class text(nn.Module):
    def __init__(self, model_name=text_encoder_model, pretrained=pretrained, trainable=trainable):
            super().__init__()

            self.transformer = DistilBertModel.from_pretrained(model_name)

            for p in self.transformer.parameters():
                p.requires_grad = trainable

            # we are using the CLS token hidden representation as the sentence's embedding
            self.target_token_idx = 0

    def forward(self, input_ids, token_type_ids, attention_mask):
        hidden = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).last_hidden_state            
        return hidden

In [None]:
class Generator(nn.Module):
    def __init__(self, hidden_dim =768, latent_dim = 768, model_name=text_encoder_model, pretrained=pretrained, trainable=trainable):
        super().__init__()


        self.lstm = nn.LSTM(768,64, batch_first = True, bidirectional = True)
        self.linear = nn.Sequential(nn.Linear(64*2, 256),
                                    nn.ReLU(),
                                    nn.Linear(256,512))

        self.latent = nn.Sequential(
            nn.Linear(768, 512),
            nn.LeakyReLU(0.1)
            )

        self.model = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(64, 16, 2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 32, 2),
            nn.Dropout2d(0.5),            
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 128, 2,2),
            nn.Dropout2d(0.5),            
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 256, 2),
            nn.Dropout2d(0.5),            
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 64, 2),
            nn.Dropout2d(0.5),            
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 512, 2),
            nn.Dropout2d(0.5),            
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 128, 2),
            nn.Dropout2d(0.5),            
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 32, 1),
            nn.Dropout2d(0.5),            
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 64, 2,2),
            nn.Dropout2d(0.5),            
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 2,2),
            nn.Tanh())
        
            
    def forward(self,noise, encoded_text):
        output,(hidden,c) = self.lstm(encoded_text)
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
        hidden = self.linear(hidden)
        noise = self.latent(noise)
        x = torch.cat([noise,hidden],dim=1) #size[batch_size, 1024]
        batch_size = x.shape[0]
        x = x.view(batch_size , 64, 4, 4) #size[batch_size, 64, 4, 4]
        x = self.model(x)
        #x = [128,3,64,64]
        
        return x

class Discriminator(nn.Module):
    def __init__(self, model_name=text_encoder_model, pretrained=pretrained, trainable=trainable):
        super().__init__()
        

        self.lstm = nn.LSTM(768,64, batch_first = True, bidirectional = True)
        self.linear = nn.Sequential(nn.Linear(64*2, 256),
                                    nn.ReLU(),
                                    nn.Linear(256,64*64*3))

        self.model = nn.Sequential(
            # input is 6 x 64 x 64
            nn.Conv2d(6, 64, 2, 2, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 2, 2, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 2,2 ,bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256 , 64, 2,bias=False), 
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64 , 16, 3, bias=False), 
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(16 , 32, 3, bias=False), 
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Flatten(),
            nn.Linear(288,1))
        

    
    def forward(self, x,encoded_text):                                                    
        output,(hidden,c) = self.lstm(encoded_text)#[batch_size,*, 768]
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
        hidden = self.linear(hidden)
        batch_size = x.shape[0]
        hidden = hidden.view(batch_size, -1, 64,64)
        x = x.type(torch.FloatTensor)
        x = torch.cat([x.to(device), hidden], dim=1)
        x = self.model(x)
        return x



In [None]:
loader = tqdm(train_loader, desc = f"Training")
for idx, ((right_images, right_texts), (wrong_images, wrong_texts)) in enumerate(train_loader, start=1):
        # Execute
        enc_right_texts = models["text_encoder"](**right_texts).detach()
        enc_wrong_texts = models["text_encoder"](**wrong_texts).detach()
        break

In [None]:
models = {
    "text_encoder": text().to(device),
    "generator": Generator().to(device),
    "discriminator": Discriminator().to(device)
}
optimizers = {}
schedulers = {}
criterion = nn.BCEWithLogitsLoss()
learning_rates = {"text_encoder": 1e-4, "generator": 5e-4, "discriminator": 1e-4}
for key in models:
    if key == "text_encoder": continue
    models[key].load_state_dict(checkpoint.get("models", {}).get(key, models[key].state_dict()))
    # Prepare optimizer and schedule (linear warmup and decay)
    optimizers[key] = torch.optim.Adam(models[key].parameters(), lr=learning_rates[key], weight_decay=1e-6)
    optimizers[key].load_state_dict(checkpoint.get("optimizers", {}).get(key, optimizers[key].state_dict()))
    schedulers[key] = torch.optim.lr_scheduler.StepLR(optimizers[key], step_size=50, gamma=0.1, last_epoch=-1)
    schedulers[key].load_state_dict(checkpoint.get("schedulers", {}).get(key, schedulers[key].state_dict()))

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The generator has {count_parameters(models["generator"]):,} trainable parameters')
print(f'The discriminator has {count_parameters(models["discriminator"]):,} trainable parameters')
print(f'The Text Encoder has {count_parameters(models["text_encoder"]):,} trainable parameters')

In [None]:
def valid():
    total_dis_loss, total_gen_loss = 0, 0
    total_dis_score, total_gen_score = 0, 0
    template = "Dis Loss: {:.4f}, Dis Score: {:.4f}, Gen Loss: {:.4f}, Gen Score: {:.4f}"
    for key in models:
        models[key].eval()
    loader = tqdm(valid_loader, desc = f"Validating")
    for idx, ((right_images, right_texts), (wrong_images, wrong_texts)) in enumerate(loader, start=1):
        # Execute
        with torch.no_grad():
            enc_right_texts = models["text_encoder"](**right_texts).detach()
            enc_wrong_texts = models["text_encoder"](**wrong_texts).detach()

            noise = torch.randn((len(right_images), 768),device = device)
            ones = torch.ones((len(right_images),1), dtype=torch.float, device=device)
            zeros = torch.zeros((len(right_images),1), dtype=torch.float, device=device)
            fake_images = models["generator"](noise, enc_right_texts)

            dis_right_logits = models["discriminator"](right_images, enc_right_texts)
            dis_wrong_text_logits =  models["discriminator"](right_images, enc_wrong_texts)
            dis_wrong_image_logits =  models["discriminator"](wrong_images, enc_right_texts)
            dis_fake_logits = models["discriminator"](fake_images.detach(), enc_right_texts)

            dis_real_loss = criterion(dis_right_logits, ones)
            dis_wrong_text_loss = criterion(dis_wrong_text_logits, zeros)
            dis_wrong_image_loss = criterion(dis_wrong_image_logits, zeros)
            dis_fake_loss = criterion(dis_fake_logits, zeros)
            dis_loss = dis_real_loss + 0.25*dis_wrong_text_loss + 0.25*dis_wrong_image_loss + 0.5*dis_fake_loss

            dis_fake_logits = models["discriminator"](fake_images, enc_right_texts)
            gen_loss = criterion(dis_fake_logits, ones)

        dis_score, gen_score = torch.tensor(0), torch.tensor(0)

        total_dis_loss += dis_loss.item(); total_gen_loss += gen_loss.item()
        total_dis_score += dis_score.item(); total_gen_score += gen_score.item()
        # print statistics
        loader.set_postfix_str(template.format(dis_loss, dis_score, gen_loss, gen_score))
        loader.update()
        # Clear variable
        del right_images; del right_texts; del wrong_images; del wrong_texts
        del enc_right_texts; del enc_wrong_texts; del noise; del fake_images
        del ones; del zeros; del dis_right_logits; del dis_wrong_text_logits
        del dis_wrong_image_logits; del dis_fake_logits; del dis_real_loss
        del dis_wrong_text_loss; del dis_wrong_image_loss; del dis_fake_loss
        del dis_loss; del gen_loss; del dis_score; del gen_score
        torch.cuda.empty_cache()
    template = template.format(total_dis_loss/idx, total_dis_score/idx, total_gen_loss/idx, total_gen_score/idx)
    loader.write(f"Validated | {template}")
    return total_dis_loss/idx, total_dis_score/idx, total_gen_loss/idx, total_gen_score/idx


In [None]:
def train():
    total_dis_loss, total_gen_loss = 0, 0
    total_dis_score, total_gen_score = 0, 0
    template = "Dis Loss: {:.4f}, Dis Score: {:.4f}, Gen Loss: {:.4f}, Gen Score: {:.4f}"
    for key in models:
        models[key].train()
    loader = tqdm(train_loader, desc = f"Training")
    for idx, ((right_images, right_texts), (wrong_images, wrong_texts)) in enumerate(loader, start=1):
        # Execute
        enc_right_texts = models["text_encoder"](**right_texts).detach()
        enc_wrong_texts = models["text_encoder"](**wrong_texts).detach()

        noise = torch.randn((len(right_images), 768), device=device)
        ones = torch.ones((len(right_images),1), dtype=torch.float, device=device)
        zeros = torch.zeros((len(right_images),1), dtype=torch.float, device=device)
        fake_images = models["generator"](noise, enc_right_texts)

        dis_right_logits = models["discriminator"](right_images, enc_right_texts)
        dis_wrong_text_logits = models["discriminator"](right_images, enc_wrong_texts)
        dis_wrong_image_logits = models["discriminator"](wrong_images, enc_right_texts)
        dis_fake_logits = models["discriminator"](fake_images.detach(), enc_right_texts)

        dis_real_loss = criterion(dis_right_logits, ones)
        dis_wrong_text_loss = criterion(dis_wrong_text_logits, zeros)
        dis_wrong_image_loss = criterion(dis_wrong_image_logits, zeros)
        dis_fake_loss = criterion(dis_fake_logits, zeros)
        dis_loss = dis_real_loss + 0.25*dis_wrong_text_loss + 0.25*dis_wrong_image_loss + 0.5*dis_fake_loss

        optimizers["discriminator"].zero_grad()
        dis_loss.backward()
        optimizers["discriminator"].step()

        dis_fake_logits = models["discriminator"](fake_images, enc_right_texts)
        gen_loss = criterion(dis_fake_logits, ones)
        optimizers["generator"].zero_grad()
        gen_loss.backward()
        optimizers["generator"].step()

        dis_score, gen_score = torch.tensor(0), torch.tensor(0)

        total_dis_loss += dis_loss.item(); total_gen_loss += gen_loss.item()
        total_dis_score += dis_score.item(); total_gen_score += gen_score.item()
        # print statistics
        loader.set_postfix_str(template.format(dis_loss, dis_score, gen_loss, gen_score))
        loader.update()
        # Clear variable
        del right_images; del right_texts; del wrong_images; del wrong_texts
        del enc_right_texts; del enc_wrong_texts; del noise; del fake_images
        del ones; del zeros; del dis_right_logits; del dis_wrong_text_logits
        del dis_wrong_image_logits; del dis_fake_logits; del dis_real_loss
        del dis_wrong_text_loss; del dis_wrong_image_loss; del dis_fake_loss
        del dis_loss; del gen_loss; del dis_score; del gen_score
        torch.cuda.empty_cache()
    template = template.format(total_dis_loss/idx, total_dis_score/idx, total_gen_loss/idx, total_gen_score/idx)
    loader.write(f"Trained | {template}")
    for key in schedulers:
        schedulers[key].step()
    return total_dis_loss/idx, total_dis_score/idx, total_gen_loss/idx, total_gen_score/idx

In [None]:
train_data = checkpoint.get('train_data', np.empty((0, 4)))
valid_data = checkpoint.get('valid_data', np.empty((0, 4)))
epoch_data = checkpoint.get('epoch_data', [])
# del checkpoint
def plot_training_data(ax, epoch_data, train_data, valid_data):
    ax[0].clear(); ax[1].clear(); ax[2].clear(); ax[3].clear()
    ax[0].plot(epoch_data, train_data[:, 0], label = f"Train dis loss {train_data[-1, 0]:.4f}")
    ax[0].plot(epoch_data, valid_data[:, 0], label = f"Valid dis loss {valid_data[-1, 0]:.4f}")
    ax[1].plot(epoch_data, train_data[:, 1], label = f"Train dis score {train_data[-1, 1]:.4f}")
    ax[1].plot(epoch_data, valid_data[:, 1], label = f"Valid dis score {valid_data[-1, 1]:.4f}")
    ax[2].plot(epoch_data, train_data[:, 2], label = f"Train gen loss {train_data[-1, 2]:.4f}")
    ax[2].plot(epoch_data, valid_data[:, 2], label = f"Valid gen loss {valid_data[-1, 2]:.4f}")
    ax[3].plot(epoch_data, train_data[:, 3], label = f"Train gen score {train_data[-1, 3]:.4f}")
    ax[3].plot(epoch_data, valid_data[:, 3], label = f"Valid gen score {valid_data[-1, 3]:.4f}")
    ax[0].legend(); ax[1].legend(); ax[2].legend(); ax[3].legend()
if epoch_data:
    fig, ax = plt.subplots(1, 4, figsize=(10*4, 5))
    plot_training_data(ax, epoch_data, train_data, valid_data)
    plt.show()

In [None]:
num_epoch = 50

In [None]:
loader = tqdm(range(len(epoch_data), len(epoch_data) + num_epoch * training), desc = "Epoch")
board = ipywidgets.Output()
if training: display.display(board)
graph = display.display(ipywidgets.widgets.HTML(f"<b>Training statred: {bool(training)}</b>"), display_id = True)
for i in loader:
    with board:
        epoch_data.append(i+1)
        # Make grid
        fig, ax = plt.subplots(1, 4, figsize=(10*4, 5))
        # Close figure
        plt.close(fig)
        train_data = np.append(train_data, [train()], axis = 0)
        valid_data = np.append(valid_data, [valid()], axis = 0)
        # Visualize
        plot_training_data(ax, epoch_data, train_data, valid_data)
        graph.update(fig)
        # Clear all progress bar with in board widget
        display.clear_output()
        graph = display.display(fig, display_id = True)
    # Save model
    params = {
        'models': dict([(key, models[key].state_dict()) for key in models]),
        'optimizers': dict([(key, optimizers[key].state_dict()) for key in optimizers]),
        'schedulers': dict([(key, schedulers[key].state_dict()) for key in schedulers]),
        'train_indices': train_indices,
        'valid_indices': valid_indices,
        'train_data': train_data,
        'valid_data': valid_data,
        'epoch_data': epoch_data
    }
    torch.save(params, "The Flower checkpoint4.pt")
loader.write("Done!")

In [None]:
idx = np.random.randint(len(train_dataset))
(image, text), _ = train_dataset[idx]
params = {}
for key in text:
    params[key] = torch.tensor(text[key], device=device)[None]
for key in models:
    models[key].eval()
enc_text = models["text_encoder"](**params)
noise = torch.randn((1, 768), device=device)
gen_image = models["generator"](noise, enc_text).detach().squeeze().cpu()
_, axs = plt.subplots(1, 2, figsize=(5, 5))
visualize(image, text, train_dataset.tokenizer, axs[0], "Right")
visualize(gen_image, text, None, axs[1], "Generated")
plt.show()

In [None]:
params = {}
txt=input()
text = np.array(txt).item().strip()
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
input_ids = tokenizer.encode(text)
token_type_ids = [0] * (len(input_ids) - 1) + [1]
attention_mask = [1] * len(token_type_ids)
text = {
    "input_ids": input_ids,
    "token_type_ids": token_type_ids,
    "attention_mask": attention_mask
}
for key in models:
    models[key].eval()
for key in text:
#     print(key)
    params[key] = torch.tensor(text[key], device=device)[None]
enc_text = models["text_encoder"](**params)
noise = torch.randn((1, 768), device=device)
gen_image = models["generator"](noise, enc_text).detach().squeeze().cpu()

_, ax = plt.subplots(1, 1, figsize=(5, 5))
visualize(gen_image, text, train_dataset.tokenizer, ax, title = "Generated Image")
plt.show()

In [None]:
#this function let me download the weights from Kaggle
from IPython.display import FileLink
FileLink('./The Flower checkpoint3.pt')

In [None]:
torch.save(params, "./The Flower checkpoint3.pt")