**COLAB SPECIFIC CODE:**

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


!python -m pip install -U scikit-image

Mounted at /content/drive
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting scikit-image
  Downloading scikit_image-0.19.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.5 MB)
[K     |████████████████████████████████| 13.5 MB 29.0 MB/s 
Installing collected packages: scikit-image
  Attempting uninstall: scikit-image
    Found existing installation: scikit-image 0.18.3
    Uninstalling scikit-image-0.18.3:
      Successfully uninstalled scikit-image-0.18.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.[0m
Successfully installed scikit-image-0.19.3


In [7]:
import sys
# sys.path.append('/content/drive/MyDrive/FYP/attngan')

**IMPORTS:**

In [8]:
from miscc.utils import mkdir_p
from miscc.utils import build_super_images
from miscc.losses import sent_loss, words_loss
from miscc.config import cfg, cfg_from_file

from dataset import prepare_data, TextDataset as TextFashionGenDataset

from model import RNN_ENCODER, CNN_ENCODER

import os
import time
import random
import pprint
import datetime
import dateutil.tz
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms



UPDATE_INTERVAL = 200

In [9]:
def build_models(dataset):
    # build model ############################################################
    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
    labels = Variable(torch.LongTensor(range(cfg.TRAIN.BATCH_SIZE)))
    start_epoch = 0
    if cfg.TRAIN.NET_E != "":
        state_dict = torch.load(cfg.TRAIN.NET_E)
        text_encoder.load_state_dict(state_dict)
        print("Load ", cfg.TRAIN.NET_E)
        #
        name = cfg.TRAIN.NET_E.replace("text_encoder", "image_encoder")
        state_dict = torch.load(name)
        image_encoder.load_state_dict(state_dict)
        print("Load ", name)

        istart = cfg.TRAIN.NET_E.rfind("_") + 8
        iend = cfg.TRAIN.NET_E.rfind(".")
        start_epoch = cfg.TRAIN.NET_E[istart:iend]
        start_epoch = int(start_epoch) + 1
        print("start_epoch", start_epoch)
    if cfg.CUDA:
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()
        labels = labels.cuda()

    return text_encoder, image_encoder, labels, start_epoch

In [10]:
def train(
    dataloader,
    cnn_model,
    rnn_model,
    batch_size,
    labels,
    optimizer,
    epoch,
    ixtoword,
    image_dir,
):
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()

    for step, data in enumerate(dataloader, 0):
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

        # nef -- cfg.TEXT.EMBEDDING_DIM = 256 (for FashionGen)

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        # words_emb: batch_size x nef x seq_len
        # sent_emb: batch_size x nef
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn_maps = words_loss(
            words_features, words_emb, labels, cap_lens, class_ids, batch_size
        )
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data

        loss.backward()

        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | "
                "s_loss {:5.2f} {:5.2f} | "
                "w_loss {:5.2f} {:5.2f}".format(
                    epoch,
                    step,
                    len(dataloader),
                    elapsed * 1000.0 / UPDATE_INTERVAL,
                    s_cur_loss0,
                    s_cur_loss1,
                    w_cur_loss0,
                    w_cur_loss1,
                )
            )
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()

            # attention Maps
            img_set, _ = build_super_images(
                imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze
            )

            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = "%s/attention_maps%d.png" % (image_dir, step)
                im.save(fullpath)

    return count

def evaluate(dataloader, cnn_model, rnn_model, batch_size, labels):
    cnn_model.eval()
    rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0

    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

        words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)

        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)

        w_loss0, w_loss1, attn = words_loss(
            words_features, words_emb, labels, cap_lens, class_ids, batch_size
        )
        w_total_loss += (w_loss0 + w_loss1).data

        s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
        s_total_loss += (s_loss0 + s_loss1).data

        if step == 50:
            break

    s_cur_loss = s_total_loss.item() / step
    w_cur_loss = w_total_loss.item() / step

    return s_cur_loss, w_cur_loss

