# Guided Learning
Will be used in case the third generator experienced mode collapse and we want to restart it's weights.    
The new weights will be saved in `manual_attngan.pt` and should be loaded when training restarts. 

In [1]:
%load_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import sys
sys.path.append("../src")

In [2]:
import numpy as np
from PIL import Image
from skimage import io
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
%matplotlib inline

In [3]:
import torch

from dataset.birds_dataset import BirdsDataset
from models.discriminator import Discriminator
from models.generator import GeneratorNetwork
from models.conditional_augmentation import ConditioningAugmentation
from models.text_encoder import TextEncoder
from models.image_encoder import ImageEncoder

In [4]:
image_encoder_weights_path = "/home/user_2/AttnGAN/trained_weights/bird/image_encoder200.pth"
text_encoder_weights_path = "/home/user_2/AttnGAN/trained_weights/bird/text_encoder200.pth"

## Image Encoder

In [5]:
image_encoder = ImageEncoder()

2019-10-16 15:33:51.169 | DEBUG    | models.image_encoder:__init__:23 - Started loading the Inception-v3 model
2019-10-16 15:33:52.933 | DEBUG    | models.image_encoder:__init__:25 - Finished loading the Inception-v3 model


#### Load Image Encoder Weights

In [6]:
image_encoder_stat_dict = torch.load(image_encoder_weights_path)

In [7]:
# image_encoder_stat_dict['emb_features.weight']
# image_encoder_stat_dict['emb_cnn_code.weight']
# image_encoder_stat_dict['emb_cnn_code.bias']

#### Assign Image Encoder Weigts

In [8]:
image_encoder.fc_local.weight = torch.nn.Parameter(image_encoder_stat_dict['emb_features.weight'])

In [9]:
image_encoder.fc_global.weight = torch.nn.Parameter(image_encoder_stat_dict['emb_cnn_code.weight'])

In [10]:
image_encoder.fc_global.bias = torch.nn.Parameter(image_encoder_stat_dict['emb_cnn_code.bias'])

In [11]:
image_encoder = image_encoder.cpu()

In [12]:
x = torch.rand(1, 3, 299, 299)

In [13]:
_ = image_encoder(x)

## Text Encoder

In [14]:
bird_dataset = BirdsDataset("/home/user_2/AttnGAN/datasets/cub200-2011/preprocessing")

In [15]:
bird_dataset.vocab_size

4055

### Load Text Encoder Weights

In [16]:
text_encoder_stat_dict = torch.load(text_encoder_weights_path)

In [17]:
text_encoder_stat_dict['encoder.weight'].shape

torch.Size([5450, 300])

In [18]:
text_encoder_stat_dict['']

KeyError: ''

#### Assign text encoder weights

In [None]:
text_encoder = TextEncoder(bird_dataset.vocab_size)

In [66]:
text_encoder.RNN.weight_ih_l0 = torch.nn.Parameter(text_encoder_stat_dict['rnn.weight_ih_l0'])
text_encoder.RNN.weight_hh_l0 = torch.nn.Parameter(text_encoder_stat_dict['rnn.weight_hh_l0'])
text_encoder.RNN.bias_ih_l0 = torch.nn.Parameter(text_encoder_stat_dict['rnn.bias_ih_l0'])
text_encoder.RNN.bias_hh_l0 = torch.nn.Parameter(text_encoder_stat_dict['rnn.bias_hh_l0'])
text_encoder.RNN.weight_ih_l0_reverse = torch.nn.Parameter(text_encoder_stat_dict['rnn.weight_ih_l0_reverse'])
text_encoder.RNN.weight_hh_l0_reverse = torch.nn.Parameter(text_encoder_stat_dict['rnn.weight_hh_l0_reverse'])
text_encoder.RNN.bias_ih_l0_reverse = torch.nn.Parameter(text_encoder_stat_dict['rnn.bias_ih_l0_reverse'])
text_encoder.RNN.bias_hh_l0_reverse = torch.nn.Parameter(text_encoder_stat_dict['rnn.bias_hh_l0_reverse'])


### Save Results

In [68]:
torch.save({"text_encoder": text_encoder.state_dict(),
            "image_encoder": image_encoder.state_dict()},
          "/home/user_2/AttnGAN/Matan/AttnGAN/models/best_encoders_weights.pt")
    

## Restoring pre mode-collapse state for last G & D layer

In [22]:
from models.generator import GeneratorNetwork
from models.conditional_augmentation import ConditioningAugmentation
from models.discriminator import Discriminator

In [48]:
gn = GeneratorNetwork()

In [49]:
D2 = Discriminator(256)

In [50]:
opt2 = torch.optim.Adam(D2.parameters(), lr=0.00002, betas=(0.5, 0.999))

In [52]:
tmp_model = torch.load("../models/attngan/epoch_checkpoint_attngan.pt", map_location='cpu')

In [53]:
# Run this if you want a random model
tmp_model['models']['generator'] = gn.state_dict()
tmp_model['models']['discriminators'][2] = D2.state_dict()
tmp_model['optimizers']['D_optimizers'][2] = opt2.state_dict()
torch.save(tmp_model, "./tmp.pt")

In [54]:
epoch_50 = torch.load("./tmp.pt")

