In [2]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import clip
from tqdm import tqdm
from scripts import eng_processor
from scripts.text_preprocessing import TextImageDataset
from scripts.train_model import BertEncoder, Generator, Discriminator
from scripts.inference import generate_image_from_text
from scripts.utils import combine_dataset, Evaluator
import pandas as pd

from torchvision import datasets, models, transforms
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from pycocotools.coco import COCO
import os
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm
  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [3]:
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
# torch.backends.mps.enable_flash_sdp(False)

In [4]:
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Apple Metal
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

mps


In [5]:
combined_df = pd.read_csv("datasets/combined_dataset.csv")
combined_df.tail(5)

Unnamed: 0,image_path,caption,source
90483,datasets/coco_dataset/train2017/000000370808.jpg,A black cat looking out the window at a black ...,MSCOCO
90484,datasets/coco_dataset/train2017/000000370808.jpg,Black cat sitting on window ledge looking outs...,MSCOCO
90485,datasets/coco_dataset/train2017/000000370808.jpg,A black cat looks out the window as a crow out...,MSCOCO
90486,datasets/coco_dataset/train2017/000000370808.jpg,A cat by a window with a small bird outside.,MSCOCO
90487,datasets/coco_dataset/train2017/000000370808.jpg,A cat watches a bird through a window.,MSCOCO


In [6]:
# combined_df= combined_df.where(combined_df["source"]=="Flickr")

# combined_df = combined_df.dropna()
combined_df = combined_df.sample(frac=1, random_state=42).head(10000)
combined_df = combined_df.reset_index(drop=True)  # Reset index to sequential integers

for i in range(len(combined_df)):
    combined_df.caption[i] = eng_processor.main(combined_df.caption[i])

You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

  combined_df.caption[i] = eng_processor.main(combined_df.caption[i])


In [7]:
combined_df

Unnamed: 0,image_path,caption,source
0,datasets/Flickr8k_Dataset/1957371533_62bc720ba...,a dog shaking water off,Flickr
1,datasets/Flickr8k_Dataset/3411022255_210eefc37...,four people and a child walking in the street,Flickr
2,datasets/Flickr8k_Dataset/2294516804_11e255807...,a baby sits by his red and green toy pulling a...,Flickr
3,datasets/coco_dataset/train2017/000000331372.jpg,an image from the outside of a window with flo...,MSCOCO
4,datasets/coco_dataset/train2017/000000527711.jpg,a close up of a person talking on a cell phone,MSCOCO
...,...,...,...
9995,datasets/coco_dataset/train2017/000000189461.jpg,an assortment of donuts in front of a coffee cup,MSCOCO
9996,datasets/coco_dataset/train2017/000000344314.jpg,the two zebras are standing behind the warning...,MSCOCO
9997,datasets/Flickr8k_Dataset/505929313_7668f021ab...,a black dog standing in shallow water with a p...,Flickr
9998,datasets/coco_dataset/train2017/000000119384.jpg,a close up of a car parked on top of a small f...,MSCOCO


In [8]:
image_to_captions = defaultdict(list)

for _, row in combined_df.iterrows():
    img_path = row['image_path']
    image_to_captions[img_path].append(row['caption'])

In [9]:
image_paths = list(image_to_captions.keys())
train_ids, test_ids = train_test_split(image_paths, test_size=0.2, random_state=42)


In [10]:
def collate_fn(batch):
    # Collect images and tokenize captions
    images = torch.stack([example['image'] for example in batch])  # Stack images
    input_ids = torch.stack([example['input_ids'] for example in batch])  # Stack input_ids
    attention_mask = torch.stack([example['attention_mask'] for example in batch])  # Stack attention_mask

    return {
        'pixel_values': images,
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }

In [11]:
# image transform
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


train_dataset = TextImageDataset(
    image_to_captions=image_to_captions,
    image_paths=train_ids,  
    transform=transform
)

test_dataset = TextImageDataset(
    image_to_captions=image_to_captions,
    image_paths=test_ids,  
    transform=transform
)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True,)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, )

Dataset size: 7988 (all captions included)
Dataset size: 2011 (all captions included)


### Model

In [12]:
# !pip install -q diffusers accelerate transformers safetensors peft bitsandbytes datasets


In [13]:
# !pip install -q git+https://github.com/huggingface/peft.git


In [14]:
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from transformers import BertModel, BertTokenizer
from peft import LoraModel


In [15]:
from transformers import BertModel, BertTokenizer

# Load the BERT model to get embeddings
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert = BertModel.from_pretrained("bert-base-uncased").to(device)

In [16]:
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
)
pipe.to(device)

pipe.vae.requires_grad_(False)
pipe.unet.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)

Loading pipeline components...: 100%|███████████████| 7/7 [00:01<00:00,  5.01it/s]


CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), ep

In [17]:
vae = pipe.vae
unet = pipe.unet

In [18]:
class TextToLatent(nn.Module):
    def __init__(self, text_dim=768, latent_dim=4*64*64):  # SD latent size: (4, 64, 64)
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(text_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, latent_dim)
        )

    def forward(self, x):
        return self.fc(x).view(-1, 4, 64, 64)

In [19]:
mapper = TextToLatent().to(device)
optimizer = torch.optim.Adam(mapper.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

In [None]:
vae = vae.to("cpu").float()   # VAE stays on CPU
bert = bert.to(device).float()
mapper = mapper.to(device).float()

for epoch in range(5):
    total_loss = 0.0
    mapper.train()
    bert.eval()
    vae.eval()

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/5")

    for batch in pbar:
        input_ids = batch["input_ids"].to(device)
        attn_mask = batch["attention_mask"].to(device)
        images = batch["image"].float().to("cpu")  # images to CPU for VAE

        # 1. Get text embeddings
        with torch.no_grad():
            text_embeds = bert(input_ids=input_ids, attention_mask=attn_mask).last_hidden_state[:, 0, :]

        # 2. Get latent encodings from VAE on CPU
        with torch.no_grad():
            latents_gt = vae.encode(images).latent_dist.sample() * 0.18215
            latents_gt = latents_gt.to(device)  # move latents to MPS for training

        # 3. Predict and compute loss
        latents_pred = mapper(text_embeds)
        loss = loss_fn(latents_pred, latents_gt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix({"loss": total_loss / (pbar.n + 1)})

        # Free memory
        del input_ids, attn_mask, images, text_embeds, latents_gt, latents_pred
        torch.mps.empty_cache()

    print(f"Epoch {epoch+1} Avg Loss: {total_loss / len(train_loader):.4f}")

# Save trained mapper
torch.save(mapper.state_dict(), "text_to_latent_mapper.pth")
print("✅ Done training!")


Epoch 1/5:   1%|▏                 | 30/3994 [01:31<3:22:57,  3.07s/it, loss=0.749]