In [11]:
def init_DAMSM(cfgdirectory):
    cfg_from_file(cfgdirectory)

    cfg.GPU_ID = 0  # -1 if on cpu, not tested so might break

    print("Using config:")
    pprint.pprint(cfg)

    # seeds
    manual_seed = None
    if not cfg.TRAIN.FLAG:
        manual_seed = 100
    elif manual_seed is None:
        manual_seed = random.randint(1, 10000)
    random.seed(manual_seed)
    np.random.seed(manual_seed)
    torch.manual_seed(manual_seed)
    if cfg.CUDA:
        torch.cuda.manual_seed_all(manual_seed)

    ##########################################################################
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime("%Y_%m_%d_%H_%M_%S")
    output_dir = "./drive/MyDrive/FYP/attngan/data/%s_%s_%s" % (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp,)

    model_dir = os.path.join(output_dir, "Model")
    image_dir = os.path.join(output_dir, "Image")
    mkdir_p(model_dir)
    mkdir_p(image_dir)

    torch.cuda.set_device(cfg.GPU_ID)
    cudnn.benchmark = True

    # Get data loader ##################################################
    imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1))
    batch_size = cfg.TRAIN.BATCH_SIZE

    # dataset images transforms

    image_transform = transforms.Compose(
        [transforms.Resize(imsize), transforms.RandomHorizontalFlip()]
    )

    # train data
    dataset = TextFashionGenDataset(
        cfg.DATA_DIR, "train", base_size=cfg.TREE.BASE_SIZE, transform=image_transform,
    )

    # print(dataset.n_words, dataset.embeddings_num)
    assert dataset
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=int(cfg.WORKERS),
    )

    # # validation data #

    dataset_val = TextFashionGenDataset(
        cfg.DATA_DIR, "test", base_size=cfg.TREE.BASE_SIZE, transform=image_transform,
    )

    dataloader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=int(cfg.WORKERS),
    )

    # Train ##############################################################
    text_encoder, image_encoder, labels, start_epoch = build_models(dataset)

    para = list(text_encoder.parameters())
    for v in image_encoder.parameters():
        if v.requires_grad:
            para.append(v)
    # optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999))
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        lr = cfg.TRAIN.ENCODER_LR
        for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH):
            optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999))
            epoch_start_time = time.time()

            count = train(
                dataloader,
                image_encoder,
                text_encoder,
                batch_size,
                labels,
                optimizer,
                epoch,
                dataset.ixtoword,
                image_dir,
            )

            print("-" * 89)
            if len(dataloader_val) > 0:
                s_loss, w_loss = evaluate(
                    dataloader_val, image_encoder, text_encoder, batch_size, labels
                )
                print(
                    "| end epoch {:3d} | valid loss "
                    "{:5.2f} {:5.2f} | lr {:.5f}|".format(epoch, s_loss, w_loss, lr)
                )
            print("-" * 89)
            if lr > cfg.TRAIN.ENCODER_LR / 10.0:
                lr *= 0.98

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or epoch == cfg.TRAIN.MAX_EPOCH:
                torch.save(
                    image_encoder.state_dict(),
                    "%s/image_encoder%d.pth" % (model_dir, epoch),
                )
                torch.save(
                    text_encoder.state_dict(),
                    "%s/text_encoder%d.pth" % (model_dir, epoch),
                )
                print("Save G/Ds models.")
    except KeyboardInterrupt:
        print("-" * 89)
        print("Exiting from training early")

In [16]:
init_DAMSM("./cfg/DAMSM/fashiongen2.yml")

