In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transformers datasets evaluate bert_score nltk

Collecting evaluate
  Downloading evaluate-0.4.4-py3-none-any.whl.metadata (9.5 kB)
Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_path = "/content/drive/MyDrive/nlp_project/stage1_model"
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

import pandas as pd

train_df = pd.read_csv("/content/drive/MyDrive/nlp_project/civitai_train_transformed.csv")
val_df = pd.read_csv("/content/drive/MyDrive/nlp_project/civitai_val_transformed.csv")
test_df = pd.read_csv("/content/drive/MyDrive/nlp_project/civitai_test_transformed.csv")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
def format_caption_prompt(row):
    return f"### Caption:\n{row['gen_caption']}\n\n### Prompt:\n{row['prompt']}"

train_df["full_text"] = train_df.apply(format_caption_prompt, axis=1)
val_df["full_text"]   = val_df.apply(format_caption_prompt, axis=1)
test_df["full_text"]  = test_df.apply(format_caption_prompt, axis=1)

from datasets import Dataset

train_dataset = Dataset.from_pandas(train_df[["full_text"]])
val_dataset   = Dataset.from_pandas(val_df[["full_text"]])
test_dataset  = Dataset.from_pandas(test_df[["full_text"]])

# Tokenize to only predict prompt
def tokenize(example):
    result = tokenizer(
        example["full_text"],
        truncation=True,
        padding="max_length",
        max_length=512,
    )
    result["labels"] = result["input_ids"].copy()
    return result

train_dataset = train_dataset.map(tokenize, batched=True, remove_columns=["full_text"])
val_dataset   = val_dataset.map(tokenize, batched=True, remove_columns=["full_text"])
test_dataset  = test_dataset.map(tokenize, batched=True, remove_columns=["full_text"])

# Let it learn prompts
def mask_caption_tokens(example):
    input_ids = example["input_ids"]
    labels = example["labels"]
    caption_start = tokenizer("### Prompt:")["input_ids"][-1]

    try:
        start_idx = input_ids.index(caption_start)
    except ValueError:
        start_idx = 0

    labels = [-100 if i < start_idx else t for i, t in enumerate(labels)]
    return {"labels": labels}

train_dataset = train_dataset.map(mask_caption_tokens)
val_dataset   = val_dataset.map(mask_caption_tokens)
test_dataset  = test_dataset.map(mask_caption_tokens)

from torch.utils.data import DataLoader

def collate_fn(batch):
    return {
        "input_ids": torch.tensor([x["input_ids"] for x in batch]),
        "attention_mask": torch.tensor([x["attention_mask"] for x in batch]),
        "labels": torch.tensor([x["labels"] for x in batch])
    }

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

Map:   0%|          | 0/1058 [00:00<?, ? examples/s]

Map:   0%|          | 0/132 [00:00<?, ? examples/s]

Map:   0%|          | 0/133 [00:00<?, ? examples/s]

Map:   0%|          | 0/1058 [00:00<?, ? examples/s]

Map:   0%|          | 0/132 [00:00<?, ? examples/s]

Map:   0%|          | 0/133 [00:00<?, ? examples/s]

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32022, 2048)
    (layers): ModuleList(
      (0-23): 24 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5504, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5504, bias=False)
          (down_proj): Linear(in_features=5504, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-06)
    (rotary_emb)

In [None]:
from tqdm import tqdm

num_epochs = 3
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}")
    model.train()

    for step, batch in enumerate(tqdm(train_loader)):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        if input_ids.size(1) == 0:
            continue

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, use_cache=False)
        loss = outputs.loss

        if torch.isnan(loss):
            continue

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % 20 == 0:
            print(f"Step {step} | Loss: {loss.item():.4f}")


Epoch 1


  0%|          | 1/529 [00:01<14:50,  1.69s/it]

Step 0 | Loss: 5.9889


  4%|▍         | 21/529 [00:13<05:21,  1.58it/s]

Step 20 | Loss: 3.7775


  8%|▊         | 41/529 [00:25<05:08,  1.58it/s]

Step 40 | Loss: 3.1869


 12%|█▏        | 61/529 [00:36<04:55,  1.58it/s]

Step 60 | Loss: 2.3857


 15%|█▌        | 81/529 [00:48<04:42,  1.58it/s]

Step 80 | Loss: 3.5380


 19%|█▉        | 101/529 [01:00<04:30,  1.58it/s]

Step 100 | Loss: 3.1394


 23%|██▎       | 121/529 [01:11<04:17,  1.58it/s]

Step 120 | Loss: 2.5463


 27%|██▋       | 141/529 [01:23<04:05,  1.58it/s]

Step 140 | Loss: 2.7831


 30%|███       | 161/529 [01:35<03:52,  1.58it/s]

Step 160 | Loss: 3.4148


 34%|███▍      | 181/529 [01:46<03:39,  1.58it/s]

Step 180 | Loss: 2.5943


 38%|███▊      | 201/529 [01:58<03:27,  1.58it/s]

