In [4]:
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Apple Metal
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

mps


In [5]:
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
# torch.backends.mps.enable_flash_sdp(False)


In [3]:
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer
import os
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DDPMScheduler, LMSDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import clip
from tqdm import tqdm
from scripts import eng_processor
from scripts.text_preprocessing import TextImageDataset
from scripts.train_model import BertEncoder, Generator, Discriminator
from scripts.inference import generate_image_from_text
from scripts.utils import combine_dataset, Evaluator
import pandas as pd
from collections import defaultdict
from sklearn.model_selection import train_test_split
from torchvision import transforms


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
combined_df = pd.read_csv("datasets/combined_dataset.csv")
combined_df.tail(5)

Unnamed: 0,image_path,caption,source
90483,datasets/coco_dataset/train2017/000000370808.jpg,A black cat looking out the window at a black ...,MSCOCO
90484,datasets/coco_dataset/train2017/000000370808.jpg,Black cat sitting on window ledge looking outs...,MSCOCO
90485,datasets/coco_dataset/train2017/000000370808.jpg,A black cat looks out the window as a crow out...,MSCOCO
90486,datasets/coco_dataset/train2017/000000370808.jpg,A cat by a window with a small bird outside.,MSCOCO
90487,datasets/coco_dataset/train2017/000000370808.jpg,A cat watches a bird through a window.,MSCOCO


In [9]:
# combined_df= combined_df.where(combined_df["source"]=="Flickr")

# combined_df = combined_df.dropna()
combined_df = combined_df.sample(frac=1, random_state=42).head(10000)
combined_df = combined_df.reset_index(drop=True)  # Reset index to sequential integers

for i in range(len(combined_df)):
    combined_df.caption[i] = eng_processor.main(combined_df.caption[i])

You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

  combined_df.caption[i] = eng_processor.main(combined_df.caption[i])


In [10]:
image_to_captions = defaultdict(list)

for _, row in combined_df.iterrows():
    img_path = row['image_path']
    image_to_captions[img_path].append(row['caption'])

In [11]:
image_paths = list(image_to_captions.keys())
train_ids, test_ids = train_test_split(image_paths, test_size=0.2, random_state=42)


In [12]:
# image transform
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize the image
    transforms.ToTensor(),  # Convert PIL Image to tensor
    transforms.Lambda(lambda x: (x * 2) - 1)  # Scale from [0,1] to [-1,1]
])


train_dataset = TextImageDataset(
    image_to_captions=image_to_captions,
    image_paths=train_ids,  
    transform=transform
)

test_dataset = TextImageDataset(
    image_to_captions=image_to_captions,
    image_paths=test_ids,  
    transform=transform
)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

Dataset size: 8032 (all captions included)
Dataset size: 1967 (all captions included)


In [13]:
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)



In [14]:
clip_text_encoder.eval()


CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), ep

In [15]:
# Load the pretrained pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",  # or any other model
    torch_dtype=torch.float16,         # Use float16 only if GPU/MLCompute is supported
)

pipeline.to("mps")  # Use M1/M2 GPU via Metal backend

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
Loading pipeline components...: 100%|█████████████| 7/7 [00:03<00:00,  2.03it/s]


StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.33.1",
  "_name_or_path": "runwayml/stable-diffusion-v1-5",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": true,
  "safety_checker": [
    "stable_diffusion",
    "StableDiffusionSafetyChecker"
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [16]:
class Projection(nn.Module):
    def __init__(self, in_dim=768, out_dim=768):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.proj(x)

projection = Projection().to(device)

# === 4. Stable Diffusion Decoder ===
vae = pipeline.vae.to(device)
unet = pipeline.unet.to(device)
scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)



In [17]:
# Freeze everything except projection
for param in vae.parameters(): param.requires_grad = False
for param in unet.parameters(): param.requires_grad = False