Using config:
{'B_VALIDATION': False,
 'CONFIG_NAME': 'DAMSM',
 'CUDA': True,
 'DATASET_NAME': 'fashiongen2',
 'DATA_DIR': './data/fashiongen',
 'GAN': {'B_ATTENTION': True,
         'B_DCGAN': False,
         'CONDITION_DIM': 100,
         'DF_DIM': 64,
         'GF_DIM': 128,
         'R_NUM': 2,
         'Z_DIM': 100},
 'GPU_ID': 0,
 'RNN_TYPE': 'LSTM',
 'TEXT': {'CAPTIONS_PER_IMAGE': 1, 'EMBEDDING_DIM': 256, 'WORDS_NUM': 10},
 'TRAIN': {'BATCH_SIZE': 32,
           'B_NET_D': True,
           'DISCRIMINATOR_LR': 0.0002,
           'ENCODER_LR': 0.001,
           'FLAG': True,
           'GENERATOR_LR': 0.0002,
           'MAX_EPOCH': 401,
           'NET_E': '/home/jupyter/temp/attngan/drive/MyDrive/FYP/attngan/data/fashiongen2_DAMSM_2022_06_30_03_57_51/Model/text_encoder132.pth',
           'NET_G': '',
           'RNN_GRAD_CLIP': 0.25,
           'SMOOTH': {'GAMMA1': 4.0,
                      'GAMMA2': 5.0,
                      'GAMMA3': 10.0,
                      'LAMBDA': 1.

In [None]:
init_DAMSM("/content/drive/MyDrive/FYP/attngan/cfg/DAMSM/fashiongen2.yml")

Using config:
{'B_VALIDATION': False,
 'CONFIG_NAME': 'DAMSM',
 'CUDA': True,
 'DATASET_NAME': 'fashiongen2',
 'DATA_DIR': './drive/MyDrive/FYP/attngan/data/fashiongen',
 'GAN': {'B_ATTENTION': True,
         'B_DCGAN': False,
         'CONDITION_DIM': 100,
         'DF_DIM': 64,
         'GF_DIM': 128,
         'R_NUM': 2,
         'Z_DIM': 100},
 'GPU_ID': 0,
 'RNN_TYPE': 'LSTM',
 'TEXT': {'CAPTIONS_PER_IMAGE': 1, 'EMBEDDING_DIM': 256, 'WORDS_NUM': 10},
 'TRAIN': {'BATCH_SIZE': 32,
           'B_NET_D': True,
           'DISCRIMINATOR_LR': 0.0002,
           'ENCODER_LR': 0.001,
           'FLAG': True,
           'GENERATOR_LR': 0.0002,
           'MAX_EPOCH': 401,
           'NET_E': '',
           'NET_G': '',
           'RNN_GRAD_CLIP': 0.25,
           'SMOOTH': {'GAMMA1': 4.0,
                      'GAMMA2': 5.0,
                      'GAMMA3': 10.0,
                      'LAMBDA': 1.0},
           'SNAPSHOT_INTERVAL': 2},
 'TREE': {'BASE_SIZE': 299, 'BRANCH_NUM': 1},
 'WORKERS

  "num_layers={}".format(dropout, num_layers))
Downloading: "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-1a9a5a14.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

Load pretrained model from  https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth


  similarities.data.masked_fill_(masks, -float('inf'))
  scores0.data.masked_fill_(masks, -float('inf'))


| epoch   0 |     0/  445 batches | ms/batch 31.65 | s_loss  0.02  0.02 | w_loss  0.02  0.02
| epoch   0 |   200/  445 batches | ms/batch 496.13 | s_loss  2.71  2.75 | w_loss  2.80  2.68
| epoch   0 |   400/  445 batches | ms/batch 349.77 | s_loss  2.32  2.35 | w_loss  2.10  2.13
-----------------------------------------------------------------------------------------
| end epoch   0 | valid loss  4.72  4.24 | lr 0.00100|
-----------------------------------------------------------------------------------------
Save G/Ds models.
| epoch   1 |     0/  445 batches | ms/batch  4.91 | s_loss  0.01  0.01 | w_loss  0.01  0.01
| epoch   1 |   200/  445 batches | ms/batch 345.23 | s_loss  2.11  2.16 | w_loss  1.85  1.88
| epoch   1 |   400/  445 batches | ms/batch 327.68 | s_loss  2.05  2.09 | w_loss  1.77  1.82
-----------------------------------------------------------------------------------------
| end epoch   1 | valid loss  4.64  4.06 | lr 0.00098|
----------------------------------------