Step 200 | Loss: 3.4027


 42%|████▏     | 221/529 [02:10<03:14,  1.58it/s]

Step 220 | Loss: 2.6908


 46%|████▌     | 241/529 [02:21<03:01,  1.58it/s]

Step 240 | Loss: 1.7207


 49%|████▉     | 261/529 [02:33<02:49,  1.58it/s]

Step 260 | Loss: 2.1737


 53%|█████▎    | 281/529 [02:45<02:36,  1.58it/s]

Step 280 | Loss: 2.0102


 57%|█████▋    | 301/529 [02:56<02:24,  1.58it/s]

Step 300 | Loss: 3.2722


 61%|██████    | 321/529 [03:08<02:11,  1.58it/s]

Step 320 | Loss: 3.0652


 64%|██████▍   | 341/529 [03:20<01:58,  1.58it/s]

Step 340 | Loss: 3.1406


 68%|██████▊   | 361/529 [03:32<01:46,  1.58it/s]

Step 360 | Loss: 2.2369


 72%|███████▏  | 381/529 [03:43<01:33,  1.58it/s]

Step 380 | Loss: 2.4929


 76%|███████▌  | 401/529 [03:55<01:20,  1.58it/s]

Step 400 | Loss: 2.9762


 80%|███████▉  | 421/529 [04:07<01:08,  1.58it/s]

Step 420 | Loss: 2.0086


 83%|████████▎ | 441/529 [04:18<00:55,  1.58it/s]

Step 440 | Loss: 2.5843


 87%|████████▋ | 461/529 [04:30<00:42,  1.58it/s]

Step 460 | Loss: 1.2836


 91%|█████████ | 481/529 [04:42<00:30,  1.58it/s]

Step 480 | Loss: 1.1730


 95%|█████████▍| 501/529 [04:53<00:17,  1.58it/s]

Step 500 | Loss: 1.8330


 98%|█████████▊| 521/529 [05:05<00:05,  1.58it/s]

Step 520 | Loss: 2.3658


100%|██████████| 529/529 [05:10<00:00,  1.71it/s]



Epoch 2


  0%|          | 1/529 [00:00<06:29,  1.35it/s]

Step 0 | Loss: 1.5159


  4%|▍         | 21/529 [00:12<05:20,  1.58it/s]

Step 20 | Loss: 1.4198


  8%|▊         | 41/529 [00:24<05:08,  1.58it/s]

Step 40 | Loss: 1.9496


 12%|█▏        | 61/529 [00:35<04:55,  1.58it/s]

Step 60 | Loss: 1.6195


 15%|█▌        | 81/529 [00:47<04:42,  1.58it/s]

Step 80 | Loss: 1.4577


 19%|█▉        | 101/529 [00:59<04:30,  1.58it/s]

Step 100 | Loss: 0.7580


 23%|██▎       | 121/529 [01:10<04:17,  1.58it/s]

Step 120 | Loss: 1.4475


 27%|██▋       | 141/529 [01:22<04:05,  1.58it/s]

Step 140 | Loss: 1.5553


 30%|███       | 161/529 [01:34<03:52,  1.58it/s]

Step 160 | Loss: 1.9155


 34%|███▍      | 181/529 [01:45<03:39,  1.58it/s]

Step 180 | Loss: 1.8320


 38%|███▊      | 201/529 [01:57<03:27,  1.58it/s]

Step 200 | Loss: 1.5937


 42%|████▏     | 221/529 [02:09<03:14,  1.58it/s]

Step 220 | Loss: 1.8418


 46%|████▌     | 241/529 [02:20<03:01,  1.58it/s]

Step 240 | Loss: 1.1054


 49%|████▉     | 261/529 [02:32<02:49,  1.58it/s]

Step 260 | Loss: 1.4453


 53%|█████▎    | 281/529 [02:44<02:36,  1.58it/s]

Step 280 | Loss: 1.2753


 57%|█████▋    | 301/529 [02:55<02:23,  1.58it/s]

Step 300 | Loss: 2.1363


 61%|██████    | 321/529 [03:07<02:11,  1.58it/s]

Step 320 | Loss: 1.5383


 64%|██████▍   | 341/529 [03:19<01:58,  1.58it/s]

Step 340 | Loss: 1.6287


 68%|██████▊   | 361/529 [03:31<01:46,  1.58it/s]

Step 360 | Loss: 1.2015


 72%|███████▏  | 381/529 [03:42<01:33,  1.58it/s]

Step 380 | Loss: 1.9274


 76%|███████▌  | 401/529 [03:54<01:20,  1.58it/s]

Step 400 | Loss: 1.3184


 80%|███████▉  | 421/529 [04:06<01:08,  1.58it/s]

Step 420 | Loss: 1.0356


 83%|████████▎ | 441/529 [04:17<00:55,  1.58it/s]

Step 440 | Loss: 1.4483


 87%|████████▋ | 461/529 [04:29<00:42,  1.58it/s]