optimizer = torch.optim.Adam(projection.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()


In [None]:
for epoch in range(10):

    unet.train()
    clip_text_encoder.eval()
    vae.eval()
    projection.train()


    for i, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}")):
        images = batch["image"].to(device)
        texts = batch["caption_text"]
        
        # Match input dtype with model
        dtype = next(vae.parameters()).dtype
        images = images.to(dtype=dtype)

        with torch.no_grad():
            # Text encoding
            inputs = clip_tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)
            inputs = {key: val.to(device) for key, val in inputs.items()}

            text_feats = clip_text_encoder(**inputs).last_hidden_state  # (B, seq_len, 768)


            # get latent for vae
            latents = vae.encode(images).latent_dist.sample() * 0.18215


        pooled_feats = text_feats.mean(dim=1)  # (B, 768)
        projected_feats = projection(pooled_feats)

        # add noise to it
        with torch.no_grad():

            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()
            noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        projected_feats = projected_feats.to(dtype=noisy_latents.dtype).unsqueeze(1)  # (B, 1, D)

        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=projected_feats).sample

        loss = loss_fn(noise_pred, noise)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


        if i % 100 == 0:
            print(f"[Batch {i}/{len(train_loader)}] [Loss: {loss.item():.4f}]")

        # Clear memory
        del images, texts, inputs, text_feats, pooled_feats, projected_feats
        del latents, noise, timesteps, noisy_latents, noise_pred
        torch.mps.empty_cache()

    print(f"Epoch {epoch} finished")


# Save model after training
torch.save({
    'unet': unet.state_dict(),
    'clip_text_encoder': clip_text_encoder.state_dict(),
    'projection': projection.state_dict()
}, "final_model_checkpoint.pt")

print("Training completed ✅")

Training Epoch 1:   0%|                      | 1/4016 [00:01<1:23:43,  1.25s/it]

[Batch 0/4016] [Loss: 0.3232]


Training Epoch 1:   3%|▌                     | 101/4016 [00:40<25:30,  2.56it/s]

[Batch 100/4016] [Loss: 0.3823]


Training Epoch 1:   5%|█                     | 201/4016 [01:19<25:03,  2.54it/s]

[Batch 200/4016] [Loss: 0.3982]


Training Epoch 1:   7%|█▋                    | 301/4016 [02:01<26:38,  2.32it/s]

[Batch 300/4016] [Loss: 0.3384]


Training Epoch 1:  10%|██▏                   | 401/4016 [02:40<23:27,  2.57it/s]

[Batch 400/4016] [Loss: 0.3679]


Training Epoch 1:  12%|██▋                   | 501/4016 [03:20<24:02,  2.44it/s]

[Batch 500/4016] [Loss: 0.4478]


Training Epoch 1:  15%|███▎                  | 601/4016 [03:59<21:57,  2.59it/s]

[Batch 600/4016] [Loss: 0.3013]


Training Epoch 1:  17%|███▊                  | 701/4016 [04:38<22:00,  2.51it/s]

[Batch 700/4016] [Loss: 0.4221]


Training Epoch 1:  20%|████▍                 | 801/4016 [05:18<20:42,  2.59it/s]

[Batch 800/4016] [Loss: 0.4031]


Training Epoch 1:  22%|████▉                 | 901/4016 [05:57<19:57,  2.60it/s]

[Batch 900/4016] [Loss: 0.5649]


Training Epoch 1:  25%|█████▏               | 1001/4016 [06:36<20:12,  2.49it/s]

[Batch 1000/4016] [Loss: 0.2632]


Training Epoch 1:  27%|█████▊               | 1101/4016 [07:15<19:02,  2.55it/s]

[Batch 1100/4016] [Loss: 0.3904]


Training Epoch 1:  30%|██████▎              | 1201/4016 [07:55<17:54,  2.62it/s]

[Batch 1200/4016] [Loss: 0.3252]


Training Epoch 1:  32%|██████▊              | 1301/4016 [08:34<17:43,  2.55it/s]

[Batch 1300/4016] [Loss: 0.2568]


Training Epoch 1:  35%|███████▎             | 1401/4016 [09:13<17:29,  2.49it/s]

[Batch 1400/4016] [Loss: 0.3064]


Training Epoch 1:  37%|███████▊             | 1501/4016 [09:53<16:23,  2.56it/s]

[Batch 1500/4016] [Loss: 0.3442]


Training Epoch 1:  40%|████████▎            | 1601/4016 [10:32<16:00,  2.52it/s]

[Batch 1600/4016] [Loss: 0.3782]


Training Epoch 1:  42%|████████▉            | 1701/4016 [11:11<14:51,  2.60it/s]

[Batch 1700/4016] [Loss: 0.3552]


Training Epoch 1:  45%|█████████▍           | 1801/4016 [11:50<14:26,  2.56it/s]

[Batch 1800/4016] [Loss: 0.3550]


Training Epoch 1:  47%|█████████▉           | 1901/4016 [12:30<13:34,  2.60it/s]

[Batch 1900/4016] [Loss: 0.5303]


