In [None]:
!pip install boto3
!pip install python-dotenv

In [None]:
from dotenv import load_dotenv
import os

load_dotenv(dotenv_path="/content/env")

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
import boto3
import tempfile
import random

In [None]:
# ==== Step 3.5: 从 S3 下载上一次 3 epoch 的 checkpoint ====
s3 = boto3.client(
    's3',
    aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
    aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY")
)
s3_bucket = "data298youcook2"
s3_key = "ModelScope_T2V_finetuned/lora_unet_2epoch.pt"
resume_path = "checkpoints/lora_unet_2epoch.pt"

if not os.path.exists(resume_path):
    os.makedirs("checkpoints", exist_ok=True)
    #print(" Downloading lora_16frames.pt from S3...")
    s3.download_file(s3_bucket, s3_key, resume_path)
    print(" Download complete.")

In [None]:
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
import imageio
import numpy as np
from IPython.display import HTML, Video
from base64 import b64encode
import os

# === Configuration ===
model_id = "damo-vilab/text-to-video-ms-1.7b"
lora_path = "checkpoints/lora_unet_2epoch.pt"
prompt = "flip the pancakes over"
#prompt = "bake the onions in the oven"
#prompt = "place the pan back on high flame and cook the dosa"
num_frames = 16
fps = 4
out_path = f"generated_{prompt.replace(' ', '_')}.mp4"

# === Load pre-trained pipeline ===
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")

# === Load fine-tuned LoRA weights ===
lora_weights = torch.load(lora_path, map_location="cuda")
pipe.unet.load_state_dict(lora_weights, strict=False)

# === Inference ===
pipe.enable_model_cpu_offload()
pipe.unet.eval()

with torch.no_grad(), torch.autocast("cuda"):
    result = pipe(prompt=prompt, num_frames=num_frames, num_inference_steps=25)

# === Process frames and convert to uint8 format ===
frames = result.frames[0]  # shape: [T, H, W, C]
frames = [(frame * 255).astype(np.uint8) if frame.max() <= 1 else frame.astype(np.uint8) for frame in frames]

# === Save video ===
dir_name = os.path.dirname(out_path)
if dir_name:
    os.makedirs(dir_name, exist_ok=True)

imageio.mimsave(out_path, frames, fps=fps)
print(f" Video saved: {out_path}, Frames: {len(frames)}, Duration: {len(frames)/fps:.2f} seconds")

# === Preview video (in Colab or Jupyter) ===
mp4 = open(out_path, "rb").read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML(f"<video width=512 controls><source src='{data_url}' type='video/mp4'></video>")