In [55]:
last_epoch = torch.load("../models/attngan/epoch_checkpoint_attngan.pt", map_location='cpu')

In [56]:
g_2_weights_names = [
"stages_layers.2.residual_layer.0.residual_layer_1.0.weight",
"stages_layers.2.residual_layer.0.residual_layer_1.1.weight",
"stages_layers.2.residual_layer.0.residual_layer_1.1.bias",
"stages_layers.2.residual_layer.0.residual_layer_1.1.running_mean",
"stages_layers.2.residual_layer.0.residual_layer_1.1.running_var",
"stages_layers.2.residual_layer.0.residual_layer_1.1.num_batches_tracked",
"stages_layers.2.residual_layer.0.residual_layer_2.0.weight",
"stages_layers.2.residual_layer.0.residual_layer_2.1.weight",
"stages_layers.2.residual_layer.0.residual_layer_2.1.bias",
"stages_layers.2.residual_layer.0.residual_layer_2.1.running_mean",
"stages_layers.2.residual_layer.0.residual_layer_2.1.running_var",
"stages_layers.2.residual_layer.0.residual_layer_2.1.num_batches_tracked",
"stages_layers.2.residual_layer.1.residual_layer_1.0.weight",
"stages_layers.2.residual_layer.1.residual_layer_1.1.weight",
"stages_layers.2.residual_layer.1.residual_layer_1.1.bias",
"stages_layers.2.residual_layer.1.residual_layer_1.1.running_mean",
"stages_layers.2.residual_layer.1.residual_layer_1.1.running_var",
"stages_layers.2.residual_layer.1.residual_layer_1.1.num_batches_tracked",
"stages_layers.2.residual_layer.1.residual_layer_2.0.weight",
"stages_layers.2.residual_layer.1.residual_layer_2.1.weight",
"stages_layers.2.residual_layer.1.residual_layer_2.1.bias",
"stages_layers.2.residual_layer.1.residual_layer_2.1.running_mean",
"stages_layers.2.residual_layer.1.residual_layer_2.1.running_var",
"stages_layers.2.residual_layer.1.residual_layer_2.1.num_batches_tracked",
"stages_layers.2.upsample.upsample_block.0.weight",
"stages_layers.2.upsample.upsample_block.0.bias",
"stages_layers.2.upsample.upsample_block.1.weight",
"stages_layers.2.upsample.upsample_block.2.weight",
"stages_layers.2.upsample.upsample_block.2.bias",
"stages_layers.2.upsample.upsample_block.2.running_mean",
"stages_layers.2.upsample.upsample_block.2.running_var",
"stages_layers.2.upsample.upsample_block.2.num_batches_tracked",
"attention_layers.1.fc.weight",
"attention_layers.1.fc.bias",
"image_generators.2.image.0.weight"]

In [57]:
# Save the relevant info to 
for name in g_2_weights_names:
    last_epoch['models']['generator'][name] = torch.nn.Parameter(epoch_50['models']['generator'][name], requires_grad=False)
    if '.num_batches_tracked' not in name:
        last_epoch['models']['generator'][name].requires_grad = True
last_epoch['models']['discriminators'][2] = epoch_50['models']['discriminators'][2]
last_epoch['optimizers']['D_optimizers'][2] = epoch_50['optimizers']['D_optimizers'][2]

In [58]:
gn = GeneratorNetwork()
gn.load_state_dict(last_epoch['models']['generator'])

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [59]:
ca = ConditioningAugmentation()
ca.load_state_dict(last_epoch['models']['conditional_augmentation'])

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [60]:
G_params = torch.nn.ParameterList()
G_params.extend(ca.parameters())
G_params.extend(gn.parameters())

g_opt = torch.optim.Adam(G_params, lr=0.0002, betas=(0.5, 0.999))
last_epoch['optimizers']['G_optimizer'] = g_opt.state_dict()

ParameterList(
    (0): Parameter containing: [torch.FloatTensor of size 400x256]
    (1): Parameter containing: [torch.FloatTensor of size 400]
)

ParameterList(
    (0): Parameter containing: [torch.FloatTensor of size 400x256]
    (1): Parameter containing: [torch.FloatTensor of size 400]
    (2): Parameter containing: [torch.FloatTensor of size 65536x200]
    (3): Parameter containing: [torch.FloatTensor of size 65536]
    (4): Parameter containing: [torch.FloatTensor of size 65536]
    (5): Parameter containing: [torch.FloatTensor of size 2048x2048x3x3]
    (6): Parameter containing: [torch.FloatTensor of size 2048]
    (7): Parameter containing: [torch.FloatTensor of size 2048x2048x3x3]
    (8): Parameter containing: [torch.FloatTensor of size 2048]
    (9): Parameter containing: [torch.FloatTensor of size 2048]
    (10): Parameter containing: [torch.FloatTensor of size 1024x1024x3x3]
    (11): Parameter containing: [torch.FloatTensor of size 1024]
    (12): Parameter containing: [torch.FloatTensor of size 1024x1024x3x3]
    (13): Parameter containing: [torch.FloatTensor of size 1024]
    (14): Parameter containing: [torch.F

In [61]:
torch.save(last_epoch, '../models/attngan/manual_attngan.pt')