Training Epoch 1:  50%|██████████▍          | 2001/4016 [13:09<13:15,  2.53it/s]

[Batch 2000/4016] [Loss: 0.3474]


Training Epoch 1:  52%|██████████▉          | 2101/4016 [13:48<12:22,  2.58it/s]

[Batch 2100/4016] [Loss: 0.3489]


Training Epoch 1:  55%|███████████▌         | 2201/4016 [14:28<11:58,  2.53it/s]

[Batch 2200/4016] [Loss: 0.2708]


Training Epoch 1:  57%|████████████         | 2301/4016 [15:07<11:02,  2.59it/s]

[Batch 2300/4016] [Loss: 0.4895]


Training Epoch 1:  60%|████████████▌        | 2401/4016 [15:46<10:32,  2.55it/s]

[Batch 2400/4016] [Loss: 0.2231]


Training Epoch 1:  62%|█████████████        | 2501/4016 [16:25<09:52,  2.56it/s]

[Batch 2500/4016] [Loss: 0.4858]


Training Epoch 1:  65%|█████████████▌       | 2601/4016 [17:04<09:15,  2.55it/s]

[Batch 2600/4016] [Loss: 0.3733]


Training Epoch 1:  67%|██████████████       | 2701/4016 [17:44<08:44,  2.51it/s]

[Batch 2700/4016] [Loss: 0.1864]


Training Epoch 1:  70%|██████████████▋      | 2801/4016 [18:23<07:53,  2.56it/s]

[Batch 2800/4016] [Loss: 0.4731]


Training Epoch 1:  72%|███████████████▏     | 2901/4016 [19:02<07:17,  2.55it/s]

[Batch 2900/4016] [Loss: 0.4316]


Training Epoch 1:  75%|███████████████▋     | 3001/4016 [19:42<06:37,  2.55it/s]

[Batch 3000/4016] [Loss: 0.3027]


Training Epoch 1:  77%|████████████████▏    | 3101/4016 [20:21<05:57,  2.56it/s]

[Batch 3100/4016] [Loss: 0.3328]


Training Epoch 1:  80%|████████████████▋    | 3201/4016 [21:00<05:32,  2.45it/s]

[Batch 3200/4016] [Loss: 0.3018]


Training Epoch 1:  82%|█████████████████▎   | 3301/4016 [21:39<04:38,  2.56it/s]

[Batch 3300/4016] [Loss: 0.3267]


Training Epoch 1:  85%|█████████████████▊   | 3401/4016 [22:19<04:05,  2.51it/s]

[Batch 3400/4016] [Loss: 0.2256]


Training Epoch 1:  87%|██████████████████▎  | 3501/4016 [22:58<03:21,  2.56it/s]

[Batch 3500/4016] [Loss: 0.4500]


Training Epoch 1:  90%|██████████████████▊  | 3601/4016 [23:37<02:43,  2.54it/s]

[Batch 3600/4016] [Loss: 0.3853]


Training Epoch 1:  92%|███████████████████▎ | 3701/4016 [24:16<02:02,  2.57it/s]

[Batch 3700/4016] [Loss: 0.2593]


Training Epoch 1:  95%|███████████████████▉ | 3801/4016 [25:05<02:24,  1.49it/s]

[Batch 3800/4016] [Loss: 0.3101]


Training Epoch 1:  97%|████████████████████▍| 3901/4016 [26:13<01:16,  1.51it/s]

[Batch 3900/4016] [Loss: 0.3748]


Training Epoch 1: 100%|████████████████████▉| 4001/4016 [27:18<00:09,  1.53it/s]

[Batch 4000/4016] [Loss: 0.3196]


Training Epoch 1: 100%|█████████████████████| 4016/4016 [27:28<00:00,  2.44it/s]


Epoch 0 finished


Training Epoch 2:   0%|                        | 1/4016 [00:00<43:03,  1.55it/s]

[Batch 0/4016] [Loss: 0.4053]


Training Epoch 2:   3%|▌                     | 101/4016 [01:06<42:21,  1.54it/s]

[Batch 100/4016] [Loss: 0.2140]


Training Epoch 2:   5%|█                     | 201/4016 [02:11<42:17,  1.50it/s]

[Batch 200/4016] [Loss: 0.3354]


Training Epoch 2:   7%|█▋                    | 301/4016 [03:17<40:12,  1.54it/s]

[Batch 300/4016] [Loss: 0.1713]


