In [10]:
import tensorflow as tf
from tensorflow.keras import layers, Model, applications
import torch
import numpy as np
import os

print(f"TensorFlow Version: {tf.__version__}")
print(f"PyTorch Version: {torch.__version__}")

TensorFlow Version: 2.18.0
PyTorch Version: 2.5.1


## 1. Define TensorFlow Architectures
These must match exactly what is used in `DFGAN_clip.ipynb`.

In [11]:
class RNN_Encoder(Model):
    """
    Bi-Directional LSTM Text Encoder.
    Matches DAMSM.py architecture.
    """
    def __init__(self, ntoken, ninput=300, nhidden=256, nlayers=1, drop_prob=0.5):
        super(RNN_Encoder, self).__init__()
        self.nhidden = nhidden // 2  # Because bidirectional doubles it
        self.ninput = ninput
        self.nlayers = nlayers
        self.drop_prob = drop_prob
        
        # Embedding: vocab_size -> 300
        self.embedding = layers.Embedding(ntoken, ninput,
                                        embeddings_initializer=tf.initializers.RandomUniform(-0.1, 0.1))
        self.drop = layers.Dropout(drop_prob)
        
        # Bi-LSTM: outputs nhidden*2 = 256
        self.rnn = layers.Bidirectional(
            layers.LSTM(self.nhidden, return_sequences=True, return_state=True, dropout=drop_prob)
        )

    def call(self, captions, cap_lens=None, training=False):
        # captions: [B, Max_Seq_Len]
        emb = self.embedding(captions)
        emb = self.drop(emb, training=training)
        
        # Create mask if cap_lens provided
        if cap_lens is not None:
            mask = tf.sequence_mask(cap_lens, maxlen=tf.shape(captions)[1])
        else:
            mask = None
        
        # RNN Forward
        # output: [B, Seq, Hidden*2]
        # states: forward_h, forward_c, backward_h, backward_c
        output, f_h, f_c, b_h, b_c = self.rnn(emb, mask=mask, training=training)
        
        # Words Embedding: [B, Hidden*2, Seq]
        # Transpose to match official PyTorch output [B, Hidden*2, Seq]
        words_emb = tf.transpose(output, [0, 2, 1])
        
        # Sentence Embedding: [B, Hidden*2]
        # Concatenate final hidden states of forward and backward
        sent_emb = tf.concat([f_h, b_h], axis=1)
        
        return words_emb, sent_emb

class CNN_Encoder(Model):
    """
    InceptionV3 Image Encoder.
    Matches DAMSM.py architecture.
    """
    def __init__(self, nef=256):
        super(CNN_Encoder, self).__init__()
        self.nef = 256  # Hardcoded like official!
        
        # Load InceptionV3 (Pretrained on ImageNet)
        base_model = applications.InceptionV3(include_top=False, weights='imagenet', input_shape=(299, 299, 3))
        base_model.trainable = False # Freeze base model
        
        # Define outputs
        # 'mixed7' is the last 17x17 block (Matches PyTorch Mixed_6e)
        # 'mixed10' is the last 8x8 block (Matches PyTorch Mixed_7c)
        layer_names = ['mixed7', 'mixed10']
        outputs = [base_model.get_layer(name).output for name in layer_names]
        
        self.inception = Model(inputs=base_model.input, outputs=outputs)
        
        # Projections
        # 1x1 Conv for local features (768 -> nef) - NO BIAS
        self.emb_features = layers.Conv2D(self.nef, 1, strides=1, padding='valid', use_bias=False,
                                          kernel_initializer=tf.initializers.RandomUniform(-0.1, 0.1))
        
        # Linear for global features (2048 -> nef) - WITH BIAS
        self.emb_cnn_code = layers.Dense(self.nef, use_bias=True,
                                         kernel_initializer=tf.initializers.RandomUniform(-0.1, 0.1))

    def call(self, inputs, training=False):
        # inputs: [B, H, W, 3] - will be resized to 299x299
        x = tf.image.resize(inputs, [299, 299])
        
        # Normalize to [-1, 1] for InceptionV3
        x = (x - 0.5) * 2.0
        
        # Get Inception Features
        feat_local, feat_global = self.inception(x, training=False)
        
        # --- Local Features ---
        # [B, 17, 17, 768] -> [B, 17, 17, nef]
        local_emb = self.emb_features(feat_local)
        
        # Transpose to [B, nef, 17, 17] (NCHW) to match official PyTorch shape
        local_emb = tf.transpose(local_emb, [0, 3, 1, 2])
        
        # --- Global Features ---
        # [B, 8, 8, 2048] -> [B, 2048]
        global_pool = tf.reduce_mean(feat_global, axis=[1, 2])
        
        # Project: [B, nef]
        global_emb = self.emb_cnn_code(global_pool)
        
        return local_emb, global_emb

## 2. Conversion Logic

