In [1]:
import torch
import torch.nn as nn
import pickle
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel


In [2]:
# ============================================
# LOAD MODEL AND ARTIFACTS
# ============================================

print("Loading model and artifacts...")

RUN_DIR = "./experiments/lr=0.0005_ep=8"

# 1. Load configuration
with open(f"{RUN_DIR}/model_config.pkl", "rb") as f:
    config_info = pickle.load(f)

NUM_LABELS = config_info["num_labels"]
MODEL_NAME = config_info["model_name"]
GENRES = config_info["genres"]

print(f"‚úì Config loaded - {NUM_LABELS} genres")

# 2. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    f"{RUN_DIR}/saved_tokenizer",
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
print("‚úì Tokenizer loaded")

# 3. Load MultiLabelBinarizer
with open(f"{RUN_DIR}/mlb.pkl", "rb") as f:
    mlb = pickle.load(f)
print("‚úì MultiLabelBinarizer loaded")

# 4. Rebuild model architecture (UNCHANGED)
class QwenForMultiLabelClassification(nn.Module):
    def __init__(self, base_model, num_labels):
        super().__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(
            base_model.config.hidden_size,
            num_labels
        )

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        last_hidden = outputs.last_hidden_state[:, -1, :]
        logits = self.classifier(last_hidden)

        loss = None
        if labels is not None:
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(logits, labels.float())

        return {"loss": loss, "logits": logits}

# 5. Load base model and LoRA weights
base_model = AutoModel.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True
)

# Load LoRA adapter
base_model = PeftModel.from_pretrained(
    base_model,
    f"{RUN_DIR}/saved_model"
)
print("‚úì Base model and LoRA weights loaded")

# 6. Create full model
model = QwenForMultiLabelClassification(base_model, NUM_LABELS)

# Load classifier head weights
classifier_state = torch.load(f"{RUN_DIR}/saved_model/classifier_head.pt")
model.classifier.load_state_dict(classifier_state["classifier"])
print("‚úì Classifier head loaded")

model.eval()
print("‚úì Model ready for inference")

print("\n" + "=" * 60)
print("MODEL LOADED SUCCESSFULLY!")
print("=")


Loading model and artifacts...
‚úì Config loaded - 23 genres
‚úì Tokenizer loaded
‚úì MultiLabelBinarizer loaded
‚úì Base model and LoRA weights loaded
‚úì Classifier head loaded
‚úì Model ready for inference

MODEL LOADED SUCCESSFULLY!
=


In [5]:

# ============================================
# INFERENCE FUNCTION
# ============================================

def predict_genres(text, threshold=0.5, top_k=None):
    """
    Predict genres for a given movie overview
    
    Args:
        text: Movie overview text
        threshold: Probability threshold (default: 0.5)
        top_k: Return only top K genres (optional)
    
    Returns:
        Dictionary with predicted genres and probabilities
    """
    # Tokenize input
    inputs = tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt"
    )
    
    # Run inference
    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        )
        logits = outputs["logits"]
        probs = torch.sigmoid(logits).squeeze().numpy()
    
    # Get predictions above threshold
    predictions = []
    for i, prob in enumerate(probs):
        if prob >= threshold:
            predictions.append({
                "genre": mlb.classes_[i],
                "probability": float(prob)
            })
    
    # Sort by probability
    predictions = sorted(predictions, key=lambda x: x["probability"], reverse=True)
    
    # Return top K if specified
    if top_k is not None:
        predictions = predictions[:top_k]
    
    return predictions


In [6]:

# ============================================
# TEST WITH CUSTOM EXAMPLES
# ============================================

print("\n" + "="*60)
print("TESTING WITH CUSTOM MOVIE OVERVIEWS")
print("="*60)

# Example 1: Action movie
text1 = """
A group of elite soldiers must infiltrate a heavily guarded compound 
to rescue hostages before time runs out. Explosions, car chases, and 
intense combat sequences define this high-octane thriller.
"""

print("\nüìΩÔ∏è Example 1:")
print(f"Overview: {text1.strip()}")
print("\nPredicted Genres:")
results1 = predict_genres(text1, threshold=0.3, top_k=5)
for pred in results1:
    print(f"  - {pred['genre']}: {pred['probability']:.3f}")

# Example 2: Romantic comedy
text2 = """
A quirky bookstore owner accidentally spills coffee on a charming 
businessman, sparking an unlikely romance. Through hilarious 
misunderstandings and heartfelt moments, they discover love in 
the most unexpected places.
"""

print("\nüìΩÔ∏è Example 2:")
print(f"Overview: {text2.strip()}")
print("\nPredicted Genres:")
results2 = predict_genres(text2, threshold=0.3, top_k=5)
for pred in results2:
    print(f"  - {pred['genre']}: {pred['probability']:.3f}")

# Example 3: Sci-fi horror
text3 = """
On a remote space station, the crew awakens to find themselves 
hunted by an unknown alien entity. As systems fail and crew members 
disappear, they must uncover the terrifying truth before it's too late.
"""

print("\nüìΩÔ∏è Example 3:")
print(f"Overview: {text3.strip()}")
print("\nPredicted Genres:")
results3 = predict_genres(text3, threshold=0.3, top_k=5)
for pred in results3:
    print(f"  - {pred['genre']}: {pred['probability']:.3f}")


TESTING WITH CUSTOM MOVIE OVERVIEWS

üìΩÔ∏è Example 1:
Overview: A group of elite soldiers must infiltrate a heavily guarded compound 
to rescue hostages before time runs out. Explosions, car chases, and 
intense combat sequences define this high-octane thriller.

Predicted Genres:
  - Action: 1.000
  - Thriller: 0.895

üìΩÔ∏è Example 2:
Overview: A quirky bookstore owner accidentally spills coffee on a charming 
businessman, sparking an unlikely romance. Through hilarious 
misunderstandings and heartfelt moments, they discover love in 
the most unexpected places.

Predicted Genres:
  - Comedy: 1.000
  - Romance: 0.977
  - Drama: 0.611

üìΩÔ∏è Example 3:
Overview: On a remote space station, the crew awakens to find themselves 
hunted by an unknown alien entity. As systems fail and crew members 
disappear, they must uncover the terrifying truth before it's too late.

Predicted Genres:
  - Sci-Fi: 0.995
  - Adventure: 0.602
  - Thriller: 0.587
  - Action: 0.551


In [None]:

# ============================================
# INTERACTIVE MODE
# ============================================

print("\n" + "="*60)
print("INTERACTIVE MODE")
print("="*60)
print("Enter your custom movie overview (or 'quit' to exit):\n")

while True:
    custom_text = input("Movie Overview: ")
    
    if custom_text.lower() in ['quit', 'exit', 'q']:
        print("Goodbye!")
        break
    
    if not custom_text.strip():
        print("Please enter a valid overview.\n")
        continue
    
    print("\nPredicted Genres:")
    results = predict_genres(custom_text, threshold=0.3, top_k=5)
    
    if not results:
        print("  No genres predicted above threshold.")
    else:
        for pred in results:
            print(f"  - {pred['genre']}: {pred['probability']:.3f}")
    
    print("\n" + "-"*60 + "\n")


INTERACTIVE MODE
Enter your custom movie overview (or 'quit' to exit):


Predicted Genres:
  - Action: 0.897
  - Comedy: 0.561
  - Family: 0.519
  - Music: 0.375

------------------------------------------------------------

