#### Testing if the model is working:

---

In [15]:
import torch
from transformers import AutoTokenizer
import torch.nn as nn
import torch.optim as optim 
import torch.nn.functional as F
from PIL import Image 

In [3]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = len(tokenizer)

In [2]:
## Config params -- same as in the earlier notebook
Z_DIM = 100
EMBED_DIM = 256
IMG_CHANNELS = 3
IMG_SIZE = 256

In [10]:
## Rebuilding the Generator -- Ctrl-V from the previous notebook
class Generator(nn.Module):
    def __init__(self, z_dim, embed_dim, img_channels):
        super().__init__()
        self.img_size = IMG_SIZE
        
        self.embedding = nn.Embedding(tokenizer.vocab_size, embed_dim)
        self.fc_embed = nn.Linear(embed_dim, z_dim)
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(z_dim * 2, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z, captions):
        embed = self.embedding(captions).mean(dim=1)
        embed = self.fc_embed(embed)
        combined = torch.cat([z, embed], dim=1)
        combined = combined.unsqueeze(-1).unsqueeze(-1)
        return self.main(combined)

In [11]:
def load_generator(checkpoint_path, device="cuda"):
    # Initialize model
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")  # Same tokenizer as training
    generator = Generator(Z_DIM, EMBED_DIM, IMG_CHANNELS).to(device)
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Handle DataParallel wrapping if used during training
    if "module." in list(checkpoint['generator'].keys())[0]:
        generator = nn.DataParallel(generator)
    
    generator.load_state_dict(checkpoint['generator'])
    generator.eval()  # Set to evaluation mode
    return generator, tokenizer

In [12]:
def process_caption(caption, tokenizer, max_length=16, device="cuda"):
    tokenized = tokenizer(
        caption,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    return tokenized.input_ids.squeeze(0).to(device)

In [13]:
def generate_image(generator, caption, device="cuda"):
    # Process text
    caption_ids = process_caption(caption, tokenizer)
    
    # Generate noise vector
    z = torch.randn(1, Z_DIM).to(device)  # Batch size = 1
    
    # Generate image
    with torch.no_grad():
        fake_image = generator(z, caption_ids.unsqueeze(0))  # Add batch dimension
    
    # Convert to PIL Image
    fake_image = fake_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
    fake_image = (fake_image * 0.5 + 0.5) * 255  # Denormalize [0-255]
    fake_image = fake_image.astype("uint8")
    
    return Image.fromarray(fake_image)

In [22]:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_path = "checkpoints/epoch_3.pt"
generator, tokenizer = load_generator(checkpoint_path, device)

# Generate an image
caption = "A bird on top of car"
image = generate_image(generator, caption, device)

# Display/save the image
image.save("generated_image.png")
image.show()

  checkpoint = torch.load(checkpoint_path, map_location=device)


Ahhh... So I guess the model is working?? but since the number of epochs was so less, it didn't work as supposed to...

---