In [12]:
def convert_rnn_weights(pt_path, tf_model):
    print(f"Loading PyTorch weights from {pt_path}...")
    # Load state dict
    state_dict = torch.load(pt_path, map_location='cpu')
    
    # 1. Embedding
    # PT: encoder.weight [Vocab, Dim]
    # TF: embedding.embeddings [Vocab, Dim]
    pt_emb = state_dict['encoder.weight'].numpy()
    tf_model.embedding.set_weights([pt_emb])
    print(" - Embeddings converted.")
    
    # 2. LSTM
    # PT: rnn.weight_ih_l0, rnn.weight_hh_l0, rnn.bias_ih_l0, rnn.bias_hh_l0 (Forward)
    # PT: rnn.weight_ih_l0_reverse, ... (Backward)
    
    def set_lstm_layer(tf_lstm_layer, suffix=''):
        # Weights
        w_ih = state_dict[f'rnn.weight_ih_l0{suffix}'].numpy() # [4*H, Input]
        w_hh = state_dict[f'rnn.weight_hh_l0{suffix}'].numpy() # [4*H, H]
        
        # Biases
        b_ih = state_dict[f'rnn.bias_ih_l0{suffix}'].numpy()   # [4*H]
        b_hh = state_dict[f'rnn.bias_hh_l0{suffix}'].numpy()   # [4*H]
        bias = b_ih + b_hh
        
        # Transpose Weights for TF [Input, 4*H]
        w_ih = w_ih.T
        w_hh = w_hh.T
        
        # Set Weights
        # TF expects [kernel, recurrent_kernel, bias]
        tf_lstm_layer.set_weights([w_ih, w_hh, bias])
        
    # Forward Layer
    print(" - Converting Forward LSTM...")
    set_lstm_layer(tf_model.rnn.forward_layer, suffix='')
    
    # Backward Layer
    print(" - Converting Backward LSTM...")
    set_lstm_layer(tf_model.rnn.backward_layer, suffix='_reverse')
    
    print("RNN Conversion Complete.")

def convert_cnn_weights(pt_path, tf_model):
    print(f"Loading PyTorch weights from {pt_path}...")
    state_dict = torch.load(pt_path, map_location='cpu')
    
    # We only convert the projection layers, assuming InceptionV3 is standard ImageNet
    
    # 1. emb_features (Conv2d 1x1)
    # PT: emb_features.weight [Out, In, kH, kW] -> [256, 768, 1, 1]
    # TF: emb_features.kernel [kH, kW, In, Out] -> [1, 1, 768, 256]
    w_conv = state_dict['emb_features.weight'].numpy()
    w_conv = np.transpose(w_conv, (2, 3, 1, 0))
    tf_model.emb_features.set_weights([w_conv])
    print(" - Local features projection converted.")
    
    # 2. emb_cnn_code (Linear)
    # PT: emb_cnn_code.weight [Out, In] -> [256, 2048]
    # PT: emb_cnn_code.bias [Out] -> [256]
    # TF: emb_cnn_code.kernel [In, Out] -> [2048, 256]
    w_dense = state_dict['emb_cnn_code.weight'].numpy()
    b_dense = state_dict['emb_cnn_code.bias'].numpy()
    
    w_dense = w_dense.T
    
    tf_model.emb_cnn_code.set_weights([w_dense, b_dense])
    print(" - Global features projection converted.")
    
    print("CNN Conversion Complete.")

## 3. Execute Conversion
Set your paths here.

In [None]:
# PATHS TO YOUR PYTORCH MODELS
PT_TEXT_PATH = './DAMSMencoders/text_encoder_best.pth'  # <--- CHANGE THIS
PT_IMAGE_PATH = './DAMSMencoders/image_encoder_best.pth' # <--- CHANGE THIS

# OUTPUT PATHS
# Keras 3 (TF 2.18+) requires .weights.h5 extension for save_weights
TF_TEXT_PATH = './damsm_checkpoints/text_encoder.weights.h5'
TF_IMAGE_PATH = './damsm_checkpoints/image_encoder.weights.h5'

if not os.path.exists('./damsm_checkpoints'):
    os.makedirs('./damsm_checkpoints')

# 1. Initialize TF Models (Need to build them first)
# We need the vocab size to initialize RNN
vocab_size = 5429 # Updated to match training (Max ID 5428 + 1)
# Or load it:
# vocab = np.load('./dictionary/vocab.npy')
# vocab_size = len(vocab)

text_encoder = RNN_Encoder(ntoken=vocab_size, nhidden=256)
image_encoder = CNN_Encoder(nef=256)

# Build models with dummy input
print("Building TF models...")
dummy_cap = tf.zeros((1, 10), dtype=tf.int32)
text_encoder(dummy_cap)

dummy_img = tf.zeros((1, 299, 299, 3), dtype=tf.float32)
image_encoder(dummy_img)
print("Models built.")

# 2. Convert
if os.path.exists(PT_TEXT_PATH):
    convert_rnn_weights(PT_TEXT_PATH, text_encoder)
    text_encoder.save_weights(TF_TEXT_PATH)
    print(f"Saved TF Text Encoder to {TF_TEXT_PATH}")
else:
    print(f"Text encoder path not found: {PT_TEXT_PATH}")

if os.path.exists(PT_IMAGE_PATH):
    convert_cnn_weights(PT_IMAGE_PATH, image_encoder)
    image_encoder.save_weights(TF_IMAGE_PATH)
    print(f"Saved TF Image Encoder to {TF_IMAGE_PATH}")
else:
    print(f"Image encoder path not found: {PT_IMAGE_PATH}")

Building TF models...
Models built.
Loading PyTorch weights from ./DAMSMencoders/text_encoder_best.pth...
 - Embeddings converted.
 - Converting Forward LSTM...
 - Converting Backward LSTM...
RNN Conversion Complete.
Models built.
Loading PyTorch weights from ./DAMSMencoders/text_encoder_best.pth...
 - Embeddings converted.
 - Converting Forward LSTM...
 - Converting Backward LSTM...
RNN Conversion Complete.


  state_dict = torch.load(pt_path, map_location='cpu')


ValueError: The filename must end in `.weights.h5`. Received: filepath=./damsm_checkpoints/text_encoder_weights