Training Epoch 2:  10%|██▏                   | 401/4016 [22:08<42:07,  1.43it/s]

[Batch 400/4016] [Loss: 0.3389]


Training Epoch 2:  12%|██▋                   | 501/4016 [23:16<41:09,  1.42it/s]

[Batch 500/4016] [Loss: 0.3582]


Training Epoch 2:  15%|███▎                  | 601/4016 [24:29<44:26,  1.28it/s]

[Batch 600/4016] [Loss: 0.2600]


Training Epoch 2:  17%|███▊                  | 701/4016 [25:43<37:33,  1.47it/s]

[Batch 700/4016] [Loss: 0.3657]


Training Epoch 2:  20%|████▍                 | 801/4016 [26:51<26:47,  2.00it/s]

[Batch 800/4016] [Loss: 0.5713]


Training Epoch 2:  22%|████▉                 | 901/4016 [27:31<20:32,  2.53it/s]

[Batch 900/4016] [Loss: 0.3525]


Training Epoch 2:  25%|█████▏               | 1001/4016 [28:11<19:33,  2.57it/s]

[Batch 1000/4016] [Loss: 0.2876]


Training Epoch 2:  27%|█████▊               | 1101/4016 [28:53<21:27,  2.26it/s]

[Batch 1100/4016] [Loss: 0.1978]


Training Epoch 2:  30%|██████▎              | 1201/4016 [29:37<22:10,  2.12it/s]

[Batch 1200/4016] [Loss: 0.5640]


Training Epoch 2:  32%|██████▊              | 1301/4016 [30:19<17:22,  2.60it/s]

[Batch 1300/4016] [Loss: 0.2062]


Training Epoch 2:  35%|███████▎             | 1401/4016 [30:58<17:07,  2.55it/s]

[Batch 1400/4016] [Loss: 0.2712]


Training Epoch 2:  37%|███████▊             | 1501/4016 [31:39<17:27,  2.40it/s]

[Batch 1500/4016] [Loss: 0.3359]


Training Epoch 2:  40%|████████▎            | 1601/4016 [32:19<16:10,  2.49it/s]

[Batch 1600/4016] [Loss: 0.3535]


Training Epoch 2:  42%|████████▉            | 1701/4016 [33:00<16:07,  2.39it/s]

[Batch 1700/4016] [Loss: 0.2106]


Training Epoch 2:  45%|█████████▍           | 1801/4016 [33:40<14:16,  2.59it/s]

[Batch 1800/4016] [Loss: 0.2549]


Training Epoch 2:  47%|█████████▉           | 1901/4016 [34:22<15:41,  2.25it/s]

[Batch 1900/4016] [Loss: 0.3665]


Training Epoch 2:  50%|██████████▍          | 2001/4016 [35:02<13:14,  2.54it/s]

[Batch 2000/4016] [Loss: 0.3787]


Training Epoch 2:  52%|██████████▉          | 2101/4016 [35:41<12:27,  2.56it/s]

[Batch 2100/4016] [Loss: 0.1720]


Training Epoch 2:  55%|███████████▌         | 2201/4016 [36:24<12:05,  2.50it/s]

[Batch 2200/4016] [Loss: 0.6113]


Training Epoch 2:  57%|████████████         | 2301/4016 [37:04<10:57,  2.61it/s]

[Batch 2300/4016] [Loss: 0.3308]


Training Epoch 2:  60%|████████████▌        | 2401/4016 [37:43<10:29,  2.57it/s]

[Batch 2400/4016] [Loss: 0.4697]


Training Epoch 2:  62%|█████████████        | 2501/4016 [38:23<09:58,  2.53it/s]

[Batch 2500/4016] [Loss: 0.2725]


Training Epoch 2:  65%|█████████████▌       | 2601/4016 [39:03<09:22,  2.52it/s]

[Batch 2600/4016] [Loss: 0.2915]


Training Epoch 2:  67%|██████████████       | 2701/4016 [39:42<08:43,  2.51it/s]

[Batch 2700/4016] [Loss: 0.3047]


Training Epoch 2:  70%|██████████████▋      | 2801/4016 [40:22<08:09,  2.48it/s]

[Batch 2800/4016] [Loss: 0.3347]


Training Epoch 2:  72%|███████████████▏     | 2901/4016 [41:01<07:17,  2.55it/s]

[Batch 2900/4016] [Loss: 0.2917]


Training Epoch 2:  73%|███████████████▏     | 2914/4016 [41:07<07:23,  2.48it/s]