In [1]:
import os
import sys

module_path = os.path.join(os.getcwd(), '../src')
sys.path.append(module_path)

import torch 
from gan_t2i.models.CLIP import CLIPModel
from gan_t2i.utils.model_loading import download_CLIP_model , CLIP_DATASETS


from gan_t2i.models.GAN import WGAN

from gan_t2i.datasets.DatasetFactory import DatasetFactory
import torchvision.transforms as transforms
import clip
from PIL import Image

In [2]:
print(torch.cuda.is_available())

True


# Loading data from checkpoints 

In [3]:
checkpoints_path = download_CLIP_model(CLIP_DATASETS.FLOWERS)

[92mCLIP model FLOWERS already exits at /home/xxx/Desktop/Deep Learning/Deep-Learning-Final-Project/examples/models_weights/CLIP/CLIP~FT_FLOWERS/CLIP~FT_FLOWERS.pt[0m


In [4]:
clip_model = CLIPModel.load(checkpoints_path)

Model loaded on device: cuda


------------------------------------

In [5]:
from torch.utils.data import DataLoader , SubsetRandomSampler

# Loading dataset

In [6]:
""" Text Transformation

You need to tokenize your text before passing it to the model.
"""
def tokenize_text(text):
    
    # Try except is needed due to error thrown by CLIP model that limit the context size
    # to 77 tokens so we need to split the text in smaller chunks and keep only a small portion
    # of it if the text is too long
    try:
        return clip.tokenize([text])[0]
    except:
        return clip.tokenize([text.split(".")[0]])[0]    
    

In [7]:
""" Image transformations """
transform_img = transforms.Compose([
    transforms.Resize(224, interpolation=Image.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    
    # Previously computed mean and std
    transforms.Normalize([0.4355, 0.3777, 0.2879], [0.2571, 0.2028, 0.2101])
])

In [8]:
dataset = DatasetFactory.Flowers(os.path.join(os.getcwd(), "..", "data"), transform_img=transform_img, transform_caption=tokenize_text)

Captions already downloaded
images already downloaded
Captions already extracted
images already extracted
[92mThe dataset is already stored in HDF5 format[0m


In [9]:
# Create train, validation and test set     NOTE: We are using small subset of the dataset for this example
#                                                 This may and will cause overfitting this is only a demo
train_size = int(0.05 * len(dataset))       
val_size = int(0.02 * len(dataset))
test_size = int(0.02 * len(dataset))

# Cration of train, validation and test set indices and samplers
train_indices = list(range(train_size))
val_indices = list(range(train_size, train_size + val_size))
test_indices = list(range(train_size + val_size, train_size + val_size + test_size))

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
test_sampler = SubsetRandomSampler(test_indices)

# Creation of train, validation and test dataloaders
train_loader = DataLoader(dataset, batch_size=16, sampler=train_sampler, pin_memory=True)
val_loader = DataLoader(dataset, batch_size=16, sampler=val_sampler, pin_memory=True)
test_loader = DataLoader(dataset, batch_size=16, sampler=test_sampler, pin_memory=True)

---------------------------------------

# Creating the GAN model and training it

- Get the image and text embedding

In [10]:
# Per ottenere l'ultimo livello
dim_img_size , dim_text_size = clip_model.get_output_dimensions()
print("Ultimo livello del modello:", dim_img_size[1] , dim_text_size[1])

tensor([[49406, 34246,  4160, 49407,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]], device='cuda:0',
       dtype=torch.int32)
Ultimo livello del modello: 512 512


In [11]:
embedding_size = dim_text_size[1]
p_emb_dim = 128
WGAN_model = WGAN(clip_model,embedding_size,p_emb_dim)


In [12]:
checkpoints_path = os.path.join(os.getcwd(), "checkpoints")

WGAN_model.fit(
    train_dataloader = train_loader, 
    val_dataloader = val_loader,
    num_epochs = 1, 
    save_path=checkpoints_path
)

Training on device:  cuda
One : tensor([1.], device='cuda:0') | Mone : tensor([-1.], device='cuda:0') 
epoch iter 0 


Epoch [1/1] Batch [256/256]: Loss Discriminator: -0.015878Loss Generator: -0.269301: 100%|██████████| 256/256 [01:22<00:00,  3.12it/s]
