In [7]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd


In [78]:
ID2LABEL = {
    0: 'Drastic Fall', 1: 'Drastic Rise', 2: 'Fall', 3: 'Rise', 4: 'Stable'
}

In [79]:
def predict_movement(model_path, input_texts):
    model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True)
    tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    predictions = []
    with torch.no_grad():
        for text in input_texts:
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)
            outputs = model(**inputs)
            logits = outputs.logits
            predicted_class_id = torch.argmax(logits, dim=1).item()
            predictions.append(ID2LABEL[predicted_class_id])

    return predictions

In [115]:
saved_model_directory = '/results/checkpoint-10350'

new_data = [
    "Context: [Quarter: Q4, VIX: 15.20, Price vs 50-Day Avg: 10.10%] | Headline: Jury: X aquires Y after paying more than $6 billion",
    "Context: [Quarter: Q2, VIX: 28.50, Price vs 50-Day Avg: -7.80%] | Headline: Global: supply chain intensify as major funding stop.",
    "Context: [Quarter: Q2, VIX: 15.13, Price vs 50-Day Avg: -0.05%] | Headline: Marketplace Africa: Welcome to our new look show!"
]

model_predictions = predict_movement(saved_model_directory, new_data)

results_df = pd.DataFrame({"Input Text": new_data, "AI Prediction": model_predictions})
print(results_df.to_string())

                                                                                                                          Input Text AI Prediction
0    Context: [Quarter: Q4, VIX: 15.20, Price vs 50-Day Avg: 10.10%] | Headline: Jury: X aquires Y after paying more than $6 billion          Rise
1  Context: [Quarter: Q2, VIX: 28.50, Price vs 50-Day Avg: -7.80%] | Headline: Global: supply chain intensify as major funding stop.  Drastic Fall
2      Context: [Quarter: Q2, VIX: 15.13, Price vs 50-Day Avg: -0.05%] | Headline: Marketplace Africa: Welcome to our new look show!          Fall
