# Loading AUGUST model with trained weights

This cell initializes the AUGUST model and loads pretrained weights from a checkpoint.

Notes and recommendations:
- Replace `"checkpoint_path"` with the actual path to your checkpoint file.
- `map_location=device` ensures tensors are loaded on the same device chosen earlier.
- After loading weights, move the model to the device and set it to eval mode for inference:
```python
model.to(device)
model.eval()

In [None]:
from august.models.model import AUGUST
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AUGUST(pretrained="checkpoint_path")
model.to(device)
model.eval()

# Human-in-the-loop Question Answering Sample

In [None]:
hitl_conversations = {
    "features_path": "./august/data/sample_features.pt",
    "conversation": [
            {
                "question": "Is this slide positive or negative for Helicobacter pylori infection?\nPlease format your answer as: 'This slide is {positive/negative} for Helicobacter pylori infection.'.",
                "answer": "This slide is negative for Helicobacter pylori infection.",
                "label": [
                    "negative"
                ],
                "task": [
                    "negative",
                    "positive"
                ]
            },
            {
                "question": "Considering the spectrum of gastric pathology, determine the category of diagnosis represented in this slide: inflammatory disease, benign tumor, dysplasia, or cancer.\nPlease format your answer as: 'The category of diagnosis represented in this slide is {inflammatory disease/benign tumor/dysplasia/cancer}.'.",
                "answer": "The category of diagnosis represented in this slide is cancer.",
                "label": [
                    "cancer"
                ],
                "task": [
                    "inflammatory disease",
                    "benign tumor",
                    "dysplasia",
                    "cancer"
                ]
            },
            {
                "question": "This slide shows malignant features. Determine the type of malignant tumor: carcinoma, malignant lymphoma, or NOS. If it is a carcinoma, also identify the histologic subtype.\nPlease format your answer as:\n* For malignant lymphoma or NOS: 'The malignant tumor is a {malignant lymphoma/NOS}.'\n* For carcinoma: 'The malignant tumor is a carcinoma, {adenocarcinoma/squamous cell carcinoma/neuroendocrine tumor}.'.",
                "answer": "The malignant tumor is a carcinoma, adenocarcinoma.",
                "label": [
                    "carcinoma, adenocarcinoma"
                ],
                "task": [
                    "malignant lymphoma",
                    "carcinoma, adenocarcinoma",
                    "carcinoma, squamous cell carcinoma",
                    "carcinoma, neuroendocrine tumor",
                    "nos"
                ]
            },
            {
                "question": "Given the malignant tumor type carcinoma, adenocarcinoma, specify the relevant detail or additional condition.\nPlease format your answer as: 'For adenocarcinoma, the tumor is {well differentiated/moderately differentiated/poorly cohesive carcinoma}.",
                "answer": "For carcinoma, adenocarcinoma, the tumor is poorly cohesive carcinoma.",
                "label": [
                    "poorly cohesive carcinoma"
                ],
                "task": [
                    "well differentiated",
                    "moderately differentiated",
                    "poorly cohesive carcinoma"
                ]
            }
        ]
}

### Diagnosis with AUGUST

In [None]:

features = torch.load(hitl_conversations["features_path"]).to(device)
history = ""
for turn in hitl_conversations["conversation"]:
    question = turn["question"]
    answer = turn["answer"]
    label = turn["label"]
    task = turn["task"]
    question = history + question
    # Process each turn with the model
    with torch.no_grad():
        with torch.autocast('cuda', torch.bfloat16), torch.inference_mode():
            slide_embedding, question_embedding = model.get_slide_embedding(features, question)
            g_answer = model(slide_embedding, history, question_embedding)
            print("Question:", question)
            print("Generated Answer:", g_answer)
            print("Answer:", answer)
        history +=  f"{turn['question']}\n{answer}\n\n"