Step 460 | Loss: 0.9767


 91%|█████████ | 481/529 [04:41<00:30,  1.58it/s]

Step 480 | Loss: 1.1395


 95%|█████████▍| 501/529 [04:52<00:17,  1.58it/s]

Step 500 | Loss: 2.0969


 98%|█████████▊| 521/529 [05:04<00:05,  1.58it/s]

Step 520 | Loss: 1.6133


100%|██████████| 529/529 [05:09<00:00,  1.71it/s]



Epoch 3


  0%|          | 1/529 [00:00<06:29,  1.35it/s]

Step 0 | Loss: 0.4876


  4%|▍         | 21/529 [00:12<05:20,  1.58it/s]

Step 20 | Loss: 0.8747


  8%|▊         | 41/529 [00:24<05:08,  1.58it/s]

Step 40 | Loss: 0.7157


 12%|█▏        | 61/529 [00:35<04:55,  1.58it/s]

Step 60 | Loss: 0.4680


 15%|█▌        | 81/529 [00:47<04:43,  1.58it/s]

Step 80 | Loss: 0.2676


 19%|█▉        | 101/529 [00:59<04:30,  1.58it/s]

Step 100 | Loss: 0.8556


 23%|██▎       | 121/529 [01:10<04:17,  1.58it/s]

Step 120 | Loss: 0.4408


 27%|██▋       | 141/529 [01:22<04:05,  1.58it/s]

Step 140 | Loss: 0.7108


 30%|███       | 161/529 [01:34<03:52,  1.58it/s]

Step 160 | Loss: 0.5490


 34%|███▍      | 181/529 [01:45<03:39,  1.58it/s]

Step 180 | Loss: 0.7485


 38%|███▊      | 201/529 [01:57<03:27,  1.58it/s]

Step 200 | Loss: 0.6331


 42%|████▏     | 221/529 [02:09<03:14,  1.58it/s]

Step 220 | Loss: 0.4542


 46%|████▌     | 241/529 [02:20<03:01,  1.58it/s]

Step 240 | Loss: 0.5004


 49%|████▉     | 261/529 [02:32<02:49,  1.58it/s]

Step 260 | Loss: 0.6122


 53%|█████▎    | 281/529 [02:44<02:36,  1.58it/s]

Step 280 | Loss: 0.4804


 57%|█████▋    | 301/529 [02:56<02:24,  1.58it/s]

Step 300 | Loss: 0.5144


 61%|██████    | 321/529 [03:07<02:11,  1.58it/s]

Step 320 | Loss: 0.3905


 64%|██████▍   | 341/529 [03:19<01:58,  1.58it/s]

Step 340 | Loss: 0.5074


 68%|██████▊   | 361/529 [03:31<01:46,  1.58it/s]

Step 360 | Loss: 0.7278


 72%|███████▏  | 381/529 [03:42<01:33,  1.58it/s]

Step 380 | Loss: 0.3618


 76%|███████▌  | 401/529 [03:54<01:20,  1.58it/s]

Step 400 | Loss: 0.4825


 80%|███████▉  | 421/529 [04:06<01:08,  1.58it/s]

Step 420 | Loss: 0.7044


 83%|████████▎ | 441/529 [04:17<00:55,  1.58it/s]

Step 440 | Loss: 0.3991


 87%|████████▋ | 461/529 [04:29<00:42,  1.58it/s]

Step 460 | Loss: 1.1263


 91%|█████████ | 481/529 [04:41<00:30,  1.58it/s]

Step 480 | Loss: 0.7902


 95%|█████████▍| 501/529 [04:52<00:17,  1.58it/s]

Step 500 | Loss: 0.7335


 98%|█████████▊| 521/529 [05:04<00:05,  1.58it/s]

Step 520 | Loss: 0.5813


100%|██████████| 529/529 [05:09<00:00,  1.71it/s]


In [None]:
save_path = "/content/drive/MyDrive/nlp_project/stage2_prompt_enhancer"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

('/content/drive/MyDrive/nlp_project/stage2_prompt_enhancer/tokenizer_config.json',
 '/content/drive/MyDrive/nlp_project/stage2_prompt_enhancer/special_tokens_map.json',
 '/content/drive/MyDrive/nlp_project/stage2_prompt_enhancer/tokenizer.json')

In [None]:
import evaluate
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_path = "/content/drive/MyDrive/nlp_project/stage2_prompt_enhancer"
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).cuda()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# Testing
input_caption = "futuristic city skyline"

inputs = tokenizer(input_caption, return_tensors="pt").to("cuda")
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=64)
generated_prompt = tokenizer.decode(output[0], skip_special_tokens=True)

print("Input:\n", input_caption)
print("\nGenerated Prompt:\n", generated_prompt)

Setting `pad_token_id` to `eos_token_id`:32014 for open-end generation.


Input:
 futuristic city skyline

Generated Prompt:
 futuristic city skyline at night, with towering skyscrapers and a large, glowing moon in a vibrant starry backdrop.

